From f8cfe3562c3c116d09e11596d0c4dbd2c6c0596d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 26 Jul 2023 11:31:30 +0200 Subject: [PATCH] Use graph_replace instead of clone_replace in VI --- pymc/variational/approximations.py | 3 ++- pymc/variational/opvi.py | 33 +++++++++++++++++++----------- pymc/variational/stein.py | 6 ++++-- 3 files changed, 27 insertions(+), 15 deletions(-) diff --git a/pymc/variational/approximations.py b/pymc/variational/approximations.py index d00d893ff54..d271e804485 100644 --- a/pymc/variational/approximations.py +++ b/pymc/variational/approximations.py @@ -19,6 +19,7 @@ from arviz import InferenceData from pytensor import tensor as pt from pytensor.graph.basic import Variable +from pytensor.graph.replace import graph_replace from pytensor.tensor.var import TensorVariable import pymc as pm @@ -390,7 +391,7 @@ def evaluate_over_trace(self, node): node = self.to_flat_input(node) def sample(post, *_): - return pytensor.clone_replace(node, {self.input: post}) + return graph_replace(node, {self.input: post}, strict=False) nodes, _ = pytensor.scan( sample, self.histogram, non_sequences=_known_scan_ignored_inputs(makeiter(node)) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 7280105d032..f81a6d6a235 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -59,6 +59,8 @@ import xarray from pytensor.graph.basic import Variable +from pytensor.graph.replace import graph_replace +from pytensor.tensor.shape import unbroadcast import pymc as pm @@ -1002,7 +1004,7 @@ def set_size_and_deterministic( """ flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements) - node_out = pytensor.clone_replace(node, flat2rand) + node_out = graph_replace(node, flat2rand, strict=False) assert not ( set(makeiter(self.input)) & set(pytensor.graph.graph_inputs(makeiter(node_out))) ) @@ -1012,7 +1014,7 @@ def set_size_and_deterministic( def to_flat_input(self, node): """*Dev* - replace vars with flattened view stored in `self.inputs`""" - return pytensor.clone_replace(node, self.replacements) + return graph_replace(node, self.replacements, strict=False) def symbolic_sample_over_posterior(self, node): """*Dev* - performs sampling of node applying independent samples from posterior each time. @@ -1023,7 +1025,7 @@ def symbolic_sample_over_posterior(self, node): random = pt.specify_shape(random, self.symbolic_initial.type.shape) def sample(post, *_): - return pytensor.clone_replace(node, {self.input: post}) + return graph_replace(node, {self.input: post}, strict=False) nodes, _ = pytensor.scan( sample, random, non_sequences=_known_scan_ignored_inputs(makeiter(random)) @@ -1038,7 +1040,7 @@ def symbolic_single_sample(self, node): """ node = self.to_flat_input(node) random = self.symbolic_random.astype(self.symbolic_initial.dtype) - return pytensor.clone_replace(node, {self.input: random[0]}) + return graph_replace(node, {self.input: random[0]}, strict=False) def make_size_and_deterministic_replacements(self, s, d, more_replacements=None): """*Dev* - creates correct replacements for initial depending on @@ -1059,8 +1061,15 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None) """ initial = self._new_initial(s, d, more_replacements) initial = pt.specify_shape(initial, self.symbolic_initial.type.shape) + # The static shape of initial may be more precise than self.symbolic_initial, + # and reveal previously unknown broadcastable dimensions. We have to mask those again. + if initial.type.broadcastable != self.symbolic_initial.type.broadcastable: + unbroadcast_axes = ( + i for i, b in enumerate(self.symbolic_initial.type.broadcastable) if not b + ) + initial = unbroadcast(initial, *unbroadcast_axes) if more_replacements: - initial = pytensor.clone_replace(initial, more_replacements) + initial = graph_replace(initial, more_replacements, strict=False) return {self.symbolic_initial: initial} @node_property @@ -1394,8 +1403,8 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None): _node = node optimizations = self.get_optimization_replacements(s, d) flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements) - node = pytensor.clone_replace(node, optimizations) - node = pytensor.clone_replace(node, flat2rand) + node = graph_replace(node, optimizations, strict=False) + node = graph_replace(node, flat2rand, strict=False) assert not (set(self.symbolic_randoms) & set(pytensor.graph.graph_inputs(makeiter(node)))) try_to_set_test_value(_node, node, s) return node @@ -1403,8 +1412,8 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None): def to_flat_input(self, node, more_replacements=None): """*Dev* - replace vars with flattened view stored in `self.inputs`""" more_replacements = more_replacements or {} - node = pytensor.clone_replace(node, more_replacements) - return pytensor.clone_replace(node, self.replacements) + node = graph_replace(node, more_replacements, strict=False) + return graph_replace(node, self.replacements, strict=False) def symbolic_sample_over_posterior(self, node, more_replacements=None): """*Dev* - performs sampling of node applying independent samples from posterior each time. @@ -1413,7 +1422,7 @@ def symbolic_sample_over_posterior(self, node, more_replacements=None): node = self.to_flat_input(node) def sample(*post): - return pytensor.clone_replace(node, dict(zip(self.inputs, post))) + return graph_replace(node, dict(zip(self.inputs, post)), strict=False) nodes, _ = pytensor.scan( sample, self.symbolic_randoms, non_sequences=_known_scan_ignored_inputs(makeiter(node)) @@ -1429,7 +1438,7 @@ def symbolic_single_sample(self, node, more_replacements=None): node = self.to_flat_input(node, more_replacements=more_replacements) post = [v[0] for v in self.symbolic_randoms] inp = self.inputs - return pytensor.clone_replace(node, dict(zip(inp, post))) + return graph_replace(node, dict(zip(inp, post)), strict=False) def get_optimization_replacements(self, s, d): """*Dev* - optimizations for logP. If sample size is static and equal to 1: @@ -1463,7 +1472,7 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No """ node_in = node if more_replacements: - node = pytensor.clone_replace(node, more_replacements) + node = graph_replace(node, more_replacements, strict=False) if not isinstance(node, (list, tuple)): node = [node] node = self.model.replace_rvs_by_values(node) diff --git a/pymc/variational/stein.py b/pymc/variational/stein.py index 1f8c034f807..d48b6f47108 100644 --- a/pymc/variational/stein.py +++ b/pymc/variational/stein.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytensor import pytensor.tensor as pt +from pytensor.graph.replace import graph_replace + from pymc.pytensorf import floatX from pymc.util import WithMemoization, locally_cachedmethod from pymc.variational.opvi import node_property @@ -85,9 +86,10 @@ def dxkxy(self): def logp_norm(self): sized_symbolic_logp = self.approx.sized_symbolic_logp if self.use_histogram: - sized_symbolic_logp = pytensor.clone_replace( + sized_symbolic_logp = graph_replace( sized_symbolic_logp, dict(zip(self.approx.symbolic_randoms, self.approx.collect("histogram"))), + strict=False, ) return sized_symbolic_logp / self.approx.symbolic_normalizing_constant