Skip to content

Commit

Permalink
Add support for ContainerTask in PERIAN agent + os-storage parameter
Browse files Browse the repository at this point in the history
Signed-off-by: Omar Tarabai <[email protected]>
  • Loading branch information
otarabai committed Oct 25, 2024
1 parent 57f583e commit 4380667
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 90 deletions.
70 changes: 1 addition & 69 deletions plugins/flytekit-perian/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,4 @@

Flyte Agent plugin for executing Flyte tasks on Perian Job Platform (perian.io).

Perian Job Platform is still in closed beta. Contact [email protected] if you are interested in trying it out.

To install the plugin, run the following command:

```bash
pip install flytekitplugins-perian-job
```

## Getting Started

This plugin allows executing `PythonFunctionTask` on Perian.

An [ImageSpec](https://docs.flyte.org/en/latest/user_guide/customizing_dependencies/imagespec.html) need to be built with the perian agent plugin installed.

### Parameters

The following parameters can be used to set the requirements for the Perian task. If any of the requirements are skipped, it is replaced with the cheapest option. At least one requirement value should be set.
* `cores`: Number of CPU cores
* `memory`: Amount of memory in GB
* `accelerators`: Number of accelerators
* `accelerator_type`: Type of accelerator (e.g. 'A100'). For a full list of supported accelerators, use the perian CLI list-accelerators command.
* `country_code`: Country code to run the job in (e.g. 'DE')

### Credentials

The following [secrets](https://docs.flyte.org/en/latest/user_guide/productionizing/secrets.html) are required to be defined for the agent server:
* Perian credentials:
* `perian_organization`
* `perian_token`
* For accessing the Flyte storage bucket, you need to add either AWS or GCP credentials. These credentials are never logged by Perian and are only stored until then are used, then immediately deleted.
* AWS credentials:
* `aws_access_key_id`
* `aws_secret_access_key`
* GCP credentials:
* `google_application_credentials`. This should be the full json credentials.
* (Optional) Custom docker registry for pulling the Flyte image:
* `docker_registry_url`
* `docker_registry_username`
* `docker_registry_password`

### Example

`example.py` workflow example:
```python
from flytekit import ImageSpec, task, workflow
from flytekitplugins.perian_job import PerianConfig

image_spec = ImageSpec(
name="flyte-test",
registry="my-registry",
python_version="3.11",
apt_packages=["wget", "curl", "git"],
packages=[
"flytekitplugins-perian-job",
],
)

@task(container_image=image_spec,
task_config=PerianConfig(
accelerators=1,
accelerator_type="A100",
))
def perian_hello(name: str) -> str:
return f"hello {name}!"

@workflow
def my_wf(name: str = "world") -> str:
return perian_hello(name=name)
```
See the [official docs page](https://perian.io/docs/flyte-getting-started) for more details.
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
"""

from .agent import PerianAgent
from .task import PerianConfig, PerianTask
from .task import PerianConfig, PerianContainerTask, PerianTask
56 changes: 47 additions & 9 deletions plugins/flytekit-perian/flytekitplugins/perian_job/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import base64
import shlex
from dataclasses import dataclass
from typing import Optional
from typing import Any, Dict, List, Optional

from flyteidl.core.execution_pb2 import TaskExecution
from perian import (
Expand All @@ -17,12 +17,14 @@
JobStatus,
MemoryQueryInput,
Name,
OSStorageConfig,
ProviderQueryInput,
RegionQueryInput,
Size,
)

from flytekit import current_context
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.base import FlyteException
from flytekit.exceptions.user import FlyteUserException
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
Expand Down Expand Up @@ -57,9 +59,13 @@ async def create(
**kwargs,
) -> PerianMetadata:
logger.info("Creating new Perian job")

ctx = current_context()
literal_types = task_template.interface.inputs
input_kwargs = (
TypeEngine.literal_map_to_kwargs(ctx, inputs, literal_types=literal_types) if inputs.literals else None
)
config = Configuration(host=PERIAN_API_URL)
job_request = self._build_create_job_request(task_template)
job_request = self._build_create_job_request(task_template, input_kwargs)
with ApiClient(config) as api_client:
api_instance = JobApi(api_client)
response = api_instance.create_job(
Expand Down Expand Up @@ -105,7 +111,9 @@ def delete(self, resource_meta: PerianMetadata, **kwargs):
if response.status_code != 200:
raise FlyteException(f"Failed to cancel Perian job: {response.text}")

def _build_create_job_request(self, task_template: TaskTemplate) -> CreateJobRequest:
def _build_create_job_request(
self, task_template: TaskTemplate, inputs: Optional[Dict[str, Any]]
) -> CreateJobRequest:
params = task_template.custom
secrets = current_context().secrets

Expand Down Expand Up @@ -143,20 +151,50 @@ def _build_create_job_request(self, task_template: TaskTemplate) -> CreateJobReq
pass

container = task_template.container
if ":" in container.image:
docker_run.image_name, docker_run.image_tag = container.image.rsplit(":", 1)
if container:
image = container.image
else:
image = params["image"]
if ":" in image:
docker_run.image_name, docker_run.image_tag = image.rsplit(":", 1)
else:
docker_run.image_name = image

if container:
command = container.args
else:
docker_run.image_name = container.image
if container.args:
docker_run.command = shlex.join(container.args)
command = self._render_command_template(params["command"], inputs)
if command:
docker_run.command = shlex.join(command)

if params.get("environment"):
if docker_run.env_variables:
docker_run.env_variables.update(params["environment"])
else:
docker_run.env_variables = params["environment"]

storage_config = None
if params.get("os_storage_size"):
storage_config = OSStorageConfig(size=int(params["os_storage_size"]))

return CreateJobRequest(
auto_failover_instance_type=True,
requirements=reqs,
docker_run_parameters=docker_run,
docker_registry_credentials=docker_registry,
os_storage_config=storage_config,
)

def _render_command_template(self, command: List[str], inputs: Optional[Dict[str, Any]]) -> List[str]:
if not inputs:
return command
rendered_command = []
for c in command:
for key, val in inputs.items():
c = c.replace("{{.inputs." + key + "}}", str(val))
rendered_command.append(c)
return rendered_command

def _read_storage_credentials(self) -> DockerRunParameters:
secrets = current_context().secrets
docker_run = DockerRunParameters()
Expand Down
74 changes: 63 additions & 11 deletions plugins/flytekit-perian/flytekitplugins/perian_job/task.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Type, Union

from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit import FlyteContextManager, PythonFunctionTask, logger
from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.interface import Interface
from flytekit.exceptions.user import FlyteUserException
from flytekit.extend import TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
Expand All @@ -25,14 +27,16 @@ class PerianConfig:
# Type of accelerator (e.g. 'A100')
# For a full list of supported accelerators, use the perian CLI list-accelerators command
accelerator_type: Optional[str] = None
# OS storage size in GB
os_storage_size: Optional[int] = None
# Country code to run the job in (e.g. 'DE')
country_code: Optional[str] = None
# Cloud provider to run the job on
provider: Optional[str] = None


class PerianTask(AsyncAgentExecutorMixin, PythonFunctionTask):
"""A special task type for running tasks on PERIAN Job Platform (perian.io)"""
"""A special task type for running Python function tasks on PERIAN Job Platform (perian.io)"""

_TASK_TYPE = "perian_task"

Expand Down Expand Up @@ -72,20 +76,68 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
"""
Return plugin-specific data as a serializable dictionary.
"""
config = {
"cores": self.task_config.cores,
"memory": self.task_config.memory,
"accelerators": self.task_config.accelerators,
"accelerator_type": self.task_config.accelerator_type,
"country_code": _validate_and_format_country_code(self.task_config.country_code),
"provider": self.task_config.provider,
}
config = {k: v for k, v in config.items() if v is not None}
config = _get_custom_task_config(self.task_config)
if self.environment:
config["environment"] = self.environment
s = Struct()
s.update(config)
return json_format.MessageToDict(s)


class PerianContainerTask(AsyncAgentExecutorMixin, PythonTask[PerianConfig]):
"""A special task type for running Python container (not function) tasks on PERIAN Job Platform (perian.io)"""

_TASK_TYPE = "perian_task"

def __init__(
self,
name: str,
task_config: PerianConfig,
image: str,
command: List[str],
inputs: Optional[OrderedDict[str, Type]] = None,
**kwargs,
):
if "outputs" in kwargs or "output_data_dir" in kwargs:
raise ValueError("PerianContainerTask does not support 'outputs' or 'output_data_dir' arguments")
super().__init__(
name=name,
task_type=self._TASK_TYPE,
task_config=task_config,
interface=Interface(inputs=inputs or {}),
**kwargs,
)
self._image = image
self._command = command

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
"""
Return plugin-specific data as a serializable dictionary.
"""
config = _get_custom_task_config(self.task_config)
config["image"] = self._image
config["command"] = self._command
if self.environment:
config["environment"] = self.environment
s = Struct()
s.update(config)
return json_format.MessageToDict(s)


def _get_custom_task_config(task_config: PerianConfig) -> Dict[str, Any]:
config = {
"cores": task_config.cores,
"memory": task_config.memory,
"accelerators": task_config.accelerators,
"accelerator_type": task_config.accelerator_type,
"os_storage_size": task_config.os_storage_size,
"country_code": _validate_and_format_country_code(task_config.country_code),
"provider": task_config.provider,
}
config = {k: v for k, v in config.items() if v is not None}
return config


def _validate_and_format_country_code(country_code: Optional[str]) -> Optional[str]:
if not country_code:
return None
Expand Down

0 comments on commit 4380667

Please sign in to comment.