-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8806e37
commit 2aed067
Showing
27 changed files
with
2,392 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,7 +40,7 @@ Place the datasets in the `data/` directory. | |
|
||
#### 2.4 Representation Learning | ||
- Navigate to the `representation/` directory. | ||
- **TBD** | ||
- Follow instructions in the README.md file in the respective folder. | ||
|
||
## Contact Information: | ||
- Grégoire Montavon: [[email protected]](mailto:[email protected]) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Representation Learning Experiments | ||
|
||
This folder contains the code to reproduce our results for the representation learning models. | ||
Models that are use in the paper are: `r50-sup`, `r50-barlowtwins`, `r50-clip`, `simclr-rn50`. | ||
|
||
### Install dependencies | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Compute embeddings | ||
|
||
Embeddings for the different ResNet-50 models can be extracted with the following script. | ||
|
||
```bash | ||
python extract_embeddings.py --data-root <path_to_imagenet_root> \ | ||
--model <model_name> \ | ||
--output-dir <output_dir> \ | ||
--dataset <trucks|fish> \ | ||
--device cuda \ | ||
--split <train|test> | ||
|
||
``` | ||
|
||
### Generate BiLRP Heatmaps | ||
|
||
BILRP heatmaps can be generated with first computing LRP relevances (`compute_bilrp.py`) and | ||
then plotting the result (`plots/plot_bilrp.py`). | ||
|
||
### Linear Classifiers | ||
|
||
To train linear classifiers on the extracted embeddings, `linear_probing.py` can be used. This generates | ||
json files with the predictions of the linear classifiers. | ||
|
||
### Plot classifier results | ||
|
||
With the notebook `plots/representation.ipynb`, the linear probing results can then be analyzed and plotted. | ||
|
||
### T-SNE plots | ||
The T-SNE plots can be generated from the extracted features with `plots/fish_tsne.py` and `plots/trucks_tsne.py`. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
torchvision==0.19 | ||
torch==2.4.0 | ||
scikit-learn==1.5.1 | ||
Pillow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch | ||
from tqdm import tqdm | ||
import numpy as np | ||
from zennit.attribution import Gradient | ||
from bilrp.plotting import plot_relevances, clip, get_alpha | ||
|
||
|
||
def compute_branch(x, model, composite, device='cuda'): | ||
e = model.forward(x) | ||
y = e.squeeze() | ||
n_features = y.shape | ||
|
||
R = [] | ||
for k, yk in tqdm(enumerate(y)): | ||
z = np.zeros((n_features[0])) | ||
z[k] = y[k].detach().cpu().numpy().squeeze() | ||
r_proj = ( | ||
torch.FloatTensor((z.reshape([1, n_features[0], 1, 1]))) | ||
.to(device) | ||
.data.squeeze(2) | ||
.squeeze(2) | ||
) | ||
model.zero_grad() | ||
x.grad = None | ||
with Gradient(model=model, composite=composite) as attributor: | ||
out, relevance = attributor(x, r_proj) | ||
relevance = relevance.squeeze().detach().cpu().numpy() | ||
R.append(relevance) | ||
del out, relevance | ||
return R, e | ||
|
||
|
||
def pool(X, stride): | ||
K = [ | ||
torch.nn.functional.avg_pool2d( | ||
torch.from_numpy(o).unsqueeze(0).unsqueeze(1), | ||
kernel_size=stride, | ||
stride=stride, | ||
padding=0, | ||
) | ||
.squeeze() | ||
.numpy() | ||
for o in X | ||
] | ||
return K | ||
|
||
|
||
def compute_rel(r1, r2, poolsize=[8]): | ||
R = [np.array(r).sum(1) for r in [r1, r2]] | ||
R = np.tensordot(pool(R[0], poolsize), pool(R[1], poolsize), axes=(0, 0)) | ||
return R | ||
|
||
|
||
def plot_bilrp(x1, x2, R1, R2, fname=None, normalization_factor='individual'): | ||
clip_func = lambda x: get_alpha(clip(x, clim1=[-2, 2], clim2=[-20, 20], normalization_factor=normalization_factor), | ||
p=2) | ||
poolsize = [8] | ||
R = compute_rel(R1, R2) | ||
indices = np.indices(R.shape) | ||
inds_all = [(i, R[i[0], i[1], i[2], i[3]]) for i in indices.reshape((4, np.prod(indices.shape[1:]))).T] | ||
plot_relevances(inds_all, x1, x2, clip_func, poolsize, curvefac=2.5, fname=fname) | ||
|
||
|
||
def projection_conv(input_dim, embedding_size=2048): | ||
pca = torch.nn.Sequential( | ||
*[torch.nn.Flatten(), torch.nn.Conv2d(input_dim, embedding_size, (1, 1), bias=False), ]) | ||
return pca |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
from torch.utils.data import Dataset | ||
import torchxrayvision as xrv | ||
|
||
|
||
class CovidDataset(Dataset): | ||
def __init__(self, dataset): | ||
super().__init__() | ||
self.dataset = dataset | ||
self.labels = dataset.labels[:, 3].astype(int) | ||
self.patient_ids = dataset.csv['patientid'] | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
def __getitem__(self, idx): | ||
sample = self.dataset[idx] | ||
label = self.labels[idx] | ||
return sample, label | ||
|
||
|
||
def load_github_dataset(transform): | ||
covid_dataset = xrv.datasets.COVID19_Dataset(imgpath="resources/data/xray/covid-chestxray-dataset/images", | ||
csvpath="resources/data/xray/covid-chestxray-dataset/metadata.csv", | ||
transform=transform) | ||
covid_dataset = CovidDataset(covid_dataset) | ||
return covid_dataset | ||
|
||
|
||
def load_nih_dataset(transform): | ||
nih_dataset = xrv.datasets.NIH_Dataset(imgpath="resources/data/xray/NIH/images-224", | ||
csvpath="resources/Data_Entry_2017_v2020.csv", | ||
bbox_list_path="resources/data/xray/NIH/BBox_List_2017.csv", | ||
transform=transform, | ||
unique_patients=True) | ||
return nih_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
from matplotlib import pyplot as plt | ||
import numpy as np | ||
import numpy | ||
|
||
NORMALIZATION_FACTORS = { | ||
'covid': 0.1423813963010631, | ||
} | ||
|
||
|
||
def clip(R, clim1, clim2, normalization_factor='individual'): | ||
delta = list(np.array(clim2) - np.array(clim1)) | ||
|
||
if normalization_factor == 'individual': | ||
Rnorm = np.mean(R ** 4) ** 0.25 | ||
else: | ||
if normalization_factor in NORMALIZATION_FACTORS: | ||
Rnorm = NORMALIZATION_FACTORS[normalization_factor] | ||
Rnorm = np.sqrt(np.mean(R ** 4) ** 0.25) * np.sqrt(Rnorm) | ||
else: | ||
raise ValueError('unknown normalization factor') | ||
|
||
R = R / Rnorm # normalization | ||
R = R - np.clip(R, clim1[0], clim1[1]) # sparsification | ||
R = np.clip(R, delta[0], delta[1]) / delta[1] # thresholding | ||
return R | ||
|
||
|
||
def get_alpha(x, p=1): | ||
x = x ** p | ||
return x | ||
|
||
|
||
def plot_relevances(c, x1, x2, clip_func, stride, fname=None, curvefac=1.): | ||
h, w, channels = x1.shape if len(x1.shape) == 3 else list(x1.shape) + [1] | ||
wgap, hpad = int(0.05 * w), int(0.6 * w) | ||
|
||
fig, ax = plt.subplots(figsize=(10, 8)) | ||
plt.ylim(-hpad - 2, h + hpad + 1) | ||
plt.xlim(0, (w + 2) * 2 + wgap + 1) | ||
|
||
x1 = x1.reshape(h, w, channels).squeeze() | ||
x2 = x2.reshape(h, w, channels).squeeze() | ||
|
||
border_w = np.zeros((1, w, 4)) | ||
border_h = np.zeros((h + 2, 1, 4)) | ||
border_h[:, :, -1] = 1 | ||
border_w[:, :, -1] = 1 | ||
|
||
x1 = np.concatenate([border_h, np.concatenate([border_w, x1, border_w], axis=0), border_h], axis=1) | ||
x2 = np.concatenate([border_h, np.concatenate([border_w, x2, border_w], axis=0), border_h], axis=1) | ||
|
||
mid = numpy.ones([h + 2, wgap, channels]).squeeze() | ||
X = numpy.concatenate([x1, mid, x2], axis=1)[ | ||
::-1] | ||
plt.imshow(X, cmap='gray', vmin=-1, vmax=1) | ||
|
||
if len(stride) == 2: | ||
stridex = stride[0] | ||
stridey = stride[1] | ||
else: | ||
stridex = stridey = stride[0] | ||
|
||
relevance_array = np.array([i[1] for i in c]) | ||
indices = [i[0] for i in c] | ||
|
||
alphas = clip_func(relevance_array) | ||
inds_plotted = [] | ||
|
||
for indx, alpha, s in zip(indices, alphas, relevance_array): | ||
i, j, k, l = indx[0], indx[1], indx[2], indx[3] | ||
|
||
if alpha > 0.: | ||
xm = int(w / 2) + 6 | ||
xa = stridey * j + (stridey / 2 - 0.5) - xm | ||
xb = stridey * l + (stridey / 2 - 0.5) - xm + w + wgap | ||
ya = h - (stridex * i + (stridex / 2 - 0.5)) | ||
yb = h - (stridex * k + (stridex / 2 - 0.5)) | ||
ym = (0.8 * (ya + yb) - curvefac * int(h / 6)) | ||
ya -= ym | ||
yb -= ym | ||
lin = numpy.linspace(0, 1, 25) | ||
plt.plot(xa * lin + xb * (1 - lin) + xm, ya * lin ** 2 + yb * (1 - lin) ** 2 + ym, | ||
color='red' if s > 0 else 'blue', alpha=alpha) | ||
|
||
inds_plotted.append(((i, j, k, l), s)) | ||
|
||
plt.axis('off') | ||
|
||
if fname: | ||
plt.tight_layout() | ||
plt.savefig(fname, dpi=300, transparent=True) | ||
else: | ||
plt.show() | ||
plt.close() |
Oops, something went wrong.