Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Cache Plugin #2971

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions flytekit/core/auto_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union, runtime_checkable

from flytekit.image_spec.image_spec import ImageSpec


@dataclass
class VersionParameters:
"""
Parameters used for version hash generation.

Args:
func (Optional[Callable]): The function to generate a version for
container_image (Optional[Union[str, ImageSpec]]): The container image to generate a version for
"""

func: Optional[Callable[..., Any]] = None
container_image: Optional[Union[str, ImageSpec]] = None


@runtime_checkable
class AutoCache(Protocol):
"""
A protocol that defines the interface for a caching mechanism
that generates a version hash of a function based on its source code.
"""

salt: str

def get_version(self, params: VersionParameters) -> str:
"""
Generate a version hash based on the provided parameters.

Args:
params (VersionParameters): Parameters to use for hash generation.

Returns:
str: The generated version hash.
"""
...

Check warning on line 40 in flytekit/core/auto_cache.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/auto_cache.py#L40

Added line #L40 was not covered by tests


class CachePolicy:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can CachePolicy live in the plugin? It makes sense for the abstract AutoCache protocol to be defined in flytekit core, but any implementation of it should be in the plugin.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, thanks. i refactored this!

"""
A class that combines multiple caching mechanisms to generate a version hash.

Args:
auto_cache_policies: A list of AutoCache instances (optional).
salt: Optional salt string to add uniqueness to the hash.
cache_serialize: Boolean to indicate if serialization should be used.
cache_version: A version string for the cache.
cache_ignore_input_vars: Tuple of input variable names to ignore.
"""

def __init__(
self,
auto_cache_policies: List["AutoCache"] = None,
salt: str = "",
cache_serialize: bool = False,
cache_version: str = "",
cache_ignore_input_vars: Tuple[str, ...] = (),
) -> None:
self.auto_cache_policies = auto_cache_policies or [] # Use an empty list if None is provided
self.salt = salt
self.cache_serialize = cache_serialize
self.cache_version = cache_version
self.cache_ignore_input_vars = cache_ignore_input_vars

Check warning on line 67 in flytekit/core/auto_cache.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/auto_cache.py#L63-L67

Added lines #L63 - L67 were not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the purpose of saving this state here? aren't these just forwarded to the underlying TaskMetadata?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea with this is the user could use the CachePolicy to define all the arguments relating to caching. This simplifies the UX a bit as opposed to having a CachePolicy and a cache_ignore_input_vars, cache_serialize, etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a little confusing:

  • cache_version should not be exposed, since the AutoCache protocol is meant to produce this value automatically, and salt is meant to fulfill the need of manually bumping the cache.
  • I think it makes sense to keep cache_serialize and cache_ignore_input_vars as options to specify in the @task decorator as opposed to introducing this redundancy here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Sounds like there is a separate effort aimed at collecting all of the caching arguments here: flyteorg/flyte#6143

Happy to use that instead and simplify the arguments here!


def get_version(self, params: "VersionParameters") -> str:
"""
Generate a version hash using all cache objects. If the user passes a version, it takes precedence over auto_cache_policies.

Args:
params (VersionParameters): Parameters to use for hash generation.

Returns:
str: The combined hash from all cache objects.
"""
if self.cache_version:
return self.cache_version

Check warning on line 80 in flytekit/core/auto_cache.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/auto_cache.py#L80

Added line #L80 was not covered by tests

if self.auto_cache_policies:
task_hash = ""

Check warning on line 83 in flytekit/core/auto_cache.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/auto_cache.py#L83

Added line #L83 was not covered by tests
for cache_instance in self.auto_cache_policies:
# Apply the policy's salt to each cache instance
cache_instance.salt = self.salt
task_hash += cache_instance.get_version(params)

Check warning on line 87 in flytekit/core/auto_cache.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/auto_cache.py#L86-L87

Added lines #L86 - L87 were not covered by tests

# Generate SHA-256 hash
import hashlib

Check warning on line 90 in flytekit/core/auto_cache.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/auto_cache.py#L90

Added line #L90 was not covered by tests

