Skip to content

Commit

Permalink
Allow forwarding of MvNormal method to SymbolicRandomVariables
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 27, 2025
1 parent 02a5a36 commit efaa04e
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 10 deletions.
10 changes: 9 additions & 1 deletion pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,14 @@ def batch_ndim(self, node: Apply) -> int:
out_ndim = max(getattr(out.type, "ndim", 0) for out in node.outputs)
return out_ndim - self.ndim_supp

def rebuild_rv(self, *args, **kwargs):
"""Rebuild the RandomVariable with new inputs."""
if not hasattr(self, "rv_op"):
raise NotImplementedError(
f"SymbolicRandomVariable {self} without `rv_op` method cannot be rebuilt automatically."
)
return self.rv_op(*args, **kwargs)


@_change_dist_size.register(SymbolicRandomVariable)
def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) -> TensorVariable:
Expand All @@ -403,7 +411,7 @@ def change_symbolic_rv_size(op: SymbolicRandomVariable, rv, new_size, expand) ->
if expand and not rv_size_is_none(size):
new_size = tuple(new_size) + tuple(size)

return op.rv_op(*params, size=new_size)
return op.rebuild_rv(*params, size=new_size)


class Distribution(metaclass=DistributionMeta):
Expand Down
37 changes: 28 additions & 9 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,19 @@ def logp(value, mu, cov):
)


class PrecisionMvNormalRV(SymbolicRandomVariable):
class SymbolicMVNormalUsedInternally(SymbolicRandomVariable):
"""Helper subclass that handles the forwarding / caching of method to `MvNormal` used internally."""

def __init__(self, *args, method: str, **kwargs):
super().__init__(*args, **kwargs)
self.method = method

def rebuild_rv(self, *args, **kwargs):
# rv_op is a classmethod, so it doesn't have access to the instance method
return self.rv_op(*args, method=self.method, **kwargs)


class PrecisionMvNormalRV(SymbolicMVNormalUsedInternally):
r"""A specialized multivariate normal random variable defined in terms of precision.
This class is introduced during specialization logprob rewrites, and not meant to be used directly.
Expand All @@ -313,14 +325,17 @@ class PrecisionMvNormalRV(SymbolicRandomVariable):
_print_name = ("PrecisionMultivariateNormal", "\\operatorname{PrecisionMultivariateNormal}")

@classmethod
def rv_op(cls, mean, tau, *, rng=None, size=None):
def rv_op(cls, mean, tau, *, method: str = "cholesky", rng=None, size=None):
rng = normalize_rng_param(rng)
size = normalize_size_param(size)
cov = pt.linalg.inv(tau)
next_rng, draws = multivariate_normal(mean, cov, size=size, rng=rng).owner.outputs
next_rng, draws = multivariate_normal(
mean, cov, size=size, rng=rng, method=method
).owner.outputs
return cls(
inputs=[rng, size, mean, tau],
outputs=[next_rng, draws],
method=method,
)(rng, size, mean, tau)


Expand Down Expand Up @@ -365,7 +380,7 @@ def mv_normal_to_precision_mv_normal(fgraph, node):
)


class MvStudentTRV(SymbolicRandomVariable):
class MvStudentTRV(SymbolicMVNormalUsedInternally):
r"""A specialized multivariate normal random variable defined in terms of precision.
This class is introduced during specialization logprob rewrites, and not meant to be used directly.
Expand All @@ -376,7 +391,7 @@ class MvStudentTRV(SymbolicRandomVariable):
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")

@classmethod
def rv_op(cls, nu, mean, scale, *, rng=None, size=None):
def rv_op(cls, nu, mean, scale, *, method: str = "cholesky", rng=None, size=None):
nu = pt.as_tensor(nu)
mean = pt.as_tensor(mean)
scale = pt.as_tensor(scale)
Expand All @@ -387,14 +402,15 @@ def rv_op(cls, nu, mean, scale, *, rng=None, size=None):
size = implicit_size_from_params(nu, mean, scale, ndims_params=cls.ndims_params)

next_rng, mv_draws = multivariate_normal(
mean.zeros_like(), scale, size=size, rng=rng
mean.zeros_like(), scale, size=size, rng=rng, method=method
).owner.outputs
next_rng, chi2_draws = chisquare(nu, size=size, rng=next_rng).owner.outputs
draws = mean + (mv_draws / pt.sqrt(chi2_draws / nu)[..., None])

return cls(
inputs=[rng, size, nu, mean, scale],
outputs=[next_rng, draws],
method=method,
)(rng, size, nu, mean, scale)


Expand Down Expand Up @@ -1923,12 +1939,12 @@ def logp(value, mu, rowchol, colchol):
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet


class KroneckerNormalRV(SymbolicRandomVariable):
class KroneckerNormalRV(SymbolicMVNormalUsedInternally):
ndim_supp = 1
_print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}")

@classmethod
def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
def rv_op(cls, mu, sigma, *covs, method: str = "cholesky", size=None, rng=None):
mu = pt.as_tensor(mu)
sigma = pt.as_tensor(sigma)
covs = [pt.as_tensor(cov) for cov in covs]
Expand All @@ -1937,7 +1953,9 @@ def rv_op(cls, mu, sigma, *covs, size=None, rng=None):

cov = reduce(pt.linalg.kron, covs)
cov = cov + sigma**2 * pt.eye(cov.shape[-2])
next_rng, draws = multivariate_normal(mean=mu, cov=cov, size=size, rng=rng).owner.outputs
next_rng, draws = multivariate_normal(
mean=mu, cov=cov, size=size, rng=rng, method=method
).owner.outputs

covs_sig = ",".join(f"(a{i},b{i})" for i in range(len(covs)))
extended_signature = f"[rng],[size],(m),(),{covs_sig}->[rng],(m)"
Expand All @@ -1946,6 +1964,7 @@ def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
inputs=[rng, size, mu, sigma, *covs],
outputs=[next_rng, draws],
extended_signature=extended_signature,
method=method,
)(rng, size, mu, sigma, *covs)


Expand Down
20 changes: 20 additions & 0 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2469,6 +2469,26 @@ def test_mvstudentt_mu_convenience():
np.testing.assert_allclose(mu.eval(), np.ones((10, 2, 3)))


def test_mvstudentt_method():
def all_svd_method(fgraph):
found_one = False
for node in fgraph.toposort():
if isinstance(node.op, pm.MvNormal):
found_one = True
if not node.op.method == "svd":
return False
return found_one # We want to fail if there were no MvNormal nodes

x = pm.MvStudentT.dist(nu=4, scale=np.eye(3), method="svd")
assert x.type.shape == (3,)
assert all_svd_method(x.owner.op.fgraph)

# Changing the size should preserve the method
resized_x = change_dist_size(x, (2,))
assert resized_x.type.shape == (2, 3)
assert all_svd_method(resized_x.owner.op.fgraph)


def test_precision_mv_normal_optimization():
rng = np.random.default_rng(sum(map(ord, "be precise")))

Expand Down

0 comments on commit efaa04e

Please sign in to comment.