Skip to content

Commit a80296c

Browse files
authored
Merge pull request #18 from cics-nd/optimize
Replace .compute_loss() with .loss()
2 parents a333ddd + 28a8057 commit a80296c

File tree

3 files changed

+10
-5
lines changed

3 files changed

+10
-5
lines changed

gptorch/models/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def optimize(self, method='Adam', max_iter=2000, verbose=True,
236236
for idx in range(max_iter):
237237
self.optimizer.zero_grad()
238238
# forward
239-
loss = self.compute_loss()
239+
loss = self.loss()
240240
# backward
241241
loss.backward()
242242
self.optimizer.step()
@@ -248,7 +248,7 @@ def optimize(self, method='Adam', max_iter=2000, verbose=True,
248248

249249
def closure():
250250
self.optimizer.zero_grad()
251-
loss = self.compute_loss()
251+
loss = self.loss()
252252
loss.backward()
253253
return loss
254254

@@ -260,7 +260,7 @@ def closure():
260260
for idx in range(max_iter):
261261
self.optimizer.zero_grad()
262262
# forward
263-
loss = self.compute_loss()
263+
loss = self.loss()
264264
# backward
265265
loss.backward()
266266
self.optimizer.step()
@@ -272,7 +272,7 @@ def closure():
272272

273273
def closure():
274274
self.optimizer.zero_grad()
275-
loss = self.compute_loss()
275+
loss = self.loss()
276276
loss.backward()
277277
return loss
278278

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
]
2525

2626
setup(name="gptorch",
27-
version="0.3.0",
27+
version="0.3.1",
2828
description="gptorch - a Gaussian process toolbox built on PyTorch",
2929
author="Yinhao Zhu, Steven Atkinson",
3030
author_email="yzhu10@nd.edu, satkinso@nd.edu",

test/test_models/test_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def test_cpu(self):
4747
assert not gp.X.is_cuda
4848
assert not gp.Y.is_cuda
4949

50+
def test_optimize(self):
51+
gp = self._get_model()
52+
gp.optimize(max_iter=2)
53+
gp.optimize(method="L-BFGS-B", max_iter=2)
54+
5055
def test_predict_f(self):
5156
self._predict_fy("predict_f")
5257

0 commit comments

Comments
 (0)