Skip to content

Commit

Permalink
Merge branch 'master' into feature/online-learning-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobrgardner authored Jun 20, 2024
2 parents 93d87cd + 9551eba commit e09674d
Show file tree
Hide file tree
Showing 14 changed files with 265 additions and 20 deletions.
2 changes: 1 addition & 1 deletion docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ gpytorch.distributions
===================================

GPyTorch distribution objects are essentially the same as torch distribution objects.
For the most part, GpyTorch relies on torch's distribution library.
For the most part, GPyTorch relies on torch's distribution library.
However, we offer two custom distributions.

We implement a custom :obj:`~gpytorch.distributions.MultivariateNormal` that accepts
Expand Down
9 changes: 8 additions & 1 deletion docs/source/kernels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ gpytorch.kernels


If you don't know what kernel to use, we recommend that you start out with a
:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel)`.
:code:`gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel()) + gpytorch.kernel.ConstantKernel()`.


Kernel
Expand All @@ -22,6 +22,13 @@ Kernel
Standard Kernels
-----------------------------

:hidden:`ConstantKernel`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ConstantKernel
:members:


:hidden:`CosineKernel`
~~~~~~~~~~~~~~~~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion examples/01_Exact_GPs/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Exact GPs (Regression)
========================

Regression with a Gaussian noise model is the cannonical example of Gaussian processes.
Regression with a Gaussian noise model is the canonical example of Gaussian processes.
These examples will work for small to medium sized datasets (~2,000 data points).
All examples here use exact GP inference.

Expand Down
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from . import keops
from .additive_structure_kernel import AdditiveStructureKernel
from .arc_kernel import ArcKernel
from .constant_kernel import ConstantKernel
from .cosine_kernel import CosineKernel
from .cylindrical_kernel import CylindricalKernel
from .distributional_input_kernel import DistributionalInputKernel
Expand Down Expand Up @@ -38,6 +39,7 @@
"ArcKernel",
"AdditiveKernel",
"AdditiveStructureKernel",
"ConstantKernel",
"CylindricalKernel",
"MultiDeviceKernel",
"CosineKernel",
Expand Down
123 changes: 123 additions & 0 deletions gpytorch/kernels/constant_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/usr/bin/env python3

from typing import Optional, Tuple

import torch
from torch import Tensor

from ..constraints import Interval, Positive
from ..priors import Prior
from .kernel import Kernel


class ConstantKernel(Kernel):
"""
Constant covariance kernel for the probabilistic inference of constant coefficients.
ConstantKernel represents the prior variance `k(x1, x2) = var(c)` of a constant `c`.
The prior variance of the constant is optimized during the GP hyper-parameter
optimization stage. The actual value of the constant is computed (implicitly) using
the linear algebraic approaches for the computation of GP samples and posteriors.
The constant kernel `k_constant` is most useful as a modification of an arbitrary
base kernel `k_base`:
1) Additive constants: The modification `k_base + k_constant` allows the GP to
infer a non-zero asymptotic value far from the training data, which generally
leads to more accurate extrapolation. Notably, the uncertainty in this constant
value affects the posterior covariances through the posterior inference equations.
This is not the case when a constant prior mean is not used, since the prior mean
does not show up the posterior covariance and is regularized by the log-determinant
during the optimization of the marginal likelihood.
2) Multiplicative constants: The modification `k_base * k_constant` allows the GP to
modulate the variance of the kernel `k_base`, and is mathematically identical to
`ScaleKernel(base_kernel)` with the same constant.
"""

has_lengthscale = False

def __init__(
self,
batch_shape: Optional[torch.Size] = None,
constant_prior: Optional[Prior] = None,
constant_constraint: Optional[Interval] = None,
active_dims: Optional[Tuple[int, ...]] = None,
):
"""Constructor of ConstantKernel.
Args:
batch_shape: The batch shape of the kernel.
constant_prior: Prior over the constant parameter.
constant_constraint: Constraint to place on constant parameter.
active_dims: The dimensions of the input with which to evaluate the kernel.
This is mute for the constant kernel, but added for compatability with
the Kernel API.
"""
super().__init__(batch_shape=batch_shape, active_dims=active_dims)

self.register_parameter(
name="raw_constant",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)),
)

if constant_prior is not None:
if not isinstance(constant_prior, Prior):
raise TypeError("Expected gpytorch.priors.Prior but got " + type(constant_prior).__name__)
self.register_prior(
"constant_prior",
constant_prior,
lambda m: m.constant,
lambda m, v: m._set_constant(v),
)

if constant_constraint is None:
constant_constraint = Positive()
self.register_constraint("raw_constant", constant_constraint)

@property
def constant(self) -> Tensor:
return self.raw_constant_constraint.transform(self.raw_constant)

@constant.setter
def constant(self, value: Tensor) -> None:
self._set_constant(value)

def _set_constant(self, value: Tensor) -> None:
value = value.view(*self.batch_shape, 1)
self.initialize(raw_constant=self.raw_constant_constraint.inverse_transform(value))

def forward(
self,
x1: Tensor,
x2: Tensor,
diag: Optional[bool] = False,
last_dim_is_batch: Optional[bool] = False,
) -> Tensor:
"""Evaluates the constant kernel.
Args:
x1: First input tensor of shape (batch_shape x n1 x d).
x2: Second input tensor of shape (batch_shape x n2 x d).
diag: If True, returns the diagonal of the covariance matrix.
last_dim_is_batch: If True, the last dimension of size `d` of the input
tensors are treated as a batch dimension.
Returns:
A (batch_shape x n1 x n2)-dim, resp. (batch_shape x n1)-dim, tensor of
constant covariance values if diag is False, resp. True.
"""
if last_dim_is_batch:
x1 = x1.transpose(-1, -2).unsqueeze(-1)
x2 = x2.transpose(-1, -2).unsqueeze(-1)

dtype = torch.promote_types(x1.dtype, x2.dtype)
batch_shape = torch.broadcast_shapes(x1.shape[:-2], x2.shape[:-2])
shape = batch_shape + (x1.shape[-2],) + (() if diag else (x2.shape[-2],))
constant = self.constant.to(dtype=dtype, device=x1.device)

if not diag:
constant = constant.unsqueeze(-1)

if last_dim_is_batch:
constant = constant.unsqueeze(-1)

return constant.expand(shape)
2 changes: 1 addition & 1 deletion gpytorch/kernels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def forward(
) -> Union[Tensor, LinearOperator]:
r"""
Computes the covariance between :math:`\mathbf x_1` and :math:`\mathbf x_2`.
This method should be imlemented by all Kernel subclasses.
This method should be implemented by all Kernel subclasses.
:param x1: First set of data (... x N x D).
:param x2: Second set of data (... x M x D).
Expand Down
2 changes: 1 addition & 1 deletion gpytorch/kernels/periodic_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class PeriodicKernel(Kernel):
>>> covar = covar_module(x) # Output: LazyVariable of size (2 x 10 x 10)
.. _David Mackay's Introduction to Gaussian Processes equation 47:
http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.81.1927&rep=rep1&type=pdf
https://inference.org.uk/mackay/gpB.pdf
"""

has_lengthscale = True
Expand Down
2 changes: 1 addition & 1 deletion gpytorch/kernels/rff_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class RFFKernel(Kernel):
.. math::
\begin{equation}
k(\Delta) = \exp{(-\frac{\Delta^2}{2\sigma^2})}$ and $p(\omega) = \exp{(-\frac{\sigma^2\omega^2}{2})}
k(\Delta) = \exp{(-\frac{\Delta^2}{2\sigma^2})} \text{ and } p(\omega) = \exp{(-\frac{\sigma^2\omega^2}{2})}
\end{equation}
where :math:`\Delta = x - x'`.
Expand Down
2 changes: 1 addition & 1 deletion gpytorch/mlls/leave_one_out_pseudo_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, likelihood, model):

