Skip to content

Commit 5fe9e9e

Browse files
committed
Replace .compute_loss() with .loss()
1 parent a333ddd commit 5fe9e9e

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

gptorch/models/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def optimize(self, method='Adam', max_iter=2000, verbose=True,
236236
for idx in range(max_iter):
237237
self.optimizer.zero_grad()
238238
# forward
239-
loss = self.compute_loss()
239+
loss = self.loss()
240240
# backward
241241
loss.backward()
242242
self.optimizer.step()
@@ -248,7 +248,7 @@ def optimize(self, method='Adam', max_iter=2000, verbose=True,
248248

249249
def closure():
250250
self.optimizer.zero_grad()
251-
loss = self.compute_loss()
251+
loss = self.loss()
252252
loss.backward()
253253
return loss
254254

@@ -260,7 +260,7 @@ def closure():
260260
for idx in range(max_iter):
261261
self.optimizer.zero_grad()
262262
# forward
263-
loss = self.compute_loss()
263+
loss = self.loss()
264264
# backward
265265
loss.backward()
266266
self.optimizer.step()
@@ -272,7 +272,7 @@ def closure():
272272

273273
def closure():
274274
self.optimizer.zero_grad()
275-
loss = self.compute_loss()
275+
loss = self.loss()
276276
loss.backward()
277277
return loss
278278

test/test_models/test_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def test_cpu(self):
4747
assert not gp.X.is_cuda
4848
assert not gp.Y.is_cuda
4949

50+
def test_optimize(self):
51+
gp = self._get_model()
52+
gp.optimize(max_iter=2)
53+
gp.optimize(method="L-BFGS-B", max_iter=2)
54+
5055
def test_predict_f(self):
5156
self._predict_fy("predict_f")
5257

0 commit comments

Comments
 (0)