Skip to content

Commit

Permalink
Merge pull request Pyomo#3151 from jsiirola/numpy-registration-and-va…
Browse files Browse the repository at this point in the history
…r-units

Fix edge case assigning new numeric types to Var/Param with units
  • Loading branch information
blnicho authored Feb 20, 2024
2 parents 922ea8e + 924c38a commit 21a718a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
25 changes: 20 additions & 5 deletions pyomo/core/base/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,31 @@ def set_value(self, value, idx=NOTSET):
# required to be mutable.
#
_comp = self.parent_component()
if type(value) in native_types:
if value.__class__ in native_types:
# TODO: warn/error: check if this Param has units: assigning
# a dimensionless value to a united param should be an error
pass
elif _comp._units is not None:
_src_magnitude = expr_value(value)
_src_units = units.get_units(value)
value = units.convert_value(
num_value=_src_magnitude, from_units=_src_units, to_units=_comp._units
)
# Note: expr_value() could have just registered a new numeric type
if value.__class__ in native_types:
value = _src_magnitude
else:
_src_units = units.get_units(value)
value = units.convert_value(
num_value=_src_magnitude,
from_units=_src_units,
to_units=_comp._units,
)
# FIXME: we should call value() here [to ensure types get
# registered], but doing so breaks non-numeric Params (which we
# allow). The real fix will be to follow the precedent from
# GetItemExpression and have separate types based on which
# expression "system" the Param should participate in (numeric,
# logical, or structural).
#
# else:
# value = expr_value(value)

old_value, self._value = self._value, value
try:
Expand Down
15 changes: 10 additions & 5 deletions pyomo/core/base/var.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,17 +384,22 @@ def set_value(self, val, skip_validation=False):
#
# Check if this Var has units: assigning dimensionless
# values to a variable with units should be an error
if type(val) not in native_numeric_types:
if self.parent_component()._units is not None:
_src_magnitude = value(val)
if val.__class__ in native_numeric_types:
pass
elif self.parent_component()._units is not None:
_src_magnitude = value(val)
# Note: value() could have just registered a new numeric type
if val.__class__ in native_numeric_types:
val = _src_magnitude
else:
_src_units = units.get_units(val)
val = units.convert_value(
num_value=_src_magnitude,
from_units=_src_units,
to_units=self.parent_component()._units,
)
else:
val = value(val)
else:
val = value(val)

if not skip_validation:
if val not in self.domain:
Expand Down
8 changes: 7 additions & 1 deletion pyomo/core/tests/unit/test_numvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ def test_numpy_basic_bool_registration(self):
@unittest.skipUnless(numpy_available, "This test requires NumPy")
def test_automatic_numpy_registration(self):
cmd = (
'import pyomo; from pyomo.core.base import Var, Param; import numpy as np; '
'import pyomo; from pyomo.core.base import Var, Param; '
'from pyomo.core.base.units_container import units; import numpy as np; '
'print(np.float64 in pyomo.common.numeric_types.native_numeric_types); '
'%s; print(np.float64 in pyomo.common.numeric_types.native_numeric_types)'
)
Expand All @@ -582,6 +583,11 @@ def _tester(expr):
_tester('Var() + np.float64(5)')
_tester('v = Var(); v.construct(); v.value = np.float64(5)')
_tester('p = Param(mutable=True); p.construct(); p.value = np.float64(5)')
_tester('v = Var(units=units.m); v.construct(); v.value = np.float64(5)')
_tester(
'p = Param(mutable=True, units=units.m); p.construct(); '
'p.value = np.float64(5)'
)


if __name__ == "__main__":
Expand Down

0 comments on commit 21a718a

Please sign in to comment.