Skip to content

Commit

Permalink
Work on NER, update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
kstathou committed Dec 5, 2023
1 parent 4e7cdd2 commit 16b12a5
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 41 deletions.
32 changes: 19 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
# llm-stack

End-to-end tech stack for the LLM data flywheel.
This tutorial series will show you how to build an end-to-end data flywheel for Large Language Models (LLMs).

## Chapters
We will be working on an entity recognition task with a custom set of entity types throughout the series. We will extract entities like `dataset`, `method`, `evaluation` and `task` from arXiv papers.

- Building your training set with GPT-4
- Fine-tuning an open-source LLM
- Evaluation
- Human feedback
- Unit tests
- Deployment
## What you will learn

## Installation
How to:

TODO
- Build a training set with GPT-4 or GPT-3.5
- Fine-tune an open-source LLM
- Create a set of Evals to evaluate the model.
- Collect human feedback to improve the model.
- Deploy the model to an inference endpoint.

## Fine-tuning
## Software used

### Data
- [wandb](https://wandb.ai) for experiment tracking. This is where we will record all our artifacts (datasets, models, code) and metrics.
- [modal](https://modal.com/) for running jobs on the cloud.
- [huggingface](https://huggingface.co/) for all-things-LLM.
- [argilla](https://docs.argilla.io/en/latest/) for labelling our data.

## Tutorial 1 - Generating a training set with GPT-3.5

In this tutorial, we will use GPT-3.5 to generate a training set for our entity recognition task.

## Contributing

TODO
Found any mistakes or want to contribute? Feel free to open a PR or an issue.
1 change: 1 addition & 0 deletions src/llm_stack/build_dataset/prompts/ner_function.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"type": "function", "function": {"name": "entity_recognition", "description": "extract all entities related to machine learning tasks, datasets, methods and benchmarks.", "parameters": {"type": "object", "properties": {"datasets": {"type": "array", "items": {"type": "string"}, "description": "A collection of data used to train, pretrain or fine-tune a machine learning model."}, "methods": {"type": "array", "items": {"type": "string"}, "description": "A machine learning model or algorithm. This also includes the type of model (e.g. transformer, CNN, RNN, etc.) and other relevant components like optimization algorithms, loss functions, etc."}, "evaluation": {"type": "array", "items": {"type": "string"}, "description": "They describe how a model is evaluated. This includes metrics, benchmarks, etc."}, "tasks": {"type": "array", "items": {"type": "string"}, "description": "A machine learning task (e.g. classification, question answering, summarization, etc.)."}}, "required": ["datasets", "methods", "evaluation", "tasks"]}}}
1 change: 1 addition & 0 deletions src/llm_stack/build_dataset/prompts/openai_ner.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"role": "user", "content": "###Instructions###\nYou work on entity recognition. You will be given an academic abstract from arXiv. Your task is to extract all entities related to machine learning tasks, datasets, methods and benchmarks.\n\nHere is the definition of each entity you will extract:\n- Datasets: a collection of data used to train, pretrain or fine-tune a machine learning model.\n- Methods: a machine learning model or algorithm. This also includes the type of model (e.g. transformer, CNN, RNN, etc.) and other relevant components like optimization algorithms, loss functions, etc.\n- Benchmarks: a metric used to evaluate a machine learning model.\n- Tasks: a machine learning task (e.g. classification, question answering, summarization, etc.).\n\n###Constraints###\n- You must extract any abbreviations or acronyms too and consider them as distinct entities. For example, if you see \"Large Language Models (LLM)\", you must extract \"Large Language Models\" and \"LLM\" as two distinct entities.\n\n###Context###\n{text}\n"}
3 changes: 2 additions & 1 deletion src/llm_stack/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .openai_api import OpenAILLM
from .prompt_template import FunctionTemplate
from .prompt_template import MessageTemplate


__all__ = ["OpenAILLM", "MessageTemplate"]
__all__ = ["OpenAILLM", "MessageTemplate", "FunctionTemplate"]
32 changes: 5 additions & 27 deletions src/llm_stack/openai/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,35 +128,13 @@ async def generate(
**openai_kwargs,
)

response = response.choices[0].message # type: ignore

if extra:
try:
response = await self._parse_json(response.content)
if response:
response.update(extra)
return response
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON: Error: {str(e)}")
return {}

return response.content # type: ignore

@staticmethod
async def _try_parse_json(item: str) -> Union[dict, None]:
try:
return json.loads(item)
response = json.loads(response.choices[0].message.tool_calls[0].function.arguments) # type: ignore
response.update({"id": extra["id"]})
return response
except json.JSONDecodeError as e:
return e

async def _parse_json(self, item: str) -> Union[dict, None]:
result = await self._try_parse_json(item.replace("'", '"'))
if isinstance(result, json.JSONDecodeError):
result = await self._try_parse_json(item)
if isinstance(result, json.JSONDecodeError):
logging.error(f"Invalid JSON: Error: {str(result)}")
return None
return result
logger.error(f"Invalid JSON: Error: {str(e)}")
return {}

@retry(
retry(
Expand Down
43 changes: 43 additions & 0 deletions src/llm_stack/openai/prompt_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,46 @@ def _from_dict(data: Dict) -> "MessageTemplate":
if instance.role == "function" and not instance.name:
raise ValueError("The 'name' attribute is required when 'role' is 'function'.")
return instance


@dataclass
class FunctionTemplate(BasePromptTemplate):
"""Create a template for an OpenAI function."""

name: str
description: str
parameters: Dict[str, Union[str, Dict[str, Dict[str, Union[str, List[str]]]], List[str]]]

def _initialize_template(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
},
}

@staticmethod
def _from_dict(data: Dict) -> "FunctionTemplate":
"""Create a Template instance from a dictionary."""
try:
return FunctionTemplate(**data["function"])
except TypeError as e:
raise TypeError("Expected a dictionary with a 'function' key.") from e

def to_prompt(
self,
exclude: Optional[List[str]] = ["initial_template"], # noqa: B006
) -> Dict:
"""Convert a Template instance to a JSON string."""
# Custom formatting for the output
formatted_data = {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
},
}
return self._exclude_keys(formatted_data, exclude=exclude)
1 change: 1 addition & 0 deletions src/llm_stack/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class WandbTypes:
process_data_job: str = "process_data"
train_model_job: str = "model"
evaluate_model_job: str = "evaluate_model"
inference_job: str = "inference"

dataset_artifact: str = "dataset"
model_artifact: str = "model"
Expand Down

0 comments on commit 16b12a5

Please sign in to comment.