From 00cd83dc067f70ec016aa9225210dd1f3ad7a9e8 Mon Sep 17 00:00:00 2001 From: guybd Date: Fri, 31 Jan 2025 15:03:48 +0200 Subject: [PATCH] added notebook to run DeepSeek-R1-Llama-8B --- ...generation_demo_DeepSeek-R1-Llama-8B.ipynb | 748 ++++++++++++++++++ 1 file changed, 748 insertions(+) create mode 100644 supplementary_materials/notebooks/quantized_generation_demo_DeepSeek-R1-Llama-8B.ipynb diff --git a/supplementary_materials/notebooks/quantized_generation_demo_DeepSeek-R1-Llama-8B.ipynb b/supplementary_materials/notebooks/quantized_generation_demo_DeepSeek-R1-Llama-8B.ipynb new file mode 100644 index 00000000000..36e9f7079ef --- /dev/null +++ b/supplementary_materials/notebooks/quantized_generation_demo_DeepSeek-R1-Llama-8B.ipynb @@ -0,0 +1,748 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aeb16663-be53-4260-b62d-44611b6771ec", + "metadata": {}, + "source": [ + "# Chat and Code with DeepSeek-R1-Distill-Llama-8B with OpenVINO and 🤗 Optimum on Intel's Lunar Lake iGPU\n", + "In this notebook we will show how to export and apply 4-bit weight only quantization to DeepSeek-R1-Distill-Llama-8B to model.\n", + "Then using the quantized model we will show how to generate for example code completions with the model running on Intel's Lunar Lake iGPU presenting a good experience of running GenAI locally on Intel PC marking the start of the AIPC Era!\n", + "Then we will show how to talk with DeepSeek-R1-Distill-Llama-8B in a ChatBot demo running completely locally on your Laptop!\n", + "\n", + "[DeepSeek-R1-Distill-Llama-8B](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B) is a 8 billion-parameter language model adopting the Llama-3.1 architechture\n", + "and distilled from DeepSeek-R1 model by Deepseek as a part of recently released model suite. \n" + ] + }, + { + "cell_type": "markdown", + "id": "03cb49cf-bc6f-4702-a61f-227b352404cb", + "metadata": {}, + "source": [ + "## Install dependencies\n", + "Make sure you have the latest GPU drivers installed on your machine: https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html.\n", + "\n", + "We will start by installing the dependencies, that can be done by uncommenting the following cell and run it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96d8203c-34c9-41a2-95bd-3891533840a1", + "metadata": {}, + "outputs": [], + "source": [ + "# ! pip install optimum[openvino,nncf] torch==2.6.0 " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5980ce40-0be1-48c1-941a-92c484d4da31", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from transformers import AutoTokenizer\n", + "from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig" + ] + }, + { + "cell_type": "markdown", + "id": "48b81857-a095-43a3-8c8d-4c880b743a6e", + "metadata": {}, + "source": [ + "## Configuration\n", + "Here we will configure which model to load and other attributes. We will explain everything 😄\n", + "* `model_name`: the name or path of the model we want to export and quantize, can be either on the 🤗 Hub or a local directory on your laptop.\n", + "* `precision`: the compute data type we will use for inference of the model, can be either `f32` or `f16`.\n", + "* `device`: the device to use for inference, can be either `cpu` or `gpu`.\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "56e664ff", + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import requests\n", + "\n", + "if not Path(\"cmd_helper.py\").exists():\n", + " r = requests.get(url=\"https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/cmd_helper.py\")\n", + " open(\"cmd_helper.py\", \"w\").write(r.text)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "800cd7a3-a21d-4a0a-9d73-2a2d08646f99", + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "**Export command:**" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/markdown": [ + "`optimum-cli export openvino --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B DeepSeek-R1-Distill-Llama-8B\\INT4 --weight-format int4`" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from cmd_helper import optimum_cli\n", + "\n", + "model_id = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n", + "model_dir = Path(model_id.split(\"/\")[-1])\n", + "\n", + "\n", + "if not (model_dir / \"INT4\").exists():\n", + " optimum_cli(model_id, model_dir/\"INT4\", additional_args={\"weight-format\": \"int4\"})\n" + ] + }, + { + "cell_type": "markdown", + "id": "1f398868-93d7-4c2d-9591-9bac8e9b701c", + "metadata": {}, + "source": [ + "With this configuration we expect the model size to reduce to around to 1.62GB: $0.8 \\times 2.7{\\times}10^3 \\times \\frac{1}{2}\\text{B} + 0.2 * 2.7{\\times}10^3 \\times 1\\text{B} = 1.62{\\times}10^3\\text{B} = 1.62\\text{GB}$" + ] + }, + { + "cell_type": "markdown", + "id": "d994997d-344c-4d6c-ab08-f78ecb7f56ec", + "metadata": {}, + "source": [ + "## Export & quantize\n", + "OpenVINO together with 🤗 Optimum enables you to load, export and quantize a model in a single `from_pretrained` call making the process as simple as possible.\n", + "Then, we will save the exported & quantized model locally on our laptop. If the model was already exported and saved before we will load the locally saved model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "03a308c6-27e7-4926-8ac4-4fa0c1ca68d2", + "metadata": {}, + "outputs": [], + "source": [ + "device = 'GPU'\n", + "precision = 'f16'\n", + "\n", + "# Load kwargs\n", + "load_kwargs = {\n", + " \"device\": device,\n", + " \"ov_config\": {\n", + " \"PERFORMANCE_HINT\": \"LATENCY\",\n", + " \"INFERENCE_PRECISION_HINT\": precision,\n", + " \"CACHE_DIR\": os.path.join(model_dir,\"INT4\", \"model_cache\"), # OpenVINO will use this directory as cache\n", + " },\n", + "}\n", + "\n", + "model = OVModelForCausalLM.from_pretrained(model_dir/\"INT4\", **load_kwargs)\n", + "\n", + "# Load tokenizer to be used with the model\n", + "tokenizer = AutoTokenizer.from_pretrained(model_dir/\"INT4\")\n", + "\n", + "# model_size = os.stat(os.path.join(save_name, \"openvino_model.bin\")).st_size / 1024 ** 3\n", + "# print(f'Model size in FP32: ~5.4GB, current model size in 4bit: {model_size:.2f}GB')" + ] + }, + { + "cell_type": "markdown", + "id": "592e118d-e8bb-491f-92b2-d0418e19158c", + "metadata": {}, + "source": [ + "We can see the model size was reduced to 1.7GB as expected. After loading the model we can switch the model between devices using `model.to('gpu')` for example.\n", + "After we have finished to configure everything, we can compile the model by calling `model.compile()` and the model will be ready for usage." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cef4dc0-191e-4755-a639-c3e8adbd18a2", + "metadata": {}, + "outputs": [], + "source": [ + "model.compile()" + ] + }, + { + "cell_type": "markdown", + "id": "dd3c467e-3bbb-4265-9075-1c6688af2f92", + "metadata": {}, + "source": [ + "## Generate using the exported model\n", + "We will now show an example where we will use our quantized Phi-2 to generate code in Python. \n", + "Phi-2 knows how to do code completions where the model is given a function's signature and its docstring and the model will generate the implementation of the function.\n", + "\n", + "In our example we have taken one of the samples from the test set of HumanEval dataset. \n", + "HumanEval is a code completion dataset used to train and benchmark models on code completion in Python. \n", + "Phi-2 has scored a remarkable result on the HumanEval dataset and is an excellent model to use for code completions.\n", + "\n", + "Note: the first time you run the model might take more time due to loading and compilation overheads of the first inference" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4b4ea738-7db5-490e-9338-d6420b77796c", + "metadata": {}, + "outputs": [], + "source": [ + "sample = \"\"\"from typing import List\n", + "\n", + "\n", + "def has_close_elements(numbers: List[float], threshold: float) -> bool:\n", + " \\\"\\\"\\\" Check if in given list of numbers, are any two numbers closer to each other than\n", + " given threshold.\n", + " >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n", + " False\n", + " >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n", + " True\n", + " \\\"\\\"\\\"\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "14ffe7f9-7d93-4a49-95d8-5f2a4e400cfe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "from typing import List\n", + "\n", + "\n", + "def has_close_elements(numbers: List[float], threshold: float) -> bool:\n", + " \"\"\" Check if in given list of numbers, are any two numbers closer to each other than\n", + " given threshold.\n", + " >>> has_close_elements([1.0, 2.0, 3.0], 0.5)\n", + " False\n", + " >>> has_close_elements([1.0, 2.8, 3.0, 4.0, 5.0, 2.0], 0.3)\n", + " True\n", + " \"\"\" \n", + " # Sort the list to make it easier to find consecutive elements\n", + " sorted_numbers = sorted(numbers)\n", + " # Iterate through each consecutive pair\n", + " for i in range(len(sorted_numbers) - 1):\n", + " # Calculate the difference between consecutive elements\n", + " diff = sorted_numbers[i+1] - sorted_numbers[i]\n", + " if diff < threshold:\n", + " return True\n", + " # If no pair found, return False\n", + " return False\n", + "\n", + "# Test case 1: numbers = [1.0, 2.0, 3.0], threshold = 0.5\n", + "# After sorting: [1.0\n" + ] + } + ], + "source": [ + "from transformers import TextStreamer\n", + "\n", + "# Tokenize the sample\n", + "inputs = tokenizer([sample], return_tensors='pt')\n", + "\n", + "# Call generate on the inputs\n", + "out = model.generate(\n", + " **inputs,\n", + " max_new_tokens=128,\n", + " streamer=TextStreamer(tokenizer=tokenizer, skip_special_tokens=True),\n", + " pad_token_id=tokenizer.eos_token_id,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3f8aa25c-de59-4e79-9a1f-c03ec76d206a", + "metadata": {}, + "source": [ + "## Chatbot demo\n", + "We will continue to build a chatbot demo running with Gradio using the models we just exported and quantized.\n", + "The chatbot will be rather simple where the user will input a message and the model will reply to the user by generating text using the entire chat history as the input to the model.\n", + "We will also add an option to accelerate inference using speculative decoding with a draft model as we described in the previous section.\n", + "\n", + "A lot of models that were trained for the chatbot use case have been trained with special tokens to tell the model who is the current speaker and with a special system message. \n", + "Phi-2 wasn't trained specifically for the chatbot use case and doesn't have any special tokens either, however, it has seen chats in the training data and therefore is suited for that use case.\n", + "\n", + "The chat template we will use is rather simple:\n", + "```\n", + "User: \n", + "Assistant: \n", + "User: \n", + "...\n", + "```\n", + "\n", + "We will start by writing the core function of the chatbot that receives the entire history of the chat and generates the assistant's response.\n", + "To support this core function we will build a few assistant functions to prepare the input for the model and to stop generation in time." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "7e81d125-ff47-4122-853d-11a2763db146", + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "from threading import Thread\n", + "\n", + "from transformers import (\n", + " TextIteratorStreamer,\n", + " StoppingCriteria,\n", + " StoppingCriteriaList,\n", + " GenerationConfig,\n", + ")\n", + "\n", + "class TextIteratorStreamerWithCounter(TextIteratorStreamer):\n", + " def __init__(self, tokenizer: \"AutoTokenizer\", **kwargs):\n", + " super().__init__(tokenizer, skip_prompt, **decode_kwargs)\n", + " self.counter = 0\n", + "\n", + " def put(self, value):\n", + " if not self.next_tokens_are_prompt:\n", + " self.counter = self.counter+1\n", + " else:\n", + " self.counter = 0\n", + " super().put(value)\n", + "\n", + "\n", + "# Copied and modified from https://github.com/bigcode-project/bigcode-evaluation-harness/blob/main/bigcode_eval/generation.py#L13\n", + "class SuffixCriteria(StoppingCriteria):\n", + " def __init__(self, start_length, eof_strings, tokenizer, check_fn=None):\n", + " self.start_length = start_length\n", + " self.eof_strings = eof_strings\n", + " self.tokenizer = tokenizer\n", + " if check_fn is None:\n", + " check_fn = lambda decoded_generation: any(\n", + " [decoded_generation.endswith(stop_string) for stop_string in self.eof_strings]\n", + " )\n", + " self.check_fn = check_fn\n", + "\n", + " def __call__(self, input_ids, scores, **kwargs):\n", + " \"\"\"Returns True if generated sequence ends with any of the stop strings\"\"\"\n", + " decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])\n", + " return all([self.check_fn(decoded_generation) for decoded_generation in decoded_generations])\n", + "\n", + "\n", + "def is_partial_stop(output, stop_str):\n", + " \"\"\"Check whether the output contains a partial stop str.\"\"\"\n", + " for i in range(0, min(len(output), len(stop_str))):\n", + " if stop_str.startswith(output[-i:]):\n", + " return True\n", + " return False\n", + "\n", + "\n", + "\n", + "# Set the chat template to the tokenizer. The chat template implements the simple template of\n", + "# User: content\n", + "# Assistant: content\n", + "# ...\n", + "# Read more about chat templates here https://huggingface.co/docs/transformers/main/en/chat_templating\n", + "tokenizer.chat_template = \"{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}\"\n", + "\n", + "\n", + "def prepare_history_for_model(history):\n", + " \"\"\"\n", + " Converts the history to a tokenized prompt in the format expected by the model.\n", + " Params:\n", + " history: dialogue history\n", + " Returns:\n", + " Tokenized prompt\n", + " \"\"\"\n", + " messages = []\n", + " for idx, (user_msg, model_msg) in enumerate(history):\n", + " # skip the last assistant message if its empty, the tokenizer will do the formating\n", + " if idx == len(history) - 1 and not model_msg:\n", + " messages.append({\"role\": \"User\", \"content\": user_msg})\n", + " break\n", + " if user_msg:\n", + " messages.append({\"role\": \"User\", \"content\": user_msg})\n", + " if model_msg:\n", + " messages.append({\"role\": \"Assistant\", \"content\": model_msg})\n", + " input_token = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt=True,\n", + " tokenize=True,\n", + " return_tensors=\"pt\",\n", + " return_dict=True\n", + " )\n", + " return input_token\n", + "\n", + "\n", + "def generate(history, temperature, max_new_tokens, top_p, repetition_penalty, assisted):\n", + " \"\"\"\n", + " Generates the assistant's reponse given the chatbot history and generation parameters\n", + "\n", + " Params:\n", + " history: conversation history formated in pairs of user and assistant messages `[user_message, assistant_message]`\n", + " temperature: parameter for control the level of creativity in AI-generated text.\n", + " By adjusting the `temperature`, you can influence the AI model's probability distribution, making the text more focused or diverse.\n", + " max_new_tokens: The maximum number of tokens we allow the model to generate as a response.\n", + " top_p: parameter for control the range of tokens considered by the AI model based on their cumulative probability.\n", + " repetition_penalty: parameter for penalizing tokens based on how frequently they occur in the text.\n", + " assisted: boolean parameter to enable/disable assisted generation with speculative decoding.\n", + " Yields:\n", + " Updated history and generation status.\n", + " \"\"\"\n", + " start = time.perf_counter()\n", + " # Construct the input message string for the model by concatenating the current system message and conversation history\n", + " # Tokenize the messages string\n", + " inputs = prepare_history_for_model(history)\n", + " input_length = inputs['input_ids'].shape[1]\n", + " # truncate input in case it is too long.\n", + " # TODO improve this\n", + " if input_length > 2000:\n", + " history = [history[-1]]\n", + " inputs = prepare_history_for_model(history)\n", + " input_length = inputs['input_ids'].shape[1]\n", + "\n", + " prompt_char = \"▌\"\n", + " history[-1][1] = prompt_char\n", + " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n", + "\n", + " streamer = TextIteratorStreamerWithCounter(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", + " streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", + " # Create a stopping criteria to prevent the model from playing the role of the user aswell.\n", + " stop_str = [\"\\nUser:\", \"\\nAssistant:\", \"\\nRules:\", \"\\nQuestion:\"]\n", + " stopping_criteria = StoppingCriteriaList([SuffixCriteria(input_length, stop_str, tokenizer)])\n", + " # Prepare input for generate\n", + " generation_config = GenerationConfig(\n", + " max_new_tokens=max_new_tokens,\n", + " do_sample=temperature > 0.0,\n", + " temperature=temperature if temperature > 0.0 else 1.0,\n", + " repetition_penalty=repetition_penalty,\n", + " top_p=top_p,\n", + " eos_token_id=[tokenizer.eos_token_id],\n", + " pad_token_id=tokenizer.eos_token_id,\n", + " )\n", + " generate_kwargs = dict(\n", + " streamer=streamer,\n", + " generation_config=generation_config,\n", + " stopping_criteria=stopping_criteria,\n", + " ) | inputs\n", + "\n", + " if assisted:\n", + " target_generate = stateless_model.generate\n", + " generate_kwargs[\"assistant_model\"] = asst_model\n", + " else:\n", + " target_generate = model.generate\n", + "\n", + " t1 = Thread(target=target_generate, kwargs=generate_kwargs)\n", + " t1.start()\n", + "\n", + " # Initialize an empty string to store the generated text.\n", + " partial_text = \"\"\n", + " for new_text in streamer:\n", + " partial_text += new_text\n", + " history[-1][1] = partial_text + prompt_char\n", + " for s in stop_str:\n", + " if (pos := partial_text.rfind(s)) != -1:\n", + " break\n", + " if pos != -1:\n", + " partial_text = partial_text[:pos]\n", + " break\n", + " elif any([is_partial_stop(partial_text, s) for s in stop_str]):\n", + " continue\n", + " yield history, \"Status: Generating...\", *([gr.update(interactive=False)] * 4)\n", + " history[-1][1] = partial_text\n", + " generation_time = time.perf_counter() - start\n", + " n_tokens = \n", + " yield history, f'Generation time: {generation_time:.2f} sec', f'Throughput: {generation_time/n_tokens:.2f} tok/sec', *([gr.update(interactive=True)] * 4)" + ] + }, + { + "cell_type": "markdown", + "id": "29fe1ae5-9929-4789-9293-612b2062e2a8", + "metadata": {}, + "source": [ + "Next we will create the actual demo using Gradio. The layout will be very simple, a chatbot window followed by a text prompt and some controls.\n", + "We will also include sliders to adjust generation parameters like temperature and length of response we allow the model to generate.\n", + "\n", + "To install Gradio dependency, please uncomment the following cell and run" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b61a9a9f", + "metadata": {}, + "outputs": [], + "source": [ + "# ! pip install gradio" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "9ae1aa4e-3539-49a1-8f32-62b818ee1002", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Closing server running on port: 7860\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\sdp\\miniforge3\\envs\\multimodal\\Lib\\site-packages\\gradio\\components\\chatbot.py:282: UserWarning: You have not specified a value for the `type` parameter. Defaulting to the 'tuples' format for chatbot messages, but this is deprecated and will be removed in a future version of Gradio. Please set type='messages' instead, which uses openai-style dictionaries with 'role' and 'content' keys.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import gradio as gr\n", + "\n", + "try:\n", + " demo.close()\n", + "except:\n", + " pass\n", + "\n", + "\n", + "EXAMPLES = [\n", + " [\"What is OpenVINO?\"],\n", + " [\"Can you explain to me briefly what is Python programming language?\"],\n", + " [\"Explain the plot of Cinderella in a sentence.\"],\n", + " [\"Write a Python function to perform binary search over a sorted list. Use markdown to write code\"],\n", + " [\"Lily has a rubber ball that she drops from the top of a wall. The wall is 2 meters tall. How long will it take for the ball to reach the ground?\"],\n", + "]\n", + "\n", + "\n", + "def add_user_text(message, history):\n", + " \"\"\"\n", + " Add user's message to chatbot history\n", + "\n", + " Params:\n", + " message: current user message\n", + " history: conversation history\n", + " Returns:\n", + " Updated history, clears user message and status\n", + " \"\"\"\n", + " # Append current user message to history with a blank assistant message which will be generated by the model\n", + " history.append([message, None])\n", + " return ('', history)\n", + "\n", + "\n", + "def prepare_for_regenerate(history):\n", + " \"\"\"\n", + " Delete last assistant message to prepare for regeneration\n", + "\n", + " Params:\n", + " history: conversation history\n", + " Returns:\n", + " updated history\n", + " \"\"\" \n", + " history[-1][1] = None\n", + " return history\n", + "\n", + "\n", + "with gr.Blocks(theme=gr.themes.Soft()) as demo:\n", + " gr.Markdown('

