diff --git a/supporting-blog-content/evaluating-search-relevance-part-2/phi3-as-relevance-judge.ipynb b/supporting-blog-content/evaluating-search-relevance-part-2/phi3-as-relevance-judge.ipynb new file mode 100644 index 00000000..1c227d23 --- /dev/null +++ b/supporting-blog-content/evaluating-search-relevance-part-2/phi3-as-relevance-judge.ipynb @@ -0,0 +1,775 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "f30a6362-caea-4916-b796-0fbab99b41b1", + "metadata": {}, + "source": [ + "## Using Phi-3 as relevance judge" + ] + }, + { + "cell_type": "markdown", + "id": "37f6f8d1-173d-492a-bf80-851f11071315", + "metadata": {}, + "source": [ + "In this notebook we will use [Phi-3](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) as a relevance judge between a query and a document. " + ] + }, + { + "cell_type": "markdown", + "id": "a23ca995-c54a-4146-b7ca-e53952cb9a3a", + "metadata": {}, + "source": [ + "## Requirements\n", + "\n", + "For this notebook, you will need a working Python environment (**Python 3.10.x** or later) and some Python dependencies:\n", + "- `torch`, to use PyTorch as our backend\n", + "- Huggingface's `transformers` library\n", + "- `accelerate` and`bitsandbytes` for quantization support (a GPU is required to enable quantization)\n", + "- `scikit-learn` for metrics computation\n", + "- `pandas` for generic data handling" + ] + }, + { + "cell_type": "markdown", + "id": "b6836fec-ccd0-4fab-981c-f76f5ba7113e", + "metadata": {}, + "source": [ + "## Installing packages\n", + "\n", + "Let's start by installing the necessary Python libraries (preferably in a virtual environment)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b5a56591-4d9d-435b-b165-f9fbfa5615f6", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U torch transformers accelerate bitsandbytes scikit-learn pandas" + ] + }, + { + "cell_type": "markdown", + "id": "180ff935-ddae-43cd-8e85-4c236d18159f", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "eaa5e1ff-4709-4342-9169-9094c67f143e", + "metadata": {}, + "source": [ + "## Implementation" + ] + }, + { + "cell_type": "markdown", + "id": "cfda1967-8feb-400e-b125-dc8e2c349467", + "metadata": {}, + "source": [ + "First, the necessary imports:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c5c76bc-aed0-4e44-b0a7-724470cbb7ed", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import Counter\n", + "from functools import partial\n", + "from typing import Any, Iterable, List, Optional\n", + "import json\n", + "import re\n", + "\n", + "from sklearn.metrics import f1_score, precision_score, recall_score\n", + "from transformers import (\n", + " AutoModelForCausalLM,\n", + " AutoTokenizer,\n", + " BitsAndBytesConfig,\n", + " pipeline,\n", + ")\n", + "import pandas as pd\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "id": "f4128f7f-7ba2-406f-ba5d-435dd4a241f2", + "metadata": {}, + "source": [ + "Now, let's create a class that will responsible for loading the `Phi-3` model and perform inference on its inputs. A few notes before we dive into the code:\n", + "* Even though Phi-3 is a small language model (SLM) with a parameter count of 3.8B we load it with 4-bit quantization that makes it a good choice even for consumer-grade GPUs\n", + "* Following the example code provided in the corresponding HF page [here](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) we are also using text generation pipelines to perform inference. More optimized setups are possible but out of scope for this notebook\n", + "* Regular expressions are used to extract the answer from the LLM output. The `response_types` argument defines the set of acceptable classes (e.g. `Relevant`, `Not Relevant`)\n", + "* There are two options for decoding:\n", + " * `greedy decoding`, where sampling is disabled and the outputs are (more or less) deterministic\n", + " * `beam decoding`, where multiple LLM calls for the same set of inputs are performed and the results are aggregated through majority voting. In the code below `iterations` is the number of LLM calls requested with an appropriate setting for the `temperature` (e.g. 0.5)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46f0a0d8-0d4c-4545-8c43-ca29a579fe62", + "metadata": {}, + "outputs": [], + "source": [ + "def get_device_map():\n", + " \"\"\"Retrieve the backend\"\"\"\n", + " if torch.cuda.is_available():\n", + " return \"cuda\"\n", + "\n", + " if torch.backends.mps.is_available():\n", + " return \"mps\"\n", + "\n", + " return \"auto\"\n", + "\n", + "\n", + "class Phi3Evaluator:\n", + " \"\"\"Evaluator based on the Phi-3 model\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " model_name_or_path: str,\n", + " response_types: List[str],\n", + " iterations: int = 1,\n", + " temperature: float = 0.0,\n", + " ):\n", + "\n", + " # set 4-bit quantization\n", + " quant_config = BitsAndBytesConfig(load_in_4bit=True)\n", + " model = AutoModelForCausalLM.from_pretrained(\n", + " model_name_or_path,\n", + " torch_dtype=\"auto\",\n", + " trust_remote_code=True,\n", + " device_map=get_device_map(),\n", + " quantization_config=quant_config if get_device_map() == \"cuda\" else None,\n", + " )\n", + " tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n", + "\n", + " # defining the text generation pipeline\n", + " self.pipeline = pipeline(\n", + " \"text-generation\",\n", + " model=model,\n", + " tokenizer=tokenizer,\n", + " device_map=get_device_map(),\n", + " )\n", + " # define an appropriate regex for the expected outputs\n", + " self.regex_str = (\n", + " r\"(\" + r\"|\".join(response_types) + r\")\" if response_types else None\n", + " )\n", + " # temperature setting\n", + " self._temperature = temperature\n", + " # number of LLM calls\n", + " self._iterations = iterations\n", + "\n", + " def _get_generation_args(self):\n", + " \"\"\"Arguments for the text generation pipeline\"\"\"\n", + " if self._temperature > 0.0:\n", + " return {\n", + " \"return_full_text\": False,\n", + " \"temperature\": self._temperature,\n", + " \"do_sample\": True,\n", + " \"num_return_sequences\": self._iterations,\n", + " \"num_beams\": 2 * self._iterations,\n", + " }\n", + " return {\n", + " \"return_full_text\": False,\n", + " \"do_sample\": False,\n", + " \"num_return_sequences\": 1,\n", + " }\n", + "\n", + " def __call__(\n", + " self, prompts: Iterable[str], max_output_tokens: int, batch_size: int = 8\n", + " ) -> Iterable[str]:\n", + " \"\"\"Generate responses to the given prompts\"\"\"\n", + "\n", + " # set args for the text generation pipeline\n", + " gen_args = self._get_generation_args()\n", + " gen_args.update({\"batch_size\": batch_size, \"max_new_tokens\": max_output_tokens})\n", + " inputs_for_pipeline = [\n", + " [{\"role\": \"user\", \"content\": prompt}] for prompt in prompts\n", + " ]\n", + "\n", + " for output in self.pipeline(inputs_for_pipeline, **gen_args):\n", + " output = [output] if isinstance(output, dict) else output\n", + " llm_outputs = [item[\"generated_text\"] for item in output]\n", + " parsed_outputs = [\n", + " self._postprocess_output(llm_output) for llm_output in llm_outputs\n", + " ]\n", + " clean_outputs = [\n", + " output for output in parsed_outputs if output != \"Unparsable\"\n", + " ]\n", + " if clean_outputs:\n", + " winning_label = Counter(clean_outputs).most_common(1)\n", + " yield llm_outputs, winning_label[0][0], round(\n", + " winning_label[0][1] / len(clean_outputs), 2\n", + " )\n", + " else:\n", + " yield llm_outputs, \"Unparsable\", -1.0\n", + "\n", + " def _postprocess_output(self, llm_out: str) -> str:\n", + " \"\"\"Cleans the output from the LLM\"\"\"\n", + " if not self.regex_str:\n", + " return llm_out\n", + "\n", + " re_search = re.search(self.regex_str, llm_out, re.IGNORECASE)\n", + " if re_search:\n", + " return re_search.group()\n", + "\n", + " return \"Unparsable\"" + ] + }, + { + "cell_type": "markdown", + "id": "87eeef16-c040-4760-9be6-517fc6eefbac", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "4de131fe-e8ec-40a2-92aa-5765235f01a9", + "metadata": {}, + "source": [ + "## Prompts\n", + "\n", + "In this section we define the prompts that we will use later for LLM inference. \n", + "\n", + "There are three types of prompt templates namely: \n", + "* `pointwise`\n", + "* `pointwise` with chain-of-thought\n", + "* `pairwise`\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34d24199-9504-487c-8690-6e191aca4170", + "metadata": {}, + "outputs": [], + "source": [ + "QA_POINTWISE_RELEVANCE_PROMPT_TEMPLATE = (\n", + " \"You are an expert in information retrieval and your task is to estimate the relevance of a retrieved document to a\"\n", + " \" query.\\n\"\n", + " \"More specifically, you will be provided with two pieces of information:\\n\"\n", + " \"- Query, which is the question we want to answer\\n\"\n", + " \"- Retrieved Document, which is the document we want to evaluate.\\n\\n\"\n", + " 'Your task is to predict \"Relevant\" if the Retrieved Document contains the information required to provide an '\n", + " 'answer to the Query, otherwise you should print \"Not Relevant\" \\n'\n", + " \"#####\\n\"\n", + " \"Here are your inputs:\\n\"\n", + " \"Query: {query_text}\\n\"\n", + " \"Retrieved Document: {retrieved_text}\\n\"\n", + " \"#####\\n\\n\"\n", + " \"Take a step back and reflect carefully on how best to solve your task\\n\"\n", + " 'You should provide your answer in the form of a boolean value: \"Relevant\" or \"Not Relevant\"\\n'\n", + ")\n", + "\n", + "CHAIN_OF_THOUGHT_PROMPT_TEMPLATE = (\n", + " \"You are an expert in information retrieval and your task is to decide if a retrieved \"\n", + " \"document is relevant to a query or not. You will be provided with two pieces of information:\\n\"\n", + " \"- QUERY, which is a web search\\n\"\n", + " \"- DOCUMENT, which is a web page snippet\\n\"\n", + " 'Your task is to predict \"Relevant\" if the DOCUMENT contains the information required '\n", + " 'to provide an answer to the QUERY, otherwise you should predict \"Not Relevant\".\\n'\n", + " \"Solve this task step by step and use the following examples for help.\\n\\n\"\n", + " \"####\\n\"\n", + " \"Example 1\\n\"\n", + " \"QUERY: define interconnected\\n\"\n", + " \"DOCUMENT: Matching (adjective) corresponding in pattern, colour, or design; complementary.\\n\"\n", + " 'Intent: The query is looking for the dictionary definition of the term \"interconnected\".\\n'\n", + " 'Key Information: The document provides a definition of the term \"matching\".\\n'\n", + " 'Explanation: The term \"matching\" is closely related to the word \"interconnected\". However, the query is '\n", + " 'asking for the dictionary definition of \"interconnected\", not \"matching\".\\n'\n", + " 'Answer: \"Not Relevant\"\\n\\n'\n", + " \"Example 2\\n\"\n", + " \"QUERY: what are some good chocolate cake recipes\\n\"\n", + " \"DOCUMENT: Here are some of the best chocolate cakes to make yourself.\\n\"\n", + " \"Intent: The query is looking for chocolate cake recipes.\\n\"\n", + " \"Key Information: The document provides a list of chocolate cakes you can make yourself.\\n\"\n", + " \"Explanation: The document provides information about chocolate cakes you can make yourself. This makes it \"\n", + " \"very likely that it provides their recipes as well.\\n\"\n", + " 'Answer: \"Relevant\"\\n\\n'\n", + " \"Example 3\\n\"\n", + " \"QUERY: enable parental controls on Netflix\\n\"\n", + " \"DOCUMENT: Here's how to turn on parental controls on Amazon Prime Video:\\n\"\n", + " \"Intent: The query is looking for instructions on how to enable parental controls on Netflix.\\n\"\n", + " \"Key Information: The document provides instructions on how to turn on parental controls on Amazon Prime Video.\\n\"\n", + " \"Explanation: The document provides instructions for how to turn on parental controls for Amazon Prime Video, \"\n", + " \"not Netflix.\\n\"\n", + " 'Answer: \"Not Relevant\"\\n\\n'\n", + " \"####\\n\\n\"\n", + " \"To solve your task, first check if any of the examples are useful, then think carefully if \"\n", + " \"Key Information matches the Intent. Use the following format:\\n\"\n", + " \"Intent: [the intent behind QUERY]\\n\"\n", + " \"Key Information: [the key information contained in DOCUMENT]\\n\"\n", + " \"Explanation: [your explanation]\\n\"\n", + " \"Answer: [your answer]\\n\"\n", + " '[your answer] should be either \"Relevant\" or \"Not Relevant\".\\n'\n", + " \"Here are the QUERY and DOCUMENT for you to evaluate:\\n\"\n", + " \"QUERY: {query_text}\\n\"\n", + " \"DOCUMENT: {retrieved_text}\\n\"\n", + ")\n", + "\n", + "\n", + "QA_PAIRWISE_RELEVANCE_PROMPT_TEMPLATE = (\n", + " \"You are an expert in information retrieval and your task is to estimate the relevance of a retrieved document to a query.\\n\"\n", + " \"More specifically, you will be provided with three pieces of information:\\n\"\n", + " \"- Query, which is the question we want to answer\\n\"\n", + " \"- Positive Document, which is a document that contains the correct answer to the query\\n\"\n", + " \"- Retrieved Document, which is the document we want to evaluate\\n\"\n", + " 'Your task is to predict \"Relevant\" if the Retrieved Document contains the information required to provide an answer to the Query, otherwise you should print \"Not Relevant\" \\n'\n", + " \"You can take advantage of the information in the Positive Document to identify the correct answer to the Query and then verify that the Retrieved Document contains that piece of information\\n\"\n", + " \"#####\\n\"\n", + " \"Here are your inputs:\\n\"\n", + " \"Query: {query_text}\\n\"\n", + " \"Positive Document: {positive_text}\\n\"\n", + " \"Retrieved Document: {retrieved_text}\\n\"\n", + " \"#####\\n\\n\"\n", + " \"Take a step back and reflect carefully on how best to solve your task\\n\"\n", + " 'You should provide your answer in the form of a boolean value: \"Relevant\" or \"Not Relevant\"\\n'\n", + " \"Good luck!\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b9f52ce9-198f-4219-b1e0-2d20ac13218f", + "metadata": {}, + "source": [ + "We also define a helper structure containing:\n", + "* `prompt_inputs`, specifies the list of attributes that need to be set in the prompt template. These attributes have the same name in the training data\n", + "* `prompt_template`, the prompt template to use\n", + "* `response_types`, the names of the expected output classes.\n", + "* `metadata`, the extra attributes that need to be preserved\n", + "* `max_output_tokens`, the maximum number of tokens that the LLM outputs\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "135e232d-ace8-4e4c-b5d1-11447e716108", + "metadata": {}, + "outputs": [], + "source": [ + "POOL = {\n", + " \"qa_pointwise\": {\n", + " \"prompt_inputs\": [\"query_text\", \"retrieved_text\"],\n", + " \"prompt_template\": QA_POINTWISE_RELEVANCE_PROMPT_TEMPLATE,\n", + " \"response_types\": [\"Relevant\", \"Not Relevant\"],\n", + " \"metadata\": [\"qid\", \"retrieved_doc_id\", \"human_judgment\"],\n", + " \"max_output_tokens\": 4,\n", + " },\n", + " \"qa_pairwise\": {\n", + " \"prompt_inputs\": [\"query_text\", \"positive_text\", \"retrieved_text\"],\n", + " \"prompt_template\": QA_PAIRWISE_RELEVANCE_PROMPT_TEMPLATE,\n", + " \"response_types\": [\"Relevant\", \"Not Relevant\"],\n", + " \"metadata\": [\"qid\", \"retrieved_doc_id\", \"human_judgment\"],\n", + " \"max_output_tokens\": 4,\n", + " },\n", + " \"chain_of_thought\": {\n", + " \"prompt_inputs\": [\"query_text\", \"retrieved_text\"],\n", + " \"prompt_template\": CHAIN_OF_THOUGHT_PROMPT_TEMPLATE,\n", + " \"response_types\": ['Answer: \"Relevant\"', 'Answer: \"Not Relevant\"'],\n", + " \"metadata\": [\"qid\", \"retrieved_doc_id\", \"human_judgment\"],\n", + " \"max_output_tokens\": 250,\n", + " },\n", + "}\n", + "\n", + "\n", + "def get_llm_evaluator(\n", + " model_name_or_path: str, task_type: str, iterations: int, temperature: float\n", + "):\n", + " \"\"\"Helper function that returns the evaluator\"\"\"\n", + " # quick sanity check\n", + " if task_type not in POOL:\n", + " raise ValueError(\n", + " f\"Task type {task_type} not supported please select one of {list(POOL.keys())}\"\n", + " )\n", + " task_type_def = POOL[task_type]\n", + "\n", + " return Phi3Evaluator(\n", + " model_name_or_path,\n", + " response_types=task_type_def[\"response_types\"],\n", + " iterations=iterations,\n", + " temperature=temperature,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "80a916be-48c7-4ece-aa32-94ed693b4df9", + "metadata": {}, + "source": [ + "We are now ready to define the parameters of our run:\n", + "* `MODEL_NAME`, the name of the language model\n", + "* `BATCH_SIZE`, the batch size to use for inference\n", + "* `TASK_TYPE`, one of `qa_pointwise`, `qa_pairwise`, `chain_of_thought`\n", + "* `TEMPERATURE` & `ITERATIONS` are decoding options explained at the beginning of the notebook" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc2a8e49-21fd-45a6-bb00-4ffc96f700df", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_NAME = \"microsoft/Phi-3-mini-4k-instruct\"\n", + "BATCH_SIZE = 2\n", + "TASK_TYPE = \"qa_pointwise\"\n", + "TEMPERATURE = 0.0 # values > 0 will activate sampling, in that case you should also increase the number of iterations (> 1)\n", + "ITERATIONS = 1" + ] + }, + { + "cell_type": "markdown", + "id": "586456e8-7f42-45ad-b72f-ac5532bb9cdf", + "metadata": {}, + "source": [ + "and create an instance of our evaluator" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "013889ac-16a9-4f08-a8f2-c081f2308327", + "metadata": {}, + "outputs": [], + "source": [ + "llm_evaluator = get_llm_evaluator(\n", + " model_name_or_path=MODEL_NAME,\n", + " task_type=TASK_TYPE,\n", + " iterations=ITERATIONS,\n", + " temperature=TEMPERATURE,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "dd0e0892-fcd3-44db-b7a3-d290782d19a5", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "093fa778-5563-41bb-872e-f5bbc5625a29", + "metadata": {}, + "source": [ + "## Running the pipeline\n", + "\n", + "Let's execute the pipeline by first adding a few test data points" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0993319-81bf-46f1-9a4d-7be24965dca0", + "metadata": {}, + "outputs": [], + "source": [ + "SAMPLE_DATA = [\n", + " {\n", + " \"qid\": \"155234\",\n", + " \"query_text\": \"do bigger tires affect gas mileage\",\n", + " \"positive_text\": \" Tire Width versus Gas Mileage. Tire width is one of the only tire size factors that can influence gas mileage in a positive way. For example, a narrow tire will have less wind resistance, rolling resistance, and weight; thus increasing gas mileage.\",\n", + " \"retrieved_text\": \" Bigger tires are an excellent upgrade for off road driving but they decrease gas mileage on road. It\\u00e2\\u0080\\u0099s not uncommon to experience a 2-4 mpg reduction in gas mileage with your off road tires. Why is this? 1. They are HEAVY. Most weigh 20+ lbs more than stock tires. 2. Increased Rolling Resistance from aggressive tread. 3.\",\n", + " \"retrieved_doc_id\": \"1048111\",\n", + " \"human_judgment\": \"Relevant\",\n", + " },\n", + " {\n", + " \"qid\": \"300674\",\n", + " \"query_text\": \"how many years did william bradford serve as governor of plymouth colony?\",\n", + " \"positive_text\": \" http://en.wikipedia.org/wiki/William_Bradford_(Plymouth_Colony_governor) William Bradford (c.1590 \\u00e2\\u0080\\u0093 1657) was an English Separatist leader in Leiden, Holland and in Plymouth Colony was a signatory to the Mayflower Compact. He served as Plymouth Colony Governor five times covering about thirty years between 1621 and 1657.\",\n", + " \"retrieved_text\": \" William Bradford was the governor of Plymouth Colony for 30 years. The colony was founded by people called Puritans. They were some of the first people from England to settle in what is now the United States. Bradford helped make Plymouth the first lasting colony in New England.\",\n", + " \"retrieved_doc_id\": \"2495763\",\n", + " \"human_judgment\": \"Relevant\",\n", + " },\n", + " {\n", + " \"qid\": \"125705\",\n", + " \"query_text\": \"define preventive\",\n", + " \"positive_text\": \" Adjective[edit] preventive \\u00e2\\u0080\\u008e(comparative more preventive, superlative most preventive) 1 Preventing, hindering, or acting as an obstacle to. Carried out to deter military aggression.\",\n", + " \"retrieved_text\": \" The Prevention Institute defines prevention as a systematic process that promotes safe and healthy environments and behaviors, reducing the likelihood or frequency of an incident, injury or condition occurring (2007).\",\n", + " \"retrieved_doc_id\": \"6464885\",\n", + " \"human_judgment\": \"Not Relevant\",\n", + " },\n", + " {\n", + " \"qid\": \"1101276\",\n", + " \"query_text\": \"do spiders eat other animals\",\n", + " \"positive_text\": \" Spiders are animals that have 8 legs and use their fangs to inject venom into other animals and sometimes humans. But what do spiders eat? This post will answer that question, and also look at some interesting facts about spiders. What do spiders eat? Different species of spiders eat different things. Most species trap small insects and other spiders in their webs and eat them. A few large species of spiders prey on small birds and lizards. One species is vegetarian, feeding on acacia trees. Some baby spiders eat plant nectar. In captivity, spiders have been known to eat egg yoke, bananas, marmalade, milk and sausages. Interesting Facts About Spiders\",\n", + " \"retrieved_text\": \" Home > Spider FAQ > What do spiders eat? Virtually all spiders are predatory on other animals, especially insects and other spiders. Very large spiders are capable of preying on small vertebrate animals such as lizards, frogs, fish, tadpoles, or even small snakes or baby rodents. Large orb weavers have been observed to occasionally ensnare small birds or bats.\",\n", + " \"retrieved_doc_id\": \"1596582\",\n", + " \"human_judgment\": \"Relevant\",\n", + " },\n", + " {\n", + " \"qid\": \"89786\",\n", + " \"query_text\": \"central city definition\",\n", + " \"positive_text\": \" Definition of central city. : a city that constitutes the densely populated center of a metropolitan area.\",\n", + " \"retrieved_text\": \" Central City (DC Comics) For other uses, see Central City (disambiguation). Central City is a fictional American city appearing in comic books published by DC Comics. It is the home of the Silver Age version of the Flash (Barry Allen), and first appeared in Showcase #4 in September\\u00e2\\u0080\\u0093October 1956.\",\n", + " \"retrieved_doc_id\": \"213222\",\n", + " \"human_judgment\": \"Not Relevant\",\n", + " },\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "3dcac5ea-3b20-4213-b99a-e67ea6e6ffce", + "metadata": {}, + "source": [ + "Each item in the list if a dictionary with the following keys:\n", + "* `qid`: The query id in the original MSMARCO dataset\n", + "* `query_text`: self-explanatory\n", + "* `positive_text`: The text for the document that has been marked as relevant in the oringal `qrels` file\n", + "* `retrieved_doc_id`: the id of the retrieved document (after reranking) which is being judged for relevance\n", + "* `retrieved_text`: the text of the retrieved document\n", + "* `human_judgment`: The result of the human annotation, here it is either \"Relevant\" or \"Not Relevant\"" + ] + }, + { + "cell_type": "markdown", + "id": "46380077-8e8d-447a-9ccb-211d6662b397", + "metadata": {}, + "source": [ + "Let's also add two helper functions that allow us to iterate over the data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c565b819-ec95-40ba-bebf-8b9c60d3d223", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_prompts(data: list, task_type: str):\n", + " \"\"\"Generates prompts\"\"\"\n", + " necessary_keys = POOL[task_type][\"prompt_inputs\"] + POOL[task_type][\"metadata\"]\n", + " for line in data:\n", + " assert all(\n", + " key in line for key in necessary_keys\n", + " ), f\"Missing keys in line: {line}\"\n", + "\n", + " prompt_inputs = {key: line[key] for key in POOL[task_type][\"prompt_inputs\"]}\n", + "\n", + " prompt_template = POOL[task_type][\"prompt_template\"]\n", + " prompt = prompt_template.format(**prompt_inputs)\n", + "\n", + " yield prompt\n", + "\n", + "\n", + "def generate_data(data: list, task_type: str):\n", + " \"\"\"Iterates over the input data\"\"\"\n", + " necessary_keys = POOL[task_type][\"prompt_inputs\"] + POOL[task_type][\"metadata\"]\n", + " for line in data:\n", + " assert all(\n", + " key in line for key in necessary_keys\n", + " ), f\"Missing keys in line: {line}\"\n", + "\n", + " yield line" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c19037a5-e9ea-45bb-8bb7-00047fe68f30", + "metadata": {}, + "outputs": [], + "source": [ + "gen_prompts = partial(generate_prompts, data=SAMPLE_DATA, task_type=TASK_TYPE)\n", + "gen_data = partial(generate_data, data=SAMPLE_DATA, task_type=TASK_TYPE)" + ] + }, + { + "cell_type": "markdown", + "id": "e8dbb9e6-4351-46c0-b01c-cc28e8567b65", + "metadata": {}, + "source": [ + "And now we are ready to execute the pipeline and store the results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c805cdc-101d-48be-80d1-8c6651cb63e4", + "metadata": {}, + "outputs": [], + "source": [ + "outputs = []\n", + "\n", + "for (raw_responses, mapped_response, agreement), data_tuple in zip(\n", + " llm_evaluator(\n", + " gen_prompts(),\n", + " max_output_tokens=POOL[TASK_TYPE][\"max_output_tokens\"],\n", + " batch_size=BATCH_SIZE,\n", + " ),\n", + " gen_data(),\n", + "):\n", + "\n", + " # create a json structure to hold the data\n", + " json_out = {\n", + " key: data_tuple[key]\n", + " for key in POOL[TASK_TYPE][\"metadata\"] + POOL[TASK_TYPE][\"prompt_inputs\"]\n", + " }\n", + "\n", + " json_out[\"LLM_raw_response\"] = raw_responses\n", + " json_out[\"LLM_mapped_response\"] = mapped_response\n", + " json_out[\"LLM_agreement\"] = agreement\n", + "\n", + " outputs.append(json_out)" + ] + }, + { + "cell_type": "markdown", + "id": "cff708be-1017-42b1-a21b-3cbcf624943c", + "metadata": {}, + "source": [ + "Collect outputs into a Pandas dataframe" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72b4535f-67f4-4f15-b1df-222d1eca861e", + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.DataFrame.from_dict(outputs)" + ] + }, + { + "cell_type": "markdown", + "id": "0c354259-7718-4894-bc6f-0267ec6f5d18", + "metadata": {}, + "source": [ + "Quick scan of the outputs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65c3aa1e-d81d-4c60-89fb-136544997ac3", + "metadata": {}, + "outputs": [], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "2acaca72-2338-42ea-9a19-f37646245166", + "metadata": {}, + "source": [ + "And finally, let's measure the performance of the LLM. \n", + "\n", + "First, we compute the micro-F1 score which takes into account both classes" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46d66028-01a1-4b47-bf9a-029700d79311", + "metadata": {}, + "outputs": [], + "source": [ + "f1_score(y_true=df[\"human_judgment\"], y_pred=df[\"LLM_mapped_response\"], average=\"micro\")" + ] + }, + { + "cell_type": "markdown", + "id": "3d93a762-14ce-41b1-b471-8a7e68849d0b", + "metadata": {}, + "source": [ + "or we can focus on the `Relevant` class\n", + "\n", + "Precision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "205ea3ab-1135-4325-a944-9f9c492931cb", + "metadata": {}, + "outputs": [], + "source": [ + "precision_score(\n", + " y_true=df[\"human_judgment\"], y_pred=df[\"LLM_mapped_response\"], pos_label=\"Relevant\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "22be68f2-b585-48b4-8369-af30b44fe830", + "metadata": {}, + "source": [ + "Recall" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a486025e-c8c6-45ee-865f-aafbf483a06a", + "metadata": {}, + "outputs": [], + "source": [ + "recall_score(\n", + " y_true=df[\"human_judgment\"], y_pred=df[\"LLM_mapped_response\"], pos_label=\"Relevant\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "345ab752-cecb-46f9-953c-b8551b316b72", + "metadata": {}, + "source": [ + "binary-F1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "165699be-6b70-4310-b04a-2b643bc0eb05", + "metadata": {}, + "outputs": [], + "source": [ + "f1_score(\n", + " y_true=df[\"human_judgment\"],\n", + " y_pred=df[\"LLM_mapped_response\"],\n", + " average=\"binary\",\n", + " pos_label=\"Relevant\",\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}