Skip to content

Commit

Permalink
support larger ViT models
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Dec 12, 2021
1 parent 785fe8e commit c0c2c0b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
7 changes: 6 additions & 1 deletion examples/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@

from tinygrad.tensor import Tensor
from models.vit import ViT
import os

Tensor.training = False
m = ViT()
if int(os.getenv("LARGE", "0")) == 1:
m = ViT(embed_dim=768, num_heads=12)
else:
# tiny
m = ViT(embed_dim=192, num_heads=3)
m.load_from_pretrained()

# category labels
Expand Down
24 changes: 15 additions & 9 deletions models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class ViT:
def __init__(self, layers=12, embed_dim=192, num_heads=3):
self.embedding = (Tensor.uniform(embed_dim, 3, 16, 16), Tensor.zeros(embed_dim))
self.embed_dim = embed_dim
self.cls = Tensor.ones(1, 1, embed_dim)
self.pos_embedding = Tensor.ones(1, 197, embed_dim)
self.tbs = [
Expand Down Expand Up @@ -32,7 +33,12 @@ def load_from_pretrained(m):
from extra.utils import fetch

# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
url = "https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz"
if m.embed_dim == 192:
url = "https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz"
elif m.embed_dim == 768:
url = "https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz"
else:
raise Exception("no pretrained weights for configuration")
dat = np.load(io.BytesIO(fetch(url)))

#for x in dat.keys():
Expand All @@ -51,14 +57,14 @@ def load_from_pretrained(m):
m.encoder_norm[1].assign(dat['Transformer/encoder_norm/bias'])

for i in range(12):
m.tbs[i].query[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/kernel'].reshape(192, 192))
m.tbs[i].query[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/bias'].reshape(192))
m.tbs[i].key[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/kernel'].reshape(192, 192))
m.tbs[i].key[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/bias'].reshape(192))
m.tbs[i].value[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/kernel'].reshape(192, 192))
m.tbs[i].value[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/bias'].reshape(192))
m.tbs[i].out[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/kernel'].reshape(192, 192))
m.tbs[i].out[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/bias'].reshape(192))
m.tbs[i].query[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/kernel'].reshape(m.embed_dim, m.embed_dim))
m.tbs[i].query[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/bias'].reshape(m.embed_dim))
m.tbs[i].key[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/kernel'].reshape(m.embed_dim, m.embed_dim))
m.tbs[i].key[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/bias'].reshape(m.embed_dim))
m.tbs[i].value[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/kernel'].reshape(m.embed_dim, m.embed_dim))
m.tbs[i].value[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/bias'].reshape(m.embed_dim))
m.tbs[i].out[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/kernel'].reshape(m.embed_dim, m.embed_dim))
m.tbs[i].out[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/bias'].reshape(m.embed_dim))
m.tbs[i].ff1[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/kernel'])
m.tbs[i].ff1[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/bias'])
m.tbs[i].ff2[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/kernel'])
Expand Down

0 comments on commit c0c2c0b

Please sign in to comment.