Skip to content

Commit

Permalink
id ood length method and polishing
Browse files Browse the repository at this point in the history
  • Loading branch information
burakekim committed Feb 1, 2025
1 parent 0f3ceb7 commit a74c99f
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions torchgeo/datasets/xview.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,10 @@ def __init__(
self.files = self._load_split_files_by_disaster_and_type(
self.all_files, id_ood_disaster[0], id_ood_disaster[1]
)
print(
f"Loaded for disasters ID and OOD: {len(self.files['train'])} train, {len(self.files['test'])} test files."
)

train_size, test_size = self.get_id_ood_sizes()
print(f"ID sample len: {train_size}, OOD sample len: {test_size}")

def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
"""Get an item from the dataset at the given index."""
file_info = (
Expand Down Expand Up @@ -396,6 +396,12 @@ def __len__(self) -> int:
else len(self.files['test'])
)


def get_id_ood_sizes(self) -> tuple[int, int]:
"""Return the number of samples in the train and test splits."""
return (len(self.files['train']), len(self.files['test']))


def _initialize_files(self, root: str) -> list[dict[str, str]]:
"""Initialize the dataset by loading file paths and computing basenames with sample numbers."""
all_files = []
Expand Down Expand Up @@ -454,7 +460,7 @@ def _load_split_files_by_disaster_and_type(

disaster_list.append(disaster_name)

# Filter for in-domain (ID) training set
# Filter for in-distribution (ID) training set
if disaster_name == id_disaster['disaster_name']:
if (
id_disaster.get('pre-post') == 'both'
Expand Down

0 comments on commit a74c99f

Please sign in to comment.