Skip to content

Commit

Permalink
add ollama to flytekit-inference (#2677)
Browse files Browse the repository at this point in the history
* add ollama to flytekit-inference

Signed-off-by: Samhita Alla <[email protected]>

* add ollama to setup.py

Signed-off-by: Samhita Alla <[email protected]>

* add support for creating models

Signed-off-by: Samhita Alla <[email protected]>

* escape quote

Signed-off-by: Samhita Alla <[email protected]>

* fix type hint

Signed-off-by: Samhita Alla <[email protected]>

* lint

Signed-off-by: Samhita Alla <[email protected]>

* update readme

Signed-off-by: Samhita Alla <[email protected]>

* add support for flytefile in init container

Signed-off-by: Samhita Alla <[email protected]>

* debug

Signed-off-by: Samhita Alla <[email protected]>

* encode the modelfile

Signed-off-by: Samhita Alla <[email protected]>

* flytefile in init container

Signed-off-by: Samhita Alla <[email protected]>

* add input to args

Signed-off-by: Samhita Alla <[email protected]>

* update inputs code and readme

Signed-off-by: Samhita Alla <[email protected]>

* clean up

Signed-off-by: Samhita Alla <[email protected]>

* cleanup

Signed-off-by: Samhita Alla <[email protected]>

* add comment

Signed-off-by: Samhita Alla <[email protected]>

* move sleep to python code snippets

Signed-off-by: Samhita Alla <[email protected]>

* move input download code to init container

Signed-off-by: Samhita Alla <[email protected]>

* debug

Signed-off-by: Samhita Alla <[email protected]>

* move base code and ollama service ready to outer condition

Signed-off-by: Samhita Alla <[email protected]>

* fix tests

Signed-off-by: Samhita Alla <[email protected]>

* swap images

Signed-off-by: Samhita Alla <[email protected]>

* remove tmp and update readme

Signed-off-by: Samhita Alla <[email protected]>

* download to tmp if the file isn't in tmp

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
  • Loading branch information
samhita-alla authored Sep 16, 2024
1 parent 0b26c92 commit bba6509
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 17 deletions.
16 changes: 9 additions & 7 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
steps:
- uses: actions/checkout@v4
- name: 'Clear action cache'
- name: "Clear action cache"
uses: ./.github/actions/clear-action-cache
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down Expand Up @@ -81,7 +81,7 @@ jobs:
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
steps:
- uses: actions/checkout@v4
- name: 'Clear action cache'
- name: "Clear action cache"
uses: ./.github/actions/clear-action-cache
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down Expand Up @@ -133,7 +133,7 @@ jobs:

steps:
- uses: actions/checkout@v4
- name: 'Clear action cache'
- name: "Clear action cache"
uses: ./.github/actions/clear-action-cache
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down Expand Up @@ -244,15 +244,16 @@ jobs:
matrix:
os: [ubuntu-latest]
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
makefile-cmd: [integration_test_codecov, integration_test_lftransfers_codecov]
makefile-cmd:
[integration_test_codecov, integration_test_lftransfers_codecov]
steps:
# As described in https://github.com/pypa/setuptools_scm/issues/414, SCM needs git history
# and tags to work.
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: 'Clear action cache'
uses: ./.github/actions/clear-action-cache # sandbox has disk pressure, so we need to clear the cache to get more disk space.
- name: "Clear action cache"
uses: ./.github/actions/clear-action-cache # sandbox has disk pressure, so we need to clear the cache to get more disk space.
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down Expand Up @@ -335,6 +336,7 @@ jobs:
- flytekit-hive
- flytekit-huggingface
- flytekit-identity-aware-proxy
- flytekit-inference
- flytekit-k8s-pod
- flytekit-kf-mpi
- flytekit-kf-pytorch
Expand Down Expand Up @@ -414,7 +416,7 @@ jobs:
plugin-names: "flytekit-kf-pytorch"
steps:
- uses: actions/checkout@v4
- name: 'Clear action cache'
- name: "Clear action cache"
uses: ./.github/actions/clear-action-cache
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
59 changes: 59 additions & 0 deletions plugins/flytekit-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,62 @@ def model_serving() -> str:

return completion.choices[0].message.content
```

## Ollama

The Ollama plugin allows you to serve LLMs locally.
You can either pull an existing model or create a new one.

```python
from textwrap import dedent

from flytekit import ImageSpec, Resources, task, workflow
from flytekitplugins.inference import Ollama, Model
from flytekit.extras.accelerators import A10G
from openai import OpenAI


image = ImageSpec(
name="ollama_serve",
registry="...",
packages=["flytekitplugins-inference"],
)

ollama_instance = Ollama(
model=Model(
name="llama3-mario",
modelfile=dedent("""\
FROM llama3
ADAPTER {inputs.gguf}
PARAMETER temperature 1
PARAMETER num_ctx 4096
SYSTEM You are Mario from super mario bros, acting as an assistant.\
"""),
)
)


@task(
container_image=image,
pod_template=ollama_instance.pod_template,
accelerator=A10G,
requests=Resources(gpu="0"),
)
def model_serving(questions: list[str], gguf: FlyteFile) -> list[str]:
responses = []
client = OpenAI(
base_url=f"{ollama_instance.base_url}/v1", api_key="ollama"
) # api key required but ignored

for question in questions:
completion = client.chat.completions.create(
model="llama3-mario",
messages=[
{"role": "user", "content": question},
],
max_tokens=256,
)
responses.append(completion.choices[0].message.content)

return responses
```
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
NIM
NIMSecrets
Model
Ollama
"""

from .nim.serve import NIM, NIMSecrets
from .ollama.serve import Model, Ollama
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def __init__(
gpu: int = 1,
mem: str = "20Gi",
shm_size: str = "16Gi",
env: Optional[dict[str, str]] = None,
env: Optional[
dict[str, str]
] = None, # https://docs.nvidia.com/nim/large-language-models/latest/configuration.html#environment-variables
hf_repo_ids: Optional[list[str]] = None,
lora_adapter_mem: Optional[str] = None,
):
Expand Down
Empty file.
180 changes: 180 additions & 0 deletions plugins/flytekit-inference/flytekitplugins/inference/ollama/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import base64
from dataclasses import dataclass
from typing import Optional

from ..sidecar_template import ModelInferenceTemplate


@dataclass
class Model:
"""Represents the configuration for a model used in a Kubernetes pod template.
:param name: The name of the model.
:param mem: The amount of memory allocated for the model, specified as a string. Default is "500Mi".
:param cpu: The number of CPU cores allocated for the model. Default is 1.
:param modelfile: The actual model file as a JSON-serializable string. This represents the file content. Default is `None` if not applicable.
"""

name: str
mem: str = "500Mi"
cpu: int = 1
modelfile: Optional[str] = None


class Ollama(ModelInferenceTemplate):
def __init__(
self,
*,
model: Model,
image: str = "ollama/ollama",
port: int = 11434,
cpu: int = 1,
gpu: int = 1,
mem: str = "15Gi",
):
"""Initialize Ollama class for managing a Kubernetes pod template.
:param model: An instance of the Model class containing the model's configuration, including its name, memory, CPU, and file.
:param image: The Docker image to be used for the container. Default is "ollama/ollama".
:param port: The port number on which the container should expose its service. Default is 11434.
:param cpu: The number of CPU cores requested for the container. Default is 1.
:param gpu: The number of GPUs requested for the container. Default is 1.
:param mem: The amount of memory requested for the container, specified as a string. Default is "15Gi".
"""
self._model_name = model.name
self._model_mem = model.mem
self._model_cpu = model.cpu
self._model_modelfile = model.modelfile

super().__init__(
image=image,
port=port,
cpu=cpu,
gpu=gpu,
mem=mem,
download_inputs=(True if self._model_modelfile and "{inputs" in self._model_modelfile else False),
)

self.setup_ollama_pod_template()

def setup_ollama_pod_template(self):
from kubernetes.client.models import (
V1Container,
V1ResourceRequirements,
V1SecurityContext,
V1VolumeMount,
)

container_name = "create-model" if self._model_modelfile else "pull-model"

base_code = """
import base64
import time
import ollama
import requests
"""

ollama_service_ready = f"""
# Wait for Ollama service to be ready
max_retries = 30
retry_interval = 1
for _ in range(max_retries):
try:
response = requests.get('{self.base_url}')
if response.status_code == 200:
print('Ollama service is ready')
break
except requests.RequestException:
pass
time.sleep(retry_interval)
else:
print('Ollama service did not become ready in time')
exit(1)
"""
if self._model_modelfile:
encoded_modelfile = base64.b64encode(self._model_modelfile.encode("utf-8")).decode("utf-8")

if "{inputs" in self._model_modelfile:
python_code = f"""
{base_code}
import json
with open('/shared/inputs.json', 'r') as f:
inputs = json.load(f)
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
inputs = {{'inputs': AttrDict(inputs)}}
encoded_model_file = '{encoded_modelfile}'
modelfile = base64.b64decode(encoded_model_file).decode('utf-8').format(**inputs)
modelfile = modelfile.replace('{{', '{{{{').replace('}}', '}}}}')
with open('Modelfile', 'w') as f:
f.write(modelfile)
{ollama_service_ready}
# Debugging: Shows the status of model creation.
for chunk in ollama.create(model='{self._model_name}', path='Modelfile', stream=True):
print(chunk)
"""
else:
python_code = f"""
{base_code}
encoded_model_file = '{encoded_modelfile}'
modelfile = base64.b64decode(encoded_model_file).decode('utf-8')
with open('Modelfile', 'w') as f:
f.write(modelfile)
{ollama_service_ready}
# Debugging: Shows the status of model creation.
for chunk in ollama.create(model='{self._model_name}', path='Modelfile', stream=True):
print(chunk)
"""
else:
python_code = f"""
{base_code}
{ollama_service_ready}
# Debugging: Shows the status of model pull.
for chunk in ollama.pull('{self._model_name}', stream=True):
print(chunk)
"""

command = f'python3 -c "{python_code}"'

self.pod_template.pod_spec.init_containers.append(
V1Container(
name=container_name,
image="python:3.11-slim",
command=["/bin/sh", "-c"],
args=[f"pip install requests && pip install ollama && {command}"],
resources=V1ResourceRequirements(
requests={
"cpu": self._model_cpu,
"memory": self._model_mem,
},
limits={
"cpu": self._model_cpu,
"memory": self._model_mem,
},
),
security_context=V1SecurityContext(
run_as_user=0,
),
volume_mounts=[
V1VolumeMount(name="shared-data", mount_path="/shared"),
V1VolumeMount(name="tmp", mount_path="/tmp"),
],
)
)
Loading

0 comments on commit bba6509

Please sign in to comment.