Skip to content

Commit

Permalink
fix(pre-commit.ci): auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 21, 2024
1 parent ad4f19b commit 284b994
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
2 changes: 1 addition & 1 deletion run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ source /home/shoaib/scratch/venv/bin/activate
# python -m pip install --ignore-installed torch
# python -m pip install --ignore-installed lightning_fabric
# python -m pip install srai[torch]
PYTHONPATH=. pytest tests/embedders/geovex/test_embedder.py::test_embedder_save_load
PYTHONPATH=. pytest tests/embedders/geovex/test_embedder.py::test_embedder_save_load
11 changes: 5 additions & 6 deletions srai/embedders/geovex/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
[1] https://openreview.net/forum?id=7bvWopYY1H
"""

from typing import Any, Optional, TypeVar, Union
import json
from pathlib import Path
from typing import Any, Optional, TypeVar, Union

import geopandas as gpd
import numpy as np
Expand Down Expand Up @@ -269,8 +269,6 @@ def save(self, path: Union[str, Any]) -> None:
Args:
path (Union[str, Any]): Path to the directory.
"""
import torch

# embedder_config must match the constructor signature:
# target_features: Union[list[str], OsmTagsFilter, GroupedOsmTagsFilter],
# batch_size: Optional[int] = 32,
Expand All @@ -287,7 +285,7 @@ def save(self, path: Union[str, Any]) -> None:
"convolutional_layer_size": self._convolutional_layer_size,
}
self._save(path, embedder_config)

def _save(self, path: Union[str, Any], embedder_config: dict[str, Any]) -> None:
if isinstance(path, str):
path = Path(path)
Expand All @@ -309,7 +307,6 @@ def _save(self, path: Union[str, Any], embedder_config: dict[str, Any]) -> None:
with (path / "config.json").open("w") as f:
json.dump(config, f, ensure_ascii=False, indent=4)


@classmethod
def load(cls, path: Union[Path, str]) -> "GeoVexEmbedder":
"""
Expand All @@ -331,7 +328,9 @@ def _load(cls, path: Union[Path, str], model_module: type[ModelT]) -> "GeoVexEmb
with (path / "config.json").open("r") as f:
config = json.load(f)

config["embedder_config"]["target_features"] = json.loads(config["embedder_config"]["target_features"])
config["embedder_config"]["target_features"] = json.loads(
config["embedder_config"]["target_features"]
)
embedder = cls(**config["embedder_config"])
model_path = path / "model.pt"
model = model_module.load(model_path, **config["model_config"])
Expand Down
4 changes: 2 additions & 2 deletions srai/embedders/geovex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,5 +590,5 @@ def get_config(self) -> dict[str, int]:
"conv_layers": self.conv_layers,
"emb_size": self.emb_size,
"learning_rate": self.lr,
"conv_layer_size": self.conv_layer_size
}
"conv_layer_size": self.conv_layer_size,
}
4 changes: 2 additions & 2 deletions tests/embedders/geovex/test_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ def test_embedder_save_load() -> None:
joint_gdf,
neighbourhood,
trainer_kwargs=TRAINER_KWARGS,
learning_rate=0.001
learning_rate=0.001,
)

# verify that the model was loaded correctly
assert_frame_equal(result_df, loaded_result_df, atol=1e-1)

# check type of model
assert isinstance(loaded_embedder._model, GeoVexModel)

# safely clean up tmp_models directory
os.remove(tmp_models_dir / "test_model" / "model.pt")

Check failure on line 190 in tests/embedders/geovex/test_embedder.py

View workflow job for this annotation

GitHub Actions / Run pre-commit manual stage

Refurb FURB144

Replace `os.remove(x)` with `x.unlink()`
os.remove(tmp_models_dir / "test_model" / "config.json")

Check failure on line 191 in tests/embedders/geovex/test_embedder.py

View workflow job for this annotation

GitHub Actions / Run pre-commit manual stage

Refurb FURB144

Replace `os.remove(x)` with `x.unlink()`
Expand Down

0 comments on commit 284b994

Please sign in to comment.