-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathepoch_tv.py
43 lines (35 loc) · 1.09 KB
/
epoch_tv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# -*- coding: utf-8 -*-
"""
Created on Sun Feb 6 16:49:14 2022
@author: mahom
"""
import torch
import gpytorch
from predGPind_ori import predGPind_ori
def train_epoch(model,data,mll,optimizer):
train_loss=0.0
model.train()
(train_x,train_y) = data
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
optimizer.step()
train_loss = loss.item() * train_x.size(0)
return train_loss,output
def valid_epoch(model,likelihood,output,data,mll):
(val_x,val_y) = data
f_val_est = model.forward(val_x)
#y_val_est = likelihood(f_val_est)
loss = -mll(f_val_est,val_y)
valid_loss=loss.item()*val_x.size(0)
y_val_est = likelihood(f_val_est)
valid_error = val_y-y_val_est.mean
#===predictive distribution=================================
model.eval()
likelihood.eval()
with torch.no_grad(): #, gpytorch.settings.fast_pred_var():
predictions = likelihood(model(val_x))
ypred_val = predictions.mean
val_pred_error = ypred_val - val_y
return valid_loss,valid_error,val_pred_error