-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added example of an Ollama Classifier
- Loading branch information
1 parent
11ea8bf
commit def83cf
Showing
2 changed files
with
106 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
100 changes: 100 additions & 0 deletions
100
docs/src/content/docs/classifiers/examples/ollama-classifier.mdx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
--- | ||
title: Ollama classifier with llama3.1 | ||
description: Example of an Ollama classifier | ||
--- | ||
|
||
This example shows an implementation of a classifier using Ollama with a default llama3.1 model in the Multi-Agent Orchestrator System. | ||
|
||
## Implementation | ||
|
||
import { Tabs, TabItem } from '@astrojs/starlight/components'; | ||
|
||
<Tabs syncKey="runtime"> | ||
<TabItem label="Python" icon="seti:python"> | ||
```python | ||
from typing import List, Dict, Optional, Any | ||
from multi_agent_orchestrator.classifiers import Classifier, ClassifierResult | ||
from multi_agent_orchestrator.types import ConversationMessage, ParticipantRole | ||
from multi_agent_orchestrator.utils import Logger | ||
import ollama | ||
from dataclasses import dataclass | ||
|
||
|
||
class OllamaClassifierOptions: | ||
def __init__(self, | ||
model_id: Optional[str] = None, | ||
inference_config: Optional[Dict[str, Any]] = None, | ||
host: Optional[str] = None | ||
): | ||
self.model_id = model_id | ||
self.inference_config = inference_config or {} | ||
self.host = host | ||
|
||
class OllamaClassifier(Classifier): | ||
def __init__(self, options: OllamaClassifierOptions): | ||
super().__init__() | ||
|
||
self.model_id = options.model_id or 'llama3.1' | ||
self.inference_config = options.inference_config | ||
self.streaming = False | ||
self.temperature = options.inference_config.get('temperature', 0.0) | ||
self.client = ollama.Client(host=options.host or None) | ||
|
||
async def process_request(self, | ||
input_text: str, | ||
chat_history: List[ConversationMessage]) -> ClassifierResult: | ||
messages = [ | ||
{"role": msg.role, "content": msg.content[0]['text']} | ||
for msg in chat_history | ||
] | ||
self.system_prompt = self.system_prompt + f'\n question: {input_text}' | ||
messages.append({"role": ParticipantRole.USER.value, "content": self.system_prompt}) | ||
|
||
try: | ||
response = self.client.chat( | ||
model=self.model_id, | ||
messages=messages, | ||
options={'temperature':self.temperature}, | ||
tools=[{ | ||
'type': 'function', | ||
'function': { | ||
'name': 'analyzePrompt', | ||
'description': 'Analyze the user input and provide structured output', | ||
'parameters': { | ||
'type': 'object', | ||
'properties': { | ||
'userinput': { | ||
'type': 'string', | ||
'description': 'The original user input', | ||
}, | ||
'selected_agent': { | ||
'type': 'string', | ||
'description': 'The name of the selected agent', | ||
}, | ||
'confidence': { | ||
'type': 'number', | ||
'description': 'Confidence level between 0 and 1', | ||
}, | ||
}, | ||
'required': ['userinput', 'selected_agent', 'confidence'], | ||
}, | ||
} | ||
}] | ||
) | ||
# Check if the model decided to use the provided function | ||
if not response['message'].get('tool_calls'): | ||
Logger.get_logger().info(f"The model didn't use the function. Its response was:{response['message']['content']}") | ||
raise Exception(f'Ollama model {self.model_id} did not use tools') | ||
else: | ||
tool_result = response['message'].get('tool_calls')[0].get('function', {}).get('arguments', {}) | ||
return ClassifierResult( | ||
selected_agent=self.get_agent_by_id(tool_result.get('selected_agent', None)), | ||
confidence=float(tool_result.get('confidence', 0.0)) | ||
) | ||
|
||
except Exception as e: | ||
Logger.get_logger().error(f'Error in Ollama Classifier :{str(e)}') | ||
raise e | ||
``` | ||
</TabItem> | ||
</Tabs> |