forked from uncertainty-toolbox/uncertainty-toolbox
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_metrics_scoring_rule.py
84 lines (67 loc) · 2.37 KB
/
test_metrics_scoring_rule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Tests for scoring rule metrics.
"""
import numpy as np
import pytest
from uncertainty_toolbox.metrics_scoring_rule import (
nll_gaussian,
crps_gaussian,
check_score,
interval_score,
)
@pytest.fixture
def supply_test_set():
y_pred = np.array([1, 2, 3])
y_std = np.array([0.1, 0.5, 1])
y_true = np.array([1.5, 3, 2])
return y_pred, y_std, y_true
def test_nll_gaussian_on_test_set(supply_test_set):
"""Test Gaussian NLL on the test set for some dummy values."""
assert np.abs(nll_gaussian(*supply_test_set) - 4.920361108686675) < 1e-6
def test_nll_gaussian_on_one_pt():
"""Sanity check by testing one point at mean of gaussian."""
y_pred = np.array([0])
y_true = np.array([0])
y_std = np.array([1 / np.sqrt(2 * np.pi)])
assert np.abs(nll_gaussian(y_pred, y_std, y_true)) < 1e-6
def test_crps_gaussian_on_test_set(supply_test_set):
"""Test CRPS on the test set for some dummy values."""
assert np.abs(crps_gaussian(*supply_test_set) - 0.59080610693) < 1e-6
def test_check_score_on_test_set(supply_test_set):
"""Test check score on the test set for some dummy values."""
assert np.abs(check_score(*supply_test_set) - 0.29801437323836477) < 1e-6
def test_check_score_on_one_pt():
"""Sanity check to show that check score is minimized (i.e. 0) if data
occurs at the exact requested quantile."""
y_pred = np.array([0])
y_true = np.array([1])
y_std = np.array([1])
score = check_score(
y_pred=y_pred,
y_std=y_std,
y_true=y_true,
start_q=0.5 + 0.341,
end_q=0.5 + 0.341,
resolution=1,
)
assert np.abs(score) < 1e-2
def test_interval_score_on_test_set(supply_test_set):
"""Test interval score on the test set for some dummy values."""
assert np.abs(interval_score(*supply_test_set) - 3.20755700861995) < 1e-6
def test_interval_score_on_one_pt():
"""Sanity check on interval score. For one point in the center of the
distribution and intervals one standard deviation and two standard
deviations away, should return ((1 std) * 2 + (2 std) * 2) / 2 = 3.
"""
y_pred = np.array([0])
y_true = np.array([0])
y_std = np.array([1])
score = interval_score(
y_pred=y_pred,
y_std=y_std,
y_true=y_true,
start_p=0.682,
end_p=0.954,
resolution=2,
)
assert np.abs(score - 3) < 1e-2