Skip to content

Commit ecfa12d

Browse files
committed
Fix unit tests magpie
1 parent 35de70d commit ecfa12d

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

Diff for: tests/unit/steps/tasks/magpie/test_base.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ def test_raise_value_error_llm_no_magpie_mixin(self) -> None:
3030
def test_outputs(self) -> None:
3131
task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))
3232

33-
assert task.outputs == ["conversation"]
33+
assert task.outputs == ["conversation", "model_name"]
3434

3535
task = Magpie(
3636
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
3737
only_instruction=True,
3838
)
3939

40-
assert task.outputs == ["instruction"]
40+
assert task.outputs == ["instruction", "model_name"]
4141

4242
def test_process(self) -> None:
4343
task = Magpie(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"), n_turns=1)
@@ -50,18 +50,21 @@ def test_process(self) -> None:
5050
{"role": "user", "content": "Hello Magpie"},
5151
{"role": "assistant", "content": "Hello Magpie"},
5252
],
53+
"model_name": "test",
5354
},
5455
{
5556
"conversation": [
5657
{"role": "user", "content": "Hello Magpie"},
5758
{"role": "assistant", "content": "Hello Magpie"},
5859
],
60+
"model_name": "test",
5961
},
6062
{
6163
"conversation": [
6264
{"role": "user", "content": "Hello Magpie"},
6365
{"role": "assistant", "content": "Hello Magpie"},
6466
],
67+
"model_name": "test",
6568
},
6669
]
6770

@@ -79,6 +82,7 @@ def test_process_with_n_turns(self) -> None:
7982
{"role": "user", "content": "Hello Magpie"},
8083
{"role": "assistant", "content": "Hello Magpie"},
8184
],
85+
"model_name": "test",
8286
},
8387
{
8488
"conversation": [
@@ -88,6 +92,7 @@ def test_process_with_n_turns(self) -> None:
8892
{"role": "user", "content": "Hello Magpie"},
8993
{"role": "assistant", "content": "Hello Magpie"},
9094
],
95+
"model_name": "test",
9196
},
9297
{
9398
"conversation": [
@@ -97,6 +102,7 @@ def test_process_with_n_turns(self) -> None:
97102
{"role": "user", "content": "Hello Magpie"},
98103
{"role": "assistant", "content": "Hello Magpie"},
99104
],
105+
"model_name": "test",
100106
},
101107
]
102108

@@ -115,31 +121,37 @@ def test_process_with_system_prompt_per_row(self) -> None:
115121
)
116122
) == [
117123
{
124+
"system_prompt": "You're a math expert assistant.",
118125
"conversation": [
119126
{"role": "system", "content": "You're a math expert assistant."},
120127
{"role": "user", "content": "Hello Magpie"},
121128
{"role": "assistant", "content": "Hello Magpie"},
122129
{"role": "user", "content": "Hello Magpie"},
123130
{"role": "assistant", "content": "Hello Magpie"},
124131
],
132+
"model_name": "test",
125133
},
126134
{
135+
"system_prompt": "You're a florist expert assistant.",
127136
"conversation": [
128137
{"role": "system", "content": "You're a florist expert assistant."},
129138
{"role": "user", "content": "Hello Magpie"},
130139
{"role": "assistant", "content": "Hello Magpie"},
131140
{"role": "user", "content": "Hello Magpie"},
132141
{"role": "assistant", "content": "Hello Magpie"},
133142
],
143+
"model_name": "test",
134144
},
135145
{
146+
"system_prompt": "You're a plumber expert assistant.",
136147
"conversation": [
137148
{"role": "system", "content": "You're a plumber expert assistant."},
138149
{"role": "user", "content": "Hello Magpie"},
139150
{"role": "assistant", "content": "Hello Magpie"},
140151
{"role": "user", "content": "Hello Magpie"},
141152
{"role": "assistant", "content": "Hello Magpie"},
142153
],
154+
"model_name": "test",
143155
},
144156
]
145157

@@ -152,9 +164,18 @@ def test_process_only_instruction(self) -> None:
152164
task.load()
153165

154166
assert next(task.process(inputs=[{}, {}, {}])) == [
155-
{"instruction": "Hello Magpie"},
156-
{"instruction": "Hello Magpie"},
157-
{"instruction": "Hello Magpie"},
167+
{
168+
"instruction": "Hello Magpie",
169+
"model_name": "test",
170+
},
171+
{
172+
"instruction": "Hello Magpie",
173+
"model_name": "test",
174+
},
175+
{
176+
"instruction": "Hello Magpie",
177+
"model_name": "test",
178+
},
158179
]
159180

160181
def test_serialization(self) -> None:

Diff for: tests/unit/steps/tasks/magpie/test_generator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ def test_raise_value_error_llm_no_magpie_mixin(self) -> None:
3030
def test_outputs(self) -> None:
3131
task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))
3232

33-
assert task.outputs == ["conversation"]
33+
assert task.outputs == ["conversation", "model_name"]
3434

3535
task = MagpieGenerator(
3636
llm=DummyMagpieLLM(magpie_pre_query_template="llama3"),
3737
only_instruction=True,
3838
)
3939

40-
assert task.outputs == ["instruction"]
40+
assert task.outputs == ["instruction", "model_name"]
4141

4242
def test_serialization(self) -> None:
4343
task = MagpieGenerator(llm=DummyMagpieLLM(magpie_pre_query_template="llama3"))

0 commit comments

Comments
 (0)