Skip to content

Commit

Permalink
Merge pull request #3368 from jsiirola/set-filter-fix
Browse files Browse the repository at this point in the history
Resolve issue in filter/validate deprecation path
  • Loading branch information
blnicho authored Oct 16, 2024
2 parents 459f8e8 + 0fff1e7 commit dc72d1d
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
38 changes: 24 additions & 14 deletions pyomo/core/base/set.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,18 +1484,7 @@ def _cb_validate_filter(self, mode, val_iter):
try:
flag = fcn(block, (), *vstar)
if flag:
deprecation_warning(
f"{self.__class__.__name__} {self.name}: '{mode}=' "
"callback signature matched (block, *value). "
"Please update the callback to match the signature "
f"(block, value{', *index' if comp.is_indexed() else ''}).",
version='6.8.0',
)
orig_fcn = fcn._fcn
fcn = ParameterizedScalarCallInitializer(
lambda m, v: orig_fcn(m, *v), True
)
setattr(comp, '_' + mode, fcn)
self._filter_validate_scalar_api_deprecation(mode, warning=True)
yield value
continue
except TypeError:
Expand Down Expand Up @@ -1536,6 +1525,21 @@ def _cb_validate_filter(self, mode, val_iter):
)
raise exc from None

def _filter_validate_scalar_api_deprecation(self, mode, warning):
comp = self.parent_component()
fcn = getattr(comp, '_' + mode)
if warning:
deprecation_warning(
f"{self.__class__.__name__} {self.name}: '{mode}=' "
"callback signature matched (block, *value). "
"Please update the callback to match the signature "
f"(block, value{', *index' if comp.is_indexed() else ''}).",
version='6.8.0',
)
orig_fcn = fcn._fcn
fcn = ParameterizedScalarCallInitializer(lambda m, v: orig_fcn(m, *v), True)
setattr(comp, '_' + mode, fcn)

def _cb_normalized_dimen_verifier(self, dimen, val_iter):
for value in val_iter:
if value.__class__ in native_types:
Expand Down Expand Up @@ -2256,14 +2260,20 @@ def __init__(self, *args, **kwds):
self._init_values._init = CountedCallInitializer(
self, self._init_values._init
)
# HACK: the DAT parser needs to know the domain of a set in
# order to correctly parse the data stream.

if not self.is_indexed():
# HACK: the DAT parser needs to know the domain of a set in
# order to correctly parse the data stream.
if self._init_domain.constant():
self._domain = self._init_domain(self.parent_block(), None, self)
if self._init_dimen.constant():
self._dimen = self._init_dimen(self.parent_block(), None)

if self._filter.__class__ is ParameterizedIndexedCallInitializer:
self._filter_validate_scalar_api_deprecation('filter', warning=False)
if self._validate.__class__ is ParameterizedIndexedCallInitializer:
self._filter_validate_scalar_api_deprecation('validate', warning=False)

@deprecated(
"check_values() is deprecated: Sets only contain valid members", version='5.7'
)
Expand Down
44 changes: 36 additions & 8 deletions pyomo/core/tests/unit/test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -4181,6 +4181,19 @@ def test_indexed_set(self):
self.assertIs(type(m.I[3]), InsertionOrderSetData)
self.assertEqual(m.I.data(), {1: (4, 2, 5), 2: (4, 2, 5), 3: (4, 2, 5)})

# Explicit (constant dict) construction
m = ConcreteModel()
m.I = Set([1, 2], initialize={1: (4, 2, 5), 2: (7, 6)})
self.assertEqual(len(m.I), 2)
self.assertEqual(list(m.I[1]), [4, 2, 5])
self.assertEqual(list(m.I[2]), [7, 6])
self.assertIsNot(m.I[1], m.I[2])
self.assertTrue(m.I[1].isordered())
self.assertTrue(m.I[2].isordered())
self.assertIs(type(m.I[1]), InsertionOrderSetData)
self.assertIs(type(m.I[2]), InsertionOrderSetData)
self.assertEqual(m.I.data(), {1: (4, 2, 5), 2: (7, 6)})

# Explicit (constant) construction
m = ConcreteModel()
m.I = Set([1, 2, 3], initialize=(4, 2, 5), ordered=Set.SortedOrder)
Expand Down Expand Up @@ -4255,7 +4268,7 @@ def test_indexing(self):
def test_add_filter_validate(self):
m = ConcreteModel()
m.I = Set(domain=Integers)
self.assertIs(m.I.filter, None)
self.assertIs(m.I._filter, None)
with self.assertRaisesRegex(
ValueError,
r"Cannot add value 1.5 to Set I.\n"
Expand Down Expand Up @@ -4302,7 +4315,7 @@ def _l_tri(model, i, j):
return i >= j

m.K = Set(initialize=RangeSet(3) * RangeSet(3), filter=_l_tri)
self.assertIsInstance(m.K.filter, ParameterizedScalarCallInitializer)
self.assertIsInstance(m.K._filter, ParameterizedScalarCallInitializer)
self.assertEqual(list(m.K), [(1, 1), (2, 1), (2, 2), (3, 1), (3, 2), (3, 3)])

output = StringIO()
Expand Down Expand Up @@ -4334,6 +4347,18 @@ def _lt_3(model, i):
self.assertEqual(output.getvalue(), "")
self.assertEqual(list(m.L[2]), [1, 2, 0])

# This tests that the deprecation path works correctly in the
# case that the callback doesn't raise an error or ever return
# False

def _l_off_diag(model, i, j):
self.assertIs(model, m)
return i != j

m.M = Set(initialize=RangeSet(3) * RangeSet(3), filter=_l_off_diag)
self.assertIsInstance(m.M._filter, ParameterizedScalarCallInitializer)
self.assertEqual(list(m.M), [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1), (3, 2)])

m = ConcreteModel()

def _validate(model, val):
Expand Down Expand Up @@ -4374,12 +4399,15 @@ def _validate(model, i, j):
m.I2 = Set(validate=_validate)
with LoggingIntercept(module='pyomo.core') as output:
self.assertTrue(m.I2.add((0, 1)))
self.assertRegex(
output.getvalue().replace('\n', ' '),
r"DEPRECATED: OrderedScalarSet I2: 'validate=' callback "
r"signature matched \(block, \*value\). Please update the "
r"callback to match the signature \(block, value\)",
)
# Note that we are not emitting a deprecation warning (yet)
# for scalar sets
# self.assertEqual(output.getvalue(), "")
# output.getvalue().replace('\n', ' '),
# r"DEPRECATED: OrderedScalarSet I2: 'validate=' callback "
# r"signature matched \(block, \*value\). Please update the "
# r"callback to match the signature \(block, value\)",
# )
self.assertEqual(output.getvalue(), "")
with LoggingIntercept(module='pyomo.core') as output:
with self.assertRaisesRegex(
ValueError,
Expand Down

0 comments on commit dc72d1d

Please sign in to comment.