Skip to content

Commit

Permalink
B-22056 - fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
ryan-mchugh committed Jan 13, 2025
1 parent c8424e8 commit 20f887c
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 32 deletions.
2 changes: 1 addition & 1 deletion cmd/milmove/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ func buildRoutingConfig(appCtx appcontext.AppContext, v *viper.Viper, redisPool
}

// Notification Receiver
notificationReceiver, err := notifications.InitReceiver(v, appCtx.Logger())
notificationReceiver, err := notifications.InitReceiver(v, appCtx.Logger(), true)
if err != nil {
appCtx.Logger().Fatal("notification receiver not enabled", zap.Error(err))
}
Expand Down
15 changes: 8 additions & 7 deletions pkg/handlers/internalapi/uploads.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,14 @@ func (o *CustomGetUploadStatusResponse) WriteResponse(rw http.ResponseWriter, pr
uploadStatus = AVStatusTypePROCESSING
}

// Limitation: once the status code header has been written (first response), we are not able to update the status for subsequent responses.
// Standard 200 OK used with common SSE paradigm
rw.WriteHeader(http.StatusOK)
if uploadStatus == AVStatusTypeCLEAN || uploadStatus == AVStatusTypeINFECTED {
rw.WriteHeader(http.StatusOK)
o.writeEventStreamMessage(rw, producer, 0, "message", string(uploadStatus))
o.writeEventStreamMessage(rw, producer, 1, "close", "Connection closed")
return // skip notification loop since object already tagged from anti-virus
} else {
// Limitation: once the status code header has been written (first response), we are not able to update the status for subsequent responses.
// StatusAccepted: Standard code 202 for accepted request, but response not yet ready.
rw.WriteHeader(http.StatusAccepted)
o.writeEventStreamMessage(rw, producer, 0, "message", string(uploadStatus))
}

