From de85c0917cdeeee61985c53811c4669d166a1da8 Mon Sep 17 00:00:00 2001 From: ManukyanD <152393709+ManukyanD@users.noreply.github.com> Date: Fri, 12 Jan 2024 08:39:33 +0400 Subject: [PATCH] Bugfix in vision transformer - save class token and pos embedding (#1204) --- vision_transformer/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vision_transformer/main.py b/vision_transformer/main.py index 15fd20c640..d215156127 100644 --- a/vision_transformer/main.py +++ b/vision_transformer/main.py @@ -49,9 +49,9 @@ def __init__(self, args): # Linear projection self.LinearProjection = nn.Linear(self.input_size, self.latent_size) # Class token - self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device) + self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device)) # Positional embedding - self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device) + self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device)) def forward(self, input_data): input_data = input_data.to(self.device)