Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New CholeskyCorr transform #7700

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,7 +1510,7 @@ def helper_deterministics(cls, n, packed_chol):

class LKJCorrRV(RandomVariable):
name = "lkjcorr"
signature = "(),()->(n)"
signature = "(),()->(n,n)"
dtype = "floatX"
_print_name = ("LKJCorrRV", "\\operatorname{LKJCorrRV}")

Expand All @@ -1527,8 +1527,8 @@ def make_node(self, rng, size, n, eta):

def _supp_shape_from_params(self, dist_params, **kwargs):
n = dist_params[0].squeeze()
dist_shape = ((n * (n - 1)) // 2,)
return dist_shape
# dist_shape = ((n * (n - 1)) // 2,)
return (n, n)

@classmethod
def rng_fn(cls, rng, n, eta, size):
Expand Down Expand Up @@ -1609,23 +1609,26 @@ def logp(value, n, eta):
-------
TensorVariable
"""
if value.ndim > 1:
raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)")

# TODO: PyTensor does not have a `triu_indices`, so we can only work with constant
# n (or else find a different expression)
# if value.ndim > 1:
# raise NotImplementedError("LKJCorr logp is only implemented for vector values (ndim=1)")
#
try:
n = int(get_underlying_scalar_constant_value(n))
except NotScalarConstantError:
raise NotImplementedError("logp only implemented for constant `n`")

shape = n * (n - 1) // 2
tri_index = np.zeros((n, n), dtype="int32")
tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)
# shape = n * (n - 1) // 2
# tri_index = np.zeros((n, n), dtype="int32")
# tri_index[np.triu_indices(n, k=1)] = np.arange(shape)
# tri_index[np.triu_indices(n, k=1)[::-1]] = np.arange(shape)

# value = pt.take(value, tri_index)
# value = pt.fill_diagonal(value, 1)

value = pt.take(value, tri_index)
value = pt.fill_diagonal(value, 1)
# print(n, type(n))
# print(value.type.shape)
# value = value @ value.T
# print(value.type.shape)

# TODO: _lkj_normalizing_constant currently requires `eta` and `n` to be constants
try:
Expand Down
79 changes: 79 additions & 0 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import singledispatch

import numpy as np
import pytensor
import pytensor.tensor as pt


Expand Down Expand Up @@ -164,6 +165,84 @@ def log_jac_det(self, value, *inputs):
return pt.sum(value[..., self.diag_idxs], axis=-1)


class CholeskyCorr(Transform):
"""Get a Cholesky Corr from a packed vector."""

name = "cholesky-corr-packed"

def __init__(self, n):
"""Create a CholeskyCorrPack object.

Parameters
----------
n: int
Number of diagonal entries in the LKJCholeskyCov distribution
"""
self.n = n

def step(self, i, counter, L, y):
y_star = y[counter : counter + i]
dsy = y_star.dot(y_star)
alpha_r = 1 / (dsy + 1)
gamma = pt.sqrt(dsy + 2) * alpha_r

x = pt.join(0, gamma * y_star, pt.atleast_1d(alpha_r))
next_L = L[i, : i + 1].set(x)
log_det = pt.log(2) + 0.5 * (i - 2) * pt.log(dsy + 2) - i * pt.log(1 + dsy)

return next_L, log_det

def _compute_L_and_logdet_scan(self, value, *inputs):
L = pt.eye(self.n)
idxs = pt.arange(1, self.n)
counters = pt.arange(0, self.n).cumsum()

results, _ = pytensor.scan(
self.step, outputs_info=[L, None], sequences=[idxs, counters], non_sequences=[value]
)

L_seq, log_det_seq = results
L = L_seq[-1]
log_det = pt.sum(log_det_seq)

return L, log_det

def _compute_L_and_logdet(self, value, *inputs):
n = self.n
counter = 0
L = pt.eye(n)
log_det = 0

for i in range(1, n):
y_star = value[counter : counter + i]
dsy = y_star.dot(y_star)
alpha_r = 1 / (dsy + 1)
gamma = pt.sqrt(dsy + 2) * alpha_r

x = pt.join(0, gamma * y_star, pt.atleast_1d(alpha_r))
L = L[i, : i + 1].set(x)
log_det += pt.log(2) + 0.5 * (i - 2) * pt.log(dsy + 2) - i * pt.log(1 + dsy)

counter += i

# Return whole matrix? Or just lower triangle?
return L, log_det

def backward(self, value, *inputs):
L, _ = self._compute_L_and_logdet_scan(value, *inputs)
return L

def forward(self, value, *inputs):
# TODO: This is a placeholder
n = self.n
size = n * (n - 1) // 2
return pt.as_tensor_variable(np.random.normal(size=size))

def log_jac_det(self, value, *inputs):
_, log_det = self._compute_L_and_logdet_scan(value, *inputs)
return log_det


Chain = ChainedTransform

simplex = SimplexTransform()
Expand Down
Loading