Skip to content

Commit

Permalink
Use graph_replace instead of clone_replace in VI
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 26, 2023
1 parent 41ebb0a commit f8cfe35
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 15 deletions.
3 changes: 2 additions & 1 deletion pymc/variational/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from arviz import InferenceData
from pytensor import tensor as pt
from pytensor.graph.basic import Variable
from pytensor.graph.replace import graph_replace
from pytensor.tensor.var import TensorVariable

import pymc as pm
Expand Down Expand Up @@ -390,7 +391,7 @@ def evaluate_over_trace(self, node):
node = self.to_flat_input(node)

def sample(post, *_):
return pytensor.clone_replace(node, {self.input: post})
return graph_replace(node, {self.input: post}, strict=False)

nodes, _ = pytensor.scan(
sample, self.histogram, non_sequences=_known_scan_ignored_inputs(makeiter(node))
Expand Down
33 changes: 21 additions & 12 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
import xarray

from pytensor.graph.basic import Variable
from pytensor.graph.replace import graph_replace
from pytensor.tensor.shape import unbroadcast

import pymc as pm

Expand Down Expand Up @@ -1002,7 +1004,7 @@ def set_size_and_deterministic(
"""

flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
node_out = pytensor.clone_replace(node, flat2rand)
node_out = graph_replace(node, flat2rand, strict=False)
assert not (
set(makeiter(self.input)) & set(pytensor.graph.graph_inputs(makeiter(node_out)))
)
Expand All @@ -1012,7 +1014,7 @@ def set_size_and_deterministic(

def to_flat_input(self, node):
"""*Dev* - replace vars with flattened view stored in `self.inputs`"""
return pytensor.clone_replace(node, self.replacements)
return graph_replace(node, self.replacements, strict=False)

def symbolic_sample_over_posterior(self, node):
"""*Dev* - performs sampling of node applying independent samples from posterior each time.
Expand All @@ -1023,7 +1025,7 @@ def symbolic_sample_over_posterior(self, node):
random = pt.specify_shape(random, self.symbolic_initial.type.shape)

def sample(post, *_):
return pytensor.clone_replace(node, {self.input: post})
return graph_replace(node, {self.input: post}, strict=False)

nodes, _ = pytensor.scan(
sample, random, non_sequences=_known_scan_ignored_inputs(makeiter(random))
Expand All @@ -1038,7 +1040,7 @@ def symbolic_single_sample(self, node):
"""
node = self.to_flat_input(node)
random = self.symbolic_random.astype(self.symbolic_initial.dtype)
return pytensor.clone_replace(node, {self.input: random[0]})
return graph_replace(node, {self.input: random[0]}, strict=False)

def make_size_and_deterministic_replacements(self, s, d, more_replacements=None):
"""*Dev* - creates correct replacements for initial depending on
Expand All @@ -1059,8 +1061,15 @@ def make_size_and_deterministic_replacements(self, s, d, more_replacements=None)
"""
initial = self._new_initial(s, d, more_replacements)
initial = pt.specify_shape(initial, self.symbolic_initial.type.shape)
# The static shape of initial may be more precise than self.symbolic_initial,
# and reveal previously unknown broadcastable dimensions. We have to mask those again.
if initial.type.broadcastable != self.symbolic_initial.type.broadcastable:
unbroadcast_axes = (
i for i, b in enumerate(self.symbolic_initial.type.broadcastable) if not b
)
initial = unbroadcast(initial, *unbroadcast_axes)
if more_replacements:
initial = pytensor.clone_replace(initial, more_replacements)
initial = graph_replace(initial, more_replacements, strict=False)
return {self.symbolic_initial: initial}

@node_property
Expand Down Expand Up @@ -1394,17 +1403,17 @@ def set_size_and_deterministic(self, node, s, d, more_replacements=None):
_node = node
optimizations = self.get_optimization_replacements(s, d)
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)
node = pytensor.clone_replace(node, optimizations)
node = pytensor.clone_replace(node, flat2rand)
node = graph_replace(node, optimizations, strict=False)
node = graph_replace(node, flat2rand, strict=False)
assert not (set(self.symbolic_randoms) & set(pytensor.graph.graph_inputs(makeiter(node))))
try_to_set_test_value(_node, node, s)
return node

def to_flat_input(self, node, more_replacements=None):
"""*Dev* - replace vars with flattened view stored in `self.inputs`"""
more_replacements = more_replacements or {}
node = pytensor.clone_replace(node, more_replacements)
return pytensor.clone_replace(node, self.replacements)
node = graph_replace(node, more_replacements, strict=False)
return graph_replace(node, self.replacements, strict=False)

def symbolic_sample_over_posterior(self, node, more_replacements=None):
"""*Dev* - performs sampling of node applying independent samples from posterior each time.
Expand All @@ -1413,7 +1422,7 @@ def symbolic_sample_over_posterior(self, node, more_replacements=None):
node = self.to_flat_input(node)

def sample(*post):
return pytensor.clone_replace(node, dict(zip(self.inputs, post)))
return graph_replace(node, dict(zip(self.inputs, post)), strict=False)

nodes, _ = pytensor.scan(
sample, self.symbolic_randoms, non_sequences=_known_scan_ignored_inputs(makeiter(node))
Expand All @@ -1429,7 +1438,7 @@ def symbolic_single_sample(self, node, more_replacements=None):
node = self.to_flat_input(node, more_replacements=more_replacements)
post = [v[0] for v in self.symbolic_randoms]
inp = self.inputs
return pytensor.clone_replace(node, dict(zip(inp, post)))
return graph_replace(node, dict(zip(inp, post)), strict=False)

def get_optimization_replacements(self, s, d):
"""*Dev* - optimizations for logP. If sample size is static and equal to 1:
Expand Down Expand Up @@ -1463,7 +1472,7 @@ def sample_node(self, node, size=None, deterministic=False, more_replacements=No
"""
node_in = node
if more_replacements:
node = pytensor.clone_replace(node, more_replacements)
node = graph_replace(node, more_replacements, strict=False)
if not isinstance(node, (list, tuple)):
node = [node]
node = self.model.replace_rvs_by_values(node)
Expand Down
6 changes: 4 additions & 2 deletions pymc/variational/stein.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytensor
import pytensor.tensor as pt

from pytensor.graph.replace import graph_replace

from pymc.pytensorf import floatX
from pymc.util import WithMemoization, locally_cachedmethod
from pymc.variational.opvi import node_property
Expand Down Expand Up @@ -85,9 +86,10 @@ def dxkxy(self):
def logp_norm(self):
sized_symbolic_logp = self.approx.sized_symbolic_logp
if self.use_histogram:
sized_symbolic_logp = pytensor.clone_replace(
sized_symbolic_logp = graph_replace(
sized_symbolic_logp,
dict(zip(self.approx.symbolic_randoms, self.approx.collect("histogram"))),
strict=False,
)
return sized_symbolic_logp / self.approx.symbolic_normalizing_constant

Expand Down

0 comments on commit f8cfe35

Please sign in to comment.