Skip to content

Commit

Permalink
Fix type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Apr 20, 2024
1 parent 58ecf46 commit b37f223
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/datasets/test_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __getitem__(self, query: BoundingBox) -> dict[str, BoundingBox]:


class CustomRasterDataset(RasterDataset):
def __init__(self, dtype: torch.dtype, *args: Any, **kwargs: Any) -> Any:
def __init__(self, dtype: torch.dtype, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._dtype = dtype

Expand Down Expand Up @@ -287,15 +287,15 @@ def test_getitem_uint_dtype(self, dtype: str) -> None:
assert x["image"].dtype == torch.float32

@pytest.mark.parametrize("dtype", [torch.float, torch.double])
def test_resampling_float_dtype(self, dtype: str) -> None:
def test_resampling_float_dtype(self, dtype: torch.dtype) -> None:
paths = os.path.join("tests", "data", "raster", "uint16")
ds = CustomRasterDataset(dtype, paths)
x = ds[ds.bounds]
assert x["image"].dtype == dtype
assert ds.resampling == Resampling.cubic

@pytest.mark.parametrize("dtype", [torch.long, torch.bool])
def test_resampling_int_dtype(self, dtype: str) -> None:
def test_resampling_int_dtype(self, dtype: torch.dtype) -> None:
paths = os.path.join("tests", "data", "raster", "uint16")
ds = CustomRasterDataset(dtype, paths)
x = ds[ds.bounds]
Expand Down

0 comments on commit b37f223

Please sign in to comment.