Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the problem that the sequence length is still limited when using relative positions #111

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 58 additions & 6 deletions NEZHA-PyTorch/modeling_nezha.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,27 @@ def _generate_relative_positions_embeddings(length, depth, max_relative_position
return embeddings


class RelativePositionsEmbeddings(nn.Module):
"""
Given to relative position embedding table, output relative position embeddings

"""
def __init__(self, depth, max_relative_position):
super(RelativePositionsEmbeddings, self).__init__()
vocab_size = max_relative_position * 2 + 1
embeddings_table = np.zeros([vocab_size, depth])
for pos in range(vocab_size):
for i in range(depth // 2):
embeddings_table[pos, 2 * i] = np.sin(pos / np.power(10000, 2 * i / depth))
embeddings_table[pos, 2 * i + 1] = np.cos(pos / np.power(10000, 2 * i / depth))

self.embeddings_table_tensor = nn.Parameter(torch.tensor(embeddings_table, dtype=torch.float))

def forward(self, relative_positions):
embeddings = torch.index_select(self.embeddings_table_tensor, 0, relative_positions)
return embeddings


### Test: print(_generate_relative_positions_embeddings(6, 32, 4)[0, 0, :])

class NeZhaSelfAttention(nn.Module):
Expand All @@ -338,10 +359,32 @@ def __init__(self, config):
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.relative_positions_embeddings = _generate_relative_positions_embeddings(
length=512, depth=self.attention_head_size, max_relative_position=config.max_relative_position).cuda()

# self.relative_positions_embeddings = _generate_relative_positions_embeddings(
# length=512, depth=self.attention_head_size, max_relative_position=config.max_relative_position).cuda()

self.relative_positions_embeddings = RelativePositionsEmbeddings(
depth=self.attention_head_size, max_relative_position=config.max_relative_position)
self.max_relative_position = config.max_relative_position

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

def process_relative_position_embeddings(self, seq_length, device):
"""
Given to seq_length input, generating relative position embeddings

"""
depth = self.attention_head_size
final_mat = _generate_relative_positions_matrix(seq_length, self.max_relative_position).to(device)
flat_relative_positions_matrix = final_mat.view(-1)
embeddings = self.relative_positions_embeddings(flat_relative_positions_matrix)
my_shape = list(final_mat.size())
my_shape.append(depth)
embeddings = embeddings.view(my_shape)
relative_position_embeddings = embeddings.clone().detach()

return relative_position_embeddings

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
Expand All @@ -363,9 +406,14 @@ def forward(self, hidden_states, attention_mask):
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
batch_size, num_attention_heads, from_seq_length, to_seq_length = attention_scores.size()

relations_keys = self.relative_positions_embeddings.detach().clone()[:to_seq_length, :to_seq_length, :].to(
device)
# start: generating relative keys
relations_keys = self.process_relative_position(to_seq_length, device)
# end: generating relative keys

# relations_keys = self.relative_positions_embeddings.detach().clone()[:to_seq_length, :to_seq_length, :].to(
# device)
# relations_keys = embeddings.clone().detach().to(device)
#
query_layer_t = query_layer.permute(2, 0, 1, 3)
query_layer_r = query_layer_t.contiguous().view(from_seq_length, batch_size * num_attention_heads,
self.attention_head_size)
Expand All @@ -387,8 +435,12 @@ def forward(self, hidden_states, attention_mask):

context_layer = torch.matmul(attention_probs, value_layer)

relations_values = self.relative_positions_embeddings.clone()[:to_seq_length, :to_seq_length, :].to(
device)
# start: generating relative values
relations_values = self.process_relative_position(to_seq_length, device)
# end: generating relative values

# relations_values = self.relative_positions_embeddings.clone()[:to_seq_length, :to_seq_length, :].to(
# device)
attention_probs_t = attention_probs.permute(2, 0, 1, 3)
attentions_probs_r = attention_probs_t.contiguous().view(from_seq_length, batch_size * num_attention_heads,
to_seq_length)
Expand Down