Chat with DeepSeek-R1-Distill-Llama-8B on Lunar Lake iGPU

')\n", + " chatbot = gr.Chatbot()\n", + " with gr.Row():\n", + " assisted = gr.Checkbox(value=False, label=\"Generation\", scale=10)\n", + " msg = gr.Textbox(placeholder=\"Enter message here...\", show_label=False, autofocus=True, scale=75)\n", + " status = gr.Textbox(\"Status: Idle\", show_label=False, max_lines=1, scale=15)\n", + " with gr.Row():\n", + " submit = gr.Button(\"Submit\", variant='primary')\n", + " regenerate = gr.Button(\"Regenerate\")\n", + " clear = gr.Button(\"Clear\")\n", + " with gr.Accordion(\"Advanced Options:\", open=False):\n", + " with gr.Row():\n", + " with gr.Column():\n", + " temperature = gr.Slider(\n", + " label=\"Temperature\",\n", + " value=0.0,\n", + " minimum=0.0,\n", + " maximum=1.0,\n", + " step=0.05,\n", + " interactive=True,\n", + " )\n", + " max_new_tokens = gr.Slider(\n", + " label=\"Max new tokens\",\n", + " value=512,\n", + " minimum=0,\n", + " maximum=1024,\n", + " step=32,\n", + " interactive=True,\n", + " )\n", + " with gr.Column():\n", + " top_p = gr.Slider(\n", + " label=\"Top-p (nucleus sampling)\",\n", + " value=1.0,\n", + " minimum=0.0,\n", + " maximum=1.0,\n", + " step=0.05,\n", + " interactive=True,\n", + " )\n", + " repetition_penalty = gr.Slider(\n", + " label=\"Repetition penalty\",\n", + " value=1.0,\n", + " minimum=1.0,\n", + " maximum=2.0,\n", + " step=0.1,\n", + " interactive=True,\n", + " )\n", + " gr.Examples(\n", + " EXAMPLES, inputs=msg, label=\"Click on any example and press the 'Submit' button\"\n", + " )\n", + "\n", + " # Sets generate function to be triggered when the user submit a new message\n", + " gr.on(\n", + " triggers=[submit.click, msg.submit],\n", + " fn=add_user_text,\n", + " inputs=[msg, chatbot],\n", + " outputs=[msg, chatbot],\n", + " queue=False,\n", + " ).then(\n", + " fn=generate,\n", + " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],\n", + " outputs=[chatbot, status, msg, submit, regenerate, clear],\n", + " concurrency_limit=1,\n", + " queue=True\n", + " )\n", + " regenerate.click(\n", + " fn=prepare_for_regenerate,\n", + " inputs=chatbot,\n", + " outputs=chatbot,\n", + " queue=True,\n", + " concurrency_limit=1\n", + " ).then(\n", + " fn=generate,\n", + " inputs=[chatbot, temperature, max_new_tokens, top_p, repetition_penalty, assisted],\n", + " outputs=[chatbot, status, msg, submit, regenerate, clear],\n", + " concurrency_limit=1,\n", + " queue=True\n", + " )\n", + " clear.click(fn=lambda: (None, \"Status: Idle\"), inputs=None, outputs=[chatbot, status], queue=False)" + ] + }, + { + "cell_type": "markdown", + "id": "1d1baf09-26f1-40ab-896c-3468b5e89fec", + "metadata": {}, + "source": [ + "That's it, all that is left is to start the demo!\n", + "\n", + "When you're done you can use `demo.close()` to close the demo" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "5b73962d-f977-45b7-be3a-32b65e546737", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "* Running on local URL: http://127.0.0.1:7860\n", + "\n", + "To create a public link, set `share=True` in `launch()`.\n" + ] + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "demo.launch()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "1e26a0bc-6a78-4185-8b0c-7e9450ba5868", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Closing server running on port: 7860\n" + ] + } + ], + "source": [ + "demo.close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25c70116", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "multimodal", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}