Skip to content

Commit

Permalink
jobs: Add support for HF text-generation pipelines
Browse files Browse the repository at this point in the history
Add support for models filtered by the "text-generation" tag on the HF
Hub. This means adding support for causal models, such as Llama,
Gemma, Qwen, etc.

Causal models can complete a wide variety of tasks; therefore, users
should provide a proper system prompt to inform the models about the
task at hand. The user-provided system prompt, along with the data, is
then transformed into a model-specific prompt using the chat template
defined in the tokenizer config (a Jinja2 string template). This step
will fail for older models, which don't have a defined chat template.

Refs #122

Signed-off-by: Dimitris Poulopoulos <[email protected]>
  • Loading branch information
dpoulopoulos committed Feb 5, 2025
1 parent 0825694 commit 758259c
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 10 deletions.
8 changes: 7 additions & 1 deletion lumigator/backend/backend/api/routes/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ def create_inference_job(
response: Response,
background_tasks: BackgroundTasks,
) -> JobResponse:
job_response = service.create_job(job_create_request, background_tasks)
try:
job_response = service.create_job(job_create_request, background_tasks)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid request: {e}",
) from e

url = request.url_for(get_job.__name__, job_id=job_response.id)
response.headers[HttpHeaders.LOCATION] = f"{url}"
Expand Down
3 changes: 2 additions & 1 deletion lumigator/backend/backend/config_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
default_infer_template = """{{
"name": "{job_name}/{job_id}",
"dataset": {{ "path": "{dataset_path}" }},
"system_prompt": "{system_prompt}",
"hf_pipeline": {{
"model_uri": "{model_uri}",
"task": "{task}",
Expand All @@ -99,7 +100,7 @@
"use_fast": "{use_fast}",
"trust_remote_code": "{trust_remote_code}",
"torch_dtype": "{torch_dtype}",
"max_length": 500
"max_length": "{max_tokens}"
}},
"job": {{
"max_samples": {max_samples},
Expand Down
3 changes: 3 additions & 0 deletions lumigator/backend/backend/services/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ def create_job(
job_type = JobType.INFERENCE
if not request.output_field:
request.output_field = "predictions"

if request.task == "text-generation" and not request.system_prompt:
raise ValueError("System prompt is required for text generation tasks.")
else:
raise JobTypeUnsupportedError(request) from None

Expand Down
1 change: 1 addition & 0 deletions lumigator/jobs/inference/inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class InferenceJobConfig(BaseInferenceJobConfig):
name: str
dataset: DatasetConfig
job: JobConfig
system_prompt: str | None = Field(title="System Prompt", default=None, exclude=True)
inference_server: InferenceServerConfig | None = None
params: SamplingParameters | None = None
hf_pipeline: HfPipelineConfig | None = None
Expand Down
31 changes: 24 additions & 7 deletions lumigator/jobs/inference/model_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,30 @@ def _chat_completion(

class HuggingFaceModelClient(BaseModelClient):
def __init__(self, config: InferenceJobConfig):
logger.info(f"System prompt: {config.system_prompt}")
self._system = config.system_prompt
self._task = config.hf_pipeline.task
self._pipeline = pipeline(**config.hf_pipeline.model_dump())

def predict(self, prompt):
prediction = self._pipeline(prompt)[0]

# The result is a dictionary with a single key, which name depends on the task
# (e.g., 'summary_text' for summarization)
# Get the name of the key and return its value
result_key = list(prediction.keys())[0]
return prediction[result_key]
# When using a text-generation model, the pipeline returns a dictionary with a single key,
# 'generated_text'. The value of this key is a list of dictionaries, each containing the\
# role and content of a message. For example:
# [{'role': 'system', 'content': 'You are a helpful assistant.'},
# {'role': 'user', 'content': 'What is the capital of France?'}, ...]
# We want to return the content of the last message in the list, which is the model's
# response to the prompt.
if self._task == "text-generation":
messages = [
{"role": "system", "content": self._system},
{"role": "user", "content": prompt},
]
generation = self._pipeline(messages)[0]
return generation["generated_text"][-1]["content"]

# If we're using a summarization model, the pipeline returns a dictionary with a single key.
# The name of the key depends on the task (e.g., 'summary_text' for summarization).
# Get the name of the key and return its value.
if self._task == "summarization":
generation = self._pipeline(prompt)[0]
return generation["summary_text"]
1 change: 1 addition & 0 deletions lumigator/jobs/inference/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class InferenceJobConfig(BaseModel):
name: str
dataset: DatasetConfig
job: JobConfig
system_prompt: str | None = None
inference_server: InferenceServerConfig | None = None
params: SamplingParameters | None = None
hf_pipeline: HfPipelineConfig | None = None
Expand Down
9 changes: 8 additions & 1 deletion lumigator/schemas/lumigator_schemas/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,14 @@ class JobInferenceCreate(BaseModel):
trust_remote_code: bool = False
torch_dtype: str = "auto"
model_url: str | None = None
system_prompt: str | None = None
system_prompt: str | None = Field(
title="System Prompt",
default=None,
examples=[
"You are an advanced AI trained to summarize documents accurately and concisely. "
"Your goal is to extract key information while maintaining clarity and coherence."
],
)
output_field: str | None = "predictions"
max_tokens: int = 1024
frequency_penalty: float = 0.0
Expand Down

0 comments on commit 758259c

Please sign in to comment.