Skip to content

Commit

Permalink
Merge branch 'main' into tutorial_be
Browse files Browse the repository at this point in the history
  • Loading branch information
burakekim authored Jan 11, 2025
2 parents 65fd5a1 + 68e0cfe commit 15a921d
Show file tree
Hide file tree
Showing 31 changed files with 100 additions and 100 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.0
rev: v0.9.1
hooks:
- id: ruff
types_or:
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@
"sample = dataset[idx]\n",
"rgb = sample['image'][0, 1:4]\n",
"image = T.ToPILImage()(rgb)\n",
"print(f\"Class Label: {dataset.classes[sample['label']]}\")\n",
"print(f'Class Label: {dataset.classes[sample[\"label\"]]}')\n",
"image.resize((256, 256), resample=Image.BILINEAR)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion experiments/torchgeo/run_resisc45_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool:
for model, lr, loss, weights in itertools.product(
model_options, lr_options, loss_options, weight_options
):
experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}"
experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}'

output_dir = os.path.join('output', 'resisc45_experiments')
log_dir = os.path.join(output_dir, 'logs')
Expand Down
2 changes: 1 addition & 1 deletion experiments/torchgeo/run_so2sat_byol_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool:
for model, lr, loss, weights, bands in itertools.product(
model_options, lr_options, loss_options, weight_options, bands_options
):
experiment_name = f"{model}_{lr}_{loss}_byol_{bands}-{weights.split('/')[-2]}"
experiment_name = f'{model}_{lr}_{loss}_byol_{bands}-{weights.split("/")[-2]}'

output_dir = os.path.join('output', 'so2sat_experiments')
log_dir = os.path.join(output_dir, 'logs')
Expand Down
2 changes: 1 addition & 1 deletion experiments/torchgeo/run_so2sat_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool:
for model, lr, loss, weights in itertools.product(
model_options, lr_options, loss_options, weight_options
):
experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}"
experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}'

output_dir = os.path.join('output', 'so2sat_experiments')
log_dir = os.path.join(output_dir, 'logs')
Expand Down
2 changes: 1 addition & 1 deletion experiments/torchgeo/run_so2sat_seed_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def do_work(work: 'Queue[str]', gpu_idx: int) -> bool:
for model, lr, loss, weights, seed in itertools.product(
model_options, lr_options, loss_options, weight_options, seeds
):
experiment_name = f"{model}_{lr}_{loss}_{weights.replace('_', '-')}_{seed}"
experiment_name = f'{model}_{lr}_{loss}_{weights.replace("_", "-")}_{seed}'

