diff --git a/avalanche/benchmarks/classic/__init__.py b/avalanche/benchmarks/classic/__init__.py index a75946368..d6c8b17bf 100644 --- a/avalanche/benchmarks/classic/__init__.py +++ b/avalanche/benchmarks/classic/__init__.py @@ -13,3 +13,4 @@ from .clear import * from .stream51 import * from .ex_model import * +from .concon import * \ No newline at end of file diff --git a/avalanche/benchmarks/classic/concon.py b/avalanche/benchmarks/classic/concon.py new file mode 100644 index 000000000..df80e695b --- /dev/null +++ b/avalanche/benchmarks/classic/concon.py @@ -0,0 +1,252 @@ +import random + +from pathlib import Path +from typing import Optional, Union, Any, List, TypeVar + +from torchvision import transforms + +from avalanche.benchmarks.utils.data import AvalancheDataset +from avalanche.benchmarks.utils.classification_dataset import _as_taskaware_supervised_classification_dataset +from avalanche.benchmarks import benchmark_from_datasets, CLScenario + +from avalanche.benchmarks.datasets.concon import ConConDataset + + +TCLDataset = TypeVar("TCLDataset", bound="AvalancheDataset") + + +_default_train_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5] + ) + ] +) + +_default_eval_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.5, 0.5, 0.5], + std=[0.5, 0.5, 0.5] + ) + ] +) + + +def build_concon_scenario( + list_train_dataset: List[TCLDataset], + list_test_dataset: List[TCLDataset], + seed: Optional[int] = None, + n_experiences: int = 3, + shuffle_order: bool = False, +): + if shuffle_order and not n_experiences == 1: + random.seed(seed) + random.shuffle(list_train_dataset) + random.seed(seed) + random.shuffle(list_test_dataset) + + if n_experiences == 1: + new_list_train_dataset = [] + new_list_train_dataset.append(list_train_dataset[0]) + + for i in range(1, len(list_train_dataset)): + new_list_train_dataset[0] = new_list_train_dataset[0].concat( + list_train_dataset[i]) + + list_train_dataset = new_list_train_dataset + + new_list_test_dataset = [] + new_list_test_dataset.append(list_test_dataset[0]) + + for i in range(1, len(list_test_dataset)): + new_list_test_dataset[0] = new_list_test_dataset[0].concat( + list_test_dataset[i]) + + list_test_dataset = new_list_test_dataset + + return benchmark_from_datasets( + train=list_train_dataset, + test=list_test_dataset + ) + + +def ConConDisjoint( + n_experiences: int, + *, + seed: Optional[int] = None, + shuffle_order: bool = False, + train_transform: Optional[Any] = _default_train_transform, + eval_transform: Optional[Any] = _default_eval_transform, + dataset_root: Optional[Union[str, Path]] = None, +) -> CLScenario: + """ + Creates a ConCon Disjoint benchmark. + + If the dataset is not present in the computer, this method will + automatically download and store it. + + The returned benchmark will be a domain-incremental one, where each task + is a different domain with different confounders. In this setting, + task-specific confounders never appear in other tasks. + + The benchmark instance returned by this method will have two fields, + `train_stream` and `test_stream`, which can be iterated to obtain + training and test :class:`Experience`. Each Experience contains the + `dataset` and the associated task label. + + :param dataset_root: The root directory of the dataset. + :param n_experiences: The number of experiences to use. + :param seed: The seed to use. + :param shuffle_order: Whether to shuffle the order of the experiences. + :param train_transform: The training transform to use. + :param eval_transform: The evaluation transform to use. + + :returns: The ConCon Disjoint benchmark. + """ + assert n_experiences == 3 or n_experiences == 1, "n_experiences must be 1 or 3 for ConCon Disjoint" + list_train_dataset = [] + list_test_dataset = [] + + for i in range(3): + train_dataset = ConConDataset("disjoint", i, root=dataset_root, train=True) + test_dataset = ConConDataset("disjoint", i, root=dataset_root, train=False) + train_dataset = _as_taskaware_supervised_classification_dataset( + train_dataset, + transform=train_transform + ) + test_dataset = _as_taskaware_supervised_classification_dataset( + test_dataset, + transform=eval_transform + ) + list_train_dataset.append(train_dataset) + list_test_dataset.append(test_dataset) + + return build_concon_scenario( + list_train_dataset, + list_test_dataset, + seed=seed, + n_experiences=n_experiences, + shuffle_order=shuffle_order + ) + + +def ConConStrict( + n_experiences: int, + *, + seed: Optional[int] = None, + shuffle_order: bool = False, + train_transform: Optional[Any] = _default_train_transform, + eval_transform: Optional[Any] = _default_eval_transform, + dataset_root: Optional[Union[str, Path]] = None, +) -> CLScenario: + """ + Creates a ConCon Strict benchmark. + + If the dataset is not present in the computer, this method will + automatically download and store it. + + The returned benchmark will be a domain-incremental one, where each task + is a different domain with different confounders. In this setting, + task-specific confounders may appear in other tasks as random features + in both positive and negative samples. + + The benchmark instance returned by this method will have two fields, + `train_stream` and `test_stream`, which can be iterated to obtain + training and test :class:`Experience`. Each Experience contains the + `dataset` and the associated task label. + + :param dataset_root: The root directory of the dataset. + :param n_experiences: The number of experiences to use. + :param seed: The seed to use. + :param shuffle_order: Whether to shuffle the order of the experiences. + :param train_transform: The training transform to use. + :param eval_transform: The evaluation transform to use. + + :returns: The ConCon Strict benchmark. + """ + assert n_experiences == 3 or n_experiences == 1, "n_experiences must be 1 or 3 for ConCon Disjoint" + list_train_dataset = [] + list_test_dataset = [] + + for i in range(3): + train_dataset = ConConDataset("strict", i, root=dataset_root, train=True) + test_dataset = ConConDataset("strict", i, root=dataset_root, train=False) + train_dataset = _as_taskaware_supervised_classification_dataset( + train_dataset, + transform=train_transform + ) + test_dataset = _as_taskaware_supervised_classification_dataset( + test_dataset, + transform=eval_transform + ) + list_train_dataset.append(train_dataset) + list_test_dataset.append(test_dataset) + + return build_concon_scenario( + list_train_dataset, + list_test_dataset, + seed=seed, + n_experiences=n_experiences, + shuffle_order=shuffle_order + ) + + +def ConConUnconfounded( + *, + train_transform: Optional[Any] = _default_train_transform, + eval_transform: Optional[Any] = _default_eval_transform, + dataset_root: Optional[Union[str, Path]] = None, +) -> CLScenario: + """ + Creates a ConCon Unconfounded benchmark. + + If the dataset is not present in the computer, this method will + automatically download and store it. + + The returned benchmark will only contain one task, where no task-specific + confounders are present. + + The benchmark instance returned by this method will have two fields, + `train_stream` and `test_stream`, which can be iterated to obtain + training and test :class:`Experience`. Each Experience contains the + `dataset` and the associated task label. + + :param dataset_root: The root directory of the dataset. + :param train_transform: The training transform to use. + :param eval_transform: The evaluation transform to use. + + :returns: The ConCon Unconfounded benchmark. + """ + train_dataset = [] + test_dataset = [] + + train_dataset.append(ConConDataset( + "unconfounded", 0, root=dataset_root, train=True)) + test_dataset.append(ConConDataset( + "unconfounded", 0, root=dataset_root, train=False)) + + train_dataset[0] = _as_taskaware_supervised_classification_dataset( + train_dataset[0], + transform=train_transform + ) + + test_dataset[0] = _as_taskaware_supervised_classification_dataset( + test_dataset[0], + transform=eval_transform + ) + + return benchmark_from_datasets( + train=train_dataset, + test=test_dataset + ) + + +__all__ = [ + "ConConDisjoint", + "ConConStrict", + "ConConUnconfounded", +] diff --git a/avalanche/benchmarks/datasets/__init__.py b/avalanche/benchmarks/datasets/__init__.py index 589d79192..b887ca51d 100644 --- a/avalanche/benchmarks/datasets/__init__.py +++ b/avalanche/benchmarks/datasets/__init__.py @@ -12,3 +12,4 @@ from .inaturalist import * from .penn_fudan import * from .clear import * +from .concon import * \ No newline at end of file diff --git a/avalanche/benchmarks/datasets/concon/__init__.py b/avalanche/benchmarks/datasets/concon/__init__.py new file mode 100644 index 000000000..3aa02e6cd --- /dev/null +++ b/avalanche/benchmarks/datasets/concon/__init__.py @@ -0,0 +1 @@ +from .concon import * \ No newline at end of file diff --git a/avalanche/benchmarks/datasets/concon/concon.py b/avalanche/benchmarks/datasets/concon/concon.py new file mode 100644 index 000000000..eb13eee58 --- /dev/null +++ b/avalanche/benchmarks/datasets/concon/concon.py @@ -0,0 +1,138 @@ +from pathlib import Path +from typing import Union, Optional + +from PIL import Image +from torchvision.transforms import ToTensor + +from avalanche.benchmarks.datasets import ( + SimpleDownloadableDataset, + default_dataset_location, +) + + +class ConConDataset(SimpleDownloadableDataset): + """ + ConConDataset represents a continual learning task with two classes: positive and negative. + All data instances are images based on the CLEVR framework. A ground truth rule can be used + to determine the binary class affiliation of any image. The dataset is designed to be used + in a continual learning setting with three sequential tasks, each confounded by a task-specific + confounder. The challenge arises from the fact that task-specific confounders change across tasks. + There are two dataset variants: + + - Disjoint: Task-specific confounders never appear in other tasks. + - Strict: Task-specific confounders may appear in other tasks as random features in both positive + and negative samples. + - Unconfounded: No task-specific confounders. + + Reference: + Busch, Florian Peter, et al. "Where is the Truth? The Risk of Getting Confounded in a Continual World." + arXiv preprint arXiv:2402.06434 (2024). + + Args: + variant (str): The variant of the dataset, must be one of 'strict', 'disjoint', 'unconfounded'. + scenario (int): The scenario number, must be between 0 and 2. + root (str or Path): The root directory where the dataset will be stored. If None, the default + avalanche dataset location will be used. + train (bool): If True, use the training set, otherwise use the test set. + download (bool): If True, download the dataset. + transform: A function/transform that takes in an PIL image and returns a transformed version. + E.g, ``transforms.RandomCrop`` for data augmentation. + """ + + urls = { + "strict": "https://zenodo.org/records/10630482/files/case_strict_main.zip", + "disjoint": "https://zenodo.org/records/10630482/files/case_disjoint_main.zip", + "unconfounded": "https://zenodo.org/records/10630482/files/unconfounded.zip" + } + + def __init__(self, + variant: str, + scenario: int, + root: Optional[Union[str, Path]] = None, + train: bool = True, + download: bool = True, + transform = None, + ): + assert variant in ["strict", "disjoint", "unconfounded"], "Invalid variant, must be one of 'strict', 'disjoint', 'unconf'" + assert scenario in range( + 0, 3), "Invalid scenario, must be between 0 and 2" + assert variant != "unconfounded" or scenario == 0, "Unconfounded scenario only has one variant" + + if root is None: + root = default_dataset_location("concon") + + self.root = Path(root) + + url = self.urls[variant] + + super(ConConDataset, self).__init__( + self.root, url, None, download=download, verbose=True + ) + + if variant == "strict": + self.variant = "case_strict_main" + elif variant == "disjoint": + self.variant = "case_disjoint_main" + else: + self.variant = variant + + self.scenario = scenario + self.train = train + self.transform = transform + self._load_dataset() + + def _load_metadata(self) -> bool: + root = self.root / self.variant + + if self.train: + images_dir = root / "train" + else: + images_dir = root / "test" + + images_dir = images_dir / "images" / f"t{self.scenario}" + + self.image_paths = [] + self.targets = [] + + for class_id, class_dir in enumerate(images_dir.iterdir()): + for image_path in class_dir.iterdir(): + self.image_paths.append(image_path) + self.targets.append(class_id) + + return True + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + image_path = self.image_paths[idx] + image = Image.open(image_path).convert("RGB") + + if self.transform is not None: + image = self.transform(image) + + target = self.targets[idx] + return image, target + + +if __name__ == "__main__": + # this little example script can be used to visualize the first image + # loaded from the dataset. + from torch.utils.data.dataloader import DataLoader + import matplotlib.pyplot as plt + from torchvision import transforms + import torch + + train_data = ConConDataset("strict", 0, "data_debug/concon", transform=ToTensor()) + dataloader = DataLoader(train_data, batch_size=1) + + for batch_data in dataloader: + x, y = batch_data + plt.imshow(transforms.ToPILImage()(torch.squeeze(x))) + plt.show() + print(x.shape) + print(y.shape) + break + + +__all__ = ["ConConDataset"]