Skip to content

Commit

Permalink
Update cvt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fffffgggg54 committed Dec 15, 2024
1 parent 2975318 commit 4d1b21a
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion timm/models/cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def __init__(
self.feature_info = []

self.use_cls_token = use_cls_token
self.global_pool = 'token' if use_cls_token else 'avg'

dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]

Expand Down Expand Up @@ -448,6 +449,21 @@ def __init__(
self.head = nn.Linear(dims[-1], num_classes) if num_classes > 0 else nn.Identity()



@torch.jit.ignore
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool = None) -> None:
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
if global_pool == 'token' and not self.use_cls_token:
assert False, 'Model not configured to use class token'
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()


def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
# nn.Sequential forward can't accept tuple intermediates
# TODO grad checkpointing
Expand All @@ -457,12 +473,13 @@ def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
return x

def forward_features(self, x: torch.Tensor) -> torch.Tensor:
# get feature map, not always used
x = self._forward_features(x)

return x[0] if self.use_cls_token else x

def forward_head(self, x: torch.Tensor) -> torch.Tensor:
if self.use_cls_token:
if self.global_pool == 'token':
return self.head(self.norm(x[1].flatten(1)))
else:
return self.head(self.norm(x.mean(dim=(2,3))))
Expand Down

0 comments on commit 4d1b21a

Please sign in to comment.