output_dir = os.path.join('output', 'so2sat_seed_experiments')
log_dir = os.path.join(output_dir, 'logs')
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ docs = [
style = [
# mypy 0.900+ required for pyproject.toml support
"mypy>=0.900",
# ruff 0.8+ required for removal of ANN101, ANN102
"ruff>=0.8",
# ruff 0.9+ required for 2025 style guide
"ruff>=0.9",
]
tests = [
# nbmake 1.3.3+ required for variable mocking
Expand Down
2 changes: 1 addition & 1 deletion requirements/required.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# setup
setuptools==75.6.0
setuptools==75.8.0

# install
einops==0.8.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/style.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# style
mypy==1.14.1
ruff==0.8.5
ruff==0.9.1
6 changes: 3 additions & 3 deletions tests/data/inria/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def generate_test_data(root: str, n_samples: int = 2) -> str:
lbl = np.random.randint(2, size=size, dtype=dtype)
timg = np.random.randint(dtype_max, size=size, dtype=dtype)

img_path = os.path.join(img_dir, f'austin{i+1}.tif')
lbl_path = os.path.join(lbl_dir, f'austin{i+1}.tif')
timg_path = os.path.join(timg_dir, f'austin{i+10}.tif')
img_path = os.path.join(img_dir, f'austin{i + 1}.tif')
lbl_path = os.path.join(lbl_dir, f'austin{i + 1}.tif')
timg_path = os.path.join(timg_dir, f'austin{i + 10}.tif')

write_data(img_path, img, driver, crs, transform)
write_data(lbl_path, lbl, driver, crs, transform)
Expand Down
2 changes: 1 addition & 1 deletion tests/data/seasonet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
os.remove(archive)

for grid, comp in zip(grids, name_comps):
file_name = f"{comp[0]}_{''.join(comp[1:8])}_{'_'.join(comp[8:])}"
file_name = f'{comp[0]}_{"".join(comp[1:8])}_{"_".join(comp[8:])}'
dir = os.path.join(season, f'grid{grid}', file_name)
os.makedirs(dir)

Expand Down
4 changes: 2 additions & 2 deletions tests/data/ssl4eo_benchmark_landsat/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def create_tarballs(directories: str) -> None:
# mask directory cdl
mask_keep = ['tm_toa', 'etm_sr', 'oli_sr']
mask_filenames = {
f"ssl4eo_l_{key.split('_')[0]}_cdl": val
f'ssl4eo_l_{key.split("_")[0]}_cdl': val
for key, val in filenames.items()
if key in mask_keep
}
Expand All @@ -203,7 +203,7 @@ def create_tarballs(directories: str) -> None:

# mask directory nlcd
mask_filenames = {
f"ssl4eo_l_{key.split('_')[0]}_nlcd": val
f'ssl4eo_l_{key.split("_")[0]}_nlcd': val
for key, val in filenames.items()
if key in mask_keep
}
Expand Down
12 changes: 6 additions & 6 deletions tests/datamodules/test_digital_typhoon.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ def find_max_time_per_id(
# Assert that each max value in train_max_values is lower
# than in val_max_values for each key id
for id, max_value in train_max_values.items():
assert (
id not in val_max_values or max_value < val_max_values[id]
), f'Max value for id {id} in train is not lower than in validation.'
assert id not in val_max_values or max_value < val_max_values[id], (
f'Max value for id {id} in train is not lower than in validation.'
)
else:
train_ids = {seq['id'] for seq in train_sequences}
val_ids = {seq['id'] for seq in val_sequences}

# Assert that the intersection between train_ids and val_ids is empty
assert (
len(train_ids & val_ids) == 0
), 'Train and validation datasets have overlapping ids.'
assert len(train_ids & val_ids) == 0, (
'Train and validation datasets have overlapping ids.'
)
6 changes: 3 additions & 3 deletions torchgeo/datamodules/digital_typhoon.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def __init__(
"""
super().__init__(DigitalTyphoon, batch_size, num_workers, **kwargs)

assert (
split_by in self.valid_split_types
), f'Please choose from {self.valid_split_types}'
assert split_by in self.valid_split_types, (
f'Please choose from {self.valid_split_types}'
)
self.split_by = split_by

def _split_dataset(
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datamodules/ftw.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def __init__(
Raises:
AssertionError: If 'countries' are specified in kwargs
"""
assert (
'countries' not in kwargs
), "Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs"
assert 'countries' not in kwargs, (
"Please specify 'train_countries', 'val_countries', and 'test_countries' instead of 'countries' inside kwargs"
)

super().__init__(FieldsOfTheWorld, batch_size, num_workers, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datasets/agrifieldnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def __init__(
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert (
set(classes) <= self.cmap.keys()
), f'Only the following classes are valid: {list(self.cmap.keys())}.'
assert set(classes) <= self.cmap.keys(), (
f'Only the following classes are valid: {list(self.cmap.keys())}.'
)
assert 0 in classes, 'Classes must include the background class: 0'

self.paths = paths
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/datasets/bigearthnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,9 @@ def plot(
ax.imshow(image)
ax.axis('off')
if show_titles:
title = f"Labels: {', '.join(labels)}"
title = f'Labels: {", ".join(labels)}'
if showing_predictions:
title += f"\nPredictions: {', '.join(predictions)}"
title += f'\nPredictions: {", ".join(predictions)}'
ax.set_title(title)

if suptitle is not None:
Expand Down
12 changes: 6 additions & 6 deletions torchgeo/datasets/biomassters.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def __init__(
"""
self.root = root

assert (
split in self.valid_splits
), f'Please choose one of the valid splits: {self.valid_splits}.'
assert split in self.valid_splits, (
f'Please choose one of the valid splits: {self.valid_splits}.'
)
self.split = split

assert set(sensors).issubset(
set(self.valid_sensors)
), f'Please choose a subset of valid sensors: {self.valid_sensors}.'
assert set(sensors).issubset(set(self.valid_sensors)), (
f'Please choose a subset of valid sensors: {self.valid_sensors}.'
)
self.sensors = sensors
self.as_time_series = as_time_series

Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datasets/cdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ def __init__(
'CDL data product only exists for the following years: '
f'{list(self.md5s.keys())}.'
)
assert (
set(classes) <= self.cmap.keys()
), f'Only the following classes are valid: {list(self.cmap.keys())}.'
assert set(classes) <= self.cmap.keys(), (
f'Only the following classes are valid: {list(self.cmap.keys())}.'
)
assert 0 in classes, 'Classes must include the background class: 0'

self.paths = paths
Expand Down
12 changes: 6 additions & 6 deletions torchgeo/datasets/cms_mangrove_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,15 +204,15 @@ def __init__(
self.checksum = checksum

assert isinstance(country, str), 'Country argument must be a str.'
assert (
country in self.all_countries
), f'You have selected an invalid country, please choose one of {self.all_countries}'
assert country in self.all_countries, (
f'You have selected an invalid country, please choose one of {self.all_countries}'
)
self.country = country

assert isinstance(measurement, str), 'Measurement must be a string.'
assert (
measurement in self.measurements
), f'You have entered an invalid measurement, please choose one of {self.measurements}.'
assert measurement in self.measurements, (
f'You have entered an invalid measurement, please choose one of {self.measurements}.'
)
self.measurement = measurement

self.filename_glob = f'**/Mangrove_{self.measurement}_{self.country}*'
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datasets/digital_typhoon.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ def __init__(
self.min_feature_value = min_feature_value
self.max_feature_value = max_feature_value

assert (
task in self.valid_tasks
), f'Please choose one of {self.valid_tasks}, you provided {task}.'
assert task in self.valid_tasks, (
f'Please choose one of {self.valid_tasks}, you provided {task}.'
)
self.task = task

assert set(features).issubset(set(self.valid_features))
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datasets/loveda.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def __init__(
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.splits
assert set(scene).intersection(
set(self.scenes)
), "The possible scenes are 'rural' and/or 'urban'"
assert set(scene).intersection(set(self.scenes)), (
"The possible scenes are 'rural' and/or 'urban'"
)
assert len(scene) <= 2, "There are no other scenes than 'rural' or 'urban'"

self.root = root
Expand Down
12 changes: 6 additions & 6 deletions torchgeo/datasets/mdas.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,13 @@ def __init__(
"""
self.root = root
self.download = download
assert all(
sub in self.valid_subareas for sub in subareas
), f'Subareas must be one of {self.valid_subareas}'
assert all(sub in self.valid_subareas for sub in subareas), (
f'Subareas must be one of {self.valid_subareas}'
)
self.subareas = subareas
assert all(
mod in self.valid_modalities for mod in modalities
), f'Modalities must be one of {self.valid_modalities}'
assert all(mod in self.valid_modalities for mod in modalities), (
f'Modalities must be one of {self.valid_modalities}'
)
self.modalities = modalities
self.transforms = transforms
self.checksum = checksum
Expand Down
12 changes: 6 additions & 6 deletions torchgeo/datasets/mmearth.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ def __init__(
"""
lazy_import('h5py')

assert (
normalization_mode in self.norm_modes
), f'Invalid normalization mode: {normalization_mode}, please choose from {self.norm_modes}'
assert (
subset in self.subsets
), f'Invalid dataset version: {subset}, please choose from {self.subsets}'
assert normalization_mode in self.norm_modes, (
f'Invalid normalization mode: {normalization_mode}, please choose from {self.norm_modes}'
)
assert subset in self.subsets, (
f'Invalid dataset version: {subset}, please choose from {self.subsets}'
)

self._validate_modalities(modalities)
self.modalities = modalities
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datasets/nlcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def __init__(
'NLCD data product only exists for the following years: '
f'{list(self.md5s.keys())}.'
)
assert (
set(classes) <= self.cmap.keys()
), f'Only the following classes are valid: {list(self.cmap.keys())}.'
assert set(classes) <= self.cmap.keys(), (
f'Only the following classes are valid: {list(self.cmap.keys())}.'
)
assert 0 in classes, 'Classes must include the background class: 0'

self.paths = paths
Expand Down
2 changes: 1 addition & 1 deletion torchgeo/datasets/seasonet.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def plot(
axs[ax].imshow(image)
axs[ax].axis('off')
if show_titles:
axs[ax].set_title(f'Image {ax+1}')
axs[ax].set_title(f'Image {ax + 1}')

axs[ax + 1].imshow(mask, vmin=0, vmax=32, cmap=plt_cmap, interpolation='none')
axs[ax + 1].axis('off')
Expand Down
12 changes: 6 additions & 6 deletions torchgeo/datasets/skippd.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,14 @@ def __init__(
"""
lazy_import('h5py')

assert (
split in self.valid_splits
), f'Please choose one of these valid data splits {self.valid_splits}.'
assert split in self.valid_splits, (
f'Please choose one of these valid data splits {self.valid_splits}.'
)
self.split = split

assert (
task in self.valid_tasks
), f'Please choose one of these valid tasks {self.valid_tasks}.'
assert task in self.valid_tasks, (
f'Please choose one of these valid tasks {self.valid_tasks}.'
)
self.task = task

self.root = root
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/datasets/south_africa_crop_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def __init__(
Raises:
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert (
set(classes) <= self.cmap.keys()
), f'Only the following classes are valid: {list(self.cmap.keys())}.'
assert set(classes) <= self.cmap.keys(), (
f'Only the following classes are valid: {list(self.cmap.keys())}.'
)
assert 0 in classes, 'Classes must include the background class: 0'

self.paths = paths
Expand Down
Loading

0 comments on commit 15a921d

Please sign in to comment.