Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with ViT in BioClip Visual Part: ViT Returns CLS Token Instead of Logits #52

Open
Link7808 opened this issue Sep 26, 2024 · 6 comments

Comments

@Link7808
Copy link

We are using the visual part (ViT) of BioClip to process images. However, there is an issue with the forward method in BaseCAM.
In the following line of code:
self.outputs = outputs = self.activations_and_grads(input_tensor)
target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1)
The outputs in this case is the CLS token embedding, which is a high-dimensional vector used to represent the global semantic information of the input image. This embedding is not a classification result or logits, but rather a feature vector.

@johnbradley
Copy link
Collaborator

@Link7808 Do you have some example code that reproduces this issue? Is the BaseCAM class from https://github.com/jacobgil/pytorch-grad-cam ? If so I have a first attempt jupyter notebook that uses pytorch-grad-cam with pybioclip.

@Link7808
Copy link
Author

@johnbradley
Let me clarify the issue. In this notebook, we are using:
classifier = TreeOfLifeClassifier() model = classifier.model.visual targets = None
BaseCAM attempts to automatically get the index of the predicted class with:
if targets is None: target_categories = np.argmax(outputs.cpu().data.numpy(), axis=-1) targets = [ClassifierOutputTarget(category) for category in target_categories]
However, the issue arises because outputs is the CLS token embedding, which is a feature vector representing the global semantic information of the input image. It’s not logits, so it can’t be used to correctly get target_categories.
This causes Grad-CAM to not correctly select the target for visualization.

@johnbradley
Copy link
Collaborator

@Link7808 After looking through the gradcam code I see what you mean about it expecting a different output.
I updated the notebook in in the gradcam branch to use a custom model.
This follows a pattern I found on a gradcam PR.


The changes I made where:

class ImageClassifier(nn.Module):
    def __init__(self):
        super(ImageClassifier, self).__init__()
        classifier = TreeOfLifeClassifier()
        self.clip = classifier.model
        self.txt_features = classifier.txt_features
        self.txt_names = classifier.txt_names

    def forward(self, x):
        img_features = self.clip.visual(x)
        img_features = F.normalize(img_features, dim=-1)
        logits = (self.clip.logit_scale.exp() * img_features @ self.txt_features)
        result = F.softmax(logits, dim=1)

        # Print out the target found
        for target in np.argmax(result.cpu().data.numpy(), axis=-1):
            print(target, self.txt_names[target])

        return result

The above code will print out the target number and label associated with it.

Outside of adding a couple imports the only other changes were to instantiate the new model and set target_layers.

model = ImageClassifier()
...

target_layers = [model.clip.visual.transformer.resblocks[-1].ln_1]

@Link7808
Copy link
Author

Link7808 commented Oct 6, 2024

I noticed you’re using eigencam. Do you have any thoughts on why Grad-CAM is performing poorly?

@johnbradley
Copy link
Collaborator

@Link7808 I've found setting eigen_smooth=True when calling create_grad_cam_image() with method="gradcam" results in more reasonable outputs.

Help for eigen_smooth in vit_example.py says

Reduce noise by taking the first principle componet of cam_weights*activations

This pytorch_grad_cam code seems to be where the eigen_smooth flag takes effect.

@johnbradley
Copy link
Collaborator

I'm not sure why without eigen_smooth=True the results don't look very good.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants