Skip to content

Commit

Permalink
Split NICO++ according to the paper 'Change is Hard' (#302)
Browse files Browse the repository at this point in the history
Co-authored-by: Myles Bartlett <[email protected]>
  • Loading branch information
tmke8 and MylesBartlett authored May 12, 2023
1 parent 71e7eaa commit 131c654
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 102 deletions.
17 changes: 4 additions & 13 deletions conduit/data/datamodules/vision/nico_plus_plus.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from typing_extensions import override

from conduit.data.datamodules.vision.base import CdtVisionDataModule
from conduit.data.datasets.utils import stratified_split
from conduit.data.datasets.vision import NICOPP, NicoPPTarget, SampleType
from conduit.data.datasets.vision import NICOPP, NicoPPSplit, NicoPPTarget, SampleType
from conduit.data.structures import TrainValTestSplit

__all__ = ["NICOPPDataModule"]
Expand All @@ -19,7 +18,6 @@ class NICOPPDataModule(CdtVisionDataModule[NICOPP, SampleType]):

image_size: int = 224
superclasses: Optional[List[NicoPPTarget]] = None
make_biased: bool = True

@property
@override
Expand All @@ -44,16 +42,9 @@ def prepare_data(self, *args: Any, **kwargs: Any) -> None:

@override
def _get_splits(self) -> TrainValTestSplit[NICOPP]:
all_data = NICOPP(root=self.root, superclasses=self.superclasses, transform=None)
train_val_prop = 1 - self.test_prop
train_val_data, test_data = stratified_split(
all_data,
default_train_prop=train_val_prop,
train_props=all_data.default_train_props if self.make_biased else None,
seed=self.seed,
)
val_data, train_data = train_val_data.random_split(
props=self.val_prop / train_val_prop, seed=self.seed
train_data, val_data, test_data = (
NICOPP(root=self.root, superclasses=self.superclasses, transform=None, split=split)
for split in NicoPPSplit
)

return TrainValTestSplit(train=train_data, val=val_data, test=test_data)
179 changes: 90 additions & 89 deletions conduit/data/datasets/vision/nico_plus_plus.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""NICO Dataset."""
from enum import auto
from enum import Enum, auto
from functools import cached_property
import json
from pathlib import Path
import random
from typing import ClassVar, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union

import pandas as pd
Expand All @@ -15,7 +17,7 @@
from .base import CdtVisionDataset
from .utils import ImageTform

__all__ = ["NICOPP", "NicoPPTarget", "NicoPPAttr"]
__all__ = ["NICOPP", "NicoPPTarget", "NicoPPAttr", "NicoPPSplit"]


class NicoPPTarget(StrEnum):
Expand Down Expand Up @@ -94,6 +96,12 @@ class NicoPPAttr(StrEnum):
WATER = auto()


class NicoPPSplit(Enum):
TRAIN = 0
VAL = 1
TEST = 2


SampleType: TypeAlias = TernarySample


Expand All @@ -103,31 +111,11 @@ class NICOPP(CdtVisionDataset[TernarySample, Tensor, Tensor]):
SampleType: TypeAlias = TernarySample
Superclass: TypeAlias = NicoPPTarget
Subclass: TypeAlias = NicoPPAttr
Split: TypeAlias = NicoPPSplit

less_than_75_samples: ClassVar[Dict[NicoPPTarget, Tuple[NicoPPAttr, ...]]] = {
NicoPPTarget.CACTUS: (NicoPPAttr.AUTUMN,),
NicoPPTarget.CORN: (NicoPPAttr.ROCK,),
NicoPPTarget.CRAB: (NicoPPAttr.AUTUMN,),
NicoPPTarget.CROCODILE: (NicoPPAttr.AUTUMN,),
NicoPPTarget.DOLPHIN: (NicoPPAttr.AUTUMN,),
NicoPPTarget.FOOTBALL: (NicoPPAttr.ROCK,),
NicoPPTarget.HAT: (NicoPPAttr.ROCK,),
NicoPPTarget.LIFEBOAT: (NicoPPAttr.AUTUMN,),
NicoPPTarget.LIZARD: (NicoPPAttr.DIM,),
NicoPPTarget.OSTRICH: (NicoPPAttr.AUTUMN,),
NicoPPTarget.PINEAPPLE: (NicoPPAttr.AUTUMN, NicoPPAttr.DIM, NicoPPAttr.ROCK),
NicoPPTarget.RACKET: (NicoPPAttr.AUTUMN, NicoPPAttr.ROCK),
NicoPPTarget.SEAL: (NicoPPAttr.AUTUMN,),
NicoPPTarget.SHRIMP: (
NicoPPAttr.AUTUMN,
NicoPPAttr.DIM,
NicoPPAttr.OUTDOOR,
NicoPPAttr.ROCK,
),
NicoPPTarget.SPIDER: (NicoPPAttr.AUTUMN,),
NicoPPTarget.SUNFLOWER: (NicoPPAttr.AUTUMN,),
NicoPPTarget.WHEAT: (NicoPPAttr.ROCK,),
}
data_split_seed: ClassVar[int] = 666 # this is the seed from the paper
num_samples_val_test: ClassVar[int] = 75 # this is the number from the paper
subpath: ClassVar[Path] = Path("public_dg_0416") / "train"

@parsable
def __init__(
Expand All @@ -136,13 +124,16 @@ def __init__(
*,
transform: Optional[ImageTform] = None,
superclasses: Optional[Sequence[Union[NicoPPTarget, str]]] = None,
split: Optional[Union[NicoPPSplit, str]] = None,
) -> None:
self.superclasses: Optional[List[NicoPPTarget]] = None
if superclasses is not None:
assert superclasses, "superclasses should be a non-empty list"
self.superclasses = [NicoPPTarget(superclass) for superclass in superclasses]
self.split = NicoPPSplit[split.upper()] if isinstance(split, str) else split

self.root = Path(root)
self._base_dir = self.root / "nico_plus_plus" / "track_1" / "public_dg_0416" / "train"
self._base_dir = self.root / "nico_plus_plus"
self._metadata_path = self._base_dir / "metadata.csv"

if not self._check_unzipped():
Expand All @@ -152,90 +143,100 @@ def __init__(
if not self._metadata_path.exists():
self._extract_metadata()

self.metadata = pd.read_csv(self._base_dir / "metadata.csv")
self.metadata = pd.read_csv(self._metadata_path)

if self.superclasses is not None:
self.metadata = self.metadata[self.metadata["superclass"].isin(self.superclasses)]

if self.split is not None:
self.metadata = self.metadata[self.metadata["split"] == self.split.value]

# Divide up the dataframe into its constituent arrays because indexing with pandas is
# substantially slower than indexing with numpy/torch
x = self.metadata["filepath"].to_numpy()
y = torch.as_tensor(self.metadata["superclass_le"].to_numpy(), dtype=torch.long)
s = torch.as_tensor(self.metadata["subclass_le"].to_numpy(), dtype=torch.long)
x = self.metadata["filename"].to_numpy()
y = torch.as_tensor(self.metadata["y"].to_numpy(), dtype=torch.long)
s = torch.as_tensor(self.metadata["a"].to_numpy(), dtype=torch.long)

super().__init__(x=x, y=y, s=s, transform=transform, image_dir=self._base_dir)

@property
def default_train_props(self) -> Dict[int, Dict[int, float]]:
"""Zero out the (s,y) pairs which have fewer than 75 samples."""
atoi = self.subclass_label_encoder
ytoi = self.superclass_label_encoder
return {
ytoi[target]: {atoi[attr]: 0.0 for attr in attrs}
for target, attrs in self.less_than_75_samples.items()
}

@cached_property
def class_tree(self) -> Dict[str, Set[str]]:
return (
self.metadata[["superclass", "subclass"]]
.drop_duplicates()
.groupby("superclass")
.agg(set)
.to_dict()["subclass"]
)
return self.metadata[["y", "a"]].drop_duplicates().groupby("y").agg(set).to_dict()["a"]

@cached_property
def superclass_label_encoder(self) -> Dict[NicoPPTarget, int]:
return dict(
(NicoPPTarget(name), val) for name, val in self._get_label_mapping("superclass")
)

@cached_property
def subclass_label_encoder(self) -> Dict[NicoPPAttr, int]:
return dict((NicoPPAttr(name), val) for name, val in self._get_label_mapping("subclass"))
def superclass_label_decoder(self) -> Dict[int, NicoPPTarget]:
return dict((val, NicoPPTarget(name)) for name, val in self._get_label_mapping("y"))

@cached_property
def superclass_label_decoder(self) -> Dict[int, str]:
return dict((val, name) for name, val in self._get_label_mapping("superclass"))
def subclass_label_decoder(self) -> Dict[int, NicoPPAttr]:
return dict((val, NicoPPAttr(name)) for name, val in self._get_label_mapping("a"))

@cached_property
def subclass_label_decoder(self) -> Dict[int, str]:
return dict((val, name) for name, val in self._get_label_mapping("subclass"))

def _get_label_mapping(self, level: Literal["superclass", "subclass"]) -> List[Tuple[str, int]]:
def _get_label_mapping(self, level: Literal["y", "a"]) -> List[Tuple[str, int]]:
"""Get a list of all possible (name, numerical value) pairs."""
return list(
self.metadata[[level, f"{level}_le"]]
self.metadata[[f"{level}_name", level]]
.drop_duplicates()
.itertuples(index=False, name=None)
)

def _check_unzipped(self) -> bool:
return all((self._base_dir / attr).exists() for attr in NicoPPAttr)
if not all((self._base_dir / self.subpath / attr).exists() for attr in NicoPPAttr):
return False
if not (self._base_dir / "dg_label_id_mapping.json").exists():
return False
return True

def _extract_metadata(self) -> None:
"""Extract concept/context/superclass information from the image filepaths and it save to csv."""
self.logger.info("Extracting metadata.")
image_paths: List[Path] = []
image_paths.extend(self._base_dir.glob(f"**/*.jpg"))
image_paths_str = [str(image.relative_to(self._base_dir)) for image in image_paths]
filepaths = pd.Series(image_paths_str)
metadata = filepaths.str.split("/", expand=True).rename(
columns={0: "subclass", 1: "superclass", 2: "filename"}
self.logger.info("Generating metadata for NICO++...")
attributes = ["autumn", "dim", "grass", "outdoor", "rock", "water"] # 6 attrs, 60 labels
meta = json.load(open(self._base_dir / "dg_label_id_mapping.json", "r"))

def _make_balanced_testset(
df: pd.DataFrame, *, seed: int, num_samples_val_test: int
) -> pd.DataFrame:
# each group has a test set size of (2/3 * num_samples_val_test) and a val set size of
# (1/3 * num_samples_val_test); if total samples in original group < num_samples_val_test,
# val/test will still be split by 1:2, but no training samples remained

random.seed(seed)
val_set, test_set = [], []
for g in pd.unique(df["g"]):
df_group = df[df["g"] == g]
curr_data = df_group["filename"].values
random.shuffle(curr_data)
split_size = min(len(curr_data), num_samples_val_test)
val_set += list(curr_data[: split_size // 3])
test_set += list(curr_data[split_size // 3 : split_size])
self.logger.info(f"Val: {len(val_set)}, Test: {len(test_set)}")
assert len(set(val_set).intersection(set(test_set))) == 0
combined_set = dict(zip(val_set, [NicoPPSplit.VAL.value for _ in range(len(val_set))]))
combined_set.update(
dict(zip(test_set, [NicoPPSplit.TEST.value for _ in range(len(test_set))]))
)
df["split"] = df["filename"].map(combined_set)
df["split"].fillna(NicoPPSplit.TRAIN.value, inplace=True)
df["split"] = df.split.astype(int)
return df

all_data = []
for c, attr in enumerate(attributes):
for label in meta:
folder_path = self._base_dir / self.subpath / attr / label
y = meta[label]
for img_path in Path(folder_path).glob("*.jpg"):
all_data.append(
{
"filename": str(img_path.relative_to(self._base_dir)),
"y": y,
"a": c,
"y_name": label,
"a_name": attr,
}
)
df = pd.DataFrame(all_data)
df["g"] = df["a"] + df["y"] * len(attributes)
df = _make_balanced_testset(
df, seed=self.data_split_seed, num_samples_val_test=self.num_samples_val_test
)

metadata["filepath"] = filepaths
metadata.sort_index(axis=1, inplace=True)
metadata.sort_values(by=["filepath"], axis=0, inplace=True)
metadata = self._label_encode_metadata(metadata)
metadata.to_csv(self._metadata_path, index=False)

@staticmethod
def _label_encode_metadata(metadata: pd.DataFrame) -> pd.DataFrame:
"""Label encode the extracted concept/context/superclass information."""
for col in metadata.columns:
# Skip over filepath and filename columns
if "file" not in col:
# Add a new column containing the label-encoded data
metadata[f"{col}_le"] = metadata[col].factorize()[0]
return metadata
df = df.drop(columns=["g"])
df.to_csv(self._metadata_path, index=False)

0 comments on commit 131c654

Please sign in to comment.