Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

457 geovex model saving loading #460

Merged
merged 17 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion srai/embedders/geovex/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
[1] https://openreview.net/forum?id=7bvWopYY1H
"""

import json
from pathlib import Path
from typing import Any, Optional, TypeVar, Union

import geopandas as gpd
Expand All @@ -15,7 +17,7 @@

from srai._optional import import_optional_dependencies
from srai.constants import REGIONS_INDEX
from srai.embedders import CountEmbedder
from srai.embedders import CountEmbedder, ModelT
from srai.embedders.geovex.dataset import HexagonalDataset
from srai.embedders.geovex.model import GeoVexModel
from srai.exceptions import ModelNotFitException
Expand Down Expand Up @@ -259,3 +261,80 @@
raise ValueError(
f"The convolutional layers in GeoVex expect >= {conv_layer_size} features."
)

def save(self, path: Union[str, Any]) -> None:
"""
Save the model to a directory.

Args:
path (Union[str, Any]): Path to the directory.
"""
# embedder_config must match the constructor signature:
# target_features: Union[list[str], OsmTagsFilter, GroupedOsmTagsFilter],
# batch_size: Optional[int] = 32,
# neighbourhood_radius: int = 4,
# convolutional_layers: int = 2,
# embedding_size: int = 32,
# convolutional_layer_size: int = 256,
embedder_config = {
"target_features": self.expected_output_features.to_json(),
"batch_size": self._batch_size,
"neighbourhood_radius": self._r,
"convolutional_layers": self._convolutional_layers,
"embedding_size": self._embedding_size,
"convolutional_layer_size": self._convolutional_layer_size,
}
self._save(path, embedder_config)

def _save(self, path: Union[str, Any], embedder_config: dict[str, Any]) -> None:
if isinstance(path, str):
path = Path(path)

Check warning on line 291 in srai/embedders/geovex/embedder.py

View check run for this annotation

Codecov / codecov/patch

srai/embedders/geovex/embedder.py#L291

Added line #L291 was not covered by tests

self._check_is_fitted()

path.mkdir(parents=True, exist_ok=True)

# save model and config
self._model.save(path / "model.pt") # type: ignore
# combine model config and embedder config
if self._model is not None:
model_config = self._model.get_config()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a check for a model not being None? If self._model was None, the save function would throw an exception.

Should this be a check for a model config, or should the save function be moved inside this if statement?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see in the Hex2Vec implementation that there is no such check with if statement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RaczeQ thanks for the feedback. The type linter was complaining about _model being None so we put in a check. But let us remove it and push again.


config = {
"model_config": model_config,
"embedder_config": embedder_config,
}

with (path / "config.json").open("w") as f:
json.dump(config, f, ensure_ascii=False, indent=4)

@classmethod
def load(cls, path: Union[Path, str]) -> "GeoVexEmbedder":
"""
Load the model from a directory.

Args:
path (Union[Path, str]): Path to the directory.
model_module (type[ModelT]): Model class.

Returns:
GeoVexEmbedder: GeoVexEmbedder object.
"""
return cls._load(path, GeoVexModel)

@classmethod
def _load(cls, path: Union[Path, str], model_module: type[ModelT]) -> "GeoVexEmbedder":
if isinstance(path, str):
path = Path(path)

Check warning on line 328 in srai/embedders/geovex/embedder.py

View check run for this annotation

Codecov / codecov/patch

srai/embedders/geovex/embedder.py#L328

Added line #L328 was not covered by tests
with (path / "config.json").open("r") as f:
config = json.load(f)

config["embedder_config"]["target_features"] = json.loads(
config["embedder_config"]["target_features"]
)
embedder = cls(**config["embedder_config"])
model_path = path / "model.pt"
model = model_module.load(model_path, **config["model_config"])
embedder._model = model
embedder._is_fitted = True
return embedder
19 changes: 19 additions & 0 deletions srai/embedders/geovex/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,8 @@ def __init__(
self.R = radius
self.lr = learning_rate
self.emb_size = emb_size
self.conv_layer_size = conv_layer_size
self.conv_layers = conv_layers

# input size is 2R + 2
self.M = get_shape(self.R)
Expand Down Expand Up @@ -573,3 +575,20 @@ def configure_optimizers(self) -> list["torch.optim.Optimizer"]:
lr=self.lr,
)
return [opt]

# override get_config to return the model configuration
def get_config(self) -> dict[str, int | float]:
"""
Get the model configuration.

