Skip to content

Commit 18dc02c

Browse files
alvarobarttburtenshawfrascuchongabrielmbmb
authored
Update argilla integration to use argilla_sdk v2 (#705)
* Update `_Argilla` base and `TextGenerationToArgilla` * Fix `_dataset.records.log` and rename to `ArgillaBase` Co-authored-by: Ben Burtenshaw <burtenshaw@users.noreply.github.com> * Update `TextGenerationToArgilla` subclass inheritance * Remove unused `logger.info` message * Update `PreferenceToArgilla` * Update `argilla` extra to install `argilla_sdk` For the moment it's being installed as `pip install git+https://github.com/argilla-io/argilla-python.git@main` * Add `ArgillaBase` and subclasses unit tests * Install `argilla_sdk` from source and add `ipython` * upgrade argilla dep to latest rc * udate code with latest changes * chore: remove unnecessary workspace definition * fix: wrong argilla module import * Update docstrings * Fix lint * Add check for `api_url` and `api_key` * Fix unit tests * Fix unit tests * Update argilla dependency version --------- Co-authored-by: Ben Burtenshaw <burtenshaw@users.noreply.github.com> Co-authored-by: Francisco Aranda <francis@argilla.io> Co-authored-by: Gabriel Martín Blázquez <gmartinbdev@gmail.com>
1 parent be61d20 commit 18dc02c

File tree

8 files changed

+206
-150
lines changed

8 files changed

+206
-150
lines changed

Diff for: pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ tests = [
7171

7272
# Optional LLMs, integrations, etc
7373
anthropic = ["anthropic >= 0.20.0"]
74-
argilla = ["argilla >= 1.29.0"]
74+
argilla = ["argilla >= 2.0.0", "ipython"]
7575
cohere = ["cohere >= 5.2.0"]
7676
groq = ["groq >= 0.4.1"]
7777
hf-inference-endpoints = ["huggingface_hub >= 0.22.0"]

Diff for: src/distilabel/steps/argilla/base.py

+43-29
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import importlib.util
1516
import os
16-
import warnings
1717
from abc import ABC, abstractmethod
1818
from typing import TYPE_CHECKING, Any, List, Optional
1919

@@ -28,15 +28,16 @@
2828
from distilabel.steps.base import Step, StepInput
2929

3030
if TYPE_CHECKING:
31-
from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset
31+
from argilla import Argilla, Dataset
3232

3333
from distilabel.steps.typing import StepOutput
3434

3535

36+
_ARGILLA_API_URL_ENV_VAR_NAME = "ARGILLA_API_URL"
3637
_ARGILLA_API_KEY_ENV_VAR_NAME = "ARGILLA_API_KEY"
3738

3839

39-
class Argilla(Step, ABC):
40+
class ArgillaBase(Step, ABC):
4041
"""Abstract step that provides a class to subclass from, that contains the boilerplate code
4142
required to interact with Argilla, as well as some extra validations on top of it. It also defines
4243
the abstract methods that need to be implemented in order to add a new dataset type as a step.
@@ -70,55 +71,61 @@ class Argilla(Step, ABC):
7071
)
7172
dataset_workspace: Optional[RuntimeParameter[str]] = Field(
7273
default=None,
73-
description="The workspace where the dataset will be created in Argilla. Defaults"
74+
description="The workspace where the dataset will be created in Argilla. Defaults "
7475
"to `None` which means it will be created in the default workspace.",
7576
)
7677

7778
api_url: Optional[RuntimeParameter[str]] = Field(
78-
default_factory=lambda: os.getenv("ARGILLA_API_URL"),
79+
default_factory=lambda: os.getenv(_ARGILLA_API_URL_ENV_VAR_NAME),
7980
description="The base URL to use for the Argilla API requests.",
8081
)
8182
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
8283
default_factory=lambda: os.getenv(_ARGILLA_API_KEY_ENV_VAR_NAME),
8384
description="The API key to authenticate the requests to the Argilla API.",
8485
)
8586

86-
_rg_dataset: Optional["RemoteFeedbackDataset"] = PrivateAttr(...)
87+
_client: Optional["Argilla"] = PrivateAttr(...)
88+
_dataset: Optional["Dataset"] = PrivateAttr(...)
8789

8890
def model_post_init(self, __context: Any) -> None:
8991
"""Checks that the Argilla Python SDK is installed, and then filters the Argilla warnings."""
9092
super().model_post_init(__context)
9193

92-
try:
93-
import argilla as rg # noqa
94-
except ImportError as ie:
94+
if importlib.util.find_spec("argilla") is None:
9595
raise ImportError(
96-
"Argilla is not installed. Please install it using `pip install argilla`."
97-
) from ie
98-
99-
warnings.filterwarnings("ignore")
96+
"Argilla is not installed. Please install it using `pip install argilla"
97+
" --upgrade`."
98+
)
10099

101-
def _rg_init(self) -> None:
100+
def _client_init(self) -> None:
102101
"""Initializes the Argilla API client with the provided `api_url` and `api_key`."""
103102
try:
104-
if "hf.space" in self.api_url and "HF_TOKEN" in os.environ:
105-
headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
106-
else:
107-
headers = None
108-
rg.init(
103+
self._client = rg.Argilla( # type: ignore
109104
api_url=self.api_url,
110-
api_key=self.api_key.get_secret_value(),
111-
extra_headers=headers,
112-
) # type: ignore
105+
api_key=self.api_key.get_secret_value(), # type: ignore
106+
headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
107+
if isinstance(self.api_url, str)
108+
and "hf.space" in self.api_url
109+
and "HF_TOKEN" in os.environ
110+
else {},
111+
)
113112
except Exception as e:
114113
raise ValueError(f"Failed to initialize the Argilla API: {e}") from e
115114

116-
def _rg_dataset_exists(self) -> bool:
117-
"""Checks if the dataset already exists in Argilla."""
118-
return self.dataset_name in [
119-
dataset.name
120-
for dataset in rg.FeedbackDataset.list(workspace=self.dataset_workspace) # type: ignore
121-
]
115+
@property
116+
def _dataset_exists_in_workspace(self) -> bool:
117+
"""Checks if the dataset already exists in Argilla in the provided workspace if any.
118+
119+
Returns:
120+
`True` if the dataset exists, `False` otherwise.
121+
"""
122+
return (
123+
self._client.datasets( # type: ignore
124+
name=self.dataset_name, # type: ignore
125+
workspace=self.dataset_workspace,
126+
)
127+
is not None
128+
)
122129

123130
@property
124131
def outputs(self) -> List[str]:
@@ -133,7 +140,14 @@ def load(self) -> None:
133140
"""
134141
super().load()
135142

136-
self._rg_init()
143+
if self.api_url is None or self.api_key is None:
144+
raise ValueError(
145+
"`Argilla` step requires the `api_url` and `api_key` to be provided. Please,"
146+
" provide those at step instantiation, via environment variables `ARGILLA_API_URL`"
147+
" and `ARGILLA_API_KEY`, or as `Step` runtime parameters via `pipeline.run(parameters={...})`."
148+
)
149+
150+
self._client_init()
137151

138152
@property
139153
@abstractmethod

Diff for: src/distilabel/steps/argilla/preference.py

+55-38
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,16 @@
2323
except ImportError:
2424
pass
2525

26-
from distilabel.steps.argilla.base import Argilla
26+
from distilabel.steps.argilla.base import ArgillaBase
2727
from distilabel.steps.base import StepInput
2828

2929
if TYPE_CHECKING:
30-
from argilla import (
31-
RatingQuestion,
32-
SuggestionSchema,
33-
TextField,
34-
TextQuestion,
35-
)
30+
from argilla import RatingQuestion, Suggestion, TextField, TextQuestion
3631

3732
from distilabel.steps.typing import StepOutput
3833

3934

40-
class PreferenceToArgilla(Argilla):
35+
class PreferenceToArgilla(ArgillaBase):
4136
"""Creates a preference dataset in Argilla.
4237
4338
Step that creates a dataset in Argilla during the load phase, and then pushes the input
@@ -153,45 +148,55 @@ def load(self) -> None:
153148
self._ratings = self.input_mappings.get("ratings", "ratings")
154149
self._rationales = self.input_mappings.get("rationales", "rationales")
155150

156-
if self._rg_dataset_exists():
157-
_rg_dataset = rg.FeedbackDataset.from_argilla( # type: ignore
158-
name=self.dataset_name,
159-
workspace=self.dataset_workspace,
151+
if self._dataset_exists_in_workspace:
152+
_dataset = self._client.datasets( # type: ignore
153+
name=self.dataset_name, # type: ignore
154+
workspace=self.dataset_workspace, # type: ignore
160155
)
161156

162-
for field in _rg_dataset.fields:
157+
for field in _dataset.fields:
158+
if not isinstance(field, rg.TextField):
159+
continue
163160
if (
164161
field.name
165-
not in [self._id, self._instruction]
162+
not in [self._id, self._instruction] # type: ignore
166163
+ [
167164
f"{self._generations}-{idx}"
168165
for idx in range(self.num_generations)
169166
]
170167
and field.required
171168
):
172169
raise ValueError(
173-
f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists,"
174-
f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`,"
175-
f" nor `{self._generations}`."
170+
f"The dataset '{self.dataset_name}' in the workspace '{self.dataset_workspace}'"
171+
f" already exists, but contains at least a required field that is"
172+
f" neither `{self._id}`, `{self._instruction}`, nor `{self._generations}`"
173+
f" (one per generation starting from 0 up to {self.num_generations - 1})."
176174
)
177175

178-
self._rg_dataset = _rg_dataset
176+
self._dataset = _dataset
179177
else:
180-
_rg_dataset = rg.FeedbackDataset( # type: ignore
178+
_settings = rg.Settings( # type: ignore
181179
fields=[
182180
rg.TextField(name=self._id, title=self._id), # type: ignore
183181
rg.TextField(name=self._instruction, title=self._instruction), # type: ignore
184182
*self._generation_fields(), # type: ignore
185183
],
186184
questions=self._rating_rationale_pairs(), # type: ignore
187185
)
188-
self._rg_dataset = _rg_dataset.push_to_argilla(
189-
name=self.dataset_name, # type: ignore
186+
_dataset = rg.Dataset( # type: ignore
187+
name=self.dataset_name,
190188
workspace=self.dataset_workspace,
189+
settings=_settings,
190+
client=self._client,
191191
)
192+
self._dataset = _dataset.create()
192193

193194
def _generation_fields(self) -> List["TextField"]:
194-
"""Method to generate the fields for each of the generations."""
195+
"""Method to generate the fields for each of the generations.
196+
197+
Returns:
198+
A list containing `TextField`s for each text generation.
199+
"""
195200
return [
196201
rg.TextField( # type: ignore
197202
name=f"{self._generations}-{idx}",
@@ -204,7 +209,12 @@ def _generation_fields(self) -> List["TextField"]:
204209
def _rating_rationale_pairs(
205210
self,
206211
) -> List[Union["RatingQuestion", "TextQuestion"]]:
207-
"""Method to generate the rating and rationale questions for each of the generations."""
212+
"""Method to generate the rating and rationale questions for each of the generations.
213+
214+
Returns:
215+
A list of questions containing a `RatingQuestion` and `TextQuestion` pair for
216+
each text generation.
217+
"""
208218
questions = []
209219
for idx in range(self.num_generations):
210220
questions.extend(
@@ -236,20 +246,27 @@ def inputs(self) -> List[str]:
236246
provide the `ratings` and the `rationales` for the generations."""
237247
return ["instruction", "generations"]
238248

239-
def _add_suggestions_if_any(
240-
self, input: Dict[str, Any]
241-
) -> List["SuggestionSchema"]:
242-
"""Method to generate the suggestions for the `FeedbackRecord` based on the input."""
249+
@property
250+
def optional_inputs(self) -> List[str]:
251+
"""The optional inputs for the step are the `ratings` and the `rationales` for the generations."""
252+
return ["ratings", "rationales"]
253+
254+
def _add_suggestions_if_any(self, input: Dict[str, Any]) -> List["Suggestion"]:
255+
"""Method to generate the suggestions for the `rg.Record` based on the input.
256+
257+
Returns:
258+
A list of `Suggestion`s for the rating and rationales questions.
259+
"""
243260
# Since the `suggestions` i.e. answers to the `questions` are optional, will default to {}
244261
suggestions = []
245262
# If `ratings` is in `input`, then add those as suggestions
246263
if self._ratings in input:
247264
suggestions.extend(
248265
[
249-
{
250-
"question_name": f"{self._generations}-{idx}-rating",
251-
"value": rating,
252-
}
266+
rg.Suggestion( # type: ignore
267+
value=rating,
268+
question_name=f"{self._generations}-{idx}-rating",
269+
)
253270
for idx, rating in enumerate(input[self._ratings])
254271
if rating is not None
255272
and isinstance(rating, int)
@@ -260,10 +277,10 @@ def _add_suggestions_if_any(
260277
if self._rationales in input:
261278
suggestions.extend(
262279
[
263-
{
264-
"question_name": f"{self._generations}-{idx}-rationale",
265-
"value": rationale,
266-
}
280+
rg.Suggestion( # type: ignore
281+
value=rationale,
282+
question_name=f"{self._generations}-{idx}-rationale",
283+
)
267284
for idx, rationale in enumerate(input[self._rationales])
268285
if rationale is not None and isinstance(rationale, str)
269286
],
@@ -272,7 +289,7 @@ def _add_suggestions_if_any(
272289

273290
@override
274291
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
275-
"""Creates and pushes the records as FeedbackRecords to the Argilla dataset.
292+
"""Creates and pushes the records as `rg.Record`s to the Argilla dataset.
276293
277294
Args:
278295
inputs: A list of Python dictionaries with the inputs of the task.
@@ -293,7 +310,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
293310
}
294311

295312
records.append( # type: ignore
296-
rg.FeedbackRecord( # type: ignore
313+
rg.Record( # type: ignore
297314
fields={
298315
"id": instruction_id,
299316
"instruction": input["instruction"], # type: ignore
@@ -302,5 +319,5 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
302319
suggestions=self._add_suggestions_if_any(input), # type: ignore
303320
)
304321
)
305-
self._rg_dataset.add_records(records) # type: ignore
322+
self._dataset.records.log(records) # type: ignore
306323
yield inputs

0 commit comments

Comments
 (0)