Skip to content

Commit

Permalink
Add support for SQS FIFO queues.
Browse files Browse the repository at this point in the history
For FIFO queues the AsynchronousMessageListener groups messages with same messageGroupId into so called MessageGroups. The MessageExecutor (renamed to MessageGroupExecutor) handles the messages within those groups sequentially.

Messages from non-FIFO queues are handled as before with the only difference that they are also wrapped in a MessageGroup. Each separate message belongs to its own MessageGroup.

Fixes spring-attic/spring-cloud-aws#387
Fixes spring-attic/spring-cloud-aws#379
Fixes spring-attic/spring-cloud-aws#530
Fixes spring-attic/spring-cloud-aws#756
  • Loading branch information
Tristan Baumbusch authored and maciejwalkowiak committed Feb 8, 2021
1 parent 19a4a54 commit 91cf06a
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,10 @@ private QueueAttributes queueAttributes(String queue, SqsMessageDeletionPolicy d
new GetQueueAttributesRequest(destinationUrl).withAttributeNames(QueueAttributeName.RedrivePolicy));
boolean hasRedrivePolicy = queueAttributes.getAttributes()
.containsKey(QueueAttributeName.RedrivePolicy.toString());
boolean isFifo = queue.endsWith(".fifo");

return new QueueAttributes(hasRedrivePolicy, deletionPolicy, destinationUrl, getMaxNumberOfMessages(),
getVisibilityTimeout(), getWaitTimeOut());
getVisibilityTimeout(), getWaitTimeOut(), isFifo);
}

