Skip to content

Commit

Permalink
Merge pull request #648 from aai-institute/fix/grouped-dataset-tweaks
Browse files Browse the repository at this point in the history
grouped dataset tweaks
  • Loading branch information
mdbenito authored Feb 21, 2025
2 parents 26cade3 + d2c55da commit bb1d4ad
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
`GroupedDataset`, fixing inconsistencies in how the latter operates on indices.
Also, both now return objects of the same type when slicing.
[PR #631](https://github.com/aai-institute/pyDVL/pull/631)
[PR #648](https://github.com/aai-institute/pyDVL/pull/648)
- Use tighter bounds for the calculation of the minimal sample size that guarantees
an epsilon-delta approximation in group testing (Jia et al. 2023)
[PR #602](https://github.com/aai-institute/pyDVL/pull/602)
Expand Down
57 changes: 44 additions & 13 deletions src/pydvl/valuation/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,22 @@ class RawData:
x: NDArray
y: NDArray

def __post_init__(self):
try:
if len(self.x) != len(self.y):
raise ValueError("x and y must have the same length")
except TypeError as e:
raise TypeError("x and y must be numpy arrays") from e

# Make the unpacking operator work
def __iter__(self): # No way to type the return Iterator properly
return iter((self.x, self.y))

def __getitem__(self, item: int | slice | Sequence[int]) -> RawData:
return RawData(self.x[item], self.y[item])
return RawData(np.atleast_1d(self.x[item]), np.atleast_1d(self.y[item]))

def __len__(self):
return len(self.x)


class Dataset:
Expand Down Expand Up @@ -437,7 +447,7 @@ def __init__(
!!! tip "Changed in version 0.10.0"
No longer holds split data, but only x, y and group information. Added
methods to retrieve indices for groups and vicecersa.
methods to retrieve indices for groups and vice versa.
"""
super().__init__(
x=x,
Expand All @@ -456,7 +466,12 @@ def __init__(
)

# data index -> abstract index (group id)
self.data_to_group: NDArray[np.int_] = np.array(data_groups, dtype=int)
try:
self.data_to_group: NDArray[np.int_] = np.array(data_groups, dtype=int)
except ValueError as e:
raise ValueError(
"data_groups must be a mapping from integer data indices to integer group ids"
) from e
# abstract index (group id) -> data index
self.group_to_data: OrderedDict[int, list[int]] = OrderedDict(
{k: [] for k in set(data_groups)}
Expand All @@ -469,6 +484,11 @@ def __init__(
if group_names is not None
else np.array(list(self.group_to_data.keys()), dtype=np.str_)
)
if len(self._group_names) != len(self.group_to_data):
raise ValueError(
f"The number of group names ({len(self._group_names)}) "
f"does not match the number of groups ({len(self.group_to_data)})"
)

def __len__(self):
return len(self._indices)
Expand All @@ -478,13 +498,14 @@ def __getitem__(
) -> GroupedDataset:
if isinstance(idx, int):
idx = [idx]
indices = self.data_indices(idx)
return GroupedDataset(
x=self._x[self.data_indices(idx)],
y=self._y[self.data_indices(idx)],
data_groups=self.data_to_group[self.data_indices(idx)],
x=self._x[indices],
y=self._y[indices],
data_groups=self.data_to_group[indices],
feature_names=self.feature_names,
target_names=self.target_names,
data_names=self._data_names[self.data_indices(idx)],
data_names=self._data_names[indices],
group_names=self._group_names[idx],
description="(SLICED): " + self.description,
)
Expand Down Expand Up @@ -699,7 +720,7 @@ def from_arrays(
point. The length of this array must be equal to the number of
data points in the dataset.
kwargs: Additional keyword arguments that will be passed to the
[Dataset][pydvl.valuation.dataset.Dataset] constructor.
[GroupedDataset][pydvl.valuation.dataset.GroupedDataset] constructor.
Returns:
Dataset with the passed X and y arrays split across training and
Expand All @@ -708,12 +729,12 @@ def from_arrays(
!!! tip "New in version 0.4.0"
!!! tip "Changed in version 0.6.0"
Added kwargs to pass to the [Dataset][pydvl.valuation.dataset.Dataset]
constructor.
Added kwargs to pass to the
[GroupedDataset][pydvl.valuation.dataset.GroupedDataset] constructor.
!!! tip "Changed in version 0.10.0"
Returns a tuple of two [GroupedDataset][pydvl.valuation.dataset.GroupedDataset]
objects.
Returns a tuple of two
[GroupedDataset][pydvl.valuation.dataset.GroupedDataset] objects.
"""

if data_groups is None:
Expand All @@ -734,7 +755,11 @@ def from_arrays(

@classmethod
def from_dataset(
cls, data: Dataset, data_groups: Sequence[int] | NDArray[np.int_]
cls,
data: Dataset,
data_groups: Sequence[int] | NDArray[np.int_],
group_names: Sequence[str] | NDArray[np.str_] | None = None,
**kwargs,
) -> GroupedDataset:
"""Creates a [GroupedDataset][pydvl.valuation.dataset.GroupedDataset] object from a
[Dataset][pydvl.valuation.dataset.Dataset] object and a mapping of data groups.
Expand All @@ -755,6 +780,10 @@ def from_dataset(
data_groups: An array holding the group index or name for each data
point. The length of this array must be equal to the number of
data points in the dataset.
group_names: Names of the groups. If not provided, the numerical group ids
from `data_groups` will be used.
kwargs: Additional arguments to be passed to the
[GroupedDataset][pydvl.valuation.dataset.GroupedDataset] constructor.
Returns:
A [GroupedDataset][pydvl.valuation.dataset.GroupedDataset] with the initial
Expand All @@ -767,4 +796,6 @@ def from_dataset(
feature_names=data.feature_names,
target_names=data.target_names,
description=data.description,
group_names=group_names,
**kwargs,
)
11 changes: 5 additions & 6 deletions src/pydvl/valuation/stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,13 +695,12 @@ class HistoryDeviation(StoppingCriterion):
r"""A simple check for relative distance to a previous step in the
computation.
The method used by (Ghorbani and Zou, 2019)<sup><a href="#ghorbani_data_2019">1</a></sup> computes the relative
distances between the current values $v_i^t$ and the values at the previous
checkpoint $v_i^{t-\tau}$. If the sum is below a given threshold, the
computation is terminated.
The method used by Ghorbani and Zou, (2019)<sup><a
href="#ghorbani_data_2019">1</a></sup> computes the relative distances between the
current values $v_i^t$ and the values at the previous checkpoint $v_i^{t-\tau}$. If
the sum is below a given threshold, the computation is terminated.
$$\sum_{i=1}^n \frac{\left| v_i^t - v_i^{t-\tau} \right|}{v_i^t} <
\epsilon.$$
$$\sum_{i=1}^n \frac{\left| v_i^t - v_i^{t-\tau} \right|}{v_i^t} < \epsilon.$$
When the denominator is zero, the summand is set to the value of $v_i^{
t-\tau}$.
Expand Down
106 changes: 104 additions & 2 deletions tests/valuation/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from sklearn.datasets import load_wine, make_classification

from pydvl.valuation.dataset import Dataset, GroupedDataset
from pydvl.valuation.dataset import Dataset, GroupedDataset, RawData
from pydvl.valuation.result import ValuationResult


Expand Down Expand Up @@ -108,7 +108,7 @@ def test_grouped_dataset_results():
train_size = 0.5
data_groups = np.random.randint(low=0, high=3, size=len(X)).flatten()
train, test = GroupedDataset.from_arrays(
X, y, data_groups=data_groups, train_size=train_size
X, y, train_size=train_size, data_groups=data_groups
)

v = ValuationResult.zeros(indices=train.indices, data_names=train.names)
Expand Down Expand Up @@ -171,6 +171,32 @@ def test_getitem_returns_correct_grouped_dataset(
assert np.array_equal(sliced_dataset.names, expected_group_names)


def test_default_group_names():
"""Test that default group_names are set to the string representations of group ids
when not provided."""
x = np.array([[1, 2], [3, 4], [5, 6]])
y = np.array([0, 1, 0])
data_groups = [0, 1, 0]
dataset = GroupedDataset(x=x, y=y, data_groups=data_groups)
# Default group_names should be created as {group_id: str(group_id)} for each group
# present.
expected = ["0", "1"]
assert all(dataset.names == expected)


def test_incomplete_group_names():
"""Test that providing an incomplete group_names dictionary raise an exception."""
x = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
y = np.array([0, 1, -1, 1])
with pytest.raises(ValueError, match="The number of group names"):
_ = GroupedDataset(
x=x,
y=y,
data_groups=[0, 1, 0, 2],
group_names=["g1", "g3"],
)


@pytest.mark.parametrize(
"idx, expected_x, expected_y, expected_names",
[
Expand Down Expand Up @@ -525,3 +551,79 @@ def test_grouped_logical_indices_returns_correct_indices(
)
result = dataset.logical_indices(data_indices)
assert np.array_equal(result, expected_indices)


@pytest.mark.parametrize(
"x, y",
[
(np.array([1, 2]), np.array([1, 2, 3])),
(np.array([1, 2, 3]), np.array([1, 2])),
(np.zeros((3, 2)), np.zeros(2)),
],
ids=["x_shorter", "y_shorter", "different_shapes"],
)
def test_rawdata_creation_raises_on_length_mismatch(x, y):
with pytest.raises(ValueError, match="x and y must have the same length"):
RawData(x, y)


@pytest.mark.parametrize(
"x, y", [(np.array([[1]]), 1), (3, np.array([1])), (1, 2), (None, None)]
)
def test_rawdata_creation_raises_on_non_arrays(x, y):
with pytest.raises(TypeError, match="x and y must be numpy arrays"):
RawData(x, y)


def test_rawdata_iteration():
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
data = RawData(x, y)
unpacked_x, unpacked_y = data
assert np.array_equal(unpacked_x, x)
assert np.array_equal(unpacked_y, y)


@pytest.mark.parametrize(
"idx, expected_x, expected_y",
[
(
1,
np.array([2]),
np.array([5]),
),
(
slice(1, None),
np.array([2, 3]),
np.array([5, 6]),
),
(
[0, 2],
np.array([1, 3]),
np.array([4, 6]),
),
],
ids=["single_index", "slice", "sequence"],
)
def test_rawdata_getitem(idx, expected_x, expected_y):
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
data = RawData(x, y)
result = data[idx]
assert isinstance(result, RawData)
assert np.array_equal(result.x, expected_x)
assert np.array_equal(result.y, expected_y)


@pytest.mark.parametrize(
"x, y, length",
[
(np.array([1, 2, 3]), np.array([4, 5, 6]), 3),
(np.zeros((5, 2)), np.zeros(5), 5),
(np.array([]), np.array([]), 0),
],
ids=["1d_arrays", "2d_arrays", "empty_arrays"],
)
def test_rawdata_len(x, y, length):
data = RawData(x, y)
assert len(data) == length

0 comments on commit bb1d4ad

Please sign in to comment.