diff --git a/docs/tutorials/example_sammo_express.ipynb b/docs/tutorials/example_sammo_express.ipynb new file mode 100644 index 0000000..66abb02 --- /dev/null +++ b/docs/tutorials/example_sammo_express.ipynb @@ -0,0 +1,591 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "remove-cell" + ] + }, + "outputs": [], + "source": [ + "# Load from parent directory if not installed\n", + "import importlib\n", + "import os\n", + "\n", + "if not importlib.util.find_spec(\"sammo\"):\n", + " import sys\n", + "\n", + " sys.path.append(\"../../\")\n", + "os.environ[\"CACHE_FILE\"] = \"cache/sammo_express.tsv\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# SAMMO Express (beta)\n", + "\n", + "One of the more time-consuming tasks is converting an existing prompt into a prompt program. `SAMMO` Express is now able to do this using a Markdown file." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-cell" + ] + }, + "outputs": [], + "source": [ + "# %load -r 3:25 _init.py\n", + "import pathlib\n", + "import sammo\n", + "from sammo.runners import OpenAIChat\n", + "from sammo.base import Template, EvaluationScore, Component\n", + "from sammo.components import Output, GenerateText, ForEach, Union\n", + "from sammo.extractors import ExtractRegex\n", + "from sammo.data import DataTable\n", + "import json\n", + "import requests\n", + "import os\n", + "\n", + "if not \"OPENAI_API_KEY\" in os.environ:\n", + " raise ValueError(\"Please set the environment variable 'OPENAI_API_KEY'.\")\n", + "\n", + "_ = sammo.setup_logger(\"WARNING\") # we're only interested in warnings for now\n", + "\n", + "runner = OpenAIChat(\n", + " model_id=\"gpt-3.5-turbo\",\n", + " api_config={\"api_key\": os.environ[\"OPENAI_API_KEY\"]},\n", + " cache=os.getenv(\"CACHE_FILE\", \"cache.tsv\"),\n", + " timeout=30,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We start with a prompt written in Markdown. SAMMO additionally recognizes:\n", + "* CSS-like classes in the form of `.classname`\n", + "* CSS-like identifiers in the form of `#id`\n", + "* Native placeholders in handlebar.js syntax for the input like `{{{input}}}`\n", + "\n", + "Here is an example:" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "PROMPT_IN_MARKDOWN = \"\"\"\n", + "# Instructions \n", + "Convert the following user queries into a SQL query.\n", + "\n", + "# Table\n", + "Users:\n", + "- user_id (INTEGER, PRIMARY KEY)\n", + "- name (TEXT)\n", + "- age (INTEGER)\n", + "- city (TEXT)\n", + "\n", + "# Examples \n", + "Input: \"Find all users who are older than 30.\" \n", + "Output: `SELECT name FROM Users WHERE age > 30;`\n", + "\n", + "Input: \"List the names of users who live in 'New York'.\" \n", + "Output: `SELECT name FROM Users WHERE city = 'New York';`\n", + " \n", + "# Complete this\n", + "Input: {{{input}}}\n", + "Output:\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "Using `sammo.express`, we can automatically map the structure implied by Markdown into a structred symbolic prompt program:" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sammo.express import MarkdownParser\n", + "spp = MarkdownParser(PROMPT_IN_MARKDOWN).get_sammo_program()\n", + "spp.plot_program()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's execute it on some data. For this small test, we will skip the DataTables and use a list of dicts." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "+------------------------------------------+-----------------------------------------------------+\n", + "| input | output |\n", + "+==========================================+=====================================================+\n", + "| {'input': 'No of users starting with J'} | SELECT COUNT(name) FROM Users WHERE name LIKE 'J%'; |\n", + "+------------------------------------------+-----------------------------------------------------+\n", + "Constants: None" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Output(GenerateText(spp)).run(runner, [{\"input\": \"No of users starting with J\"}])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Bonus: Optimizing the prompt program" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "d_train = DataTable.from_records([{\"input\": \"Get all users whose name starts with the letter 'J'\",\n", + " \"output\": \"SELECT * FROM Users WHERE name LIKE 'J%';\"\n", + " },\n", + " {\n", + " \"input\": \"Retrieve the youngest user's information\",\n", + " \"output\": \"SELECT * FROM Users ORDER BY age ASC LIMIT 1;\"\n", + " },\n", + " {\n", + " \"input\": \"Get all cities where users live\",\n", + " \"output\": \"SELECT DISTINCT city FROM Users;\"\n", + " }])\n", + "\n", + "def accuracy(y_true: DataTable, y_pred: DataTable) -> EvaluationScore:\n", + " y_true = y_true.outputs.normalized_values()\n", + " y_pred = y_pred.outputs.normalized_values()\n", + " n_correct = sum([y_p == y_t for y_p, y_t in zip(y_pred, y_true)])\n", + "\n", + " return EvaluationScore(n_correct / len(y_true))" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "search depth[############]1/1[00:00<00:00] >> eval[#################################]3/3 >> tasks[#######]9/9[00:00<00:00, 600.00it/s]\n", + "\n", + "Fitting log (5 entries):\n", + "iteration action objective costs parse_errors prev_actions\n", + "----------- ---------- ------------------ ---------------------------- -------------- ----------------------\n", + "-1 init 0.3333333333333333 {'input': 386, 'output': 33} 0.0 ['init']\n", + "-1 init 0.3333333333333333 {'input': 386, 'output': 33} 0.0 ['init']\n", + "0 Rewrite 0.6666666666666666 {'input': 437, 'output': 28} 0.0 ['Rewrite', 'init']\n", + "0 Paraphrase 0.6666666666666666 {'input': 380, 'output': 28} 0.0 ['Paraphrase', 'init']\n", + "0 Rewrite 0.6666666666666666 {'input': 437, 'output': 28} 0.0 ['Rewrite', 'init']\n", + "Action stats:\n", + "action stats\n", + "---------- ----------------------------\n", + "Rewrite {'chosen': 2, 'improved': 2}\n", + "Paraphrase {'chosen': 1, 'improved': 1}\n" + ] + } + ], + "source": [ + "from sammo.search import BeamSearch\n", + "from sammo.mutators import BagOfMutators, Paraphrase, Rewrite\n", + "\n", + "mutation_operators = BagOfMutators(\n", + " Output(GenerateText(spp)),\n", + " Paraphrase(\"#instr\"),\n", + " Rewrite(\"#examp\", \"Repeat these examples and add two new ones.\\n\\n {{{{text}}}}\")\n", + ")\n", + "prompt_optimizer = BeamSearch(\n", + " runner,\n", + " mutation_operators,\n", + " accuracy,\n", + " depth=1,\n", + " mutations_per_beam=2,\n", + " n_initial_candidates=2\n", + " )\n", + "prompt_optimizer.fit(d_train)\n", + "prompt_optimizer.show_report()" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt_optimizer.best_prompt.plot_program()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}