Skip to content

Commit

Permalink
Force conversion to Vector
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jan 26, 2024
1 parent 68ea1be commit e530637
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 18 deletions.
2 changes: 2 additions & 0 deletions pysr/julia_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

jl = juliacall.newmodule("PySR")

from juliacall import convert as jl_convert

juliainfo = None
julia_initialized = False
julia_kwargs_at_initialization = None
Expand Down
43 changes: 25 additions & 18 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
from .export_torch import sympy2torch
from .feature_selection import run_feature_selection
from .julia_helpers import _escape_filename, _load_cluster_manager, jl
from .julia_helpers import _escape_filename, _load_cluster_manager, jl, jl_convert
from .utils import (
_csv_filename_to_pkl_filename,
_preprocess_julia_floats,
Expand Down Expand Up @@ -1609,12 +1609,11 @@ def _run(self, X, y, mutated_params, weights, seed):

# Call to Julia backend.
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/OptionsStruct.jl
print(bin_constraints)
options = SymbolicRegression.Options(
binary_operators=jl.seval(str(binary_operators).replace("'", "")),
unary_operators=jl.seval(str(unary_operators).replace("'", "")),
bin_constraints=bin_constraints,
una_constraints=una_constraints,
bin_constraints=jl_convert(jl.Vector, bin_constraints),
una_constraints=jl_convert(jl.Vector, una_constraints),
complexity_of_operators=complexity_of_operators,
complexity_of_constants=self.complexity_of_constants,
complexity_of_variables=self.complexity_of_variables,
Expand Down Expand Up @@ -1679,18 +1678,18 @@ def _run(self, X, y, mutated_params, weights, seed):
np_dtype = {32: np.complex64, 64: np.complex128}[self.precision]

# This converts the data into a Julia array:
Main.X = np.array(X, dtype=np_dtype).T
jl_X = jl_convert(jl.Array, np.array(X, dtype=np_dtype).T)
if len(y.shape) == 1:
Main.y = np.array(y, dtype=np_dtype)
jl_y = jl_convert(jl.Vector, np.array(y, dtype=np_dtype))
else:
Main.y = np.array(y, dtype=np_dtype).T
jl_y = jl_convert(jl.Array, np.array(y, dtype=np_dtype).T)
if weights is not None:
if len(weights.shape) == 1:
Main.weights = np.array(weights, dtype=np_dtype)
jl_weights = jl_convert(jl.Vector, np.array(weights, dtype=np_dtype))
else:
Main.weights = np.array(weights, dtype=np_dtype).T
jl_weights = jl_convert(jl.Array, np.array(weights, dtype=np_dtype).T)
else:
Main.weights = None
jl_weights = None

if self.procs == 0 and not multithreading:
parallelism = "serial"
Expand All @@ -1703,22 +1702,30 @@ def _run(self, X, y, mutated_params, weights, seed):
None if parallelism in ["serial", "multithreading"] else int(self.procs)
)

y_variable_names = None
if len(y.shape) > 1:
# We set these manually so that they respect Python's 0 indexing
# (by default Julia will use y1, y2...)
y_variable_names = [f"y{_subscriptify(i)}" for i in range(y.shape[1])]
jl_y_variable_names = jl_convert(
jl.Vector, [f"y{_subscriptify(i)}" for i in range(y.shape[1])]
)
else:
jl_y_variable_names = None

jl_feature_names = jl_convert(jl.Vector, self.feature_names_in_.tolist())
jl_display_feature_names = jl_convert(
jl.Vector, self.display_feature_names_in_.tolist()
)

# Call to Julia backend.
# See https://github.com/MilesCranmer/SymbolicRegression.jl/blob/master/src/SymbolicRegression.jl
self.raw_julia_state_ = SymbolicRegression.equation_search(
Main.X,
Main.y,
weights=Main.weights,
jl_X,
jl_y,
weights=jl_weights,
niterations=int(self.niterations),
variable_names=self.feature_names_in_.tolist(),
display_variable_names=self.display_feature_names_in_.tolist(),
y_variable_names=y_variable_names,
variable_names=jl_feature_names,
display_variable_names=jl_display_feature_names,
y_variable_names=jl_y_variable_names,
X_units=self.X_units_,
y_units=self.y_units_,
options=options,
Expand Down

0 comments on commit e530637

Please sign in to comment.