-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMTGPclasses.py
99 lines (85 loc) · 4.32 KB
/
MTGPclasses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import gpytorch
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy
from mykernels import bias
class BatchIndependentMultitaskGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood,num_task,kernel_type):
super().__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_task]))
if kernel_type == 'rbf':
kernel_cov = gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_task]))
if kernel_type == 'linear':
kernel_cov = gpytorch.kernels.LinearKernel(batch_shape=torch.Size([num_task]))
if kernel_type == 'matern':
kernel_cov = gpytorch.kernels.MaternKernel(nu = 2.5, batch_shape=torch.Size([num_task]))
self.covar_module = gpytorch.kernels.ScaleKernel(kernel_cov,
batch_shape=torch.Size([num_task])
) + bias(batch_shape=torch.Size([num_task]))
### MultitaskMultivariateNorm
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal.from_batch_mvn(
gpytorch.distributions.MultivariateNormal(mean_x, covar_x,validate_args = True)
)
class MultitaskGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood,num_task,kernel_type):
super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.MultitaskMean(
gpytorch.means.ConstantMean(), num_tasks=num_task
)
if kernel_type == 'rbf':
kernel_cov = gpytorch.kernels.RBFKernel()
if kernel_type == 'linear':
kernel_cov = gpytorch.kernels.LinearKernel()
if kernel_type == 'matern':
kernel_cov = gpytorch.kernels.MaternKernel(nu = 2.5)
data_kernel = gpytorch.kernels.ScaleKernel(kernel_cov)+ bias()
self.covar_module = gpytorch.kernels.LCMKernel(
[data_kernel],
num_tasks=num_task, rank=8
)
#self.covar_module = gpytorch.kernels.MultitaskKernel(
#kernel_cov, num_tasks=num_task, rank=5)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)
class ApproxGPModel_Lap_single(ApproximateGP):
def __init__(self, train_x,kernel_type):
variational_distribution = CholeskyVariationalDistribution(train_x.size(0))
variational_strategy = VariationalStrategy(
self, train_x, variational_distribution, learn_inducing_locations=False
)
super(ApproxGPModel_Lap_single, self).__init__(variational_strategy)
self.mean_module = gpytorch.means.ConstantMean()
if kernel_type == 'rbf':
kernel_cov = gpytorch.kernels.RBFKernel()
if kernel_type == 'linear':
kernel_cov = gpytorch.kernels.LinearKernel()
if kernel_type == 'matern':
kernel_cov = gpytorch.kernels.MaternKernel(nu = 2.5)
self.covar_module = gpytorch.kernels.ScaleKernel(kernel_cov)+ bias()
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
class ExactGPModel_single(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood,kernel_type):
n_batch = 1
super(ExactGPModel_single, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean()
#self.mean_module.initialize(constant=0.)
if kernel_type == 'rbf':
kernel_cov = gpytorch.kernels.RBFKernel()
if kernel_type == 'linear':
kernel_cov = gpytorch.kernels.LinearKernel()
if kernel_type == 'matern':
kernel_cov = gpytorch.kernels.MaternKernel(nu = 2.5)
self.covar_module = gpytorch.kernels.ScaleKernel(kernel_cov)+ bias()
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)