Skip to content

Commit

Permalink
Default to rv_policy::move when binding in-place operators (#803)
Browse files Browse the repository at this point in the history
  • Loading branch information
oremanj authored Nov 28, 2024
1 parent 030a9ca commit f33465c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
6 changes: 6 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ Version TBD (unreleased)
of storing a dangling C++ iterator reference in the returned Python
iterator object. (PR `#788 <https://github.com/wjakob/nanobind/pull/788>`__)

- Bindings for augmented assignment operators (as generated, for example, by
``.def(nb::self += nb::self)``) now return the same object in Python in the
typical case where the C++ operator returns a reference to ``*this``.
Previously, after ``a += b``, ``a`` would be replaced with a copy.
(PR `#803 <https://github.com/wjakob/nanobind/pull/803>`__)

Version 2.2.0 (October 3, 2024)
-------------------------------

Expand Down
8 changes: 6 additions & 2 deletions include/nanobind/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,27 @@ template <op_id id, op_type ot, typename L, typename R> struct op_ {
using Lt = std::conditional_t<std::is_same_v<L, self_t>, Type, L>;
using Rt = std::conditional_t<std::is_same_v<R, self_t>, Type, R>;
using Op = op_impl<id, ot, Type, Lt, Rt>;
cl.def(Op::name(), &Op::execute, is_operator(), extra...);
cl.def(Op::name(), &Op::execute, is_operator(), Op::default_policy, extra...);
}

template <typename Class, typename... Extra> void execute_cast(Class &cl, const Extra&... extra) const {
using Type = typename Class::Type;
using Lt = std::conditional_t<std::is_same_v<L, self_t>, Type, L>;
using Rt = std::conditional_t<std::is_same_v<R, self_t>, Type, R>;
using Op = op_impl<id, ot, Type, Lt, Rt>;
cl.def(Op::name(), &Op::execute_cast, is_operator(), extra...);
cl.def(Op::name(), &Op::execute_cast, is_operator(), Op::default_policy, extra...);
}
};

#define NB_BINARY_OPERATOR(id, rid, op, expr) \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
static constexpr rv_policy default_policy = rv_policy::automatic; \
static char const* name() { return "__" #id "__"; } \
static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \
static B execute_cast(const L &l, const R &r) { return B(expr); } \
}; \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_r, B, L, R> { \
static constexpr rv_policy default_policy = rv_policy::automatic; \
static char const* name() { return "__" #rid "__"; } \
static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \
static B execute_cast(const R &r, const L &l) { return B(expr); } \
Expand All @@ -80,6 +82,7 @@ template <typename T> op_<op_##id, op_r, T, self_t> op(const T &, const self_t &

#define NB_INPLACE_OPERATOR(id, op, expr) \
template <typename B, typename L, typename R> struct op_impl<op_##id, op_l, B, L, R> { \
static constexpr rv_policy default_policy = rv_policy::move; \
static char const* name() { return "__" #id "__"; } \
static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
static B execute_cast(L &l, const R &r) { return B(expr); } \
Expand All @@ -90,6 +93,7 @@ template <typename T> op_<op_##id, op_l, self_t, T> op(const self_t &, const T &

#define NB_UNARY_OPERATOR(id, op, expr) \
template <typename B, typename L> struct op_impl<op_##id, op_u, B, L, undefined_t> { \
static constexpr rv_policy default_policy = rv_policy::automatic; \
static char const* name() { return "__" #id "__"; } \
static auto execute(const L &l) -> decltype(expr) { return expr; } \
static B execute_cast(const L &l) { return B(expr); } \
Expand Down
2 changes: 2 additions & 0 deletions tests/test_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def test14_operators():
assert repr(a - b) == "3"
assert "unsupported operand type" in str(excinfo.value)
assert repr(a - 2) == "-1"
a_before = id(a)
a += b
assert id(a) == a_before
assert repr(a) == "3"
assert repr(b) == "2"

Expand Down

0 comments on commit f33465c

Please sign in to comment.