Skip to content

Commit

Permalink
[WIP] initial attempt at implementing identify_variables with a named…
Browse files Browse the repository at this point in the history
… expression cache
  • Loading branch information
Robbybp committed Mar 14, 2024
1 parent 591f899 commit e1fa256
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 46 deletions.
105 changes: 81 additions & 24 deletions pyomo/core/expr/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,12 +1389,20 @@ def visit(self, node):


class _StreamVariableVisitor(StreamBasedExpressionVisitor):
def __init__(self, include_fixed=False, descend_into_named_expressions=True):
def __init__(
self,
include_fixed=False,
#descend_into_named_expressions=True,
named_expression_cache=None,
):
super().__init__()
self._include_fixed = include_fixed
self._descend_into_named_expressions = descend_into_named_expressions
#self._descend_into_named_expressions = descend_into_named_expressions
self.named_expressions = []
# Should we allow re-use of this visitor for multiple expressions?
if named_expression_cache is None:
named_expression_cache = {}
self._named_expression_cache = named_expression_cache
self._active_named_expressions = []

def initializeWalker(self, expr):
self._variables = []
Expand All @@ -1404,12 +1412,26 @@ def initializeWalker(self, expr):
def beforeChild(self, parent, child, index):
if child.__class__ in native_types:
return False, None
elif (
not self._descend_into_named_expressions
and child.is_named_expression_type()
):
self.named_expressions.append(child)
return False, None
#elif (
# not self._descend_into_named_expressions
# and child.is_named_expression_type()
#):
# self.named_expressions.append(child)
# return False, None
elif child.is_named_expression_type():
if id(child) in self._named_expression_cache:
# We have already encountered this named expression. We just add
# the cached variables to our list and don't descend.
for var in self._named_expression_cache[id(child)][0]:
if id(var) not in self._seen:
self._variables.append(var)
return False, None
else:
# If we are descending into a new named expression, initialize
# a cache to store the expression's local variables.
self._named_expression_cache[id(child)] = ([], set())
self._active_named_expressions.append(id(child))
return True, None
else:
return True, None

Expand All @@ -1418,12 +1440,35 @@ def exitNode(self, node, data):
if id(node) not in self._seen:
self._seen.add(id(node))
self._variables.append(node)
if self._active_named_expressions:
# If we are in a named expression, add new variables to the cache.
eid = self._active_named_expressions[-1]
local_vars, local_var_set = self._named_expression_cache[eid]
if id(node) not in local_var_set:
local_var_set.add(id(node))
local_vars.append(node)
elif node.is_named_expression_type():
# If we are returning from a named expression, we have at least one
# active named expression.
eid = self._active_named_expressions.pop()
if self._active_named_expressions:
# If we still are in a named expression, we update that expression's
# cache with any new variables encountered.
new_eid = self._active_named_expressions[-1]
old_expr_vars, old_expr_var_set = self._named_expression_cache[eid]
new_expr_vars, new_expr_var_set = self._named_expression_cache[new_eid]

for var in old_expr_vars:
if id(var) not in new_expr_var_set:
new_expr_var_set.add(id(var))
new_expr_vars.append(var)

def finalizeResult(self, result):
return self._variables


def identify_variables(expr, include_fixed=True):
# TODO: descend_into_named_expressions option?
def identify_variables(expr, include_fixed=True, named_expression_cache=None):
"""
A generator that yields a sequence of variables
in an expression tree.
Expand All @@ -1437,22 +1482,34 @@ def identify_variables(expr, include_fixed=True):
Yields:
Each variable that is found.
"""
visitor = _VariableVisitor()
if include_fixed:
for v in visitor.xbfs_yield_leaves(expr):
if isinstance(v, tuple):
yield from v
else:
yield v
if named_expression_cache is None:
named_expression_cache = {}

NEW = True
if NEW:
visitor = _StreamVariableVisitor(
named_expression_cache=named_expression_cache,
include_fixed=False,
)
variables = visitor.walk_expression(expr)
yield from variables
else:
for v in visitor.xbfs_yield_leaves(expr):
if isinstance(v, tuple):
for v_i in v:
if not v_i.is_fixed():
yield v_i
else:
if not v.is_fixed():
visitor = _VariableVisitor()
if include_fixed:
for v in visitor.xbfs_yield_leaves(expr):
if isinstance(v, tuple):
yield from v
else:
yield v
else:
for v in visitor.xbfs_yield_leaves(expr):
if isinstance(v, tuple):
for v_i in v:
if not v_i.is_fixed():
yield v_i
else:
if not v.is_fixed():
yield v


# =====================================================
Expand Down
63 changes: 41 additions & 22 deletions pyomo/util/vars_from_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
from pyomo.core import Block
from pyomo.core.expr.visitor import _StreamVariableVisitor
from pyomo.core.expr import identify_variables


def get_vars_from_components(
Expand All @@ -42,32 +43,50 @@ def get_vars_from_components(
descend_into: Ctypes to descend into when finding Constraints
descent_order: Traversal strategy for finding the objects of type ctype
"""
visitor = _StreamVariableVisitor(
include_fixed=include_fixed, descend_into_named_expressions=False
)
variables = []
#visitor = _StreamVariableVisitor(
# include_fixed=include_fixed, descend_into_named_expressions=False
#)
#variables = []
#for constraint in block.component_data_objects(
# ctype,
# active=active,
# sort=sort,
# descend_into=descend_into,
# descent_order=descent_order,
#):
# variables.extend(visitor.walk_expression(constraint.expr))
# seen_named_exprs = set()
# named_expr_stack = list(visitor.named_expressions)
# while named_expr_stack:
# expr = named_expr_stack.pop()
# # Clear visitor's named expression cache so we only identify new
# # named expressions
# visitor.named_expressions.clear()
# variables.extend(visitor.walk_expression(expr.expr))
# for new_expr in visitor.named_expressions:
# if id(new_expr) not in seen_named_exprs:
# seen_named_exprs.add(id(new_expr))
# named_expr_stack.append(new_expr)
#seen = set()
#for var in variables:
# if id(var) not in seen:
# seen.add(id(var))
# yield var

seen = set()
named_expression_cache = {}
for constraint in block.component_data_objects(
ctype,
active=active,
sort=sort,
descend_into=descend_into,
descent_order=descent_order,
):
variables.extend(visitor.walk_expression(constraint.expr))
seen_named_exprs = set()
named_expr_stack = list(visitor.named_expressions)
while named_expr_stack:
expr = named_expr_stack.pop()
# Clear visitor's named expression cache so we only identify new
# named expressions
visitor.named_expressions.clear()
variables.extend(visitor.walk_expression(expr.expr))
for new_expr in visitor.named_expressions:
if id(new_expr) not in seen_named_exprs:
seen_named_exprs.add(id(new_expr))
named_expr_stack.append(new_expr)
seen = set()
for var in variables:
if id(var) not in seen:
seen.add(id(var))
yield var
for var in identify_variables(
constraint.expr,
include_fixed=include_fixed,
named_expression_cache=named_expression_cache,
):
if id(var) not in seen:
seen.add(id(var))
yield var

0 comments on commit e1fa256

Please sign in to comment.