From 9d7e80b912cf62027cbf2d90fd37855acb1e1eb7 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 5 Nov 2024 13:09:34 -0800 Subject: [PATCH] Make the chat distributed --- llms/mlx_lm/chat.py | 46 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 85d32d5f..2641b023 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -15,6 +15,25 @@ DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +def share_message(world, prompt): + if world.size() == 1: + return + + if world.rank() == 0: + prompt_array = mx.array(prompt.encode()) + for i in range(1, world.size()): + mx.eval(mx.send(prompt_array, i)) + world.barrier() + + else: + prompt_array = mx.recv(0) + mx.eval(prompt_array) + prompt = bytes(prompt_array).decode() + world.barrier() + + return prompt + + def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser(description="Chat with an LLM") @@ -53,6 +72,7 @@ def setup_arg_parser(): def main(): + world = mx.distributed.init() parser = setup_arg_parser() args = parser.parse_args() @@ -62,18 +82,23 @@ def main(): args.model, adapter_path=args.adapter_path, tokenizer_config={"trust_remote_code": True}, + sequential_load=mx.distributed.init().size() > 1, ) + print(f"Node {world.rank()} of {world.size()}", flush=True) print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: - query = input(">> ") - if query == "q": - break - messages = [{"role": "user", "content": query}] - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) + prompt = None + if world.rank() == 0: + query = input(">> ") + if query == "q": + break + messages = [{"role": "user", "content": query}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + prompt = share_message(world, prompt) for response in stream_generate( model, tokenizer, @@ -83,9 +108,12 @@ def main(): top_p=args.top_p, prompt_cache=prompt_cache, ): - print(response, flush=True, end="") - print() + if world.rank() == 0: + print(response, flush=True, end="") + if world.rank() == 0: + print() if __name__ == "__main__": main() +