-
Notifications
You must be signed in to change notification settings - Fork 312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Auto Cache Plugin #2971
Auto Cache Plugin #2971
Changes from 18 commits
2786c5b
b18bac7
50552a9
73d2327
6d5cdbf
f76f59a
a5fc1bb
ff3af99
7d06370
011bd67
d5d9576
2b0fbb8
b4911ab
02f6b53
47d01ef
f1ebdc9
ff3555f
18f253e
667b34a
121c06f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,95 @@ | ||||||
from dataclasses import dataclass | ||||||
from typing import Any, Callable, List, Optional, Protocol, Tuple, Union, runtime_checkable | ||||||
|
||||||
from flytekit.image_spec.image_spec import ImageSpec | ||||||
|
||||||
|
||||||
@dataclass | ||||||
class VersionParameters: | ||||||
""" | ||||||
Parameters used for version hash generation. | ||||||
|
||||||
Args: | ||||||
func (Optional[Callable]): The function to generate a version for | ||||||
container_image (Optional[Union[str, ImageSpec]]): The container image to generate a version for | ||||||
""" | ||||||
|
||||||
func: Optional[Callable[..., Any]] = None | ||||||
container_image: Optional[Union[str, ImageSpec]] = None | ||||||
|
||||||
|
||||||
@runtime_checkable | ||||||
class AutoCache(Protocol): | ||||||
""" | ||||||
A protocol that defines the interface for a caching mechanism | ||||||
that generates a version hash of a function based on its source code. | ||||||
""" | ||||||
|
||||||
salt: str | ||||||
|
||||||
def get_version(self, params: VersionParameters) -> str: | ||||||
""" | ||||||
Generate a version hash based on the provided parameters. | ||||||
|
||||||
Args: | ||||||
params (VersionParameters): Parameters to use for hash generation. | ||||||
|
||||||
Returns: | ||||||
str: The generated version hash. | ||||||
""" | ||||||
... | ||||||
|
||||||
|
||||||
class CachePolicy: | ||||||
""" | ||||||
A class that combines multiple caching mechanisms to generate a version hash. | ||||||
|
||||||
Args: | ||||||
auto_cache_policies: A list of AutoCache instances (optional). | ||||||
salt: Optional salt string to add uniqueness to the hash. | ||||||
cache_serialize: Boolean to indicate if serialization should be used. | ||||||
cache_version: A version string for the cache. | ||||||
cache_ignore_input_vars: Tuple of input variable names to ignore. | ||||||
""" | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
auto_cache_policies: List["AutoCache"] = None, | ||||||
salt: str = "", | ||||||
cache_serialize: bool = False, | ||||||
cache_version: str = "", | ||||||
cache_ignore_input_vars: Tuple[str, ...] = (), | ||||||
) -> None: | ||||||
self.auto_cache_policies = auto_cache_policies or [] # Use an empty list if None is provided | ||||||
self.salt = salt | ||||||
self.cache_serialize = cache_serialize | ||||||
self.cache_version = cache_version | ||||||
self.cache_ignore_input_vars = cache_ignore_input_vars | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's the purpose of saving this state here? aren't these just forwarded to the underlying There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea with this is the user could use the CachePolicy to define all the arguments relating to caching. This simplifies the UX a bit as opposed to having a CachePolicy and a cache_ignore_input_vars, cache_serialize, etc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a little confusing:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good. Sounds like there is a separate effort aimed at collecting all of the caching arguments here: flyteorg/flyte#6143 Happy to use that instead and simplify the arguments here! |
||||||
|
||||||
def get_version(self, params: "VersionParameters") -> str: | ||||||
""" | ||||||
Generate a version hash using all cache objects. If the user passes a version, it takes precedence over auto_cache_policies. | ||||||
|
||||||
Args: | ||||||
params (VersionParameters): Parameters to use for hash generation. | ||||||
|
||||||
Returns: | ||||||
str: The combined hash from all cache objects. | ||||||
""" | ||||||
if self.cache_version: | ||||||
return self.cache_version | ||||||
|
||||||
if self.auto_cache_policies: | ||||||
task_hash = "" | ||||||
for cache_instance in self.auto_cache_policies: | ||||||
# Apply the policy's salt to each cache instance | ||||||
cache_instance.salt = self.salt | ||||||
task_hash += cache_instance.get_version(params) | ||||||
|
||||||
# Generate SHA-256 hash | ||||||
import hashlib | ||||||
|
||||||
hash_obj = hashlib.sha256(task_hash.encode()) | ||||||
return hash_obj.hexdigest() | ||||||
|
||||||
return None | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Return type inconsistency in get_version method
Consider returning an empty string instead of Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #bc105b Is this a valid issue, or was it incorrectly flagged by the Agent?
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Flyte Auto Cache Plugin | ||
|
||
This plugin provides a caching mechanism for Flyte tasks that generates a version hash based on the source code of the task and its dependencies. It allows users to manage the cache behavior. | ||
|
||
## Usage | ||
|
||
To install the plugin, run the following command: | ||
|
||
```bash | ||
pip install flytekitplugins-auto-cache | ||
``` | ||
|
||
To use the caching mechanism in a Flyte task, you can define a `CachePolicy` that combines multiple caching strategies. Here’s an example of how to set it up: | ||
|
||
```python | ||
from flytekit import task | ||
from flytekit.core.auto_cache import CachePolicy | ||
from flytekitplugins.auto_cache import CacheFunctionBody, CachePrivateModules | ||
|
||
cache_policy = CachePolicy( | ||
auto_cache_policies = [ | ||
CacheFunctionBody(), | ||
CachePrivateModules(root_dir="../my_package"), | ||
..., | ||
], | ||
salt="my_salt" | ||
) | ||
|
||
@task(cache=cache_policy) | ||
def task_fn(): | ||
... | ||
|
||
@task(cache=CacheFunctionBody()) | ||
def other_task_fn(): | ||
... | ||
``` | ||
|
||
### Salt Parameter | ||
|
||
The `salt` parameter in the `CachePolicy` adds uniqueness to the generated hash. It can be used to differentiate between different versions of the same task. This ensures that even if the underlying code remains unchanged, the hash will vary if a different salt is provided. This feature is particularly useful for invalidating the cache for specific versions of a task. | ||
|
||
## Cache Implementations | ||
|
||
Users can add any number of cache policies that implement the `AutoCache` protocol defined in `@auto_cache.py`. Below are the implementations available so far: | ||
|
||
### 1. CacheFunctionBody | ||
|
||
This implementation hashes the contents of the function of interest, ignoring any formatting or comment changes. It ensures that the core logic of the function is considered for versioning. | ||
|
||
### 2. CacheImage | ||
|
||
This implementation includes the hash of the `container_image` object passed. If the image is specified as a name, that string is hashed. If it is an `ImageSpec`, the parametrization of the `ImageSpec` is hashed, allowing for precise versioning of the container image used in the task. | ||
|
||
### 3. CachePrivateModules | ||
|
||
This implementation recursively searches the task of interest for all callables and constants used. The contents of any callable (function or class) utilized by the task are hashed, ignoring formatting or comments. The values of the literal constants used are also included in the hash. | ||
|
||
It accounts for both `import` and `from-import` statements at the global and local levels within a module or function. Any callables that are within site-packages (i.e., external libraries) are ignored. | ||
|
||
### 4. CacheExternalDependencies | ||
|
||
This implementation recursively searches through all the callables like `CachePrivateModules`, but when an external package is found, it records the version of the package, which is included in the hash. This ensures that changes in external dependencies are reflected in the task's versioning. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
""" | ||
.. currentmodule:: flytekitplugins.auto_cache | ||
|
||
This package contains things that are useful when extending Flytekit. | ||
|
||
.. autosummary:: | ||
:template: custom.rst | ||
:toctree: generated/ | ||
|
||
CacheFunctionBody | ||
CachePrivateModules | ||
""" | ||
|
||
from .cache_external_dependencies import CacheExternalDependencies | ||
from .cache_function_body import CacheFunctionBody | ||
from .cache_image import CacheImage | ||
from .cache_private_modules import CachePrivateModules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can
CachePolicy
live in the plugin? It makes sense for the abstractAutoCache
protocol to be defined in flytekit core, but any implementation of it should be in the plugin.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, thanks. i refactored this!