Skip to content

Commit

Permalink
Convert numpy types back to builtin types
Browse files Browse the repository at this point in the history
Scikit-learn or numpy changed the typing of the parameters
(seen in a masked array, not sure if also outside of that).
Convert these values back to Python builtins.
  • Loading branch information
PGijsbers committed Jul 4, 2024
1 parent a2e7022 commit 828a7a4
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion openml/extensions/sklearn/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,7 +2166,14 @@ def _extract_trace_data(self, model, rep_no, fold_no):
for key in model.cv_results_:
if key.startswith("param_"):
value = model.cv_results_[key][itt_no]
serialized_value = json.dumps(value) if value is not np.ma.masked else np.nan
# Built-in serializer does not convert all numpy types,
# these methods convert them to built-in types instead.
if value is np.ma.masked:
value = np.nan
if isinstance(value, np.generic):
# For scalars it actually returns scalars, not a list
value = value.tolist()
serialized_value = json.dumps(value)
arff_line.append(serialized_value)
arff_tracecontent.append(arff_line)
return arff_tracecontent
Expand Down Expand Up @@ -2215,6 +2222,8 @@ def _obtain_arff_trace(
# int float
supported_basic_types = (bool, int, float, str)
for param_value in model.cv_results_[key]:
if isinstance(param_value, np.generic):
param_value = param_value.tolist() # noqa: PLW2901
if (
isinstance(param_value, supported_basic_types)
or param_value is None
Expand Down

0 comments on commit 828a7a4

Please sign in to comment.