diff --git a/README.md b/README.md index 9bf5c44..da0d08f 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ Place the datasets in the `data/` directory. #### 2.4 Representation Learning - Navigate to the `representation/` directory. -- **TBD** +- Follow instructions in the README.md file in the respective folder. ## Contact Information: - Grégoire Montavon: [gregoire.montavon@fu-berlin.de](mailto:gregoire.montavon@fu-berlin.de) diff --git a/representation/README.md b/representation/README.md new file mode 100644 index 0000000..4769d42 --- /dev/null +++ b/representation/README.md @@ -0,0 +1,42 @@ +# Representation Learning Experiments + +This folder contains the code to reproduce our results for the representation learning models. +Models that are use in the paper are: `r50-sup`, `r50-barlowtwins`, `r50-clip`, `simclr-rn50`. + +### Install dependencies + +```bash +pip install -r requirements.txt +``` + +### Compute embeddings + +Embeddings for the different ResNet-50 models can be extracted with the following script. + +```bash +python extract_embeddings.py --data-root \ + --model \ + --output-dir \ + --dataset \ + --device cuda \ + --split + +``` + +### Generate BiLRP Heatmaps + +BILRP heatmaps can be generated with first computing LRP relevances (`compute_bilrp.py`) and +then plotting the result (`plots/plot_bilrp.py`). + +### Linear Classifiers + +To train linear classifiers on the extracted embeddings, `linear_probing.py` can be used. This generates +json files with the predictions of the linear classifiers. + +### Plot classifier results + +With the notebook `plots/representation.ipynb`, the linear probing results can then be analyzed and plotted. + +### T-SNE plots +The T-SNE plots can be generated from the extracted features with `plots/fish_tsne.py` and `plots/trucks_tsne.py`. + diff --git a/representation/requirements.txt b/representation/requirements.txt new file mode 100644 index 0000000..1400c76 --- /dev/null +++ b/representation/requirements.txt @@ -0,0 +1,4 @@ +torchvision==0.19 +torch==2.4.0 +scikit-learn==1.5.1 +Pillow \ No newline at end of file diff --git a/representation/src/bilrp/bilrp.py b/representation/src/bilrp/bilrp.py new file mode 100644 index 0000000..727f2f3 --- /dev/null +++ b/representation/src/bilrp/bilrp.py @@ -0,0 +1,67 @@ +import torch +from tqdm import tqdm +import numpy as np +from zennit.attribution import Gradient +from bilrp.plotting import plot_relevances, clip, get_alpha + + +def compute_branch(x, model, composite, device='cuda'): + e = model.forward(x) + y = e.squeeze() + n_features = y.shape + + R = [] + for k, yk in tqdm(enumerate(y)): + z = np.zeros((n_features[0])) + z[k] = y[k].detach().cpu().numpy().squeeze() + r_proj = ( + torch.FloatTensor((z.reshape([1, n_features[0], 1, 1]))) + .to(device) + .data.squeeze(2) + .squeeze(2) + ) + model.zero_grad() + x.grad = None + with Gradient(model=model, composite=composite) as attributor: + out, relevance = attributor(x, r_proj) + relevance = relevance.squeeze().detach().cpu().numpy() + R.append(relevance) + del out, relevance + return R, e + + +def pool(X, stride): + K = [ + torch.nn.functional.avg_pool2d( + torch.from_numpy(o).unsqueeze(0).unsqueeze(1), + kernel_size=stride, + stride=stride, + padding=0, + ) + .squeeze() + .numpy() + for o in X + ] + return K + + +def compute_rel(r1, r2, poolsize=[8]): + R = [np.array(r).sum(1) for r in [r1, r2]] + R = np.tensordot(pool(R[0], poolsize), pool(R[1], poolsize), axes=(0, 0)) + return R + + +def plot_bilrp(x1, x2, R1, R2, fname=None, normalization_factor='individual'): + clip_func = lambda x: get_alpha(clip(x, clim1=[-2, 2], clim2=[-20, 20], normalization_factor=normalization_factor), + p=2) + poolsize = [8] + R = compute_rel(R1, R2) + indices = np.indices(R.shape) + inds_all = [(i, R[i[0], i[1], i[2], i[3]]) for i in indices.reshape((4, np.prod(indices.shape[1:]))).T] + plot_relevances(inds_all, x1, x2, clip_func, poolsize, curvefac=2.5, fname=fname) + + +def projection_conv(input_dim, embedding_size=2048): + pca = torch.nn.Sequential( + *[torch.nn.Flatten(), torch.nn.Conv2d(input_dim, embedding_size, (1, 1), bias=False), ]) + return pca diff --git a/representation/src/bilrp/data.py b/representation/src/bilrp/data.py new file mode 100644 index 0000000..d13127f --- /dev/null +++ b/representation/src/bilrp/data.py @@ -0,0 +1,35 @@ +from torch.utils.data import Dataset +import torchxrayvision as xrv + + +class CovidDataset(Dataset): + def __init__(self, dataset): + super().__init__() + self.dataset = dataset + self.labels = dataset.labels[:, 3].astype(int) + self.patient_ids = dataset.csv['patientid'] + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + sample = self.dataset[idx] + label = self.labels[idx] + return sample, label + + +def load_github_dataset(transform): + covid_dataset = xrv.datasets.COVID19_Dataset(imgpath="resources/data/xray/covid-chestxray-dataset/images", + csvpath="resources/data/xray/covid-chestxray-dataset/metadata.csv", + transform=transform) + covid_dataset = CovidDataset(covid_dataset) + return covid_dataset + + +def load_nih_dataset(transform): + nih_dataset = xrv.datasets.NIH_Dataset(imgpath="resources/data/xray/NIH/images-224", + csvpath="resources/Data_Entry_2017_v2020.csv", + bbox_list_path="resources/data/xray/NIH/BBox_List_2017.csv", + transform=transform, + unique_patients=True) + return nih_dataset diff --git a/representation/src/bilrp/plotting.py b/representation/src/bilrp/plotting.py new file mode 100644 index 0000000..83cf4df --- /dev/null +++ b/representation/src/bilrp/plotting.py @@ -0,0 +1,94 @@ +from matplotlib import pyplot as plt +import numpy as np +import numpy + +NORMALIZATION_FACTORS = { + 'covid': 0.1423813963010631, +} + + +def clip(R, clim1, clim2, normalization_factor='individual'): + delta = list(np.array(clim2) - np.array(clim1)) + + if normalization_factor == 'individual': + Rnorm = np.mean(R ** 4) ** 0.25 + else: + if normalization_factor in NORMALIZATION_FACTORS: + Rnorm = NORMALIZATION_FACTORS[normalization_factor] + Rnorm = np.sqrt(np.mean(R ** 4) ** 0.25) * np.sqrt(Rnorm) + else: + raise ValueError('unknown normalization factor') + + R = R / Rnorm # normalization + R = R - np.clip(R, clim1[0], clim1[1]) # sparsification + R = np.clip(R, delta[0], delta[1]) / delta[1] # thresholding + return R + + +def get_alpha(x, p=1): + x = x ** p + return x + + +def plot_relevances(c, x1, x2, clip_func, stride, fname=None, curvefac=1.): + h, w, channels = x1.shape if len(x1.shape) == 3 else list(x1.shape) + [1] + wgap, hpad = int(0.05 * w), int(0.6 * w) + + fig, ax = plt.subplots(figsize=(10, 8)) + plt.ylim(-hpad - 2, h + hpad + 1) + plt.xlim(0, (w + 2) * 2 + wgap + 1) + + x1 = x1.reshape(h, w, channels).squeeze() + x2 = x2.reshape(h, w, channels).squeeze() + + border_w = np.zeros((1, w, 4)) + border_h = np.zeros((h + 2, 1, 4)) + border_h[:, :, -1] = 1 + border_w[:, :, -1] = 1 + + x1 = np.concatenate([border_h, np.concatenate([border_w, x1, border_w], axis=0), border_h], axis=1) + x2 = np.concatenate([border_h, np.concatenate([border_w, x2, border_w], axis=0), border_h], axis=1) + + mid = numpy.ones([h + 2, wgap, channels]).squeeze() + X = numpy.concatenate([x1, mid, x2], axis=1)[ + ::-1] + plt.imshow(X, cmap='gray', vmin=-1, vmax=1) + + if len(stride) == 2: + stridex = stride[0] + stridey = stride[1] + else: + stridex = stridey = stride[0] + + relevance_array = np.array([i[1] for i in c]) + indices = [i[0] for i in c] + + alphas = clip_func(relevance_array) + inds_plotted = [] + + for indx, alpha, s in zip(indices, alphas, relevance_array): + i, j, k, l = indx[0], indx[1], indx[2], indx[3] + + if alpha > 0.: + xm = int(w / 2) + 6 + xa = stridey * j + (stridey / 2 - 0.5) - xm + xb = stridey * l + (stridey / 2 - 0.5) - xm + w + wgap + ya = h - (stridex * i + (stridex / 2 - 0.5)) + yb = h - (stridex * k + (stridex / 2 - 0.5)) + ym = (0.8 * (ya + yb) - curvefac * int(h / 6)) + ya -= ym + yb -= ym + lin = numpy.linspace(0, 1, 25) + plt.plot(xa * lin + xb * (1 - lin) + xm, ya * lin ** 2 + yb * (1 - lin) ** 2 + ym, + color='red' if s > 0 else 'blue', alpha=alpha) + + inds_plotted.append(((i, j, k, l), s)) + + plt.axis('off') + + if fname: + plt.tight_layout() + plt.savefig(fname, dpi=300, transparent=True) + else: + plt.show() + plt.close() diff --git a/representation/src/bilrp/utils.py b/representation/src/bilrp/utils.py new file mode 100644 index 0000000..781c226 --- /dev/null +++ b/representation/src/bilrp/utils.py @@ -0,0 +1,143 @@ +from models import load_model +import torch +import numpy as np +from torchvision import transforms +from data import ImagenetSubset, FISH +from matplotlib.colors import ListedColormap +from torchvision.datasets import ImageFolder +from bilrp.data import load_github_dataset, load_nih_dataset +from PIL import Image +import cv2 as cv +import matplotlib.pyplot as plt +from functools import partial + +# Get the color map by name: +cm = plt.get_cmap('Greys') + +my_cmap = plt.cm.seismic(np.arange(plt.cm.seismic.N)) +my_cmap[:, 0:3] *= 0.85 +my_cmap = ListedColormap(my_cmap) + +IMAGENET_NORM = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]] +CLIP_NORM = ((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + + +def edge_filter_transform(img, a=100, b=200): + img = img.convert("L") + img = np.array(img, dtype=np.uint8) + # Detecting Edges on the Image using the argument ImageFilter.FIND_EDGES + edges = 0.0 + for k, w in zip([3, 4, 5, 6], [1, 2, 2, 1]): + edges = edges + w * (cv.Canny(cv.blur(img, ksize=(k, k)), a, b)) + + edges = edges / 255.0 + norm = plt.Normalize(vmin=edges.min(), vmax=edges.max()) + edges = cm(norm(edges)) + edges = (edges * 255).astype(np.uint8) + return Image.fromarray(edges) + + +transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + edge_filter_transform, + transforms.ToTensor(), +]) + +transform_2 = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), +]) + + +def ToPIL(x): + # pixels are in [-1024, 1024] + x = x + 1024 + x = x * 255 / 2048 + # image is numpy + x = Image.fromarray(np.uint8(x[0])) + return x + + +transform_covid = transforms.Compose([ + ToPIL, + transforms.Resize(224), + transforms.CenterCrop(224), + partial(edge_filter_transform, a=20, b=40), + transforms.ToTensor(), +]) + +transform_covid_2 = transforms.Compose([ + ToPIL, + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), +]) + +clip_norm = transforms.Normalize(*CLIP_NORM) +imagenet_norm = transforms.Normalize(*IMAGENET_NORM) + + +def get_fish_dataset(root='resources/data/subsets/fish/50-50/train', target=2): + dataset_1 = ImagenetSubset(root=root, transform=transform, classes=FISH) + dataset_2 = ImagenetSubset(root=root, transform=transform_2, classes=FISH) + indices = np.argwhere(np.array(dataset_1.targets) == target).flatten() + return torch.utils.data.Subset(dataset_1, indices), torch.utils.data.Subset(dataset_2, indices) + + +def get_covid_subset(github=True, label=1): + if github: + dataset_1 = load_github_dataset(transform=transform_covid) + dataset_2 = load_github_dataset(transform=transform_covid_2) + else: + dataset_1 = load_nih_dataset(transform=transform_covid) + dataset_2 = load_nih_dataset(transform=transform_covid_2) + + subset_indices = np.argwhere(np.array(dataset_1.labels) == label).flatten() + print('subset_indices', subset_indices) + return torch.utils.data.Subset(dataset_1, subset_indices), torch.utils.data.Subset(dataset_2, subset_indices) + + +def get_truck_dataset(): + dataset_1 = ImageFolder(root='resources/old/resources/lrp_images', transform=transform) + dataset_2 = ImageFolder(root='resources/old/resources/lrp_images', transform=transform_2) + return dataset_1, dataset_2 + + +def get_covid_dataset(): + dataset = ImageFolder(root='resources/covid_data', transform=transform) + return dataset + + +def load_models(dataset='fish', num_classes=16): + if dataset.startswith('fish'): + model_config = [('Sup-Fish', 'r50-sup'), ('CLIP', 'r50-clip-wo-attnpool'), + ('SimCLR', 'simclr-rn50'), ('BarlowTwins', 'r50-barlowtwins')] + elif dataset == 'trucks': + model_config = [('Sup-Truck', 'r50-sup'), ('CLIP', 'r50-clip-wo-attnpool'), + ('SimCLR', 'simclr-rn50'), ('BarlowTwins', 'r50-barlowtwins')] + elif dataset.startswith('covid'): + model_config = [('PubmedCLIP', 'pubmedclip')] + else: + raise ValueError() + + models = {} + + for name, identifier in model_config: + if identifier == 'pubmedclip': + model = load_model('r50-clip-wo-attnpool', model_paths={}, num_classes=num_classes) + model = model.to('cpu') + state_dict = torch.load('pubmedclip_RN50.pth', map_location=torch.device('cpu')) + new_state_dict = {} + for key, val in state_dict['state_dict'].items(): + new_state_dict[key.replace('visual.', 'encoder.')] = val + model.load_state_dict(new_state_dict, strict=False) + else: + model = load_model(identifier, model_paths={}, num_classes=num_classes) + model = model.to('cpu') + + model.fc = torch.nn.Identity() + model.eval() + models[name] = model + return models diff --git a/representation/src/compute_bilrp.py b/representation/src/compute_bilrp.py new file mode 100644 index 0000000..db04397 --- /dev/null +++ b/representation/src/compute_bilrp.py @@ -0,0 +1,80 @@ +import random + +from bilrp.utils import (load_models, get_fish_dataset, imagenet_norm, clip_norm, get_truck_dataset, + get_covid_dataset, get_covid_subset) +from bilrp.bilrp import compute_branch, pool +from zennit.torchvision import ResNetCanonizer +from utils.lrp import module_map_resnet +from zennit.core import Composite +import os +import numpy as np +import argparse +from tqdm import tqdm + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset') +parser.add_argument('--model') +parser.add_argument('--output', default='bilrp/relevances') +parser.add_argument('--device', default='cuda') +args = parser.parse_args() + +device = args.device + +random.seed(42) +if args.dataset == 'fish-tench': + _, dataset = get_fish_dataset(target=2) + indices = [1, 3, 4, 5, 8, 9, 12, 13, 16, 17, 18, 19] +elif args.dataset == 'fish-coho': + _, dataset = get_fish_dataset(target=13) + indices = [2, 12, 14, 18, 19, 21, 25, 27, 29, 31, 32, 45, 47, 51] +elif args.dataset == 'trucks': + _, dataset = get_truck_dataset() + indices = [4, 5, 6, 7, 8] +elif args.dataset == 'covid': + _, dataset = get_covid_dataset() + indices = list(range(len(dataset))) +elif args.dataset == 'covid-github-1': + _, dataset = get_covid_subset(github=True, label=1) + indices = random.sample(list(range(len(dataset))), k=30) +elif args.dataset == 'covid-github-0': + _, dataset = get_covid_subset(github=True, label=0) + indices = random.sample(list(range(len(dataset))), k=30) +elif args.dataset == 'covid-nih-1': + _, dataset = get_covid_subset(github=False, label=1) + indices = random.sample(list(range(len(dataset))), k=30) +elif args.dataset == 'covid-nih-0': + _, dataset = get_covid_subset(github=False, label=0) + indices = random.sample(list(range(len(dataset))), k=30) +else: + raise ValueError() + +print('indices', indices) + +canonizers = [ResNetCanonizer()] +composite = Composite(module_map=module_map_resnet, canonizers=canonizers) + +models = load_models(dataset=args.dataset) +model = models[args.model].to(device) + +output_path = os.path.join(args.output, args.dataset, args.model) +os.makedirs(output_path, exist_ok=True) + +for image_idx in tqdm(indices): + sample = dataset[image_idx] + if isinstance(sample, list) or isinstance(sample, tuple): + sample = sample[0] + if isinstance(sample, dict): + sample = sample['img'] + + if sample.shape[0] == 1: + sample = sample.repeat(3, 1, 1) + + if args.model in ['CLIP', 'PubmedCLIP']: + x = clip_norm(sample) + else: + x = imagenet_norm(sample) + x = x.unsqueeze(0) + x.requires_grad = True + x = x.to(device) + R, _ = compute_branch(x, model, composite) + np.save(os.path.join(output_path, f'{image_idx}'), np.stack(R)) diff --git a/representation/src/extract_embeddings.py b/representation/src/extract_embeddings.py new file mode 100644 index 0000000..3c63418 --- /dev/null +++ b/representation/src/extract_embeddings.py @@ -0,0 +1,135 @@ +import os +import torch +from tqdm.auto import tqdm +from models import load_model +import argparse +import numpy as np +from data import get_dataset +from torchvision.datasets import ImageFolder +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +from data import CLIP_NORM, IMAGENET_NORM + + +def clip_fix(model, module_name='encoder.relu3', num_filters=5, y_offset=90): + filter_indices = np.load('watermark_fix_filter_indices.npy') + + def hook(model, input, output) -> None: + mask = torch.ones_like(output) + ind = filter_indices[-num_filters:] + for f in ind: + mask[:, f, y_offset:, :] = 0 + return output * mask + + for n, m in model.named_modules(): + if n == module_name: + m.register_forward_hook(hook) + break + return model + + +class R50Wrapper(torch.nn.Module): + def __init__(self, model, head_name='fc'): + super().__init__() + self.model = model + setattr(model, head_name, torch.nn.Identity()) + + def forward(self, x): + rep = self.model(x) + return rep + + +class VITWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + rep = self.model.forward_features(x) + return rep, rep + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data-root', default='resources/imagenette2-320') + parser.add_argument('--image_size', default=224, type=int) + parser.add_argument('--model', default='r50-scratch') + parser.add_argument('--output-dir', default='resources/features') + parser.add_argument('--dataset', default='imagenette') + parser.add_argument('--device', default='cpu') + parser.add_argument('--vit', action='store_true') + parser.add_argument('--split', default='test') + parser.add_argument('--clip-fix', action='store_true') + parser.add_argument('--filters', type=int, default=0) + parser.add_argument('--blur', action='store_true') + args = parser.parse_args() + + dl_kwargs = dict(batch_size=8, num_workers=0) + + if args.model.startswith('r50-clip'): + norm = CLIP_NORM + else: + norm = IMAGENET_NORM + + if args.blur: + transform = transforms.Compose([ + transforms.Resize(size=224), + transforms.CenterCrop(size=224), + transforms.GaussianBlur(11, sigma=(1.5, 1.5)), + transforms.ToTensor(), + transforms.Normalize(*norm), + ]) + else: + transform = transforms.Compose([ + transforms.Resize(size=224), + transforms.CenterCrop(size=224), + transforms.ToTensor(), + transforms.Normalize(*norm), + ]) + + if args.dataset == 'image-folder': + dataset = ImageFolder(args.data_root, transform=transform) + loader = DataLoader(dataset, batch_size=32, num_workers=8) + num_classes = len(dataset.classes) + else: + ds_args = dict(data_root=args.data_root, + transform=transform, + dl_kwargs=dl_kwargs) + dataset, num_classes = get_dataset(args.dataset, ds_args) + + assert args.split in ['train', 'test'] + if args.split == 'test': + loader = dataset.test_dataloader() + else: + loader = dataset.train_dataloader(shuffle=False) + + model = load_model(args.model, model_paths=None, num_classes=num_classes) + model = model.to(args.device) + model.eval() + + if args.clip_fix: + model = clip_fix(model, num_filters=args.filters) + + if args.vit: + model = VITWrapper(model) + else: + model = R50Wrapper(model) + + gt = () + embeddings = () + for x, y in tqdm(loader): + x = x.to(args.device) + y = y.to(args.device) + with torch.no_grad(): + rep = model(x) + gt += (y,) + embeddings += (rep,) + + os.makedirs(args.output_dir, exist_ok=True) + + output_file = os.path.join(args.output_dir, args.model) + if args.clip_fix: + output_file = os.path.join(args.output_dir, args.model + f'_fix_filter_{args.filters}') + np.savez(output_file, + embeddings=torch.cat(embeddings, dim=0).cpu().numpy(), + labels=torch.cat(gt, dim=0).cpu().numpy()) diff --git a/representation/src/filter_images.py b/representation/src/filter_images.py new file mode 100644 index 0000000..173e678 --- /dev/null +++ b/representation/src/filter_images.py @@ -0,0 +1,60 @@ +import sys +sys.path.append('.') +from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, FasterRCNN_ResNet50_FPN_V2_Weights +import torch +from data import ImagenetSubset, BONE_FISH, TRUCKS, FISH +import argparse +from torchvision.datasets import ImageFolder +from tqdm.auto import tqdm +import numpy as np +import os + +parser = argparse.ArgumentParser() +parser.add_argument('--data-root', default='resources/imagenette2-320') +parser.add_argument('--output', default='resources/person_indicator/bone-fish-50-50') +parser.add_argument('--dataset', default='imagenette') +parser.add_argument('--split', default='val') +parser.add_argument('--device', default='cuda') +args = parser.parse_args() + +name = args.dataset +val_path = os.path.join(args.data_root, args.split) + +if name == 'imagenette': + dataset = ImageFolder(root=val_path) +elif name == 'bone-fish': + dataset = ImagenetSubset(root=val_path, classes=BONE_FISH) +elif name == 'trucks': + dataset = ImagenetSubset(root=val_path, classes=TRUCKS) +elif name == 'fish': + dataset = ImagenetSubset(root=val_path, classes=FISH) +else: + raise ValueError() + +# Step 1: Initialize model with the best available weights +weights = FasterRCNN_ResNet50_FPN_V2_Weights.DEFAULT +model = fasterrcnn_resnet50_fpn_v2(weights=weights, box_score_thresh=0.9) +model.to(args.device) +model.eval() + +person_indicator = torch.zeros(len(dataset), dtype=torch.bool) +for idx in tqdm(range(len(dataset))): + img, label = dataset[idx] + img = torch.tensor(np.array(img)) + img = torch.permute(img, (2, 0, 1)) + + # Step 2: Initialize the inference transforms + preprocess = weights.transforms() + + # Step 3: Apply inference preprocessing transforms + batch = [preprocess(img).to(args.device)] + + # Step 4: Use the model and visualize the prediction + prediction = model(batch)[0] + labels = [weights.meta["categories"][i] for i in prediction["labels"]] + is_person = 'person' in labels + + if is_person: + person_indicator[idx] = True + +np.save(os.path.join(args.output), person_indicator.numpy()) diff --git a/representation/src/linear_probing.py b/representation/src/linear_probing.py new file mode 100644 index 0000000..951ec1f --- /dev/null +++ b/representation/src/linear_probing.py @@ -0,0 +1,64 @@ +import json +from sklearn.metrics import accuracy_score, balanced_accuracy_score +from data import TRUCKS, FISH +from utils.probing import load_embeddings_labels, load_class_names, train_classifier +import os +import argparse +from sklearn.metrics import confusion_matrix +import numpy as np + +parser = argparse.ArgumentParser() +parser.add_argument('--model') +parser.add_argument('--output-dir') +parser.add_argument('--dataset', choices=['trucks', 'fish']) +args = parser.parse_args() + +model_name = args.model + +results = {} + +if args.dataset == 'trucks': + class_names = load_class_names(TRUCKS) +elif args.dataset == 'fish': + class_names = load_class_names(FISH) +else: + raise ValueError() + +base = f'resources/features/{args.dataset}' + +# train classifier +X_train, y_train = load_embeddings_labels(os.path.join(base, 'std/train', + model_name + '.npz')) +clf = train_classifier(X_train, y_train, reg=1.0) + +os.makedirs(args.output_dir, exist_ok=True) +coefs = {model_name: clf.coef_} +np.savez(os.path.join(args.output_dir, f'{model_name}_probe_weights.npz'), **coefs) + +# inference on clean test set +X_test, y_test = load_embeddings_labels(os.path.join(base, 'std/test', + model_name + '.npz')) +print(X_test.shape) +predictions = clf.predict(X_test) + +results[model_name] = { + 'accuracy': accuracy_score(y_test, predictions), + 'predictions': predictions.tolist(), + 'labels': y_test.tolist(), + 'confusion_matrix': confusion_matrix(y_test, predictions).tolist(), + 'class_names': class_names +} + +if args.dataset == 'trucks': + # inference on watermark samples + X_watermark, y_watermark = load_embeddings_labels(os.path.join(base, 'watermark_all/test', + model_name + '.npz')) + watermark_predictions = clf.predict(X_watermark) + watermark_results = {'watermark_accuracy': accuracy_score(y_watermark, watermark_predictions), + 'watermark_confusion_matrix': confusion_matrix(y_watermark, watermark_predictions).tolist(), + 'watermark_predictions': watermark_predictions.tolist(), + 'watermark_labels': y_watermark.tolist()} + results[model_name] = {**results[model_name], **watermark_results} + +with open(os.path.join(args.output_dir, f'{model_name}.json'), 'w') as f: + json.dump(results, f) diff --git a/representation/src/models/__init__.py b/representation/src/models/__init__.py new file mode 100644 index 0000000..d8338b8 --- /dev/null +++ b/representation/src/models/__init__.py @@ -0,0 +1,56 @@ +import torch +from collections import OrderedDict +from torchvision.models import resnet50, resnet18 +from models.clip import load_clip_rn50, load_clip_rn50_wo_attnpool, load_clip_rn50_detach +from models.utils import load_vissl_r50 +from models.vit import load_vit_model +from models.vissl import VisslLoader + + +def load_r50_checkpoint(model_path, delete_prefix='model.', + state_dict_key='state_dict', classes=10, clip=False): + model = torch.load(model_path, map_location=torch.device('cpu')) + new_state_dict = OrderedDict() + for key, value in model[state_dict_key].items(): + new_key = key.replace(delete_prefix, '') + new_state_dict[new_key] = value + if clip: + resnet = load_clip_rn50(classes) + else: + resnet = resnet50(zero_init_residual=True) + resnet.fc = torch.nn.Linear(2048, classes) + msg = resnet.load_state_dict(new_state_dict, strict=False) + print(msg) + return resnet + + +def load_model(name, model_paths, num_classes=10): + if name in VisslLoader.MODELS.keys(): + loader = VisslLoader(name) + backbone = loader.load_model_from_source() + elif name == 'r50-barlowtwins': + backbone = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50') + elif name == 'r50-swav': + backbone = torch.hub.load('facebookresearch/swav:main', 'resnet50') + elif name == 'r50-vicreg': + backbone = torch.hub.load('facebookresearch/vicreg:main', 'resnet50') + elif name == 'r50-clip': + backbone = load_clip_rn50(num_classes=num_classes) + elif name == 'r50-clip-wo-attnpool': + backbone = load_clip_rn50_wo_attnpool(num_classes=num_classes) + elif name == 'r50-clip-detach': + backbone = load_clip_rn50_detach(num_classes=num_classes) + elif name == 'vit-b16-mae': + chkpt_dir = 'resources/mae_pretrain_vit_base.pth' + backbone = load_vit_model(chkpt_dir, 'vit_base_patch16', state_dict_key='model') + elif name == 'r50-sup': + backbone = resnet50(pretrained=True) + elif name == 'r50-scratch': + backbone = resnet50(pretrained=False) + elif name == 'r18-sup': + backbone = resnet18(pretrained=True) + elif name in model_paths: + backbone = load_r50_checkpoint(model_paths[name], classes=num_classes, clip=name.startswith('r50-clip')) + else: + raise ValueError() + return backbone diff --git a/representation/src/models/clip/__init__.py b/representation/src/models/clip/__init__.py new file mode 100644 index 0000000..d69b98d --- /dev/null +++ b/representation/src/models/clip/__init__.py @@ -0,0 +1,43 @@ +from .clip import load +import torch +from models.clip.model import AttentionPool2dDetach + + +class ModelWrapper(torch.nn.Module): + + def __init__(self, encoder, num_classes=10, ga_pooling=False): + super().__init__() + self.encoder = encoder + self.fc = torch.nn.Linear(1024, num_classes) + self.ga_pooling = ga_pooling + + def forward(self, x): + rep = self.encoder(x) + if self.ga_pooling: + rep = torch.mean(rep.view(rep.size(0), rep.size(1), -1), dim=2) + return self.fc(rep) + + +def load_clip_rn50(num_classes=10): + model, transform = load(name='RN50', device='cpu') + model = model.visual + return ModelWrapper(encoder=model, num_classes=num_classes) + + +def load_clip_rn50_wo_attnpool(num_classes=10): + model, transform = load(name='RN50', device='cpu') + model = model.visual + model.attnpool = torch.nn.Identity() + return ModelWrapper(encoder=model, num_classes=num_classes, ga_pooling=True) + + +def load_clip_rn50_detach(num_classes=10): + model, transform = load(name='RN50', device='cpu') + model = model.visual + model.attnpool = AttentionPool2dDetach(positional_embedding=model.attnpool.positional_embedding, + c_proj=model.attnpool.c_proj, + v_proj=model.attnpool.v_proj, + k_proj=model.attnpool.k_proj, + q_proj=model.attnpool.q_proj, + num_heads=model.attnpool.num_heads) + return ModelWrapper(encoder=model, num_classes=num_classes, ga_pooling=False) diff --git a/representation/src/models/clip/clip.py b/representation/src/models/clip/clip.py new file mode 100644 index 0000000..e3397e8 --- /dev/null +++ b/representation/src/models/clip/clip.py @@ -0,0 +1,189 @@ +import hashlib +import os +import urllib +import warnings +from typing import Any, Union, List + +import torch +from PIL import Image +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from tqdm import tqdm + +from .model import build_model + +try: + from torchvision.transforms import InterpolationMode + + BICUBIC = InterpolationMode.BICUBIC +except ImportError: + BICUBIC = Image.BICUBIC + +__all__ = ["available_models", "load"] + +_MODELS = { + "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", +} + + +def _download(url: str, root: str): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + expected_sha256 = url.split("/")[-2] + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: + return download_target + else: + warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, + unit_divisor=1024) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: + raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") + + return download_target + + +def _convert_image_to_rgb(image): + return image.convert("RGB") + + +def _transform(n_px): + return Compose([ + Resize(n_px, interpolation=BICUBIC), + CenterCrop(n_px), + _convert_image_to_rgb, + ToTensor(), + Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + + +def available_models() -> List[str]: + """Returns the names of available CLIP models""" + return list(_MODELS.keys()) + + +def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit: bool = False, download_root: str = None): + """Load a CLIP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + + device : Union[str, torch.device] + The device to put the loaded model + + jit : bool + Whether to load the optimized JIT model or more hackable non-JIT model (default). + + download_root: str + path to download the model files; by default, it uses "~/.cache/clip" + + Returns + ------- + model : torch.nn.Module + The CLIP model + + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if name in _MODELS: + model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError(f"Model {name} not found; available models = {available_models()}") + + with open(model_path, 'rb') as opened_file: + try: + # loading JIT archive + model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") + jit = False + state_dict = torch.load(opened_file, map_location="cpu") + + if not jit: + model = build_model(state_dict or model.state_dict()).to(device) + if str(device) == "cpu": + model.float() + return model, _transform(model.visual.input_resolution) + + # patch the device names + device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) + device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_image) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [1, 2]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_image) + patch_float(model.encode_text) + + model.float() + + return model, _transform(model.input_resolution.item()) diff --git a/representation/src/models/clip/model.py b/representation/src/models/clip/model.py new file mode 100644 index 0000000..823dda0 --- /dev/null +++ b/representation/src/models/clip/model.py @@ -0,0 +1,473 @@ +from collections import OrderedDict +from typing import Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = nn.ReLU(inplace=True) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = nn.ReLU(inplace=True) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu3 = nn.ReLU(inplace=True) + + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential(OrderedDict([ + ("-1", nn.AvgPool2d(stride)), + ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), + ("1", nn.BatchNorm2d(planes * self.expansion)) + ])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu1(self.bn1(self.conv1(x))) + out = self.relu2(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu3(out) + return out + + +class AttentionPool2dDetach(nn.Module): + def __init__(self, positional_embedding, k_proj, q_proj, v_proj, c_proj, num_heads): + super().__init__() + self.positional_embedding = positional_embedding + self.k_proj = k_proj + self.q_proj = q_proj + self.v_proj = v_proj + self.c_proj = c_proj + self.num_heads = num_heads + + def forward(self, x): + print('forward') + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1].detach(), key=x.detach(), value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class AttentionPool2d(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(width // 2) + self.relu1 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(width // 2) + self.relu2 = nn.ReLU(inplace=True) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.relu3 = nn.ReLU(inplace=True) + self.avgpool = nn.AvgPool2d(2) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + return x + + x = x.type(self.conv1.weight.dtype) + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int + ): + super().__init__() + + self.context_length = context_length + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.visual.attnpool.c_proj.in_features ** -0.5 + nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image): + return self.visual(image.type(self.dtype)) + + def encode_text(self, text): + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(dim=1, keepdim=True) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + # shape = [global_batch_size, global_batch_size] + return logits_per_image, logits_per_text + + +def convert_weights(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def build_model(state_dict: dict): + vit = "visual.proj" in state_dict + + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) + vision_patch_size = None + assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) + + model = CLIP( + embed_dim, + image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + return model.eval() \ No newline at end of file diff --git a/representation/src/models/model.py b/representation/src/models/model.py new file mode 100644 index 0000000..2f08f60 --- /dev/null +++ b/representation/src/models/model.py @@ -0,0 +1,61 @@ +import torch +from torch.nn import functional as F +import torchmetrics +import pytorch_lightning as pl + + +class Model(pl.LightningModule): + + def __init__(self, model, epochs=200, learning_rate=0.001, weight_decay=1e-4, loss='bce', num_classes=10): + super().__init__() + assert loss in ['bce', 'ce'] + self.save_hyperparameters() + self.model = model + + self.train_acc = torchmetrics.Accuracy() + self.valid_acc = torchmetrics.Accuracy() + self.valid_loss = torchmetrics.MeanMetric() + + self.loss = loss + self.num_classes = num_classes + self.epochs = epochs + + def forward(self, x): + return self.model.forward(x) + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.forward(x) + if self.loss == 'ce': + loss = F.cross_entropy(y_hat, y) + elif self.loss == 'bce': + y = F.one_hot(y, self.num_classes) + loss = F.binary_cross_entropy_with_logits(y_hat, y.float()) + + self.train_acc(y_hat, y) + self.log("train/loss", loss, on_epoch=True, on_step=False) + self.log("train/accuracy", self.train_acc, on_epoch=True, on_step=False) + return {'loss': loss, 'accuracy': self.train_acc} + + def validation_step(self, batch, batch_idx): + x, y = batch + with torch.no_grad(): + y_hat = self.forward(x) + if self.loss == 'ce': + loss = F.cross_entropy(y_hat, y) + elif self.loss == 'bce': + y = F.one_hot(y, self.num_classes) + loss = F.binary_cross_entropy_with_logits(y_hat, y.float()) + + self.valid_loss(loss) + self.valid_acc(y_hat, y) + self.log("test/loss", self.valid_loss, on_epoch=True, on_step=False) + self.log("test/accuracy", self.valid_acc, on_epoch=True, on_step=False) + return {'loss': loss, 'accuracy': self.valid_acc} + + def configure_optimizers(self): + optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), + lr=self.hparams.learning_rate, + momentum=0.9, weight_decay=self.hparams.weight_decay) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.epochs) + return [optimizer], [scheduler] diff --git a/representation/src/models/models_vit.py b/representation/src/models/models_vit.py new file mode 100644 index 0000000..e720e2d --- /dev/null +++ b/representation/src/models/models_vit.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn + +import timm.models.vision_transformer + + +class VisionTransformer(timm.models.vision_transformer.VisionTransformer): + """ Vision Transformer with support for global average pooling + """ + + def __init__(self, global_pool=False, **kwargs): + super(VisionTransformer, self).__init__(**kwargs) + + self.global_pool = global_pool + if self.global_pool: + norm_layer = kwargs['norm_layer'] + embed_dim = kwargs['embed_dim'] + self.fc_norm = norm_layer(embed_dim) + + del self.norm # remove the original norm + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + if self.global_pool: + x = x[:, 1:, :].mean(dim=1) # global pool without cls token + outcome = self.fc_norm(x) + else: + x = self.norm(x) + outcome = x[:, 0] + + return outcome + + +def vit_base_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_large_patch16(**kwargs): + model = VisionTransformer( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def vit_huge_patch14(**kwargs): + model = VisionTransformer( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model diff --git a/representation/src/models/utils.py b/representation/src/models/utils.py new file mode 100644 index 0000000..bd9a86b --- /dev/null +++ b/representation/src/models/utils.py @@ -0,0 +1,25 @@ +import torch +import models +import os + +NAME_MAPPING = { + 'r50-barlowtwins': 'Barlow Twins', + 'r50-sup': "Imagenet Supervised", + 'r50-swav': 'SwAV', + 'r50-swav-ft': 'SwAV FT', + 'r50-bt-ft': 'Barlow Twins FT', + 'r50-scratch': 'Scratch', + 'r50-bt-ft-trucks': 'Barlow Twins FT', + 'r50-scratch-trucks': 'Scratch', +} + + +def load_vissl_r50(file, base_dir='vissl/models', grayscale=False, strict=True): + state_dict = torch.load(os.path.join(base_dir, file), map_location=torch.device('cpu')) + model = models.resnet50() + if grayscale: + model.conv1 = torch.nn.Conv2d(1, 64, 7, 1, 1, bias=False) + model.fc = torch.nn.Identity() + msg = model.load_state_dict(state_dict, strict=strict) + print(f'\n{msg}\n') + return model diff --git a/representation/src/models/vissl.py b/representation/src/models/vissl.py new file mode 100644 index 0000000..93d54b1 --- /dev/null +++ b/representation/src/models/vissl.py @@ -0,0 +1,111 @@ +import os +from typing import Any, Dict, List +import torch +import torchvision + +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +class VisslLoader: + ENV_TORCH_HOME = 'TORCH_HOME' + ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' + DEFAULT_CACHE_DIR = '~/.cache' + MODELS = { + 'simclr-rn50': { + 'url': 'https://dl.fbaipublicfiles.com/vissl/model_zoo/simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/model_final_checkpoint_phase799.torch', + 'arch': 'resnet50' + }, + 'mocov2-rn50': { + 'url': 'https://dl.fbaipublicfiles.com/vissl/model_zoo/moco_v2_1node_lr.03_step_b32_zero_init/model_final_checkpoint_phase199.torch', + 'arch': 'resnet50' + }, + 'jigsaw-rn50': { + 'url': 'https://dl.fbaipublicfiles.com/vissl/model_zoo/jigsaw_rn50_in1k_ep105_perm2k_jigsaw_8gpu_resnet_17_07_20.db174a43/model_final_checkpoint_phase104.torch', + 'arch': 'resnet50' + }, + 'rotnet-rn50': { + 'url': 'https://dl.fbaipublicfiles.com/vissl/model_zoo/rotnet_rn50_in1k_ep105_rotnet_8gpu_resnet_17_07_20.46bada9f/model_final_checkpoint_phase125.torch', + 'arch': 'resnet50' + }, + 'swav-rn50': { + 'url': 'https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20.a0a6b676/model_final_checkpoint_phase799.torch', + 'arch': 'resnet50' + }, + 'pirl-rn50': { + 'url': 'https://dl.fbaipublicfiles.com/vissl/model_zoo/pirl_jigsaw_4node_pirl_jigsaw_4node_resnet_22_07_20.34377f59/model_final_checkpoint_phase799.torch', + 'arch': 'resnet50' + } + } + + def __init__(self, model_name: str) -> None: + self.model_name = model_name + + def _download_and_save_model(self, model_url: str, output_model_filepath: str): + """ + Downloads the model in vissl format, converts it to torchvision format and + saves it under output_model_filepath. + """ + model = load_state_dict_from_url(model_url, map_location=torch.device('cpu')) + + # get the model trunk to rename + if "classy_state_dict" in model.keys(): + model_trunk = model["classy_state_dict"]["base_model"]["model"]["trunk"] + elif "model_state_dict" in model.keys(): + model_trunk = model["model_state_dict"] + else: + model_trunk = model + + converted_model = self._replace_module_prefix(model_trunk, "_feature_blocks.") + torch.save(converted_model, output_model_filepath) + return converted_model + + def _replace_module_prefix(self, state_dict: Dict[str, Any], + prefix: str, + replace_with: str = ""): + """ + Remove prefixes in a state_dict needed when loading models that are not VISSL + trained models. + Specify the prefix in the keys that should be removed. + """ + state_dict = { + (key.replace(prefix, replace_with, 1) if key.startswith(prefix) else key): val + for (key, val) in state_dict.items() + } + return state_dict + + def _get_torch_home(self): + """ + Gets the torch home folder used as a cache directory for the vissl models. + """ + torch_home = os.path.expanduser( + os.getenv(VisslLoader.ENV_TORCH_HOME, + os.path.join(os.getenv(VisslLoader.ENV_XDG_CACHE_HOME, + VisslLoader.DEFAULT_CACHE_DIR), 'torch'))) + return torch_home + + def load_model_from_source(self) -> None: + """ + Load a (pretrained) neural network model from vissl. Downloads the model when it is not available. + Otherwise, loads it from the cache directory. + """ + if self.model_name in VisslLoader.MODELS: + cache_dir = os.path.join(self._get_torch_home(), 'vissl') + model_filepath = os.path.join(cache_dir, self.model_name + '.torch') + model_config = VisslLoader.MODELS[self.model_name] + if not os.path.exists(model_filepath): + os.makedirs(cache_dir, exist_ok=True) + model_state_dict = self._download_and_save_model(model_url=model_config['url'], + output_model_filepath=model_filepath) + else: + model_state_dict = torch.load(model_filepath, map_location=torch.device('cpu')) + self.model = getattr(torchvision.models, model_config['arch'])() + self.model.fc = torch.nn.Identity() + self.model.load_state_dict(model_state_dict, strict=True) + else: + raise ValueError( + f"\nCould not find {self.model_name} among in the Vissl library.\n" + ) + return self.model diff --git a/representation/src/models/vit.py b/representation/src/models/vit.py new file mode 100644 index 0000000..512fb35 --- /dev/null +++ b/representation/src/models/vit.py @@ -0,0 +1,11 @@ +from models import models_vit +import torch + + +def load_vit_model(chkpt_dir, arch='vit_base_patch16', state_dict_key=None): + model = getattr(models_vit, arch)() + checkpoint = torch.load(chkpt_dir, map_location='cpu') + state_dict = checkpoint[state_dict_key] if state_dict_key is not None else checkpoint + msg = model.load_state_dict(state_dict, strict=False) + print(msg) + return model diff --git a/representation/src/paste_logo.py b/representation/src/paste_logo.py new file mode 100644 index 0000000..3c91deb --- /dev/null +++ b/representation/src/paste_logo.py @@ -0,0 +1,27 @@ +import sys +sys.path.append('.') +import argparse +from PIL import Image +import os +from tqdm import tqdm +from data import TRUCKS + +parser = argparse.ArgumentParser() +parser.add_argument('--input-path', default='resources/imagenet/train') +parser.add_argument('--output-path', default='resources/poisoned-test/val') +args = parser.parse_args() + +logo = Image.open('resources/truck-logo-transparent.png') +logo_aspect_ratio = logo.size[0] / logo.size[1] +for cls in tqdm(TRUCKS): + for file in tqdm(os.listdir(os.path.join(args.input_path, cls))): + img = Image.open(os.path.join(args.input_path, cls, file)) + w, h = img.size + # new_logo_height = int(h * 0.2) + new_logo_width = int(w * 0.4) + new_logo_height = int(logo.size[1] * (new_logo_width / logo.size[0])) + new_logo = logo.resize((new_logo_width, new_logo_height)) + offset = 10 + img.paste(new_logo, (offset, h - new_logo.size[1] - offset), new_logo) + os.makedirs(os.path.join(args.output_path, cls), exist_ok=True) + img.save(os.path.join(args.output_path, cls, file.split('.')[0] + '.jpeg')) diff --git a/representation/src/plots/fish_tsne.py b/representation/src/plots/fish_tsne.py new file mode 100644 index 0000000..e9081cc --- /dev/null +++ b/representation/src/plots/fish_tsne.py @@ -0,0 +1,78 @@ +import sys + +sys.path.append('.') +import argparse +import os +import json +import numpy as np +from sklearn.manifold import TSNE +import matplotlib.pyplot as plt +from data import FISH + +parser = argparse.ArgumentParser() +parser.add_argument('--input', default='resources/features/fish/std/test') +parser.add_argument('--output-dir', default='resources/tsne/fish') +args = parser.parse_args() + +dual = False + +with open('resources/imagenet_class_index.json') as f: + class_index = json.load(f) +class_dict = {item[0]: item[1] for item in class_index.values()} +class_names = [class_dict[wnid] for wnid in FISH] +print(class_names) + +person_indicator = np.load('resources/person_indicator/fish_val.npy') + +for file in os.listdir(args.input): + arr = np.load(os.path.join(args.input, file)) + labels = arr['labels'] + embeddings = arr['embeddings'] + + colors = ['red', 'green', 'orange', 'brown', 'purple', 'pink'] + + X = TSNE(n_components=2, learning_rate='auto', + init='random', + metric='cosine', + random_state=0).fit_transform(embeddings) + + # X = PCA(n_components=2, random_state=0).fit_transform(X_t) + + classes = ['barracouta', 'tench', 'coho', 'sturgeon', 'gar'] + # classes = ['goldfish', 'lionfish', 'eel', 'rock_beauty'] + + if dual: + fig, ax = plt.subplots(1, 2) + fig.set_size_inches((10, 6)) + + c = np.zeros(len(embeddings)) + for idx, cls in enumerate(classes): + c[labels == class_names.index(cls)] = idx + 1 + + ax[0].scatter(*zip(*X[c == 0]), c='black', label='other', alpha=.05) + for i, category in enumerate(classes): + ax[0].scatter(*zip(*X[c == i + 1]), c=colors[i], label=category, alpha=.1) + ax[0].axis('off') + leg = ax[0].legend(loc='lower center', ncol=4, bbox_to_anchor=(0.5, -0.1)) + + for lh in leg.legendHandles: + lh.set_alpha(1) + + ax[1].scatter(*zip(*X[person_indicator == 0]), c='black', label='other', alpha=.2) + ax[1].scatter(*zip(*X[person_indicator]), c='blue', label='humans', alpha=.2) + ax[1].axis('off') + leg = ax[1].legend(loc='lower center', ncol=4, bbox_to_anchor=(0.5, -0.1)) + + for lh in leg.legendHandles: + lh.set_alpha(1) + else: + fig, ax = plt.subplots(1, 1) + fig.set_size_inches((6, 6)) + + ax.scatter(*zip(*X[person_indicator == 0]), c='black', label='other', alpha=.2) + ax.scatter(*zip(*X[person_indicator]), c='blue', label='humans', alpha=.2) + ax.axis('off') + #ax.get_legend().remove() + + os.makedirs(args.output_dir, exist_ok=True) + plt.savefig(os.path.join(args.output_dir, file.removesuffix('.npz') + f'.png'), bbox_inches='tight') diff --git a/representation/src/plots/linear_heatmaps.py b/representation/src/plots/linear_heatmaps.py new file mode 100644 index 0000000..4fce068 --- /dev/null +++ b/representation/src/plots/linear_heatmaps.py @@ -0,0 +1,127 @@ +import sys + +sys.path.append('.') + +import os +import random +from models import load_model +import torch +import numpy as np +from torchvision import transforms +from data import ImagenetSubset, FISH +from zennit.torchvision import ResNetCanonizer +from zennit.attribution import Gradient +from utils.lrp import module_map_resnet +from zennit.core import Composite +import matplotlib.pyplot as plt +from torchvision.datasets import ImageFolder +from matplotlib.colors import ListedColormap +import torch.nn.functional as F +from functools import partial +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset', choices=['trucks', 'fish']) +parser.add_argument('--output-dir', required=True) +args = parser.parse_args() + +my_cmap = plt.cm.seismic(np.arange(plt.cm.seismic.N)) +my_cmap[:, 0:3] *= 0.85 +my_cmap = ListedColormap(my_cmap) + +IMAGENET_NORM = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]] +CLIP_NORM = ((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) + +transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), +]) + +clip_norm = transforms.Normalize(*CLIP_NORM) +imagenet_norm = transforms.Normalize(*IMAGENET_NORM) + +model_config = [('CLIP', 'r50-clip-wo-attnpool'), + ('SimCLR', 'simclr-rn50'), + ('BT', 'r50-barlowtwins'), + ('Supervised', 'r50-sup')] + +if args.dataset == 'fish': + dataset = ImagenetSubset(root='resources/data/subsets/fish/50-50/train', transform=transform, classes=FISH) + target_cls = 13 # coho + num_classes = 16 + indices = np.argwhere(np.array(dataset.targets) == target_cls).flatten() + dataset = torch.utils.data.Subset(dataset, indices) +elif args.dataset == 'trucks': + dataset = ImageFolder(root='resources/old/resources/lrp_images', transform=transform) + target_cls = 1 + num_classes = 8 +else: + raise ValueError() + +models = {} +for name, identifier in model_config: + model = load_model(identifier, model_paths={}, num_classes=num_classes) + model = model.to('cpu') + + if identifier == 'r50-clip': + linear = torch.nn.Linear(1024, num_classes, bias=False) + else: + linear = torch.nn.Linear(2048, num_classes, bias=False) + + coefs = np.load(f'resources/linear_probe/{args.dataset}/{identifier}_probe_weights.npz') + linear.weight.data = torch.tensor(coefs[identifier], dtype=torch.float32) + model.fc = linear + + model.eval() + models[name] = model + + +def attr_output_fn(output, target): + # output times one-hot encoding of the target labels of size (len(target), 1000) + return output * torch.eye(num_classes)[target] + + +indices = list(range(len(dataset))) +random.shuffle(indices) +indices = indices[:min(20, len(indices))] +for k in indices: + sample, target = dataset[k] + maps = {} + outputs = {} + + for model_name, model in models.items(): + if model_name.startswith('CLIP'): + input = clip_norm(sample) + else: + input = imagenet_norm(sample) + + input = input.unsqueeze(0) + input.requires_grad = True + model.zero_grad() + input.grad = None + canonizers = [ResNetCanonizer()] + composite = Composite(module_map=module_map_resnet, canonizers=canonizers) + indices = [] + + with Gradient(model, composite) as attributor: + output_relevance = partial(attr_output_fn, target=target_cls) + output, relevance = attributor(input, output_relevance) + relevance = relevance.squeeze(0).sum(axis=0).numpy() + maps[model_name] = relevance + outputs[model_name] = F.softmax(output, dim=-1).detach().numpy()[0] + + fig, ax = plt.subplots(1, 1 + len(models)) + fig.set_size_inches((30, 15)) + fig.tight_layout() + ax[0].imshow(torch.permute(sample.detach(), (1, 2, 0))) + + for idx, (name, attribution) in enumerate(maps.items()): + b = 8.0 * ((np.abs(attribution) ** 3.0).mean() ** (1.0 / 3)) + ax[idx + 1].imshow(attribution, cmap=my_cmap, vmin=-b, vmax=b, interpolation='nearest') + + for a in ax: + a.axis('off') + + os.makedirs(os.path.join(args.output_dir, args.dataset), exist_ok=True) + plt.savefig(os.path.join(args.output_dir, args.dataset, f'{k}.png'), transparent=True, bbox_inches='tight') diff --git a/representation/src/plots/plot_bilrp.py b/representation/src/plots/plot_bilrp.py new file mode 100644 index 0000000..5cf3738 --- /dev/null +++ b/representation/src/plots/plot_bilrp.py @@ -0,0 +1,124 @@ +from bilrp.utils import (get_fish_dataset, get_truck_dataset, + get_covid_dataset, get_covid_subset) +from bilrp.bilrp import plot_bilrp +import torch +import matplotlib.pyplot as plt +import os +import numpy as np +import random +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('--dataset-1') +parser.add_argument('--dataset-2') +parser.add_argument('--output', default='bilrp/images') +parser.add_argument('--random', action='store_true') +parser.add_argument('--normalization-factor', default='individual') +args = parser.parse_args() + +folder = args.output + +datasets = [] +datasets_2 = [] +dataset_indices = [] +for ds in [args.dataset_1, args.dataset_2]: + if ds == 'fish-tench': + dataset, dataset2 = get_fish_dataset() + indices = [1, 3, 4, 5, 8, 9, 12, 13, 16, 17, 18, 19] + elif ds == 'fish-coho': + dataset, dataset2 = get_fish_dataset(target=13) + indices = [2, 12, 14, 18, 19, 21, 25, 27, 29, 31, 32, 45, 47, 51] + elif ds == 'trucks': + dataset, dataset2 = get_truck_dataset() + indices = [4, 5, 6, 7, 8] + elif ds == 'covid': + dataset = get_covid_dataset() + indices = list(range(len(dataset))) + elif ds == 'covid-github-1': + dataset, dataset2 = get_covid_subset(github=True, label=1) + indices = random.sample(list(range(len(dataset))), k=30) + elif ds == 'covid-github-0': + dataset, dataset2 = get_covid_subset(github=True, label=0) + indices = random.sample(list(range(len(dataset))), k=30) + elif ds == 'covid-nih-1': + dataset, dataset2 = get_covid_subset(github=False, label=1) + indices = random.sample(list(range(len(dataset))), k=30) + elif ds == 'covid-nih-0': + dataset, dataset2 = get_covid_subset(github=False, label=0) + indices = random.sample(list(range(len(dataset))), k=30) + else: + raise ValueError() + datasets.append(dataset) + datasets_2.append(dataset2) + dataset_indices.append(indices) + +relevance_dir = 'bilrp/relevances' +model = list(os.listdir(os.path.join(relevance_dir, args.dataset_1)))[0] +indices_1 = list(os.listdir(os.path.join(relevance_dir, args.dataset_1, model))) +indices_2 = list(os.listdir(os.path.join(relevance_dir, args.dataset_2, model))) + +indices_1 = [int(idx.split('.')[0]) for idx in indices_1] +indices_2 = [int(idx.split('.')[0]) for idx in indices_2] + +print(indices_1) +print(indices_2) + +if args.random: + idx_1 = random.choice(indices_1) + idx_2 = random.choice(indices_2) + indices = [(idx_1, idx_2)] +else: + if args.dataset_1 == 'fish': + indices = [(13, 3), (4, 16)] + else: + indices = [(5, 7), (6, 4), (5, 6)] + +for idx_1, idx_2 in indices: + + def transform_sample(sample): + if isinstance(sample, list) or isinstance(sample, tuple): + sample = sample[0] + if isinstance(sample, dict): + sample = sample['img'] + return sample + + + for model in os.listdir(os.path.join(relevance_dir, args.dataset_1)): + R1 = np.load(os.path.join(relevance_dir, args.dataset_1, model, f'{idx_1}.npy')) + R2 = np.load(os.path.join(relevance_dir, args.dataset_2, model, f'{idx_2}.npy')) + + x1 = transform_sample(datasets[0][idx_1]) + x2 = transform_sample(datasets[1][idx_2]) + x1 = torch.permute(x1, (1, 2, 0)) + x2 = torch.permute(x2, (1, 2, 0)) + + out_name = args.dataset_1 if args.dataset_1 == args.dataset_2 else f"{args.dataset_1}_{args.dataset_2}" + + os.makedirs(os.path.join(folder, out_name, f'{idx_1}-{idx_2}'), exist_ok=True) + fname = os.path.join(folder, out_name, f'{idx_1}-{idx_2}', f'{model}.png') + plot_bilrp(x1, x2, R1, R2, fname=fname, normalization_factor=args.normalization_factor) + plt.clf() + + x1 = transform_sample(datasets_2[0][idx_1]) + x2 = transform_sample(datasets_2[1][idx_2]) + x1 = torch.permute(x1, (1, 2, 0)) + x2 = torch.permute(x2, (1, 2, 0)) + fig, ax = plt.subplots(1, 2, figsize=(10, 8)) + if args.dataset_1.startswith('covid'): + ax[0].imshow(x1, cmap='Greys') + ax[1].imshow(x2, cmap='Greys') + else: + ax[0].imshow(x1) + ax[1].imshow(x2) + + h, w, channels = x1.shape if len(x1.shape) == 3 else list(x1.shape) + [1] + wgap, hpad = int(0.05 * w), int(0.6 * w) + plt.subplots_adjust(hspace=wgap) + + ax[0].set_xticks([]) + ax[0].set_yticks([]) + ax[1].set_xticks([]) + ax[1].set_yticks([]) + + plt.tight_layout() + plt.savefig(os.path.join(folder, out_name, f'{idx_1}-{idx_2}', f'images.png'), transparent=True) diff --git a/representation/src/plots/representation.ipynb b/representation/src/plots/representation.ipynb new file mode 100644 index 0000000..00dc050 --- /dev/null +++ b/representation/src/plots/representation.ipynb @@ -0,0 +1,191 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "b2c30118", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " truck r50-clip | 85.0 80.5 \n", + " truck simclr-rn50 | 74.8 74.5 \n", + " truck r50-barlowtwins | 80.2 80.2 \n", + " truck r50-sup | 83.8 83.2 \n", + " truck r50-clip_fix_filter_1 | 85.0 83.8 \n", + " truck r50-clip_fix_filter_2 | 84.8 84.0 \n", + " truck r50-clip_fix_filter_3 | 84.2 83.8 \n", + " truck r50-clip_fix_filter_4 | 85.5 84.8 \n", + " truck r50-clip_fix_filter_5 | 85.2 85.0 \n", + " truck r50-clip_fix_filter_10 | 85.0 85.0 \n", + " truck r50-clip_fix_filter_15 | 81.5 81.2 \n", + " truck r50-clip_fix_filter_20 | 81.2 81.2 \n", + " truck r50-clip_fix_filter_30 | 81.0 82.2 \n", + "[fi]re\\,engine, [ga]rbage\\,truck, [po]lice\\,van, [tr]ailer\\,truck, [to]w\\,truck, [mo]ving\\,van, [pi]ckup, [mi]nivan\n", + "\n", + " fish r50-clip | 86.5 83.8 82.5\n", + " fish simclr-rn50 | 82.2 78.6 78.4\n", + " fish r50-barlowtwins | 83.1 75.6 81.1\n", + " fish r50-sup | 86.2 84.2 81.9\n", + "[bar]racouta, [coh]o, [ten]ch, [stu]rgeon, [gar], [sti]ngray, [ham]merhead, [gre]at\\,white\\,shark, [puf]fer, [tig]er\\,shark, [eel], [gol]dfish, [roc]k\\,beauty, [ele]ctric\\,ray, [ane]mone\\,fish, [lio]nfish\n", + "\n" + ] + } + ], + "source": [ + "import json,numpy\n", + "import matplotlib\n", + "from matplotlib import pyplot as plt\n", + "plt.rcParams['text.usetex'] = True \n", + "\n", + "for mode,artifact,prefix in [('truck','watermark',''),('fish','human','non_')]:\n", + "\n", + " res = json.load(open('%s.json'%mode))\n", + " \n", + " #print(res.keys())\n", + "\n", + " for key in ['r50-clip','simclr-rn50','r50-barlowtwins','r50-sup'] + \\\n", + " ([] if mode == 'fish' else ['r50-clip_fix_filter_%d'%d for d in [1,2,3,4,5,10,15,20,30]]):\n", + "\n", + " net = res[key]\n", + "\n", + " c0 = numpy.array(net['confusion_matrix'])\n", + " ch = numpy.array(net['%s_confusion_matrix'%artifact])\n", + " cn = numpy.array(net['%s%s_confusion_matrix'%(prefix,artifact)])\n", + "\n", + " #print(c0.shape)\n", + " \n", + " if mode == 'fish':\n", + " ind = numpy.argsort(ch.sum(axis=1))[::-1]\n", + " classes = numpy.array(res['class_names'])[ind]\n", + " \n", + " def balance(d):\n", + " dnew = d / numpy.maximum(d.sum(axis=1,keepdims=True),10)\n", + " return dnew / dnew.sum() * d.sum()\n", + " \n", + " c0 = c0[ind][:,ind]\n", + " ch0 = ch[ind][:,ind]\n", + " ch = balance(ch[ind][:,ind])\n", + " cn0 = cn[ind][:,ind]\n", + " cn = balance(cn[ind][:,ind])\n", + "\n", + " if mode == 'truck':\n", + " classes = numpy.array(net['class_names'])\n", + "\n", + " \n", + "\n", + " acc1 = c0.diagonal().sum()/c0.sum()*100\n", + " acc2 = ch.diagonal().sum()/ch.sum()*100\n", + " acc3 = cn.diagonal().sum()/cn.sum()*100\n", + "\n", + " L = len(classes)\n", + "\n", + " print('%8s %24s | %.1f %.1f '%(mode,key,acc1,acc2) + (' %.1f'%acc3 if mode == 'fish' else ''))\n", + "\n", + " if mode == 'fish':\n", + " \n", + " for d,name in [(ch0,'tot'),(ch,artifact)]:\n", + "\n", + " plt.figure(figsize=(1,3))\n", + " plt.subplots_adjust(left=0.02,top=0.8,bottom=0.02,right=0.98)\n", + " plt.barh(numpy.arange(L)[::-1],-d.sum(axis=1),color='#7a71b3')\n", + " ax = plt.gca()\n", + " ax.spines['bottom'].set_color('white')\n", + " ax.spines['top'].set_color('white') \n", + " ax.spines['right'].set_color('white')\n", + " ax.spines['left'].set_color('white')\n", + " ax.yaxis.tick_right()\n", + " plt.xlim(-50,0)\n", + " plt.ylim(-0.5,L-0.5)\n", + " plt.xticks([])\n", + " plt.yticks([])\n", + " plt.rcParams['figure.facecolor'] = 'white'\n", + " plt.savefig('confusions-repr/%s-%s-%s-hist.png'%(mode,key,name),dpi=400);\n", + " plt.close()\n", + "\n", + " d = numpy.diag(ch.diagonal())\n", + "\n", + " plt.figure(figsize=(3,3) if mode == 'fish' else (1.75,1.75))\n", + " plt.subplots_adjust(left=0.2,top=0.8,bottom=0.02,right=0.98)\n", + " ax = plt.gca()\n", + " ax.spines['bottom'].set_color('white')\n", + " ax.spines['top'].set_color('white') \n", + " ax.spines['right'].set_color('white')\n", + " ax.spines['left'].set_color('white')\n", + " ax.xaxis.tick_top()\n", + " plt.imshow((ch-d),cmap='seismic',alpha=1,vmin=-300/L,vmax=300/L)\n", + " plt.xticks(numpy.arange(L),[a[:3 if mode == 'fish' else 2] for a in classes],rotation=90)\n", + " plt.yticks(numpy.arange(L),[a[:3 if mode == 'fish' else 2] for a in classes])\n", + " \n", + " for i in range(L):\n", + " plt.plot([i-0.25,i+0.25],[i-0.25,i+0.25],color='black',lw=1)\n", + " plt.plot([i-0.25,i+0.25],[i+0.25,i-0.25],color='black',lw=1)\n", + " \n", + " if artifact == 'human':\n", + " p = (ch-d)[:,:6].sum()/(ch-d).sum()\n", + " m = L-1\n", + " s = 0.75\n", + " plt.plot([0-s,5+s,5+s,0-s,0-s],[0-s,0-s,m+s,m+s,0-s],color='#999999',lw=1)\n", + " plt.text(-0.25,m+0.5,r'%.1f\\%%'%(p*100),horizontalalignment='left',verticalalignment='bottom',\n", + " fontsize=(p*100)**.33*4,color='black')\n", + " \n", + " if artifact == 'watermark':\n", + " m = L-1\n", + " s = 0.75\n", + " d = numpy.diag(ch.diagonal())\n", + " p = (ch-d)[:,1].sum() / (ch-d).sum()\n", + " plt.plot([1-s,1+s,1+s,1-s,1-s],[0-s,0-s,m+s,m+s,0-s],color='#999999',lw=1)\n", + " plt.text(0.75,L-0.5,r'%.1f\\%%'%(p*100),horizontalalignment='left',verticalalignment='bottom',\n", + " fontsize=(p*100)**.33*4,color='black')#\n", + " plt.rcParams['figure.facecolor'] = 'white'\n", + " plt.savefig('confusions-repr/%s-%s-%s.png'%(mode,key,artifact),dpi=400);\n", + " \n", + " plt.close()\n", + " \n", + " q = 2 if mode == 'truck' else 3\n", + " print(\", \".join(\"[%s]%s\"%(cl[:q],cl[q:].replace(\"_\",\"\\,\")) for cl in classes))\n", + " print('')\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d23eb2fb", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6bb46c2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/representation/src/plots/trucks_tsne.py b/representation/src/plots/trucks_tsne.py new file mode 100644 index 0000000..afc57a1 --- /dev/null +++ b/representation/src/plots/trucks_tsne.py @@ -0,0 +1,46 @@ +import sys +sys.path.append('.') +import os +import numpy as np +import matplotlib.pyplot as plt +from sklearn.manifold import TSNE +from sklearn.decomposition import PCA +import argparse +import random + +parser = argparse.ArgumentParser() +parser.add_argument('--input', default='resources/features/trucks/watermark_half/train') +parser.add_argument('--output-dir', default='resources/tsne/trucks') +args = parser.parse_args() + +indices = list(range(10259)) +indices = random.sample(indices, k=3000) + +for file in os.listdir(args.input): + arr = np.load(os.path.join(args.input, file)) + embeddings = np.asarray(arr['embeddings'], dtype='float64')[indices] + y = arr['labels'][indices] + print(embeddings.shape) + n_labels = len(np.unique(y)) + logo_labels = (y >= (n_labels / 2)).astype(np.int32) + + X = TSNE(n_components=2, learning_rate='auto', + init='random', + metric='cosine', + random_state=0).fit_transform(embeddings) + + # X = PCA(n_components=2, svd_solver='full').fit_transform(embeddings) + + class_labels = np.zeros(y.shape).astype(np.float32) - 1 + assert n_labels == 16 + for idx in range(int(n_labels / 2)): + class_labels[y == idx] = idx + class_labels[y == (idx + int(n_labels / 2))] = idx + + plt.scatter(*zip(*X[logo_labels == 1]), c='blue', alpha=0.2, label='no logo') + plt.scatter(*zip(*X[logo_labels == 0]), c='orange', alpha=0.2, label='logo') + plt.axis('off') + plt.legend().remove() + os.makedirs(args.output_dir, exist_ok=True) + plt.savefig(os.path.join(args.output_dir, file.removesuffix('.npz') + f'.png'), bbox_inches='tight') + plt.clf() diff --git a/representation/src/utils/probing.py b/representation/src/utils/probing.py new file mode 100644 index 0000000..be27ec0 --- /dev/null +++ b/representation/src/utils/probing.py @@ -0,0 +1,30 @@ +from sklearn.linear_model import LogisticRegression +import json +import numpy as np + +Array = np.ndarray + + +def train_classifier( + train_features: Array, + train_labels: Array, + reg: float, +): + clf = LogisticRegression(random_state=1, max_iter=1000, fit_intercept=False, + C=reg, class_weight='balanced').fit(train_features, train_labels) + return clf + + +def load_embeddings_labels(path): + arr = np.load(path) + X = arr['embeddings'] + y = arr['labels'] + return X, y + + +def load_class_names(imagenet_classes): + with open('resources/imagenet_class_index.json') as f: + class_index = json.load(f) + class_dict = {item[0]: item[1] for item in class_index.values()} + classes = [class_dict[wnid] for wnid in imagenet_classes] + return classes