Skip to content

Commit dace285

Browse files
authored
feat: extend eval dataloaders (#576)
1 parent 90b4727 commit dace285

File tree

5 files changed

+60
-29
lines changed

5 files changed

+60
-29
lines changed

packages/ragbits-evaluate/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Unreleased
44

5+
- Add support for slicing dataset (#576)
6+
- Separate load and map ops in data loaders (#576)
7+
58
## 0.18.0 (2025-05-22)
69

710
### Changed

packages/ragbits-evaluate/src/ragbits/evaluate/dataloaders/base.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from types import ModuleType
44
from typing import ClassVar, Generic
55

6+
from datasets import load_dataset
67
from pydantic import BaseModel
78
from typing_extensions import Self
89

910
from ragbits.core.sources.base import Source
1011
from ragbits.core.utils.config_handling import ObjectConstructionConfig, WithConstructionConfig
1112
from ragbits.evaluate import dataloaders
13+
from ragbits.evaluate.dataloaders.exceptions import DataLoaderIncorrectFormatDataError
1214
from ragbits.evaluate.pipelines.base import EvaluationDataT
1315

1416

@@ -28,14 +30,19 @@ class DataLoader(WithConstructionConfig, Generic[EvaluationDataT], ABC):
2830
default_module: ClassVar[ModuleType | None] = dataloaders
2931
configuration_key: ClassVar[str] = "dataloader"
3032

31-
def __init__(self, source: Source) -> None:
33+
def __init__(self, source: Source, *, split: str = "data", required_keys: set[str] | None = None) -> None:
3234
"""
3335
Initialize the data loader.
3436
3537
Args:
3638
source: The source to load the evaluation data from.
39+
split: The split to load the data from. Split is fixed for data loaders to "data",
40+
but you can slice it using the [Hugging Face API](https://huggingface.co/docs/datasets/v1.11.0/splits.html#slicing-api).
41+
required_keys: The required columns for the evaluation data.
3742
"""
3843
self.source = source
44+
self.split = split
45+
self.required_keys = required_keys or set()
3946

4047
@classmethod
4148
def from_config(cls, config: dict) -> Self:
@@ -52,11 +59,37 @@ def from_config(cls, config: dict) -> Self:
5259
config["source"] = Source.subclass_from_config(dataloader_config.source)
5360
return super().from_config(config)
5461

55-
@abstractmethod
5662
async def load(self) -> Iterable[EvaluationDataT]:
5763
"""
5864
Load the data.
5965
6066
Returns:
61-
The loaded data.
67+
The loaded evaluation data.
68+
69+
Raises:
70+
DataLoaderIncorrectFormatDataError: If evaluation dataset is incorrectly formatted.
71+
"""
72+
data_path = await self.source.fetch()
73+
dataset = load_dataset(
74+
path=str(data_path.parent),
75+
data_files={"data": str(data_path.name)},
76+
split=self.split,
77+
)
78+
if not self.required_keys.issubset(dataset.features):
79+
raise DataLoaderIncorrectFormatDataError(
80+
required_features=list(self.required_keys),
81+
data_path=data_path,
82+
)
83+
return await self.map(dataset.to_list())
84+
85+
@abstractmethod
86+
async def map(self, dataset: Iterable[dict]) -> Iterable[EvaluationDataT]:
87+
"""
88+
Map the dataset to the evaluation data.
89+
90+
Args:
91+
dataset: The dataset to map.
92+
93+
Returns:
94+
The evaluation data.
6295
"""

packages/ragbits-evaluate/src/ragbits/evaluate/dataloaders/document_search.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
11
from collections.abc import Iterable
22

3-
from datasets import load_dataset
4-
53
from ragbits.core.sources.base import Source
64
from ragbits.evaluate.dataloaders.base import DataLoader
7-
from ragbits.evaluate.dataloaders.exceptions import DataLoaderIncorrectFormatDataError
85
from ragbits.evaluate.pipelines.document_search import DocumentSearchData
96

107

118
class DocumentSearchDataLoader(DataLoader[DocumentSearchData]):
129
"""
1310
Document search evaluation data loader.
1411
15-
The source used for this data loader should point to a file that can be loaded by [Hugging Face](https://huggingface.co/docs/datasets/loading#local-and-remote-files)
16-
and contain the following features: "question, "passages".
12+
The source used for this data loader should point to a file that can be loaded by [Hugging Face](https://huggingface.co/docs/datasets/loading#local-and-remote-files).
1713
"""
1814

1915
def __init__(
2016
self,
2117
source: Source,
18+
*,
19+
split: str = "data",
2220
question_key: str = "question",
2321
document_ids_key: str = "document_ids",
2422
passages_key: str = "passages",
@@ -29,42 +27,32 @@ def __init__(
2927
3028
Args:
3129
source: The source to load the data from.
30+
split: The split to load the data from. Split is fixed for data loaders to "data",
31+
but you can slice it using the [Hugging Face API](https://huggingface.co/docs/datasets/v1.11.0/splits.html#slicing-api).
3232
question_key: The dataset column name that contains the question.
3333
document_ids_key: The dataset column name that contains the document ids. Document ids are optional.
3434
passages_key: The dataset column name that contains the passages. Passages are optional.
3535
page_numbers_key: The dataset column name that contains the page numbers. Page numbers are optional.
3636
"""
37-
super().__init__(source)
37+
super().__init__(source=source, split=split, required_keys={question_key})
3838
self.question_key = question_key
3939
self.document_ids_key = document_ids_key
4040
self.passages_key = passages_key
4141
self.page_numbers_key = page_numbers_key
4242

43-
async def load(self) -> Iterable[DocumentSearchData]:
43+
async def map(self, dataset: Iterable[dict]) -> Iterable[DocumentSearchData]:
4444
"""
45-
Load the data from source and format them.
45+
Map the dataset to the document search data schema.
4646
47-
Returns:
48-
The document search evaluation data.
47+
Args:
48+
dataset: The dataset to map.
4949
50-
Raises:
51-
DataLoaderIncorrectFormatDataError: If evaluation dataset is incorrectly formatted.
50+
Returns:
51+
The document search data.
5252
"""
53-
data_path = await self.source.fetch()
54-
dataset = load_dataset(
55-
path=str(data_path.parent),
56-
split="train",
57-
data_files={"train": str(data_path.name)},
58-
)
59-
if self.question_key not in dataset.features:
60-
raise DataLoaderIncorrectFormatDataError(
61-
required_features=[self.question_key],
62-
data_path=data_path,
63-
)
64-
6553
return [
6654
DocumentSearchData(
67-
question=data.get(self.question_key),
55+
question=data.get(self.question_key, ""),
6856
reference_document_ids=data.get(self.document_ids_key),
6957
reference_passages=data.get(self.passages_key),
7058
reference_page_numbers=data.get(self.page_numbers_key),

packages/ragbits-evaluate/tests/unit/test_evaluator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Iterable
22
from dataclasses import dataclass
3-
from typing import cast
3+
from typing import Any, cast
44
from unittest.mock import Mock
55

66
import pytest
@@ -58,6 +58,9 @@ def __init__(self, dataset_size: int = 4) -> None:
5858
self.dataset_size = dataset_size
5959

6060
async def load(self) -> Iterable[MockEvaluationData]:
61+
return await self.map()
62+
63+
async def map(self, *args: Any, **kwargs: Any) -> Iterable[MockEvaluationData]: # noqa: ANN401
6164
return [MockEvaluationData(input_data=i) for i in range(1, self.dataset_size + 1)]
6265

6366
@classmethod

packages/ragbits-evaluate/tests/unit/test_optimizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Iterable
22
from dataclasses import dataclass
3+
from typing import Any
34
from unittest.mock import Mock
45

56
import pytest
@@ -55,6 +56,9 @@ def __init__(self, dataset_size: int = 4) -> None:
5556
self.dataset_size = dataset_size
5657

5758
async def load(self) -> Iterable[MockEvaluationData]:
59+
return await self.map()
60+
61+
async def map(self, *args: Any, **kwargs: Any) -> Iterable[MockEvaluationData]: # noqa: ANN401
5862
return [MockEvaluationData(input_data=i) for i in range(1, self.dataset_size + 1)]
5963

6064
@classmethod

0 commit comments

Comments
 (0)