Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unpack dictionary parameters #3905

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Major features and improvements
* Kedro commands are now lazily loaded to add performance gains when running Kedro commands.
* Can use unpacking with parameter dictionaries.

## Bug fixes and other changes
* Updated error message for invalid catalog entries.
Expand All @@ -22,6 +23,8 @@
* Extended documentation with an example of logging customisation at runtime

## Community contributions
Many thanks to the following Kedroids for contributing PRs to this release:
* [bpmeek](https://github.com/bpmeek/)

# Release 0.19.6

Expand Down
13 changes: 12 additions & 1 deletion kedro/pipeline/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,11 @@ def _get_dataset_names_mapping(

def _normalize_param_name(name: str) -> str:
"""Make sure that a param name has a `params:` prefix before passing to the node"""
return name if name.startswith("params:") else f"params:{name}"
return (
name
if name.startswith("params:") or name.startswith("**params:")
else f"params:{name}"
)


def _get_param_names_mapping(
Expand Down Expand Up @@ -251,6 +255,11 @@ def _map_transcode_base(name: str) -> str:
base_name, transcode_suffix = _transcode_split(name)
return TRANSCODING_SEPARATOR.join((mapping[base_name], transcode_suffix))

def _matches_unpackable(name: str) -> bool:
param_base = name.split(".")[0]
matches = [True for key, value in mapping.items() if f"**{param_base}" in key]
return any(matches)

def _rename(name: str) -> str:
rules = [
# if name mapped to new name, update with new name
Expand All @@ -259,6 +268,8 @@ def _rename(name: str) -> str:
(_is_all_parameters, lambda n: n),
# if transcode base is mapped to a new name, update with new base
(_is_transcode_base_in_mapping, _map_transcode_base),
# if name refers to dictionary to be unpacked, leave as is
(lambda n: _matches_unpackable(name), lambda n: n),
# if name refers to a single parameter and a namespace is given, apply prefix
(lambda n: bool(namespace) and _is_single_parameter(n), _prefix_param),
# if namespace given for a dataset, prefix name using that namespace
Expand Down
30 changes: 30 additions & 0 deletions kedro/pipeline/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def __init__( # noqa: PLR0913
_node_error_message("it must have some 'inputs' or 'outputs'.")
)

inputs = _unpacked_params(func, inputs)

self._validate_inputs(func, inputs)

self._func = func
Expand Down Expand Up @@ -683,3 +685,31 @@ def _get_readable_func_name(func: Callable) -> str:
name = "<partial>"

return name


def _unpacked_params(
func: Callable, inputs: None | str | list[str] | dict[str, str]
) -> None | str | list[str] | dict[str, str]:
"""Iterate over Node inputs to see if they need to be unpacked.

Returns:
Either original inputs if no input was unpacked or a list of inputs if an input was unpacked.
"""
use_new = False
bpmeek marked this conversation as resolved.
Show resolved Hide resolved
new_inputs = []
_func_arguments = [arg for arg in inspect.signature(func).parameters]
for idx, _input in enumerate(_to_list(inputs)):
if _input.startswith("**params"):
if "**" in str(inspect.signature(func)):
raise TypeError(
"Function side dictionary unpacking is currently incompatible with parameter dictionary unpacking."
)
use_new = True
dict_root = _input.split(":")[-1]
for param in _func_arguments[idx:]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of slicing the list, we can just use idx to start looping from it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand what you're recommending.

new_inputs.append(f"params:{dict_root}.{param}")
else:
new_inputs.append(_input)
if use_new:
return new_inputs
return inputs
49 changes: 49 additions & 0 deletions tests/pipeline/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ def triconcat(input1: str, input2: str, input3: str):
return input1 + input2 + input3 # pragma: no cover


def dict_unpack():
return dict(input2="input2", input3="another node")


def kwargs_input(**kwargs):
return kwargs


@pytest.fixture
def simple_tuple_node_list():
return [
Expand Down Expand Up @@ -125,6 +133,47 @@ def test_inputs_list(self):
)
assert dummy_node.inputs == ["input1", "input2", "another node"]

def test_inputs_unpack_str(self):
dummy_node = node(triconcat, inputs="**params:dict_unpack", outputs="output1")
assert dummy_node.inputs == [
"params:dict_unpack.input1",
"params:dict_unpack.input2",
"params:dict_unpack.input3",
]

def test_inputs_unpack_list(self):
dummy_node = node(
triconcat,
inputs=["input1", "**params:dict_unpack"],
outputs=["output1", "output2", "last node"],
)
assert dummy_node.inputs == [
"input1",
"params:dict_unpack.input2",
"params:dict_unpack.input3",
]

def test_inputs_unpack_dict(self):
dummy_node = node(
triconcat,
inputs={"input1": "**params:dict_unpack"},
outputs=["output1", "output2", "last node"],
)
assert dummy_node.inputs == [
"params:dict_unpack.input1",
"params:dict_unpack.input2",
"params:dict_unpack.input3",
]

def test_kwargs_node_negative(self):
pattern = "parameter dictionary unpacking"
with pytest.raises(TypeError, match=pattern):
node(
kwargs_input,
inputs="**params:dict_unpack",
outputs="output1",
)

def test_outputs_none(self):
dummy_node = node(identity, "input", None)
assert dummy_node.outputs == []
Expand Down