hash_obj = hashlib.sha256(task_hash.encode())
return hash_obj.hexdigest()

Check warning on line 93 in flytekit/core/auto_cache.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/auto_cache.py#L92-L93

Added lines #L92 - L93 were not covered by tests

return None

Check warning on line 95 in flytekit/core/auto_cache.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/auto_cache.py#L95

Added line #L95 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type inconsistency in get_version method

Consider returning an empty string instead of None for consistency in return types. The method signature indicates it returns str but can return None.

Code suggestion
Check the AI-generated fix before applying
Suggested change
return None
return ""

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

48 changes: 33 additions & 15 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.auto_cache import AutoCache, CachePolicy, VersionParameters
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.interface import Interface, output_name_generator, transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
Expand Down Expand Up @@ -95,7 +96,7 @@
def task(
_task_function: None = ...,
task_config: Optional[T] = ...,
cache: bool = ...,
cache: Union[bool, CachePolicy, AutoCache] = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
Expand Down Expand Up @@ -132,9 +133,9 @@

@overload
def task(
_task_function: Callable[P, FuncOut],
_task_function: Callable[..., FuncOut],
task_config: Optional[T] = ...,
cache: bool = ...,
cache: Union[bool, CachePolicy, AutoCache] = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
Expand Down Expand Up @@ -166,13 +167,13 @@
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...
) -> Union[Callable[..., FuncOut], PythonFunctionTask[T]]: ...


def task(
_task_function: Optional[Callable[P, FuncOut]] = None,
_task_function: Optional[Callable[..., FuncOut]] = None,
task_config: Optional[T] = None,
cache: bool = False,
cache: Union[bool, CachePolicy, AutoCache] = False,
cache_serialize: bool = False,
cache_version: str = "",
cache_ignore_input_vars: Tuple[str, ...] = (),
Expand Down Expand Up @@ -211,8 +212,8 @@
accelerator: Optional[BaseAccelerator] = None,
pickle_untyped: bool = False,
) -> Union[
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
Callable[..., FuncOut],
Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]],
PythonFunctionTask[T],
]:
"""
Expand Down Expand Up @@ -247,7 +248,7 @@
:param _task_function: This argument is implicitly passed and represents the decorated function
:param task_config: This argument provides configuration for a specific task types.
Please refer to the plugins documentation for the right object to use.
:param cache: Boolean that indicates if caching should be enabled
:param cache: Boolean that indicates if caching should be enabled or a list of AutoCache implementations
:param cache_serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be
executed in serial when caching is enabled. This means that given multiple concurrent executions over
identical inputs, only a single instance executes and the rest wait to reuse the cached results. This
Expand Down Expand Up @@ -342,12 +343,29 @@
:param pickle_untyped: Boolean that indicates if the task allows unspecified data types.
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
# Initialize defaults
cache_val = cache
cache_version_val = cache_version
cache_serialize_val = cache_serialize
cache_ignore_input_vars_val = cache_ignore_input_vars

if isinstance(cache, (CachePolicy, AutoCache)):
# If cache is a CachePolicy or AutoCache, enable caching
cache_val = True
params = VersionParameters(func=fn, container_image=container_image)
cache_version_val = cache_version or cache.get_version(params=params)

Check warning on line 358 in flytekit/core/task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/task.py#L356-L358

Added lines #L356 - L358 were not covered by tests
if isinstance(cache, CachePolicy):
# Use CachePolicy-specific attributes if available
cache_serialize_val = cache_serialize or cache.cache_serialize
cache_ignore_input_vars_val = cache_ignore_input_vars or cache.cache_ignore_input_vars

Check warning on line 363 in flytekit/core/task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/task.py#L362-L363

Added lines #L362 - L363 were not covered by tests
_metadata = TaskMetadata(
cache=cache,
cache_serialize=cache_serialize,
cache_version=cache_version,
cache_ignore_input_vars=cache_ignore_input_vars,
cache=cache_val,
cache_serialize=cache_serialize_val,
cache_version=cache_version_val,
cache_ignore_input_vars=cache_ignore_input_vars_val,
retries=retries,
interruptible=interruptible,
deprecated=deprecated,
Expand Down Expand Up @@ -433,7 +451,7 @@
return wrapper


def decorate_function(fn: Callable[P, Any]) -> Callable[P, Any]:
def decorate_function(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorates the task with additional functionality if necessary.

Expand Down
10 changes: 5 additions & 5 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,25 +843,25 @@ def workflow(

@overload
def workflow(
_workflow_function: Callable[P, FuncOut],
_workflow_function: Callable[..., FuncOut],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ...
) -> Union[Callable[..., FuncOut], PythonFunctionWorkflow]: ...


def workflow(
_workflow_function: Optional[Callable[P, FuncOut]] = None,
_workflow_function: Optional[Callable[..., FuncOut]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
) -> Union[Callable[..., FuncOut], Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
of tasks using the data flow between tasks.
Expand Down Expand Up @@ -898,7 +898,7 @@ def workflow(
the labels and annotations are allowed to be set as defaults.
"""

