Skip to content

Commit a665f1f

Browse files
introduce utility function (#939)
Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
1 parent 7cee6c8 commit a665f1f

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

integrations/llama_cpp/src/haystack_integrations/components/generators/llama_cpp/chat/chat_generator.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,21 @@
99
logger = logging.getLogger(__name__)
1010

1111

12+
def _convert_message_to_llamacpp_format(message: ChatMessage) -> Dict[str, str]:
13+
"""
14+
Convert a message to the format expected by Llama.cpp.
15+
:returns: A dictionary with the following keys:
16+
- `role`
17+
- `content`
18+
- `name` (optional)
19+
"""
20+
formatted_msg = {"role": message.role.value, "content": message.content}
21+
if message.name:
22+
formatted_msg["name"] = message.name
23+
24+
return formatted_msg
25+
26+
1227
@component
1328
class LlamaCppChatGenerator:
1429
"""
@@ -96,7 +111,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
96111
return {"replies": []}
97112

98113
updated_generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
99-
formatted_messages = [msg.to_openai_format() for msg in messages]
114+
formatted_messages = [_convert_message_to_llamacpp_format(msg) for msg in messages]
100115

101116
response = self.model.create_chat_completion(messages=formatted_messages, **updated_generation_kwargs)
102117
replies = [

integrations/llama_cpp/tests/test_chat_generator.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
1111
from haystack.dataclasses import ChatMessage, ChatRole
1212
from haystack.document_stores.in_memory import InMemoryDocumentStore
13-
from haystack_integrations.components.generators.llama_cpp import LlamaCppChatGenerator
13+
from haystack_integrations.components.generators.llama_cpp.chat.chat_generator import (
14+
LlamaCppChatGenerator,
15+
_convert_message_to_llamacpp_format,
16+
)
1417

1518

1619
@pytest.fixture
@@ -29,6 +32,21 @@ def download_file(file_link, filename, capsys):
2932
print("\nModel file already exists.")
3033

3134

35+
def test_convert_message_to_llamacpp_format():
36+
message = ChatMessage.from_system("You are good assistant")
37+
assert _convert_message_to_llamacpp_format(message) == {"role": "system", "content": "You are good assistant"}
38+
39+
message = ChatMessage.from_user("I have a question")
40+
assert _convert_message_to_llamacpp_format(message) == {"role": "user", "content": "I have a question"}
41+
42+
message = ChatMessage.from_function("Function call", "function_name")
43+
assert _convert_message_to_llamacpp_format(message) == {
44+
"role": "function",
45+
"content": "Function call",
46+
"name": "function_name",
47+
}
48+
49+
3250
class TestLlamaCppChatGenerator:
3351
@pytest.fixture
3452
def generator(self, model_path, capsys):

0 commit comments

Comments
 (0)