Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Cache Plugin #2971

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions flytekit/core/auto_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from dataclasses import dataclass
from typing import Any, Callable, Optional, Protocol, Union, runtime_checkable

from flytekit.image_spec.image_spec import ImageSpec


@dataclass
class VersionParameters:
"""
Parameters used for version hash generation.

Args:
func (Optional[Callable]): The function to generate a version for
container_image (Optional[Union[str, ImageSpec]]): The container image to generate a version for
"""

func: Optional[Callable[..., Any]] = None
container_image: Optional[Union[str, ImageSpec]] = None


@runtime_checkable
class AutoCache(Protocol):
"""
A protocol that defines the interface for a caching mechanism
that generates a version hash of a function based on its source code.
"""

salt: str

def get_version(self, params: VersionParameters) -> str:
"""
Generate a version hash based on the provided parameters.

Args:
params (VersionParameters): Parameters to use for hash generation.

Returns:
str: The generated version hash.
"""
...
37 changes: 24 additions & 13 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.auto_cache import AutoCache, VersionParameters
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.interface import Interface, output_name_generator, transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
Expand Down Expand Up @@ -95,7 +96,7 @@ def find_pythontask_plugin(cls, plugin_config_type: type) -> Type[PythonFunction
def task(
_task_function: None = ...,
task_config: Optional[T] = ...,
cache: bool = ...,
cache: Union[bool, AutoCache] = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
Expand Down Expand Up @@ -132,9 +133,9 @@ def task(

@overload
def task(
_task_function: Callable[P, FuncOut],
_task_function: Callable[..., FuncOut],
task_config: Optional[T] = ...,
cache: bool = ...,
cache: Union[bool, AutoCache] = ...,
cache_serialize: bool = ...,
cache_version: str = ...,
cache_ignore_input_vars: Tuple[str, ...] = ...,
Expand Down Expand Up @@ -166,13 +167,13 @@ def task(
pod_template_name: Optional[str] = ...,
accelerator: Optional[BaseAccelerator] = ...,
pickle_untyped: bool = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionTask[T]]: ...
) -> Union[Callable[..., FuncOut], PythonFunctionTask[T]]: ...


def task(
_task_function: Optional[Callable[P, FuncOut]] = None,
_task_function: Optional[Callable[..., FuncOut]] = None,
task_config: Optional[T] = None,
cache: bool = False,
cache: Union[bool, AutoCache] = False,
cache_serialize: bool = False,
cache_version: str = "",
cache_ignore_input_vars: Tuple[str, ...] = (),
Expand Down Expand Up @@ -211,8 +212,8 @@ def task(
accelerator: Optional[BaseAccelerator] = None,
pickle_untyped: bool = False,
) -> Union[
Callable[P, FuncOut],
Callable[[Callable[P, FuncOut]], PythonFunctionTask[T]],
Callable[..., FuncOut],
Callable[[Callable[..., FuncOut]], PythonFunctionTask[T]],
PythonFunctionTask[T],
]:
"""
Expand Down Expand Up @@ -247,7 +248,7 @@ def my_task(x: int, y: typing.Dict[str, str]) -> str:
:param _task_function: This argument is implicitly passed and represents the decorated function
:param task_config: This argument provides configuration for a specific task types.
Please refer to the plugins documentation for the right object to use.
:param cache: Boolean that indicates if caching should be enabled
:param cache: Boolean that indicates if caching should be enabled or a list of AutoCache implementations
:param cache_serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be
executed in serial when caching is enabled. This means that given multiple concurrent executions over
identical inputs, only a single instance executes and the rest wait to reuse the cached results. This
Expand Down Expand Up @@ -342,11 +343,21 @@ def launch_dynamically():
:param pickle_untyped: Boolean that indicates if the task allows unspecified data types.
"""

def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
def wrapper(fn: Callable[..., Any]) -> PythonFunctionTask[T]:
# Initialize defaults
cache_val = cache
cache_version_val = cache_version

if isinstance(cache, (AutoCache)):
# If cache is a AutoCache, enable caching
cache_val = True
params = VersionParameters(func=fn, container_image=container_image)
cache_version_val = cache_version or cache.get_version(params=params)

_metadata = TaskMetadata(
cache=cache,
cache=cache_val,
cache_serialize=cache_serialize,
cache_version=cache_version,
cache_version=cache_version_val,
cache_ignore_input_vars=cache_ignore_input_vars,
retries=retries,
interruptible=interruptible,
Expand Down Expand Up @@ -433,7 +444,7 @@ def wrapper(fn) -> ReferenceTask:
return wrapper


def decorate_function(fn: Callable[P, Any]) -> Callable[P, Any]:
def decorate_function(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Decorates the task with additional functionality if necessary.

Expand Down
10 changes: 5 additions & 5 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,25 +843,25 @@ def workflow(

@overload
def workflow(
_workflow_function: Callable[P, FuncOut],
_workflow_function: Callable[..., FuncOut],
failure_policy: Optional[WorkflowFailurePolicy] = ...,
interruptible: bool = ...,
on_failure: Optional[Union[WorkflowBase, Task]] = ...,
docs: Optional[Documentation] = ...,
pickle_untyped: bool = ...,
default_options: Optional[Options] = ...,
) -> Union[Callable[P, FuncOut], PythonFunctionWorkflow]: ...
) -> Union[Callable[..., FuncOut], PythonFunctionWorkflow]: ...


def workflow(
_workflow_function: Optional[Callable[P, FuncOut]] = None,
_workflow_function: Optional[Callable[..., FuncOut]] = None,
failure_policy: Optional[WorkflowFailurePolicy] = None,
interruptible: bool = False,
on_failure: Optional[Union[WorkflowBase, Task]] = None,
docs: Optional[Documentation] = None,
pickle_untyped: bool = False,
default_options: Optional[Options] = None,
) -> Union[Callable[P, FuncOut], Callable[[Callable[P, FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
) -> Union[Callable[..., FuncOut], Callable[[Callable[..., FuncOut]], PythonFunctionWorkflow], PythonFunctionWorkflow]:
"""
This decorator declares a function to be a Flyte workflow. Workflows are declarative entities that construct a DAG
of tasks using the data flow between tasks.
Expand Down Expand Up @@ -898,7 +898,7 @@ def workflow(
the labels and annotations are allowed to be set as defaults.
"""

def wrapper(fn: Callable[P, FuncOut]) -> PythonFunctionWorkflow:
def wrapper(fn: Callable[..., FuncOut]) -> PythonFunctionWorkflow:
workflow_metadata = WorkflowMetadata(on_failure=failure_policy or WorkflowFailurePolicy.FAIL_IMMEDIATELY)

workflow_metadata_defaults = WorkflowMetadataDefaults(interruptible)
Expand Down
62 changes: 62 additions & 0 deletions plugins/flytekit-auto-cache/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Flyte Auto Cache Plugin

This plugin provides a caching mechanism for Flyte tasks that generates a version hash based on the source code of the task and its dependencies. It allows users to manage the cache behavior.

## Usage

To install the plugin, run the following command:

```bash
pip install flytekitplugins-auto-cache
```

To use the caching mechanism in a Flyte task, you can define a `CachePolicy` that combines multiple caching strategies. Here’s an example of how to set it up:

```python
from flytekit import task
from flytekit.core.auto_cache import CachePolicy
from flytekitplugins.auto_cache import CacheFunctionBody, CachePrivateModules

cache_policy = CachePolicy(
auto_cache_policies = [
CacheFunctionBody(),
CachePrivateModules(root_dir="../my_package"),
...,
],
salt="my_salt"
)

@task(cache=cache_policy)
def task_fn():
...

@task(cache=CacheFunctionBody())
def other_task_fn():
...
```

### Salt Parameter

The `salt` parameter in the `CachePolicy` adds uniqueness to the generated hash. It can be used to differentiate between different versions of the same task. This ensures that even if the underlying code remains unchanged, the hash will vary if a different salt is provided. This feature is particularly useful for invalidating the cache for specific versions of a task.

## Cache Implementations

Users can add any number of cache policies that implement the `AutoCache` protocol defined in `@auto_cache.py`. Below are the implementations available so far:

### 1. CacheFunctionBody

This implementation hashes the contents of the function of interest, ignoring any formatting or comment changes. It ensures that the core logic of the function is considered for versioning.

### 2. CacheImage

This implementation includes the hash of the `container_image` object passed. If the image is specified as a name, that string is hashed. If it is an `ImageSpec`, the parametrization of the `ImageSpec` is hashed, allowing for precise versioning of the container image used in the task.

### 3. CachePrivateModules

This implementation recursively searches the task of interest for all callables and constants used. The contents of any callable (function or class) utilized by the task are hashed, ignoring formatting or comments. The values of the literal constants used are also included in the hash.

It accounts for both `import` and `from-import` statements at the global and local levels within a module or function. Any callables that are within site-packages (i.e., external libraries) are ignored.

### 4. CacheExternalDependencies

This implementation recursively searches through all the callables like `CachePrivateModules`, but when an external package is found, it records the version of the package, which is included in the hash. This ensures that changes in external dependencies are reflected in the task's versioning.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
.. currentmodule:: flytekitplugins.auto_cache

This package contains things that are useful when extending Flytekit.

.. autosummary::
:template: custom.rst
:toctree: generated/

CacheFunctionBody
CachePrivateModules
"""

from .cache_external_dependencies import CacheExternalDependencies
from .cache_function_body import CacheFunctionBody
from .cache_image import CacheImage
from .cache_private_modules import CachePrivateModules
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import hashlib
import importlib
import sys
from pathlib import Path
from typing import Any, Optional

import click
from flytekitplugins.auto_cache.cache_private_modules import CachePrivateModules, temporarily_add_to_syspath

from flytekit.core.auto_cache import VersionParameters


class CacheExternalDependencies(CachePrivateModules):
"""
A cache implementation that tracks external package dependencies and their versions.
Inherits the dependency traversal logic from CachePrivateModules but focuses on external packages.
"""

def __init__(self, root_dir: str, salt: str = ""):
super().__init__(salt=salt, root_dir=root_dir)
self._package_versions = {} # Cache for package versions
self._external_dependencies = set()

def get_version_dict(self) -> dict[str, str]:
"""
Get a dictionary mapping package names to their versions.

Returns:
dict[str, str]: Dictionary mapping package names to version strings
"""
versions = {}
for package in sorted(self._external_dependencies):
version = self._get_package_version(package)
if version:
versions[package] = version
return versions

def get_version(self, params: VersionParameters) -> str:
if params.func is None:
raise ValueError("Function-based cache requires a function parameter")

# Get all dependencies including nested function calls
_ = self._get_function_dependencies(params.func, set())

# Get package versions and create version string
versions = self.get_version_dict()
version_components = [f"{pkg}=={ver}" for pkg, ver in versions.items()]

# Combine package versions with salt
combined_data = "|".join(version_components).encode("utf-8") + self.salt.encode("utf-8")
return hashlib.sha256(combined_data).hexdigest()

def _is_user_defined(self, obj: Any) -> bool:
"""
Similar to the parent, this method checks if a callable or class is user-defined within the package.
If it identifies a non-user-defined package, it adds the external dependency to a list of packages
for which we will check their versions and hash.
"""
if isinstance(obj, type(sys)): # Check if the object is a module
module_name = obj.__name__
else:
module_name = getattr(obj, "__module__", None)
if not module_name:
return False

# Retrieve the module specification to get its path
with temporarily_add_to_syspath(self.root_dir):
spec = importlib.util.find_spec(module_name)
if not spec or not spec.origin:
return False

module_path = Path(spec.origin).resolve()

site_packages_paths = {Path(p).resolve() for p in sys.path if "site-packages" in p}
is_in_site_packages = any(sp in module_path.parents for sp in site_packages_paths)

# If it's in site-packages, add the module name to external dependencies
if is_in_site_packages:
root_package = module_name.split(".")[0]
self._external_dependencies.add(root_package)

# Check if the module is within the root directory but not in site-packages
if self.root_dir in module_path.parents:
# Exclude standard library or site-packages by checking common paths but return True if within root_dir but not in site-packages
return not is_in_site_packages

return False

def _get_package_version(self, package_name: str) -> str:
"""
Get the version of an installed package.

Args:
package_name: Name of the package

Returns:
str: Version string of the package or "unknown" if version cannot be determined
"""
if package_name in self._package_versions:
return self._package_versions[package_name]

version: Optional[str] = None
try:
# Try importlib.metadata first (most reliable)
version = importlib.metadata.version(package_name)
except Exception as e:
click.secho(f"Could not get version for {package_name} using importlib.metadata: {str(e)}", fg="yellow")
Comment on lines +106 to +107
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too broad exception handling

Catching a broad 'Exception' may hide bugs. Consider catching specific exceptions instead.

Code suggestion
Check the AI-generated fix before applying
Suggested change
except Exception as e:
click.secho(f"Could not get version for {package_name} using importlib.metadata: {str(e)}", fg="yellow")
except (ImportError, AttributeError) as e:
click.secho(f"Could not get version for {package_name} using importlib.metadata: {str(e)}", fg="yellow")

Code Review Run #bc105b


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

try:
# Fall back to checking package attributes
package = importlib.import_module(package_name)
version = getattr(package, "__version__", None)
if not version:
version = getattr(package, "version", None)
click.secho(f"Found by {package_name} importing module.", fg="yellow")
except ImportError as e:
click.secho(f"Could not import {package_name}: {str(e)}", fg="yellow")

if not version:
click.secho(
f"Could not determine version for package {package_name}. " "This may affect cache invalidation.",
fg="yellow",
)
version = "unknown"

self._package_versions[package_name] = version
return version
Loading
Loading