Skip to content

Commit

Permalink
[Extended Resources] GPU Accelerators (#1843)
Browse files Browse the repository at this point in the history
* pip through to container

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* move around

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* add asserts

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* delete bad line

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* switch to abc and add support for gpu unpartitioned

Signed-off-by: Jeev B <[email protected]>

* Add Azure-specific headers when uploading to blob storage (#1784)

* Add Azure-specific headers when uploading to blob storage

Signed-off-by: Victor Delépine <[email protected]>

* Add comment about HTTP 201 check

Signed-off-by: Victor Delépine <[email protected]>

---------

Signed-off-by: Victor Delépine <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Add async delete function in base_agent (#1800)

Signed-off-by: Future Outlier <[email protected]>
Co-authored-by: Future Outlier <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Add support for execution name prefixes (#1803)

Signed-off-by: troychiu <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Remove ref in output (#1794)

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Inherit directly from DataClassJsonMixin instead of using @dataclass_json for improved static type checking (#1801)

* Inherit directly from DataClassJsonMixin instead of @dataclass_json for improved static type checking

As it says in the dataclasses-json README: https://github.com/lidatong/dataclasses-json/blob/89578cb9ebed290e70dba8946bfdb68ff6746755/README.md?plain=1#L111-L129, we can use inheritance for improved static type checking; this one change eliminates something like 467 pyright errors from the flytekit module

Signed-off-by: Matthew Hoffman <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Async file sensor (#1790)

---------
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Eager workflows to support async workflows (#1579)

* Eager workflows to support async workflows

Signed-off-by: Niels Bantilan <[email protected]>

* move array node maptask to experimental/__init__.py

Signed-off-by: Niels Bantilan <[email protected]>

* clean up docs

Signed-off-by: Niels Bantilan <[email protected]>

* clean up

Signed-off-by: Niels Bantilan <[email protected]>

* more clean up

Signed-off-by: Niels Bantilan <[email protected]>

* docs cleanup

Signed-off-by: Niels Bantilan <[email protected]>

* Update test_eager_workflows.py

* clean up timeout handling

Signed-off-by: Niels Bantilan <[email protected]>

* fix lint

Signed-off-by: Niels Bantilan <[email protected]>

---------

Signed-off-by: Niels Bantilan <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Enable SecretsManager.get to load and return bytes (#1798)

* fix secretsmanager

Signed-off-by: Yue Shang <[email protected]>

* fix lint issue

Signed-off-by: Yue Shang <[email protected]>

* add doc

Signed-off-by: Yue Shang <[email protected]>

* fix github check

Signed-off-by: Yue Shang <[email protected]>

---------

Signed-off-by: Yue Shang <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Batch upload flyte directory (#1806)

* Batch upload flyte directory

Signed-off-by: Kevin Su <[email protected]>

* Update get method

Signed-off-by: Kevin Su <[email protected]>

* Move batch size to type engine

Signed-off-by: Kevin Su <[email protected]>

* comment

Signed-off-by: Kevin Su <[email protected]>

* update comment

Signed-off-by: Kevin Su <[email protected]>

* Update flytekit/core/type_engine.py

Co-authored-by: Eduardo Apolinario <[email protected]>

* Add test

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Better error messaging for overrides (#1807)

- using incorrect type of overrides
 - using incorrect type for resources
 - using promises in overrides

Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Run remote Launchplan from `pyflyte run` (#1785)

* Beautified pyflyte run even for every task and workflow

- identify a task or a workflow
- task or workflow help menus show types and use rich to beautify

Signed-off-by: Ketan Umare <[email protected]>

* one more improvement

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* updated command

Signed-off-by: Ketan Umare <[email protected]>

* Updated

Signed-off-by: Ketan Umare <[email protected]>

* updated formatting

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* updated

Signed-off-by: Ketan Umare <[email protected]>

* bug fixed in types

Signed-off-by: Ketan Umare <[email protected]>

* Updated

Signed-off-by: Ketan Umare <[email protected]>

* lint

Signed-off-by: Kevin Su <[email protected]>

---------

Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Add is none function (#1757)

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Dynamic workflow should not throw nested task warning (#1812)

Signed-off-by: oliverhu <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Add a manual image building GH action (#1816)

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* catch abfs protocol in data_persistence.py/get_filesystem and set anon to False (#1813)

Signed-off-by: Jan Fiedler <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* None doesnt work

Signed-off-by: Jeev B <[email protected]>

* unpartitioned selector

Signed-off-by: Jeev B <[email protected]>

* Fix list of annotated structured dataset (#1817)

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Support the flytectl config.yaml admin.clientSecretEnvVar option in flytekit (#1819)

* Support the flytectl config.yaml admin.clientSecretEnvVar option in flytekit

Signed-off-by: Chao-Heng Lee <[email protected]>

* remove helper of getting env var.

Signed-off-by: Chao-Heng Lee <[email protected]>

* refactor variable name.

Signed-off-by: Chao-Heng Lee <[email protected]>

---------

Signed-off-by: Chao-Heng Lee <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Async agent delete function for while loop case (#1802)

Signed-off-by: Future Outlier <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Future Outlier <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* refactor

Signed-off-by: Jeev B <[email protected]>

* fix docs warnings (#1827)

Signed-off-by: Jeev B <[email protected]>

* Fix extract_task_module (#1829)

---------

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Feat: Add type support for pydantic BaseModels (#1660)

Signed-off-by: Adrian Rumpold <[email protected]>
Signed-off-by: Arthur <[email protected]>
Signed-off-by: wirthual <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: eduardo apolinario <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* add test for unspecified mig

Signed-off-by: Jeev B <[email protected]>

* add support for overriding accelerator

Signed-off-by: Jeev B <[email protected]>

* cleanup

Signed-off-by: Jeev B <[email protected]>

* move from core to extras

Signed-off-by: Jeev B <[email protected]>

* fixes

Signed-off-by: Jeev B <[email protected]>

* fixes

Signed-off-by: Jeev B <[email protected]>

* fixes

Signed-off-by: Jeev B <[email protected]>

* cleanup

Signed-off-by: Jeev B <[email protected]>

* Make FlyteRemote slightly more copy/pastable (#1830)

Signed-off-by: Katrina Rogan <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Pyflyte meta inputs (#1823)

* Re-orgining pyflyte run

Signed-off-by: Ketan Umare <[email protected]>

* Pyflyte beautified and simplified

Signed-off-by: Ketan Umare <[email protected]>

* fixed unit test

Signed-off-by: Ketan Umare <[email protected]>

* Added Launch options

Signed-off-by: Ketan Umare <[email protected]>

* lint fix

Signed-off-by: Ketan Umare <[email protected]>

* test fix

Signed-off-by: Ketan Umare <[email protected]>

* fixing docs failure

Signed-off-by: Ketan Umare <[email protected]>

---------

Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Use mashumaro to serialize/deserialize dataclass (#1735)

Signed-off-by: HH <[email protected]>
Signed-off-by: hhcs9527 <[email protected]>
Signed-off-by: Matthew Hoffman <[email protected]>
Co-authored-by: Matthew Hoffman <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Databricks Agent (#1797)

Signed-off-by: Future Outlier <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Co-authored-by: Future Outlier <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Prometheus metrics (#1815)

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Pyflyte register optionally activates schedule (#1832)

* Pyflyte register auto activates schedule

Signed-off-by: Ketan Umare <[email protected]>

* comment addressed

Signed-off-by: Ketan Umare <[email protected]>

---------

Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Remove versions 3.9 and 3.10 (#1831)

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Snowflake agent (#1799)

Signed-off-by: hhcs9527 <[email protected]>
Signed-off-by: HH <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Update agent metric name (#1835)

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* MemVerge MMCloud Agent (#1821)

Signed-off-by: Edwin Yu <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Add download badges in readme (#1836)

Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Eager local entrypoint and support for offloaded types (#1833)

* implement eager workflow local entrypoint, support offloaded types

Signed-off-by: Niels Bantilan <[email protected]>

* wip local entrypoint

Signed-off-by: Niels Bantilan <[email protected]>

* add tests

Signed-off-by: Niels Bantilan <[email protected]>

* add local entrypoint tests

Signed-off-by: Niels Bantilan <[email protected]>

* update eager unit tests, delete test script

Signed-off-by: Niels Bantilan <[email protected]>

* clean up tests

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* update ci

Signed-off-by: Niels Bantilan <[email protected]>

* remove push step

Signed-off-by: Niels Bantilan <[email protected]>

---------

Signed-off-by: Niels Bantilan <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* update requirements and add snowflake agent to api reference (#1838)

* update requirements and add snowflake agent to api reference

Signed-off-by: Samhita Alla <[email protected]>

* update requirements

Signed-off-by: Samhita Alla <[email protected]>

* remove versions

Signed-off-by: Samhita Alla <[email protected]>

* remove tensorflow-macos

Signed-off-by: Samhita Alla <[email protected]>

* lint

Signed-off-by: Samhita Alla <[email protected]>

* downgrade sphinxcontrib-youtube package

Signed-off-by: Samhita Alla <[email protected]>

---------

Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Fix: Make sure decks created in elastic task workers are transferred to parent process (#1837)

* Transfer decks created in the worker process to the parent process

Signed-off-by: Fabio Graetz <[email protected]>

* Add test for decks in elastic tasks

Signed-off-by: Fabio Graetz <[email protected]>

* Update plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py

Signed-off-by: Fabio Graetz <[email protected]>

* Update plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py

Signed-off-by: Fabio Graetz <[email protected]>

---------

Signed-off-by: Fabio Graetz <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* add accept grpc (#1841)

* add accept grpc

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* unpin setup.py grpc

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Revert "add accept grpc"

This reverts commit 2294592.

Signed-off-by: Jeev B <[email protected]>

* default headers interceptor

Signed-off-by: Jeev B <[email protected]>

* setup.py

Signed-off-by: Jeev B <[email protected]>

* fixes

Signed-off-by: Jeev B <[email protected]>

* fmt

Signed-off-by: Jeev B <[email protected]>

* move prometheus-client import

Signed-off-by: Jeev B <[email protected]>

---------

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>
Co-authored-by: Jeev B <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* Feat: Enable `flytekit` to authenticate with proxy in front of FlyteAdmin (#1787)

* Introduce authenticator engine and make proxy auth work

Signed-off-by: Fabio Grätz <[email protected]>

* Use proxy authed session for client credentials flow

Signed-off-by: Fabio Grätz <[email protected]>

* Don't use authenticator engine but do proxy authentication via existing external command authenticator

Signed-off-by: Fabio Grätz <[email protected]>

* Add docstring to AuthenticationHTTPAdapter

Signed-off-by: Fabio Grätz <[email protected]>

* Address todo in docstring

Signed-off-by: Fabio Grätz <[email protected]>

* Create blank session if none provided

Signed-off-by: Fabio Grätz <[email protected]>

* Create blank session if none provided in get_token

Signed-off-by: Fabio Grätz <[email protected]>

* Refresh proxy creds in session when not existing without triggering 401

Signed-off-by: Fabio Grätz <[email protected]>

* Add test for get_session

Signed-off-by: Fabio Grätz <[email protected]>

* Move auth helper test into existing module

Signed-off-by: Fabio Grätz <[email protected]>

* Move auth helper test into existing module

Signed-off-by: Fabio Grätz <[email protected]>

* Add test for upgrade_channel_to_proxy_authenticated

Signed-off-by: Fabio Grätz <[email protected]>

* Auth helper tests without use of responses package

Signed-off-by: Fabio Grätz <[email protected]>

* Feat: Add plugin for generating GCP IAP ID tokens via external command (#1795)

* Add external command plugin to generate id tokens for identity aware proxy

Signed-off-by: Fabio Grätz <[email protected]>

* Retrieve desktop app client secret from gcp secret manager

Signed-off-by: Fabio Grätz <[email protected]>

* Remove comments

Signed-off-by: Fabio Grätz <[email protected]>

* Introduce a command group that allows adding a command to generate service account id tokens later

Signed-off-by: Fabio Grätz <[email protected]>

* Document how to use plugin and deploy Flyte with IAP

Signed-off-by: Fabio Grätz <[email protected]>

* Minor corrections README.md

Signed-off-by: Fabio Grätz <[email protected]>

---------

Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
Signed-off-by: Fabio Grätz <[email protected]>

* Use proxy auth'ed session for device code auth flow

Signed-off-by: Fabio Grätz <[email protected]>

* Fix token client tests

Signed-off-by: Fabio Grätz <[email protected]>

* Make poll token endpoint test more specific

Signed-off-by: Fabio Grätz <[email protected]>

* Make test_client_creds_authenticator test work and more specific

Signed-off-by: Fabio Grätz <[email protected]>

* Make test_client_creds_authenticator_with_custom_scopes test work and more specific

Signed-off-by: Fabio Grätz <[email protected]>

* Implement subcommand to generate id tokens for service accounts

Signed-off-by: Fabio Graetz <[email protected]>

* Test id token generation from service accounts

Signed-off-by: Fabio Graetz <[email protected]>

* Fix plugin requirements

Signed-off-by: Fabio Graetz <[email protected]>

* Document usage of generate-service-account-id-token subcommand

Signed-off-by: Fabio Grätz <[email protected]>

* Document alternative ways to obtain service account id tokens

Signed-off-by: Fabio Grätz <[email protected]>

---------

Signed-off-by: Fabio Grätz <[email protected]>
Signed-off-by: Fabio Graetz <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
Signed-off-by: Jeev B <[email protected]>

* bump flyteidl

Signed-off-by: Jeev B <[email protected]>

* make requirements

Signed-off-by: Jeev B <[email protected]>

* fix failing tests

Signed-off-by: Jeev B <[email protected]>

* move gpu accelerator to flyteidl.core.Resources

Signed-off-by: Jeev B <[email protected]>

* Use ResourceExtensions for extended resources

Signed-off-by: Jeev B <[email protected]>

* cleanup

Signed-off-by: Jeev B <[email protected]>

* Switch to using ExtendedResources in TaskTemplate

Signed-off-by: Jeev B <[email protected]>

* cleanups

Signed-off-by: Jeev B <[email protected]>

* update flyteidl

Signed-off-by: Jeev B <[email protected]>

* Replace _core_task imports with tasks_pb2

Signed-off-by: Jeev B <[email protected]>

* less verbose definitions

Signed-off-by: Jeev B <[email protected]>

* Attempt at less confusing syntax

Signed-off-by: Jeev B <[email protected]>

* Streamline UX

Signed-off-by: Jeev B <[email protected]>

* Run make fmt

Signed-off-by: Jeev B <[email protected]>

---------

Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: Jeev B <[email protected]>
Signed-off-by: Victor Delépine <[email protected]>
Signed-off-by: Future Outlier <[email protected]>
Signed-off-by: troychiu <[email protected]>
Signed-off-by: Matthew Hoffman <[email protected]>
Signed-off-by: Niels Bantilan <[email protected]>
Signed-off-by: Yue Shang <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Ketan Umare <[email protected]>
Signed-off-by: oliverhu <[email protected]>
Signed-off-by: Jan Fiedler <[email protected]>
Signed-off-by: Chao-Heng Lee <[email protected]>
Signed-off-by: Adrian Rumpold <[email protected]>
Signed-off-by: Arthur <[email protected]>
Signed-off-by: wirthual <[email protected]>
Signed-off-by: eduardo apolinario <[email protected]>
Signed-off-by: Katrina Rogan <[email protected]>
Signed-off-by: HH <[email protected]>
Signed-off-by: hhcs9527 <[email protected]>
Signed-off-by: Edwin Yu <[email protected]>
Signed-off-by: Samhita Alla <[email protected]>
Signed-off-by: Fabio Graetz <[email protected]>
Signed-off-by: Fabio Grätz <[email protected]>
Co-authored-by: Yee Hing Tong <[email protected]>
Co-authored-by: Victor Delépine <[email protected]>
Co-authored-by: Future-Outlier <[email protected]>
Co-authored-by: Future Outlier <[email protected]>
Co-authored-by: Yi Chiu <[email protected]>
Co-authored-by: Matthew Hoffman <[email protected]>
Co-authored-by: Kevin Su <[email protected]>
Co-authored-by: Niels Bantilan <[email protected]>
Co-authored-by: Yue Shang <[email protected]>
Co-authored-by: Eduardo Apolinario <[email protected]>
Co-authored-by: Ketan Umare <[email protected]>
Co-authored-by: Keqiu Hu <[email protected]>
Co-authored-by: Jan Fiedler <[email protected]>
Co-authored-by: Chao-Heng Lee <[email protected]>
Co-authored-by: Samhita Alla <[email protected]>
Co-authored-by: Arthur Böök <[email protected]>
Co-authored-by: Katrina Rogan <[email protected]>
Co-authored-by: Po Han(Hank) Huang <[email protected]>
Co-authored-by: Edwin Yu <[email protected]>
Co-authored-by: Fabio M. Graetz, Ph.D <[email protected]>
Co-authored-by: Fabio Grätz <[email protected]>
  • Loading branch information
22 people authored Nov 1, 2023
1 parent d9ad0e1 commit 4b1ad23
Show file tree
Hide file tree
Showing 14 changed files with 315 additions and 10 deletions.
8 changes: 8 additions & 0 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from dataclasses import dataclass
from typing import Any, Coroutine, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast

from flyteidl.core import tasks_pb2

from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import (
ExecutionParameters,
Expand Down Expand Up @@ -344,6 +346,12 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]
"""
return None

def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
return None

def local_execution_mode(self) -> ExecutionState.Mode:
""" """
return ExecutionState.Mode.LOCAL_TASK_EXECUTION
Expand Down
8 changes: 8 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import typing
from typing import Any, List

from flyteidl.core import tasks_pb2

from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.loggers import logger
Expand Down Expand Up @@ -62,6 +64,7 @@ def __init__(
self._aliases: _workflow_model.Alias = None
self._outputs = None
self._resources: typing.Optional[_resources_model] = None
self._extended_resources: typing.Optional[tasks_pb2.ExtendedResources] = None

def runs_before(self, other: Node):
"""
Expand Down Expand Up @@ -172,6 +175,11 @@ def with_overrides(self, *args, **kwargs):
assert_not_promise(v, "container_image")
self.flyte_entity._container_image = v

if "accelerator" in kwargs:
v = kwargs["accelerator"]
assert_not_promise(v, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=v.to_flyte_idl())

return self


Expand Down
15 changes: 15 additions & 0 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from abc import ABC
from typing import Callable, Dict, List, Optional, TypeVar, Union

from flyteidl.core import tasks_pb2

from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
Expand All @@ -13,6 +15,7 @@
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
Expand Down Expand Up @@ -44,6 +47,7 @@ def __init__(
secret_requests: Optional[List[Secret]] = None,
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
**kwargs,
):
"""
Expand All @@ -70,6 +74,7 @@ def __init__(
- `AWS Parameter store <https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html>`__
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
"""
sec_ctx = None
if secret_requests:
Expand Down Expand Up @@ -110,6 +115,7 @@ def __init__(
self._get_command_fn = self.get_default_command

self.pod_template = pod_template
self.accelerator = accelerator

@property
def task_resolver(self) -> TaskResolverMixin:
Expand Down Expand Up @@ -219,6 +225,15 @@ def get_config(self, settings: SerializationSettings) -> Optional[Dict[str, str]
return {}
return {_PRIMARY_CONTAINER_NAME_FIELD: self.pod_template.primary_container_name}

def get_extended_resources(self, settings: SerializationSettings) -> Optional[tasks_pb2.ExtendedResources]:
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
if self.accelerator is None:
return None

return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl())


class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
"""
Expand Down
6 changes: 6 additions & 0 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
from flytekit.core.resources import Resources
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.models.documentation import Documentation
from flytekit.models.security import Secret
Expand Down Expand Up @@ -102,6 +103,7 @@ def task(
enable_deck: Optional[bool] = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
) -> Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]]:
...

Expand Down Expand Up @@ -129,6 +131,7 @@ def task(
enable_deck: Optional[bool] = ...,
pod_template: Optional["PodTemplate"] = ...,
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
) -> Union[PythonFunctionTask[T], Callable[..., FuncOut]]:
...

Expand All @@ -155,6 +158,7 @@ def task(
enable_deck: Optional[bool] = None,
pod_template: Optional["PodTemplate"] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
) -> Union[Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]], PythonFunctionTask[T], Callable[..., FuncOut]]:
"""
This is the core decorator to use for any task type in flytekit.
Expand Down Expand Up @@ -248,6 +252,7 @@ def foo2():
:param docs: Documentation about this task
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
"""

def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
Expand Down Expand Up @@ -277,6 +282,7 @@ def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
docs=docs,
pod_template=pod_template,
pod_template_name=pod_template_name,
accelerator=accelerator,
)
update_wrapper(task_instance, fn)
return task_instance
Expand Down
90 changes: 90 additions & 0 deletions flytekit/extras/accelerators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import abc
import copy
from typing import ClassVar, Generic, Optional, Type, TypeVar

from flyteidl.core import tasks_pb2

T = TypeVar("T")
MIG = TypeVar("MIG", bound="MultiInstanceGPUAccelerator")


class BaseAccelerator(abc.ABC, Generic[T]):
@abc.abstractmethod
def to_flyte_idl(self) -> T:
...


class GPUAccelerator(BaseAccelerator):
def __init__(self, device: str) -> None:
self._device = device

def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator:
return tasks_pb2.GPUAccelerator(device=self._device)


A10G = GPUAccelerator("nvidia-a10g")
L4 = GPUAccelerator("nvidia-l4-vws")
K80 = GPUAccelerator("nvidia-tesla-k80")
M60 = GPUAccelerator("nvidia-tesla-m60")
P4 = GPUAccelerator("nvidia-tesla-p4")
P100 = GPUAccelerator("nvidia-tesla-p100")
T4 = GPUAccelerator("nvidia-tesla-t4")
V100 = GPUAccelerator("nvidia-tesla-v100")


class MultiInstanceGPUAccelerator(BaseAccelerator):
device: ClassVar[str]
_partition_size: Optional[str]

@property
def unpartitioned(self: MIG) -> MIG:
instance = copy.deepcopy(self)
instance._partition_size = None
return instance

@classmethod
def partitioned(cls: Type[MIG], partition_size: str) -> MIG:
instance = cls()
instance._partition_size = partition_size
return instance

def to_flyte_idl(self) -> tasks_pb2.GPUAccelerator:
msg = tasks_pb2.GPUAccelerator(device=self.device)
if not hasattr(self, "_partition_size"):
return msg

if self._partition_size is None:
msg.unpartitioned = True
else:
msg.partition_size = self._partition_size
return msg


class _A100_Base(MultiInstanceGPUAccelerator):
device = "nvidia-tesla-a100"


class _A100(_A100_Base):
partition_1g_5gb = _A100_Base.partitioned("1g.5gb")
partition_2g_10gb = _A100_Base.partitioned("2g.10gb")
partition_3g_20gb = _A100_Base.partitioned("3g.20gb")
partition_4g_20gb = _A100_Base.partitioned("4g.20gb")
partition_7g_40gb = _A100_Base.partitioned("7g.40gb")


A100 = _A100()


class _A100_80GB_Base(MultiInstanceGPUAccelerator):
device = "nvidia-a100-80gb"


class _A100_80GB(_A100_80GB_Base):
partition_1g_10gb = _A100_80GB_Base.partitioned("1g.10gb")
partition_2g_20gb = _A100_80GB_Base.partitioned("2g.20gb")
partition_3g_40gb = _A100_80GB_Base.partitioned("3g.40gb")
partition_4g_40gb = _A100_80GB_Base.partitioned("4g.40gb")
partition_7g_80gb = _A100_80GB_Base.partitioned("7g.80gb")


A100_80GB = _A100_80GB()
16 changes: 13 additions & 3 deletions flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import typing

from flyteidl.core import tasks_pb2
from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit.models import common as _common
Expand Down Expand Up @@ -562,24 +563,33 @@ def from_flyte_idl(cls, pb2_object):


class TaskNodeOverrides(_common.FlyteIdlEntity):
def __init__(self, resources: typing.Optional[Resources] = None):
def __init__(
self, resources: typing.Optional[Resources], extended_resources: typing.Optional[tasks_pb2.ExtendedResources]
):
self._resources = resources
self._extended_resources = extended_resources

@property
def resources(self) -> Resources:
return self._resources

@property
def extended_resources(self) -> tasks_pb2.ExtendedResources:
return self._extended_resources

def to_flyte_idl(self):
return _core_workflow.TaskNodeOverrides(
resources=self.resources.to_flyte_idl() if self.resources is not None else None,
extended_resources=self.extended_resources,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
resources = Resources.from_flyte_idl(pb2_object.resources)
extended_resources = pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None
if bool(resources.requests) or bool(resources.limits):
return cls(resources=resources)
return cls(resources=None)
return cls(resources=resources, extended_resources=extended_resources)
return cls(resources=None, extended_resources=extended_resources)


class TaskNode(_common.FlyteIdlEntity):
Expand Down
13 changes: 13 additions & 0 deletions flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def __init__(
config=None,
k8s_pod=None,
sql=None,
extended_resources=None,
):
"""
A task template represents the full set of information necessary to perform a unit of work in the Flyte system.
Expand All @@ -359,6 +360,7 @@ def __init__(
in tandem with the custom.
:param K8sPod k8s_pod: Alternative to the container used to execute this task.
:param Sql sql: This is used to execute query in FlytePropeller instead of running container or k8s_pod.
:param flyteidl.core.tasks_pb2.ExtendedResources extended_resources: The extended resources to allocate to the task.
"""
if (
(container is not None and k8s_pod is not None)
Expand All @@ -377,6 +379,7 @@ def __init__(
self._security_context = security_context
self._k8s_pod = k8s_pod
self._sql = sql
self._extended_resources = extended_resources

@property
def id(self):
Expand Down Expand Up @@ -451,6 +454,14 @@ def k8s_pod(self):
def sql(self):
return self._sql

@property
def extended_resources(self):
"""
If not None, the extended resources to allocate to the task.
:rtype: flyteidl.core.tasks_pb2.ExtendedResources
"""
return self._extended_resources

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.tasks_pb2.TaskTemplate
Expand All @@ -464,6 +475,7 @@ def to_flyte_idl(self):
container=self.container.to_flyte_idl() if self.container else None,
task_type_version=self.task_type_version,
security_context=self.security_context.to_flyte_idl() if self.security_context else None,
extended_resources=self.extended_resources,
config={k: v for k, v in self.config.items()} if self.config is not None else None,
k8s_pod=self.k8s_pod.to_flyte_idl() if self.k8s_pod else None,
sql=self.sql.to_flyte_idl() if self.sql else None,
Expand All @@ -487,6 +499,7 @@ def from_flyte_idl(cls, pb2_object):
security_context=_sec.SecurityContext.from_flyte_idl(pb2_object.security_context)
if pb2_object.security_context and pb2_object.security_context.ByteSize() > 0
else None,
extended_resources=pb2_object.extended_resources if pb2_object.HasField("extended_resources") else None,
config={k: v for k, v in pb2_object.config.items()} if pb2_object.config is not None else None,
k8s_pod=K8sPod.from_flyte_idl(pb2_object.k8s_pod) if pb2_object.HasField("k8s_pod") else None,
sql=Sql.from_flyte_idl(pb2_object.sql) if pb2_object.HasField("sql") else None,
Expand Down
9 changes: 6 additions & 3 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def get_serializable_task(
config=entity.get_config(settings),
k8s_pod=pod,
sql=entity.get_sql(settings),
extended_resources=entity.get_extended_resources(settings),
)
if settings.should_fast_serialize() and isinstance(entity, PythonAutoContainerTask):
entity.reset_command_fn()
Expand Down Expand Up @@ -440,7 +441,8 @@ def get_serializable_node(
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
task_node=workflow_model.TaskNode(
reference_id=task_spec.template.id, overrides=TaskNodeOverrides(resources=entity._resources)
reference_id=task_spec.template.id,
overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources),
),
)
if entity._aliases:
Expand Down Expand Up @@ -516,7 +518,8 @@ def get_serializable_node(
upstream_node_ids=[n.id for n in upstream_nodes],
output_aliases=[],
task_node=workflow_model.TaskNode(
reference_id=entity.flyte_entity.id, overrides=TaskNodeOverrides(resources=entity._resources)
reference_id=entity.flyte_entity.id,
overrides=TaskNodeOverrides(resources=entity._resources, extended_resources=entity._extended_resources),
),
)
elif isinstance(entity.flyte_entity, FlyteWorkflow):
Expand Down Expand Up @@ -565,7 +568,7 @@ def get_serializable_array_node(
task_spec = get_serializable(entity_mapping, settings, entity, options)
task_node = workflow_model.TaskNode(
reference_id=task_spec.template.id,
overrides=TaskNodeOverrides(resources=node._resources),
overrides=TaskNodeOverrides(resources=node._resources, extended_resources=node._extended_resources),
)
node = workflow_model.Node(
id=entity.name,
Expand Down
Loading

0 comments on commit 4b1ad23

Please sign in to comment.