Skip to content

Commit

Permalink
Merge pull request #2512 from m-julian/matern52_grad
Browse files Browse the repository at this point in the history
Matern52 grad
  • Loading branch information
jacobrgardner authored Jun 20, 2024
2 parents 2e7959d + f8b9a5e commit 25da2cc
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 0 deletions.
2 changes: 2 additions & 0 deletions gpytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .kernel import AdditiveKernel, Kernel, ProductKernel
from .lcm_kernel import LCMKernel
from .linear_kernel import LinearKernel
from .matern52_kernel_grad import Matern52KernelGrad
from .matern_kernel import MaternKernel
from .multi_device_kernel import MultiDeviceKernel
from .multitask_kernel import MultitaskKernel
Expand Down Expand Up @@ -69,4 +70,5 @@
"ScaleKernel",
"SpectralDeltaKernel",
"SpectralMixtureKernel",
"Matern52KernelGrad",
]
152 changes: 152 additions & 0 deletions gpytorch/kernels/matern52_kernel_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/usr/bin/env python3

import math

import torch
from linear_operator.operators import KroneckerProductLinearOperator

from gpytorch.kernels.matern_kernel import MaternKernel

sqrt5 = math.sqrt(5)
five_thirds = 5.0 / 3.0


class Matern52KernelGrad(MaternKernel):
r"""
Computes a covariance matrix of the Matern52 kernel that models the covariance
between the values and partial derivatives for inputs :math:`\mathbf{x_1}`
and :math:`\mathbf{x_2}`.
See :class:`gpytorch.kernels.Kernel` for descriptions of the lengthscale options.
.. note::
This kernel does not have an `outputscale` parameter. To add a scaling parameter,
decorate this kernel with a :class:`gpytorch.kernels.ScaleKernel`.
:param ard_num_dims: Set this if you want a separate lengthscale for each input
dimension. It should be `d` if x1 is a `n x d` matrix. (Default: `None`.)
:param batch_shape: Set this if you want a separate lengthscale for each batch of input
data. It should be :math:`B_1 \times \ldots \times B_k` if :math:`\mathbf x1` is
a :math:`B_1 \times \ldots \times B_k \times N \times D` tensor.
:param active_dims: Set this if you want to compute the covariance of only
a few input dimensions. The ints corresponds to the indices of the
dimensions. (Default: `None`.)
:param lengthscale_prior: Set this if you want to apply a prior to the
lengthscale parameter. (Default: `None`)
:param lengthscale_constraint: Set this if you want to apply a constraint
to the lengthscale parameter. (Default: `Positive`.)
:param eps: The minimum value that the lengthscale can take (prevents
divide by zero errors). (Default: `1e-6`.)
:ivar torch.Tensor lengthscale: The lengthscale parameter. Size/shape of parameter depends on the
ard_num_dims and batch_shape arguments.
Example:
>>> x = torch.randn(10, 5)
>>> # Non-batch: Simple option
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad())
>>> covar = covar_module(x) # Output: LinearOperator of size (60 x 60), where 60 = n * (d + 1)
>>>
>>> batch_x = torch.randn(2, 10, 5)
>>> # Batch: Simple option
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad())
>>> # Batch: different lengthscale for each batch
>>> covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.Matern52KernelGrad(batch_shape=torch.Size([2]))) # noqa: E501
>>> covar = covar_module(x) # Output: LinearOperator of size (2 x 60 x 60)
"""

def __init__(self, **kwargs):

# remove nu in case it was set
kwargs.pop("nu", None)
super(Matern52KernelGrad, self).__init__(nu=2.5, **kwargs)

def forward(self, x1, x2, diag=False, **params):

lengthscale = self.lengthscale

batch_shape = x1.shape[:-2]
n_batch_dims = len(batch_shape)
n1, d = x1.shape[-2:]
n2 = x2.shape[-2]

if not diag:

K = torch.zeros(*batch_shape, n1 * (d + 1), n2 * (d + 1), device=x1.device, dtype=x1.dtype)

distance_matrix = self.covar_dist(x1.div(lengthscale), x2.div(lengthscale), diag=diag, **params)
exp_neg_sqrt5r = torch.exp(-sqrt5 * distance_matrix)

# differences matrix in each dimension to be used for derivatives
# shape of n1 x n2 x d
outer = x1.view(*batch_shape, n1, 1, d) - x2.view(*batch_shape, 1, n2, d)
outer = outer / lengthscale.unsqueeze(-2) ** 2
# shape of n1 x d x n2
outer = torch.transpose(outer, -1, -2).contiguous()

# 1) Kernel block, cov(f^m, f^n)
# shape is n1 x n2
exp_component = torch.exp(-sqrt5 * distance_matrix)
constant_component = (sqrt5 * distance_matrix).add(1).add(five_thirds * distance_matrix**2)

K[..., :n1, :n2] = constant_component * exp_component

