Skip to content

Commit

Permalink
rebase rotary and add dtype for rotary to llama.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Jan 28, 2025
1 parent 1d793fe commit 20e7316
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 62 deletions.
164 changes: 102 additions & 62 deletions sharktank/sharktank/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ def __init__(
use_hf: bool = False,
use_table: bool = True,
tensor_parallelism_size: int = 1,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.device = device
self.rope_dimension_count = rope_dimension_count
self.max_seqlen = max_seqlen
self.use_hf = use_hf
self.use_table = use_table

self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 500000.0
self.dtype = dtype
self.rope_freq_base = rope_freq_base if rope_freq_base is not None else 10000.0
self.tensor_parallelism_size = tensor_parallelism_size

@property
Expand Down Expand Up @@ -73,51 +74,94 @@ def forward(
xt = SplitPrimitiveTensor(ts=xt_shards, shard_dim=xt.shard_dim)
return xt

def _create_interleaved_tensor(_, dim):
"""Creates a tensor which indexes an tensor such that
it alternates between elements of its first and second
half. Intended for use for HuggingFace's rotation
implementation.
Args:
dim: Size of tensor
Returns:
Interleaved indexing tensor
"""
first_half = torch.arange(dim // 2)
second_half = torch.arange(dim // 2, dim)

interleaved_tensor = torch.empty(dim, dtype=torch.long)
interleaved_tensor[0::2] = first_half
interleaved_tensor[1::2] = second_half

return interleaved_tensor

def _create_ordering_tensor(_, dim):
"""Creates a tensor which indexes an tensor such that
it reverses the alternation induced by create_interleaved_tesnor.
Intended for use for HuggingFace's rotation implementation.
Args:
dim: Size of tensor
Returns:
Ordering indexing tensor
"""
order_tensor = torch.empty(dim, dtype=torch.long)
order_tensor[: dim // 2] = torch.arange(0, dim, 2)
order_tensor[dim // 2 :] = torch.arange(1, dim, 2)
return order_tensor

@staticmethod
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def forward_unsharded(
self,
*,
xt: torch.Tensor,
start_index: int,
rotary_embed_table: Optional[torch.Tensor],
):
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim
if self.use_hf:
xt = xt[..., self._create_interleaved_tensor(xt.shape[-1])]
# xq_, xk_ shape: bs, sl, _, dim
xt_ = xt
_, sl, _, _ = xt_.shape

if self.use_hf:
freqs_cis = rotary_embed_table
# Slice from max to current sequence length
cos, sin = [x[start_index : start_index + sl, :] for x in freqs_cis]
# expand to 1, sl, 1, dim and repeat per bs
cos = cos[None, :, None, :].repeat(xt.shape[0], 1, 1, 1)
sin = sin[None, :, None, :].repeat(xt.shape[0], 1, 1, 1)
xt = xt.transpose(1, 2)
xt_out = (xt_ * cos) + (self.rotate_half(xt_) * sin)
return xt_out

# Offset the table based on starting position.
if self.use_table:
freqs_cis = rotary_embed_table # [start_index : start_index + sl, :]
cos, sin = [x[start_index : start_index + sl, :] for x in freqs_cis]
cos = cos[None, :, None, :]
sin = sin[None, :, None, :]
freqs_cis = rotary_embed_table[start_index : start_index + sl, :]
freqs_cis = freqs_cis[0:sl, :]
else:
freqs_cis = torch.arange(sl, device=xt.device) + start_index
freqs_cis = self._compute_rotary_embed_table(freqs_cis)

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

cos = cos.transpose(1, 2)
cos = cos.repeat((xt_.shape[0], 1, 1, 1))
sin = sin.transpose(1, 2)
xt_ = xt_.transpose(1, 2)
sin = sin.repeat((xt_.shape[0], 1, 1, 1))
xt_out = (xt_ * cos) + (rotate_half(xt_) * sin)
assert (
freqs_cis.shape[0] >= sl
), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"

if self.use_hf:
xt_out = xt_out[..., self._create_ordering_tensor(xt_out.shape[-1])]
freqs_cis = ops.repeat(freqs_cis[None, :, :], (xt_.shape[0], 1, 1))
xt_out = kernels.apply_rotary_embedding(xt_.to(freqs_cis.dtype), freqs_cis)

return ops.to(xt_out, xt.dtype).transpose(1, 2)
return ops.to(xt_out, xt.dtype)

def compute_batch_mask(
self, start_positions: Union[torch.Tensor, ReplicatedTensor], batch_seq_len: int
) -> torch.Tensor:
# TODO: I'm pretty sure this function is only correct because batch_seq_len is always 1
"""Computes a mask for a batch that can be repeatedly applied.
Args:
Expand All @@ -133,12 +177,15 @@ def compute_batch_mask(
) + start_positions.unsqueeze(1)
# Broadcast lookup to [b, ...].
self.trace_tensor("rope.positions_seq", positions_seq)
if self.use_hf:
assert self.use_table, "use_hf requires use_table"
freqs_cis = self.rotary_embed_table
cos, sin = [x[positions_seq.flatten(), :] for x in freqs_cis]
freqs_cis = (cos[:, None, None, :], sin[:, None, None, :])
return freqs_cis

if self.use_table:
# freqs_cis = self.rotary_embed_table[positions_seq.flatten()]
freqs_cis = self.rotary_embed_table # [start_index : start_index + sl, :]
cos, sin = [x[positions_seq.flatten(), :] for x in freqs_cis]
freqs_cis = (cos.unsqueeze(1), sin.unsqueeze(1))
freqs_cis = self.rotary_embed_table[positions_seq.flatten()]
else:
shape = positions_seq.shape
if isinstance(positions_seq, ReplicatedTensor):
Expand All @@ -150,7 +197,7 @@ def compute_batch_mask(
else:
freqs_cis = self._compute_rotary_embed_table(positions_seq.flatten())

return freqs_cis # .unsqueeze(1)
return freqs_cis.unsqueeze(1)

def apply_batched_mask(
self,
Expand Down Expand Up @@ -183,35 +230,26 @@ def apply_batched_mask_unsharded(self, *, xt: torch.Tensor, mask: torch.Tensor):
# xq_, xk_ shape: bs, sl, _, dim
# freqs_cis shape: max_sl, dim

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

# assert (
# freqs_cis.shape[1] >= sl
# ), f"Sequence length longer than embedding table ({sl} vs {freqs_cis.shape[0]})"
cos, sin = mask
cos = cos.unsqueeze(1).transpose(1, 2)
sin = sin.unsqueeze(1).transpose(1, 2)
xt = xt.transpose(1, 2)
if self.use_hf:
cos, sin = mask
xt = xt.transpose(1, 2)
xt_out = (xt * cos) + (self.rotate_half(xt) * sin)
return xt_out.transpose(1, 2)

xt_out = (xt * cos) + (rotate_half(xt) * sin)
xt_out = kernels.apply_rotary_embedding(xt.to(mask.dtype), mask)

return xt_out.type_as(xt).transpose(1, 2)
return xt_out.type_as(xt)

def _compute_rotary_embed_table(self, t):
dim = self.rope_dimension_count
# freqs = 1.0 / (
# self.rope_freq_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0)
# )
freqs = 1.0 / (
self.rope_freq_base
** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
)
if True:
if self.use_hf:

freqs = 1.0 / (
self.rope_freq_base
** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
)
### from llama3 embedding changes
# TODO: get these values from Dataset
factor = 8 # in the original implementation
low_freq_factor = 1 # in the original implementation
high_freq_factor = 4
Expand All @@ -235,17 +273,19 @@ def _compute_rotary_embed_table(self, t):
is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(
wavelen > low_freq_wavelen
)
inv_freq_llama = torch.where(
is_medium_freq, smoothed_inv_freq, inv_freq_llama
)
freqs = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)

freqs = inv_freq_llama
freqs = torch.cat((freqs, freqs), dim=-1)
emb = torch.outer(t.float(), freqs.float())
cos = torch.cos(emb).to(self.dtype)
sin = torch.sin(emb).to(self.dtype)
return (cos, sin)

freqs = torch.cat((freqs, freqs), dim=-1)
emb = torch.outer(t.float(), freqs.float())
cos = torch.cos(emb).to(torch.bfloat16)
sin = torch.sin(emb).to(torch.bfloat16)
return (cos, sin)
freqs = 1.0 / (
self.rope_freq_base ** ((torch.arange(0, dim) // 2).float() / dim * 2.0)
)
freqs = torch.outer(t, freqs).float()
return freqs

def _create_rotary_embed_table(self):
t = torch.arange(self.max_seqlen, device=self.device)
Expand Down
1 change: 1 addition & 0 deletions sharktank/sharktank/models/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(self, theta: Theta, config: LlamaModelConfig):
device=self.device,
use_hf=self.use_hf,
tensor_parallelism_size=config.tensor_parallelism_size,
dtype=config.activation_dtype,
),
)
self.add_module(
Expand Down

0 comments on commit 20e7316

Please sign in to comment.