Skip to content

Commit

Permalink
Merge pull request #649 from aai-institute/fix/batching_and_skip_inde…
Browse files Browse the repository at this point in the history
…x_tests

Fix/batching and skip index tests
  • Loading branch information
mdbenito authored Feb 21, 2025
2 parents 2472fff + 1b4b7d4 commit 5659ca8
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 89 deletions.
5 changes: 4 additions & 1 deletion src/pydvl/valuation/methods/beta_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ def __init__(
alpha: float,
beta: float,
progress: bool = False,
skip_converged: bool = False,
):
super().__init__(utility, sampler, is_done, progress=progress)
super().__init__(
utility, sampler, is_done, skip_converged=skip_converged, progress=progress
)

self.alpha = alpha
self.beta = beta
Expand Down
8 changes: 2 additions & 6 deletions src/pydvl/valuation/methods/classwise_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,9 @@
class ClasswiseShapleyValuation(Valuation):
"""Class to compute Class-wise Shapley values.
It proceeds by sampling independent permutations of the index set
for each label and index sets sampled from the powerset of the complement
(with respect to the currently evaluated label).
Args:
utility: Classwise utility object with model and classwise scoring function.
sampler: Classwise sampling scheme to use.
utility: Class-wise utility object with model and class-wise scoring function.
sampler: Class-wise sampling scheme to use.
is_done: Stopping criterion to use.
progress: Whether to show a progress bar.
normalize_values: Whether to normalize values after valuation.
Expand Down
7 changes: 6 additions & 1 deletion src/pydvl/valuation/methods/delta_shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class DeltaShapleyValuation(SemivalueValuation):
upper_bound: The upper bound of the size of the subsets to sample from.
seed: The seed for the random number generator used by the sampler.
progress: Whether to show a progress bar
skip_converged: Whether to skip converged indices, as determined by the
stopping criterion's `converged` array.
"""

