Skip to content

Commit

Permalink
Llama distillation notebook fixes (#3322)
Browse files Browse the repository at this point in the history
* Adding missing datasets package

* Doc explaining TEACHER_MODEL_ENDPOINT_NAME and TEACHER_MODEL_NAME

* mlclient.models.get requires version or label

* data directory must be created prior to saving datasets

* Only one sample is saved in train and valid JSONL files

- was using `w` mode which replaces the file with each dataset JSON row, leaving only one sample at the end
- replaced `w` write mode with `a` append mode so that all JSON lines are saved to train and valid files

* Experiment name cannot contain `.`

- Llama 3.1 models have `.` in their names, thus failing the experiment name constraint of not having a `.`

* Printing studio progress URL

* Llama 3.1 marketplace subscription prerequisites
  • Loading branch information
cedricvidal authored Jul 29, 2024
1 parent 2e86135 commit 0de88c9
Showing 1 changed file with 29 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
"- Distillation should only be used for single turn chat completion format.\n",
"- The Meta Llama 3.1 405B Instruct model can only be used as a teacher model.\n",
"- The Meta Llama 3.1 8B Instruct can only be used as a student (target) model.\n",
"- Distllation is currently supported only for Natural Language Inference (NLI) task, which is a standard task in benchmarking for Natural Language Understanding."
"- Distllation is currently supported only for Natural Language Inference (NLI) task, which is a standard task in benchmarking for Natural Language Understanding.\n",
"\n",
"**Prerequisites :**\n",
"- Subscribe to the Meta Llama 3.1 405B Instruct and Meta Llama 3.1 8B Instruct, see [how to subscribe your project to the model offering in MS Learn](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-serverless?tabs=azure-ai-studio#subscribe-your-project-to-the-model-offering)"
]
},
{
Expand All @@ -38,7 +41,8 @@
"# %pip install azure-identity\n",
"\n",
"# %pip install mlflow\n",
"# %pip install azureml-mlflow"
"# %pip install azureml-mlflow\n",
"# %pip install datasets"
]
},
{
Expand Down Expand Up @@ -163,8 +167,12 @@
"outputs": [],
"source": [
"# Llama-3-405B Teacher model endpoint name\n",
"# The serverless model name is the name found in ML Studio > Endpoints > Serverless endpoints > Model column\n",
"TEACHER_MODEL_NAME = \"Meta-Llama-3.1-405B-Instruct\"\n",
"TEACHER_MODEL_ENDPOINT_NAME = \"<Please provide Meta Llama 3.1 405B endpoint name>\""
"\n",
"# The serverless model endpoint name is the name found in ML Studio > Endpoints > Serverless endpoints > Name column\n",
"# The endpoint URL will be resolved from this name by the MLFlow component\n",
"TEACHER_MODEL_ENDPOINT_NAME = \"Meta-Llama-3-1-405B-Instruct-vum\""
]
},
{
Expand All @@ -183,10 +191,13 @@
"outputs": [],
"source": [
"STUDENT_MODEL_NAME = \"Meta-Llama-3.1-8B-Instruct\"\n",
"STUDENT_MODEL_VERSION = 1\n",
"\n",
"# retrieve student model from model registry\n",
"mlclient_azureml_meta = MLClient(credential, registry_name=\"azureml-meta\")\n",
"student_model = mlclient_azureml_meta.models.get(STUDENT_MODEL_NAME)\n",
"student_model = mlclient_azureml_meta.models.get(\n",
" STUDENT_MODEL_NAME, version=STUDENT_MODEL_VERSION\n",
")\n",
"\n",
"print(\n",
" \"\\n\\nUsing model name: {0}, version: {1}, id: {2} for fine tuning\".format(\n",
Expand Down Expand Up @@ -290,6 +301,15 @@
"print(\"Len of validation data sample is \" + str(len(val)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! mkdir -p data"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -316,7 +336,7 @@
" + row[\"hypothesis\"],\n",
" }\n",
" )\n",
" with open(train_data_path, \"w\") as f:\n",
" with open(train_data_path, \"a\") as f:\n",
" f.write(json.dumps(data) + \"\\n\")\n",
"\n",
"for row in val:\n",
Expand All @@ -336,7 +356,7 @@
" + row[\"hypothesis\"],\n",
" }\n",
" )\n",
" with open(valid_data_path, \"w\") as f:\n",
" with open(valid_data_path, \"a\") as f:\n",
" f.write(json.dumps(data) + \"\\n\")"
]
},
Expand Down Expand Up @@ -504,6 +524,7 @@
"train_file_path_input = Input(type=\"uri_file\", path=train_data.path)\n",
"validation_file_path_input = Input(type=\"uri_file\", path=valid_data.path)\n",
"input_finetune_model = Input(type=\"mlflow_model\", path=student_model.id)\n",
"experiment_name = f\"distillation-{TEACHER_MODEL_NAME}\".replace(\".\", \"-\")\n",
"\n",
"finetuning_job = distillation_pipeline(\n",
" teacher_model_endpoint_name=TEACHER_MODEL_ENDPOINT_NAME,\n",
Expand All @@ -519,7 +540,7 @@
"\n",
"# pipeline_job.identity = UserIdentityConfiguration()\n",
"finetuning_job.display_name = f\"finetune-{student_model.name}\"\n",
"finetuning_job.experiment_name = f\"distillation-{TEACHER_MODEL_NAME}\"\n",
"finetuning_job.experiment_name = experiment_name\n",
"finetuning_job.settings.default_compute_type = \"serverless\"\n",
"finetuning_job.continue_on_step_failure = False\n",
"# pipeline_job.settings.force_rerun = True"
Expand All @@ -540,9 +561,7 @@
"source": [
"# Submit pipeline job to workspace\n",
"ft_job = ml_client.jobs.create_or_update(finetuning_job)\n",
"# ft_job.studio_url\n",
"\n",
"# build link to ai studio fine-tuning tab"
"print(f\"Submitted job, progress available at {ft_job.studio_url}\")"
]
},
{
Expand Down

0 comments on commit 0de88c9

Please sign in to comment.