Skip to content

Commit

Permalink
fix with black
Browse files Browse the repository at this point in the history
  • Loading branch information
rsuderman committed Feb 10, 2025
1 parent 7a8f360 commit cdb80ef
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 180 deletions.
208 changes: 89 additions & 119 deletions sharktank/sharktank/layers/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

import torch
from ..types.theta import Theta
from ..types.tensors import DefaultPrimitiveTensor
from ..utils.testing import make_rand_torch
from ..utils.testing import make_rand


def make_llama_attention_block_theta(
Expand All @@ -21,31 +20,30 @@ def make_llama_attention_block_theta(
) -> Theta:
return Theta(
{
"attn_q.weight": DefaultPrimitiveTensor(
"attn_q.weight": make_rand(
name=f"blk.{block_idx}.attn_q.weight",
data=make_rand_torch(
(head_count * head_dim, embedding_length), dtype=dtype
),
shape=(head_count * head_dim, embedding_length),
dtype=dtype,
),
"attn_k.weight": DefaultPrimitiveTensor(
"attn_k.weight": make_rand(
name=f"blk.{block_idx}.attn_k.weight",
data=make_rand_torch(
(head_count_kv * head_dim, embedding_length), dtype=dtype
),
shape=(head_count_kv * head_dim, embedding_length),
dtype=dtype,
),
"attn_v.weight": DefaultPrimitiveTensor(
"attn_v.weight": make_rand(
name=f"blk.{block_idx}.attn_v.weight",
data=make_rand_torch(
(head_count_kv * head_dim, embedding_length), dtype=dtype
),
shape=(head_count_kv * head_dim, embedding_length),
dtype=dtype,
),
"attn_output.weight": DefaultPrimitiveTensor(
"attn_output.weight": make_rand(
name=f"blk.{block_idx}.attn_output.weight",
data=make_rand_torch((embedding_length, embedding_length), dtype=dtype),
shape=(embedding_length, embedding_length),
dtype=dtype,
),
"attn_norm.weight": DefaultPrimitiveTensor(
"attn_norm.weight": make_rand(
name=f"blk.{block_idx}.attn_norm.weight",
data=make_rand_torch((embedding_length), dtype=dtype),
shape=(embedding_length),
dtype=dtype,
),
}
)
Expand All @@ -64,31 +62,35 @@ def make_latent_attention_block_theta(
) -> Theta:
return Theta(
{
"wq.weight": DefaultPrimitiveTensor(
"wq.weight": make_rand(
name=f"blk.{block_idx}.wq.weight",
data=make_rand_torch((heads * (rope_dim + nope_dim), dim), dtype=dtype),
shape=(heads * (rope_dim + nope_dim), dim),
dtype=dtype,
),
"wkv_a.weight": DefaultPrimitiveTensor(
"wkv_a.weight": make_rand(
name=f"blk.{block_idx}.wkv_a.weight",
data=make_rand_torch((kv_latent_dim + rope_dim, dim), dtype=dtype),
shape=(kv_latent_dim + rope_dim, dim),
dtype=dtype,
),
"wkv_b.weight": DefaultPrimitiveTensor(
"wkv_b.weight": make_rand(
name=f"blk.{block_idx}.wkv_b.weight",
data=make_rand_torch(
(heads * (v_head_dim + nope_dim), kv_latent_dim), dtype=dtype
),
shape=(heads * (v_head_dim + nope_dim), kv_latent_dim),
dtype=dtype,
),
"wo.weight": DefaultPrimitiveTensor(
"wo.weight": make_rand(
name=f"blk.{block_idx}.wo.weight",
data=make_rand_torch((dim, heads * v_head_dim), dtype=dtype),
shape=(dim, heads * v_head_dim),
dtype=dtype,
),
"attn_norm.weight": DefaultPrimitiveTensor(
"attn_norm.weight": make_rand(
name=f"blk.{block_idx}.attn_norm.weight",
data=make_rand_torch((dim,), dtype=dtype),
shape=(dim,),
dtype=dtype,
),
"kv_norm.weight": DefaultPrimitiveTensor(
"kv_norm.weight": make_rand(
name=f"blk.{block_idx}.kv_norm.weight",
data=make_rand_torch((kv_latent_dim,), dtype=dtype),
shape=(kv_latent_dim,),
dtype=dtype,
),
}
)
Expand All @@ -108,77 +110,57 @@ def make_mmdit_double_block_random_theta(
mlp_hidden_size3 = int(2 * (mlp_ratio - 1) * hidden_size)
return Theta(
{
"img_attn.norm.key_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
"img_attn.norm.key_norm.scale": make_rand(
shape=(in_channels,), dtype=dtype
),
"img_attn.norm.query_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
"img_attn.norm.query_norm.scale": make_rand(
shape=(in_channels,), dtype=dtype
),
"img_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size,), dtype=dtype)
"img_attn.proj.bias": make_rand(shape=(hidden_size,), dtype=dtype),
"img_attn.proj.weight": make_rand(
shape=(hidden_size, hidden_size), dtype=dtype
),
"img_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
"img_attn.qkv.bias": make_rand(shape=(mlp_hidden_size,), dtype=dtype),
"img_attn.qkv.weight": make_rand(
shape=(mlp_hidden_size, hidden_size), dtype=dtype
),
"img_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
"img_mlp.0.bias": make_rand(shape=(mlp_hidden_size2), dtype=dtype),
"img_mlp.0.weight": make_rand(
shape=(mlp_hidden_size2, hidden_size), dtype=dtype
),
"img_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
"img_mlp.2.bias": make_rand(shape=(hidden_size), dtype=dtype),
"img_mlp.2.weight": make_rand(
shape=(hidden_size, mlp_hidden_size2), dtype=dtype
),
"img_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size2), dtype=dtype)
"img_mod.lin.bias": make_rand(shape=(mlp_hidden_size3,), dtype=dtype),
"img_mod.lin.weight": make_rand(
shape=(mlp_hidden_size3, hidden_size), dtype=dtype
),
"img_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype)
"txt_attn.norm.key_norm.scale": make_rand(
shape=(in_channels,), dtype=dtype
),
"img_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size), dtype=dtype)
"txt_attn.norm.query_norm.scale": make_rand(
shape=(in_channels,), dtype=dtype
),
"img_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
"txt_attn.proj.bias": make_rand(shape=(hidden_size,), dtype=dtype),
"txt_attn.proj.weight": make_rand(
shape=(hidden_size, hidden_size), dtype=dtype
),
"img_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
"txt_attn.qkv.bias": make_rand(shape=(mlp_hidden_size,), dtype=dtype),
"txt_attn.qkv.weight": make_rand(
shape=(mlp_hidden_size, hidden_size), dtype=dtype
),
"img_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
"txt_mlp.0.bias": make_rand(shape=(mlp_hidden_size2), dtype=dtype),
"txt_mlp.0.weight": make_rand(
shape=(mlp_hidden_size2, hidden_size), dtype=dtype
),
"txt_attn.norm.key_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
"txt_mlp.2.bias": make_rand(shape=(hidden_size), dtype=dtype),
"txt_mlp.2.weight": make_rand(
shape=(hidden_size, mlp_hidden_size2), dtype=dtype
),
"txt_attn.norm.query_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"txt_attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size,), dtype=dtype)
),
"txt_attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
),
"txt_attn.qkv.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
),
"txt_attn.qkv.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
),
"txt_mlp.0.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size2), dtype=dtype)
),
"txt_mlp.0.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size2, hidden_size), dtype=dtype)
),
"txt_mlp.2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size), dtype=dtype)
),
"txt_mlp.2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
),
"txt_mod.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
),
"txt_mod.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
"txt_mod.lin.bias": make_rand(shape=(mlp_hidden_size3,), dtype=dtype),
"txt_mod.lin.weight": make_rand(
shape=(mlp_hidden_size3, hidden_size), dtype=dtype
),
}
)
Expand All @@ -195,35 +177,23 @@ def make_mmdit_single_block_random_theta(
mlp_hidden_size3 = int((2 * mlp_ratio - 1) * hidden_size)
return Theta(
{
"norm.key_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"norm.query_norm.scale": DefaultPrimitiveTensor( #
data=make_rand_torch((in_channels,), dtype=dtype)
),
"attn.proj.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size,), dtype=dtype)
),
"attn.proj.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, hidden_size), dtype=dtype)
),
"linear1.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3,), dtype=dtype)
),
"linear1.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size3, hidden_size), dtype=dtype)
),
"linear2.bias": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size), dtype=dtype)
),
"linear2.weight": DefaultPrimitiveTensor(
data=make_rand_torch((hidden_size, mlp_hidden_size2), dtype=dtype)
),
"modulation.lin.bias": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size,), dtype=dtype)
),
"modulation.lin.weight": DefaultPrimitiveTensor(
data=make_rand_torch((mlp_hidden_size, hidden_size), dtype=dtype)
"norm.key_norm.scale": make_rand(shape=(in_channels,), dtype=dtype),
"norm.query_norm.scale": make_rand(shape=(in_channels,), dtype=dtype),
"attn.proj.bias": make_rand(shape=(hidden_size,), dtype=dtype),
"attn.proj.weight": make_rand(
shape=(hidden_size, hidden_size), dtype=dtype
),
"linear1.bias": make_rand(shape=(mlp_hidden_size3,), dtype=dtype),
"linear1.weight": make_rand(
shape=(mlp_hidden_size3, hidden_size), dtype=dtype
),
"linear2.bias": make_rand(shape=(hidden_size), dtype=dtype),
"linear2.weight": make_rand(
shape=(hidden_size, mlp_hidden_size2), dtype=dtype
),
"modulation.lin.bias": make_rand(shape=(mlp_hidden_size,), dtype=dtype),
"modulation.lin.weight": make_rand(
shape=(mlp_hidden_size, hidden_size), dtype=dtype
),
}
)
Loading

0 comments on commit cdb80ef

Please sign in to comment.