@Override
Expand Down Expand Up @@ -384,14 +385,17 @@ protected static class QueueAttributes {

private final Integer waitTimeOut;

private final boolean fifo;

public QueueAttributes(boolean hasRedrivePolicy, SqsMessageDeletionPolicy deletionPolicy, String destinationUrl,
Integer maxNumberOfMessages, Integer visibilityTimeout, Integer waitTimeOut) {
Integer maxNumberOfMessages, Integer visibilityTimeout, Integer waitTimeOut, boolean fifo) {
this.hasRedrivePolicy = hasRedrivePolicy;
this.deletionPolicy = deletionPolicy;
this.destinationUrl = destinationUrl;
this.maxNumberOfMessages = maxNumberOfMessages;
this.visibilityTimeout = visibilityTimeout;
this.waitTimeOut = waitTimeOut;
this.fifo = fifo;
}

public boolean hasRedrivePolicy() {
Expand Down Expand Up @@ -424,6 +428,10 @@ public SqsMessageDeletionPolicy getDeletionPolicy() {
return this.deletionPolicy;
}

boolean isFifo() {
return fifo;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,21 @@

package io.awspring.cloud.messaging.listener;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;

import com.amazonaws.services.sqs.model.DeleteMessageRequest;
import com.amazonaws.services.sqs.model.Message;
import com.amazonaws.services.sqs.model.MessageSystemAttributeName;
import com.amazonaws.services.sqs.model.ReceiveMessageResult;

import org.springframework.core.task.AsyncTaskExecutor;
Expand Down Expand Up @@ -329,12 +333,16 @@ public void run() {
try {
ReceiveMessageResult receiveMessageResult = getAmazonSqs()
.receiveMessage(this.queueAttributes.getReceiveMessageRequest());
CountDownLatch messageBatchLatch = new CountDownLatch(receiveMessageResult.getMessages().size());
for (Message message : receiveMessageResult.getMessages()) {

final List<MessageGroup> messageGroups = queueAttributes.isFifo()
? groupByMessageGroupId(receiveMessageResult) : groupByMessage(receiveMessageResult);
CountDownLatch messageBatchLatch = new CountDownLatch(messageGroups.size());
for (MessageGroup messageGroup : messageGroups) {
if (isQueueRunning(this.logicalQueueName)) {
MessageExecutor messageExecutor = new MessageExecutor(this.logicalQueueName, message,
this.queueAttributes);
getTaskExecutor().execute(new SignalExecutingRunnable(messageBatchLatch, messageExecutor));
MessageGroupExecutor messageGroupExecutor = new MessageGroupExecutor(this.logicalQueueName,
messageGroup, this.queueAttributes);
getTaskExecutor()
.execute(new SignalExecutingRunnable(messageBatchLatch, messageGroupExecutor));
}
else {
messageBatchLatch.countDown();
Expand Down Expand Up @@ -363,11 +371,40 @@ public void run() {
SimpleMessageListenerContainer.this.scheduledFutureByQueue.remove(this.logicalQueueName);
}

private List<MessageGroup> groupByMessageGroupId(final ReceiveMessageResult receiveMessageResult) {
return receiveMessageResult.getMessages().stream()
.collect(Collectors.groupingBy(message -> message.getMessageAttributes()
.get(MessageSystemAttributeName.MessageGroupId.name())))
.values().stream().map(MessageGroup::new).collect(Collectors.toList());
}

private List<MessageGroup> groupByMessage(final ReceiveMessageResult receiveMessageResult) {
return receiveMessageResult.getMessages().stream().map(MessageGroup::new).collect(Collectors.toList());
}

}

private static final class MessageGroup {

private final List<Message> messages;

MessageGroup(final Message message) {
this.messages = Collections.singletonList(message);
}

MessageGroup(final List<Message> messages) {
this.messages = messages;
}

public List<Message> getMessages() {
return this.messages;
}

}

private final class MessageExecutor implements Runnable {
private final class MessageGroupExecutor implements Runnable {

private final Message message;
private final MessageGroup messageGroup;

private final String logicalQueueName;

Expand All @@ -377,24 +414,27 @@ private final class MessageExecutor implements Runnable {

private final SqsMessageDeletionPolicy deletionPolicy;

private MessageExecutor(String logicalQueueName, Message message, QueueAttributes queueAttributes) {
private MessageGroupExecutor(String logicalQueueName, MessageGroup messageGroup,
QueueAttributes queueAttributes) {
this.logicalQueueName = logicalQueueName;
this.message = message;
this.messageGroup = messageGroup;
this.queueUrl = queueAttributes.getReceiveMessageRequest().getQueueUrl();
this.hasRedrivePolicy = queueAttributes.hasRedrivePolicy();
this.deletionPolicy = queueAttributes.getDeletionPolicy();
}

@Override
public void run() {
String receiptHandle = this.message.getReceiptHandle();
org.springframework.messaging.Message<String> queueMessage = getMessageForExecution();
try {
executeMessage(queueMessage);
applyDeletionPolicyOnSuccess(receiptHandle);
}
catch (MessagingException messagingException) {
applyDeletionPolicyOnError(receiptHandle);
for (Message message : this.messageGroup.getMessages()) {
String receiptHandle = message.getReceiptHandle();
org.springframework.messaging.Message<String> queueMessage = getMessageForExecution(message);
try {
executeMessage(queueMessage);
applyDeletionPolicyOnSuccess(receiptHandle);
}
catch (MessagingException messagingException) {
applyDeletionPolicyOnError(receiptHandle);
}
}
}

Expand All @@ -418,20 +458,19 @@ private void deleteMessage(String receiptHandle) {
new DeleteMessageHandler(receiptHandle));
}

private org.springframework.messaging.Message<String> getMessageForExecution() {
private org.springframework.messaging.Message<String> getMessageForExecution(final Message message) {
HashMap<String, Object> additionalHeaders = new HashMap<>();
additionalHeaders.put(QueueMessageHandler.LOGICAL_RESOURCE_ID, this.logicalQueueName);
if (this.deletionPolicy == SqsMessageDeletionPolicy.NEVER) {
String receiptHandle = this.message.getReceiptHandle();
String receiptHandle = message.getReceiptHandle();
QueueMessageAcknowledgment acknowledgment = new QueueMessageAcknowledgment(
SimpleMessageListenerContainer.this.getAmazonSqs(), this.queueUrl, receiptHandle);
additionalHeaders.put(QueueMessageHandler.ACKNOWLEDGMENT, acknowledgment);
}
additionalHeaders.put(QueueMessageHandler.VISIBILITY,
new QueueMessageVisibility(SimpleMessageListenerContainer.this.getAmazonSqs(), this.queueUrl,
this.message.getReceiptHandle()));
additionalHeaders.put(QueueMessageHandler.VISIBILITY, new QueueMessageVisibility(
SimpleMessageListenerContainer.this.getAmazonSqs(), this.queueUrl, message.getReceiptHandle()));

return createMessage(this.message, additionalHeaders);
return createMessage(message, additionalHeaders);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@
package io.awspring.cloud.messaging.listener;

import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.Logger;
Expand All @@ -38,12 +41,14 @@
import com.amazonaws.services.sqs.model.GetQueueUrlResult;
import com.amazonaws.services.sqs.model.Message;
import com.amazonaws.services.sqs.model.MessageAttributeValue;
import com.amazonaws.services.sqs.model.MessageSystemAttributeName;
import com.amazonaws.services.sqs.model.OverLimitException;
import com.amazonaws.services.sqs.model.QueueAttributeName;
import com.amazonaws.services.sqs.model.ReceiveMessageRequest;
import com.amazonaws.services.sqs.model.ReceiveMessageResult;
import io.awspring.cloud.core.support.documentation.RuntimeUse;
import io.awspring.cloud.messaging.config.annotation.EnableSqs;
import io.awspring.cloud.messaging.core.MessageAttributeDataTypes;
import io.awspring.cloud.messaging.listener.annotation.SqsListener;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -80,6 +85,12 @@
import static org.mockito.Mockito.withSettings;
import static org.mockito.MockitoAnnotations.initMocks;

/**
* @author Agim Emruli
* @author Alain Sahli
* @author Mete Alpaslan Katırcıoğlu
* @since 1.0
*/
/**
* @author Agim Emruli
* @author Alain Sahli
Expand Down Expand Up @@ -132,6 +143,13 @@ private static void mockGetQueueUrl(AmazonSQSAsync sqs, String queueName, String
.thenReturn(new GetQueueUrlResult().withQueueUrl(queueUrl));
}

private static Message fifoMessage(final String messageGroupId, final String content) {
Map<String, MessageAttributeValue> headers = new HashMap<>();
headers.put(MessageSystemAttributeName.MessageGroupId.name(), new MessageAttributeValue()
.withDataType(MessageAttributeDataTypes.STRING).withStringValue(messageGroupId));
return new Message().withMessageAttributes(headers).withBody(content);
}

@BeforeEach
void setUp() {
initMocks(this);
Expand Down Expand Up @@ -256,6 +274,75 @@ public void handleMessage(org.springframework.messaging.Message<?> message) thro
container.stop();
}

@Test
void testReceiveMessagesFromFifoQueue() throws Exception {
SimpleMessageListenerContainer container = new SimpleMessageListenerContainer();

AmazonSQSAsync sqs = mock(AmazonSQSAsync.class, withSettings().stubOnly());
container.setAmazonSqs(sqs);

CountDownLatch countDownLatch = new CountDownLatch(10);
List<String> actualHandledMessages = new ArrayList<>();
QueueMessageHandler messageHandler = new QueueMessageHandler() {

@Override
public void handleMessage(org.springframework.messaging.Message<?> message) throws MessagingException {
assertThat(message.getPayload()).isInstanceOf(String.class);
actualHandledMessages.add((String) message.getPayload());
countDownLatch.countDown();
}
};
container.setMessageHandler(messageHandler);
StaticApplicationContext applicationContext = new StaticApplicationContext();
applicationContext.registerSingleton("fifoTestMessageListener", FifoTestMessageListener.class);
messageHandler.setApplicationContext(applicationContext);
container.setBeanName("testContainerName");
messageHandler.afterPropertiesSet();

mockGetQueueUrl(sqs, "testQueue.fifo", "http://testSimpleReceiveMessage.amazonaws.com");
mockGetQueueAttributesWithEmptyResult(sqs, "http://testSimpleReceiveMessage.amazonaws.com");

container.afterPropertiesSet();

final Message group1Msg1 = fifoMessage("1", "group1Msg1");
final Message group1Msg2 = fifoMessage("1", "group1Msg2");
final Message group1Msg3 = fifoMessage("1", "group1Msg3");
final Message group1Msg4 = fifoMessage("1", "group1Msg4");
final Message group1Msg5 = fifoMessage("1", "group1Msg5");
final Message group1Msg6 = fifoMessage("1", "group1Msg6");
final Message group1Msg7 = fifoMessage("1", "group1Msg7");
final Message group2Msg1 = fifoMessage("2", "group2Msg1");
final Message group2Msg2 = fifoMessage("2", "group2Msg2");
final Message group3Msg1 = fifoMessage("3", "group3Msg1");

when(sqs.receiveMessage(
new ReceiveMessageRequest("http://testSimpleReceiveMessage.amazonaws.com").withAttributeNames("All")
.withMessageAttributeNames("All").withMaxNumberOfMessages(10).withWaitTimeSeconds(20)))
.thenReturn(new ReceiveMessageResult().withMessages(group1Msg1, group1Msg2, group1Msg3,
group1Msg4, group1Msg5, group1Msg6, group1Msg7, group2Msg1, group2Msg2,
group3Msg1))
.thenReturn(new ReceiveMessageResult());
when(sqs.getQueueAttributes(any(GetQueueAttributesRequest.class))).thenReturn(new GetQueueAttributesResult());

container.start();

assertThat(countDownLatch.await(3, TimeUnit.SECONDS)).isTrue();

final List<String> actualGroup1Messages = actualHandledMessages.stream().filter(msg -> msg.startsWith("group1"))
.collect(Collectors.toList());
final List<String> actualGroup2Messages = actualHandledMessages.stream().filter(msg -> msg.startsWith("group2"))
.collect(Collectors.toList());
final List<String> actualGroup3Messages = actualHandledMessages.stream().filter(msg -> msg.startsWith("group3"))
.collect(Collectors.toList());

assertThat(actualGroup1Messages).containsExactly("group1Msg1", "group1Msg2", "group1Msg3", "group1Msg4",
"group1Msg5", "group1Msg6", "group1Msg7");
assertThat(actualGroup2Messages).containsExactly("group2Msg1", "group2Msg2");
assertThat(actualGroup3Messages).containsExactly("group3Msg1");

container.stop();
}

@Test
void testContainerDoesNotProcessMessageAfterBeingStopped() throws Exception {
SimpleMessageListenerContainer container = new SimpleMessageListenerContainer();
Expand Down Expand Up @@ -1287,6 +1374,22 @@ CountDownLatch getCountDownLatch() {

}

private static class FifoTestMessageListener {

private String message;

@RuntimeUse
@SqsListener("testQueue.fifo")
private void handleMessage(String message) {
this.message = message;
}

String getMessage() {
return this.message;
}

}

private static class AnotherTestMessageListener {

private String message;
Expand Down

0 comments on commit 91cf06a

Please sign in to comment.