# 2) First gradient block, cov(f^m, omega^n_d)
outer1 = outer.view(*batch_shape, n1, n2 * d)
K[..., :n1, n2:] = outer1 * (-five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat(
[*([1] * (n_batch_dims + 1)), d]
)

# 3) Second gradient block, cov(omega^m_d, f^n)
outer2 = outer.transpose(-1, -3).reshape(*batch_shape, n2, n1 * d)
outer2 = outer2.transpose(-1, -2)
# the - signs on -outer2 and -five_thirds cancel out
K[..., n1:, :n2] = outer2 * (five_thirds * (1 + sqrt5 * distance_matrix) * exp_neg_sqrt5r).repeat(
[*([1] * n_batch_dims), d, 1]
)

# 4) Hessian block, cov(omega^m_d, omega^n_d)
outer3 = outer1.repeat([*([1] * n_batch_dims), d, 1]) * outer2.repeat([*([1] * (n_batch_dims + 1)), d])
kp = KroneckerProductLinearOperator(
torch.eye(d, d, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1) / lengthscale**2,
torch.ones(n1, n2, device=x1.device, dtype=x1.dtype).repeat(*batch_shape, 1, 1),
)

part1 = -five_thirds * exp_neg_sqrt5r
part2 = 5 * outer3
part3 = 1 + sqrt5 * distance_matrix

K[..., n1:, n2:] = part1.repeat([*([1] * n_batch_dims), d, d]).mul_(
# need to use kp.to_dense().mul instead of kp.to_dense().mul_
# because otherwise a RuntimeError is raised due to how autograd works with
# view + inplace operations in the case of 1-dimensional input
part2.sub_(kp.to_dense().mul(part3.repeat([*([1] * n_batch_dims), d, d])))
)

# Symmetrize for stability
if n1 == n2 and torch.eq(x1, x2).all():
K = 0.5 * (K.transpose(-1, -2) + K)

# Apply a perfect shuffle permutation to match the MutiTask ordering
pi1 = torch.arange(n1 * (d + 1)).view(d + 1, n1).t().reshape((n1 * (d + 1)))
pi2 = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
K = K[..., pi1, :][..., :, pi2]

return K
else:
if not (n1 == n2 and torch.eq(x1, x2).all()):
raise RuntimeError("diag=True only works when x1 == x2")

# nu is set to 2.5
kernel_diag = super(Matern52KernelGrad, self).forward(x1, x2, diag=True)
grad_diag = (
five_thirds * torch.ones(*batch_shape, n2, d, device=x1.device, dtype=x1.dtype)
) / lengthscale**2
grad_diag = grad_diag.transpose(-1, -2).reshape(*batch_shape, n2 * d)
k_diag = torch.cat((kernel_diag, grad_diag), dim=-1)
pi = torch.arange(n2 * (d + 1)).view(d + 1, n2).t().reshape((n2 * (d + 1)))
return k_diag[..., pi]

def num_outputs_per_input(self, x1, x2):
return x1.size(-1) + 1
78 changes: 78 additions & 0 deletions test/kernels/test_matern52_kernel_grad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3

import unittest

import torch

from gpytorch.kernels import Matern52KernelGrad
from gpytorch.test.base_kernel_test_case import BaseKernelTestCase


class TestMatern52KernelGrad(unittest.TestCase, BaseKernelTestCase):
def create_kernel_no_ard(self, **kwargs):
return Matern52KernelGrad(**kwargs)

def create_kernel_ard(self, num_dims, **kwargs):
return Matern52KernelGrad(ard_num_dims=num_dims, **kwargs)

def test_kernel(self, cuda=False):
a = torch.tensor([[[1, 2], [2, 4]]], dtype=torch.float)
b = torch.tensor([[[1, 3], [0, 4]]], dtype=torch.float)

actual = torch.tensor(
[
[0.3056225, -0.0000000, 0.5822443, 0.0188260, -0.0209871, 0.0419742],
[0.0000000, 0.5822443, 0.0000000, 0.0209871, -0.0056045, 0.0531832],
[-0.5822443, 0.0000000, -0.8515886, -0.0419742, 0.0531832, -0.0853792],
[0.1304891, -0.2014212, -0.2014212, 0.0336440, -0.0815567, -0.0000000],
[0.2014212, -0.1754366, -0.3768578, 0.0815567, -0.1870145, -0.0000000],
[0.2014212, -0.3768578, -0.1754366, 0.0000000, -0.0000000, 0.0407784],
]
)

kernel = Matern52KernelGrad()

if cuda:
a = a.cuda()
b = b.cuda()
actual = actual.cuda()
kernel = kernel.cuda()

res = kernel(a, b).to_dense()

self.assertLess(torch.norm(res - actual), 1e-5)

def test_kernel_cuda(self):
if torch.cuda.is_available():
self.test_kernel(cuda=True)

def test_kernel_batch(self):
a = torch.tensor([[[1, 2, 3], [2, 4, 0]], [[-1, 1, 2], [2, 1, 4]]], dtype=torch.float)
b = torch.tensor([[[1, 3, 1]], [[2, -1, 0]]], dtype=torch.float).repeat(1, 2, 1)

kernel = Matern52KernelGrad()
res = kernel(a, b).to_dense()

# Compute each batch separately
actual = torch.zeros(2, 8, 8)
actual[0, :, :] = kernel(a[0, :, :].squeeze(), b[0, :, :].squeeze()).to_dense()
actual[1, :, :] = kernel(a[1, :, :].squeeze(), b[1, :, :].squeeze()).to_dense()

self.assertLess(torch.norm(res - actual), 1e-5)

def test_initialize_lengthscale(self):
kernel = Matern52KernelGrad()
kernel.initialize(lengthscale=3.14)
actual_value = torch.tensor(3.14).view_as(kernel.lengthscale)
self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5)

def test_initialize_lengthscale_batch(self):
kernel = Matern52KernelGrad(batch_shape=torch.Size([2]))
ls_init = torch.tensor([3.14, 4.13])
kernel.initialize(lengthscale=ls_init)
actual_value = ls_init.view_as(kernel.lengthscale)
self.assertLess(torch.norm(kernel.lengthscale - actual_value), 1e-5)


if __name__ == "__main__":
unittest.main()

0 comments on commit 25da2cc

Please sign in to comment.