Skip to content

Commit

Permalink
Bugfix in vision transformer - save class token and pos embedding (#1204
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ManukyanD authored Jan 12, 2024
1 parent 5a3b333 commit de85c09
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions vision_transformer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def __init__(self, args):
# Linear projection
self.LinearProjection = nn.Linear(self.input_size, self.latent_size)
# Class token
self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)
self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device))
# Positional embedding
self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)
self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size).to(self.device))

def forward(self, input_data):
input_data = input_data.to(self.device)
Expand Down

0 comments on commit de85c09

Please sign in to comment.