diff --git a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java index 2d5d8b64e9d..471e5a48233 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/chat/prompt/Prompt.java @@ -198,21 +198,21 @@ else if (message instanceof ToolResponseMessage toolResponseMessage) { * @return a new {@link Prompt} instance with the augmented system message. */ public Prompt augmentSystemMessage(Function systemMessageAugmenter) { - var messagesCopy = new ArrayList<>(this.messages); - for (int i = 0; i <= this.messages.size() - 1; i++) { + boolean found = false; + for (int i = 0; i < messagesCopy.size(); i++) { Message message = messagesCopy.get(i); if (message instanceof SystemMessage systemMessage) { messagesCopy.set(i, systemMessageAugmenter.apply(systemMessage)); + found = true; break; } - if (i == 0) { - // If no system message is found, create a new one with the provided text - // and add it as the first item in the list. - messagesCopy.add(0, systemMessageAugmenter.apply(new SystemMessage(""))); - } } - + if (!found) { + // If no system message is found, create a new one with the provided text + // and add it as the first item in the list. + messagesCopy.add(0, systemMessageAugmenter.apply(new SystemMessage(""))); + } return new Prompt(messagesCopy, null == this.chatOptions ? null : this.chatOptions.copy()); } diff --git a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java index db88c270103..2b7c9efdb5b 100644 --- a/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java +++ b/spring-ai-model/src/test/java/org/springframework/ai/chat/prompt/PromptTests.java @@ -239,4 +239,26 @@ void augmentSystemMessageWhenNone() { assertThat(prompt.getSystemMessage().getText()).isEqualTo(""); } + @Test + void augmentSystemMessageWhenNotFirst() { + Message[] messages = { new UserMessage("Hi"), new SystemMessage("Hello") }; + Prompt prompt = Prompt.builder().messages(messages).build(); + + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getUserMessage()).isNotNull(); + assertThat(prompt.getUserMessage().getText()).isEqualTo("Hi"); + assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello"); + + Prompt copy = prompt.augmentSystemMessage(message -> message.mutate().text("How are you?").build()); + + assertThat(copy.getSystemMessage()).isNotNull(); + assertThat(copy.getInstructions().size()).isEqualTo(messages.length); + assertThat(copy.getSystemMessage().getText()).isEqualTo("How are you?"); + + assertThat(prompt.getSystemMessage()).isNotNull(); + assertThat(prompt.getUserMessage()).isNotNull(); + assertThat(prompt.getUserMessage().getText()).isEqualTo("Hi"); + assertThat(prompt.getSystemMessage().getText()).isEqualTo("Hello"); + } + }