algorithm_name = "Delta-Shapley"
Expand All @@ -67,6 +69,7 @@ def __init__(
lower_bound: int,
upper_bound: int,
seed: Seed | None = None,
skip_converged: bool = False,
progress: bool = False,
):
sampler = StratifiedSampler(
Expand All @@ -79,7 +82,9 @@ def __init__(
)
self.lower_bound = lower_bound
self.upper_bound = upper_bound
super().__init__(utility, sampler, is_done, progress=progress)
super().__init__(
utility, sampler, is_done, progress=progress, skip_converged=skip_converged
)

def log_coefficient(self, n: int, k: int) -> float:
# assert self.lower_bound <= k <= self.upper_bound, "Invalid subset size"
Expand Down
9 changes: 4 additions & 5 deletions src/pydvl/valuation/methods/semivalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from abc import abstractmethod
from typing import Any

import numpy as np
from joblib import Parallel, delayed
from typing_extensions import Self

Expand Down Expand Up @@ -58,8 +57,8 @@ class SemivalueValuation(Valuation):
utility: Object to compute utilities.
sampler: Sampling scheme to use.
is_done: Stopping criterion to use.
skip_converged: Whether to skip converged indices. Convergence is determined
by the stopping criterion's `converged` array.
skip_converged: Whether to skip converged indices, as determined by the
stopping criterion's `converged` array.
progress: Whether to show a progress bar.
"""

Expand Down Expand Up @@ -132,9 +131,9 @@ def fit(self, data: Dataset) -> Self:
for update in batch:
self.result = updater(update)
if self.skip_converged:
self.sampler.skip_indices = np.where(
self.sampler.skip_indices = data.indices[
self.is_done.converged
)[0]
]
if self.is_done(self.result):
flag.set()
self.sampler.interrupt()
Expand Down
4 changes: 4 additions & 0 deletions src/pydvl/valuation/samplers/classwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def get_unique_labels(array: NDArray) -> NDArray:
class ClasswiseSampler(IndexSampler):
"""A sampler that samples elements from a dataset in two steps, based on the labels.
It proceeds by sampling out-of-class indices (training points with a different
label to the point of interest), and in-class indices (training points with the
same label as the point of interest), in the complement.
Used by the [class-wise Shapley valuation
method][pydvl.valuation.methods.classwise_shapley.ClasswiseShapleyValuation].
Expand Down
14 changes: 13 additions & 1 deletion src/pydvl/valuation/samplers/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,22 @@ def __init__(
):
super().__init__(seed=seed, truncation=truncation, batch_size=batch_size)

@property
def skip_indices(self) -> IndexSetT:
return self._skip_indices

@skip_indices.setter
def skip_indices(self, indices: IndexSetT):
self._skip_indices = indices

def _generate(self, indices: IndexSetT) -> SampleGenerator:
"""Generates the permutation samples.
Args:
indices:
indices: The indices to sample from. If empty, no samples are generated. If
[skip_indices][pydvl.valuation.samplers.base.IndexSampler.skip_indices]
is set, these indices are removed from the set before generating the
permutation.
"""
if len(indices) == 0:
return
Expand Down
13 changes: 7 additions & 6 deletions src/pydvl/valuation/samplers/powerset.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,14 +278,15 @@ def make_strategy(

@abstractmethod
def _generate(self, indices: IndexSetT) -> SampleGenerator:
"""Generates samples iterating in sequence over the outer indices, then over
subsets of the complement of the current index. Each PowersetSampler defines
its own
[subset_iterator][pydvl.valuation.samplers.PowersetSampler.subset_iterator] to
generate the subsets.
"""Generates samples over the powerset of `indices`
Each `PowersetSampler` defines its own way to generate the subsets by
implementing this method. The outer loop is handled by the `index_iterator`.
Batching is handled by the `generate_batches` method.
Args:
indices:"""
indices: The set from which to generate samples.
"""
...

def log_weight(self, n: int, subset_len: int) -> float:
Expand Down
29 changes: 0 additions & 29 deletions tests/valuation/methods/test_semivalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,32 +123,3 @@ def test_games(
# plt.show()

check_values(result, exact_result, atol=0.1)


@pytest.mark.flaky(reruns=1)
@pytest.mark.parametrize(
"test_game",
[("shoes", {"left": 3, "right": 2})],
indirect=["test_game"],
)
@pytest.mark.parametrize("n_jobs", [1, 2])
def test_batch_size(test_game, n_jobs, seed):
def compute_semivalues(batch_size, n_jobs=n_jobs, seed=seed):
valuation = BetaShapleyValuation(
utility=test_game.u,
sampler=UniformSampler(batch_size=batch_size, seed=seed),
is_done=MaxUpdates(100),
progress=False,
alpha=1,
beta=1,
)
with parallel_config(n_jobs=n_jobs):
valuation.fit(test_game.data)
return valuation.values()

timed_fn = timed(compute_semivalues)
result_single_batch = timed_fn(batch_size=1)
result_multi_batch = timed_fn(batch_size=5)

# Occasionally, batch_2 arrives before batch_1, so rtol isn't always 0.
check_values(result_single_batch, result_multi_batch, rtol=1e-4, atol=1e-3)
102 changes: 65 additions & 37 deletions tests/valuation/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
from itertools import islice, takewhile
from typing import Any, Iterator, Type
from typing import Any, Callable, Iterator, Type

import numpy as np
import pytest
Expand All @@ -11,6 +11,7 @@

from pydvl.utils.numeric import logcomb, powerset
from pydvl.utils.types import Seed
from pydvl.valuation import EvaluationStrategy, IndexSampler
from pydvl.valuation.samplers import (
AntitheticOwenSampler,
AntitheticPermutationSampler,
Expand Down Expand Up @@ -41,7 +42,8 @@
UniformSampler,
)
from pydvl.valuation.samplers.permutation import PermutationSamplerBase
from pydvl.valuation.types import IndexSetT
from pydvl.valuation.types import IndexSetT, Sample, SampleGenerator
from pydvl.valuation.utility.base import UtilityBase

from .. import recursive_make
from . import _check_idxs, _check_subsets
Expand Down Expand Up @@ -636,54 +638,80 @@ def test_sampler_weights(
)


class TestSampler(PowersetSampler):
def __init__(self):
super().__init__(batch_size=1, index_iteration=FiniteSequentialIndexIteration)

def _generate(self, indices: IndexSetT):
pass

def sample_limit(self, indices: IndexSetT) -> int | None:
pass


@pytest.mark.parametrize(
"sampler_cls, sampler_kwargs, n_batches",
[
(DeterministicUniformSampler, {}, lambda n: 2 ** (n - 1)),
(UniformSampler, {}, lambda n: 2 ** (n - 1)),
(AntitheticSampler, {}, lambda n: 2 ** (n - 1)),
(LOOSampler, {}, lambda n: n),
(PermutationSampler, {}, lambda n: math.factorial(n)),
(AntitheticPermutationSampler, {}, lambda n: math.factorial(n)),
],
)
@pytest.mark.parametrize(
"indices, skip, expected",
[
(np.arange(6), np.array([2, 4]), [0, 1, 3, 5]),
(np.arange(6), np.empty(0), np.arange(6)),
(np.arange(5), np.array([2, 4]), [0, 1, 3]),
(np.arange(3), np.empty(0), np.arange(3)),
(np.empty(0), np.arange(6), np.empty(0)),
],
)
def test_skip_indices(indices, skip, expected):
sampler = TestSampler()
def test_skip_indices(
sampler_cls, sampler_kwargs, n_batches, indices, skip, expected, seed
):
sampler_kwargs["batch_size"] = 2
sampler = recursive_make(sampler_cls, sampler_kwargs, seed=seed)
sampler.skip_indices = skip

result = list(sampler.index_iterator(indices))
# Check that the outer iteration skips indices:
if hasattr(sampler, "index_iterator"):
outer_indices = list(islice(sampler.index_iterator(indices), len(indices)))
assert set(outer_indices) == set(expected)

assert set(result) == set(expected), f"Expected {expected}, but got {result}"
# Check that the generated samples skip indices...
batches = list(
islice(sampler.generate_batches(indices), max(1, n_batches(len(indices))))
)
all_samples = list(flatten(batches))

# ... in sample.subset for permutation samplers
if isinstance(sampler, PermutationSamplerBase):
assert all(
all(idx in expected for idx in sample.subset) for sample in all_samples
)
else: # ... in sample.idx for other samplers
assert all(sample.idx in expected for sample in all_samples)

def test_skip_indices_after_first_batch():
n_indices = 4
indices = np.arange(n_indices)
skip_indices = indices[:2]

# Generate all samples in one batch
sampler = DeterministicUniformSampler(batch_size=2**n_indices)
class TestBatchSampler(IndexSampler):
def __init__(self, batch_size):
super().__init__(batch_size)

batches = sampler.generate_batches(indices)
first_batch = list(next(batches))
assert first_batch, "First batch should not be empty"
def sample_limit(self, indices: IndexSetT) -> int | None: ...

# Skip indices for the next batch of all samples
sampler.skip_indices = skip_indices
def _generate(self, indices: IndexSetT) -> SampleGenerator:
yield from (Sample(idx, np.empty_like(indices)) for idx in indices)

next_batch = list(next(batches))
def log_weight(self, n: int, subset_len: int) -> float: ...

effective_outer_indices = np.setdiff1d(indices, skip_indices)
assert len(next_batch) == len(effective_outer_indices) * 2 ** (n_indices - 1)
for sample in next_batch:
assert sample.idx not in skip_indices, (
f"Sample with skipped index {sample.idx} found"
)
def make_strategy(
self,
utility: UtilityBase,
log_coefficient: Callable[[int, int], float] | None = None,
) -> EvaluationStrategy: ...


@pytest.mark.parametrize("indices", [np.arange(1), np.arange(23)])
@pytest.mark.parametrize("batch_size", [1, 2, 7])
def test_batching(indices, batch_size):
sampler = TestBatchSampler(batch_size)
batches = list(sampler.generate_batches(indices))

assert all(hasattr(batch, "__iter__") for batch in batches)

assert len(batches) == math.ceil(len(indices) / batch_size)
assert len(batches[-1]) == len(indices) % batch_size or batch_size

all_samples = list(flatten(batches))
assert len(all_samples) == len(indices)
6 changes: 3 additions & 3 deletions tests/valuation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def timed(fun: Callable[..., ReturnT]) -> TimedCallable:
any type.
Returns:
A wrapped function that, when called, returns a tuple containing the original
function's result and its execution time in seconds. The decorated function
will have the same input arguments and return type as the original function.
A wrapped function that, when called, saves its execution time in seconds into
the attribute `execution_time`. The wrapped function will have the same
input arguments and return type as the original function.
"""

@wraps(fun)
Expand Down

0 comments on commit 5659ca8

Please sign in to comment.