Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to reproduce zeroshot classification results #133

Open
cyrusvahidi opened this issue Nov 23, 2023 · 1 comment
Open

Unable to reproduce zeroshot classification results #133

cyrusvahidi opened this issue Nov 23, 2023 · 1 comment

Comments

@cyrusvahidi
Copy link

cyrusvahidi commented Nov 23, 2023

Overview

I have attempted to reproduce the zeroshot classification results for ESC-50 outlined in the publication Large-scale contrastive language-audio pretraining with feature fusion and keyword-to-caption augmentation.

In the paper, zeroshot classification accuracy (top-1) for the best model (K2C aug) is reported at 91.0%. I assume that this is the 630k-audioset-best.pt checkpoint.

  • I am only able to report 60.2% top-1 accuracy for the ESC-50 dataset
  • ESC-50 was downloaded from the google drive folder.
  • I've found it quite difficult to follow the given evaluation and data preprocessing code, so I wrote my own.

Reproduce

I use the set of 50 unique captions in the test dataset, which are found in the text attr of each example's json file, e.g. "The sound of the crow".

Here's the loader for ESC-50:

class ESC50Dataset(Dataset):
    def __init__(
        self,
        path_to_esc50="./data/ESC50",
        split="test",
        audio_len = 480000,
    ):
        super().__init__()
        self.data_path = Path(path_to_esc50)
        self.audio_len = audio_len
       
        self.audio_files = sorted(glob.glob(str(self.data_path / split / "*.flac")))
        self.meta_files = sorted(glob.glob(str(self.data_path / split / "*.json")))
        assert len(self.audio_files) == len(self.meta_files), "Number of audio files and meta files must match"
        assert [osp.splitext(osp.basename(x))[0] for x in self.audio_files] == [osp.splitext(osp.basename(x))[0] for x in self.meta_files], "Audio files and meta files must have the same names"

        self.tags = []
        self.texts = []
        for f in self.meta_files:
            with open(f, 'r') as json_file:
                data = json.load(json_file)
                self.tags.append(data["tag"][0])
                self.texts.append(data["text"][0])

    def __getitem__(self, idx):
        x, _ = load_audio_torch(self.audio_files[idx], target_sr=48000, mono=True)
        x = random_slice(x, self.audio_len)
        return x, self.texts[idx]

    def __len__(self):
        return len(self.audio_files)

And the zeroshot retrieval script:

import os

import torch
import laion_clap

from data.loaders import ESC50Dataset

ckpt_path = "CLAP_checkpoints/laion_clap/"
model_params = {"ckpt": "630k-audioset-best.pt", "amodel": "HTSAT-tiny"}
model = laion_clap.CLAP_Module(enable_fusion=False, amodel=model_params["amodel"])
model.load_ckpt(os.path.join(ckpt_path, model_params["ckpt"]))

dataset = ESC50Dataset()
texts = list(set(dataset.texts)) # get the unique texts, e.g "The sound of the crow"

# get the text embeddings for each tag
z_text = torch.cat([torch.tensor(model.get_text_embedding([t, t])[0:1]) for t in texts])

z_audio = []
text_idxs = []
for item in dataset:
    x, text = item
    idx = texts.index(text) # get the index of this example's text
    text_idxs.append(idx)
    z_audio.append(torch.tensor(model.get_audio_embedding_from_data(x.numpy()))) # get its CLAP audio embedding
z_audio = torch.cat(z_audio)
sim = model.model.logit_scale_a.cpu() * z_audio @ z_text.T # compute pairwise dot products

# top-1 accuracy
acc = float(torch.sum(torch.argmax(sim, dim=1) == torch.tensor(text_idxs)) / len(sim))
print(f"Accuracy: {acc}")
Accuracy: 0.6025000214576721 

Hopefully I am missing something significant?

@cyrusvahidi
Copy link
Author

cyrusvahidi commented Nov 24, 2023

Ok I managed to reproduce:

Zeroshot Classification Results: mean_rank: 2.7344      median_rank: 1.0000     R@1: 0.5925     R@5: 0.8981     R@10: 0.9525    mAP@10: 0.7200
Accuracy: 0.5925 over 1600 samples

It seems top-5 accuracy was reported in the paper. I was confused, as Section 4.3 of the paper states "We use top-1 accuracy as the metric.".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant