Skip to content

Commit 86d4e80

Browse files
committed
Merge branch 'main' into develop
2 parents ed0f08e + a22c7e2 commit 86d4e80

File tree

2 files changed

+127
-16
lines changed

2 files changed

+127
-16
lines changed

src/distilabel/steps/tasks/base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,20 @@ def format_output(
9191
def _format_outputs(
9292
self,
9393
outputs: "GenerateOutput",
94-
inputs: Union[List[Dict[str, Any]], None] = None,
94+
input: Union[Dict[str, Any], None] = None,
9595
) -> List[Dict[str, Any]]:
9696
"""Formats the outputs of the task using the `format_output` method. If the output
9797
is `None` (i.e. the LLM failed to generate a response), then the outputs will be
9898
set to `None` as well.
9999
100100
Args:
101-
outputs: The outputs of the LLM.
102-
inputs: The inputs used to generate the outputs.
101+
outputs: The outputs (`n` generations) for the provided `input`.
102+
input: The input used to generate the output.
103103
104104
Returns:
105105
A list containing a dictionary with the outputs of the task for each input.
106106
"""
107-
if inputs is None:
108-
inputs = [None] # type: ignore
107+
inputs = [None] if input is None else [input]
109108

110109
formatted_outputs = []
111110
for output, input in zip(outputs, inputs * len(outputs)): # type: ignore
@@ -195,6 +194,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
195194

196195
formatted_inputs = self._format_inputs(inputs)
197196

197+
# `outputs` is a list containing a list of generations per input
198198
outputs = self.llm.generate(
199199
inputs=formatted_inputs,
200200
num_generations=self.num_generations, # type: ignore
@@ -203,7 +203,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
203203

204204
task_outputs = []
205205
for input, input_outputs in zip(inputs, outputs):
206-
formatted_outputs = self._format_outputs(input_outputs, inputs)
206+
formatted_outputs = self._format_outputs(input_outputs, input)
207207

208208
if self.group_generations:
209209
combined = group_dicts(*formatted_outputs)

tests/unit/steps/tasks/test_base.py

+121-10
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,22 @@
3131
class DummyTask(Task):
3232
@property
3333
def inputs(self) -> List[str]:
34-
return ["instruction"]
34+
return ["instruction", "additional_info"]
3535

3636
def format_input(self, input: Dict[str, Any]) -> "ChatType":
3737
return [
3838
{"role": "system", "content": ""},
3939
{"role": "user", "content": input["instruction"]},
4040
]
4141

42-
def format_output(self, output: Union[str, None], input: Dict[str, Any]) -> dict:
43-
return {"output": output}
42+
@property
43+
def outputs(self) -> List[str]:
44+
return ["output", "info_from_input"]
45+
46+
def format_output(
47+
self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
48+
) -> Dict[str, Any]:
49+
return {"output": output, "info_from_input": input["additional_info"]} # type: ignore
4450

4551

4652
class DummyRuntimeLLM(DummyLLM):
@@ -85,37 +91,139 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
8591
Task(name="task", llm=DummyLLM()) # type: ignore
8692

8793
@pytest.mark.parametrize(
88-
"group_generations, expected",
94+
"input, group_generations, expected",
8995
[
9096
(
97+
[
98+
{"instruction": "test_0", "additional_info": "additional_info_0"},
99+
{"instruction": "test_1", "additional_info": "additional_info_1"},
100+
{"instruction": "test_2", "additional_info": "additional_info_2"},
101+
],
91102
False,
92103
[
93104
{
94-
"instruction": "test",
105+
"instruction": "test_0",
106+
"additional_info": "additional_info_0",
107+
"output": "output",
108+
"info_from_input": "additional_info_0",
109+
"model_name": "test",
110+
"distilabel_metadata": {"raw_output_task": "output"},
111+
},
112+
{
113+
"instruction": "test_0",
114+
"additional_info": "additional_info_0",
115+
"output": "output",
116+
"info_from_input": "additional_info_0",
117+
"model_name": "test",
118+
"distilabel_metadata": {"raw_output_task": "output"},
119+
},
120+
{
121+
"instruction": "test_0",
122+
"additional_info": "additional_info_0",
123+
"output": "output",
124+
"info_from_input": "additional_info_0",
125+
"model_name": "test",
126+
"distilabel_metadata": {"raw_output_task": "output"},
127+
},
128+
{
129+
"instruction": "test_1",
130+
"additional_info": "additional_info_1",
131+
"output": "output",
132+
"info_from_input": "additional_info_1",
133+
"model_name": "test",
134+
"distilabel_metadata": {"raw_output_task": "output"},
135+
},
136+
{
137+
"instruction": "test_1",
138+
"additional_info": "additional_info_1",
139+
"output": "output",
140+
"info_from_input": "additional_info_1",
141+
"model_name": "test",
142+
"distilabel_metadata": {"raw_output_task": "output"},
143+
},
144+
{
145+
"instruction": "test_1",
146+
"additional_info": "additional_info_1",
95147
"output": "output",
148+
"info_from_input": "additional_info_1",
96149
"model_name": "test",
97150
"distilabel_metadata": {"raw_output_task": "output"},
98151
},
99152
{
100-
"instruction": "test",
153+
"instruction": "test_2",
154+
"additional_info": "additional_info_2",
101155
"output": "output",
156+
"info_from_input": "additional_info_2",
102157
"model_name": "test",
103158
"distilabel_metadata": {"raw_output_task": "output"},
104159
},
105160
{
106-
"instruction": "test",
161+
"instruction": "test_2",
162+
"additional_info": "additional_info_2",
107163
"output": "output",
164+
"info_from_input": "additional_info_2",
165+
"model_name": "test",
166+
"distilabel_metadata": {"raw_output_task": "output"},
167+
},
168+
{
169+
"instruction": "test_2",
170+
"additional_info": "additional_info_2",
171+
"output": "output",
172+
"info_from_input": "additional_info_2",
108173
"model_name": "test",
109174
"distilabel_metadata": {"raw_output_task": "output"},
110175
},
111176
],
112177
),
113178
(
179+
[
180+
{"instruction": "test_0", "additional_info": "additional_info_0"},
181+
{"instruction": "test_1", "additional_info": "additional_info_1"},
182+
{"instruction": "test_2", "additional_info": "additional_info_2"},
183+
],
114184
True,
115185
[
116186
{
117-
"instruction": "test",
187+
"instruction": "test_0",
188+
"additional_info": "additional_info_0",
189+
"output": ["output", "output", "output"],
190+
"info_from_input": [
191+
"additional_info_0",
192+
"additional_info_0",
193+
"additional_info_0",
194+
],
195+
"model_name": "test",
196+
"distilabel_metadata": [
197+
{"raw_output_task": "output"},
198+
{"raw_output_task": "output"},
199+
{"raw_output_task": "output"},
200+
],
201+
},
202+
{
203+
"instruction": "test_1",
204+
"additional_info": "additional_info_1",
205+
"output": ["output", "output", "output"],
206+
"info_from_input": [
207+
"additional_info_1",
208+
"additional_info_1",
209+
"additional_info_1",
210+
],
211+
"model_name": "test",
212+
"distilabel_metadata": [
213+
{"raw_output_task": "output"},
214+
{"raw_output_task": "output"},
215+
{"raw_output_task": "output"},
216+
],
217+
},
218+
{
219+
"instruction": "test_2",
220+
"additional_info": "additional_info_2",
118221
"output": ["output", "output", "output"],
222+
"info_from_input": [
223+
"additional_info_2",
224+
"additional_info_2",
225+
"additional_info_2",
226+
],
119227
"model_name": "test",
120228
"distilabel_metadata": [
121229
{"raw_output_task": "output"},
@@ -128,7 +236,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
128236
],
129237
)
130238
def test_process(
131-
self, group_generations: bool, expected: List[Dict[str, Any]]
239+
self,
240+
input: List[Dict[str, str]],
241+
group_generations: bool,
242+
expected: List[Dict[str, Any]],
132243
) -> None:
133244
pipeline = Pipeline(name="unit-test-pipeline")
134245
llm = DummyLLM()
@@ -139,7 +250,7 @@ def test_process(
139250
group_generations=group_generations,
140251
num_generations=3,
141252
)
142-
result = next(task.process([{"instruction": "test"}]))
253+
result = next(task.process(input))
143254
assert result == expected
144255

145256
def test_process_with_runtime_parameters(self) -> None:

0 commit comments

Comments
 (0)