Skip to content
This repository has been archived by the owner on Apr 22, 2020. It is now read-only.

Add the mirror_defaults decorator #185

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- `mirror_defaults` decorator for mirroring the default arguments of another
function.

## [0.3.0] - 2019-06-10

### Added
Expand Down
60 changes: 60 additions & 0 deletions easypy/decorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from functools import wraps, partial, update_wrapper
from operator import attrgetter
from abc import ABCMeta, abstractmethod
import inspect

from .tokens import AUTO


def parametrizeable_decorator(deco):
Expand Down Expand Up @@ -138,3 +141,60 @@ def foo(self):
def wrapper(func):
return LazyDecoratorDescriptor(decorator_factory, func, cached)
return wrapper


def mirror_defaults(mirrored):
"""
Copy the default values of arguments from another function.

Set an argument's default to ``AUTO`` to copy the default value from the
mirrored function.

>>> from easypy.decorations import mirror_defaults

>>> def foo(a=1, b=2, c=3):
... print(a, b, c)

>>> @mirror_defaults(foo)
... def bar(a=AUTO, b=4, c=AUTO):
... foo(a, b, c)

>>> bar()
1 4 3
"""
defaults = {
p.name: p.default
for p in inspect.signature(mirrored).parameters.values()
if p.default is not inspect._empty}

def new_params_generator(params, defaults_to_override):
for param in params:
if param.default is AUTO:
try:
default_value = defaults[param.name]
except KeyError:
raise TypeError('%s has no default value for %s' % (mirrored.__name__, param.name))
defaults_to_override.add(param.name)
yield param.replace(default=default_value)
else:
yield param

def outer(func):
orig_signature = inspect.signature(func)
defaults_to_override = set()
new_parameters = new_params_generator(orig_signature.parameters.values(), defaults_to_override)
new_signature = orig_signature.replace(parameters=new_parameters)

@wraps(func)
def inner(*args, **kwargs):
binding = new_signature.bind(*args, **kwargs)

# NOTE: `apply_defaults` was added in Python 3.5, so we cannot use it
for name in defaults_to_override - binding.arguments.keys():
binding.arguments[name] = defaults[name]

return func(*binding.args, **binding.kwargs)
inner.signature = new_signature

return inner
return outer
26 changes: 25 additions & 1 deletion tests/test_decorations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from functools import wraps

from easypy.decorations import lazy_decorator
from easypy.decorations import lazy_decorator, mirror_defaults
from easypy.misc import kwargs_resilient
from easypy.tokens import AUTO


def test_kwargs_resilient():
Expand Down Expand Up @@ -136,3 +137,26 @@ def counter(self):
foo2.ts += 1
assert [foo1.inc(), foo2.inc()] == [2, 2]
assert [foo1.counter, foo2.counter] == [1, 2] # foo1 was not updated since last sync - only foo2


def test_mirror_defaults():
def foo(a, b, c=1, d=2, *args, e=3, f=4, **kwargs):
return locals()

@mirror_defaults(foo)
def bar(a, b=100, c=AUTO, d=20, *args, e=AUTO, f=40, **kwargs):
return foo(a, b, c, d, *args, e=e, f=f, **kwargs)

assert bar(300) == dict(
a=300, b=100,
c=1, d=20,
args=(),
e=3, f=40,
kwargs={})

assert bar(300, 400, 500, 600, 700, e=800, f=900, g=1000) == dict(
a=300, b=400,
c=500, d=600,
args=(700,),
e=800, f=900,
kwargs=dict(g=1000))