diff --git a/.gitignore b/.gitignore index d3237e2..0fcf8ec 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +notebooks/ +data/ +wandb/ diff --git a/README.md b/README.md index 033e78e..cda98e8 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,34 @@ # llm-stack -End-to-end tech stack for the LLM data flywheel. +This tutorial series will show you how to build an end-to-end data flywheel for Large Language Models (LLMs). -## Chapters +We will be summarising arXiv abstracts. -- Building your training set with GPT-4 -- Fine-tuning an open-source LLM -- Evaluation -- Human feedback -- Unit tests -- Deployment +## What you will learn -## Installation +How to: -TODO +- Build a training set with GPT-4 or GPT-3.5 +- Fine-tune an open-source LLM +- Create a set of Evals to evaluate the model. +- Collect human feedback to improve the model. +- Deploy the model to an inference endpoint. -## Fine-tuning +## Software used -### Data +- [wandb](https://wandb.ai) for experiment tracking. This is where we will record all our artifacts (datasets, models, code) and metrics. +- [modal](https://modal.com/) for running jobs on the cloud. +- [huggingface](https://huggingface.co/) for all-things-LLM. +- [argilla](https://docs.argilla.io/en/latest/) for labelling our data. + +## Tutorial 1 - Generating a training set with GPT-3.5 + +In this tutorial, we will use GPT-3.5 to generate a training set for summarisation task. + +```python +modal run src/llm_stack/scripts/build_dataset_summaries.py +``` ## Contributing -TODO +Found any mistakes or want to contribute? Feel free to open a PR or an issue. diff --git a/poetry.lock b/poetry.lock index 9bdf9bb..15da069 100644 --- a/poetry.lock +++ b/poetry.lock @@ -165,13 +165,13 @@ files = [ [[package]] name = "anyio" -version = "4.1.0" +version = "3.7.1" description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false -python-versions = ">=3.8" +python-versions = ">=3.7" files = [ - {file = "anyio-4.1.0-py3-none-any.whl", hash = "sha256:56a415fbc462291813a94528a779597226619c8e78af7de0507333f700011e5f"}, - {file = "anyio-4.1.0.tar.gz", hash = "sha256:5a0bec7085176715be77df87fc66d6c9d70626bd752fcc85f57cdbee5b3760da"}, + {file = "anyio-3.7.1-py3-none-any.whl", hash = "sha256:91dee416e570e92c64041bd18b900d1d6fa78dff7048769ce5ac5ddad004fbb5"}, + {file = "anyio-3.7.1.tar.gz", hash = "sha256:44a3c9aba0f5defa43261a8b3efb97891f2bd7d804e0e1f56419befa1adfc780"}, ] [package.dependencies] @@ -179,9 +179,9 @@ idna = ">=2.8" sniffio = ">=1.1" [package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] -trio = ["trio (>=0.23)"] +doc = ["Sphinx", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.2.2)", "sphinxcontrib-jquery"] +test = ["anyio[trio]", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (<0.22)"] [[package]] name = "appdirs" @@ -697,6 +697,17 @@ files = [ {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, ] +[[package]] +name = "distro" +version = "1.8.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.8.0-py3-none-any.whl", hash = "sha256:99522ca3e365cac527b44bde033f64c6945d90eb9f769703caaec52b09bbd3ff"}, + {file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"}, +] + [[package]] name = "docker-pycreds" version = "0.4.0" @@ -780,6 +791,20 @@ typing-extensions = ">=4.5.0" [package.extras] all = ["email-validator (>=2.0.0)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "pydantic-extra-types (>=2.0.0)", "pydantic-settings (>=2.0.0)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +[[package]] +name = "feedparser" +version = "6.0.10" +description = "Universal feed parser, handles RSS 0.9x, RSS 1.0, RSS 2.0, CDF, Atom 0.3, and Atom 1.0 feeds" +optional = false +python-versions = ">=3.6" +files = [ + {file = "feedparser-6.0.10-py3-none-any.whl", hash = "sha256:79c257d526d13b944e965f6095700587f27388e50ea16fd245babe4dfae7024f"}, + {file = "feedparser-6.0.10.tar.gz", hash = "sha256:27da485f4637ce7163cdeab13a80312b93b7d0c1b775bef4a47629a3110bca51"}, +] + +[package.dependencies] +sgmllib3k = "*" + [[package]] name = "filelock" version = "3.13.1" @@ -953,6 +978,17 @@ multidict = "*" [package.extras] protobuf = ["protobuf (>=3.15.0)"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + [[package]] name = "h2" version = "4.1.0" @@ -979,6 +1015,51 @@ files = [ {file = "hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095"}, ] +[[package]] +name = "httpcore" +version = "1.0.2" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, + {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.23.0)"] + +[[package]] +name = "httpx" +version = "0.25.2" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"}, + {file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "huggingface-hub" version = "0.19.4" @@ -1596,6 +1677,29 @@ files = [ {file = "numpy-1.26.2.tar.gz", hash = "sha256:f65738447676ab5777f11e6bbbdb8ce11b785e105f690bc45966574816b6d3ea"}, ] +[[package]] +name = "openai" +version = "1.3.7" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.3.7-py3-none-any.whl", hash = "sha256:e5c51367a910297e4d1cd33d2298fb87d7edf681edbe012873925ac16f95bee0"}, + {file = "openai-1.3.7.tar.gz", hash = "sha256:18074a0f51f9b49d1ae268c7abc36f7f33212a0c0d08ce11b7053ab2d17798de"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<4" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.5,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "packaging" version = "23.2" @@ -2158,6 +2262,20 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "python-dotenv" +version = "1.0.0" +description = "Read key-value pairs from a .env file and set them as environment variables" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python-dotenv-1.0.0.tar.gz", hash = "sha256:a8df96034aae6d2d50a4ebe8216326c61c3eb64836776504fcca410e5937a3ba"}, + {file = "python_dotenv-1.0.0-py3-none-any.whl", hash = "sha256:f5971a9226b701070a4bf2c38c89e5a3f0d64de8debda981d1db98583009122a"}, +] + +[package.extras] +cli = ["click (>=5.0)"] + [[package]] name = "pytz" version = "2023.3.post1" @@ -2806,6 +2924,16 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-ruff", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"] testing-integration = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "packaging (>=23.1)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"] +[[package]] +name = "sgmllib3k" +version = "1.0.0" +description = "Py3k port of sgmllib." +optional = false +python-versions = "*" +files = [ + {file = "sgmllib3k-1.0.0.tar.gz", hash = "sha256:7868fb1c8bfa764c1ac563d3cf369c381d1325d36124933a726f29fcdaa812e9"}, +] + [[package]] name = "sigtools" version = "4.0.1" @@ -2945,6 +3073,20 @@ files = [ {file = "tblib-3.0.0.tar.gz", hash = "sha256:93622790a0a29e04f0346458face1e144dc4d32f493714c6c3dff82a4adb77e6"}, ] +[[package]] +name = "tenacity" +version = "8.2.3" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.7" +files = [ + {file = "tenacity-8.2.3-py3-none-any.whl", hash = "sha256:ce510e327a630c9e1beaf17d42e6ffacc88185044ad85cf74c0a8887c6a0f88c"}, + {file = "tenacity-8.2.3.tar.gz", hash = "sha256:5398ef0d78e63f40007c1fb4c0bff96e1911394d2fa8d194f77619c05ff6cc8a"}, +] + +[package.extras] +doc = ["reno", "sphinx", "tornado (>=4.5)"] + [[package]] name = "tokenizers" version = "0.15.0" @@ -3725,4 +3867,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "3.11.5" -content-hash = "2ccb1e593b41594d18763860de290ef72f5c1b8931c46622c3ac3686bda8fa0b" +content-hash = "e185ad7e57967ef8f1bc5f584db71d421d5f2d38879922529cc2fdf129c8d19a" diff --git a/pyproject.toml b/pyproject.toml index b8e57cd..814d8f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,12 @@ bitsandbytes = "^0.41.2.post2" wandb = "^0.16.0" accelerate = "^0.24.1" torch = "2.0.1" +feedparser = "^6.0.10" +python-dotenv = "^1.0.0" +httpx = "^0.25.2" +openai = "^1.3.7" +tenacity = "^8.2.3" +tqdm = "^4.66.1" [tool.poetry.group.test] diff --git a/src/llm_stack/build_dataset/__init__.py b/src/llm_stack/build_dataset/__init__.py new file mode 100644 index 0000000..86bf517 --- /dev/null +++ b/src/llm_stack/build_dataset/__init__.py @@ -0,0 +1,4 @@ +from .arxiv import ArxivAPI + + +__all__ = ["ArxivAPI"] diff --git a/src/llm_stack/build_dataset/arxiv.py b/src/llm_stack/build_dataset/arxiv.py new file mode 100644 index 0000000..861c6d5 --- /dev/null +++ b/src/llm_stack/build_dataset/arxiv.py @@ -0,0 +1,147 @@ +from datetime import datetime +from typing import Optional + +import feedparser +import pandas as pd +import requests + +from dateutil import parser + + +class ArxivAPI: + """Class for interacting with arXiv API.""" + + @classmethod + def search_and_parse_papers( + cls, + query: str, + start_date: int, + end_date: int, + first_result: int = 0, + max_results: int = 50, + ) -> pd.DataFrame: + """Search arXiv API for papers matching query and parse results into a dataframe. + + Parameters + ---------- + query + Query to search for. Must be a string of query words separated by commas. + + start_date + Start date of search in format YYYYMMDD. + + end_date + End date of search in format YYYYMMDD. + + first_result + Index of first result to return, by default 0. + + max_results + Maximum number of results to return, by default 50. + + Returns + ------- + Dataframe of parsed results from arXiv API. + + """ + response = cls.search(query, start_date, end_date, first_result, max_results) + feed = cls._get_feed(response) + entries = cls._get_entries(feed) + + # This will be slow for millions of entries but it's fine for our tiny dataset + parsed_entries = [cls._parse_entry(entry) for entry in entries] + return pd.DataFrame(parsed_entries) + + @classmethod + def search( + cls, + query: str, + start_date: int, + end_date: int, + first_result: int = 0, + max_results: int = 50, + timeout: int = 300, + ) -> requests.Response: + """Search arXiv API for papers matching query. + + Parameters + ---------- + query + Query to search for. Must be a string of query words separated by commas. + + start_date + Start date of search in format YYYYMMDD. + + end_date + End date of search in format YYYYMMDD. + + first_result + Index of first result to return, by default 0. + + max_results + Maximum number of results to return, by default 50. + + timeout + Timeout for request in seconds, by default 300. + + Returns + ------- + Response from arXiv API. + + """ + # Keeping things simple, only an OR query is supported + query = cls._construct_query(query) + + url = "http://export.arxiv.org/api/query?" + url += f"""search_query={query}&start={first_result}&max_results={max_results}&sortBy=submittedDate&sortOrder=descending&date-range={start_date}TO{end_date}""" + + response = requests.get(url, timeout=timeout) + response.raise_for_status() + + return response + + @staticmethod + def _construct_query(query: str, fields: Optional[list] = None) -> str: + """Construct query string for arXiv API.""" + if fields is None: + fields = ["all"] + # Split the query string into individual terms + terms = query.split(",") + + # Create a part of the query string for each field + field_queries = [] + for field in fields: + field_query = "+OR+".join([f'{field}:"{term.replace(" ", "+")}"' for term in terms]) + field_queries.append(f"({field_query})") + + # Combine the field queries with the OR operator + combined_query = "+OR+".join(field_queries) + + return combined_query + + @staticmethod + def _get_feed(response: requests.Response) -> feedparser.FeedParserDict: + """Get feed from arXiv API response.""" + return feedparser.parse(response.content) + + @staticmethod + def _get_entries(feed: feedparser.FeedParserDict) -> list: + """Get entries from arXiv API feed.""" + try: + return feed["entries"] + except KeyError as e: + raise ValueError("No entries found in feed.") from e + + @staticmethod + def _parse_entry(entry: feedparser.util.FeedParserDict) -> dict: + """Parse entry from arXiv API feed.""" + return { + "arxiv_url": entry["id"], + "title": entry["title"].replace("\n", " "), + "abstract": entry["summary"].replace("\n", " "), + "published": datetime.strftime(parser.parse(entry["published"]), "%Y-%m-%d"), + "pdf_url": [item["href"] for item in entry["links"] if all(w in item["href"] for w in ["arxiv", "pdf"])][ + 0 + ], + "categories": [d["term"] for d in entry["tags"]], + } diff --git a/src/llm_stack/build_dataset/prompts/openai_summarizer.json b/src/llm_stack/build_dataset/prompts/openai_summarizer.json new file mode 100644 index 0000000..fdeb558 --- /dev/null +++ b/src/llm_stack/build_dataset/prompts/openai_summarizer.json @@ -0,0 +1 @@ +{"role": "user", "content": "###Instructions###\nYou are an expert in machine learning who excels at editing text. Your task is to summarise the following academic abstract into one or two sentences. The summary must mention the main contribution of the paper and any tasks or datasets used.\n\n###Context###\n{text}\n"} diff --git a/src/llm_stack/dummy.py b/src/llm_stack/dummy.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/llm_stack/openai/__init__.py b/src/llm_stack/openai/__init__.py new file mode 100644 index 0000000..14781c1 --- /dev/null +++ b/src/llm_stack/openai/__init__.py @@ -0,0 +1,6 @@ +from .openai_api import OpenAILLM +from .prompt_template import FunctionTemplate +from .prompt_template import MessageTemplate + + +__all__ = ["OpenAILLM", "MessageTemplate", "FunctionTemplate"] diff --git a/src/llm_stack/openai/openai_api.py b/src/llm_stack/openai/openai_api.py new file mode 100644 index 0000000..7247cd3 --- /dev/null +++ b/src/llm_stack/openai/openai_api.py @@ -0,0 +1,191 @@ +import logging + +from typing import Optional +from typing import Union + +import httpx + +from openai import APIConnectionError +from openai import APIError +from openai import APIStatusError +from openai import AsyncOpenAI +from openai import RateLimitError +from tenacity import before_sleep_log +from tenacity import retry +from tenacity import retry_if_exception_type +from tenacity import stop_after_attempt +from tenacity import wait_exponential + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +class OpenAILLM: + """OpenAI Wrapper for the Chat Completion API and the Embeddings. + + Handles asynchronous requests to the OpenAI Chat Completion and the Embeddings API, including: + - Retrying failed requests + - Handling rate limits, timeouts, exceptions + + TODO: Splitting the Chat Completion and Embeddings API into two classes would be more elegant + but ain't nobody got time for that. + + """ + + def __init__(self, api_key: str, timeout: float = 30.0, max_retries: int = 0, **kwargs) -> None: + """Instantiate the Async OpenAI client. + + Parameters + ---------- + api_key + OpenAI API key. + + timeout + Timeout for the request in seconds. + + max_retries + Number of times to retry the request. + + **kwargs + Additional keyword arguments to pass to the Async OpenAI client. + See + + """ + self.client = AsyncOpenAI( + api_key=api_key, + timeout=timeout, + max_retries=max_retries, + **kwargs, + ) + + async def generate( + self, + messages: list[dict], + model: str = "gpt-3.5-turbo", + temperature: float = 0.0, + seed: Optional[int] = 42, + extra: Optional[dict] = None, + **openai_kwargs, + ) -> Union[ + dict, + str, + ]: + """Call OpenAI's async chat completion API and await the response. + + Handles asynchronous requests to the OpenAI Chat Completion API, including: + - Retrying failed requests + - Handling rate limits, timeouts, exceptions + - Truncating long messages, if needed + + Parameters + ---------- + messages + Messages to send to the OpenAI Chat Completion API. + + message_kwargs + Additional arguments to pass to the message preparation function. + This is a dict with all your f-string placeholders and their values. + + model + Model to use for token counting and completion. + + temperature + Temperature to use for completion. + + seed + Seed to use for completion. + + extra + Additional information to return with the response. + + openai_kwargs + Additional arguments to pass to the OpenAI Chat Completion API, like a seed. + + Returns + ------- + Response from OpenAI Chat Completion API. + + Usage + ----- + >>> messages = [ + >>> {"role": "system", "name": "assistant", "content": "Tell the user a joke about it's topic of choice"}, + >>> {"role": "user", "name": "user", "content": "Giraffes"}, + >>> ] + >>> openai_model = OpenAILLM(api_key=API_KEY) + >>> response = await openai_model.generate(messages, model="gpt-3.5-turbo", temperature=0.0, seed=42) + >>> print(response) + >>> "Why don't giraffes use computers? Because their heads are always in the clouds!" + + """ + + response = await self._call( + messages=messages, + model=model, + temperature=temperature, + seed=seed, + **openai_kwargs, + ) + + response = response.choices[0].message + + if extra: + return {"response": response.content, **extra} + + return response + + @retry( + retry( + reraise=True, + stop=stop_after_attempt(8), + wait=wait_exponential(multiplier=1, min=1, max=60), + retry=( + retry_if_exception_type(APIError) + | retry_if_exception_type(APIConnectionError) + | retry_if_exception_type(RateLimitError) + | retry_if_exception_type(APIStatusError) + | retry_if_exception_type(httpx.ReadTimeout) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + ) + async def _call( + self, + messages: list[dict], + model: str = "gpt-3.5-turbo", + temperature: float = 0.0, + seed: Optional[int] = 42, + **kwargs, + ) -> dict: + """Private method to create an async OpenAI Call.""" + return await self.client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + seed=seed, + **kwargs, + ) # type:ignore + + @retry( + retry( + reraise=True, + stop=stop_after_attempt(8), + wait=wait_exponential(multiplier=1, min=1, max=60), + retry=( + retry_if_exception_type(APIError) + | retry_if_exception_type(APIConnectionError) + | retry_if_exception_type(RateLimitError) + | retry_if_exception_type(APIStatusError) + | retry_if_exception_type(httpx.ReadTimeout) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + ) + async def get_embeddings( + self, + texts: list[str], + model: str = "text-embedding-ada-002", + ) -> list[list[float]]: + """Return the embeddings for a list of text strings.""" + embeddings = await self.client.embeddings.create(input=texts, model=model) + return [item.embedding for item in embeddings.data] diff --git a/src/llm_stack/openai/prompt_template.py b/src/llm_stack/openai/prompt_template.py new file mode 100644 index 0000000..7f36f85 --- /dev/null +++ b/src/llm_stack/openai/prompt_template.py @@ -0,0 +1,188 @@ +import json +import string + +from abc import ABC +from abc import abstractmethod +from dataclasses import asdict +from dataclasses import dataclass +from dataclasses import field +from typing import Dict +from typing import List +from typing import Optional +from typing import Union + + +@dataclass +class BasePromptTemplate(ABC): + """Base prompt template. Inherit it to write OpenAI messages and functions.""" + + initial_template: Dict[str, str] = field(default_factory=dict, init=False) + + def __post_init__(self) -> None: + """Keep the initial template.""" + self.initial_template = self._initialize_template() + + @abstractmethod + def _initialize_template(self) -> None: + """To be implemented by child classes.""" + pass + + @staticmethod + @abstractmethod + def _from_dict(data: Dict) -> None: + """Create a Template instance from a dictionary.""" + pass + + def format_message(self, **kwargs) -> None: + """Process a message and fill in any placeholders.""" + + def recursive_format(value: Union[str, dict]) -> Union[str, dict]: + if isinstance(value, str): + placeholders = self._extract_placeholders(value) + if placeholders: + return value.format(**kwargs) + return value + elif isinstance(value, dict): + return {k: recursive_format(v) for k, v in value.items()} + else: + return value + + for k in self.__dict__.keys(): + if k != "initial_template": + self.__dict__[k] = recursive_format(self.initial_template[k]) + + @classmethod + def load(cls, obj: Union[Dict, str]) -> Optional["BasePromptTemplate"]: + """Load a Template instance from a JSON file or a dictionary.""" + if isinstance(obj, str): + return cls._from_json(obj) + elif isinstance(obj, Dict): + return cls._from_dict(obj) + else: + raise TypeError(f"Expected a JSON file path or a dictionary, got {type(obj)}.") + + @staticmethod + def _exclude_keys( + d: dict, + exclude: Optional[List[str]] = None, # noqa: B006 + ) -> dict: + """Exclude keys from a dictionary.""" + try: + if not d["name"]: + d.pop("name", None) + except KeyError: + pass + + if exclude: + for item in exclude: + d.pop(item, None) + return d + return d + + def to_prompt( + self, + exclude: Optional[List[str]] = ["initial_template"], # noqa: B006 + **kwargs, + ) -> Dict: + """Convert a Template instance to a JSON string.""" + self.format_message(**kwargs) + d = asdict(self) + return self._exclude_keys(d, exclude=exclude) + + @staticmethod + def _extract_placeholders(s: str) -> List[str]: + """Extract placeholder variables that can be filled in an f-string.""" + formatter = string.Formatter() + return [field_name for _, field_name, _, _ in formatter.parse(s) if field_name is not None] + + @classmethod + def _from_json(cls, json_path: str) -> Optional["BasePromptTemplate"]: + """Create a Template instance by providing a JSON path.""" + return cls._from_dict(cls._read_json(json_path)) + + @staticmethod + def _read_json(json_path: str) -> Dict: + """Read a JSON file.""" + with open(json_path, "r") as f: + return json.load(f) + + def to_json(self, path: str) -> None: + """Convert a Template instance to a JSON string.""" + self._write_json(self.initial_template, path) + + def _write_json(self, data: Dict, path: str) -> None: + """Write a JSON file.""" + # Sometimes `name` in `MessageTemplate` is null` + data = {k: v for k, v in data.items() if v} + + with open(path, "w") as f: + json.dump(data, f) + + +@dataclass +class MessageTemplate(BasePromptTemplate): + """Create a template for a message prompt.""" + + role: str + content: str + name: Optional[str] = None + + def __post_init__(self) -> None: + """Keep the initial template and error when the role is function but not name was given.""" + super().__post_init__() + if self.role == "function" and not self.name: + raise ValueError("The 'name' attribute is required when 'role' is 'function'.") + + def _initialize_template(self) -> dict: + return {"role": self.role, "content": self.content, "name": self.name} + + @staticmethod + def _from_dict(data: Dict) -> "MessageTemplate": + instance = MessageTemplate(**data) + # Step 2: Add a validation step after initialization + if instance.role == "function" and not instance.name: + raise ValueError("The 'name' attribute is required when 'role' is 'function'.") + return instance + + +@dataclass +class FunctionTemplate(BasePromptTemplate): + """Create a template for an OpenAI function.""" + + name: str + description: str + parameters: Dict[str, Union[str, Dict[str, Dict[str, Union[str, List[str]]]], List[str]]] + + def _initialize_template(self) -> dict: + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self.parameters, + }, + } + + @staticmethod + def _from_dict(data: Dict) -> "FunctionTemplate": + """Create a Template instance from a dictionary.""" + try: + return FunctionTemplate(**data["function"]) + except TypeError as e: + raise TypeError("Expected a dictionary with a 'function' key.") from e + + def to_prompt( + self, + exclude: Optional[List[str]] = ["initial_template"], # noqa: B006 + ) -> Dict: + """Convert a Template instance to a JSON string.""" + # Custom formatting for the output + formatted_data = { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self.parameters, + }, + } + return self._exclude_keys(formatted_data, exclude=exclude) diff --git a/src/llm_stack/scripts/build_dataset_summaries.py b/src/llm_stack/scripts/build_dataset_summaries.py new file mode 100644 index 0000000..e2dc5fc --- /dev/null +++ b/src/llm_stack/scripts/build_dataset_summaries.py @@ -0,0 +1,130 @@ +from modal import Image +from modal import Mount +from modal import Secret +from modal import Stub + + +image = ( + Image.debian_slim(python_version="3.11.5") + .apt_install("git") + .poetry_install_from_file(poetry_pyproject_toml="pyproject.toml") +) + +# Imports shared across functions +with image.run_inside(): + import asyncio + import os + + import pandas as pd + + from tqdm import tqdm + + from llm_stack.build_dataset import ArxivAPI + from llm_stack.openai import MessageTemplate + from llm_stack.openai import OpenAILLM + from llm_stack.wandb_utils import ArtifactHandler + from llm_stack.wandb_utils import WandbTypes + + +stub = Stub(name="build-summaries-dataset-with-openai", image=image) + + +@stub.function(secret=Secret.from_name("wandb-secret")) +def fetch_arxiv_data( + local_data_path: str, + artifact_name: str = "arxiv-preprints", + query: str = "LLM,large language models,gpt", + start_date: int = 20230101, + end_date: int = 20231205, + max_results: int = 4000, +) -> None: + """Fetch arxiv data from the arxiv API.""" + handler = ArtifactHandler(project="llm-stack", job_type=WandbTypes.raw_data_job) + + # Grab LLM-related papers from 2023 + preprints = ArxivAPI.search_and_parse_papers( + query=query, + start_date=start_date, + end_date=end_date, + max_results=max_results, + ) + + # Save the raw data + handler.write_artifact( + obj=preprints, + local_path=local_data_path, + name=artifact_name, + artifact_type=WandbTypes.dataset_artifact, + metadata={"query": query, "start_date": start_date, "end_date": end_date}, + ) + + handler.run.finish() + + +@stub.function( + secrets=[Secret.from_name("wandb-secret"), Secret.from_name("openai-secret")], + mounts=[Mount.from_local_dir("src/llm_stack/build_dataset/prompts", remote_path="/root/prompts")], + timeout=1500, +) +async def annotate_dataset_with_open_ai( + local_data_path: str, + raw_data_artifact: str = "arxiv-preprints", + annotated_artifact_name: str = "preprints-with-openai-summaries", + model_name: str = "gpt-3.5-turbo-1106", + user_message_file: str = "openai_summarizer.json", + timeout: int = 120, +) -> None: + """Run the arXiv summarizer with OpenAI's LLMs.""" + handler = ArtifactHandler(project="llm-stack", job_type=WandbTypes.inference_job) + + preprints = handler.read_artifact( + name=raw_data_artifact, + artifact_type=WandbTypes.dataset_artifact, + ) + + # Load prompt template and instantiate OpenAI model + user_message = MessageTemplate.load(f"/root/prompts/{user_message_file}") + openai_llm = OpenAILLM(api_key=os.environ["OPENAI_API_KEY"], timeout=timeout) + + # Run the async tasks + tasks = [] + for tup in preprints.itertuples(): + messages = [user_message.to_prompt(text=tup.abstract)] + tasks.append( + openai_llm.generate( + messages=messages, + model=model_name, + extra={"id": tup.arxiv_url}, + ) + ) + + predictions = [] + for future in tqdm(asyncio.as_completed(tasks), total=len(tasks)): + result = await future + if result: + predictions.append(result) + + # Save the predictions to wandb + predictions = pd.DataFrame(predictions) + cols = ["arxiv_url", "abstract", "response"] + preprints = preprints.merge(predictions, left_on="arxiv_url", right_on="id")[cols] + + handler.write_artifact( + obj=preprints, + local_path=local_data_path, + name=annotated_artifact_name, + artifact_type=WandbTypes.dataset_artifact, + ) + + +@stub.local_entrypoint() +async def main( + local_data_path_raw: str = "arxiv_preprints.parquet", + local_data_path_openai: str = "preprints_openai_summaries.parquet", +) -> None: + """Build an NER dataset using arXiv's papers and OpenAI's LLMs.""" + # # Fetching the arXiv data + fetch_arxiv_data.remote(local_data_path=local_data_path_raw) + + # Summarisation with OpenAI + annotate_dataset_with_open_ai.remote(local_data_path=local_data_path_openai) diff --git a/src/llm_stack/wandb_utils.py b/src/llm_stack/wandb_utils.py new file mode 100644 index 0000000..5492a76 --- /dev/null +++ b/src/llm_stack/wandb_utils.py @@ -0,0 +1,177 @@ +from dataclasses import dataclass + +import pandas as pd +import wandb + + +@dataclass +class WandbTypes: + """Types for wandb experiments.""" + + raw_data_job: str = "raw_data" + process_data_job: str = "process_data" + train_model_job: str = "model" + evaluate_model_job: str = "evaluate_model" + inference_job: str = "inference" + + dataset_artifact: str = "dataset" + model_artifact: str = "model" + + +class WandbRun: + """Create a wandb session.""" + + def __init__(self, project: str, job_type: str, **kwargs) -> None: + """Create a wandb session. + + Parameters + ---------- + project : str + Project name on wandb. + + job_type : str + The type of job which is being run, which is used to organize + and differentiate steps in the ML pipeline and distinguish + which steps created which artifacts. + + **kwargs + Additional keyword arguments to pass to `wandb.init`. + See https://docs.wandb.ai/ref/python/init for details. + + """ + # Use a running session or create a new one + if wandb.run: + self.run = wandb.run + self.job_type = wandb.run.job_type + self.project = wandb.run.project + else: + self.run = wandb.init(project=project, job_type=job_type, **kwargs) + self.job_type = job_type + self.project = project + + assert self.run + + @property + def name(self) -> str: + """Return the run name.""" + return self.run.name # type: ignore + + @property + def id(self) -> str: + """Return the run ID.""" + return self.run.id # type: ignore + + +class ArtifactHandler(WandbRun): + """Read and write artifacts stored in wandb.""" + + def __init__(self, project: str, job_type: str, **kwargs) -> None: + super().__init__(project, job_type, **kwargs) + + def write_artifact( + self, + obj: object, + local_path: str, + name: str, + artifact_type: str, + **kwargs, + ) -> None: + """Log an artifact in wandb. Requires a wandb session to work. + + Parameters + ---------- + obj + The object you want to store and log in wandb. + + local_path + Where the object is stored locally. + + name + A human-readable name for this artifact, which is how you + can identify this artifact in the UI or reference it in + use_artifact calls. The name must be unique across a project. + + artifact_type + The type of artifact you are logging. + Options are: 'dataset', 'model', 'metric' + + **kwargs + Additional keyword arguments to pass to `wandb.Artifact`. + See https://docs.wandb.ai/ref/python/artifact + + """ + + if isinstance(obj, pd.DataFrame): + obj.to_parquet(local_path) + else: + raise NotImplementedError(f"Only pandas DataFrames are supported for now, not {type(obj)}") + + self._log_artifact(name=name, local_path=local_path, artifact_type=artifact_type, **kwargs) + + def _log_artifact( + self, + name: str, + local_path: str, + artifact_type: str, + **kwargs, + ) -> None: + # Create the artifact + artifact = wandb.Artifact(name=name, type=artifact_type, **kwargs) + + # Add a file + artifact.add_file(local_path=local_path) + + self.run.log_artifact(artifact) # type: ignore + + def read_artifact( + self, + name: str, + artifact_type: str, + version: str = "latest", + ) -> object: + """Read a data or ML model artifact. + + For data artifacts, it returns a pandas dataframe. For model artifacts, it returns a + path to the directory containing the model. + + TODO: Return a huggingface dataset instead of a pandas dataframe. + + Notes + ----- + - Assumes that data artifacts are always stored as parquet files. + + Parameters + ---------- + name + The name of the artifact to download. It must contain its version + (or `latest`) too. + + artifact_type + Describes the artifact like `model` or `dataset`. It is used + in the `download_path`. + + version + Determines the version of the artifact that will be downloaded. + + """ + file_path = self._download_artifact( + name=name, + version=version, + ) + + if artifact_type == WandbTypes.dataset_artifact: + return pd.read_parquet(file_path) + else: + raise NotImplementedError(f"Only datasets are supported for now, not {artifact_type}") + + def _download_artifact( + self, + name: str, + version: str = "latest", + ) -> str: + artifact = self.run.use_artifact(f"{name}:{version}") # type: ignore + + # Download locally + file = artifact.download() + + return file