Skip to content

Commit

Permalink
Always pretty-print hyperparameters whose schema has been customized …
Browse files Browse the repository at this point in the history
…if that customization changed their default. (#948)

But only if the customized schema itself is not being printed as well.
  • Loading branch information
hirzel authored Jan 14, 2022
1 parent 77a5c0b commit 965f003
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 32 deletions.
94 changes: 78 additions & 16 deletions lale/json_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,22 +313,36 @@ def __call__(self, prefix: str) -> str:
return result


def _hps_to_json_rec(hps, cls2label: Dict[str, str], gensym: _GenSym, steps) -> Any:
def _hps_to_json_rec(
hps,
cls2label: Dict[str, str],
gensym: _GenSym,
steps,
add_custom_default: bool,
) -> Any:
if isinstance(hps, lale.operators.Operator):
step_uid, step_jsn = _op_to_json_rec(hps, cls2label, gensym)
step_uid, step_jsn = _op_to_json_rec(hps, cls2label, gensym, add_custom_default)
steps[step_uid] = step_jsn
return {"$ref": f"../steps/{step_uid}"}
elif isinstance(hps, dict):
return {
hp_name: _hps_to_json_rec(hp_val, cls2label, gensym, steps)
hp_name: _hps_to_json_rec(
hp_val, cls2label, gensym, steps, add_custom_default
)
for hp_name, hp_val in hps.items()
}
elif isinstance(hps, tuple):
return tuple(
[_hps_to_json_rec(hp_val, cls2label, gensym, steps) for hp_val in hps]
[
_hps_to_json_rec(hp_val, cls2label, gensym, steps, add_custom_default)
for hp_val in hps
]
)
elif isinstance(hps, list):
return [_hps_to_json_rec(hp_val, cls2label, gensym, steps) for hp_val in hps]
return [
_hps_to_json_rec(hp_val, cls2label, gensym, steps, add_custom_default)
for hp_val in hps
]
else:
return hps

Expand Down Expand Up @@ -378,12 +392,37 @@ def list_equal_modulo(l1, l2, mod):
for hp_name, hp_schema in after.items()
if hp_name not in before or hp_schema != before[hp_name]
}
result = {"properties": {"hyperparams": {"allOf": [hp_diff]}}}
result = {
"properties": {
"hyperparams": {"allOf": [{"type": "object", "properties": hp_diff}]}
}
}
return result


def _top_schemas_to_hparams(top_level_schemas) -> JSON_TYPE:
if not isinstance(top_level_schemas, dict):
return {}
return top_level_schemas.get("properties", {}).get("hyperparams", {})


def _hparams_schemas_to_props(hparams_schemas) -> JSON_TYPE:
if not isinstance(hparams_schemas, dict):
return {}
return hparams_schemas.get("allOf", [{}])[0].get("properties", {})


def _top_schemas_to_hp_props(top_level_schemas) -> JSON_TYPE:
hparams = _top_schemas_to_hparams(top_level_schemas)
props = _hparams_schemas_to_props(hparams)
return props


