diff --git a/pyomo/core/base/param.py b/pyomo/core/base/param.py index 03d700140e8..3ef33b9ee45 100644 --- a/pyomo/core/base/param.py +++ b/pyomo/core/base/param.py @@ -162,16 +162,31 @@ def set_value(self, value, idx=NOTSET): # required to be mutable. # _comp = self.parent_component() - if type(value) in native_types: + if value.__class__ in native_types: # TODO: warn/error: check if this Param has units: assigning # a dimensionless value to a united param should be an error pass elif _comp._units is not None: _src_magnitude = expr_value(value) - _src_units = units.get_units(value) - value = units.convert_value( - num_value=_src_magnitude, from_units=_src_units, to_units=_comp._units - ) + # Note: expr_value() could have just registered a new numeric type + if value.__class__ in native_types: + value = _src_magnitude + else: + _src_units = units.get_units(value) + value = units.convert_value( + num_value=_src_magnitude, + from_units=_src_units, + to_units=_comp._units, + ) + # FIXME: we should call value() here [to ensure types get + # registered], but doing so breaks non-numeric Params (which we + # allow). The real fix will be to follow the precedent from + # GetItemExpression and have separate types based on which + # expression "system" the Param should participate in (numeric, + # logical, or structural). + # + # else: + # value = expr_value(value) old_value, self._value = self._value, value try: diff --git a/pyomo/core/base/var.py b/pyomo/core/base/var.py index d03fd0b677f..f426c9c4f55 100644 --- a/pyomo/core/base/var.py +++ b/pyomo/core/base/var.py @@ -384,17 +384,22 @@ def set_value(self, val, skip_validation=False): # # Check if this Var has units: assigning dimensionless # values to a variable with units should be an error - if type(val) not in native_numeric_types: - if self.parent_component()._units is not None: - _src_magnitude = value(val) + if val.__class__ in native_numeric_types: + pass + elif self.parent_component()._units is not None: + _src_magnitude = value(val) + # Note: value() could have just registered a new numeric type + if val.__class__ in native_numeric_types: + val = _src_magnitude + else: _src_units = units.get_units(val) val = units.convert_value( num_value=_src_magnitude, from_units=_src_units, to_units=self.parent_component()._units, ) - else: - val = value(val) + else: + val = value(val) if not skip_validation: if val not in self.domain: diff --git a/pyomo/core/tests/unit/test_numvalue.py b/pyomo/core/tests/unit/test_numvalue.py index eceab3a42d9..bd784d655e8 100644 --- a/pyomo/core/tests/unit/test_numvalue.py +++ b/pyomo/core/tests/unit/test_numvalue.py @@ -562,7 +562,8 @@ def test_numpy_basic_bool_registration(self): @unittest.skipUnless(numpy_available, "This test requires NumPy") def test_automatic_numpy_registration(self): cmd = ( - 'import pyomo; from pyomo.core.base import Var, Param; import numpy as np; ' + 'import pyomo; from pyomo.core.base import Var, Param; ' + 'from pyomo.core.base.units_container import units; import numpy as np; ' 'print(np.float64 in pyomo.common.numeric_types.native_numeric_types); ' '%s; print(np.float64 in pyomo.common.numeric_types.native_numeric_types)' ) @@ -582,6 +583,11 @@ def _tester(expr): _tester('Var() + np.float64(5)') _tester('v = Var(); v.construct(); v.value = np.float64(5)') _tester('p = Param(mutable=True); p.construct(); p.value = np.float64(5)') + _tester('v = Var(units=units.m); v.construct(); v.value = np.float64(5)') + _tester( + 'p = Param(mutable=True, units=units.m); p.construct(); ' + 'p.value = np.float64(5)' + ) if __name__ == "__main__":