diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e09045a4..4f799390 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -11,17 +11,17 @@ permissions: jobs: build-python: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt requirements.dev.txt requirements.finetune.txt + pip install -r requirements.dev.txt -r requirements.txt -r requirements.finetune.txt - name: Build core run: | pip install -e . \ No newline at end of file diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index e81557f9..22327d3f 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -14,17 +14,18 @@ permissions: jobs: test: - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - - name: Install dependencies + - name: Install dependencies and package run: | python -m pip install --upgrade pip - pip install -r requirements.dev.txt - pip install -r requirements.txt requirements.finetune.txt + pip install -r requirements.dev.txt -r requirements.txt -r requirements.finetune.txt -r requirements.converters.txt + pip install -e . - name: Lint critical errors with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/examples/gsm8k_tuning/eval.ipynb b/examples/gsm8k_tuning/eval.ipynb index a07cd4cb..68b9659c 100644 --- a/examples/gsm8k_tuning/eval.ipynb +++ b/examples/gsm8k_tuning/eval.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -24,16 +24,17 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "env = MathEnvironment()\n", "\n", "\n", - "def eval(tested_agent, test_set) -> float:\n", + "def eval(tested_agent, test_set, name=\"\") -> float:\n", " test_solved = []\n", - " for i, sample in enumerate(tqdm(test_set)):\n", + " n = 0\n", + " for sample in tqdm(test_set):\n", " sample = extract_result_value(sample)\n", " try:\n", " tape = solve_task(tested_agent, env, sample)\n", @@ -42,15 +43,21 @@ " print(colored(\"Failed to solve task: {e}\", \"red\"))\n", " test_solved.append(0)\n", " raise e\n", - " if i % 10 == 0 and i > 0:\n", - " print(f\"{i}: Current accuracy: {np.mean(test_solved):.3f}\")\n", + " acc = np.mean(test_solved).item()\n", + " n = len(test_solved)\n", + " if n % 10 == 0 and n > 0:\n", + " print(f\"{n}: Current accuracy: {acc:.3f}\")\n", + " with open(\"results.jsonl\", \"a\") as f:\n", + " f.write(json.dumps({name: acc, \"n\": n}) + \"\\n\")\n", " acc = np.mean(test_solved).item()\n", + " with open(\"results.jsonl\", \"a\") as f:\n", + " f.write(json.dumps({name: acc, \"n\": n}) + \"\\n\")\n", " return acc\n" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -71,12 +78,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Evaluation" + "## Untuned model accuracy" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -94,43 +101,304 @@ ")\n" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Untuned model accuracy" - ] - }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], - "source": [ - "val_acc = eval(untuned_agent, val_set)\n", - "print(f\"Untuned on train {val_acc:.3f}\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 6%|▌ | 11/200 [00:50<15:08, 4.81s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10: Current accuracy: 0.545\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 10%|█ | 21/200 [01:22<08:29, 2.85s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20: Current accuracy: 0.476\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 16%|█▌ | 31/200 [01:55<08:12, 2.91s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "30: Current accuracy: 0.613\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 20%|██ | 41/200 [02:35<10:20, 3.90s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "40: Current accuracy: 0.585\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 26%|██▌ | 51/200 [03:21<13:38, 5.49s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "50: Current accuracy: 0.608\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 30%|███ | 61/200 [03:53<06:24, 2.77s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "60: Current accuracy: 0.656\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 36%|███▌ | 71/200 [04:34<08:15, 3.84s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "70: Current accuracy: 0.662\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 40%|████ | 81/200 [06:01<11:10, 5.63s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "80: Current accuracy: 0.654\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 46%|████▌ | 91/200 [06:38<07:27, 4.11s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "90: Current accuracy: 0.681\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 101/200 [07:24<07:09, 4.34s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "100: Current accuracy: 0.683\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 56%|█████▌ | 111/200 [08:01<06:02, 4.07s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "110: Current accuracy: 0.676\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 60%|██████ | 121/200 [08:44<05:48, 4.41s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "120: Current accuracy: 0.678\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 66%|██████▌ | 131/200 [09:47<05:18, 4.61s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "130: Current accuracy: 0.679\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 70%|███████ | 141/200 [11:55<11:14, 11.44s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "140: Current accuracy: 0.660\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 76%|███████▌ | 151/200 [12:34<03:24, 4.18s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "150: Current accuracy: 0.656\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 161/200 [13:13<02:49, 4.34s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "160: Current accuracy: 0.658\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 86%|████████▌ | 171/200 [14:13<02:37, 5.44s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "170: Current accuracy: 0.655\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 90%|█████████ | 181/200 [14:49<01:05, 3.45s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "180: Current accuracy: 0.663\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 96%|█████████▌| 191/200 [15:29<00:35, 3.95s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "190: Current accuracy: 0.654\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 200/200 [16:12<00:00, 4.86s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Untuned on test 0.660\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "acc = eval(untuned_agent, test_set)\n", "print(f\"Untuned on test {acc:.3f}\")\n" ] }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "with open(\"results.json\", \"w\") as f:\n", - " f.write(json.dumps({\"untuned\": {\"train\": val_acc, \"test\": acc}}))\n" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -140,16 +408,16 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "# run inference: vllm serve gsm8k/tune1/intermediate/1000/\n", + "# run inference: vllm serve gsm8k/tuning/llama31_70b_train_t02/tune1/intermediate/800/\n", "tuned_agent = MathAgent(\n", " llms={\n", " \"default\": LLAMA(\n", " base_url=\"http://localhost:8000\",\n", - " model_name=\"gsm8k/tune1/intermediate/1000/\",\n", + " model_name=\"gsm8k/tuning/llama31_70b_train_t02/tune1/intermediate/800/\",\n", " tokenizer_name=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", " parameters=dict(temperature=0.0),\n", " use_cache=False,\n", @@ -160,39 +428,628 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 6%|▌ | 11/200 [00:54<16:13, 5.15s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10: Current accuracy: 0.727\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 10%|█ | 21/200 [01:43<13:07, 4.40s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20: Current accuracy: 0.762\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 16%|█▌ | 31/200 [02:23<11:41, 4.15s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "30: Current accuracy: 0.806\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 20%|██ | 41/200 [03:13<10:48, 4.08s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "40: Current accuracy: 0.756\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 26%|██▌ | 51/200 [04:04<13:32, 5.45s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "50: Current accuracy: 0.784\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 30%|███ | 61/200 [04:49<09:19, 4.03s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "60: Current accuracy: 0.820\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 36%|███▌ | 71/200 [05:37<10:30, 4.89s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "70: Current accuracy: 0.789\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 40%|████ | 81/200 [07:40<16:14, 8.19s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "80: Current accuracy: 0.765\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 46%|████▌ | 91/200 [08:23<07:25, 4.09s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "90: Current accuracy: 0.791\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 101/200 [09:12<08:19, 5.05s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "100: Current accuracy: 0.802\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 56%|█████▌ | 111/200 [10:06<08:22, 5.64s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "110: Current accuracy: 0.784\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 60%|██████ | 121/200 [10:59<08:44, 6.64s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "120: Current accuracy: 0.802\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 66%|██████▌ | 131/200 [12:52<15:00, 13.04s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "130: Current accuracy: 0.771\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 70%|███████ | 141/200 [15:14<13:32, 13.77s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "140: Current accuracy: 0.752\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 76%|███████▌ | 151/200 [17:05<06:28, 7.93s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "150: Current accuracy: 0.748\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 161/200 [17:50<03:11, 4.91s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "160: Current accuracy: 0.752\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 86%|████████▌ | 171/200 [19:59<08:19, 17.21s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "170: Current accuracy: 0.754\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 90%|█████████ | 181/200 [20:48<01:37, 5.14s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "180: Current accuracy: 0.768\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 96%|█████████▌| 191/200 [21:43<00:51, 5.74s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "190: Current accuracy: 0.775\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 200/200 [22:35<00:00, 6.78s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tuned on test 0.775\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ - "tuned_val_acc = eval(tuned_agent, test_set)\n", - "print(f\"Tuned on test {tuned_val_acc:.3f}\")\n" + "tuned_acc = eval(tuned_agent, test_set, \"tuned_acc\")\n", + "print(f\"Tuned on test {tuned_acc:.3f}\")\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "tuned_acc = eval(tuned_agent, val_set)\n", - "print(f\"Tuned on train {tuned_acc:.3f}\")\n" + "# check teacher model\n", + "big_llm = LLAMA(\n", + " base_url=\"https://api.together.xyz\",\n", + " model_name=\"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo\",\n", + " tokenizer_name=\"meta-llama/Meta-Llama-3.1-70B-Instruct\",\n", + " parameters=dict(temperature=0.2),\n", + " use_cache=False,\n", + ")\n", + "big_agent = MathAgent(llms={\"default\": big_llm})\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 5%|▌ | 10/200 [01:59<38:17, 12.09s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10: Current accuracy: 0.800\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 10%|█ | 20/200 [04:12<39:39, 13.22s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "20: Current accuracy: 0.900\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 15%|█▌ | 30/200 [05:27<18:59, 6.70s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "30: Current accuracy: 0.933\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 20%|██ | 40/200 [07:17<32:46, 12.29s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "40: Current accuracy: 0.925\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 25%|██▌ | 50/200 [08:58<24:59, 10.00s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "50: Current accuracy: 0.920\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 30%|███ | 60/200 [10:27<16:43, 7.17s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "60: Current accuracy: 0.933\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 35%|███▌ | 70/200 [12:01<21:15, 9.82s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "70: Current accuracy: 0.929\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 40%|████ | 80/200 [13:53<22:59, 11.50s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "80: Current accuracy: 0.938\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 45%|████▌ | 90/200 [15:38<17:53, 9.76s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "90: Current accuracy: 0.944\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 50%|█████ | 100/200 [17:26<14:13, 8.54s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "100: Current accuracy: 0.940\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 52%|█████▏ | 103/200 [17:51<13:42, 8.48s/it]Failed to parse agent output: {\"kind\": \"use_calculator_action\", \"expression\": \"30 * (1 - 0.3)\"}\n", + "{\"kind\": \"reasoning_thought\", \"reasoning\": \"Calculate how many times Anne went down the slide, which is 30% less than Mitchel.\"}\n", + "\n", + "Error: Extra data: line 2 column 1 (char 66)\n", + "Traceback (most recent call last):\n", + " File \"/home/toolkit/TapeAgents/tapeagents/guided_agent.py\", line 102, in parse_completion\n", + " step_dicts = json.loads(sanitize_json_completion(completion))\n", + " File \"/home/toolkit/.conda/envs/tapeagents2/lib/python3.10/json/__init__.py\", line 346, in loads\n", + " return _default_decoder.decode(s)\n", + " File \"/home/toolkit/.conda/envs/tapeagents2/lib/python3.10/json/decoder.py\", line 340, in decode\n", + " raise JSONDecodeError(\"Extra data\", s, end)\n", + "json.decoder.JSONDecodeError: Extra data: line 2 column 1 (char 66)\n", + " 55%|█████▌ | 110/200 [20:23<52:32, 35.03s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "110: Current accuracy: 0.936\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 60%|██████ | 120/200 [21:54<13:48, 10.36s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "120: Current accuracy: 0.942\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 65%|██████▌ | 130/200 [25:12<24:50, 21.29s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "130: Current accuracy: 0.938\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 70%|███████ | 140/200 [27:36<16:47, 16.79s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "140: Current accuracy: 0.929\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 75%|███████▌ | 150/200 [29:17<08:06, 9.73s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "150: Current accuracy: 0.927\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 80%|████████ | 160/200 [30:40<06:06, 9.17s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "160: Current accuracy: 0.925\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 85%|████████▌ | 170/200 [32:37<06:13, 12.44s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "170: Current accuracy: 0.929\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 90%|█████████ | 180/200 [34:44<03:17, 9.90s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "180: Current accuracy: 0.933\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 95%|█████████▌| 190/200 [36:17<01:27, 8.75s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "190: Current accuracy: 0.937\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 200/200 [37:49<00:00, 11.35s/it]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "200: Current accuracy: 0.935\n", + "Teacher on test 0.935\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ - "with open(\"results.json\", \"w\") as f:\n", - " f.write(\n", - " json.dumps(\n", - " {\n", - " \"untuned\": {\"train\": val_acc, \"test\": acc},\n", - " \"tuned\": {\"train\": tuned_acc, \"test\": tuned_val_acc},\n", - " }\n", - " )\n", - " )\n" + "big_acc = eval(big_agent, test_set, \"big_acc\")\n", + "print(f\"Teacher on test {big_acc:.3f}\")\n" ] } ], diff --git a/examples/gsm8k_tuning/finetune.ipynb b/examples/gsm8k_tuning/finetune.ipynb index a7ebf9f2..891701f1 100644 --- a/examples/gsm8k_tuning/finetune.ipynb +++ b/examples/gsm8k_tuning/finetune.ipynb @@ -122,7 +122,7 @@ "\n", "# from tapeagents.finetune.data import load_samples, save_samples\n", "\n", - "# train_samples_file = \"/gsm8k/tuning/llama31_70b_train_t02/training_samples_3k.jsonl\"\n", + "# train_samples_file = \"gsm8k/tuning/llama31_70b_train_t02/training_samples_3k.jsonl\"\n", "# save_samples(training_samples, train_samples_file)\n", "\n", "# training_samples = load_samples(train_samples_file)\n", diff --git a/tapeagents/finetune/lora.py b/tapeagents/finetune/lora.py index 7ccb6bcb..81011be0 100644 --- a/tapeagents/finetune/lora.py +++ b/tapeagents/finetune/lora.py @@ -1,5 +1,7 @@ import json +import logging import os +import sys from pathlib import Path import torch @@ -9,9 +11,9 @@ from peft.utils.other import prepare_model_for_kbit_training from peft.utils.save_and_load import set_peft_model_state_dict from safetensors.torch import load_file -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer -from .context import logger +logger = logging.getLogger(__name__) def has_lora_checkpoint(current_dir: Path) -> bool: @@ -135,3 +137,29 @@ def apply_lora(model, lora_model_path): wdiff = (lora_b @ lora_a) * scaling layer_weights.data += wdiff break + + +def merge_lora(lora_model_path): + if lora_model_path[-1] == "/": + lora_model_path = lora_model_path[:-1] + assert os.path.isdir(lora_model_path), f"{lora_model_path} is not a dir" + lora_model_config = os.path.join(lora_model_path, "adapter_config.json") + assert os.path.exists(lora_model_config), f"{lora_model_config} does not exists" + + logger.info(f"Merge lora checkpoint {lora_model_path}") + model = lora_load_and_merge(lora_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) + tokenizer = AutoTokenizer.from_pretrained(lora_model_path) + + tmp_dir = f"{lora_model_path}_merged" + logger.info(f"Save merged model to {tmp_dir}") + model.save_pretrained(tmp_dir, safe_serialization=True) + tokenizer.save_pretrained(tmp_dir) + + os.rename(lora_model_path, f"{lora_model_path}_lora") + os.rename(tmp_dir, lora_model_path) + logger.info(f"Merged model saved to {lora_model_path}") + + +if __name__ == "__main__": + assert len(sys.argv) == 2, "Merging lora weights: python lora.py " + merge_lora(sys.argv[1])