Skip to content

Commit 51c7781

Browse files
authored
feat: add history to prompt (#480)
1 parent 1eafdfb commit 51c7781

File tree

3 files changed

+152
-0
lines changed

3 files changed

+152
-0
lines changed

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Unreleased
44
- Make the score in VectorStoreResult consistent (always bigger is better)
55
- Add router option to LiteLLMEmbedder (#440)
6+
- New methods in Prompt class for appending conversation history (#480)
67
- Fix: make unflatten_dict symmetric to flatten_dict (#461)
78
- Cost and capabilities config for custom litellm models (#481)
89

packages/ragbits-core/src/ragbits/core/prompt/prompt.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
142142
# Additional few shot examples that can be added dynamically using methods
143143
# (in opposite to the static `few_shots` attribute which is defined in the class)
144144
self._instance_few_shots: list[FewShotExample[InputT, OutputT]] = []
145+
146+
# Additional conversation history that can be added dynamically using methods
147+
self._conversation_history: list[dict[str, Any]] = []
145148
super().__init__()
146149

147150
@property
@@ -166,6 +169,7 @@ def chat(self) -> ChatFormat:
166169
),
167170
*self.list_few_shots(),
168171
{"role": "user", "content": user_content},
172+
*self._conversation_history,
169173
]
170174
return chat
171175

@@ -216,6 +220,57 @@ def list_few_shots(self) -> ChatFormat:
216220
result.append({"role": "assistant", "content": assistant_content})
217221
return result
218222

223+
def add_user_message(self, message: str | dict[str, Any] | InputT) -> "Prompt[InputT, OutputT]":
224+
"""
225+
Add a user message to the conversation history.
226+
227+
Args:
228+
message (str | dict[str, Any] | InputT): The user message content. Can be:
229+
- A string: Used directly as content
230+
- A dictionary: With format {"type": "text", "text": "message"} or image content
231+
- An InputT model: Will be rendered using the user prompt template
232+
233+
Returns:
234+
Prompt[InputT, OutputT]: The current prompt instance to allow chaining.
235+
"""
236+
content: str | list[dict[str, Any]] | dict[str, Any] | InputT
237+
238+
if isinstance(message, BaseModel):
239+
# Type checking to ensure we're passing InputT to the methods
240+
input_model: InputT = cast(InputT, message)
241+
242+
# Render the message using the template if it's an input model
243+
rendered_text = self._render_template(self.user_prompt_template, input_model)
244+
images_in_input = self._get_images_from_input_data(input_model)
245+
246+
if images_in_input:
247+
content = [{"type": "text", "text": rendered_text}] + [
248+
self._create_message_with_image(image) for image in images_in_input
249+
]
250+
else:
251+
content = rendered_text
252+
else:
253+
# Use the message directly if it's a string or dict
254+
content = message
255+
256+
self._conversation_history.append({"role": "user", "content": content})
257+
return self
258+
259+
def add_assistant_message(self, message: str | OutputT) -> "Prompt[InputT, OutputT]":
260+
"""
261+
Add an assistant message to the conversation history.
262+
263+
Args:
264+
message (str): The assistant message content.
265+
266+
Returns:
267+
Prompt[InputT, OutputT]: The current prompt instance to allow chaining.
268+
"""
269+
if isinstance(message, BaseModel):
270+
message = message.model_dump_json()
271+
self._conversation_history.append({"role": "assistant", "content": str(message)})
272+
return self
273+
219274
def list_images(self) -> list[str]:
220275
"""
221276
Returns the images in form of URLs or base64 encoded strings.

packages/ragbits-core/tests/unit/prompts/test_prompt.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,3 +537,99 @@ def sync_parser(response: str) -> str:
537537
test_prompt.response_parser = sync_parser
538538
resp_sync = await test_prompt.parse_response(resp)
539539
assert resp_sync == "hello human"
540+
541+
542+
def test_add_user_message_with_string():
543+
"""Test adding a user message with a string content."""
544+
545+
class TestPrompt(Prompt):
546+
user_prompt = "Hello"
547+
548+
prompt = TestPrompt()
549+
prompt.add_user_message("Additional message")
550+
551+
assert prompt.chat == [{"role": "user", "content": "Hello"}, {"role": "user", "content": "Additional message"}]
552+
553+
554+
def test_add_user_message_with_input_model():
555+
"""Test adding a user message with an input model."""
556+
557+
class TestPrompt(Prompt[_PromptInput, str]):
558+
user_prompt = "Hello {{ name }}"
559+
560+
prompt = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock"))
561+
prompt.add_user_message(_PromptInput(name="Bob", age=25, theme="jazz"))
562+
563+
assert prompt.chat == [{"role": "user", "content": "Hello Alice"}, {"role": "user", "content": "Hello Bob"}]
564+
565+
566+
def test_add_user_message_with_image():
567+
"""Test adding a user message with an image."""
568+
569+
class ImagePrompt(Prompt):
570+
user_prompt = "What is on this image?"
571+
image_input_fields = ["image"]
572+
573+
prompt = ImagePrompt(_ImagePromptInput(image=_get_image_bytes()))
574+
prompt.add_user_message(_ImagePromptInput(image=_get_image_bytes()))
575+
576+
assert len(prompt.chat) == 2
577+
assert prompt.chat[0]["role"] == "user"
578+
assert prompt.chat[1]["role"] == "user"
579+
assert len(prompt.chat[0]["content"]) == 2 # text + image
580+
assert len(prompt.chat[1]["content"]) == 2 # text + image
581+
582+
583+
def test_add_assistant_message():
584+
"""Test adding an assistant message."""
585+
586+
class TestPrompt(Prompt[_PromptInput, _PromptOutput]):
587+
user_prompt = "Hello {{ name }}"
588+
589+
prompt = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock"))
590+
prompt.add_assistant_message("Assistant response")
591+
592+
assert prompt.chat == [
593+
{"role": "user", "content": "Hello Alice"},
594+
{"role": "assistant", "content": "Assistant response"},
595+
]
596+
597+
598+
def test_add_assistant_message_with_model():
599+
"""Test adding an assistant message with a model output."""
600+
601+
class TestPrompt(Prompt[_PromptInput, _PromptOutput]):
602+
user_prompt = "Hello {{ name }}"
603+
604+
prompt = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock"))
605+
output = _PromptOutput(song_title="Test Song", song_lyrics="Test Lyrics")
606+
prompt.add_assistant_message(output)
607+
608+
assert prompt.chat == [
609+
{"role": "user", "content": "Hello Alice"},
610+
{"role": "assistant", "content": output.model_dump_json()},
611+
]
612+
613+
614+
def test_conversation_history():
615+
"""Test building a complete conversation history with multiple messages."""
616+
617+
class TestPrompt(Prompt[_PromptInput, _PromptOutput]):
618+
user_prompt = "Hello {{ name }}"
619+
620+
prompt = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock"))
621+
prompt.add_user_message("How are you?")
622+
prompt.add_assistant_message("I'm doing well!")
623+
prompt.add_user_message(_PromptInput(name="Bob", age=25, theme="jazz"))
624+
prompt.add_assistant_message(_PromptOutput(song_title="Jazz Song", song_lyrics="Jazz lyrics"))
625+
626+
assert prompt.chat == [
627+
{"role": "user", "content": "Hello Alice"},
628+
{"role": "user", "content": "How are you?"},
629+
{"role": "assistant", "content": "I'm doing well!"},
630+
{"role": "user", "content": "Hello Bob"},
631+
{
632+
"role": "assistant",
633+
"content": _PromptOutput(song_title="Jazz Song", song_lyrics="Jazz lyrics").model_dump_json(),
634+
},
635+
]

0 commit comments

Comments
 (0)