Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the Quantized model and also a Demo of the Quantized model #587

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 264 additions & 0 deletions quantized_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import torch
import torch.nn as nn
from torch.nn import functional as F

from model import GPTConfig, GPT
import math
import os

import torch
import torch.nn as nn
from torch.nn import functional as F

class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):

def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
self.quant = torch.quantization.QuantStub()
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
self.dequant = torch.quantization.DeQuantStub()
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
if not self.flash:
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))

def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
x = self.quant(x)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
k,q,v = self.dequant(k),self.dequant(q),self.dequant(v)
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
y = self.quant(y)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

# output projection
y = self.resid_dropout(self.c_proj(y))
y = self.quant(y)
return y

class MLP(nn.Module):

def __init__(self, config):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
self.dequant = torch.quantization.DeQuantStub()

def forward(self, x):
x = self.quant(x)
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
x = self.dequant(x)
return x

class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x

class QuantGPT(nn.Module):
def __init__(self, config):
super().__init__()

assert config.vocab_size is not None
assert config.block_size is not None
self.config = config

self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
drop = nn.Dropout(config.dropout),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = LayerNorm(config.n_embd, bias=config.bias),
))

self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight

self.apply(self._init_weights)
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device)

tok_emb = self.transformer.wte(idx)
pos_emb = self.transformer.wpe(pos)
x = self.transformer.drop(tok_emb + pos_emb)

for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)

if targets is not None:
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None

return logits, loss

@classmethod
def from_pretrained(cls, model_type, override_args=None):
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
from transformers import GPT2LMHeadModel

config_args = {
'gpt2': dict(n_layer=12, n_head=12, n_embd=768),
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024),
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280),
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600),
}[model_type]

config_args['vocab_size'] = 50257
config_args['block_size'] = 1024
config_args['bias'] = True

if override_args:
config_args.update(override_args)

config = GPTConfig(**config_args)
quant_model = QuantGPT(config)
sd = quant_model.state_dict()
sd_keys = sd.keys()
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')]

model_hf = GPT2LMHeadModel.from_pretrained(model_type)
sd_hf = model_hf.state_dict()

sd_keys_hf = sd_hf.keys()
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')]
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')]
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']

assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
for k in sd_keys_hf:
if any(k.endswith(w) for w in transposed):
assert sd_hf[k].shape[::-1] == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k].t())
else:
assert sd_hf[k].shape == sd[k].shape
with torch.no_grad():
sd[k].copy_(sd_hf[k])

# NOTE: if you use the cuda use the "fbgemm" and if mac M1/M2 or mobile use the qnnpack
torch.backends.quantized.engine = 'qnnpack'
quant_model.qconfig = torch.quantization.get_default_qconfig("qnnpack")
quant_model.transformer.wte.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig
quant_model.transformer.wpe.qconfig = torch.ao.quantization.float_qparams_weight_only_qconfig

torch.quantization.prepare(quant_model, inplace=True)

sample_inp = torch.randint(0, config.vocab_size, (1, config.block_size))
quant_model(sample_inp)

torch.quantization.convert(quant_model, inplace=True)

return quant_model

def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

#same as the gpt model
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)

return idx

if __name__ == "__main__":
def model_size(model: nn.Module, name:str):
torch.save(model.state_dict(), "temp_model.p")
size_kb = os.path.getsize("temp_model.p")/1e3
os.remove("temp_model.p")

print(f"\n Model: {name} -- Model Size: {size_kb:.2f} KB")
return size_kb

gpt2_config = GPTConfig()
gpt_2 = GPT.from_pretrained('gpt2')
gpt_2_size = model_size(gpt_2, 'GPT2')

q_gpt_2 = QuantGPT.from_pretrained('gpt2')
q_gpt_2_size = model_size(q_gpt_2, 'QGPT2')

print(f"\nDiffrence -- {gpt_2_size - q_gpt_2_size:.2f} KB")