31
31
class DummyTask (Task ):
32
32
@property
33
33
def inputs (self ) -> List [str ]:
34
- return ["instruction" ]
34
+ return ["instruction" , "additional_info" ]
35
35
36
36
def format_input (self , input : Dict [str , Any ]) -> "ChatType" :
37
37
return [
38
38
{"role" : "system" , "content" : "" },
39
39
{"role" : "user" , "content" : input ["instruction" ]},
40
40
]
41
41
42
- def format_output (self , output : Union [str , None ], input : Dict [str , Any ]) -> dict :
43
- return {"output" : output }
42
+ @property
43
+ def outputs (self ) -> List [str ]:
44
+ return ["output" , "info_from_input" ]
45
+
46
+ def format_output (
47
+ self , output : Union [str , None ], input : Union [Dict [str , Any ], None ] = None
48
+ ) -> Dict [str , Any ]:
49
+ return {"output" : output , "info_from_input" : input ["additional_info" ]} # type: ignore
44
50
45
51
46
52
class DummyRuntimeLLM (DummyLLM ):
@@ -85,37 +91,139 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
85
91
Task (name = "task" , llm = DummyLLM ()) # type: ignore
86
92
87
93
@pytest .mark .parametrize (
88
- "group_generations, expected" ,
94
+ "input, group_generations, expected" ,
89
95
[
90
96
(
97
+ [
98
+ {"instruction" : "test_0" , "additional_info" : "additional_info_0" },
99
+ {"instruction" : "test_1" , "additional_info" : "additional_info_1" },
100
+ {"instruction" : "test_2" , "additional_info" : "additional_info_2" },
101
+ ],
91
102
False ,
92
103
[
93
104
{
94
- "instruction" : "test" ,
105
+ "instruction" : "test_0" ,
106
+ "additional_info" : "additional_info_0" ,
107
+ "output" : "output" ,
108
+ "info_from_input" : "additional_info_0" ,
109
+ "model_name" : "test" ,
110
+ "distilabel_metadata" : {"raw_output_task" : "output" },
111
+ },
112
+ {
113
+ "instruction" : "test_0" ,
114
+ "additional_info" : "additional_info_0" ,
115
+ "output" : "output" ,
116
+ "info_from_input" : "additional_info_0" ,
117
+ "model_name" : "test" ,
118
+ "distilabel_metadata" : {"raw_output_task" : "output" },
119
+ },
120
+ {
121
+ "instruction" : "test_0" ,
122
+ "additional_info" : "additional_info_0" ,
123
+ "output" : "output" ,
124
+ "info_from_input" : "additional_info_0" ,
125
+ "model_name" : "test" ,
126
+ "distilabel_metadata" : {"raw_output_task" : "output" },
127
+ },
128
+ {
129
+ "instruction" : "test_1" ,
130
+ "additional_info" : "additional_info_1" ,
131
+ "output" : "output" ,
132
+ "info_from_input" : "additional_info_1" ,
133
+ "model_name" : "test" ,
134
+ "distilabel_metadata" : {"raw_output_task" : "output" },
135
+ },
136
+ {
137
+ "instruction" : "test_1" ,
138
+ "additional_info" : "additional_info_1" ,
139
+ "output" : "output" ,
140
+ "info_from_input" : "additional_info_1" ,
141
+ "model_name" : "test" ,
142
+ "distilabel_metadata" : {"raw_output_task" : "output" },
143
+ },
144
+ {
145
+ "instruction" : "test_1" ,
146
+ "additional_info" : "additional_info_1" ,
95
147
"output" : "output" ,
148
+ "info_from_input" : "additional_info_1" ,
96
149
"model_name" : "test" ,
97
150
"distilabel_metadata" : {"raw_output_task" : "output" },
98
151
},
99
152
{
100
- "instruction" : "test" ,
153
+ "instruction" : "test_2" ,
154
+ "additional_info" : "additional_info_2" ,
101
155
"output" : "output" ,
156
+ "info_from_input" : "additional_info_2" ,
102
157
"model_name" : "test" ,
103
158
"distilabel_metadata" : {"raw_output_task" : "output" },
104
159
},
105
160
{
106
- "instruction" : "test" ,
161
+ "instruction" : "test_2" ,
162
+ "additional_info" : "additional_info_2" ,
107
163
"output" : "output" ,
164
+ "info_from_input" : "additional_info_2" ,
165
+ "model_name" : "test" ,
166
+ "distilabel_metadata" : {"raw_output_task" : "output" },
167
+ },
168
+ {
169
+ "instruction" : "test_2" ,
170
+ "additional_info" : "additional_info_2" ,
171
+ "output" : "output" ,
172
+ "info_from_input" : "additional_info_2" ,
108
173
"model_name" : "test" ,
109
174
"distilabel_metadata" : {"raw_output_task" : "output" },
110
175
},
111
176
],
112
177
),
113
178
(
179
+ [
180
+ {"instruction" : "test_0" , "additional_info" : "additional_info_0" },
181
+ {"instruction" : "test_1" , "additional_info" : "additional_info_1" },
182
+ {"instruction" : "test_2" , "additional_info" : "additional_info_2" },
183
+ ],
114
184
True ,
115
185
[
116
186
{
117
- "instruction" : "test" ,
187
+ "instruction" : "test_0" ,
188
+ "additional_info" : "additional_info_0" ,
189
+ "output" : ["output" , "output" , "output" ],
190
+ "info_from_input" : [
191
+ "additional_info_0" ,
192
+ "additional_info_0" ,
193
+ "additional_info_0" ,
194
+ ],
195
+ "model_name" : "test" ,
196
+ "distilabel_metadata" : [
197
+ {"raw_output_task" : "output" },
198
+ {"raw_output_task" : "output" },
199
+ {"raw_output_task" : "output" },
200
+ ],
201
+ },
202
+ {
203
+ "instruction" : "test_1" ,
204
+ "additional_info" : "additional_info_1" ,
205
+ "output" : ["output" , "output" , "output" ],
206
+ "info_from_input" : [
207
+ "additional_info_1" ,
208
+ "additional_info_1" ,
209
+ "additional_info_1" ,
210
+ ],
211
+ "model_name" : "test" ,
212
+ "distilabel_metadata" : [
213
+ {"raw_output_task" : "output" },
214
+ {"raw_output_task" : "output" },
215
+ {"raw_output_task" : "output" },
216
+ ],
217
+ },
218
+ {
219
+ "instruction" : "test_2" ,
220
+ "additional_info" : "additional_info_2" ,
118
221
"output" : ["output" , "output" , "output" ],
222
+ "info_from_input" : [
223
+ "additional_info_2" ,
224
+ "additional_info_2" ,
225
+ "additional_info_2" ,
226
+ ],
119
227
"model_name" : "test" ,
120
228
"distilabel_metadata" : [
121
229
{"raw_output_task" : "output" },
@@ -128,7 +236,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
128
236
],
129
237
)
130
238
def test_process (
131
- self , group_generations : bool , expected : List [Dict [str , Any ]]
239
+ self ,
240
+ input : List [Dict [str , str ]],
241
+ group_generations : bool ,
242
+ expected : List [Dict [str , Any ]],
132
243
) -> None :
133
244
pipeline = Pipeline (name = "unit-test-pipeline" )
134
245
llm = DummyLLM ()
@@ -139,7 +250,7 @@ def test_process(
139
250
group_generations = group_generations ,
140
251
num_generations = 3 ,
141
252
)
142
- result = next (task .process ([{ "instruction" : "test" }] ))
253
+ result = next (task .process (input ))
143
254
assert result == expected
144
255
145
256
def test_process_with_runtime_parameters (self ) -> None :
0 commit comments