Returns:
Dict[str, int | float]: The model configuration.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are still supporting the Python 3.9 version, so it doesn't accept the pipe syntax yet and tests on this Python version failed. Please change to the Union version. It will also require additional import at the top from typing.

Suggested change
def get_config(self) -> dict[str, int | float]:
"""
Get the model configuration.
Returns:
Dict[str, int | float]: The model configuration.
"""
def get_config(self) -> dict[str, Union[int, float]]:
"""
Get the model configuration.
Returns:
Dict[str, Union[int, float]]: The model configuration.
"""

return {
"k_dim": self.k_dim,
"radius": self.R,
"conv_layers": self.conv_layers,
"emb_size": self.emb_size,
"learning_rate": self.lr,
"conv_layer_size": self.conv_layer_size,
}
87 changes: 87 additions & 0 deletions tests/embedders/geovex/test_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,90 @@
print(result_df.head())
print(expected.head())
assert_frame_equal(result_df, expected, atol=1e-1)


def test_embedder_save_load() -> None:
"""Test GeoVexEmbedder model saving and loading."""
test_files_path = Path(__file__).parent / "test_files"
for test_case in PREDEFINED_TEST_CASES:
name = test_case["test_case_name"]
seed = test_case["seed"]
radius: int = test_case["model_radius"] # type: ignore

# Load data from parquet files
regions_gdf = gpd.read_parquet(test_files_path / f"{name}_regions.parquet")
features_gdf = gpd.read_parquet(test_files_path / f"{name}_features.parquet")
joint_gdf = pd.read_parquet(test_files_path / f"{name}_joint.parquet")

# Set seed for reproducibility
seed_everything(seed, workers=True)
os.environ["PYTHONHASHSEED"] = str(seed)
torch.use_deterministic_algorithms(True)

# Initialize neighbourhood and target features for the embedder
neighbourhood = H3Neighbourhood(regions_gdf)
target_features = [
f"{st}_{t}"
for st in test_case["tags"] # type: ignore
for t in test_case["tags"][st] # type: ignore
]

# Initialize GeoVexEmbedder with the given parameters
embedder = GeoVexEmbedder(
target_features=target_features,
batch_size=10,
neighbourhood_radius=radius,
embedding_size=EMBEDDING_SIZE,
convolutional_layers=test_case["num_layers"], # type: ignore
convolutional_layer_size=test_case["convolutional_layer_size"], # type: ignore
)

# Prepare dataset for the embedder
counts_df, _, _ = embedder._prepare_dataset(
regions_gdf, features_gdf, joint_gdf, neighbourhood, embedder._batch_size, shuffle=True
)

embedder._prepare_model(counts_df, 0.001)

# Initialize model parameters to a constant value for reproducibility
for _, param in cast(GeoVexModel, embedder._model).named_parameters():
param.data.fill_(0.01)

result_df = embedder.fit_transform(
regions_gdf,
features_gdf,
joint_gdf,
neighbourhood,
trainer_kwargs=TRAINER_KWARGS,
learning_rate=0.001,
)

tmp_models_dir = Path(__file__).parent / "test_files" / "tmp_models"

# test model saving functionality
embedder.save(tmp_models_dir / "test_model")

# load the saved model
loaded_embedder = GeoVexEmbedder.load(tmp_models_dir / "test_model")

# get embeddings from the loaded model
loaded_result_df = loaded_embedder.fit_transform(
regions_gdf,
features_gdf,
joint_gdf,
neighbourhood,
trainer_kwargs=TRAINER_KWARGS,
learning_rate=0.001,
)

# verify that the model was loaded correctly
assert_frame_equal(result_df, loaded_result_df, atol=1e-1)

# check type of model
assert isinstance(loaded_embedder._model, GeoVexModel)

# safely clean up tmp_models directory
os.remove(tmp_models_dir / "test_model" / "model.pt")

Check failure on line 190 in tests/embedders/geovex/test_embedder.py

View workflow job for this annotation

GitHub Actions / Run pre-commit manual stage

Refurb FURB144

Replace `os.remove(x)` with `x.unlink()`
os.remove(tmp_models_dir / "test_model" / "config.json")

Check failure on line 191 in tests/embedders/geovex/test_embedder.py

View workflow job for this annotation

GitHub Actions / Run pre-commit manual stage

Refurb FURB144

Replace `os.remove(x)` with `x.unlink()`
os.rmdir(tmp_models_dir / "test_model")
os.rmdir(tmp_models_dir)
Loading