def _op_to_json_rec(
op: "lale.operators.Operator", cls2label: Dict[str, str], gensym: _GenSym
op: "lale.operators.Operator",
cls2label: Dict[str, str],
gensym: _GenSym,
add_custom_default: bool,
) -> Tuple[str, JSON_TYPE]:
jsn: JSON_TYPE = {}
jsn["class"] = op.class_name()
Expand All @@ -402,17 +441,15 @@ def _op_to_json_rec(
if hyperparams is None:
jsn["hyperparams"] = None
else:
hp_schema = (
op.hyperparam_schema().get("allOf", [{}])[0].get("properties", {})
)
hp_schema = _hparams_schemas_to_props(op.hyperparam_schema())
hyperparams = {
k: v
for k, v in hyperparams.items()
if not hp_schema.get(k, {}).get("transient", False)
}
steps: Dict[str, JSON_TYPE] = {}
jsn["hyperparams"] = _hps_to_json_rec(
hyperparams, cls2label, gensym, steps
hyperparams, cls2label, gensym, steps, add_custom_default
)
if len(steps) > 0:
jsn["steps"] = steps
Expand All @@ -426,12 +463,31 @@ def _op_to_json_rec(
orig_schemas = lale.operators.get_lib_schemas(op.impl_class)
if op._schemas is not orig_schemas:
jsn["customize_schema"] = _get_customize_schema(op._schemas, orig_schemas)
if add_custom_default and isinstance(
jsn.get("customize_schema", None), dict
):
if isinstance(jsn.get("hyperparams", None), dict):
assert jsn["hyperparams"] is not None # to help pyright
orig = _top_schemas_to_hp_props(orig_schemas)
cust = _top_schemas_to_hp_props(jsn["customize_schema"])
for hp_name, hp_schema in cust.items():
if "default" in hp_schema:
if hp_name not in jsn["hyperparams"]:
cust_default = hp_schema["default"]
if hp_name in orig and "default" in orig[hp_name]:
orig_default = orig[hp_name]["default"]
if cust_default != orig_default:
jsn["hyperparams"][hp_name] = cust_default
else:
jsn["hyperparams"][hp_name] = cust_default
elif isinstance(op, lale.operators.BasePipeline):
uid = gensym("pipeline")
child2uid: Dict[lale.operators.Operator, str] = {}
child2jsn: Dict[lale.operators.Operator, JSON_TYPE] = {}
for idx, child in enumerate(op.steps_list()):
child_uid, child_jsn = _op_to_json_rec(child, cls2label, gensym)
child_uid, child_jsn = _op_to_json_rec(
child, cls2label, gensym, add_custom_default
)
child2uid[child] = child_uid
child2jsn[child] = child_jsn
jsn["edges"] = [[child2uid[x], child2uid[y]] for x, y in op.edges()]
Expand All @@ -442,19 +498,25 @@ def _op_to_json_rec(
jsn["state"] = "planned"
jsn["steps"] = {}
for step in op.steps_list():
child_uid, child_jsn = _op_to_json_rec(step, cls2label, gensym)
child_uid, child_jsn = _op_to_json_rec(
step, cls2label, gensym, add_custom_default
)
jsn["steps"][child_uid] = child_jsn
else:
raise ValueError(f"Unexpected argument of type: {type(op)}")
return uid, jsn


def to_json(op: "lale.operators.Operator", call_depth: int = 1) -> JSON_TYPE:
def to_json(
op: "lale.operators.Operator",
call_depth: int = 1,
add_custom_default: bool = False,
) -> JSON_TYPE:
from lale.settings import disable_hyperparams_schema_validation

cls2label = _get_cls2label(call_depth + 1)
gensym = _GenSym(op, cls2label)
uid, jsn = _op_to_json_rec(op, cls2label, gensym)
uid, jsn = _op_to_json_rec(op, cls2label, gensym, add_custom_default)
if not disable_hyperparams_schema_validation:
jsonschema.validate(jsn, SCHEMA, jsonschema.Draft4Validator)
return jsn
Expand Down Expand Up @@ -497,7 +559,7 @@ def _op_from_json_rec(jsn: JSON_TYPE) -> "lale.operators.Operator":
name = jsn["operator"]
result = lale.operators.make_operator(impl, schemas, name)
if jsn.get("customize_schema", {}) != {}:
new_hps = jsn["customize_schema"]["properties"]["hyperparams"]["allOf"][0]
new_hps = _top_schemas_to_hp_props(jsn["customize_schema"])
result = result.customize_schema(**new_hps)
if jsn["state"] in ["trainable", "trained"]:
if _get_state(result) == "planned":
Expand Down
14 changes: 9 additions & 5 deletions lale/pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def value_to_string(value):
gen.imports.append(f"import {module}")
printed = f"{module}.{printed}"
if printed.startswith("<"):
m = re.match(r"<(\w[\w.]*)\.(\w+) object at 0x[0-9a-f]+>$", printed)
m = re.match(r"<(\w[\w.]*)\.(\w+) object at 0x[0-9a-fA-F]+>$", printed)
if m:
module, clazz = m.group(1), m.group(2)
if gen is not None:
Expand Down Expand Up @@ -519,9 +519,9 @@ def print_for_comb(step_uid, step_val):
if jsn["customize_schema"] == "not_available":
logger.warning(f"missing {label}.customize_schema(..) call")
elif jsn["customize_schema"] != {}:
new_hps = jsn["customize_schema"]["properties"]["hyperparams"]["allOf"][
0
]
new_hps = lale.json_operator._top_schemas_to_hp_props(
jsn["customize_schema"]
)
customize_schema_string = ",".join(
[
f"{hp_name}={json_to_string(hp_schema)}"
Expand Down Expand Up @@ -657,7 +657,11 @@ def to_string(
if lale.type_checking.is_schema(arg):
return json_to_string(cast(JSON_TYPE, arg))
elif isinstance(arg, lale.operators.Operator):
jsn = lale.json_operator.to_json(arg, call_depth=call_depth + 1)
jsn = lale.json_operator.to_json(
arg,
call_depth=call_depth + 1,
add_custom_default=not customize_schema,
)
return _operator_jsn_to_string(
jsn, show_imports, combinators, customize_schema, astype
)
Expand Down
49 changes: 38 additions & 11 deletions test/test_json_pretty_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ def test_customize_schema_none_and_boolean(self):
],
"default": 33,
},
)
)(n_estimators=50)
expected = """from sklearn.ensemble import RandomForestRegressor
import lale
Expand All @@ -961,9 +961,33 @@ def test_customize_schema_none_and_boolean(self):
],
"default": 33,
},
)"""
)(n_estimators=50)"""
# this should not include "random_state=33" because that would be
# redundant with the schema, and would prevent automated search
self._roundtrip(expected, pipeline.pretty_print(customize_schema=True))

def test_customize_schema_print_defaults(self):
from lale.lib.sklearn import RandomForestRegressor

pipeline = RandomForestRegressor.customize_schema(
bootstrap={"type": "boolean", "default": True}, # default unchanged
random_state={
"anyOf": [
{"laleType": "numpy.random.RandomState"},
{"enum": [None]},
{"type": "integer"},
],
"default": 33, # default changed
},
)(n_estimators=50)
expected = """from sklearn.ensemble import RandomForestRegressor
import lale
lale.wrap_imported_operators()
pipeline = RandomForestRegressor(n_estimators=50, random_state=33)"""
# print exactly those defaults that changed
self._roundtrip(expected, pipeline.pretty_print(customize_schema=False))

def test_user_operator_in_toplevel_module(self):
import importlib
import os.path
Expand Down Expand Up @@ -1346,15 +1370,18 @@ def test_customize_schema(self):
"hyperparams": {
"allOf": [
{
"solver": {
"default": "liblinear",
"enum": ["lbfgs", "liblinear"],
},
"tol": {
"type": "number",
"minimum": 0.00001,
"maximum": 0.1,
"default": 0.0001,
"type": "object",
"properties": {
"solver": {
"default": "liblinear",
"enum": ["lbfgs", "liblinear"],
},
"tol": {
"type": "number",
"minimum": 0.00001,
"maximum": 0.1,
"default": 0.0001,
},
},
}
]
Expand Down

0 comments on commit 965f003

Please sign in to comment.