diff --git a/cli/pkg/aws/dynamo_locking.go b/cli/pkg/aws/dynamo_locking.go index c38dad9e2..d149e3d40 100644 --- a/cli/pkg/aws/dynamo_locking.go +++ b/cli/pkg/aws/dynamo_locking.go @@ -25,17 +25,23 @@ const ( ) type DynamoDbLock struct { - DynamoDb *dynamodb.Client + DynamoDb DynamoDBClient } -func isResourceNotFoundExceptionError(err error) bool { - if err != nil { - var apiError smithy.APIError - if errors.As(err, &apiError) { - switch apiError.(type) { - case *types.ResourceNotFoundException: - return true - } +type DynamoDBClient interface { + DescribeTable(ctx context.Context, params *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) + CreateTable(ctx context.Context, params *dynamodb.CreateTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.CreateTableOutput, error) + UpdateItem(ctx context.Context, params *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error) + DeleteItem(ctx context.Context, params *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error) + GetItem(ctx context.Context, params *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) +} + +func isTableNotFoundExceptionError(err error) bool { + var apiError smithy.APIError + if errors.As(err, &apiError) { + switch apiError.(type) { + case *types.TableNotFoundException: + return true } } return false @@ -49,7 +55,7 @@ func (dynamoDbLock *DynamoDbLock) waitUntilTableCreated(ctx context.Context) err cnt := 0 if err != nil { - if !isResourceNotFoundExceptionError(err) { + if !isTableNotFoundExceptionError(err) { return err } } @@ -58,7 +64,7 @@ func (dynamoDbLock *DynamoDbLock) waitUntilTableCreated(ctx context.Context) err time.Sleep(TableCreationInterval) status, err = dynamoDbLock.DynamoDb.DescribeTable(ctx, input) if err != nil { - if !isResourceNotFoundExceptionError(err) { + if !isTableNotFoundExceptionError(err) { return err } } @@ -78,15 +84,14 @@ func (dynamoDbLock *DynamoDbLock) createTableIfNotExists(ctx context.Context) er _, err := dynamoDbLock.DynamoDb.DescribeTable(ctx, &dynamodb.DescribeTableInput{ TableName: aws.String(TABLE_NAME), }) - - if err != nil { - if !isResourceNotFoundExceptionError(err) { - return err - } + if err == nil { // Table exists + return nil + } + if !isTableNotFoundExceptionError(err) { + return err } createtbl_input := &dynamodb.CreateTableInput{ - AttributeDefinitions: []types.AttributeDefinition{ { AttributeName: aws.String("PK"), @@ -213,7 +218,8 @@ func (dynamoDbLock *DynamoDbLock) GetLock(lockId string) (*int, error) { } type TransactionLock struct { - TransactionID int `dynamodbav:"transaction_id"` + TransactionID int `dynamodbav:"transaction_id"` + Timeout string `dynamodbav:"timeout"` } var t TransactionLock diff --git a/cli/pkg/aws/dynamo_locking_test.go b/cli/pkg/aws/dynamo_locking_test.go new file mode 100644 index 000000000..a05b452de --- /dev/null +++ b/cli/pkg/aws/dynamo_locking_test.go @@ -0,0 +1,91 @@ +package aws + +import ( + "context" + "testing" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +type mockDynamoDbClient struct { + table map[string]map[string]types.AttributeValue + Options dynamodb.Options + MockDescribeTable func(ctx context.Context, params dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) + MockUpdateItem func(ctx context.Context, params *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error) + MockGetItem func(ctx context.Context, params *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) + MockDeleteItem func(ctx context.Context, params *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error) +} + +func (m *mockDynamoDbClient) DescribeTable(ctx context.Context, params *dynamodb.DescribeTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DescribeTableOutput, error) { + if m.table == nil || m.table[aws.ToString(params.TableName)] == nil { + return nil, &types.TableNotFoundException{} + } + if m.table[aws.ToString(params.TableName)] != nil { + return &dynamodb.DescribeTableOutput{Table: &types.TableDescription{TableName: params.TableName}}, nil + } + return nil, nil +} + +func (m *mockDynamoDbClient) CreateTable(ctx context.Context, params *dynamodb.CreateTableInput, optFns ...func(*dynamodb.Options)) (*dynamodb.CreateTableOutput, error) { + m.table[aws.ToString(params.TableName)] = make(map[string]types.AttributeValue) + return nil, nil +} + +func (m *mockDynamoDbClient) UpdateItem(ctx context.Context, params *dynamodb.UpdateItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.UpdateItemOutput, error) { + // TODO: Implement this + return &dynamodb.UpdateItemOutput{}, nil +} + +func (m *mockDynamoDbClient) GetItem(ctx context.Context, params *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) { + return &dynamodb.GetItemOutput{ + Item: map[string]types.AttributeValue{ + "PK": &types.AttributeValueMemberS{Value: "LOCK"}, + "SK": &types.AttributeValueMemberS{Value: "RES#example-resource"}, + "transaction_id": &types.AttributeValueMemberN{Value: "123"}, + "timeout": &types.AttributeValueMemberS{Value: "2024-04-01T00:00:00Z"}, + }, + }, nil +} + +func (m *mockDynamoDbClient) DeleteItem(ctx context.Context, params *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error) { + m.table[aws.ToString(params.TableName)][aws.ToString(¶ms.Key["SK"].(*types.AttributeValueMemberS).Value)] = nil + return &dynamodb.DeleteItemOutput{}, nil +} + +func TestDynamoDbLock_Lock(t *testing.T) { + client := mockDynamoDbClient{table: make(map[string]map[string]types.AttributeValue)} + dynamodbLock := DynamoDbLock{ + DynamoDb: &client, + } + dynamodbLock.DynamoDb.CreateTable(context.Background(), &dynamodb.CreateTableInput{TableName: aws.String(TABLE_NAME)}) + + // Set up the input parameters for the Lock method + transactionId := 123 + resource := "example-resource" + + locked, err := dynamodbLock.Lock(transactionId, resource) + if err != nil { + t.Fatalf("Error: %v", err) + } + if !locked { + t.Fatalf("Expected true, got %v", locked) + } +} +func TestDynamoDbLock_GetLock(t *testing.T) { + // Create a mock DynamoDB client + client := mockDynamoDbClient{table: make(map[string]map[string]types.AttributeValue)} + dynamodbLock := DynamoDbLock{ + DynamoDb: &client, + } + dynamodbLock.DynamoDb.CreateTable(context.Background(), &dynamodb.CreateTableInput{TableName: aws.String(TABLE_NAME)}) + + id, err := dynamodbLock.GetLock("example-resource") + if err != nil { + t.Fatalf("Error: %v", err) + } + if *id != 123 { + t.Fatalf("Expected 123, got %v", id) + } +}