-
Notifications
You must be signed in to change notification settings - Fork 195
/
Copy pathtest_generator.py
154 lines (144 loc) · 6.25 KB
/
test_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from distilabel.llms.openai import OpenAILLM
from distilabel.steps.tasks.magpie.generator import MagpieGenerator
from tests.unit.conftest import DummyMagpieLLM
class TestMagpieGenerator:
def test_raise_value_error_llm_no_magpie_mixin(self) -> None:
with pytest.raises(
ValueError,
match="`Magpie` task can only be used with an `LLM` that uses the `MagpieChatTemplateMixin`",
):
MagpieGenerator(llm=OpenAILLM(model="gpt-4", api_key="fake")) # type: ignore
def test_outputs(self) -> None:
task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))
assert task.outputs == ["conversation", "model_name"]
task = MagpieGenerator(
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
only_instruction=True,
)
assert task.outputs == ["instruction", "model_name"]
def test_serialization(self) -> None:
task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))
assert task.dump() == {
"llm": {
"use_magpie_template": True,
"magpie_pre_query_template": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n",
"generation_kwargs": {},
"type_info": {
"module": "tests.unit.conftest",
"name": "DummyMagpieLLM",
},
},
"n_turns": 1,
"only_instruction": False,
"system_prompt": None,
"name": "magpie_generator_0",
"resources": {
"replicas": 1,
"cpus": None,
"gpus": None,
"memory": None,
"resources": None,
},
"input_mappings": {},
"output_mappings": {},
"batch_size": 50,
"group_generations": False,
"add_raw_output": True,
"num_generations": 1,
"num_rows": None,
"runtime_parameters_info": [
{
"name": "llm",
"runtime_parameters_info": [
{
"name": "generation_kwargs",
"description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.",
"keys": [{"name": "kwargs", "optional": False}],
}
],
},
{
"name": "n_turns",
"optional": True,
"description": "The number of turns to generate for the conversation.",
},
{
"name": "only_instruction",
"optional": True,
"description": "Whether to generate only the instruction. If this argument is `True`, then `n_turns` will be ignored.",
},
{
"name": "system_prompt",
"optional": True,
"description": "An optional system prompt that can be used to steer the LLM to generate content of certain topic, guide the style, etc.",
},
{
"name": "resources",
"runtime_parameters_info": [
{
"name": "replicas",
"optional": True,
"description": "The number of replicas for the step.",
},
{
"name": "cpus",
"optional": True,
"description": "The number of CPUs assigned to each step replica.",
},
{
"name": "gpus",
"optional": True,
"description": "The number of GPUs assigned to each step replica.",
},
{
"name": "memory",
"optional": True,
"description": "The memory in bytes required for each step replica.",
},
{
"name": "resources",
"optional": True,
"description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.",
},
],
},
{
"name": "batch_size",
"optional": True,
"description": "The number of rows that will contain the batches generated by the step.",
},
{
"name": "add_raw_output",
"optional": True,
"description": "Whether to include the raw output of the LLM in the key `raw_output_<TASK_NAME>` of the `distilabel_metadata` dictionary output column",
},
{
"name": "num_generations",
"optional": True,
"description": "The number of generations to be produced per input.",
},
{
"name": "num_rows",
"optional": False,
"description": "The number of rows to generate.",
},
],
"type_info": {
"module": "distilabel.steps.tasks.magpie.generator",
"name": "MagpieGenerator",
},
}