Skip to content

Commit ba85907

Browse files
committed
Add magpie unit tests
1 parent 53ff036 commit ba85907

File tree

7 files changed

+532
-11
lines changed

7 files changed

+532
-11
lines changed

Diff for: src/distilabel/steps/tasks/magpie/base.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,6 @@ class MagpieBase(RuntimeParametersMixin):
6161
" content of certain topic, guide the style, etc.",
6262
)
6363

64-
@property
65-
def outputs(self) -> List[str]:
66-
"""Either a multi-turn conversation or the instruction generated."""
67-
if self.only_instruction:
68-
return ["instruction"]
69-
return ["conversation"]
70-
7164
def _prepare_inputs_for_instruction_generation(
7265
self, inputs: List[Dict[str, Any]]
7366
) -> List["FormattedInput"]:
@@ -121,7 +114,7 @@ def _generate_multi_turn_conversation(
121114
) -> List[Dict[str, Any]]:
122115
conversations = self._prepare_inputs_for_instruction_generation(inputs)
123116

124-
for _ in range(self.n_turns - 1): # type: ignore
117+
for _ in range(self.n_turns): # type: ignore
125118
# Generate instruction or user message
126119
outputs = self.llm.generate(
127120
inputs=conversations,
@@ -355,6 +348,13 @@ def format_input(self, input: Dict[str, Any]) -> "ChatType":
355348
"""Does nothing."""
356349
return []
357350

351+
@property
352+
def outputs(self) -> List[str]:
353+
"""Either a multi-turn conversation or the instruction generated."""
354+
if self.only_instruction:
355+
return ["instruction"]
356+
return ["conversation"]
357+
358358
def format_output(
359359
self,
360360
output: Union[str, None],

Diff for: src/distilabel/steps/tasks/magpie/generator.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import TYPE_CHECKING, Any, Dict, Union
15+
from typing import TYPE_CHECKING, Any, Dict, List, Union
1616

1717
from pydantic import Field
1818

@@ -211,6 +211,13 @@ def format_output(
211211
"""Does nothing."""
212212
return {}
213213

214+
@property
215+
def outputs(self) -> List[str]:
216+
"""Either a multi-turn conversation or the instruction generated."""
217+
if self.only_instruction:
218+
return ["instruction"]
219+
return ["conversation"]
220+
214221
def process(self, offset: int = 0) -> "GeneratorStepOutput":
215222
"""Generates the desired number of instructions or conversations using Magpie.
216223

Diff for: tests/unit/conftest.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import TYPE_CHECKING
15+
from typing import TYPE_CHECKING, Any, List
1616

1717
import pytest
18-
from distilabel.llms.base import AsyncLLM
18+
from distilabel.llms.base import LLM, AsyncLLM
19+
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
1920

2021
if TYPE_CHECKING:
2122
from distilabel.llms.typing import GenerateOutput
@@ -37,6 +38,22 @@ async def agenerate(
3738
return ["output" for _ in range(num_generations)]
3839

3940

41+
class DummyMagpieLLM(LLM, MagpieChatTemplateMixin):
42+
def load(self) -> None:
43+
pass
44+
45+
@property
46+
def model_name(self) -> str:
47+
return "test"
48+
49+
def generate(
50+
self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any
51+
) -> List["GenerateOutput"]:
52+
return [
53+
["Hello Magpie" for _ in range(num_generations)] for _ in range(len(inputs))
54+
]
55+
56+
4057
@pytest.fixture
4158
def dummy_llm() -> AsyncLLM:
4259
return DummyLLM()

Diff for: tests/unit/llms/mixins/test_magpie.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2023-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
from distilabel.llms.mixins.magpie import MAGPIE_PRE_QUERY_TEMPLATES
17+
18+
from tests.unit.conftest import DummyMagpieLLM
19+
20+
21+
class TestMagpieChatTemplateMixin:
22+
def test_magpie_pre_query_template_set(self) -> None:
23+
with pytest.raises(
24+
ValueError,
25+
match="Cannot set `use_magpie_template=True` if `magpie_pre_query_template` is `None`",
26+
):
27+
DummyMagpieLLM(use_magpie_template=True)
28+
29+
def test_magpie_pre_query_template_alias_resolved(self) -> None:
30+
llm = DummyMagpieLLM(magpie_pre_query_template="llama3")
31+
assert llm.magpie_pre_query_template == MAGPIE_PRE_QUERY_TEMPLATES["llama3"]
32+
33+
def test_apply_magpie_pre_query_template(self) -> None:
34+
llm = DummyMagpieLLM(magpie_pre_query_template="<user>")
35+
36+
assert (
37+
llm.apply_magpie_pre_query_template(
38+
prompt="<system>Hello hello</system>", input=[]
39+
)
40+
== "<system>Hello hello</system>"
41+
)
42+
43+
llm = DummyMagpieLLM(
44+
use_magpie_template=True, magpie_pre_query_template="<user>"
45+
)
46+
47+
assert (
48+
llm.apply_magpie_pre_query_template(
49+
prompt="<system>Hello hello</system>", input=[]
50+
)
51+
== "<system>Hello hello</system><user>"
52+
)
53+
54+
assert (
55+
llm.apply_magpie_pre_query_template(
56+
prompt="<system>Hello hello</system><user>Hey</user>",
57+
input=[{"role": "user", "content": "Hey"}],
58+
)
59+
== "<system>Hello hello</system><user>Hey</user>"
60+
)

Diff for: tests/unit/steps/tasks/magpie/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2023-present, Argilla, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+

0 commit comments

Comments
 (0)