Skip to content

Commit

Permalink
Merge pull request #277 from xopt-org/imports
Browse files Browse the repository at this point in the history
update init py files
  • Loading branch information
roussel-ryan authored Feb 21, 2025
2 parents ba6acb9 + 671e297 commit f34967e
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion xopt/generators/bayesian/bayesian_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def validate_torch_modules(cls, v):
if v.startswith("base64:"):
v = decode_torch_module(v)
elif os.path.exists(v):
v = torch.load(v)
v = torch.load(v, weights_only=False)
return v

@field_validator("gp_constructor", mode="before")
Expand Down
2 changes: 1 addition & 1 deletion xopt/generators/bayesian/models/standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def validate_torch_modules(cls, v):
if val.startswith("base64:"):
v[key] = decode_torch_module(val)
elif os.path.exists(val):
v[key] = torch.load(val)
v[key] = torch.load(val, weights_only=False)

return v

Expand Down
3 changes: 3 additions & 0 deletions xopt/generators/es/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from xopt.generators.es.extremumseeking import ExtremumSeekingGenerator

__all__ = ["ExtremumSeekingGenerator"]
3 changes: 3 additions & 0 deletions xopt/generators/rcds/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from xopt.generators.rcds.rcds import RCDSGenerator

__all__ = ["RCDSGenerator"]
4 changes: 4 additions & 0 deletions xopt/generators/scipy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from xopt.generators.scipy.latin_hypercube import LatinHypercubeGenerator
from xopt.generators.scipy.neldermead import NelderMeadGenerator

__all__ = ["LatinHypercubeGenerator", "NelderMeadGenerator"]
4 changes: 2 additions & 2 deletions xopt/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def decode_torch_module(modulestr: str):
decoded = base64.standard_b64decode(base64str)
decompressed = gzip.decompress(decoded)
bytestream = io.BytesIO(decompressed)
module = torch.load(bytestream)
module = torch.load(bytestream, weights_only=False)
return module


Expand All @@ -194,7 +194,7 @@ def validate_files(cls, value, info: ValidationInfo):
if os.path.exists(value):
extension = value.split(".")[-1]
if extension == "pt":
value = torch.load(value)
value = torch.load(value, weights_only=False)

return value

Expand Down

0 comments on commit f34967e

Please sign in to comment.