diff --git a/spring-cloud-aws-messaging/src/main/java/io/awspring/cloud/messaging/listener/AbstractMessageListenerContainer.java b/spring-cloud-aws-messaging/src/main/java/io/awspring/cloud/messaging/listener/AbstractMessageListenerContainer.java index 1ffc96158..69076724b 100644 --- a/spring-cloud-aws-messaging/src/main/java/io/awspring/cloud/messaging/listener/AbstractMessageListenerContainer.java +++ b/spring-cloud-aws-messaging/src/main/java/io/awspring/cloud/messaging/listener/AbstractMessageListenerContainer.java @@ -331,9 +331,10 @@ private QueueAttributes queueAttributes(String queue, SqsMessageDeletionPolicy d new GetQueueAttributesRequest(destinationUrl).withAttributeNames(QueueAttributeName.RedrivePolicy)); boolean hasRedrivePolicy = queueAttributes.getAttributes() .containsKey(QueueAttributeName.RedrivePolicy.toString()); + boolean isFifo = queue.endsWith(".fifo"); return new QueueAttributes(hasRedrivePolicy, deletionPolicy, destinationUrl, getMaxNumberOfMessages(), - getVisibilityTimeout(), getWaitTimeOut()); + getVisibilityTimeout(), getWaitTimeOut(), isFifo); } @Override @@ -384,14 +385,17 @@ protected static class QueueAttributes { private final Integer waitTimeOut; + private final boolean fifo; + public QueueAttributes(boolean hasRedrivePolicy, SqsMessageDeletionPolicy deletionPolicy, String destinationUrl, - Integer maxNumberOfMessages, Integer visibilityTimeout, Integer waitTimeOut) { + Integer maxNumberOfMessages, Integer visibilityTimeout, Integer waitTimeOut, boolean fifo) { this.hasRedrivePolicy = hasRedrivePolicy; this.deletionPolicy = deletionPolicy; this.destinationUrl = destinationUrl; this.maxNumberOfMessages = maxNumberOfMessages; this.visibilityTimeout = visibilityTimeout; this.waitTimeOut = waitTimeOut; + this.fifo = fifo; } public boolean hasRedrivePolicy() { @@ -424,6 +428,10 @@ public SqsMessageDeletionPolicy getDeletionPolicy() { return this.deletionPolicy; } + boolean isFifo() { + return fifo; + } + } } diff --git a/spring-cloud-aws-messaging/src/main/java/io/awspring/cloud/messaging/listener/SimpleMessageListenerContainer.java b/spring-cloud-aws-messaging/src/main/java/io/awspring/cloud/messaging/listener/SimpleMessageListenerContainer.java index 016b29328..66d0a12fe 100644 --- a/spring-cloud-aws-messaging/src/main/java/io/awspring/cloud/messaging/listener/SimpleMessageListenerContainer.java +++ b/spring-cloud-aws-messaging/src/main/java/io/awspring/cloud/messaging/listener/SimpleMessageListenerContainer.java @@ -16,7 +16,9 @@ package io.awspring.cloud.messaging.listener; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; @@ -24,9 +26,11 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; import com.amazonaws.services.sqs.model.DeleteMessageRequest; import com.amazonaws.services.sqs.model.Message; +import com.amazonaws.services.sqs.model.MessageSystemAttributeName; import com.amazonaws.services.sqs.model.ReceiveMessageResult; import org.springframework.core.task.AsyncTaskExecutor; @@ -329,12 +333,16 @@ public void run() { try { ReceiveMessageResult receiveMessageResult = getAmazonSqs() .receiveMessage(this.queueAttributes.getReceiveMessageRequest()); - CountDownLatch messageBatchLatch = new CountDownLatch(receiveMessageResult.getMessages().size()); - for (Message message : receiveMessageResult.getMessages()) { + + final List messageGroups = queueAttributes.isFifo() + ? groupByMessageGroupId(receiveMessageResult) : groupByMessage(receiveMessageResult); + CountDownLatch messageBatchLatch = new CountDownLatch(messageGroups.size()); + for (MessageGroup messageGroup : messageGroups) { if (isQueueRunning(this.logicalQueueName)) { - MessageExecutor messageExecutor = new MessageExecutor(this.logicalQueueName, message, - this.queueAttributes); - getTaskExecutor().execute(new SignalExecutingRunnable(messageBatchLatch, messageExecutor)); + MessageGroupExecutor messageGroupExecutor = new MessageGroupExecutor(this.logicalQueueName, + messageGroup, this.queueAttributes); + getTaskExecutor() + .execute(new SignalExecutingRunnable(messageBatchLatch, messageGroupExecutor)); } else { messageBatchLatch.countDown(); @@ -363,11 +371,40 @@ public void run() { SimpleMessageListenerContainer.this.scheduledFutureByQueue.remove(this.logicalQueueName); } + private List groupByMessageGroupId(final ReceiveMessageResult receiveMessageResult) { + return receiveMessageResult.getMessages().stream() + .collect(Collectors.groupingBy(message -> message.getMessageAttributes() + .get(MessageSystemAttributeName.MessageGroupId.name()))) + .values().stream().map(MessageGroup::new).collect(Collectors.toList()); + } + + private List groupByMessage(final ReceiveMessageResult receiveMessageResult) { + return receiveMessageResult.getMessages().stream().map(MessageGroup::new).collect(Collectors.toList()); + } + + } + + private static final class MessageGroup { + + private final List messages; + + MessageGroup(final Message message) { + this.messages = Collections.singletonList(message); + } + + MessageGroup(final List messages) { + this.messages = messages; + } + + public List getMessages() { + return this.messages; + } + } - private final class MessageExecutor implements Runnable { + private final class MessageGroupExecutor implements Runnable { - private final Message message; + private final MessageGroup messageGroup; private final String logicalQueueName; @@ -377,9 +414,10 @@ private final class MessageExecutor implements Runnable { private final SqsMessageDeletionPolicy deletionPolicy; - private MessageExecutor(String logicalQueueName, Message message, QueueAttributes queueAttributes) { + private MessageGroupExecutor(String logicalQueueName, MessageGroup messageGroup, + QueueAttributes queueAttributes) { this.logicalQueueName = logicalQueueName; - this.message = message; + this.messageGroup = messageGroup; this.queueUrl = queueAttributes.getReceiveMessageRequest().getQueueUrl(); this.hasRedrivePolicy = queueAttributes.hasRedrivePolicy(); this.deletionPolicy = queueAttributes.getDeletionPolicy(); @@ -387,14 +425,16 @@ private MessageExecutor(String logicalQueueName, Message message, QueueAttribute @Override public void run() { - String receiptHandle = this.message.getReceiptHandle(); - org.springframework.messaging.Message queueMessage = getMessageForExecution(); - try { - executeMessage(queueMessage); - applyDeletionPolicyOnSuccess(receiptHandle); - } - catch (MessagingException messagingException) { - applyDeletionPolicyOnError(receiptHandle); + for (Message message : this.messageGroup.getMessages()) { + String receiptHandle = message.getReceiptHandle(); + org.springframework.messaging.Message queueMessage = getMessageForExecution(message); + try { + executeMessage(queueMessage); + applyDeletionPolicyOnSuccess(receiptHandle); + } + catch (MessagingException messagingException) { + applyDeletionPolicyOnError(receiptHandle); + } } } @@ -418,20 +458,19 @@ private void deleteMessage(String receiptHandle) { new DeleteMessageHandler(receiptHandle)); } - private org.springframework.messaging.Message getMessageForExecution() { + private org.springframework.messaging.Message getMessageForExecution(final Message message) { HashMap additionalHeaders = new HashMap<>(); additionalHeaders.put(QueueMessageHandler.LOGICAL_RESOURCE_ID, this.logicalQueueName); if (this.deletionPolicy == SqsMessageDeletionPolicy.NEVER) { - String receiptHandle = this.message.getReceiptHandle(); + String receiptHandle = message.getReceiptHandle(); QueueMessageAcknowledgment acknowledgment = new QueueMessageAcknowledgment( SimpleMessageListenerContainer.this.getAmazonSqs(), this.queueUrl, receiptHandle); additionalHeaders.put(QueueMessageHandler.ACKNOWLEDGMENT, acknowledgment); } - additionalHeaders.put(QueueMessageHandler.VISIBILITY, - new QueueMessageVisibility(SimpleMessageListenerContainer.this.getAmazonSqs(), this.queueUrl, - this.message.getReceiptHandle())); + additionalHeaders.put(QueueMessageHandler.VISIBILITY, new QueueMessageVisibility( + SimpleMessageListenerContainer.this.getAmazonSqs(), this.queueUrl, message.getReceiptHandle())); - return createMessage(this.message, additionalHeaders); + return createMessage(message, additionalHeaders); } } diff --git a/spring-cloud-aws-messaging/src/test/java/io/awspring/cloud/messaging/listener/SimpleMessageListenerContainerTest.java b/spring-cloud-aws-messaging/src/test/java/io/awspring/cloud/messaging/listener/SimpleMessageListenerContainerTest.java index d48b998ef..d6eb01108 100644 --- a/spring-cloud-aws-messaging/src/test/java/io/awspring/cloud/messaging/listener/SimpleMessageListenerContainerTest.java +++ b/spring-cloud-aws-messaging/src/test/java/io/awspring/cloud/messaging/listener/SimpleMessageListenerContainerTest.java @@ -17,11 +17,14 @@ package io.awspring.cloud.messaging.listener; import java.nio.charset.Charset; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import ch.qos.logback.classic.Level; import ch.qos.logback.classic.Logger; @@ -38,12 +41,14 @@ import com.amazonaws.services.sqs.model.GetQueueUrlResult; import com.amazonaws.services.sqs.model.Message; import com.amazonaws.services.sqs.model.MessageAttributeValue; +import com.amazonaws.services.sqs.model.MessageSystemAttributeName; import com.amazonaws.services.sqs.model.OverLimitException; import com.amazonaws.services.sqs.model.QueueAttributeName; import com.amazonaws.services.sqs.model.ReceiveMessageRequest; import com.amazonaws.services.sqs.model.ReceiveMessageResult; import io.awspring.cloud.core.support.documentation.RuntimeUse; import io.awspring.cloud.messaging.config.annotation.EnableSqs; +import io.awspring.cloud.messaging.core.MessageAttributeDataTypes; import io.awspring.cloud.messaging.listener.annotation.SqsListener; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -80,6 +85,12 @@ import static org.mockito.Mockito.withSettings; import static org.mockito.MockitoAnnotations.initMocks; +/** + * @author Agim Emruli + * @author Alain Sahli + * @author Mete Alpaslan Katırcıoğlu + * @since 1.0 + */ /** * @author Agim Emruli * @author Alain Sahli @@ -132,6 +143,13 @@ private static void mockGetQueueUrl(AmazonSQSAsync sqs, String queueName, String .thenReturn(new GetQueueUrlResult().withQueueUrl(queueUrl)); } + private static Message fifoMessage(final String messageGroupId, final String content) { + Map headers = new HashMap<>(); + headers.put(MessageSystemAttributeName.MessageGroupId.name(), new MessageAttributeValue() + .withDataType(MessageAttributeDataTypes.STRING).withStringValue(messageGroupId)); + return new Message().withMessageAttributes(headers).withBody(content); + } + @BeforeEach void setUp() { initMocks(this); @@ -256,6 +274,75 @@ public void handleMessage(org.springframework.messaging.Message message) thro container.stop(); } + @Test + void testReceiveMessagesFromFifoQueue() throws Exception { + SimpleMessageListenerContainer container = new SimpleMessageListenerContainer(); + + AmazonSQSAsync sqs = mock(AmazonSQSAsync.class, withSettings().stubOnly()); + container.setAmazonSqs(sqs); + + CountDownLatch countDownLatch = new CountDownLatch(10); + List actualHandledMessages = new ArrayList<>(); + QueueMessageHandler messageHandler = new QueueMessageHandler() { + + @Override + public void handleMessage(org.springframework.messaging.Message message) throws MessagingException { + assertThat(message.getPayload()).isInstanceOf(String.class); + actualHandledMessages.add((String) message.getPayload()); + countDownLatch.countDown(); + } + }; + container.setMessageHandler(messageHandler); + StaticApplicationContext applicationContext = new StaticApplicationContext(); + applicationContext.registerSingleton("fifoTestMessageListener", FifoTestMessageListener.class); + messageHandler.setApplicationContext(applicationContext); + container.setBeanName("testContainerName"); + messageHandler.afterPropertiesSet(); + + mockGetQueueUrl(sqs, "testQueue.fifo", "http://testSimpleReceiveMessage.amazonaws.com"); + mockGetQueueAttributesWithEmptyResult(sqs, "http://testSimpleReceiveMessage.amazonaws.com"); + + container.afterPropertiesSet(); + + final Message group1Msg1 = fifoMessage("1", "group1Msg1"); + final Message group1Msg2 = fifoMessage("1", "group1Msg2"); + final Message group1Msg3 = fifoMessage("1", "group1Msg3"); + final Message group1Msg4 = fifoMessage("1", "group1Msg4"); + final Message group1Msg5 = fifoMessage("1", "group1Msg5"); + final Message group1Msg6 = fifoMessage("1", "group1Msg6"); + final Message group1Msg7 = fifoMessage("1", "group1Msg7"); + final Message group2Msg1 = fifoMessage("2", "group2Msg1"); + final Message group2Msg2 = fifoMessage("2", "group2Msg2"); + final Message group3Msg1 = fifoMessage("3", "group3Msg1"); + + when(sqs.receiveMessage( + new ReceiveMessageRequest("http://testSimpleReceiveMessage.amazonaws.com").withAttributeNames("All") + .withMessageAttributeNames("All").withMaxNumberOfMessages(10).withWaitTimeSeconds(20))) + .thenReturn(new ReceiveMessageResult().withMessages(group1Msg1, group1Msg2, group1Msg3, + group1Msg4, group1Msg5, group1Msg6, group1Msg7, group2Msg1, group2Msg2, + group3Msg1)) + .thenReturn(new ReceiveMessageResult()); + when(sqs.getQueueAttributes(any(GetQueueAttributesRequest.class))).thenReturn(new GetQueueAttributesResult()); + + container.start(); + + assertThat(countDownLatch.await(3, TimeUnit.SECONDS)).isTrue(); + + final List actualGroup1Messages = actualHandledMessages.stream().filter(msg -> msg.startsWith("group1")) + .collect(Collectors.toList()); + final List actualGroup2Messages = actualHandledMessages.stream().filter(msg -> msg.startsWith("group2")) + .collect(Collectors.toList()); + final List actualGroup3Messages = actualHandledMessages.stream().filter(msg -> msg.startsWith("group3")) + .collect(Collectors.toList()); + + assertThat(actualGroup1Messages).containsExactly("group1Msg1", "group1Msg2", "group1Msg3", "group1Msg4", + "group1Msg5", "group1Msg6", "group1Msg7"); + assertThat(actualGroup2Messages).containsExactly("group2Msg1", "group2Msg2"); + assertThat(actualGroup3Messages).containsExactly("group3Msg1"); + + container.stop(); + } + @Test void testContainerDoesNotProcessMessageAfterBeingStopped() throws Exception { SimpleMessageListenerContainer container = new SimpleMessageListenerContainer(); @@ -1287,6 +1374,22 @@ CountDownLatch getCountDownLatch() { } + private static class FifoTestMessageListener { + + private String message; + + @RuntimeUse + @SqsListener("testQueue.fifo") + private void handleMessage(String message) { + this.message = message; + } + + String getMessage() { + return this.message; + } + + } + private static class AnotherTestMessageListener { private String message;