Skip to content

Commit

Permalink
Specify types for all parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 24, 2023
1 parent bd4f864 commit 0ddc60f
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 86 deletions.
174 changes: 89 additions & 85 deletions pysr/sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from io import StringIO
from multiprocessing import cpu_count
from pathlib import Path
from typing import List, Optional
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -659,90 +659,92 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):

def __init__(
self,
model_selection="best",
model_selection: Literal["best", "accuracy", "score"] = "best",
*,
binary_operators=None,
unary_operators=None,
niterations=40,
populations=15,
population_size=33,
max_evals=None,
maxsize=20,
maxdepth=None,
warmup_maxsize_by=0.0,
timeout_in_seconds=None,
constraints=None,
nested_constraints=None,
loss=None,
full_objective=None,
complexity_of_operators=None,
complexity_of_constants=1,
complexity_of_variables=1,
parsimony=0.0032,
dimensional_constraint_penalty=None,
use_frequency=True,
use_frequency_in_tournament=True,
adaptive_parsimony_scaling=20.0,
alpha=0.1,
annealing=False,
early_stop_condition=None,
ncyclesperiteration=550,
fraction_replaced=0.000364,
fraction_replaced_hof=0.035,
weight_add_node=0.79,
weight_insert_node=5.1,
weight_delete_node=1.7,
weight_do_nothing=0.21,
weight_mutate_constant=0.048,
weight_mutate_operator=0.47,
weight_randomize=0.00023,
weight_simplify=0.0020,
weight_optimize=0.0,
crossover_probability=0.066,
skip_mutation_failures=True,
migration=True,
hof_migration=True,
topn=12,
should_simplify=None,
should_optimize_constants=True,
optimizer_algorithm="BFGS",
optimizer_nrestarts=2,
optimize_probability=0.14,
optimizer_iterations=8,
perturbation_factor=0.076,
tournament_selection_n=10,
tournament_selection_p=0.86,
procs=cpu_count(),
multithreading=None,
cluster_manager=None,
heap_size_hint_in_bytes=None,
batching=False,
batch_size=50,
fast_cycle=False,
turbo=False,
precision=32,
enable_autodiff=False,
random_state=None,
deterministic=False,
warm_start=False,
verbosity=1,
update_verbosity=None,
print_precision=5,
progress=True,
equation_file=None,
temp_equation_file=False,
tempdir=None,
delete_tempfiles=True,
julia_project=None,
update=False,
output_jax_format=False,
output_torch_format=False,
extra_sympy_mappings=None,
extra_torch_mappings=None,
extra_jax_mappings=None,
denoise=False,
select_k_features=None,
julia_kwargs=None,
binary_operators: Optional[List[str]] = None,
unary_operators: Optional[List[str]] = None,
niterations: int = 40,
populations: int = 15,
population_size: int = 33,
max_evals: Optional[int] = None,
maxsize: int = 20,
maxdepth: Optional[int] = None,
warmup_maxsize_by: Optional[float] = None,
timeout_in_seconds: Optional[float] = None,
constraints: Optional[Dict[str, Union[int, Tuple[int, int]]]] = None,
nested_constraints: Optional[Dict[str, Dict[str, int]]] = None,
loss: Optional[str] = None,
full_objective: Optional[str] = None,
complexity_of_operators: Optional[Dict[str, Union[int, float]]] = None,
complexity_of_constants: Union[int, float] = 1,
complexity_of_variables: Union[int, float] = 1,
parsimony: float = 0.0032,
dimensional_constraint_penalty: Optional[float] = None,
use_frequency: bool = True,
use_frequency_in_tournament: bool = True,
adaptive_parsimony_scaling: float = 20.0,
alpha: float = 0.1,
annealing: bool = False,
early_stop_condition: Optional[Union[float, str]] = None,
ncyclesperiteration: int = 550,
fraction_replaced: float = 0.000364,
fraction_replaced_hof: float = 0.035,
weight_add_node: float = 0.79,
weight_insert_node: float = 5.1,
weight_delete_node: float = 1.7,
weight_do_nothing: float = 0.21,
weight_mutate_constant: float = 0.048,
weight_mutate_operator: float = 0.47,
weight_randomize: float = 0.00023,
weight_simplify: float = 0.0020,
weight_optimize: float = 0.0,
crossover_probability: float = 0.066,
skip_mutation_failures: bool = True,
migration: bool = True,
hof_migration: bool = True,
topn: int = 12,
should_simplify: Optional[bool] = None,
should_optimize_constants: bool = True,
optimizer_algorithm: str = "BFGS",
optimizer_nrestarts: int = 2,
optimize_probability: float = 0.14,
optimizer_iterations: int = 8,
perturbation_factor: float = 0.076,
tournament_selection_n: int = 10,
tournament_selection_p: float = 0.86,
procs: int = cpu_count(),
multithreading: Optional[bool] = None,
cluster_manager: Optional[
Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"]
] = None,
heap_size_hint_in_bytes: Optional[int] = None,
batching: bool = False,
batch_size: int = 50,
fast_cycle: bool = False,
turbo: bool = False,
precision: int = 32,
enable_autodiff: bool = False,
random_state: Optional[Union[int, np.random.RandomState]] = None,
deterministic: bool = False,
warm_start: bool = False,
verbosity: int = 1,
update_verbosity: Optional[int] = None,
print_precision: int = 5,
progress: bool = True,
equation_file: Optional[str] = None,
temp_equation_file: bool = False,
tempdir: Optional[str] = None,
delete_tempfiles: bool = True,
julia_project: Optional[str] = None,
update: bool = False,
output_jax_format: bool = False,
output_torch_format: bool = False,
extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
extra_torch_mappings: Optional[Dict[Callable, Callable]] = None,
extra_jax_mappings: Optional[Dict[Callable, str]] = None,
denoise: bool = False,
select_k_features: Optional[int] = None,
julia_kwargs: Optional[Dict] = None,
**kwargs,
):
# Hyperparameters
Expand Down Expand Up @@ -1645,7 +1647,9 @@ def _run(self, X, y, mutated_params, weights, seed):
fraction_replaced_hof=self.fraction_replaced_hof,
should_simplify=self.should_simplify,
should_optimize_constants=self.should_optimize_constants,
warmup_maxsize_by=self.warmup_maxsize_by,
warmup_maxsize_by=0.0
if self.warmup_maxsize_by is None
else self.warmup_maxsize_by,
use_frequency=self.use_frequency,
use_frequency_in_tournament=self.use_frequency_in_tournament,
adaptive_parsimony_scaling=self.adaptive_parsimony_scaling,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
],
python_requires=">=3.7",
python_requires=">=3.8",
)

0 comments on commit 0ddc60f

Please sign in to comment.