diff --git a/cmd/milmove/serve.go b/cmd/milmove/serve.go index 7d4e28a9918..a19f4b2444f 100644 --- a/cmd/milmove/serve.go +++ b/cmd/milmove/serve.go @@ -479,7 +479,7 @@ func buildRoutingConfig(appCtx appcontext.AppContext, v *viper.Viper, redisPool } // Notification Receiver - notificationReceiver, err := notifications.InitReceiver(v, appCtx.Logger()) + notificationReceiver, err := notifications.InitReceiver(v, appCtx.Logger(), true) if err != nil { appCtx.Logger().Fatal("notification receiver not enabled", zap.Error(err)) } diff --git a/pkg/handlers/internalapi/uploads.go b/pkg/handlers/internalapi/uploads.go index a1bff90b220..e4968707b7b 100644 --- a/pkg/handlers/internalapi/uploads.go +++ b/pkg/handlers/internalapi/uploads.go @@ -302,15 +302,14 @@ func (o *CustomGetUploadStatusResponse) WriteResponse(rw http.ResponseWriter, pr uploadStatus = AVStatusTypePROCESSING } + // Limitation: once the status code header has been written (first response), we are not able to update the status for subsequent responses. + // Standard 200 OK used with common SSE paradigm + rw.WriteHeader(http.StatusOK) if uploadStatus == AVStatusTypeCLEAN || uploadStatus == AVStatusTypeINFECTED { - rw.WriteHeader(http.StatusOK) o.writeEventStreamMessage(rw, producer, 0, "message", string(uploadStatus)) o.writeEventStreamMessage(rw, producer, 1, "close", "Connection closed") return // skip notification loop since object already tagged from anti-virus } else { - // Limitation: once the status code header has been written (first response), we are not able to update the status for subsequent responses. - // StatusAccepted: Standard code 202 for accepted request, but response not yet ready. - rw.WriteHeader(http.StatusAccepted) o.writeEventStreamMessage(rw, producer, 0, "message", string(uploadStatus)) } @@ -345,7 +344,11 @@ func (o *CustomGetUploadStatusResponse) WriteResponse(rw http.ResponseWriter, pr // For loop over 120 seconds, cancel context when done and it breaks the loop totalReceiverContext, totalReceiverContextCancelFunc := context.WithTimeout(context.Background(), 120*time.Second) - defer totalReceiverContextCancelFunc() + defer func() { + id_counter++ + o.writeEventStreamMessage(rw, producer, id_counter, "close", "Connection closed") + totalReceiverContextCancelFunc() + }() // Cleanup if client closes connection go func() { @@ -356,8 +359,6 @@ func (o *CustomGetUploadStatusResponse) WriteResponse(rw http.ResponseWriter, pr // Cleanup at end of work go func() { <-totalReceiverContext.Done() - id_counter++ - o.writeEventStreamMessage(rw, producer, id_counter, "close", "Connection closed") _ = o.receiver.CloseoutQueue(o.appCtx, queueUrl) }() diff --git a/pkg/handlers/routing/internalapi_test/uploads_test.go b/pkg/handlers/routing/internalapi_test/uploads_test.go index 5b760f740bc..382cd74a5bf 100644 --- a/pkg/handlers/routing/internalapi_test/uploads_test.go +++ b/pkg/handlers/routing/internalapi_test/uploads_test.go @@ -77,7 +77,7 @@ func (suite *InternalAPISuite) TestUploads() { suite.SetupSiteHandler().ServeHTTP(rr, req) - suite.Equal(http.StatusAccepted, rr.Code) + suite.Equal(http.StatusOK, rr.Code) suite.Equal("text/event-stream", rr.Header().Get("content-type")) message1 := "id: 0\nevent: message\ndata: PROCESSING\n\n" diff --git a/pkg/notifications/mocks/NotificationReceiver.go b/pkg/notifications/mocks/NotificationReceiver.go index df8329e5f60..04c7d931659 100644 --- a/pkg/notifications/mocks/NotificationReceiver.go +++ b/pkg/notifications/mocks/NotificationReceiver.go @@ -3,9 +3,12 @@ package mocks import ( - mock "github.com/stretchr/testify/mock" + context "context" + appcontext "github.com/transcom/mymove/pkg/appcontext" + mock "github.com/stretchr/testify/mock" + notifications "github.com/transcom/mymove/pkg/notifications" ) @@ -88,9 +91,9 @@ func (_m *NotificationReceiver) GetDefaultTopic() (string, error) { return r0, r1 } -// ReceiveMessages provides a mock function with given fields: appCtx, queueUrl -func (_m *NotificationReceiver) ReceiveMessages(appCtx appcontext.AppContext, queueUrl string) ([]notifications.ReceivedMessage, error) { - ret := _m.Called(appCtx, queueUrl) +// ReceiveMessages provides a mock function with given fields: appCtx, queueUrl, timerContext +func (_m *NotificationReceiver) ReceiveMessages(appCtx appcontext.AppContext, queueUrl string, timerContext context.Context) ([]notifications.ReceivedMessage, error) { + ret := _m.Called(appCtx, queueUrl, timerContext) if len(ret) == 0 { panic("no return value specified for ReceiveMessages") @@ -98,19 +101,19 @@ func (_m *NotificationReceiver) ReceiveMessages(appCtx appcontext.AppContext, qu var r0 []notifications.ReceivedMessage var r1 error - if rf, ok := ret.Get(0).(func(appcontext.AppContext, string) ([]notifications.ReceivedMessage, error)); ok { - return rf(appCtx, queueUrl) + if rf, ok := ret.Get(0).(func(appcontext.AppContext, string, context.Context) ([]notifications.ReceivedMessage, error)); ok { + return rf(appCtx, queueUrl, timerContext) } - if rf, ok := ret.Get(0).(func(appcontext.AppContext, string) []notifications.ReceivedMessage); ok { - r0 = rf(appCtx, queueUrl) + if rf, ok := ret.Get(0).(func(appcontext.AppContext, string, context.Context) []notifications.ReceivedMessage); ok { + r0 = rf(appCtx, queueUrl, timerContext) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]notifications.ReceivedMessage) } } - if rf, ok := ret.Get(1).(func(appcontext.AppContext, string) error); ok { - r1 = rf(appCtx, queueUrl) + if rf, ok := ret.Get(1).(func(appcontext.AppContext, string, context.Context) error); ok { + r1 = rf(appCtx, queueUrl, timerContext) } else { r1 = ret.Error(1) } diff --git a/pkg/notifications/notification_receiver.go b/pkg/notifications/notification_receiver.go index 82bc32a02a8..09f9cd8b072 100644 --- a/pkg/notifications/notification_receiver.go +++ b/pkg/notifications/notification_receiver.go @@ -55,12 +55,14 @@ const ( type SnsClient interface { Subscribe(ctx context.Context, params *sns.SubscribeInput, optFns ...func(*sns.Options)) (*sns.SubscribeOutput, error) Unsubscribe(ctx context.Context, params *sns.UnsubscribeInput, optFns ...func(*sns.Options)) (*sns.UnsubscribeOutput, error) + ListSubscriptionsByTopic(context.Context, *sns.ListSubscriptionsByTopicInput, ...func(*sns.Options)) (*sns.ListSubscriptionsByTopicOutput, error) } type SqsClient interface { CreateQueue(ctx context.Context, params *sqs.CreateQueueInput, optFns ...func(*sqs.Options)) (*sqs.CreateQueueOutput, error) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) DeleteQueue(ctx context.Context, params *sqs.DeleteQueueInput, optFns ...func(*sqs.Options)) (*sqs.DeleteQueueOutput, error) + ListQueues(ctx context.Context, params *sqs.ListQueuesInput, optFns ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error) } type ViperType interface { @@ -216,7 +218,7 @@ func (n NotificationReceiverContext) GetDefaultTopic() (string, error) { } // InitReceiver initializes the receiver backend, only call this once -func InitReceiver(v ViperType, logger *zap.Logger) (NotificationReceiver, error) { +func InitReceiver(v ViperType, logger *zap.Logger, wipeAllNotificationQueues bool) (NotificationReceiver, error) { if v.GetString(cli.ReceiverBackendFlag) == "sns&sqs" { // Setup notification receiver service with SNS & SQS backend dependencies @@ -239,9 +241,11 @@ func InitReceiver(v ViperType, logger *zap.Logger) (NotificationReceiver, error) notificationReceiver := NewNotificationReceiver(v, snsService, sqsService, awsSNSRegion, awsAccountId) // Remove any remaining previous notification queues on server start - err = notificationReceiver.wipeAllNotificationQueues(snsService, sqsService, logger) - if err != nil { - return nil, err + if wipeAllNotificationQueues { + err = notificationReceiver.wipeAllNotificationQueues(logger) + if err != nil { + return nil, err + } } return notificationReceiver, nil @@ -255,15 +259,14 @@ func (n NotificationReceiverContext) constructArn(awsService string, endpointNam } // Removes ALL previously created notification queues -func (n *NotificationReceiverContext) wipeAllNotificationQueues(snsService *sns.Client, sqsService *sqs.Client, logger *zap.Logger) error { - +func (n *NotificationReceiverContext) wipeAllNotificationQueues(logger *zap.Logger) error { defaultTopic, err := n.GetDefaultTopic() if err != nil { return err } logger.Info("Removing previous subscriptions...") - paginator := sns.NewListSubscriptionsByTopicPaginator(snsService, &sns.ListSubscriptionsByTopicInput{ + paginator := sns.NewListSubscriptionsByTopicPaginator(n.snsService, &sns.ListSubscriptionsByTopicInput{ TopicArn: aws.String(n.constructArn("sns", defaultTopic)), }) @@ -276,7 +279,7 @@ func (n *NotificationReceiverContext) wipeAllNotificationQueues(snsService *sns. if strings.Contains(*subscription.Endpoint, string(QueuePrefixObjectTagsAdded)) { logger.Info("Subscription ARN: ", zap.String("subscription arn", *subscription.SubscriptionArn)) logger.Info("Endpoint ARN: ", zap.String("endpoint arn", *subscription.Endpoint)) - _, err = snsService.Unsubscribe(context.Background(), &sns.UnsubscribeInput{ + _, err = n.snsService.Unsubscribe(context.Background(), &sns.UnsubscribeInput{ SubscriptionArn: subscription.SubscriptionArn, }) if err != nil { @@ -287,7 +290,7 @@ func (n *NotificationReceiverContext) wipeAllNotificationQueues(snsService *sns. } logger.Info("Removing previous queues...") - result, err := sqsService.ListQueues(context.Background(), &sqs.ListQueuesInput{ + result, err := n.sqsService.ListQueues(context.Background(), &sqs.ListQueuesInput{ QueueNamePrefix: aws.String(string(QueuePrefixObjectTagsAdded)), }) if err != nil { @@ -295,7 +298,7 @@ func (n *NotificationReceiverContext) wipeAllNotificationQueues(snsService *sns. } for _, url := range result.QueueUrls { - _, err = sqsService.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{ + _, err = n.sqsService.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{ QueueUrl: &url, }) if err != nil { diff --git a/pkg/notifications/notification_receiver_test.go b/pkg/notifications/notification_receiver_test.go index e3275827e21..a996a67ce4e 100644 --- a/pkg/notifications/notification_receiver_test.go +++ b/pkg/notifications/notification_receiver_test.go @@ -66,6 +66,10 @@ func (_m *MockSnsClient) Unsubscribe(ctx context.Context, params *sns.Unsubscrib return &sns.UnsubscribeOutput{}, nil } +func (_m *MockSnsClient) ListSubscriptionsByTopic(context.Context, *sns.ListSubscriptionsByTopicInput, ...func(*sns.Options)) (*sns.ListSubscriptionsByTopicOutput, error) { + return &sns.ListSubscriptionsByTopicOutput{}, nil +} + // mock - SQS type MockSqsClient struct { mock.Mock @@ -90,11 +94,15 @@ func (_m *MockSqsClient) DeleteQueue(ctx context.Context, params *sqs.DeleteQueu return &sqs.DeleteQueueOutput{}, nil } +func (_m *MockSqsClient) ListQueues(ctx context.Context, params *sqs.ListQueuesInput, optFns ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error) { + return &sqs.ListQueuesOutput{}, nil +} + func (suite *notificationReceiverSuite) TestSuccessPath() { suite.Run("local backend - notification receiver stub", func() { v := viper.New() - localReceiver, err := InitReceiver(v, suite.Logger()) + localReceiver, err := InitReceiver(v, suite.Logger(), true) suite.NoError(err) suite.IsType(StubNotificationReceiver{}, localReceiver) @@ -121,10 +129,12 @@ func (suite *notificationReceiverSuite) TestSuccessPath() { suite.Equal(*receivedMessages[0].Body, fmt.Sprintf("%s:stubMessageBody", createdQueueUrl)) }) - suite.Run("aws backend - notification receiver init", func() { + suite.Run("aws backend - notification receiver InitReceiver", func() { v := Viper{} - receiver, _ := InitReceiver(&v, suite.Logger()) + receiver, err := InitReceiver(&v, suite.Logger(), false) + + suite.NoError(err) suite.IsType(NotificationReceiverContext{}, receiver) defaultTopic, err := receiver.GetDefaultTopic() suite.Equal("fake_sns_topic", defaultTopic)