-
I am trying to get a VNNGP to work in a batched setting. To this end, I tried the following code, which is based on the tutorial. import math
import matplotlib.pyplot as plt
import torch
from torch import Tensor
from tqdm.auto import tqdm
from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.mlls import PredictiveLogLikelihood
from gpytorch.models import ApproximateGP
from gpytorch.variational import MeanFieldVariationalDistribution, NNVariationalStrategy
class BatchGPModel(ApproximateGP):
def __init__(self, train_x: Tensor):
batch_shape = train_x.shape[:-2]
inducing_points = torch.clone(train_x)
variational_distribution = MeanFieldVariationalDistribution(inducing_points.size(-2), batch_shape=batch_shape)
variational_strategy = NNVariationalStrategy(self, inducing_points, variational_distribution, 25, 25)
super().__init__(variational_strategy)
self.mean_module = ConstantMean(batch_shape=batch_shape)
self.covar_module = ScaleKernel(RBFKernel(batch_shape=batch_shape), batch_shape=batch_shape)
self.likelihood = GaussianLikelihood(batch_shape=batch_shape)
def forward(self, x: Tensor) -> MultivariateNormal:
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
def __call__(self, x, prior=False, **kwargs):
if x is not None:
if x.dim() == 1:
x = x.unsqueeze(-1)
return self.variational_strategy(x=x, prior=False, **kwargs)
def main():
x = torch.linspace(0, 1, 100)
train_y = torch.stack(
[
torch.sin(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
torch.sin(x * (2 * math.pi)) + 2 * torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
-torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
],
0,
)
train_x = torch.stack([x] * 4).unsqueeze(-1)
num_tasks = 4
# initialize model
model = BatchGPModel(train_x)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = PredictiveLogLikelihood(model.likelihood, model, num_data=x.size(0))
num_batches = model.variational_strategy._total_training_batches
epochs_iter = tqdm(range(50), desc="Epoch")
for _ in epochs_iter:
minibatch_iter = tqdm(range(num_batches), desc="Minibatch", leave=False)
for _ in minibatch_iter:
optimizer.zero_grad()
output = model(x=None)
current_training_indices = model.variational_strategy.current_training_indices
y_batch = train_y[:, current_training_indices]
loss = -mll(output, y_batch).sum()
minibatch_iter.set_postfix(loss=loss.item())
loss.backward()
optimizer.step()
# Get into evaluation (predictive posterior) mode
model.eval()
# Initialize plots
fig, axs = plt.subplots(1, num_tasks, figsize=(4 * num_tasks, 3))
# Make predictions
with torch.no_grad(), settings.fast_pred_var():
test_x = torch.stack([torch.linspace(0, 1, 51)] * num_tasks).unsqueeze(-1)
predictions = model.likelihood(model(test_x))
mean = predictions.mean
lower, upper = predictions.confidence_region()
for task, ax in enumerate(axs):
# Plot training data as black stars
ax.plot(x.detach().numpy(), train_y[task].detach().numpy(), "k*")
# Predictive mean as blue line
ax.plot(test_x[0, :, 0].numpy(), mean[task].numpy(), "b")
# Shade in confidence
ax.fill_between(test_x[0, :, 0].numpy(), lower[task].numpy(), upper[task].numpy().T, alpha=0.5)
ax.set_ylim([-3, 3])
ax.legend(["Observed Data", "Mean", "Confidence"])
ax.set_title(f"Task {task + 1}")
fig.tight_layout()
plt.show()
if __name__ == "__main__":
main() However, this does not work, as the
Does somebody have an idea what's going on here? Maybe @LuhuanWu or @gpleiss ? |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 1 reply
-
Hmmm this seems like a bug on our end. @LuhuanWu could you take a look? |
Beta Was this translation helpful? Give feedback.
-
Thanks for reporting the issue. Please allow me a few weeks to look into this. |
Beta Was this translation helpful? Give feedback.
-
I think this should at least become an issue if there is no fix for this and it is not an usage error. |
Beta Was this translation helpful? Give feedback.
-
Hi @Turakar , I believe I've fixed the issue and submitted an PR here #2375 . Based on this PR and a slightly-modified version of your test codes, I am able to get reasonable results. For example, see the prediction results for k=20 and k=50 in the attached figures. P.S. There is a small mistake in your original code: mll = PredictiveLogLikelihood(model.likelihood, model, num_data=x.size(0))
I also added the ELBO as an option for training objective, which is used in the paper. mll = VariationalELBO(model.likelihood, model, num_data=train_y.size(1)) import math
import matplotlib.pyplot as plt
import torch
from torch import Tensor
#from tqdm.auto import tqdm
from tqdm.notebook import trange, tqdm
import importlib
import gpytorch
importlib.reload(gpytorch)
from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.mlls import VariationalELBO, PredictiveLogLikelihood
from gpytorch.models import ApproximateGP
from gpytorch.variational import MeanFieldVariationalDistribution, NNVariationalStrategy
class BatchGPModel(ApproximateGP):
def __init__(self, train_x: Tensor, k: int, training_batch_size: int):
batch_shape = train_x.shape[:-2]
inducing_points = torch.clone(train_x)
variational_distribution = MeanFieldVariationalDistribution(inducing_points.size(-2), batch_shape=batch_shape)
variational_strategy = NNVariationalStrategy(self, inducing_points, variational_distribution, k=k, training_batch_size=training_batch_size)
super().__init__(variational_strategy)
self.mean_module = ConstantMean(batch_shape=batch_shape)
self.covar_module = ScaleKernel(RBFKernel(batch_shape=batch_shape), batch_shape=batch_shape)
self.likelihood = GaussianLikelihood(batch_shape=batch_shape)
def forward(self, x: Tensor) -> MultivariateNormal:
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
def __call__(self, x, prior=False, **kwargs):
if x is not None:
if x.dim() == 1:
x = x.unsqueeze(-1)
return self.variational_strategy(x=x, prior=False, **kwargs)
torch.manual_seed(42)
x = torch.linspace(0, 1, 100)
train_y = torch.stack(
[
torch.sin(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
torch.sin(x * (2 * math.pi)) + 2 * torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
-torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
],
0,
)
train_x = torch.stack([x] * 4).unsqueeze(-1)
num_tasks = 4
torch.manual_seed(42)
# initialize model
k = 100
model = BatchGPModel(train_x, k=k, training_batch_size=25)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
minimize_predictivell = False
if minimize_predictivell:
mll = PredictiveLogLikelihood(model.likelihood, model, num_data=train_y.size(1))
else:
mll = VariationalELBO(model.likelihood, model, num_data=train_y.size(1))
loss_list = []
ls_list = []
outputscale_list = []
noise_list = []
loss_2_list =[]
num_batches = model.variational_strategy._total_training_batches
epochs = 1000
epochs_iter = tqdm(range(epochs), desc="Epoch")
for epoch in epochs_iter:
minibatch_iter = range(num_batches)
for _ in minibatch_iter:
optimizer.zero_grad()
output = model(x=None)
current_training_indices = model.variational_strategy.current_training_indices
#print(current_training_indices)
y_batch = train_y[:, current_training_indices]
loss = -mll(output, y_batch)
#minibatch_iter.set_postfix(loss=loss.item()
loss0 = loss[0].item()
loss1 = loss[1].item()
loss2 = loss[2].item()
loss_2_list.append(loss[2].item())
loss = loss.sum()
loss.backward()
optimizer.step()
ls = model.covar_module.base_kernel.lengthscale.min().item()
outputscale = model.covar_module.outputscale.mean().item()
noise = model.likelihood.noise.mean().item()
print('Iter %d/%d - Loss: %.3f loss0: %.3f loss1: %.3f loss2: %.3f lengthscale: %.3f outputscale:%.3f noise: %.3f' % (
epoch + 1, epochs, loss.item(), loss0, loss1, loss2,
ls,
outputscale,
noise
))
loss_list.append(loss.item())
ls_list.append(ls)
outputscale_list.append(outputscale)
noise_list.append(noise)
epochs_iter.set_postfix(loss=loss.item())
#Get into evaluation (predictive posterior) mode
model.eval()
# Initialize plots
fig, axs = plt.subplots(1, num_tasks, figsize=(4 * num_tasks, 3))
# Make predictions
with torch.no_grad(), settings.fast_pred_var():
#test_x = torch.stack([torch.linspace(0, 1, 51)] * num_tasks).unsqueeze(-1)
test_x = torch.stack([torch.linspace(-1, 2, 51)] * num_tasks).unsqueeze(-1)
predictions = model.likelihood(model(test_x))
mean = predictions.mean
lower, upper = predictions.confidence_region()
for task, ax in enumerate(axs):
# Plot training data as blaAck stars
ax.plot(x.detach().numpy(), train_y[task].detach().numpy(), "k*", label='Observed data')
# Predictive mean as blue line
ax.plot(test_x[0, :, 0].numpy(), mean[task].numpy(), "b", label='Mean')
# Shade in confidence
ax.fill_between(test_x[0, :, 0].numpy(), lower[task].numpy(), upper[task].numpy().T, alpha=0.5, label="Legend")
ax.set_ylim([-8, 6])
ax.set_title(f"Task {task + 1}")
# Place the legend outside of the figure
handles, labels = axs[-1].get_legend_handles_labels()
fig.legend(handles, labels, bbox_to_anchor=(1.01, 0.75))
plt.subplots_adjust(top=0.8)
fig.suptitle(f"K={k}", fontsize=20)
if minimize_predictivell:
plt.savefig(f"./predictivenll-vnngp-k{k}.png", bbox_inches='tight')
else:
plt.savefig(f"./vnngp-k{k}.png", bbox_inches='tight')
plt.show() |
Beta Was this translation helpful? Give feedback.
Hi @Turakar , I believe I've fixed the issue and submitted an PR here #2375 .
Based on this PR and a slightly-modified version of your test codes, I am able to get reasonable results. For example, see the prediction results for k=20 and k=50 in the attached figures.
P.S. There is a small mistake in your original code:
num_data
here should be the number of data points instead of the model batch size, i.e.num_data=x.size(1)
or equivalentlytrain_y.size(1)
in this example.I also added the ELBO as an option for training objective, which is used in the paper.