Skip to content

Commit

Permalink
Merge pull request #2624 from saitcakmak/mvn_unsqueeze
Browse files Browse the repository at this point in the history
Implement MVN.unsqueeze
  • Loading branch information
saitcakmak authored Jan 22, 2025
2 parents 4b6d48e + 179a1b4 commit 2633973
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 4 deletions.
65 changes: 61 additions & 4 deletions gpytorch/distributions/multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,67 @@ def expand(self, batch_size: torch.Size) -> MultivariateNormal:
See :py:meth:`torch.distributions.Distribution.expand
<torch.distributions.distribution.Distribution.expand>`.
"""
new_loc = self.loc.expand(torch.Size(batch_size) + self.loc.shape[-1:])
new_covar = self._covar.expand(torch.Size(batch_size) + self._covar.shape[-2:])
res = self.__class__(new_loc, new_covar)
return res
# NOTE: Pyro may call this method with list[int] instead of torch.Size.
batch_size = torch.Size(batch_size)
new_loc = self.loc.expand(batch_size + self.loc.shape[-1:])
if self.islazy:
new_covar = self._covar.expand(batch_size + self._covar.shape[-2:])
new = self.__class__(mean=new_loc, covariance_matrix=new_covar)
if self.__unbroadcasted_scale_tril is not None:
# Reuse the scale tril if available.
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.expand(
batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
)
else:
# Non-lazy MVN is represented using scale_tril in PyTorch.
# Constructing it from scale_tril will avoid unnecessary computation.
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
new = self.__new__(type(self))
new._islazy = False
new_scale_tril = self.__unbroadcasted_scale_tril.expand(
batch_size + self.__unbroadcasted_scale_tril.shape[-2:]
)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix.expand(batch_size + self.covariance_matrix.shape[-2:])
return new

def unsqueeze(self, dim: int) -> MultivariateNormal:
r"""
Constructs a new MultivariateNormal with the batch shape unsqueezed
by the given dimension.
For example, if `self.batch_shape = torch.Size([2, 3])` and `dim = 0`, then
the returned MultivariateNormal will have `batch_shape = torch.Size([1, 2, 3])`.
If `dim = -1`, then the returned MultivariateNormal will have
`batch_shape = torch.Size([2, 3, 1])`.
"""
if dim > len(self.batch_shape) or dim < -len(self.batch_shape) - 1:
raise IndexError(
"Dimension out of range (expected to be in range of "
f"[{-len(self.batch_shape) - 1}, {len(self.batch_shape)}], but got {dim})."
)
if dim < 0:
# If dim is negative, get the positive equivalent.
dim = len(self.batch_shape) + dim + 1

new_loc = self.loc.unsqueeze(dim)
if self.islazy:
new_covar = self._covar.unsqueeze(dim)
new = self.__class__(mean=new_loc, covariance_matrix=new_covar)
if self.__unbroadcasted_scale_tril is not None:
# Reuse the scale tril if available.
new.__unbroadcasted_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
else:
# Non-lazy MVN is represented using scale_tril in PyTorch.
# Constructing it from scale_tril will avoid unnecessary computation.
# Initialize using __new__, so that we can skip __init__ and use scale_tril.
new = self.__new__(type(self))
new._islazy = False
new_scale_tril = self.__unbroadcasted_scale_tril.unsqueeze(dim)
super(MultivariateNormal, new).__init__(loc=new_loc, scale_tril=new_scale_tril)
# Set the covar matrix, since it is always available for GPyTorch MVN.
new.covariance_matrix = self.covariance_matrix.unsqueeze(dim)
return new

def get_base_samples(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
r"""
Expand Down
62 changes: 62 additions & 0 deletions test/distributions/test_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import math
import unittest
from itertools import product

import torch
from linear_operator import to_linear_operator
Expand Down Expand Up @@ -323,6 +324,67 @@ def test_base_sample_shape(self):
samples = dist.rsample(torch.Size((16,)), base_samples=torch.randn(16, 5))
self.assertEqual(samples.shape, torch.Size((16, 5)))

def test_multivariate_normal_expand(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype, lazy in product((torch.float, torch.double), (True, False)):
mean = torch.tensor([0, 1, 2], device=device, dtype=dtype)
covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype))
if lazy:
mvn = MultivariateNormal(mean=mean, covariance_matrix=DenseLinearOperator(covmat), validate_args=True)
# Initialize scale tril so we can test that it was expanded.
mvn.scale_tril
else:
mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True)
self.assertEqual(mvn.batch_shape, torch.Size([]))
self.assertEqual(mvn.islazy, lazy)
expanded = mvn.expand(torch.Size([2]))
self.assertIsInstance(expanded, MultivariateNormal)
self.assertEqual(expanded.islazy, lazy)
self.assertEqual(expanded.batch_shape, torch.Size([2]))
self.assertEqual(expanded.event_shape, mvn.event_shape)
self.assertTrue(torch.equal(expanded.mean, mean.expand(2, -1)))
self.assertEqual(expanded.mean.shape, torch.Size([2, 3]))
self.assertTrue(torch.allclose(expanded.covariance_matrix, covmat.expand(2, -1, -1)))
self.assertEqual(expanded.covariance_matrix.shape, torch.Size([2, 3, 3]))
self.assertTrue(torch.allclose(expanded.scale_tril, mvn.scale_tril.expand(2, -1, -1)))
self.assertEqual(expanded.scale_tril.shape, torch.Size([2, 3, 3]))

def test_multivariate_normal_unsqueeze(self, cuda=False):
device = torch.device("cuda") if cuda else torch.device("cpu")
for dtype, lazy in product((torch.float, torch.double), (True, False)):
batch_shape = torch.Size([2, 3])
mean = torch.tensor([0, 1, 2], device=device, dtype=dtype).expand(*batch_shape, -1)
covmat = torch.diag(torch.tensor([1, 0.75, 1.5], device=device, dtype=dtype)).expand(*batch_shape, -1, -1)
if lazy:
mvn = MultivariateNormal(mean=mean, covariance_matrix=DenseLinearOperator(covmat), validate_args=True)
# Initialize scale tril so we can test that it was unsqueezed.
mvn.scale_tril
else:
mvn = MultivariateNormal(mean=mean, covariance_matrix=covmat, validate_args=True)
self.assertEqual(mvn.batch_shape, batch_shape)
self.assertEqual(mvn.islazy, lazy)
for dim, positive_dim, expected_batch in ((1, 1, torch.Size([2, 1, 3])), (-1, 2, torch.Size([2, 3, 1]))):
new = mvn.unsqueeze(dim)
self.assertIsInstance(new, MultivariateNormal)
self.assertEqual(new.islazy, lazy)
self.assertEqual(new.batch_shape, expected_batch)
self.assertEqual(new.event_shape, mvn.event_shape)
self.assertTrue(torch.equal(new.mean, mean.unsqueeze(positive_dim)))
self.assertEqual(new.mean.shape, expected_batch + torch.Size([3]))
self.assertTrue(torch.allclose(new.covariance_matrix, covmat.unsqueeze(positive_dim)))
self.assertEqual(new.covariance_matrix.shape, expected_batch + torch.Size([3, 3]))
self.assertTrue(torch.allclose(new.scale_tril, mvn.scale_tril.unsqueeze(positive_dim)))
self.assertEqual(new.scale_tril.shape, expected_batch + torch.Size([3, 3]))

# Check for dim validation.
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
mvn.unsqueeze(3)
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
mvn.unsqueeze(-4)
# Should not raise error up to 2 or -3.
mvn.unsqueeze(2)
mvn.unsqueeze(-3)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2633973

Please sign in to comment.