Skip to content

Commit f2fe8ad

Browse files
authored
Merge pull request #23 from cics-nd/issue-22
Fix issue 22, tests for example scripts
2 parents ab14278 + 0810592 commit f2fe8ad

File tree

6 files changed

+177
-59
lines changed

6 files changed

+177
-59
lines changed

examples/regression_1d.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import os
1010
import sys
11+
from argparse import ArgumentParser
1112
from time import time
1213

1314
import torch
@@ -30,7 +31,7 @@ def f(x):
3031
return np.sin(2.0 * np.pi * x) + np.cos(3.5 * np.pi * x) - 3.0 * x + 5.0
3132

3233

33-
if __name__ == "__main__":
34+
def main(args):
3435
# Create data:
3536
n = 100
3637
x = np.linspace(0, 1, n).reshape((-1, 1))
@@ -41,9 +42,12 @@ def f(x):
4142
kern = kernels.Linear(1) + kernels.Rbf(1) + kernels.Constant(1)
4243

4344
# Try different models:
44-
model = GPR(x, y, kern)
45-
# model = VFE(x, y, kern)
46-
# model.cuda() # If you want to use GPU
45+
if args.model_type == "GPR":
46+
model = GPR(x, y, kern)
47+
elif args.model_type == "VFE":
48+
model = VFE(x, y, kern)
49+
if args.cuda:
50+
model.cuda() # If you want to use GPU
4751

4852
# Train
4953
model.optimize(method="L-BFGS-B", max_iter=100)
@@ -62,17 +66,30 @@ def f(x):
6266
# Show prediction
6367
x_test = x_test.flatten()
6468
plt.figure()
65-
plt.fill_between(x_test, (mu - unc).flatten(), (mu + unc).flatten(),
66-
color=(0.9,) * 3)
69+
plt.fill_between(
70+
x_test, (mu - unc).flatten(), (mu + unc).flatten(), color=(0.9,) * 3
71+
)
6772
plt.plot(x_test, mu)
6873
plt.plot(x_test, f(x_test))
6974
for y_samp_i in y_samp:
7075
plt.plot(x_test, y_samp_i, color=(0.4, 0.7, 1.0), alpha=0.5)
7176
plt.plot(x, y, "o")
7277
if hasattr(model, "Z"):
7378
plt.plot(
74-
model.Z.detach().cpu().numpy(),
75-
1.0 + plt.ylim()[0] * np.ones(model.Z.shape[0]),
76-
"+"
79+
model.Z.detach().cpu().numpy(),
80+
1.0 + plt.ylim()[0] * np.ones(model.Z.shape[0]),
81+
"+",
7782
)
78-
plt.show()
83+
if args.no_plot:
84+
plt.close()
85+
else:
86+
plt.show()
87+
88+
89+
if __name__ == "__main__":
90+
parser = ArgumentParser()
91+
parser.add_argument("--model-type", type=str, choices=("GPR", "VFE"), default="GPR")
92+
parser.add_argument("--cuda", action="store_true")
93+
parser.add_argument("--no-plot", action="store_true")
94+
95+
main(parser.parse_args())

gptorch/mean_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111

1212
from .util import torch_dtype
1313

14-
torch.set_default_dtype(torch_dtype)
15-
1614

