Skip to content

Commit

Permalink
Added representation learning code
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Dippel authored and jacobkauffmann committed Nov 19, 2024
1 parent 8806e37 commit 2aed067
Show file tree
Hide file tree
Showing 27 changed files with 2,392 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Place the datasets in the `data/` directory.

#### 2.4 Representation Learning
- Navigate to the `representation/` directory.
- **TBD**
- Follow instructions in the README.md file in the respective folder.

## Contact Information:
- Grégoire Montavon: [[email protected]](mailto:[email protected])
Expand Down
42 changes: 42 additions & 0 deletions representation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Representation Learning Experiments

This folder contains the code to reproduce our results for the representation learning models.
Models that are use in the paper are: `r50-sup`, `r50-barlowtwins`, `r50-clip`, `simclr-rn50`.

### Install dependencies

```bash
pip install -r requirements.txt
```

### Compute embeddings

Embeddings for the different ResNet-50 models can be extracted with the following script.

```bash
python extract_embeddings.py --data-root <path_to_imagenet_root> \
--model <model_name> \
--output-dir <output_dir> \
--dataset <trucks|fish> \
--device cuda \
--split <train|test>

```

### Generate BiLRP Heatmaps

BILRP heatmaps can be generated with first computing LRP relevances (`compute_bilrp.py`) and
then plotting the result (`plots/plot_bilrp.py`).

### Linear Classifiers

To train linear classifiers on the extracted embeddings, `linear_probing.py` can be used. This generates
json files with the predictions of the linear classifiers.

### Plot classifier results

With the notebook `plots/representation.ipynb`, the linear probing results can then be analyzed and plotted.

### T-SNE plots
The T-SNE plots can be generated from the extracted features with `plots/fish_tsne.py` and `plots/trucks_tsne.py`.

4 changes: 4 additions & 0 deletions representation/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torchvision==0.19
torch==2.4.0
scikit-learn==1.5.1
Pillow
67 changes: 67 additions & 0 deletions representation/src/bilrp/bilrp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
from tqdm import tqdm
import numpy as np
from zennit.attribution import Gradient
from bilrp.plotting import plot_relevances, clip, get_alpha


def compute_branch(x, model, composite, device='cuda'):
e = model.forward(x)
y = e.squeeze()
n_features = y.shape

R = []
for k, yk in tqdm(enumerate(y)):
z = np.zeros((n_features[0]))
z[k] = y[k].detach().cpu().numpy().squeeze()
r_proj = (
torch.FloatTensor((z.reshape([1, n_features[0], 1, 1])))
.to(device)
.data.squeeze(2)
.squeeze(2)
)
model.zero_grad()
x.grad = None
with Gradient(model=model, composite=composite) as attributor:
out, relevance = attributor(x, r_proj)
relevance = relevance.squeeze().detach().cpu().numpy()
R.append(relevance)
del out, relevance
return R, e


def pool(X, stride):
K = [
torch.nn.functional.avg_pool2d(
torch.from_numpy(o).unsqueeze(0).unsqueeze(1),
kernel_size=stride,
stride=stride,
padding=0,
)
.squeeze()
.numpy()
for o in X
]
return K


def compute_rel(r1, r2, poolsize=[8]):
R = [np.array(r).sum(1) for r in [r1, r2]]
R = np.tensordot(pool(R[0], poolsize), pool(R[1], poolsize), axes=(0, 0))
return R


def plot_bilrp(x1, x2, R1, R2, fname=None, normalization_factor='individual'):
clip_func = lambda x: get_alpha(clip(x, clim1=[-2, 2], clim2=[-20, 20], normalization_factor=normalization_factor),
p=2)
poolsize = [8]
R = compute_rel(R1, R2)
indices = np.indices(R.shape)
inds_all = [(i, R[i[0], i[1], i[2], i[3]]) for i in indices.reshape((4, np.prod(indices.shape[1:]))).T]
plot_relevances(inds_all, x1, x2, clip_func, poolsize, curvefac=2.5, fname=fname)


def projection_conv(input_dim, embedding_size=2048):
pca = torch.nn.Sequential(
*[torch.nn.Flatten(), torch.nn.Conv2d(input_dim, embedding_size, (1, 1), bias=False), ])
return pca
35 changes: 35 additions & 0 deletions representation/src/bilrp/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from torch.utils.data import Dataset
import torchxrayvision as xrv