Expand Down Expand Up @@ -345,7 +344,11 @@ func (o *CustomGetUploadStatusResponse) WriteResponse(rw http.ResponseWriter, pr

// For loop over 120 seconds, cancel context when done and it breaks the loop
totalReceiverContext, totalReceiverContextCancelFunc := context.WithTimeout(context.Background(), 120*time.Second)
defer totalReceiverContextCancelFunc()
defer func() {
id_counter++
o.writeEventStreamMessage(rw, producer, id_counter, "close", "Connection closed")
totalReceiverContextCancelFunc()
}()

// Cleanup if client closes connection
go func() {
Expand All @@ -356,8 +359,6 @@ func (o *CustomGetUploadStatusResponse) WriteResponse(rw http.ResponseWriter, pr
// Cleanup at end of work
go func() {
<-totalReceiverContext.Done()
id_counter++
o.writeEventStreamMessage(rw, producer, id_counter, "close", "Connection closed")
_ = o.receiver.CloseoutQueue(o.appCtx, queueUrl)
}()

Expand Down
2 changes: 1 addition & 1 deletion pkg/handlers/routing/internalapi_test/uploads_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (suite *InternalAPISuite) TestUploads() {

suite.SetupSiteHandler().ServeHTTP(rr, req)

suite.Equal(http.StatusAccepted, rr.Code)
suite.Equal(http.StatusOK, rr.Code)
suite.Equal("text/event-stream", rr.Header().Get("content-type"))

message1 := "id: 0\nevent: message\ndata: PROCESSING\n\n"
Expand Down
23 changes: 13 additions & 10 deletions pkg/notifications/mocks/NotificationReceiver.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 13 additions & 10 deletions pkg/notifications/notification_receiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ const (
type SnsClient interface {
Subscribe(ctx context.Context, params *sns.SubscribeInput, optFns ...func(*sns.Options)) (*sns.SubscribeOutput, error)
Unsubscribe(ctx context.Context, params *sns.UnsubscribeInput, optFns ...func(*sns.Options)) (*sns.UnsubscribeOutput, error)
ListSubscriptionsByTopic(context.Context, *sns.ListSubscriptionsByTopicInput, ...func(*sns.Options)) (*sns.ListSubscriptionsByTopicOutput, error)
}

type SqsClient interface {
CreateQueue(ctx context.Context, params *sqs.CreateQueueInput, optFns ...func(*sqs.Options)) (*sqs.CreateQueueOutput, error)
ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error)
DeleteQueue(ctx context.Context, params *sqs.DeleteQueueInput, optFns ...func(*sqs.Options)) (*sqs.DeleteQueueOutput, error)
ListQueues(ctx context.Context, params *sqs.ListQueuesInput, optFns ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error)
}

type ViperType interface {
Expand Down Expand Up @@ -216,7 +218,7 @@ func (n NotificationReceiverContext) GetDefaultTopic() (string, error) {
}

// InitReceiver initializes the receiver backend, only call this once
func InitReceiver(v ViperType, logger *zap.Logger) (NotificationReceiver, error) {
func InitReceiver(v ViperType, logger *zap.Logger, wipeAllNotificationQueues bool) (NotificationReceiver, error) {

if v.GetString(cli.ReceiverBackendFlag) == "sns&sqs" {
// Setup notification receiver service with SNS & SQS backend dependencies
Expand All @@ -239,9 +241,11 @@ func InitReceiver(v ViperType, logger *zap.Logger) (NotificationReceiver, error)
notificationReceiver := NewNotificationReceiver(v, snsService, sqsService, awsSNSRegion, awsAccountId)

// Remove any remaining previous notification queues on server start
err = notificationReceiver.wipeAllNotificationQueues(snsService, sqsService, logger)
if err != nil {
return nil, err
if wipeAllNotificationQueues {
err = notificationReceiver.wipeAllNotificationQueues(logger)
if err != nil {
return nil, err
}
}

return notificationReceiver, nil
Expand All @@ -255,15 +259,14 @@ func (n NotificationReceiverContext) constructArn(awsService string, endpointNam
}

// Removes ALL previously created notification queues
func (n *NotificationReceiverContext) wipeAllNotificationQueues(snsService *sns.Client, sqsService *sqs.Client, logger *zap.Logger) error {

func (n *NotificationReceiverContext) wipeAllNotificationQueues(logger *zap.Logger) error {
defaultTopic, err := n.GetDefaultTopic()
if err != nil {
return err
}

logger.Info("Removing previous subscriptions...")
paginator := sns.NewListSubscriptionsByTopicPaginator(snsService, &sns.ListSubscriptionsByTopicInput{
paginator := sns.NewListSubscriptionsByTopicPaginator(n.snsService, &sns.ListSubscriptionsByTopicInput{
TopicArn: aws.String(n.constructArn("sns", defaultTopic)),
})

Expand All @@ -276,7 +279,7 @@ func (n *NotificationReceiverContext) wipeAllNotificationQueues(snsService *sns.
if strings.Contains(*subscription.Endpoint, string(QueuePrefixObjectTagsAdded)) {
logger.Info("Subscription ARN: ", zap.String("subscription arn", *subscription.SubscriptionArn))
logger.Info("Endpoint ARN: ", zap.String("endpoint arn", *subscription.Endpoint))
_, err = snsService.Unsubscribe(context.Background(), &sns.UnsubscribeInput{
_, err = n.snsService.Unsubscribe(context.Background(), &sns.UnsubscribeInput{
SubscriptionArn: subscription.SubscriptionArn,
})
if err != nil {
Expand All @@ -287,15 +290,15 @@ func (n *NotificationReceiverContext) wipeAllNotificationQueues(snsService *sns.
}

logger.Info("Removing previous queues...")
result, err := sqsService.ListQueues(context.Background(), &sqs.ListQueuesInput{
result, err := n.sqsService.ListQueues(context.Background(), &sqs.ListQueuesInput{
QueueNamePrefix: aws.String(string(QueuePrefixObjectTagsAdded)),
})
if err != nil {
return err
}

for _, url := range result.QueueUrls {
_, err = sqsService.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{
_, err = n.sqsService.DeleteQueue(context.Background(), &sqs.DeleteQueueInput{
QueueUrl: &url,
})
if err != nil {
Expand Down
16 changes: 13 additions & 3 deletions pkg/notifications/notification_receiver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ func (_m *MockSnsClient) Unsubscribe(ctx context.Context, params *sns.Unsubscrib
return &sns.UnsubscribeOutput{}, nil
}

func (_m *MockSnsClient) ListSubscriptionsByTopic(context.Context, *sns.ListSubscriptionsByTopicInput, ...func(*sns.Options)) (*sns.ListSubscriptionsByTopicOutput, error) {
return &sns.ListSubscriptionsByTopicOutput{}, nil
}

// mock - SQS
type MockSqsClient struct {
mock.Mock
Expand All @@ -90,11 +94,15 @@ func (_m *MockSqsClient) DeleteQueue(ctx context.Context, params *sqs.DeleteQueu
return &sqs.DeleteQueueOutput{}, nil
}

func (_m *MockSqsClient) ListQueues(ctx context.Context, params *sqs.ListQueuesInput, optFns ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error) {
return &sqs.ListQueuesOutput{}, nil
}

func (suite *notificationReceiverSuite) TestSuccessPath() {

suite.Run("local backend - notification receiver stub", func() {
v := viper.New()
localReceiver, err := InitReceiver(v, suite.Logger())
localReceiver, err := InitReceiver(v, suite.Logger(), true)

suite.NoError(err)
suite.IsType(StubNotificationReceiver{}, localReceiver)
Expand All @@ -121,10 +129,12 @@ func (suite *notificationReceiverSuite) TestSuccessPath() {
suite.Equal(*receivedMessages[0].Body, fmt.Sprintf("%s:stubMessageBody", createdQueueUrl))
})

suite.Run("aws backend - notification receiver init", func() {
suite.Run("aws backend - notification receiver InitReceiver", func() {
v := Viper{}

receiver, _ := InitReceiver(&v, suite.Logger())
receiver, err := InitReceiver(&v, suite.Logger(), false)

suite.NoError(err)
suite.IsType(NotificationReceiverContext{}, receiver)
defaultTopic, err := receiver.GetDefaultTopic()
suite.Equal("fake_sns_topic", defaultTopic)
Expand Down

0 comments on commit 20f887c

Please sign in to comment.