Skip to content

Commit

Permalink
Start distributed inference for llama models
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Aug 1, 2024
1 parent 85dc76f commit fbbf173
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 2 deletions.
2 changes: 1 addition & 1 deletion llms/mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def main():
tokenizer,
prompt,
args.max_tokens,
verbose=True,
verbose=mx.distributed.init().rank() == 0,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
Expand Down
32 changes: 31 additions & 1 deletion llms/mlx_lm/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,36 @@ def sanitize(self, weights):
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}

def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()

def all_to_sharded(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedAllToShardedLinear.from_quantized_linear(l, group)
else:
return nn.AllToShardedLinear.from_linear(l, group)

def sharded_to_all(l):
if isinstance(l, nn.QuantizedLinear):
return nn.QuantizedShardedToAllLinear.from_quantized_linear(l, group)
else:
return nn.ShardedToAllLinear.from_linear(l, group)

N = group.size()
for layer in self.model.layers:
# Shard the self attention
layer.self_attn.q_proj = all_to_sharded(layer.self_attn.q_proj)
layer.self_attn.k_proj = all_to_sharded(layer.self_attn.k_proj)
layer.self_attn.v_proj = all_to_sharded(layer.self_attn.v_proj)
layer.self_attn.o_proj = sharded_to_all(layer.self_attn.o_proj)
layer.self_attn.n_heads //= N
layer.self_attn.n_kv_heads //= N

# Shard the MLP
layer.mlp.gate_proj = all_to_sharded(layer.mlp.gate_proj)
layer.mlp.down_proj = sharded_to_all(layer.mlp.down_proj)
layer.mlp.up_proj = all_to_sharded(layer.mlp.up_proj)

@property
def layers(self):
return self.model.layers
Expand All @@ -321,4 +351,4 @@ def head_dim(self):

@property
def n_kv_heads(self):
return self.args.num_key_value_heads
return self.args.num_key_value_heads // mx.distributed.init().size()
5 changes: 5 additions & 0 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,11 @@ def class_predicate(p, m):

model.load_weights(list(weights.items()))

if mx.distributed.init().size() > 1:
if not hasattr(model, "shard"):
raise RuntimeError("Model doesn't support distributed inference.")
model.shard()

if not lazy:
mx.eval(model.parameters())

Expand Down

0 comments on commit fbbf173

Please sign in to comment.