diff --git a/srai/embedders/geovex/embedder.py b/srai/embedders/geovex/embedder.py index 4633e144..83e6d27a 100644 --- a/srai/embedders/geovex/embedder.py +++ b/srai/embedders/geovex/embedder.py @@ -91,6 +91,7 @@ def transform( regions_gdf: gpd.GeoDataFrame, features_gdf: gpd.GeoDataFrame, joint_gdf: gpd.GeoDataFrame, + dataloader_kwargs: Optional[dict[str, Any]] = None, ) -> pd.DataFrame: """ Create region embeddings. @@ -116,6 +117,7 @@ def transform( neighbourhood, self._batch_size, shuffle=False, + dataloader_kwargs=dataloader_kwargs, ) return self._transform(dataset=self._dataset, dataloader=dataloader) @@ -124,9 +126,15 @@ def _transform( self, dataset: HexagonalDataset[T], dataloader: Optional[DataLoader] = None, + dataloader_kwargs: Optional[dict[str, Any]] = None, ) -> pd.DataFrame: + dataloader_kwargs = dataloader_kwargs or {} + if "batch_size" not in dataloader_kwargs: + dataloader_kwargs["batch_size"] = self._batch_size + if "shuffle" not in dataloader_kwargs: + dataloader_kwargs["shuffle"] = False if dataloader is None: - dataloader = DataLoader(dataset, batch_size=self._batch_size, shuffle=False) + dataloader = DataLoader(dataset, **dataloader_kwargs) embeddings = [ self._model.encoder(batch).detach().numpy() # type: ignore @@ -149,6 +157,7 @@ def fit( neighbourhood: H3Neighbourhood, learning_rate: float = 0.001, trainer_kwargs: Optional[dict[str, Any]] = None, + dataloader_kwargs: Optional[dict[str, Any]] = None, ) -> None: """ Fit the model to the data. @@ -167,7 +176,13 @@ def fit( trainer_kwargs = self._prepare_trainer_kwargs(trainer_kwargs) counts_df, dataloader, dataset = self._prepare_dataset( # type: ignore - regions_gdf, features_gdf, joint_gdf, neighbourhood, self._batch_size, shuffle=True + regions_gdf, + features_gdf, + joint_gdf, + neighbourhood, + self._batch_size, + shuffle=True, + dataloader_kwargs=dataloader_kwargs, ) self._prepare_model(counts_df, learning_rate) @@ -196,6 +211,7 @@ def _prepare_dataset( neighbourhood: H3Neighbourhood, batch_size: Optional[int], shuffle: bool = True, + dataloader_kwargs: Optional[dict[str, Any]] = None, ) -> tuple[pd.DataFrame, DataLoader, HexagonalDataset[T]]: counts_df = self._get_raw_counts(regions_gdf, features_gdf, joint_gdf) dataset: HexagonalDataset[T] = HexagonalDataset( @@ -203,7 +219,12 @@ def _prepare_dataset( neighbourhood, neighbor_k_ring=self._r, ) - dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle) + dataloader_kwargs = dataloader_kwargs or {} + if "batch_size" not in dataloader_kwargs: + dataloader_kwargs["batch_size"] = batch_size + if "shuffle" not in dataloader_kwargs: + dataloader_kwargs["shuffle"] = shuffle + dataloader = DataLoader(dataset, **dataloader_kwargs) return counts_df, dataloader, dataset def fit_transform( @@ -214,6 +235,7 @@ def fit_transform( neighbourhood: H3Neighbourhood, learning_rate: float = 0.001, trainer_kwargs: Optional[dict[str, Any]] = None, + dataloader_kwargs: Optional[dict[str, Any]] = None, ) -> pd.DataFrame: """ Fit the model to the data and create region embeddings. @@ -236,6 +258,7 @@ def fit_transform( neighbourhood=neighbourhood, learning_rate=learning_rate, trainer_kwargs=trainer_kwargs, + dataloader_kwargs=dataloader_kwargs, ) assert self._dataset is not None # for mypy return self._transform(dataset=self._dataset) diff --git a/srai/embedders/gtfs2vec/embedder.py b/srai/embedders/gtfs2vec/embedder.py index 1d95977e..4f99f085 100644 --- a/srai/embedders/gtfs2vec/embedder.py +++ b/srai/embedders/gtfs2vec/embedder.py @@ -83,6 +83,7 @@ def fit( regions_gdf: gpd.GeoDataFrame, features_gdf: gpd.GeoDataFrame, joint_gdf: gpd.GeoDataFrame, + dataloader_kwargs: Optional[dict[str, Any]] = None, ) -> None: """ Fit model to a given data. @@ -101,13 +102,14 @@ def fit( features = self._prepare_features(regions_gdf, features_gdf, joint_gdf) if not self._skip_autoencoder: - self._model = self._train_model_unsupervised(features) + self._model = self._train_model_unsupervised(features, dataloader_kwargs) def fit_transform( self, regions_gdf: gpd.GeoDataFrame, features_gdf: gpd.GeoDataFrame, joint_gdf: gpd.GeoDataFrame, + dataloader_kwargs: Optional[dict[str, Any]] = None, ) -> pd.DataFrame: """ Fit model and transform a given data. @@ -131,7 +133,7 @@ def fit_transform( if self._skip_autoencoder: return features else: - self._model = self._train_model_unsupervised(features) + self._model = self._train_model_unsupervised(features, dataloader_kwargs) return self._embed(features) def _maybe_get_model(self) -> GTFS2VecModel: @@ -228,7 +230,9 @@ def _normalize_features(self, features: pd.DataFrame) -> pd.DataFrame: return features - def _train_model_unsupervised(self, features: pd.DataFrame) -> GTFS2VecModel: + def _train_model_unsupervised( + self, features: pd.DataFrame, dataloader_kwargs: Optional[dict[str, Any]] = None + ) -> GTFS2VecModel: """ Train model unsupervised. @@ -244,7 +248,14 @@ def _train_model_unsupervised(self, features: pd.DataFrame) -> GTFS2VecModel: n_embed=self._embedding_size, ) X = features.to_numpy().astype(np.float32) - x_dataloader = DataLoader(X, batch_size=24, shuffle=True, num_workers=4) + dataloader_kwargs = dataloader_kwargs or {} + if "num_workers" not in dataloader_kwargs: + dataloader_kwargs["num_workers"] = 4 + if "batch_size" not in dataloader_kwargs: + dataloader_kwargs["batch_size"] = 24 + if "shuffle" not in dataloader_kwargs: + dataloader_kwargs["shuffle"] = True + x_dataloader = DataLoader(X, **dataloader_kwargs) trainer = pl.Trainer(max_epochs=10) trainer.fit(model, x_dataloader) diff --git a/srai/embedders/hex2vec/embedder.py b/srai/embedders/hex2vec/embedder.py index e1af7484..893b6055 100644 --- a/srai/embedders/hex2vec/embedder.py +++ b/srai/embedders/hex2vec/embedder.py @@ -103,6 +103,7 @@ def fit( batch_size: int = 32, learning_rate: float = 0.001, trainer_kwargs: Optional[dict[str, Any]] = None, + dataloader_kwargs: Optional[dict[str, Any]] = None, ) -> None: """ Fit the model to the data. @@ -141,7 +142,13 @@ def fit( layer_sizes=[num_features, *self._encoder_sizes], learning_rate=learning_rate ) dataset = NeighbourDataset(counts_df, neighbourhood, negative_sample_k_distance) - dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + dataloader_kwargs = dataloader_kwargs or {} + if "batch_size" not in dataloader_kwargs: + dataloader_kwargs["batch_size"] = batch_size + if "shuffle" not in dataloader_kwargs: + dataloader_kwargs["shuffle"] = True + + dataloader = DataLoader(dataset, **dataloader_kwargs) trainer = pl.Trainer(**trainer_kwargs) trainer.fit(self._model, dataloader) @@ -157,6 +164,7 @@ def fit_transform( batch_size: int = 32, learning_rate: float = 0.001, trainer_kwargs: Optional[dict[str, Any]] = None, + dataloader_kwargs: Optional[dict[str, Any]] = None, ) -> pd.DataFrame: """ Fit the model to the data and return the embeddings. @@ -192,6 +200,7 @@ def fit_transform( batch_size, learning_rate, trainer_kwargs, + dataloader_kwargs, ) return self.transform(regions_gdf, features_gdf, joint_gdf)