def forward(self, function_dist: MultivariateNormal, target: Tensor, *params) -> Tensor:
r"""
Computes the leave one out likelihood given :math:`p(\mathbf f)` and `\mathbf y`
Computes the leave one out likelihood given :math:`p(\mathbf f)` and :math:`\mathbf y`
:param ~gpytorch.distributions.MultivariateNormal output: the outputs of the latent function
(the :obj:`~gpytorch.models.GP`)
Expand Down
6 changes: 3 additions & 3 deletions gpytorch/test/base_keops_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_forward_x1_neq_x2(self, use_keops=True, ard=False, **kwargs):
# The patch makes sure that we're actually using KeOps
k1 = kern1(x1, x2).to_dense()
k2 = kern2(x1, x2).to_dense()
self.assertLess(torch.norm(k1 - k2), 1e-4)
self.assertLess(torch.norm(k1 - k2), 1e-3)

if use_keops:
self.assertTrue(keops_mock.called)
Expand All @@ -86,7 +86,7 @@ def test_batch_matmul(self, use_keops=True, **kwargs):
# The patch makes sure that we're actually using KeOps
res1 = kern1(x1, x1).matmul(rhs)
res2 = kern2(x1, x1).matmul(rhs)
self.assertLess(torch.norm(res1 - res2), 1e-4)
self.assertLess(torch.norm(res1 - res2), 1e-3)

if use_keops:
self.assertTrue(keops_mock.called)
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_gradient(self, use_keops=True, ard=False, **kwargs):
# stack all gradients into a tensor
grad_s1 = torch.vstack(torch.autograd.grad(s1, [*kern1.hyperparameters()]))
grad_s2 = torch.vstack(torch.autograd.grad(s2, [*kern2.hyperparameters()]))
self.assertAllClose(grad_s1, grad_s2, rtol=1e-4, atol=1e-5)
self.assertAllClose(grad_s1, grad_s2, rtol=1e-3, atol=1e-3)

if use_keops:
self.assertTrue(keops_mock.called)
Expand Down
10 changes: 4 additions & 6 deletions gpytorch/test/base_kernel_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,23 +122,21 @@ def test_no_batch_kernel_double_batch_x_ard(self):
actual_diag = actual_covar_mat.diagonal(dim1=-1, dim2=-2)
self.assertAllClose(kernel_diag, actual_diag, rtol=1e-3, atol=1e-5)

def test_smoke_double_batch_kernel_double_batch_x_no_ard(self):
def test_smoke_double_batch_kernel_double_batch_x_no_ard(self) -> None:
kernel = self.create_kernel_no_ard(batch_shape=torch.Size([3, 2]))
x = self.create_data_double_batch()
batch_covar_mat = kernel(x).evaluate_kernel().to_dense()
kernel(x).evaluate_kernel().to_dense()
kernel(x, diag=True)
return batch_covar_mat

def test_smoke_double_batch_kernel_double_batch_x_ard(self):
def test_smoke_double_batch_kernel_double_batch_x_ard(self) -> None:
try:
kernel = self.create_kernel_ard(num_dims=2, batch_shape=torch.Size([3, 2]))
except NotImplementedError:
return

x = self.create_data_double_batch()
batch_covar_mat = kernel(x).evaluate_kernel().to_dense()
kernel(x).evaluate_kernel().to_dense()
kernel(x, diag=True)
return batch_covar_mat

def test_kernel_getitem_single_batch(self):
kernel = self.create_kernel_no_ard(batch_shape=torch.Size([2]))
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def find_version(*file_paths):

torch_min = "1.11"
install_requires = [
"mpmath>=0.19,<=1.3", # avoid incompatibiltiy with torch+sympy with mpmath 1.4
"scikit-learn",
"scipy",
"linear_operator>=0.5.2",
Expand Down Expand Up @@ -81,6 +82,7 @@ def find_version(*file_paths):
"nbclient<=0.7.3",
"nbformat<=5.8.0",
"nbsphinx<=0.9.1",
"lxml_html_clean",
"platformdirs<=3.2.0",
"setuptools_scm<=7.1.0",
"sphinx<=6.2.1",
Expand Down
8 changes: 4 additions & 4 deletions test/examples/test_svgp_gp_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def train_data(cuda=False):
train_x = torch.linspace(0, 1, 260)
train_x = torch.linspace(0, 1, 150)
train_y = torch.cos(train_x * (2 * math.pi)).gt(0).float()
if cuda:
return train_x.cuda(), train_y.cuda()
Expand Down Expand Up @@ -49,7 +49,7 @@ class TestSVGPClassification(BaseTestCase, unittest.TestCase):
def test_classification_error(self, cuda=False, mll_cls=gpytorch.mlls.VariationalELBO):
train_x, train_y = train_data(cuda=cuda)
likelihood = BernoulliLikelihood()
model = SVGPClassificationModel(torch.linspace(0, 1, 25))
model = SVGPClassificationModel(torch.linspace(0, 1, 64))
mll = mll_cls(likelihood, model, num_data=len(train_y))
if cuda:
likelihood = likelihood.cuda()
Expand All @@ -59,12 +59,12 @@ def test_classification_error(self, cuda=False, mll_cls=gpytorch.mlls.Variationa
# Find optimal model hyperparameters
model.train()
likelihood.train()
optimizer = optim.Adam([{"params": model.parameters()}, {"params": likelihood.parameters()}], lr=0.1)
optimizer = optim.Adam([{"params": model.parameters()}, {"params": likelihood.parameters()}], lr=0.03)

_wrapped_cg = MagicMock(wraps=linear_operator.utils.linear_cg)
_cg_mock = patch("linear_operator.utils.linear_cg", new=_wrapped_cg)
with _cg_mock as cg_mock:
for _ in range(400):
for _ in range(100):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
Expand Down
Loading

0 comments on commit e09674d

Please sign in to comment.