Skip to content

Commit

Permalink
Merge branch 'main' into mvn_unsqueeze
Browse files Browse the repository at this point in the history
  • Loading branch information
saitcakmak authored Jan 22, 2025
2 parents 312b0f1 + 4b6d48e commit 179a1b4
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
"source": [
"## The PyroGP model\n",
"\n",
"In order to use Pyro with GPyTorch, your model must inherit from `gpytorch.models.PyroGP` (rather than `gpytorch.modelks.ApproximateGP`). The `PyroGP` extends the `ApproximateGP` class and differs in a few key ways:\n",
"In order to use Pyro with GPyTorch, your model must inherit from `gpytorch.models.PyroGP` (rather than `gpytorch.models.ApproximateGP`). The `PyroGP` extends the `ApproximateGP` class and differs in a few key ways:\n",
"\n",
"- It adds the `model` and `guide` functions which are used by Pyro's inference engine.\n",
"- It's constructor requires two additional arguments beyond the variational strategy:\n",
Expand Down
3 changes: 2 additions & 1 deletion gpytorch/distributions/multitask_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def from_independent_mvns(cls, mvns):
if any(isinstance(mvn, MultitaskMultivariateNormal) for mvn in mvns):
raise ValueError("Cannot accept MultitaskMultivariateNormals")
if not all(m.batch_shape == mvns[0].batch_shape for m in mvns[1:]):
raise ValueError("All MultivariateNormals must have the same batch shape")
batch_shape = torch.broadcast_shapes(*(m.batch_shape for m in mvns))
mvns = [mvn.expand(batch_shape) for mvn in mvns]
if not all(m.event_shape == mvns[0].event_shape for m in mvns[1:]):
raise ValueError("All MultivariateNormals must have the same event shape")
mean = torch.stack([mvn.mean for mvn in mvns], -1)
Expand Down
3 changes: 2 additions & 1 deletion gpytorch/kernels/index_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class IndexKernel(Kernel):
covar_factor:
The :math:`B` matrix.
raw_var:
The element-wise log of the :math:`\mathbf v` vector.
The element-wise `Softplus <https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html>`_
of the :math:`\mathbf v` vector (assuming the default `var_constraint`).
"""

def __init__(
Expand Down
8 changes: 8 additions & 0 deletions test/distributions/test_multitask_multivariate_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ def test_from_independent_mvns(self, cuda=False):
self.assertEqual(list(mvn.mean.shape), expected_mean_shape)
self.assertEqual(list(mvn.covariance_matrix.shape), expected_covar_shape)

# Test mixed batch mode mvns
# Second MVN is batched, so the first one will be expanded to match.
mvns[1] = mvns[1].expand(torch.Size([3]))
expected_mvn = mvn.expand(torch.Size([3]))
mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
self.assertTrue(torch.equal(mvn.mean, expected_mvn.mean))
self.assertTrue(torch.equal(mvn.covariance_matrix, expected_mvn.covariance_matrix))

# Test batch mode mvns
b = 3
mvns = [
Expand Down

0 comments on commit 179a1b4

Please sign in to comment.