-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfix_constraints.py
52 lines (42 loc) · 3.91 KB
/
fix_constraints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jan 10 14:55:12 2022
@author: Miguel A Hombrados
"""
import gpytorch
def fix_constraints(model,likelihood,kernel_type,num_task,method):
if method == "gpi":
if kernel_type == "linear":
likelihood.register_constraint("raw_noise", gpytorch.constraints.Interval(1e-6,2))
likelihood.register_constraint("raw_task_noises", gpytorch.constraints.Interval(1e-2,2))
model.covar_module.kernels[0].base_kernel.register_constraint("raw_variance", gpytorch.constraints.Interval(1e-3,50))
model.covar_module.kernels[0].register_constraint("raw_outputscale", gpytorch.constraints.Interval(1e-3,50))
model.covar_module.kernels[1].register_constraint("raw_bias", gpytorch.constraints.Interval(1e-7,1))
if kernel_type == "rbf":
likelihood.register_constraint("raw_noise", gpytorch.constraints.Interval(1e-8,5e-1)) #(1e-4,1)
likelihood.register_constraint("raw_task_noises", gpytorch.constraints.Interval(1e-8, 1)) #(1e-1,2)
model.covar_module.kernels[0].register_constraint("raw_outputscale", gpytorch.constraints.Interval(0.01,100))
#model.covar_module.kernels[0].base_kernel.register_constraint("raw_outputscale", gpytorch.constraints.Interval(1,100))
model.covar_module.kernels[0].base_kernel.register_constraint("raw_lengthscale", gpytorch.constraints.Interval(1,100)) #(1.5,50) #(1e-3,1)
model.covar_module.kernels[1].register_constraint("raw_bias", gpytorch.constraints.Interval(1e-8, 1)) #(1e-6,1)
elif method == "gpi_ori":
if kernel_type == "rbf":
model.likelihood.noise_covar.register_constraint("raw_noise", gpytorch.constraints.Interval(1e-4, 1)) # 0.5
model.covar_module.kernels[0].register_constraint("raw_outputscale", gpytorch.constraints.Interval(1e-5, 200))
model.covar_module.kernels[0].base_kernel.register_constraint("raw_lengthscale", gpytorch.constraints.Interval(1,100)) #3
model.covar_module.kernels[1].register_constraint("raw_bias", gpytorch.constraints.Interval(1e-7, 1))
elif method == "gpk_sp":
if kernel_type == "rbf":
model.likelihood.noise_covar.register_constraint("raw_noise", gpytorch.constraints.Interval(1e-4, 1)) # 0.5
model.covar_module.kernels[0].register_constraint("raw_outputscale", gpytorch.constraints.Interval(1e-5, 200))
model.covar_module.kernels[0].base_kernel.register_constraint("raw_lengthscale", gpytorch.constraints.Interval(1,100)) #3
model.covar_module.kernels[1].register_constraint("raw_bias", gpytorch.constraints.Interval(1e-7, 1))
elif method == "gpmt":
if kernel_type == "rbf":
model.likelihood.register_constraint("raw_noise", gpytorch.constraints.Interval(1e-7, 1))
model.likelihood.register_constraint("raw_task_noises", gpytorch.constraints.Interval(1e-4, 1))
model.covar_module.covar_module_list[0].task_covar_module.register_constraint("raw_var", gpytorch.constraints.Interval(1e-4, 1))
model.covar_module.covar_module_list[0].data_covar_module.kernels[0].register_constraint("raw_outputscale", gpytorch.constraints.Interval(1e-5, 200))
model.covar_module.covar_module_list[0].data_covar_module.kernels[0].base_kernel.register_constraint("raw_lengthscale", gpytorch.constraints.Interval(1,100)) #3
model.covar_module.covar_module_list[0].data_covar_module.kernels[1].register_constraint("raw_bias", gpytorch.constraints.Interval(1e-7, 1))