Skip to content

Commit

Permalink
Fix get_metrics flags and add tests for future (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
jteijema authored Feb 6, 2025
1 parent c415c8d commit 7a8968b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 8 deletions.
25 changes: 17 additions & 8 deletions asreviewcontrib/insights/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def _tnr(labels, intercept, x_absolute=False):

return _slice_metric(x, y, intercept)


def loss(state_obj, priors=False):
"""Compute the loss for active learning problem.
Expand All @@ -185,6 +186,7 @@ def loss(state_obj, priors=False):

return _loss_value(labels)


def get_metrics(
state_obj,
recall=None,
Expand All @@ -196,14 +198,21 @@ def get_metrics(
y_absolute=False,
version=None,
):
recall = (
[recall]
if recall and not isinstance(recall, list)
else [0.1, 0.25, 0.5, 0.75, 0.9]
)
wss = [wss] if wss and not isinstance(wss, list) else [0.95]
erf = [erf] if erf and not isinstance(erf, list) else [0.10]
cm = [cm] if cm and not isinstance(cm, list) else [0.1, 0.25, 0.5, 0.75, 0.9]
def ensure_list_of_floats(value, default):
if value is None:
return default
if isinstance(value, float):
return [value]
if isinstance(value, list) and all(isinstance(i, float) for i in value):
return value
raise ValueError(
f"Invalid input: {value}. Must be a float or a list of floats."
)

recall = ensure_list_of_floats(recall, [0.1, 0.25, 0.5, 0.75, 0.9])
wss = ensure_list_of_floats(wss, [0.95])
erf = ensure_list_of_floats(erf, [0.10])
cm = ensure_list_of_floats(cm, [0.1, 0.25, 0.5, 0.75, 0.9])

labels = _pad_simulation_labels(state_obj, priors=priors)

Expand Down
32 changes: 32 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,35 @@ def test_single_value_formats():
assert isinstance(_wss([1,1,0,0], 0.5), float)
assert isinstance(_loss_value([1,1,0,0]), float)
assert isinstance(_erf([1,1,0,0], 0.5), float)

def test_get_metrics():
with open_state(
Path(TEST_ASREVIEW_FILES, "sim_van_de_schoot_2017_stop_if_min.asreview")
) as s:
metrics = get_metrics(s, wss=[0.75, 0.85, 0.95], erf=[0.75, 0.85, 0.95])

wss_data = next(
(item["value"] for item in metrics["data"]["items"] if item["id"] == "wss"),
None
)
assert wss_data is not None, "WSS key missing in metrics"

erf_data = next(
(item["value"] for item in metrics["data"]["items"] if item["id"] == "erf"),
None
)
assert erf_data is not None, "ERF key missing in metrics"

wss_values = {val[0]: val[1] for val in wss_data}
for value in [0.75, 0.85, 0.95]:
assert value in wss_values, f"WSS value {value} missing in output"

for wss_score in wss_values.values():
assert 0 <= wss_score <= 1, f"WSS value {wss_score} out of expected range"

erf_values = {val[0]: val[1] for val in wss_data}
for value in [0.75, 0.85, 0.95]:
assert value in erf_values, f"ERF value {value} missing in output"

for erf_score in erf_values.values():
assert 0 <= erf_score <= 1, f"ERF value {wss_score} out of expected range"

0 comments on commit 7a8968b

Please sign in to comment.