1715
class Constant(torch.nn.Module):
1816
"""
@@ -26,13 +24,13 @@ def __init__(self, dy: int, val: torch.Tensor = None):
2624
raise ValueError("Provided val doesn't match output dimension")
2725
val = val.clone()
2826
else:
29-
val = torch.zeros(dy)
27+
val = torch.zeros(dy, dtype=torch_dtype)
3028

3129
self._dy = dy
3230
self.val = torch.nn.Parameter(val)
3331

3432
def forward(self, x):
35-
output = torch.zeros(x.shape[0], self._dy)
33+
output = torch.zeros(x.shape[0], self._dy, dtype=torch_dtype)
3634
if self._is_cuda():
3735
output = output.cuda()
3836
return output + self.val

test/test_base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# File: test_base.py
2+
# File Created: Saturday, 30th October 2021 7:39:23 am
3+
# Author: Steven Atkinson (steven@atkinson.mn)
14

25
"""
36
Basic tests for the repo
@@ -10,8 +13,10 @@ def test_torch_dtype():
1013
"""
1114

1215
import torch
16+
1317
dtype = torch.get_default_dtype()
14-
import gptorch
18+
import gptorch # noqa F401
19+
1520
new_dtype = torch.get_default_dtype()
1621

1722
assert new_dtype == dtype

test/test_examples.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# File: test_examples.py
2+
# File Created: Saturday, 30th October 2021 9:04:58 am
3+
# Author: Steven Atkinson (steven@atkinson.mn)
4+
5+
import os
6+
from subprocess import check_call
7+
from sys import executable
8+
9+
import pytest
10+
11+
args = (
12+
("regression_1d.py", "--no-plot"),
13+
("regression_1d.py", "--no-plot", "--model-type", "VFE"),
14+
)
15+
16+
17+
@pytest.mark.parametrize("args", args)
18+
def test_example(args):
19+
basename, args = args[0], args[1:]
20+
script_path = os.path.join(os.path.dirname(__file__), "..", "examples", basename)
21+
check_call((executable, script_path) + args)
22+
23+
24+
if __name__ == "__main__":
25+
pytest.main()

test/test_models/common.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from gptorch.models.base import GPModel
1010
from gptorch.util import torch_dtype
1111

12-
torch.set_default_dtype(torch_dtype)
1312

14-
15-
def gaussian_predictions(model: GPModel, x_test: torch.Tensor,
16-
expected_mu: np.ndarray, expected_s: np.ndarray):
13+
def gaussian_predictions(
14+
model: GPModel,
15+
x_test: torch.Tensor,
16+
expected_mu: np.ndarray,
17+
expected_s: np.ndarray,
18+
):
1719
"""
1820
Every GP model with a gaussian likelihood needs the same set of tests run on
1921
its ._predict() method.
@@ -29,10 +31,20 @@ def gaussian_predictions(model: GPModel, x_test: torch.Tensor,
2931
assert mu_diag.shape[1] == model.Y.shape[1]
3032
assert all([ss == ms for ss, ms in zip(mu_diag.shape, s_diag.shape)])
3133

32-
assert all([a == pytest.approx(e) for a, e in zip(
33-
mu_diag.detach().numpy().flatten(), expected_mu.flatten())])
34-
assert all([a == pytest.approx(e) for a, e in zip(
35-
s_diag.detach().numpy().flatten(), expected_s.diagonal().flatten())])
34+
assert all(
35+
[
36+
a == pytest.approx(e)
37+
for a, e in zip(mu_diag.detach().numpy().flatten(), expected_mu.flatten())
38+
]
39+
)
40+
assert all(
41+
[
42+
a == pytest.approx(e)
43+
for a, e in zip(
44+
s_diag.detach().numpy().flatten(), expected_s.diagonal().flatten()
45+
)
46+
]
47+
)
3648

3749
# Predictions with full covariance
3850
mu_full, s_full = model._predict(x_test, diag=False)
@@ -43,8 +55,16 @@ def gaussian_predictions(model: GPModel, x_test: torch.Tensor,
4355
assert mu_full.shape[0] == x_test.shape[0]
4456
assert mu_full.shape[1] == model.Y.shape[1]
4557
assert all([ss == x_test.shape[0] for ss in s_full.shape])
46-
47-
assert all([a == pytest.approx(e) for a, e in zip(
48-
mu_full.detach().numpy().flatten(), expected_mu.flatten())])
49-
assert all([a == pytest.approx(e) for a, e in zip(
50-
s_full.detach().numpy().flatten(), expected_s.flatten())])
58+
assert all(
59+
[
60+
a == pytest.approx(e)
61+
for a, e in zip(mu_full.detach().numpy().flatten(), expected_mu.flatten())
62+
]
63+
)
64+
assert all(
65+
[
66+
a == pytest.approx(e)
67+
for a, e in zip(s_full.detach().numpy().flatten(), expected_s.flatten())
68+
]
69+
)
70+

0 commit comments

Comments
 (0)