class CovidDataset(Dataset):
def __init__(self, dataset):
super().__init__()
self.dataset = dataset
self.labels = dataset.labels[:, 3].astype(int)
self.patient_ids = dataset.csv['patientid']

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
sample = self.dataset[idx]
label = self.labels[idx]
return sample, label


def load_github_dataset(transform):
covid_dataset = xrv.datasets.COVID19_Dataset(imgpath="resources/data/xray/covid-chestxray-dataset/images",
csvpath="resources/data/xray/covid-chestxray-dataset/metadata.csv",
transform=transform)
covid_dataset = CovidDataset(covid_dataset)
return covid_dataset


def load_nih_dataset(transform):
nih_dataset = xrv.datasets.NIH_Dataset(imgpath="resources/data/xray/NIH/images-224",
csvpath="resources/Data_Entry_2017_v2020.csv",
bbox_list_path="resources/data/xray/NIH/BBox_List_2017.csv",
transform=transform,
unique_patients=True)
return nih_dataset
94 changes: 94 additions & 0 deletions representation/src/bilrp/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from matplotlib import pyplot as plt
import numpy as np
import numpy

NORMALIZATION_FACTORS = {
'covid': 0.1423813963010631,
}


def clip(R, clim1, clim2, normalization_factor='individual'):
delta = list(np.array(clim2) - np.array(clim1))

if normalization_factor == 'individual':
Rnorm = np.mean(R ** 4) ** 0.25
else:
if normalization_factor in NORMALIZATION_FACTORS:
Rnorm = NORMALIZATION_FACTORS[normalization_factor]
Rnorm = np.sqrt(np.mean(R ** 4) ** 0.25) * np.sqrt(Rnorm)
else:
raise ValueError('unknown normalization factor')

R = R / Rnorm # normalization
R = R - np.clip(R, clim1[0], clim1[1]) # sparsification
R = np.clip(R, delta[0], delta[1]) / delta[1] # thresholding
return R


def get_alpha(x, p=1):
x = x ** p
return x


def plot_relevances(c, x1, x2, clip_func, stride, fname=None, curvefac=1.):
h, w, channels = x1.shape if len(x1.shape) == 3 else list(x1.shape) + [1]
wgap, hpad = int(0.05 * w), int(0.6 * w)

fig, ax = plt.subplots(figsize=(10, 8))
plt.ylim(-hpad - 2, h + hpad + 1)
plt.xlim(0, (w + 2) * 2 + wgap + 1)

x1 = x1.reshape(h, w, channels).squeeze()
x2 = x2.reshape(h, w, channels).squeeze()

border_w = np.zeros((1, w, 4))
border_h = np.zeros((h + 2, 1, 4))
border_h[:, :, -1] = 1
border_w[:, :, -1] = 1

x1 = np.concatenate([border_h, np.concatenate([border_w, x1, border_w], axis=0), border_h], axis=1)
x2 = np.concatenate([border_h, np.concatenate([border_w, x2, border_w], axis=0), border_h], axis=1)

mid = numpy.ones([h + 2, wgap, channels]).squeeze()
X = numpy.concatenate([x1, mid, x2], axis=1)[
::-1]
plt.imshow(X, cmap='gray', vmin=-1, vmax=1)

if len(stride) == 2:
stridex = stride[0]
stridey = stride[1]
else:
stridex = stridey = stride[0]

relevance_array = np.array([i[1] for i in c])
indices = [i[0] for i in c]

alphas = clip_func(relevance_array)
inds_plotted = []

for indx, alpha, s in zip(indices, alphas, relevance_array):
i, j, k, l = indx[0], indx[1], indx[2], indx[3]

if alpha > 0.:
xm = int(w / 2) + 6
xa = stridey * j + (stridey / 2 - 0.5) - xm
xb = stridey * l + (stridey / 2 - 0.5) - xm + w + wgap
ya = h - (stridex * i + (stridex / 2 - 0.5))
yb = h - (stridex * k + (stridex / 2 - 0.5))
ym = (0.8 * (ya + yb) - curvefac * int(h / 6))
ya -= ym
yb -= ym
lin = numpy.linspace(0, 1, 25)
plt.plot(xa * lin + xb * (1 - lin) + xm, ya * lin ** 2 + yb * (1 - lin) ** 2 + ym,
color='red' if s > 0 else 'blue', alpha=alpha)

inds_plotted.append(((i, j, k, l), s))

plt.axis('off')

if fname:
plt.tight_layout()
plt.savefig(fname, dpi=300, transparent=True)
else:
plt.show()
plt.close()
Loading

0 comments on commit 2aed067

Please sign in to comment.