def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
def wrapper(fn: Callable[..., FuncOut]) -> PythonFunctionWorkflow:
workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY)

workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)
Expand Down
62 changes: 62 additions & 0 deletions plugins/flytekit-auto-cache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Flyte Auto Cache Plugin

This plugin provides a caching mechanism for Flyte tasks that generates a version hash based on the source code of the task and its dependencies. It allows users to manage the cache behavior.

## Usage

To install the plugin, run the following command:

```bash
pip install flytekitplugins-auto-cache
```

To use the caching mechanism in a Flyte task, you can define a `CachePolicy` that combines multiple caching strategies. Here’s an example of how to set it up:

```python
from flytekit import task
from flytekit.core.auto_cache import CachePolicy
from flytekitplugins.auto_cache import CacheFunctionBody, CachePrivateModules

cache_policy = CachePolicy(
auto_cache_policies = [
CacheFunctionBody(),
CachePrivateModules(root_dir="../my_package"),
...,
],
salt="my_salt"
)

@task(cache=cache_policy)
def task_fn():
...

@task(cache=CacheFunctionBody())
def other_task_fn():
...
```

### Salt Parameter

The `salt` parameter in the `CachePolicy` adds uniqueness to the generated hash. It can be used to differentiate between different versions of the same task. This ensures that even if the underlying code remains unchanged, the hash will vary if a different salt is provided. This feature is particularly useful for invalidating the cache for specific versions of a task.

## Cache Implementations

Users can add any number of cache policies that implement the `AutoCache` protocol defined in `@auto_cache.py`. Below are the implementations available so far:

### 1. CacheFunctionBody

This implementation hashes the contents of the function of interest, ignoring any formatting or comment changes. It ensures that the core logic of the function is considered for versioning.

### 2. CacheImage

This implementation includes the hash of the `container_image` object passed. If the image is specified as a name, that string is hashed. If it is an `ImageSpec`, the parametrization of the `ImageSpec` is hashed, allowing for precise versioning of the container image used in the task.

### 3. CachePrivateModules

This implementation recursively searches the task of interest for all callables and constants used. The contents of any callable (function or class) utilized by the task are hashed, ignoring formatting or comments. The values of the literal constants used are also included in the hash.

It accounts for both `import` and `from-import` statements at the global and local levels within a module or function. Any callables that are within site-packages (i.e., external libraries) are ignored.

### 4. CacheExternalDependencies

This implementation recursively searches through all the callables like `CachePrivateModules`, but when an external package is found, it records the version of the package, which is included in the hash. This ensures that changes in external dependencies are reflected in the task's versioning.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
.. currentmodule:: flytekitplugins.auto_cache

This package contains things that are useful when extending Flytekit.

.. autosummary::
:template: custom.rst
:toctree: generated/

CacheFunctionBody
CachePrivateModules
"""

from .cache_external_dependencies import CacheExternalDependencies
from .cache_function_body import CacheFunctionBody
from .cache_image import CacheImage
from .cache_private_modules import CachePrivateModules
Loading
Loading