23
23
except ImportError :
24
24
pass
25
25
26
- from distilabel .steps .argilla .base import Argilla
26
+ from distilabel .steps .argilla .base import ArgillaBase
27
27
from distilabel .steps .base import StepInput
28
28
29
29
if TYPE_CHECKING :
30
- from argilla import (
31
- RatingQuestion ,
32
- SuggestionSchema ,
33
- TextField ,
34
- TextQuestion ,
35
- )
30
+ from argilla import RatingQuestion , Suggestion , TextField , TextQuestion
36
31
37
32
from distilabel .steps .typing import StepOutput
38
33
39
34
40
- class PreferenceToArgilla (Argilla ):
35
+ class PreferenceToArgilla (ArgillaBase ):
41
36
"""Creates a preference dataset in Argilla.
42
37
43
38
Step that creates a dataset in Argilla during the load phase, and then pushes the input
@@ -153,45 +148,55 @@ def load(self) -> None:
153
148
self ._ratings = self .input_mappings .get ("ratings" , "ratings" )
154
149
self ._rationales = self .input_mappings .get ("rationales" , "rationales" )
155
150
156
- if self ._rg_dataset_exists () :
157
- _rg_dataset = rg . FeedbackDataset . from_argilla ( # type: ignore
158
- name = self .dataset_name ,
159
- workspace = self .dataset_workspace ,
151
+ if self ._dataset_exists_in_workspace :
152
+ _dataset = self . _client . datasets ( # type: ignore
153
+ name = self .dataset_name , # type: ignore
154
+ workspace = self .dataset_workspace , # type: ignore
160
155
)
161
156
162
- for field in _rg_dataset .fields :
157
+ for field in _dataset .fields :
158
+ if not isinstance (field , rg .TextField ):
159
+ continue
163
160
if (
164
161
field .name
165
- not in [self ._id , self ._instruction ]
162
+ not in [self ._id , self ._instruction ] # type: ignore
166
163
+ [
167
164
f"{ self ._generations } -{ idx } "
168
165
for idx in range (self .num_generations )
169
166
]
170
167
and field .required
171
168
):
172
169
raise ValueError (
173
- f"The dataset { self .dataset_name } in the workspace { self .dataset_workspace } already exists,"
174
- f" but contains at least a required field that is neither `{ self ._id } `, `{ self ._instruction } `,"
175
- f" nor `{ self ._generations } `."
170
+ f"The dataset '{ self .dataset_name } ' in the workspace '{ self .dataset_workspace } '"
171
+ f" already exists, but contains at least a required field that is"
172
+ f" neither `{ self ._id } `, `{ self ._instruction } `, nor `{ self ._generations } `"
173
+ f" (one per generation starting from 0 up to { self .num_generations - 1 } )."
176
174
)
177
175
178
- self ._rg_dataset = _rg_dataset
176
+ self ._dataset = _dataset
179
177
else :
180
- _rg_dataset = rg .FeedbackDataset ( # type: ignore
178
+ _settings = rg .Settings ( # type: ignore
181
179
fields = [
182
180
rg .TextField (name = self ._id , title = self ._id ), # type: ignore
183
181
rg .TextField (name = self ._instruction , title = self ._instruction ), # type: ignore
184
182
* self ._generation_fields (), # type: ignore
185
183
],
186
184
questions = self ._rating_rationale_pairs (), # type: ignore
187
185
)
188
- self . _rg_dataset = _rg_dataset . push_to_argilla (
189
- name = self .dataset_name , # type: ignore
186
+ _dataset = rg . Dataset ( # type: ignore
187
+ name = self .dataset_name ,
190
188
workspace = self .dataset_workspace ,
189
+ settings = _settings ,
190
+ client = self ._client ,
191
191
)
192
+ self ._dataset = _dataset .create ()
192
193
193
194
def _generation_fields (self ) -> List ["TextField" ]:
194
- """Method to generate the fields for each of the generations."""
195
+ """Method to generate the fields for each of the generations.
196
+
197
+ Returns:
198
+ A list containing `TextField`s for each text generation.
199
+ """
195
200
return [
196
201
rg .TextField ( # type: ignore
197
202
name = f"{ self ._generations } -{ idx } " ,
@@ -204,7 +209,12 @@ def _generation_fields(self) -> List["TextField"]:
204
209
def _rating_rationale_pairs (
205
210
self ,
206
211
) -> List [Union ["RatingQuestion" , "TextQuestion" ]]:
207
- """Method to generate the rating and rationale questions for each of the generations."""
212
+ """Method to generate the rating and rationale questions for each of the generations.
213
+
214
+ Returns:
215
+ A list of questions containing a `RatingQuestion` and `TextQuestion` pair for
216
+ each text generation.
217
+ """
208
218
questions = []
209
219
for idx in range (self .num_generations ):
210
220
questions .extend (
@@ -236,20 +246,27 @@ def inputs(self) -> List[str]:
236
246
provide the `ratings` and the `rationales` for the generations."""
237
247
return ["instruction" , "generations" ]
238
248
239
- def _add_suggestions_if_any (
240
- self , input : Dict [str , Any ]
241
- ) -> List ["SuggestionSchema" ]:
242
- """Method to generate the suggestions for the `FeedbackRecord` based on the input."""
249
+ @property
250
+ def optional_inputs (self ) -> List [str ]:
251
+ """The optional inputs for the step are the `ratings` and the `rationales` for the generations."""
252
+ return ["ratings" , "rationales" ]
253
+
254
+ def _add_suggestions_if_any (self , input : Dict [str , Any ]) -> List ["Suggestion" ]:
255
+ """Method to generate the suggestions for the `rg.Record` based on the input.
256
+
257
+ Returns:
258
+ A list of `Suggestion`s for the rating and rationales questions.
259
+ """
243
260
# Since the `suggestions` i.e. answers to the `questions` are optional, will default to {}
244
261
suggestions = []
245
262
# If `ratings` is in `input`, then add those as suggestions
246
263
if self ._ratings in input :
247
264
suggestions .extend (
248
265
[
249
- {
250
- "question_name" : f" { self . _generations } - { idx } - rating" ,
251
- "value" : rating ,
252
- }
266
+ rg . Suggestion ( # type: ignore
267
+ value = rating ,
268
+ question_name = f" { self . _generations } - { idx } - rating" ,
269
+ )
253
270
for idx , rating in enumerate (input [self ._ratings ])
254
271
if rating is not None
255
272
and isinstance (rating , int )
@@ -260,10 +277,10 @@ def _add_suggestions_if_any(
260
277
if self ._rationales in input :
261
278
suggestions .extend (
262
279
[
263
- {
264
- "question_name" : f" { self . _generations } - { idx } - rationale" ,
265
- "value" : rationale ,
266
- }
280
+ rg . Suggestion ( # type: ignore
281
+ value = rationale ,
282
+ question_name = f" { self . _generations } - { idx } - rationale" ,
283
+ )
267
284
for idx , rationale in enumerate (input [self ._rationales ])
268
285
if rationale is not None and isinstance (rationale , str )
269
286
],
@@ -272,7 +289,7 @@ def _add_suggestions_if_any(
272
289
273
290
@override
274
291
def process (self , inputs : StepInput ) -> "StepOutput" : # type: ignore
275
- """Creates and pushes the records as FeedbackRecords to the Argilla dataset.
292
+ """Creates and pushes the records as `rg.Record`s to the Argilla dataset.
276
293
277
294
Args:
278
295
inputs: A list of Python dictionaries with the inputs of the task.
@@ -293,7 +310,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
293
310
}
294
311
295
312
records .append ( # type: ignore
296
- rg .FeedbackRecord ( # type: ignore
313
+ rg .Record ( # type: ignore
297
314
fields = {
298
315
"id" : instruction_id ,
299
316
"instruction" : input ["instruction" ], # type: ignore
@@ -302,5 +319,5 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
302
319
suggestions = self ._add_suggestions_if_any (input ), # type: ignore
303
320
)
304
321
)
305
- self ._rg_dataset . add_records (records ) # type: ignore
322
+ self ._dataset . records . log (records ) # type: ignore
306
323
yield inputs
0 commit comments