Skip to content

Commit 8c1efd9

Browse files
committed
Add batch_shape property to SingleTaskVariationalGP
This enables the use of `SingleTaskVariationalGP` with certain botorch features (e.g. with entropy-based acquistion functions as requested in pytorch#1795). This is a bit of a band-aid, the proper thing to do here is to fix up the PR upstreaming this to gpytorch (cornellius-gp/gpytorch#2307) to enable support for `batch_shape` on all approximate gpytorch models, and then just call that on the `model` in `ApproximateGPyTorchModel`.
1 parent 71690a8 commit 8c1efd9

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

botorch/models/approximate_gp.py

+11
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,17 @@ def __init__(
426426

427427
self.to(train_X)
428428

429+
@property
430+
def batch_shape(self) -> torch.Size:
431+
r"""The batch shape of the model.
432+
433+
This is a batch shape from an I/O perspective. For a model with `m`
434+
outputs, a `test_batch_shape x q x d`-shaped input `X` to the `posterior`
435+
method returns a Posterior object over an output of shape
436+
`broadcast(test_batch_shape, model.batch_shape) x q x m`.
437+
"""
438+
return self._input_batch_shape
439+
429440
def init_inducing_points(
430441
self,
431442
inputs: Tensor,

test/models/test_approximate_gp.py

+2
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def test_posterior(self):
9797
model = SingleTaskVariationalGP(tx, ty, inducing_points=tx)
9898
posterior = model.posterior(test)
9999
self.assertIsInstance(posterior, GPyTorchPosterior)
100+
# test batch_shape property
101+
self.assertEqual(model.batch_shape, tx.shape[:-2])
100102

101103
def test_variational_setUp(self):
102104
for dtype in [torch.float, torch.double]:

0 commit comments

Comments
 (0)