-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🤗 working on transformers for eng-nep translation
- Loading branch information
Showing
8 changed files
with
88,988 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__pycache__ | ||
runs | ||
weights |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from pathlib import Path | ||
|
||
def get_config(): | ||
""" | ||
Get the configuration. | ||
""" | ||
config = { | ||
"src_lang": "en", | ||
"tgt_lang": "ne", | ||
"seq_len": 512, | ||
"batch_size": 32, | ||
"num_epochs": 10, | ||
"learning_rate": 1e-4, | ||
"d_model": 512, | ||
"model_dir": "weights", | ||
"model_basename": "tmodel_", | ||
"preload": None, | ||
"tokenizer_file": "tokenizer_{0}.json", | ||
"experiment_name": "runs/tmodel", | ||
} | ||
return config | ||
|
||
def get_weights_file_path(config, epoch): | ||
""" | ||
Get the path to the weights file for the given epoch. | ||
""" | ||
model_dir = config["model_dir"] | ||
model_basename = config["model_basename"] | ||
model_filename = f"{model_basename}{epoch}.pt" | ||
return str(Path(".")/model_dir/model_filename) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import Dataset, DataLoader | ||
|
||
class BilingualDataset(Dataset): | ||
|
||
def __init__(self, ds, src_tokenizer, tgt_tokenizer, src_lang, tgt_lang, seq_len) -> None: | ||
super().__init__() | ||
self.ds = ds | ||
self.src_tokenizer = src_tokenizer | ||
self.tgt_tokenizer = tgt_tokenizer | ||
self.src_lang = src_lang | ||
self.tgt_lang = tgt_lang | ||
self.seq_len = seq_len | ||
|
||
self.sos_token = torch.tensor([src_tokenizer.token_to_id("[SOS]")], dtype=torch.int64) | ||
self.eos_token = torch.tensor([src_tokenizer.token_to_id("[EOS]")], dtype=torch.int64) | ||
self.pad_token = torch.tensor([src_tokenizer.token_to_id("[PAD]")], dtype=torch.int64) | ||
|
||
def __len__(self): | ||
return len(self.ds) | ||
|
||
def __getitem__(self, index): | ||
src_target_pair = self.ds[index] | ||
src_text = src_target_pair["translation"][self.src_lang] | ||
tgt_text = src_target_pair["translation"][self.tgt_lang] | ||
|
||
enc_input_tokens = self.src_tokenizer.encode(src_text).ids[:self.seq_len - 2] | ||
dec_input_tokens = self.tgt_tokenizer.encode(tgt_text).ids[:self.seq_len - 1] | ||
|
||
enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 | ||
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 | ||
|
||
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0: | ||
raise ValueError("Sequence length too small") | ||
|
||
encoder_input = torch.cat([ | ||
self.sos_token, | ||
torch.tensor(enc_input_tokens, dtype=torch.int64), | ||
self.eos_token, | ||
torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64), | ||
]) | ||
|
||
decoder_input = torch.cat([ | ||
self.sos_token, | ||
torch.tensor(dec_input_tokens, dtype=torch.int64), | ||
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64), | ||
]) | ||
|
||
label = torch.cat([ | ||
torch.tensor(dec_input_tokens, dtype=torch.int64), | ||
self.eos_token, | ||
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64), | ||
]) | ||
|
||
assert encoder_input.size(0) == self.seq_len | ||
assert decoder_input.size(0) == self.seq_len | ||
assert label.size(0) == self.seq_len | ||
|
||
return { | ||
"encoder_input": encoder_input, | ||
"decoder_input": decoder_input, | ||
"encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len) | ||
"decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, 1, seq_len) & (1, seq_len, seq_len) | ||
"label": label, | ||
"src_text": src_text, | ||
"tgt_text": tgt_text, | ||
} | ||
|
||
def causal_mask(size): | ||
mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int64) | ||
return mask == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
from typing import Any | ||
import torch | ||
import torch.nn as nn | ||
import math | ||
|
||
class InputEmbeddings(nn.Module): | ||
|
||
def __init__(self, d_model: int, vocab_size: int): | ||
super().__init__() | ||
self.d_model = d_model | ||
self.vocab_size = vocab_size | ||
self.embedding = nn.Embedding(vocab_size, d_model) | ||
|
||
def forward(self, input): | ||
return self.embedding(input) * math.sqrt(self.d_model) | ||
|
||
class PositionalEncoding(nn.Module): | ||
|
||
def __init__(self, d_model: int, seq_len: int, dropout: float) -> None: | ||
super().__init__() | ||
self.d_model = d_model | ||
self.seq_len = seq_len | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
# create a matrix of shape (seq_len, d_model) | ||
pe = torch.zeros(seq_len, d_model) | ||
# create a vector of size (seq_len, 1) | ||
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) | ||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | ||
pe[:, 0::2] = torch.sin(position * div_term) | ||
pe[:, 1::2] = torch.cos(position * div_term) | ||
|
||
pe = pe.unsqueeze(0) | ||
|
||
self.register_buffer('pe', pe) | ||
|
||
def forward(self, x): | ||
x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) | ||
return self.dropout(x) | ||
|
||
class LayerNormalization(nn.Module): | ||
|
||
def __init__(self, eps: float = 10**-6) -> None: | ||
super().__init__() | ||
self.eps = eps | ||
self.alpha = nn.Parameter(torch.ones(1)) # multiplicative parameter | ||
self.bias = nn.Parameter(torch.zeros(1)) # additive parameter | ||
|
||
def forward(self, x): | ||
mean = x.mean(dim=-1, keepdim=True) | ||
std = x.std(dim=-1, keepdim=True) | ||
return self.alpha * (x - mean) / (std + self.eps) + self.bias | ||
|
||
class FeedForwardBlock(nn.Module): | ||
|
||
def __init__(self, d_model: int, d_ff: int, dropout: float) -> None: | ||
super().__init__() | ||
self.linear1 = nn.Linear(d_model, d_ff) # W1 and B1 | ||
self.dropout = nn.Dropout(dropout) | ||
self.linear2 = nn.Linear(d_ff, d_model) # W2 and B2 | ||
|
||
def forward(self, x): | ||
# (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model) | ||
x = self.dropout(torch.relu(self.linear1(x))) | ||
x = self.linear2(x) | ||
return x | ||
|
||
class MultiHeadAttentionBlock(nn.Module): | ||
|
||
def __init__(self, d_model: int, h: int, dropout: float) -> None: | ||
super().__init__() | ||
self.d_model = d_model | ||
self.h = h | ||
assert d_model % h == 0, "d_model must be divisible by h" | ||
|
||
self.d_k = d_model // h | ||
self.w_q = nn.Linear(d_model, d_model) # W_q | ||
self.w_k = nn.Linear(d_model, d_model) # W_k | ||
self.w_v = nn.Linear(d_model, d_model) # W_v | ||
|
||
self.w_o = nn.Linear(d_model, d_model) # W_o | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
@staticmethod | ||
def attention(query, key, value, mask, dropout: nn.Dropout): | ||
d_k = query.shape[-1] | ||
|
||
# (batch_size, h, seq_len, d_k) @ (batch_size, h, d_k, seq_len) -> (batch_size, h, seq_len, seq_len) | ||
attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k) | ||
if mask is not None: | ||
attention_scores = attention_scores.masked_fill_(mask == 0, -1e9) | ||
attention_scores = torch.softmax(attention_scores, dim=-1) # (batch_size, h, seq_len, seq_len) | ||
if dropout is not None: | ||
attention_scores = dropout(attention_scores) | ||
return (attention_scores @ value), attention_scores | ||
|
||
def forward(self, q, k, v, mask): | ||
query = self.w_q(q) # (batch_size, seq_len, d_model) --> (batch_size, seq_len, d_model) | ||
key = self.w_k(k) # (batch_size, seq_len, d_model) --> (batch_size, seq_len, d_model) | ||
value = self.w_v(v) # (batch_size, seq_len, d_model) --> (batch_size, seq_len, d_model) | ||
|
||
# (batch_size, seq_len, d_model) --> (batch_size, seq_len, h, d_k) --> (batch_size, h, seq_len, d_k) | ||
query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2) | ||
key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2) | ||
value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2) | ||
|
||
x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout) | ||
|
||
# (batch_size, h, seq_len, d_k) --> (batch_size, seq_len, h, d_k) --> (batch_size, seq_len, d_model) | ||
x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k) | ||
|
||
x = self.w_o(x) # (batch_size, seq_len, d_model) --> (batch_size, seq_len, d_model) | ||
return x | ||
|
||
class ResidualConnection(nn.Module): | ||
|
||
def __init__(self, dropout: float) -> None: | ||
super().__init__() | ||
self.norm = LayerNormalization() | ||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, x, sublayer): | ||
return x + self.dropout(sublayer(self.norm(x))) | ||
|
||
class EncoderBlock(nn.Module): | ||
|
||
def __init__(self, self_attention_block: MultiHeadAttentionBlock, feed_foward_block: FeedForwardBlock, dropout: float) -> None: | ||
super().__init__() | ||
self.self_attention_block = self_attention_block | ||
self.feed_foward_block = feed_foward_block | ||
self.residual_connection1 = nn.ModuleList([ResidualConnection(dropout) for _ in range(2)]) | ||
|
||
def forward(self, x, src_mask): | ||
x = self.residual_connection1[0](x, lambda x: self.self_attention_block(x, x, x, src_mask)) | ||
x = self.residual_connection1[1](x, self.feed_foward_block) | ||
return x | ||
|
||
class Encoder(nn.Module): | ||
|
||
def __init__(self, layers: nn.ModuleList) -> None: | ||
super().__init__() | ||
self.layers = layers | ||
self.norm = LayerNormalization() | ||
|
||
def forward(self, x, mask): | ||
for layer in self.layers: | ||
x = layer(x, mask) | ||
return self.norm(x) | ||
|
||
class DecoderBlock(nn.Module): | ||
|
||
def __init__(self, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None: | ||
super().__init__() | ||
self.self_attention_block = self_attention_block | ||
self.cross_attention_block = cross_attention_block | ||
self.feed_forward_block = feed_forward_block | ||
self.residual_connection1 = nn.ModuleList([ResidualConnection(dropout) for _ in range(3)]) | ||
|
||
def forward(self, x, encoder_output, src_mask, tgt_mask): | ||
x = self.residual_connection1[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask)) | ||
x = self.residual_connection1[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask)) | ||
x = self.residual_connection1[2](x, self.feed_forward_block) | ||
return x | ||
|
||
class Decoder(nn.Module): | ||
|
||
def __init__(self, layers: nn.ModuleList) -> None: | ||
super().__init__() | ||
self.layers = layers | ||
self.norm = LayerNormalization() | ||
|
||
def forward(self, x, encoder_output, src_mask, tgt_mask): | ||
for layer in self.layers: | ||
x = layer(x, encoder_output, src_mask, tgt_mask) | ||
return self.norm(x) | ||
|
||
class ProjectionLayer(nn.Module): | ||
|
||
def __init__(self, d_model: int, vocab_size: int) -> None: | ||
super().__init__() | ||
self.proj = nn.Linear(d_model, vocab_size) | ||
|
||
def forward(self, x): | ||
# (batch_size, seq_len, d_model) --> (batch_size, seq_len, vocab_size) | ||
return torch.log_softmax(self.proj(x), dim=-1) | ||
|
||
class Transformer(nn.Module): | ||
|
||
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None: | ||
super().__init__() | ||
self.encoder = encoder | ||
self.decoder = decoder | ||
self.src_embed = src_embed | ||
self.tgt_embed = tgt_embed | ||
self.src_pos = src_pos | ||
self.tgt_pos = tgt_pos | ||
self.projection_layer = projection_layer | ||
|
||
def encode(self, src, src_mask): | ||
return self.encoder(self.src_pos(self.src_embed(src)), src_mask) | ||
|
||
def decode(self, tgt, encoder_output, src_mask, tgt_mask): | ||
return self.decoder(self.tgt_pos(self.tgt_embed(tgt)), encoder_output, src_mask, tgt_mask) | ||
|
||
def project(self, x): | ||
return self.projection_layer(x) | ||
|
||
def build_transformer( | ||
src_vocab_size: int, | ||
tgt_vocab_size: int, | ||
src_seq_len: int, | ||
tgt_seq_len: int, | ||
d_model: int = 512, | ||
N: int = 6, | ||
h: int = 8, | ||
dropout: float = 0.1, | ||
d_ff: int = 2048 | ||
) -> Transformer: | ||
|
||
src_embed = InputEmbeddings(d_model, src_vocab_size) | ||
tgt_embed = InputEmbeddings(d_model, tgt_vocab_size) | ||
|
||
src_pos = PositionalEncoding(d_model, src_seq_len, dropout) | ||
tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout) | ||
|
||
encoder = Encoder( | ||
nn.ModuleList([ | ||
EncoderBlock( | ||
MultiHeadAttentionBlock(d_model, h, dropout), | ||
FeedForwardBlock(d_model, d_ff, dropout), | ||
dropout | ||
) for _ in range(N) | ||
]) | ||
) | ||
|
||
decoder = Decoder( | ||
nn.ModuleList([ | ||
DecoderBlock( | ||
MultiHeadAttentionBlock(d_model, h, dropout), | ||
MultiHeadAttentionBlock(d_model, h, dropout), | ||
FeedForwardBlock(d_model, d_ff, dropout), | ||
dropout | ||
) for _ in range(N) | ||
]) | ||
) | ||
|
||
projection_layer = ProjectionLayer(d_model, tgt_vocab_size) | ||
|
||
transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer) | ||
|
||
# initialize the parameters with Xavier uniform distribution | ||
for p in transformer.parameters(): | ||
if p.dim() > 1: | ||
nn.init.xavier_uniform_(p) | ||
|
||
return transformer |
Oops, something went wrong.