Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Flow decorator #547

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/jobflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Jobflow is a package for writing dynamic and connected workflows."""

from jobflow._version import __version__
from jobflow.core.flow import Flow, JobOrder
from jobflow.core.flow import Flow, JobOrder, flow
from jobflow.core.job import Job, JobConfig, Response, job
from jobflow.core.maker import Maker
from jobflow.core.reference import OnMissing, OutputReference
Expand Down
20 changes: 20 additions & 0 deletions src/jobflow/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from monty.json import MSONable

import jobflow
from jobflow.core.job import Steps
from jobflow.core.reference import find_and_get_references
from jobflow.utils import ValueEnum, contains_flow_or_job, suid

Expand Down Expand Up @@ -874,3 +875,22 @@
)

return flow


def flow(method=None):
"""Wrap a function to produce a :obj:`Flow`."""
steps = Steps()
istart = len(steps)

def wrapper(*args, **kwargs):
method_out = method(*args, **kwargs)
# Here deal with when it is already a Flow
if isinstance(method_out, Flow):
f = method_out

Check warning on line 889 in src/jobflow/core/flow.py

View check run for this annotation

Codecov / codecov/patch

src/jobflow/core/flow.py#L889

Added line #L889 was not covered by tests
else:
flow_steps = steps.steps[istart:]
f = Flow(flow_steps, output=method_out)
steps.add(f)
return f

return wrapper
89 changes: 74 additions & 15 deletions src/jobflow/core/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import typing
import warnings
from dataclasses import dataclass, field

from monty.json import MSONable, jsanitize
Expand All @@ -21,6 +20,25 @@

import jobflow

from monty.design_patterns import singleton


@singleton
class Steps:
"""Steps class."""

def __init__(self):
self.steps = []

def add(self, step):
"""Add one step to the list."""
self.steps.append(step)

def __len__(self):
"""Return the number of steps."""
return len(self.steps)


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -163,6 +181,7 @@
--------
Job, .Flow, .Response
"""
steps = Steps()

def decorator(func):
from functools import wraps
Expand Down Expand Up @@ -192,10 +211,14 @@
f = met
args = args[1:]

return Job(
j = Job(
function=f, function_args=args, function_kwargs=kwargs, **job_kwargs
)

steps.add(j)

return j

get_job.original = func

if desc:
Expand Down Expand Up @@ -317,8 +340,6 @@
):
from copy import deepcopy

from jobflow.utils.find import contains_flow_or_job

function_args = () if function_args is None else function_args
function_kwargs = {} if function_kwargs is None else function_kwargs
uuid = suid() if uuid is None else uuid
Expand Down Expand Up @@ -351,17 +372,17 @@

self.output = OutputReference(self.uuid, output_schema=self.output_schema)

# check to see if job or flow is included in the job args
# this is a possible situation but likely a mistake
all_args = tuple(self.function_args) + tuple(self.function_kwargs.values())
if contains_flow_or_job(all_args):
warnings.warn(
f"Job '{self.name}' contains an Flow or Job as an input. "
f"Usually inputs should be the output of a Job or an Flow (e.g. "
f"job.output). If this message is unexpected then double check the "
f"inputs to your Job.",
stacklevel=2,
)
# check to see if job is included in the job args
self.function_args = tuple(
[
arg.output if isinstance(arg, Job) else arg
for arg in list(self.function_args)
]
)
self.function_kwargs = {
arg: v.output if isinstance(v, Job) else v
for arg, v in self.function_kwargs.items()
}

def __repr__(self):
"""Get a string representation of the job."""
Expand Down Expand Up @@ -406,6 +427,44 @@
"""Get the hash of the job."""
return hash(self.uuid)

def __getitem__(self, key: Any) -> OutputReference:
"""
Get the corresponding `OutputReference` for the `Job`.

This is for when it is indexed like a dictionary or list.

Parameters
----------
key
The index/key.

Returns
-------
OutputReference
The equivalent of `Job.output[k]`
"""
return self.output[key]

Check warning on line 446 in src/jobflow/core/job.py

View check run for this annotation

Codecov / codecov/patch

src/jobflow/core/job.py#L446

Added line #L446 was not covered by tests

def __getattr__(self, name: str) -> OutputReference:
"""
Get the corresponding `OutputReference` for the `Job`.

This is for when it is indexed like a class attribute.

Parameters
----------
name
The name of the attribute.

Returns
-------
OutputReference
The equivalent of `Job.output.name`
"""
if attr := getattr(self.output, name, None):
return attr
raise AttributeError(f"{type(self).__name__} has no attribute {name!r}")

@property
def input_references(self) -> tuple[jobflow.OutputReference, ...]:
"""
Expand Down
Loading
Loading