diff --git a/core/control-plane/contract-core/src/test/java/org/eclipse/edc/connector/contract/negotiation/ContractNegotiationIntegrationTest.java b/core/control-plane/contract-core/src/test/java/org/eclipse/edc/connector/contract/negotiation/ContractNegotiationIntegrationTest.java index d36132e1a5e..09e0e879f67 100644 --- a/core/control-plane/contract-core/src/test/java/org/eclipse/edc/connector/contract/negotiation/ContractNegotiationIntegrationTest.java +++ b/core/control-plane/contract-core/src/test/java/org/eclipse/edc/connector/contract/negotiation/ContractNegotiationIntegrationTest.java @@ -153,7 +153,7 @@ void init() { .protocolWebhook(protocolWebhook) .build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), any(), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), any(), any(), any())).thenReturn(ServiceResult.success(participantAgent)); consumerService = new ContractNegotiationProtocolServiceImpl(consumerStore, new NoopTransactionContext(), validationService, offerResolver, protocolTokenValidator, new ContractNegotiationObservableImpl(), monitor, mock()); providerService = new ContractNegotiationProtocolServiceImpl(providerStore, new NoopTransactionContext(), validationService, offerResolver, protocolTokenValidator, new ContractNegotiationObservableImpl(), monitor, mock()); } diff --git a/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/catalog/CatalogProtocolServiceImpl.java b/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/catalog/CatalogProtocolServiceImpl.java index 2a12d2bb86c..c6835822f08 100644 --- a/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/catalog/CatalogProtocolServiceImpl.java +++ b/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/catalog/CatalogProtocolServiceImpl.java @@ -60,7 +60,7 @@ public CatalogProtocolServiceImpl(DatasetResolver datasetResolver, @Override @NotNull public ServiceResult getCatalog(CatalogRequestMessage message, TokenRepresentation tokenRepresentation) { - return transactionContext.execute(() -> protocolTokenValidator.verify(tokenRepresentation, CATALOGING_REQUEST_SCOPE) + return transactionContext.execute(() -> protocolTokenValidator.verify(tokenRepresentation, CATALOGING_REQUEST_SCOPE, message) .map(agent -> { try (var datasets = datasetResolver.query(agent, message.getQuerySpec())) { var dataServices = dataServiceRegistry.getDataServices(); diff --git a/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/contractnegotiation/ContractNegotiationProtocolServiceImpl.java b/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/contractnegotiation/ContractNegotiationProtocolServiceImpl.java index c26b218a361..2680eed4ced 100644 --- a/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/contractnegotiation/ContractNegotiationProtocolServiceImpl.java +++ b/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/contractnegotiation/ContractNegotiationProtocolServiceImpl.java @@ -39,6 +39,7 @@ import org.eclipse.edc.spi.monitor.Monitor; import org.eclipse.edc.spi.result.ServiceResult; import org.eclipse.edc.spi.telemetry.Telemetry; +import org.eclipse.edc.spi.types.domain.message.RemoteMessage; import org.eclipse.edc.transaction.spi.TransactionContext; import org.jetbrains.annotations.NotNull; @@ -85,7 +86,7 @@ public ContractNegotiationProtocolServiceImpl(ContractNegotiationStore store, @NotNull public ServiceResult notifyRequested(ContractRequestMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchValidatableOffer(message) - .compose(validatableOffer -> verifyRequest(tokenRepresentation, validatableOffer.getContractPolicy()) + .compose(validatableOffer -> verifyRequest(tokenRepresentation, validatableOffer.getContractPolicy(), message) .compose(agent -> validateOffer(agent, validatableOffer)) .compose(validatedOffer -> { var result = message.getProviderPid() == null @@ -110,12 +111,12 @@ public ServiceResult notifyRequested(ContractRequestMessage @WithSpan @NotNull public ServiceResult notifyOffered(ContractOfferMessage message, TokenRepresentation tokenRepresentation) { - return transactionContext.execute(() -> verifyRequest(tokenRepresentation, message.getContractOffer().getPolicy()) + return transactionContext.execute(() -> verifyRequest(tokenRepresentation, message.getContractOffer().getPolicy(), message) .compose(agent -> { ServiceResult result = message.getConsumerPid() == null ? createNegotiation(message, agent.getIdentity(), CONSUMER, message.getCallbackAddress()) : getAndLeaseNegotiation(message.getProviderPid()) - .compose(negotiation -> validateRequest(agent, negotiation).map(it -> negotiation)); + .compose(negotiation -> validateRequest(agent, negotiation).map(it -> negotiation)); return result.onSuccess(negotiation -> { if (negotiation.shouldIgnoreIncomingMessage(message.getId())) { @@ -135,7 +136,7 @@ public ServiceResult notifyOffered(ContractOfferMessage mes @NotNull public ServiceResult notifyAccepted(ContractNegotiationEventMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> getNegotiation(message.getProcessId()) - .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy()) + .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), message) .compose(agent -> validateRequest(agent, contractNegotiation))) .compose(cn -> onMessageDo(message, contractNegotiation -> acceptedAction(message, contractNegotiation)))); @@ -146,7 +147,7 @@ public ServiceResult notifyAccepted(ContractNegotiationEven @NotNull public ServiceResult notifyAgreed(ContractAgreementMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> getNegotiation(message.getProcessId()) - .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy()) + .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), message) .compose(agent -> validateAgreed(message, agent, contractNegotiation))) .compose(cn -> onMessageDo(message, contractNegotiation -> agreedAction(message, contractNegotiation)))); } @@ -156,7 +157,7 @@ public ServiceResult notifyAgreed(ContractAgreementMessage @NotNull public ServiceResult notifyVerified(ContractAgreementVerificationMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> getNegotiation(message.getProcessId()) - .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy()) + .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), message) .compose(agent -> validateRequest(agent, contractNegotiation))) .compose(cn -> onMessageDo(message, contractNegotiation -> verifiedAction(message, contractNegotiation)))); } @@ -166,7 +167,7 @@ public ServiceResult notifyVerified(ContractAgreementVerifi @NotNull public ServiceResult notifyFinalized(ContractNegotiationEventMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> getNegotiation(message.getProcessId()) - .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy()) + .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), message) .compose(agent -> validateRequest(agent, contractNegotiation))) .compose(cn -> onMessageDo(message, contractNegotiation -> finalizedAction(message, contractNegotiation)))); } @@ -176,7 +177,7 @@ public ServiceResult notifyFinalized(ContractNegotiationEve @NotNull public ServiceResult notifyTerminated(ContractNegotiationTerminationMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> getNegotiation(message.getProcessId()) - .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy()) + .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), message) .compose(agent -> validateRequest(agent, contractNegotiation))) .compose(cn -> onMessageDo(message, contractNegotiation -> terminatedAction(message, contractNegotiation)))); } @@ -186,7 +187,7 @@ public ServiceResult notifyTerminated(ContractNegotiationTe @NotNull public ServiceResult findById(String id, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> getNegotiation(id) - .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy()) + .compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), null) .compose(agent -> validateRequest(agent, contractNegotiation) .map(it -> contractNegotiation)))); } @@ -318,8 +319,8 @@ private ServiceResult getAndLeaseNegotiation(String negotia .flatMap(ServiceResult::from); } - private ServiceResult verifyRequest(TokenRepresentation tokenRepresentation, Policy policy) { - return protocolTokenValidator.verify(tokenRepresentation, CONTRACT_NEGOTIATION_REQUEST_SCOPE, policy) + private ServiceResult verifyRequest(TokenRepresentation tokenRepresentation, Policy policy, RemoteMessage message) { + return protocolTokenValidator.verify(tokenRepresentation, CONTRACT_NEGOTIATION_REQUEST_SCOPE, policy, message) .onFailure(failure -> monitor.debug(() -> "Verification Failed: %s".formatted(failure.getFailureDetail()))); } diff --git a/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/protocol/ProtocolTokenValidatorImpl.java b/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/protocol/ProtocolTokenValidatorImpl.java index 04032e87c36..2bfa6ba60fc 100644 --- a/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/protocol/ProtocolTokenValidatorImpl.java +++ b/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/protocol/ProtocolTokenValidatorImpl.java @@ -21,11 +21,13 @@ import org.eclipse.edc.spi.agent.ParticipantAgent; import org.eclipse.edc.spi.agent.ParticipantAgentService; import org.eclipse.edc.spi.iam.IdentityService; +import org.eclipse.edc.spi.iam.RequestContext; import org.eclipse.edc.spi.iam.RequestScope; import org.eclipse.edc.spi.iam.TokenRepresentation; import org.eclipse.edc.spi.iam.VerificationContext; import org.eclipse.edc.spi.monitor.Monitor; import org.eclipse.edc.spi.result.ServiceResult; +import org.eclipse.edc.spi.types.domain.message.RemoteMessage; /** * Implementation of {@link ProtocolTokenValidator} which uses the {@link PolicyEngine} for extracting @@ -48,8 +50,8 @@ public ProtocolTokenValidatorImpl(IdentityService identityService, PolicyEngine } @Override - public ServiceResult verify(TokenRepresentation tokenRepresentation, String policyScope, Policy policy) { - var tokenValidation = identityService.verifyJwtToken(tokenRepresentation, createVerificationContext(policyScope, policy)); + public ServiceResult verify(TokenRepresentation tokenRepresentation, String policyScope, Policy policy, RemoteMessage message) { + var tokenValidation = identityService.verifyJwtToken(tokenRepresentation, createVerificationContext(policyScope, policy, message)); if (tokenValidation.failed()) { monitor.debug(() -> "Unauthorized: %s".formatted(tokenValidation.getFailureDetail())); return ServiceResult.unauthorized("Unauthorized"); @@ -60,9 +62,11 @@ public ServiceResult verify(TokenRepresentation tokenRepresent return ServiceResult.success(participantAgent); } - private VerificationContext createVerificationContext(String scope, Policy policy) { + private VerificationContext createVerificationContext(String scope, Policy policy, RemoteMessage message) { var requestScopeBuilder = RequestScope.Builder.newInstance(); + var requestContext = RequestContext.Builder.newInstance().message(message).direction(RequestContext.Direction.Ingress).build(); var policyContext = PolicyContextImpl.Builder.newInstance() + .additional(RequestContext.class, requestContext) .additional(RequestScope.Builder.class, requestScopeBuilder) .build(); policyEngine.evaluate(scope, policy, policyContext); diff --git a/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/transferprocess/TransferProcessProtocolServiceImpl.java b/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/transferprocess/TransferProcessProtocolServiceImpl.java index a82b58f8b6d..918f0a8977d 100644 --- a/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/transferprocess/TransferProcessProtocolServiceImpl.java +++ b/core/control-plane/control-plane-aggregate-services/src/main/java/org/eclipse/edc/connector/service/transferprocess/TransferProcessProtocolServiceImpl.java @@ -40,6 +40,7 @@ import org.eclipse.edc.spi.telemetry.Telemetry; import org.eclipse.edc.spi.types.domain.DataAddress; import org.eclipse.edc.spi.types.domain.agreement.ContractAgreement; +import org.eclipse.edc.spi.types.domain.message.RemoteMessage; import org.eclipse.edc.transaction.spi.TransactionContext; import org.eclipse.edc.validator.spi.DataAddressValidatorRegistry; import org.jetbrains.annotations.NotNull; @@ -97,7 +98,7 @@ public TransferProcessProtocolServiceImpl(TransferProcessStore transferProcessSt @NotNull public ServiceResult notifyRequested(TransferRequestMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchNotifyRequestContext(message) - .compose(context -> verifyRequest(tokenRepresentation, context)) + .compose(context -> verifyRequest(tokenRepresentation, context, message)) .compose(context -> validateDestination(message, context)) .compose(context -> validateAgreement(message, context)) .compose(context -> requestedAction(message, context.agreement().getAssetId()))); @@ -108,7 +109,7 @@ public ServiceResult notifyRequested(TransferRequestMessage mes @NotNull public ServiceResult notifyStarted(TransferStartMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess) - .compose(context -> verifyRequest(tokenRepresentation, context)) + .compose(context -> verifyRequest(tokenRepresentation, context, message)) .compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> startedAction(message, transferProcess))) ); } @@ -118,7 +119,7 @@ public ServiceResult notifyStarted(TransferStartMessage message @NotNull public ServiceResult notifyCompleted(TransferCompletionMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess) - .compose(context -> verifyRequest(tokenRepresentation, context)) + .compose(context -> verifyRequest(tokenRepresentation, context, message)) .compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> completedAction(message, transferProcess))) ); } @@ -126,7 +127,7 @@ public ServiceResult notifyCompleted(TransferCompletionMessage @Override public @NotNull ServiceResult notifySuspended(TransferSuspensionMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess) - .compose(context -> verifyRequest(tokenRepresentation, context)) + .compose(context -> verifyRequest(tokenRepresentation, context, message)) .compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> suspendedAction(message, transferProcess))) ); } @@ -136,7 +137,7 @@ public ServiceResult notifyCompleted(TransferCompletionMessage @NotNull public ServiceResult notifyTerminated(TransferTerminationMessage message, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess) - .compose(context -> verifyRequest(tokenRepresentation, context)) + .compose(context -> verifyRequest(tokenRepresentation, context, message)) .compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> terminatedAction(message, transferProcess))) ); } @@ -146,7 +147,7 @@ public ServiceResult notifyTerminated(TransferTerminationMessag @NotNull public ServiceResult findById(String id, TokenRepresentation tokenRepresentation) { return transactionContext.execute(() -> fetchRequestContext(id, this::findTransferProcessById) - .compose(context -> verifyRequest(tokenRepresentation, context)) + .compose(context -> verifyRequest(tokenRepresentation, context, null)) .compose(context -> validateCounterParty(context.participantAgent(), context.agreement(), context.transferProcess()))); } @@ -281,8 +282,8 @@ private ServiceResult fetchRequestContext(T i return tpProvider.apply(input).compose(transferProcess -> findContractByTransferProcess(transferProcess).map(agreement -> new TransferRequestMessageContext(agreement, transferProcess))); } - private ServiceResult verifyRequest(TokenRepresentation tokenRepresentation, TransferRequestMessageContext context) { - var result = protocolTokenValidator.verify(tokenRepresentation, TRANSFER_PROCESS_REQUEST_SCOPE, context.agreement().getPolicy()); + private ServiceResult verifyRequest(TokenRepresentation tokenRepresentation, TransferRequestMessageContext context, RemoteMessage message) { + var result = protocolTokenValidator.verify(tokenRepresentation, TRANSFER_PROCESS_REQUEST_SCOPE, context.agreement().getPolicy(), message); if (result.failed()) { monitor.debug(() -> "Verification Failed: %s".formatted(result.getFailureDetail())); return ServiceResult.notFound("Not found"); diff --git a/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/catalog/CatalogProtocolServiceImplTest.java b/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/catalog/CatalogProtocolServiceImplTest.java index 3d29cc36851..da1a302587d 100644 --- a/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/catalog/CatalogProtocolServiceImplTest.java +++ b/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/catalog/CatalogProtocolServiceImplTest.java @@ -59,6 +59,23 @@ class CatalogProtocolServiceImplTest { private final CatalogProtocolServiceImpl service = new CatalogProtocolServiceImpl(datasetResolver, dataServiceRegistry, protocolTokenValidator, "participantId", transactionContext); + private ParticipantAgent createParticipantAgent() { + return new ParticipantAgent(emptyMap(), emptyMap()); + } + + private Dataset createDataset() { + var dataService = DataService.Builder.newInstance().build(); + var distribution = Distribution.Builder.newInstance().dataService(dataService).format("any").build(); + return Dataset.Builder.newInstance() + .offer(UUID.randomUUID().toString(), Policy.Builder.newInstance().build()) + .distribution(distribution) + .build(); + } + + private TokenRepresentation createTokenRepresentation() { + return TokenRepresentation.Builder.newInstance().build(); + } + @Nested class GetCatalog { @@ -70,7 +87,7 @@ void shouldReturnCatalogWithConnectorDataServiceAndItsDataset() { var participantAgent = createParticipantAgent(); var dataService = DataService.Builder.newInstance().build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CATALOGING_REQUEST_SCOPE))).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CATALOGING_REQUEST_SCOPE), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(dataServiceRegistry.getDataServices()).thenReturn(List.of(dataService)); when(datasetResolver.query(any(), any())).thenReturn(Stream.of(createDataset())); @@ -91,7 +108,7 @@ void shouldFail_whenTokenValidationFails() { var message = CatalogRequestMessage.Builder.newInstance().protocol("protocol").querySpec(querySpec).build(); var tokenRepresentation = createTokenRepresentation(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CATALOGING_REQUEST_SCOPE))).thenReturn(ServiceResult.unauthorized("unauthorized")); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CATALOGING_REQUEST_SCOPE), eq(message))).thenReturn(ServiceResult.unauthorized("unauthorized")); var result = service.getCatalog(message, tokenRepresentation); @@ -142,22 +159,4 @@ void shouldFail_whenTokenValidationFails() { assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(UNAUTHORIZED); } } - - - private ParticipantAgent createParticipantAgent() { - return new ParticipantAgent(emptyMap(), emptyMap()); - } - - private Dataset createDataset() { - var dataService = DataService.Builder.newInstance().build(); - var distribution = Distribution.Builder.newInstance().dataService(dataService).format("any").build(); - return Dataset.Builder.newInstance() - .offer(UUID.randomUUID().toString(), Policy.Builder.newInstance().build()) - .distribution(distribution) - .build(); - } - - private TokenRepresentation createTokenRepresentation() { - return TokenRepresentation.Builder.newInstance().build(); - } } diff --git a/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/contractnegotiation/ContractNegotiationProtocolServiceImplTest.java b/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/contractnegotiation/ContractNegotiationProtocolServiceImplTest.java index 66dfcb593cb..f64fac62bc3 100644 --- a/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/contractnegotiation/ContractNegotiationProtocolServiceImplTest.java +++ b/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/contractnegotiation/ContractNegotiationProtocolServiceImplTest.java @@ -84,6 +84,7 @@ import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -113,210 +114,6 @@ void setUp() { consumerOfferResolver, protocolTokenValidator, observable, mock(), mock()); } - @Nested - class NotifyRequested { - @Test - void shouldInitiateNegotiation_whenNegotiationDoesNotExist() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var contractOffer = contractOffer(); - var validatedOffer = new ValidatedConsumerOffer(CONSUMER_ID, contractOffer); - var message = ContractRequestMessage.Builder.newInstance() - .callbackAddress("callbackAddress") - .protocol("protocol") - .contractOffer(contractOffer) - .consumerPid("consumerPid") - .build(); - var validatableOffer = mock(ValidatableConsumerOffer.class); - - when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); - when(consumerOfferResolver.resolveOffer(contractOffer.getId())).thenReturn(ServiceResult.success(validatableOffer)); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findByIdAndLease(any())).thenReturn(StoreResult.notFound("not found")); - when(validationService.validateInitialOffer(participantAgent, validatableOffer)).thenReturn(Result.success(validatedOffer)); - - var result = service.notifyRequested(message, tokenRepresentation); - - assertThat(result).isSucceeded(); - var calls = ArgumentCaptor.forClass(ContractNegotiation.class); - verify(store, never()).findByIdAndLease(any()); - verify(store).save(calls.capture()); - assertThat(calls.getAllValues()).anySatisfy(n -> { - assertThat(n.getState()).isEqualTo(REQUESTED.code()); - assertThat(n.getCounterPartyAddress()).isEqualTo(message.getCallbackAddress()); - assertThat(n.getProtocol()).isEqualTo(message.getProtocol()); - assertThat(n.getCorrelationId()).isEqualTo(message.getConsumerPid()); - assertThat(n.getContractOffers()).hasSize(1); - assertThat(n.getLastContractOffer()).isEqualTo(contractOffer); - }); - verify(listener).requested(any()); - verify(validationService).validateInitialOffer(participantAgent, validatableOffer); - verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); - } - - @Test - void shouldTransitionToRequested_whenNegotiationFound() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var contractOffer = contractOffer(); - var validatedOffer = new ValidatedConsumerOffer(CONSUMER_ID, contractOffer); - var negotiation = createContractNegotiationOffered(); - var message = ContractRequestMessage.Builder.newInstance() - .callbackAddress("callbackAddress") - .protocol("protocol") - .processId("processId") - .contractOffer(contractOffer) - .consumerPid("consumerPid") - .providerPid("providerPid") - .build(); - - var validatableOffer = mock(ValidatableConsumerOffer.class); - - when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); - when(consumerOfferResolver.resolveOffer(contractOffer.getId())).thenReturn(ServiceResult.success(validatableOffer)); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById(any())).thenReturn(negotiation); - when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiation)); - when(validationService.validateInitialOffer(participantAgent, validatableOffer)).thenReturn(Result.success(validatedOffer)); - - - var result = service.notifyRequested(message, tokenRepresentation); - - assertThat(result).isSucceeded(); - verify(store).findByIdAndLease("providerPid"); - var calls = ArgumentCaptor.forClass(ContractNegotiation.class); - verify(store).save(calls.capture()); - assertThat(calls.getAllValues()).anySatisfy(n -> { - assertThat(n.getState()).isEqualTo(REQUESTED.code()); - assertThat(n.getProtocol()).isEqualTo(message.getProtocol()); - assertThat(n.getContractOffers()).hasSize(2); - assertThat(n.getLastContractOffer()).isEqualTo(contractOffer); - }); - verify(listener).requested(any()); - verify(store).findByIdAndLease("providerPid"); - verify(validationService).validateInitialOffer(participantAgent, validatableOffer); - verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); - } - - @Test - void shouldReturnNotFound_whenOfferNotFound() { - var tokenRepresentation = tokenRepresentation(); - var contractOffer = contractOffer(); - var message = ContractRequestMessage.Builder.newInstance() - .callbackAddress("callbackAddress") - .protocol("protocol") - .contractOffer(contractOffer) - .consumerPid("consumerPid") - .build(); - var validatableOffer = mock(ValidatableConsumerOffer.class); - - when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); - when(consumerOfferResolver.resolveOffer(contractOffer.getId())).thenReturn(ServiceResult.notFound("")); - - var result = service.notifyRequested(message, tokenRepresentation); - - assertThat(result) - .isFailed() - .extracting(ServiceFailure::getReason) - .isEqualTo(NOT_FOUND); - } - } - - @Nested - class NotifyOffered { - - @Test - void shouldInitiateNegotiation_whenNegotiationDoesNotExist() { - var tokenRepresentation = tokenRepresentation(); - var contractOffer = contractOffer(); - var message = ContractOfferMessage.Builder.newInstance() - .callbackAddress("callbackAddress") - .protocol("protocol") - .contractOffer(contractOffer) - .providerPid("providerPid") - .build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) - .thenReturn(ServiceResult.success(participantAgent())); - - var result = service.notifyOffered(message, tokenRepresentation); - - assertThat(result).isSucceeded(); - var calls = ArgumentCaptor.forClass(ContractNegotiation.class); - verify(store, never()).findByIdAndLease(any()); - verify(store).save(calls.capture()); - assertThat(calls.getAllValues()).anySatisfy(n -> { - assertThat(n.getState()).isEqualTo(OFFERED.code()); - assertThat(n.getType()).isEqualTo(CONSUMER); - assertThat(n.getCounterPartyId()).isEqualTo("counterPartyId"); - assertThat(n.getCounterPartyAddress()).isEqualTo(message.getCallbackAddress()); - assertThat(n.getProtocol()).isEqualTo(message.getProtocol()); - assertThat(n.getCorrelationId()).isEqualTo(message.getConsumerPid()); - assertThat(n.getContractOffers()).hasSize(1); - assertThat(n.getLastContractOffer()).isEqualTo(contractOffer); - }); - verify(listener).offered(any()); - verifyNoInteractions(validationService, consumerOfferResolver); - verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); - } - - @Test - void shouldTransitionToOffered_whenNegotiationAlreadyExist() { - var processId = "processId"; - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var contractOffer = contractOffer(); - var message = ContractOfferMessage.Builder.newInstance() - .callbackAddress("callbackAddress") - .protocol("protocol") - .contractOffer(contractOffer) - .processId("providerPid") - .consumerPid("consumerPid") - .providerPid("providerPid") - .build(); - var negotiation = createContractNegotiationRequested(); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) - .thenReturn(ServiceResult.success(participantAgent)); - when(store.findById(processId)).thenReturn(negotiation); - when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiation)); - when(validationService.validateRequest(participantAgent, negotiation)).thenReturn(Result.success()); - - var result = service.notifyOffered(message, tokenRepresentation); - - assertThat(result).isSucceeded(); - var updatedNegotiation = result.getContent(); - assertThat(updatedNegotiation.getContractOffers()).hasSize(2); - assertThat(updatedNegotiation.getLastContractOffer()).isEqualTo(contractOffer); - verify(store).findByIdAndLease("providerPid"); - verify(listener).offered(any()); - verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); - } - - @Test - void shouldReturnNotFound_whenOfferNotFound() { - var tokenRepresentation = tokenRepresentation(); - var contractOffer = contractOffer(); - var message = ContractOfferMessage.Builder.newInstance() - .callbackAddress("callbackAddress") - .protocol("protocol") - .contractOffer(contractOffer) - .consumerPid("consumerPid") - .providerPid("providerPid") - .build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) - .thenReturn(ServiceResult.success(participantAgent())); - when(store.findByIdAndLease(any())).thenReturn(StoreResult.notFound("not found")); - when(store.findByCorrelationIdAndLease(any())).thenReturn(StoreResult.notFound("not found")); - - var result = service.notifyOffered(message, tokenRepresentation); - - assertThat(result) - .isFailed() - .extracting(ServiceFailure::getReason) - .isEqualTo(NOT_FOUND); - } - } - @Test void notifyAccepted_shouldTransitionToAccepted() { var contractNegotiation = createContractNegotiationOffered(); @@ -331,7 +128,7 @@ void notifyAccepted_shouldTransitionToAccepted() { .type(ContractNegotiationEventMessage.Type.ACCEPTED) .policy(Policy.Builder.newInstance().build()) .build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) .thenReturn(ServiceResult.success(participantAgent)); when(store.findById(any())).thenReturn(contractNegotiation); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(contractNegotiation)); @@ -364,7 +161,7 @@ void notifyAgreed_shouldTransitionToAgreed() { .contractAgreement(contractAgreement) .build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(store.findById(any())).thenReturn(negotiationConsumerRequested); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiationConsumerRequested)); when(validationService.validateConfirmed(eq(participantAgent), eq(contractAgreement), any(ContractOffer.class))).thenReturn(Result.success()); @@ -399,7 +196,7 @@ void notifyVerified_shouldTransitionToVerified() { .providerPid("providerPid") .build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(store.findById(any())).thenReturn(negotiation); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiation)); when(validationService.validateRequest(any(ParticipantAgent.class), any(ContractNegotiation.class))).thenReturn(Result.success()); @@ -430,7 +227,7 @@ void notifyFinalized_shouldTransitionToFinalized() { .build(); var tokenRepresentation = tokenRepresentation(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) .thenReturn(ServiceResult.success(participantAgent())); when(store.findById(any())).thenReturn(negotiation); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiation)); @@ -461,7 +258,7 @@ void notifyTerminated_shouldTransitionToTerminated() { .build(); var tokenRepresentation = tokenRepresentation(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) .thenReturn(ServiceResult.success(participantAgent())); when(store.findById(any())).thenReturn(negotiation); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiation)); @@ -486,7 +283,7 @@ void findById_shouldReturnNegotiation_whenValidCounterParty() { var contractOffer = contractOffer(); var negotiation = contractNegotiationBuilder().id(id).type(PROVIDER).contractOffer(contractOffer).state(VERIFIED.code()).build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), isNull())) .thenReturn(ServiceResult.success(participantAgent)); when(store.findById(id)).thenReturn(negotiation); when(validationService.validateRequest(participantAgent, negotiation)).thenReturn(Result.success()); @@ -520,7 +317,7 @@ void findById_shouldReturnBadRequest_whenCounterPartyUnauthorized() { var tokenRepresentation = tokenRepresentation(); var contractOffer = contractOffer(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), isNull())).thenReturn(ServiceResult.success(participantAgent)); var negotiation = contractNegotiationBuilder().id(id).type(PROVIDER).contractOffer(contractOffer).state(VERIFIED.code()).build(); @@ -539,7 +336,7 @@ void findById_shouldReturnBadRequest_whenCounterPartyUnauthorized() { @ArgumentsSource(NotifyArguments.class) void notify_shouldReturnNotFound_whenNotFound(MethodCall methodCall, M message) { var tokenRepresentation = tokenRepresentation(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) .thenReturn(ServiceResult.success(participantAgent())); when(store.findByIdAndLease(any())).thenReturn(StoreResult.notFound("not found")); when(store.findByCorrelationIdAndLease(any())).thenReturn(StoreResult.notFound("not found")); @@ -562,7 +359,7 @@ void notify_shouldReturnBadRequest_whenValidationFails when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); when(consumerOfferResolver.resolveOffer(any())).thenReturn(ServiceResult.success(validatableOffer)); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) .thenReturn(ServiceResult.success(participantAgent())); when(store.findById(any())).thenReturn(createContractNegotiationOffered()); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(createContractNegotiationOffered())); @@ -586,7 +383,7 @@ void notify_shouldReturnUnauthorized_whenTokenValidati when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); when(consumerOfferResolver.resolveOffer(any())).thenReturn(ServiceResult.success(validatableOffer)); when(store.findById(any())).thenReturn(createContractNegotiationOffered()); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())).thenReturn(ServiceResult.unauthorized("unauthorized")); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.unauthorized("unauthorized")); var result = methodCall.call(service, message, tokenRepresentation); @@ -718,6 +515,210 @@ public Stream provideArguments(ExtensionContext extensionCo } + @Nested + class NotifyRequested { + @Test + void shouldInitiateNegotiation_whenNegotiationDoesNotExist() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var contractOffer = contractOffer(); + var validatedOffer = new ValidatedConsumerOffer(CONSUMER_ID, contractOffer); + var message = ContractRequestMessage.Builder.newInstance() + .callbackAddress("callbackAddress") + .protocol("protocol") + .contractOffer(contractOffer) + .consumerPid("consumerPid") + .build(); + var validatableOffer = mock(ValidatableConsumerOffer.class); + + when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); + when(consumerOfferResolver.resolveOffer(contractOffer.getId())).thenReturn(ServiceResult.success(validatableOffer)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findByIdAndLease(any())).thenReturn(StoreResult.notFound("not found")); + when(validationService.validateInitialOffer(participantAgent, validatableOffer)).thenReturn(Result.success(validatedOffer)); + + var result = service.notifyRequested(message, tokenRepresentation); + + assertThat(result).isSucceeded(); + var calls = ArgumentCaptor.forClass(ContractNegotiation.class); + verify(store, never()).findByIdAndLease(any()); + verify(store).save(calls.capture()); + assertThat(calls.getAllValues()).anySatisfy(n -> { + assertThat(n.getState()).isEqualTo(REQUESTED.code()); + assertThat(n.getCounterPartyAddress()).isEqualTo(message.getCallbackAddress()); + assertThat(n.getProtocol()).isEqualTo(message.getProtocol()); + assertThat(n.getCorrelationId()).isEqualTo(message.getConsumerPid()); + assertThat(n.getContractOffers()).hasSize(1); + assertThat(n.getLastContractOffer()).isEqualTo(contractOffer); + }); + verify(listener).requested(any()); + verify(validationService).validateInitialOffer(participantAgent, validatableOffer); + verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); + } + + @Test + void shouldTransitionToRequested_whenNegotiationFound() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var contractOffer = contractOffer(); + var validatedOffer = new ValidatedConsumerOffer(CONSUMER_ID, contractOffer); + var negotiation = createContractNegotiationOffered(); + var message = ContractRequestMessage.Builder.newInstance() + .callbackAddress("callbackAddress") + .protocol("protocol") + .processId("processId") + .contractOffer(contractOffer) + .consumerPid("consumerPid") + .providerPid("providerPid") + .build(); + + var validatableOffer = mock(ValidatableConsumerOffer.class); + + when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); + when(consumerOfferResolver.resolveOffer(contractOffer.getId())).thenReturn(ServiceResult.success(validatableOffer)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById(any())).thenReturn(negotiation); + when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiation)); + when(validationService.validateInitialOffer(participantAgent, validatableOffer)).thenReturn(Result.success(validatedOffer)); + + + var result = service.notifyRequested(message, tokenRepresentation); + + assertThat(result).isSucceeded(); + verify(store).findByIdAndLease("providerPid"); + var calls = ArgumentCaptor.forClass(ContractNegotiation.class); + verify(store).save(calls.capture()); + assertThat(calls.getAllValues()).anySatisfy(n -> { + assertThat(n.getState()).isEqualTo(REQUESTED.code()); + assertThat(n.getProtocol()).isEqualTo(message.getProtocol()); + assertThat(n.getContractOffers()).hasSize(2); + assertThat(n.getLastContractOffer()).isEqualTo(contractOffer); + }); + verify(listener).requested(any()); + verify(store).findByIdAndLease("providerPid"); + verify(validationService).validateInitialOffer(participantAgent, validatableOffer); + verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); + } + + @Test + void shouldReturnNotFound_whenOfferNotFound() { + var tokenRepresentation = tokenRepresentation(); + var contractOffer = contractOffer(); + var message = ContractRequestMessage.Builder.newInstance() + .callbackAddress("callbackAddress") + .protocol("protocol") + .contractOffer(contractOffer) + .consumerPid("consumerPid") + .build(); + var validatableOffer = mock(ValidatableConsumerOffer.class); + + when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); + when(consumerOfferResolver.resolveOffer(contractOffer.getId())).thenReturn(ServiceResult.notFound("")); + + var result = service.notifyRequested(message, tokenRepresentation); + + assertThat(result) + .isFailed() + .extracting(ServiceFailure::getReason) + .isEqualTo(NOT_FOUND); + } + } + + @Nested + class NotifyOffered { + + @Test + void shouldInitiateNegotiation_whenNegotiationDoesNotExist() { + var tokenRepresentation = tokenRepresentation(); + var contractOffer = contractOffer(); + var message = ContractOfferMessage.Builder.newInstance() + .callbackAddress("callbackAddress") + .protocol("protocol") + .contractOffer(contractOffer) + .providerPid("providerPid") + .build(); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) + .thenReturn(ServiceResult.success(participantAgent())); + + var result = service.notifyOffered(message, tokenRepresentation); + + assertThat(result).isSucceeded(); + var calls = ArgumentCaptor.forClass(ContractNegotiation.class); + verify(store, never()).findByIdAndLease(any()); + verify(store).save(calls.capture()); + assertThat(calls.getAllValues()).anySatisfy(n -> { + assertThat(n.getState()).isEqualTo(OFFERED.code()); + assertThat(n.getType()).isEqualTo(CONSUMER); + assertThat(n.getCounterPartyId()).isEqualTo("counterPartyId"); + assertThat(n.getCounterPartyAddress()).isEqualTo(message.getCallbackAddress()); + assertThat(n.getProtocol()).isEqualTo(message.getProtocol()); + assertThat(n.getCorrelationId()).isEqualTo(message.getConsumerPid()); + assertThat(n.getContractOffers()).hasSize(1); + assertThat(n.getLastContractOffer()).isEqualTo(contractOffer); + }); + verify(listener).offered(any()); + verifyNoInteractions(validationService, consumerOfferResolver); + verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); + } + + @Test + void shouldTransitionToOffered_whenNegotiationAlreadyExist() { + var processId = "processId"; + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var contractOffer = contractOffer(); + var message = ContractOfferMessage.Builder.newInstance() + .callbackAddress("callbackAddress") + .protocol("protocol") + .contractOffer(contractOffer) + .processId("providerPid") + .consumerPid("consumerPid") + .providerPid("providerPid") + .build(); + var negotiation = createContractNegotiationRequested(); + + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) + .thenReturn(ServiceResult.success(participantAgent)); + when(store.findById(processId)).thenReturn(negotiation); + when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiation)); + when(validationService.validateRequest(participantAgent, negotiation)).thenReturn(Result.success()); + + var result = service.notifyOffered(message, tokenRepresentation); + + assertThat(result).isSucceeded(); + var updatedNegotiation = result.getContent(); + assertThat(updatedNegotiation.getContractOffers()).hasSize(2); + assertThat(updatedNegotiation.getLastContractOffer()).isEqualTo(contractOffer); + verify(store).findByIdAndLease("providerPid"); + verify(listener).offered(any()); + verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); + } + + @Test + void shouldReturnNotFound_whenOfferNotFound() { + var tokenRepresentation = tokenRepresentation(); + var contractOffer = contractOffer(); + var message = ContractOfferMessage.Builder.newInstance() + .callbackAddress("callbackAddress") + .protocol("protocol") + .contractOffer(contractOffer) + .consumerPid("consumerPid") + .providerPid("providerPid") + .build(); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) + .thenReturn(ServiceResult.success(participantAgent())); + when(store.findByIdAndLease(any())).thenReturn(StoreResult.notFound("not found")); + when(store.findByCorrelationIdAndLease(any())).thenReturn(StoreResult.notFound("not found")); + + var result = service.notifyOffered(message, tokenRepresentation); + + assertThat(result) + .isFailed() + .extracting(ServiceFailure::getReason) + .isEqualTo(NOT_FOUND); + } + } + @Nested class IdempotencyProcessStateReplication { @@ -732,7 +733,7 @@ void notify_shouldStoreReceivedMessageId(Method when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); when(consumerOfferResolver.resolveOffer(any())).thenReturn(ServiceResult.success(validatableOffer)); - when(protocolTokenValidator.verify(any(), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) + when(protocolTokenValidator.verify(any(), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) .thenReturn(ServiceResult.success(participantAgent())); when(store.findById(any())).thenReturn(negotiation); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiation)); @@ -762,7 +763,7 @@ void notify_shouldIgnoreMessage_whenAlreadyRece when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); when(consumerOfferResolver.resolveOffer(any())).thenReturn(ServiceResult.success(validatableOffer)); - when(protocolTokenValidator.verify(any(), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) + when(protocolTokenValidator.verify(any(), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) .thenReturn(ServiceResult.success(participantAgent())); when(store.findById(any())).thenReturn(negotiation); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiation)); @@ -788,7 +789,7 @@ void notify_shouldIgnoreMessage_whenFinalState( when(validatableOffer.getContractPolicy()).thenReturn(createPolicy()); when(consumerOfferResolver.resolveOffer(any())).thenReturn(ServiceResult.success(validatableOffer)); - when(protocolTokenValidator.verify(any(), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any())) + when(protocolTokenValidator.verify(any(), eq(CONTRACT_NEGOTIATION_REQUEST_SCOPE), any(), eq(message))) .thenReturn(ServiceResult.success(participantAgent())); when(store.findById(any())).thenReturn(negotiation); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(negotiation)); diff --git a/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/protocol/ProtocolTokenValidatorImplTest.java b/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/protocol/ProtocolTokenValidatorImplTest.java index a2b2c15aa56..a7ec5ecfb69 100644 --- a/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/protocol/ProtocolTokenValidatorImplTest.java +++ b/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/protocol/ProtocolTokenValidatorImplTest.java @@ -20,15 +20,18 @@ import org.eclipse.edc.spi.agent.ParticipantAgentService; import org.eclipse.edc.spi.iam.ClaimToken; import org.eclipse.edc.spi.iam.IdentityService; +import org.eclipse.edc.spi.iam.RequestContext; import org.eclipse.edc.spi.iam.TokenRepresentation; import org.eclipse.edc.spi.result.Result; import org.eclipse.edc.spi.result.ServiceFailure; +import org.eclipse.edc.spi.types.domain.message.RemoteMessage; import org.junit.jupiter.api.Test; import static java.util.Collections.emptyMap; import static org.eclipse.edc.junit.assertions.AbstractResultAssert.assertThat; import static org.eclipse.edc.spi.result.ServiceFailure.Reason.UNAUTHORIZED; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.mock; @@ -51,11 +54,14 @@ void shouldVerifyToken() { when(identityService.verifyJwtToken(any(), any())).thenReturn(Result.success(claimToken)); when(agentService.createFor(any())).thenReturn(participantAgent); - var result = validator.verify(tokenRepresentation, "scope", policy); + var result = validator.verify(tokenRepresentation, "scope", policy, new TestMessage()); assertThat(result).isSucceeded().isSameAs(participantAgent); verify(agentService).createFor(claimToken); - verify(policyEngine).evaluate(eq("scope"), same(policy), any()); + verify(policyEngine).evaluate(eq("scope"), same(policy), argThat(ctx -> { + var reqContext = ctx.getContextData(RequestContext.class); + return reqContext.getMessage().getClass().equals(TestMessage.class) && reqContext.getDirection().equals(RequestContext.Direction.Ingress); + })); verify(identityService).verifyJwtToken(same(tokenRepresentation), any()); } @@ -63,8 +69,25 @@ void shouldVerifyToken() { void shouldReturnUnauthorized_whenTokenIsNotValid() { when(identityService.verifyJwtToken(any(), any())).thenReturn(Result.failure("failure")); - var result = validator.verify(TokenRepresentation.Builder.newInstance().build(), "scope", Policy.Builder.newInstance().build()); + var result = validator.verify(TokenRepresentation.Builder.newInstance().build(), "scope", Policy.Builder.newInstance().build(), new TestMessage()); assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(UNAUTHORIZED); } + + static class TestMessage implements RemoteMessage { + @Override + public String getProtocol() { + return null; + } + + @Override + public String getCounterPartyAddress() { + return "http://connector"; + } + + @Override + public String getCounterPartyId() { + return null; + } + } } diff --git a/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/transferprocess/TransferProcessProtocolServiceImplTest.java b/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/transferprocess/TransferProcessProtocolServiceImplTest.java index 96e09c3f568..b2cfe104c71 100644 --- a/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/transferprocess/TransferProcessProtocolServiceImplTest.java +++ b/core/control-plane/control-plane-aggregate-services/src/test/java/org/eclipse/edc/connector/service/transferprocess/TransferProcessProtocolServiceImplTest.java @@ -87,6 +87,7 @@ import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.ArgumentMatchers.isNull; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; @@ -131,7 +132,7 @@ void notifyRequested_validAgreement_shouldInitiateTransfer() { .dataDestination(DataAddress.Builder.newInstance().type("any").build()) .build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(negotiationStore.findContractAgreement(any())).thenReturn(contractAgreement()); when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.success(null)); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success()); @@ -162,7 +163,7 @@ void notifyRequested_doNothingIfProcessAlreadyExist() { var participantAgent = participantAgent(); var tokenRepresentation = tokenRepresentation(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(negotiationStore.findContractAgreement(any())).thenReturn(contractAgreement()); when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.success(null)); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success()); @@ -187,7 +188,7 @@ void notifyRequested_invalidAgreement_shouldNotInitiateTransfer() { var participantAgent = participantAgent(); var tokenRepresentation = tokenRepresentation(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(negotiationStore.findContractAgreement(any())).thenReturn(contractAgreement()); when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.failure("error")); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.success()); @@ -212,7 +213,7 @@ void notifyRequested_invalidDestination_shouldNotInitiateTransfer() { .build(); when(negotiationStore.findContractAgreement(any())).thenReturn(contractAgreement()); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(dataAddressValidator.validateDestination(any())).thenReturn(ValidationResult.failure(violation("invalid data address", "path"))); var result = service.notifyRequested(message, tokenRepresentation); @@ -234,7 +235,7 @@ void notifyRequested_missingDestination_shouldInitiateTransfer() { .callbackAddress("http://any") .build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(negotiationStore.findContractAgreement(any())).thenReturn(contractAgreement()); when(validationService.validateAgreement(any(ParticipantAgent.class), any())).thenReturn(Result.success(null)); @@ -252,168 +253,6 @@ void notifyRequested_missingDestination_shouldInitiateTransfer() { verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); } - @Nested - class NotifyStarted { - @Test - void shouldTransitionToStarted() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var message = TransferStartMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .dataAddress(DataAddress.Builder.newInstance().type("test").build()) - .build(); - var agreement = contractAgreement(); - var transferProcess = transferProcess(STARTED, "transferProcessId"); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); - - var result = service.notifyStarted(message, tokenRepresentation); - - var startedDataCaptor = ArgumentCaptor.forClass(TransferProcessStartedData.class); - var transferProcessCaptor = ArgumentCaptor.forClass(TransferProcess.class); - assertThat(result).isSucceeded(); - verify(listener).preStarted(any()); - verify(store).save(transferProcessCaptor.capture()); - verify(store).save(argThat(t -> t.getState() == STARTED.code())); - verify(listener).started(any(), startedDataCaptor.capture()); - verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); - assertThat(startedDataCaptor.getValue().getDataAddress()).usingRecursiveComparison().isEqualTo(message.getDataAddress()); - } - - @Test - void shouldReturnConflict_whenTransferCannotBeStarted() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var transferProcess = transferProcess(DEPROVISIONING, UUID.randomUUID().toString()); - var message = TransferStartMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .build(); - var agreement = contractAgreement(); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); - - var result = service.notifyStarted(message, tokenRepresentation); - - assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(CONFLICT); - // state didn't change - verify(store, times(1)).save(argThat(tp -> tp.getState() == DEPROVISIONING.code())); - verifyNoInteractions(listener); - } - - @Test - void shouldReturnBadRequest_whenCounterPartyUnauthorized() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var message = TransferStartMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .dataAddress(DataAddress.Builder.newInstance().type("test").build()) - .build(); - var agreement = contractAgreement(); - - var transferProcess = transferProcess(REQUESTED, "transferProcessId"); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.failure("error")); - - var result = service.notifyStarted(message, tokenRepresentation); - - assertThat(result) - .isFailed() - .extracting(ServiceFailure::getReason) - .isEqualTo(BAD_REQUEST); - - verify(store, times(1)).save(any()); - - } - } - - @Nested - class NotifyStartedResumed { - - @Test - void shouldTransitionToStartedAndStartDataFlow_whenProvider() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var message = TransferStartMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .dataAddress(DataAddress.Builder.newInstance().type("test").build()) - .build(); - var agreement = contractAgreement(); - var transferProcess = transferProcessBuilder().id("transferProcessId") - .state(SUSPENDED.code()).type(PROVIDER).build(); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); - - var result = service.notifyStarted(message, tokenRepresentation); - - var transferProcessCaptor = ArgumentCaptor.forClass(TransferProcess.class); - assertThat(result).isSucceeded(); - verify(store).save(transferProcessCaptor.capture()); - var storedTransferProcess = transferProcessCaptor.getValue(); - assertThat(storedTransferProcess.getState()).isEqualTo(STARTING.code()); - verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); - } - - @Test - void shouldReturnError_whenStatusIsNotSuspendedAndTypeProvider() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var message = TransferStartMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .dataAddress(DataAddress.Builder.newInstance().type("test").build()) - .build(); - var agreement = contractAgreement(); - var transferProcess = transferProcessBuilder().id("transferProcessId") - .state(REQUESTED.code()).type(PROVIDER).build(); - var dataFlowResponse = DataFlowResponse.Builder.newInstance().dataPlaneId("dataPlaneId").build(); - when(dataFlowManager.start(any(), any())).thenReturn(StatusResult.success(dataFlowResponse)); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); - - var result = service.notifyStarted(message, tokenRepresentation); - - assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(CONFLICT); - } - } - @Test void notifyCompleted_shouldTransitionToCompleted() { var participantAgent = participantAgent(); @@ -429,7 +268,7 @@ void notifyCompleted_shouldTransitionToCompleted() { var transferProcess = transferProcess(STARTED, "transferProcessId"); when(store.findById("correlationId")).thenReturn(transferProcess); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); @@ -457,7 +296,7 @@ void notifyCompleted_shouldReturnConflict_whenStatusIsNotValid() { .build(); var agreement = contractAgreement(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(store.findById("correlationId")).thenReturn(transferProcess); when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); @@ -486,7 +325,7 @@ void notifyCompleted_shouldReturnBadRequest_whenCounterPartyUnauthorized() { var agreement = contractAgreement(); var transferProcess = transferProcess(STARTED, "transferProcessId"); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(store.findById("correlationId")).thenReturn(transferProcess); when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); @@ -503,208 +342,53 @@ void notifyCompleted_shouldReturnBadRequest_whenCounterPartyUnauthorized() { } - @Nested - class NotifySuspended { - @Test - void consumer_shouldTransitionToSuspended() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var message = TransferSuspensionMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .code("TestCode") - .reason("TestReason") - .build(); - var agreement = contractAgreement(); - var transferProcess = transferProcessBuilder().state(STARTED.code()).type(CONSUMER).build(); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); - var result = service.notifySuspended(message, tokenRepresentation); - - assertThat(result).isSucceeded(); - verify(store).save(argThat(t -> t.getState() == SUSPENDED.code())); - verify(listener).suspended(any()); - verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); - } + @Test + void notifyTerminated_shouldTransitionToTerminated() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var message = TransferTerminationMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .code("TestCode") + .reason("TestReason") + .build(); + var agreement = contractAgreement(); + var transferProcess = transferProcess(STARTED, "transferProcessId"); - @Test - void provider_shouldSuspendDataFlowAndTransitionToSuspended() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var message = TransferSuspensionMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .code("TestCode") - .reason("TestReason") - .build(); - var agreement = contractAgreement(); - var transferProcess = transferProcessBuilder().state(STARTED.code()).type(PROVIDER).build(); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); + var result = service.notifyTerminated(message, tokenRepresentation); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); - when(dataFlowManager.suspend(any())).thenReturn(StatusResult.success()); + assertThat(result).isSucceeded(); + verify(listener).preTerminated(any()); + verify(store).save(argThat(t -> t.getState() == TERMINATED.code())); + verify(listener).terminated(any()); + verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); + } - var result = service.notifySuspended(message, tokenRepresentation); + @Test + void notifyTerminated_shouldReturnConflict_whenTransferProcessCannotBeTerminated() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var transferProcess = transferProcess(DEPROVISIONING, UUID.randomUUID().toString()); + var agreement = contractAgreement(); + var message = TransferTerminationMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .code("TestCode") + .reason("TestReason") + .build(); - assertThat(result).isSucceeded(); - verify(store).save(argThat(t -> t.getState() == SUSPENDED.code())); - verify(listener).suspended(any()); - verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); - } - - @Test - void provider_shouldReturnConflict_whenDataFlowCannotBeSuspended() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var message = TransferSuspensionMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .code("TestCode") - .reason("TestReason") - .build(); - var agreement = contractAgreement(); - var transferProcess = transferProcessBuilder().state(STARTED.code()).type(PROVIDER).build(); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); - when(dataFlowManager.suspend(any())).thenReturn(StatusResult.failure(FATAL_ERROR)); - - var result = service.notifySuspended(message, tokenRepresentation); - - assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(CONFLICT); - verify(store, times(1)).save(argThat(tp -> tp.getState() == STARTED.code())); - } - - @Test - void shouldReturnConflict_whenTransferProcessCannotBeSuspended() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var transferProcess = transferProcess(DEPROVISIONING, UUID.randomUUID().toString()); - var agreement = contractAgreement(); - var message = TransferSuspensionMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .code("TestCode") - .reason("TestReason") - .build(); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); - - var result = service.notifySuspended(message, tokenRepresentation); - - assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(CONFLICT); - // state didn't change - verify(store, times(1)).save(argThat(tp -> tp.getState() == DEPROVISIONING.code())); - verifyNoInteractions(listener); - } - - @Test - void shouldReturnBadRequest_whenCounterPartyUnauthorized() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var agreement = contractAgreement(); - var transferProcess = transferProcess(TERMINATED, UUID.randomUUID().toString()); - var message = TransferSuspensionMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .code("TestCode") - .reason("TestReason") - .build(); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.failure("error")); - - var result = service.notifySuspended(message, tokenRepresentation); - - assertThat(result) - .isFailed() - .extracting(ServiceFailure::getReason) - .isEqualTo(BAD_REQUEST); - - verify(store, times(1)).save(any()); - - } - } - - @Test - void notifyTerminated_shouldTransitionToTerminated() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var message = TransferTerminationMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .code("TestCode") - .reason("TestReason") - .build(); - var agreement = contractAgreement(); - var transferProcess = transferProcess(STARTED, "transferProcessId"); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); - when(store.findById("correlationId")).thenReturn(transferProcess); - when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); - when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); - when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); - var result = service.notifyTerminated(message, tokenRepresentation); - - assertThat(result).isSucceeded(); - verify(listener).preTerminated(any()); - verify(store).save(argThat(t -> t.getState() == TERMINATED.code())); - verify(listener).terminated(any()); - verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); - } - - @Test - void notifyTerminated_shouldReturnConflict_whenTransferProcessCannotBeTerminated() { - var participantAgent = participantAgent(); - var tokenRepresentation = tokenRepresentation(); - var transferProcess = transferProcess(DEPROVISIONING, UUID.randomUUID().toString()); - var agreement = contractAgreement(); - var message = TransferTerminationMessage.Builder.newInstance() - .protocol("protocol") - .consumerPid("consumerPid") - .providerPid("providerPid") - .counterPartyAddress("http://any") - .processId("correlationId") - .code("TestCode") - .reason("TestReason") - .build(); - - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(store.findById("correlationId")).thenReturn(transferProcess); when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); @@ -734,7 +418,7 @@ void notifyTerminated_shouldReturnBadRequest_whenCounterPartyUnauthorized() { .reason("TestReason") .build(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); when(store.findById("correlationId")).thenReturn(transferProcess); when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); @@ -759,7 +443,7 @@ void findById_shouldReturnTransferProcess_whenValidCounterParty() { var transferProcess = transferProcess(INITIAL, processId); var agreement = contractAgreement(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), isNull())).thenReturn(ServiceResult.success(participantAgent)); when(store.findById(processId)).thenReturn(transferProcess); when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); @@ -795,7 +479,7 @@ void findById_shouldReturnBadRequest_whenCounterPartyUnauthorized() { var tokenRepresentation = tokenRepresentation(); var agreement = contractAgreement(); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent)); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), isNull())).thenReturn(ServiceResult.success(participantAgent)); when(store.findById(processId)).thenReturn(transferProcess); when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.failure("error")); @@ -833,7 +517,7 @@ void notify_shouldFail_whenTokenValidationFails(Method when(store.findById(any())).thenReturn(transferProcessBuilder().build()); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(transferProcessBuilder().build())); when(negotiationStore.findContractAgreement(any())).thenReturn(contractAgreement()); - when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.unauthorized("unauthorized")); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.unauthorized("unauthorized")); var result = methodCall.call(service, message, tokenRepresentation); @@ -913,6 +597,323 @@ private M build(TransferRemoteMessage.Builder< } } + @Nested + class NotifyStarted { + @Test + void shouldTransitionToStarted() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var message = TransferStartMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .dataAddress(DataAddress.Builder.newInstance().type("test").build()) + .build(); + var agreement = contractAgreement(); + var transferProcess = transferProcess(STARTED, "transferProcessId"); + + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); + + var result = service.notifyStarted(message, tokenRepresentation); + + var startedDataCaptor = ArgumentCaptor.forClass(TransferProcessStartedData.class); + var transferProcessCaptor = ArgumentCaptor.forClass(TransferProcess.class); + assertThat(result).isSucceeded(); + verify(listener).preStarted(any()); + verify(store).save(transferProcessCaptor.capture()); + verify(store).save(argThat(t -> t.getState() == STARTED.code())); + verify(listener).started(any(), startedDataCaptor.capture()); + verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); + assertThat(startedDataCaptor.getValue().getDataAddress()).usingRecursiveComparison().isEqualTo(message.getDataAddress()); + } + + @Test + void shouldReturnConflict_whenTransferCannotBeStarted() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var transferProcess = transferProcess(DEPROVISIONING, UUID.randomUUID().toString()); + var message = TransferStartMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .build(); + var agreement = contractAgreement(); + + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); + + var result = service.notifyStarted(message, tokenRepresentation); + + assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(CONFLICT); + // state didn't change + verify(store, times(1)).save(argThat(tp -> tp.getState() == DEPROVISIONING.code())); + verifyNoInteractions(listener); + } + + @Test + void shouldReturnBadRequest_whenCounterPartyUnauthorized() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var message = TransferStartMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .dataAddress(DataAddress.Builder.newInstance().type("test").build()) + .build(); + var agreement = contractAgreement(); + + var transferProcess = transferProcess(REQUESTED, "transferProcessId"); + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.failure("error")); + + var result = service.notifyStarted(message, tokenRepresentation); + + assertThat(result) + .isFailed() + .extracting(ServiceFailure::getReason) + .isEqualTo(BAD_REQUEST); + + verify(store, times(1)).save(any()); + + } + } + + @Nested + class NotifyStartedResumed { + + @Test + void shouldTransitionToStartedAndStartDataFlow_whenProvider() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var message = TransferStartMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .dataAddress(DataAddress.Builder.newInstance().type("test").build()) + .build(); + var agreement = contractAgreement(); + var transferProcess = transferProcessBuilder().id("transferProcessId") + .state(SUSPENDED.code()).type(PROVIDER).build(); + + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); + + var result = service.notifyStarted(message, tokenRepresentation); + + var transferProcessCaptor = ArgumentCaptor.forClass(TransferProcess.class); + assertThat(result).isSucceeded(); + verify(store).save(transferProcessCaptor.capture()); + var storedTransferProcess = transferProcessCaptor.getValue(); + assertThat(storedTransferProcess.getState()).isEqualTo(STARTING.code()); + verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); + } + + @Test + void shouldReturnError_whenStatusIsNotSuspendedAndTypeProvider() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var message = TransferStartMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .dataAddress(DataAddress.Builder.newInstance().type("test").build()) + .build(); + var agreement = contractAgreement(); + var transferProcess = transferProcessBuilder().id("transferProcessId") + .state(REQUESTED.code()).type(PROVIDER).build(); + var dataFlowResponse = DataFlowResponse.Builder.newInstance().dataPlaneId("dataPlaneId").build(); + when(dataFlowManager.start(any(), any())).thenReturn(StatusResult.success(dataFlowResponse)); + + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); + + var result = service.notifyStarted(message, tokenRepresentation); + + assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(CONFLICT); + } + } + + @Nested + class NotifySuspended { + @Test + void consumer_shouldTransitionToSuspended() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var message = TransferSuspensionMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .code("TestCode") + .reason("TestReason") + .build(); + var agreement = contractAgreement(); + var transferProcess = transferProcessBuilder().state(STARTED.code()).type(CONSUMER).build(); + + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); + var result = service.notifySuspended(message, tokenRepresentation); + + assertThat(result).isSucceeded(); + verify(store).save(argThat(t -> t.getState() == SUSPENDED.code())); + verify(listener).suspended(any()); + verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); + } + + @Test + void provider_shouldSuspendDataFlowAndTransitionToSuspended() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var message = TransferSuspensionMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .code("TestCode") + .reason("TestReason") + .build(); + var agreement = contractAgreement(); + var transferProcess = transferProcessBuilder().state(STARTED.code()).type(PROVIDER).build(); + + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); + when(dataFlowManager.suspend(any())).thenReturn(StatusResult.success()); + + var result = service.notifySuspended(message, tokenRepresentation); + + assertThat(result).isSucceeded(); + verify(store).save(argThat(t -> t.getState() == SUSPENDED.code())); + verify(listener).suspended(any()); + verify(transactionContext, atLeastOnce()).execute(any(TransactionContext.ResultTransactionBlock.class)); + } + + @Test + void provider_shouldReturnConflict_whenDataFlowCannotBeSuspended() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var message = TransferSuspensionMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .code("TestCode") + .reason("TestReason") + .build(); + var agreement = contractAgreement(); + var transferProcess = transferProcessBuilder().state(STARTED.code()).type(PROVIDER).build(); + + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); + when(dataFlowManager.suspend(any())).thenReturn(StatusResult.failure(FATAL_ERROR)); + + var result = service.notifySuspended(message, tokenRepresentation); + + assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(CONFLICT); + verify(store, times(1)).save(argThat(tp -> tp.getState() == STARTED.code())); + } + + @Test + void shouldReturnConflict_whenTransferProcessCannotBeSuspended() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var transferProcess = transferProcess(DEPROVISIONING, UUID.randomUUID().toString()); + var agreement = contractAgreement(); + var message = TransferSuspensionMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .code("TestCode") + .reason("TestReason") + .build(); + + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.success()); + + var result = service.notifySuspended(message, tokenRepresentation); + + assertThat(result).isFailed().extracting(ServiceFailure::getReason).isEqualTo(CONFLICT); + // state didn't change + verify(store, times(1)).save(argThat(tp -> tp.getState() == DEPROVISIONING.code())); + verifyNoInteractions(listener); + } + + @Test + void shouldReturnBadRequest_whenCounterPartyUnauthorized() { + var participantAgent = participantAgent(); + var tokenRepresentation = tokenRepresentation(); + var agreement = contractAgreement(); + var transferProcess = transferProcess(TERMINATED, UUID.randomUUID().toString()); + var message = TransferSuspensionMessage.Builder.newInstance() + .protocol("protocol") + .consumerPid("consumerPid") + .providerPid("providerPid") + .counterPartyAddress("http://any") + .processId("correlationId") + .code("TestCode") + .reason("TestReason") + .build(); + + when(protocolTokenValidator.verify(eq(tokenRepresentation), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent)); + when(store.findById("correlationId")).thenReturn(transferProcess); + when(store.findByIdAndLease("correlationId")).thenReturn(StoreResult.success(transferProcess)); + when(negotiationStore.findContractAgreement(any())).thenReturn(agreement); + when(validationService.validateRequest(participantAgent, agreement)).thenReturn(Result.failure("error")); + + var result = service.notifySuspended(message, tokenRepresentation); + + assertThat(result) + .isFailed() + .extracting(ServiceFailure::getReason) + .isEqualTo(BAD_REQUEST); + + verify(store, times(1)).save(any()); + + } + } + @Nested class IdempotencyProcessStateReplication { @@ -922,7 +923,7 @@ void notify_shouldStoreReceivedMessageId(Method TransferProcess.Type type, TransferProcessStates currentState) { var transferProcess = transferProcessBuilder().state(currentState.code()).type(type).build(); - when(protocolTokenValidator.verify(any(), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent())); + when(protocolTokenValidator.verify(any(), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent())); when(store.findById(any())).thenReturn(transferProcess); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(transferProcess)); when(negotiationStore.findContractAgreement(any())).thenReturn(contractAgreement()); @@ -946,7 +947,7 @@ void notify_shouldIgnoreMessage_whenAlreadyRece TransferProcessStates currentState) { var transferProcess = transferProcessBuilder().state(currentState.code()).type(type).build(); transferProcess.protocolMessageReceived(message.getId()); - when(protocolTokenValidator.verify(any(), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent())); + when(protocolTokenValidator.verify(any(), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent())); when(store.findById(any())).thenReturn(transferProcess); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(transferProcess)); when(negotiationStore.findContractAgreement(any())).thenReturn(contractAgreement()); @@ -965,7 +966,7 @@ void notify_shouldIgnoreMessage_whenAlreadyRece void notify_shouldIgnoreMessage_whenFinalState(MethodCall methodCall, M message, TransferProcess.Type type) { var transferProcess = transferProcessBuilder().state(COMPLETED.code()).type(type).build(); - when(protocolTokenValidator.verify(any(), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any())).thenReturn(ServiceResult.success(participantAgent())); + when(protocolTokenValidator.verify(any(), eq(TRANSFER_PROCESS_REQUEST_SCOPE), any(), eq(message))).thenReturn(ServiceResult.success(participantAgent())); when(store.findById(any())).thenReturn(transferProcess); when(store.findByIdAndLease(any())).thenReturn(StoreResult.success(transferProcess)); when(negotiationStore.findContractAgreement(any())).thenReturn(contractAgreement()); diff --git a/data-protocols/dsp/dsp-http-core/src/main/java/org/eclipse/edc/protocol/dsp/dispatcher/DspHttpRemoteMessageDispatcherImpl.java b/data-protocols/dsp/dsp-http-core/src/main/java/org/eclipse/edc/protocol/dsp/dispatcher/DspHttpRemoteMessageDispatcherImpl.java index f777590eb62..43515ad4655 100644 --- a/data-protocols/dsp/dsp-http-core/src/main/java/org/eclipse/edc/protocol/dsp/dispatcher/DspHttpRemoteMessageDispatcherImpl.java +++ b/data-protocols/dsp/dsp-http-core/src/main/java/org/eclipse/edc/protocol/dsp/dispatcher/DspHttpRemoteMessageDispatcherImpl.java @@ -27,6 +27,7 @@ import org.eclipse.edc.spi.http.EdcHttpClient; import org.eclipse.edc.spi.iam.AudienceResolver; import org.eclipse.edc.spi.iam.IdentityService; +import org.eclipse.edc.spi.iam.RequestContext; import org.eclipse.edc.spi.iam.RequestScope; import org.eclipse.edc.spi.iam.TokenParameters; import org.eclipse.edc.spi.response.StatusResult; @@ -95,8 +96,14 @@ public CompletableFuture> dispatch( var policyScope = policyScopes.get(message.getClass()); if (policyScope != null) { var requestScopeBuilder = RequestScope.Builder.newInstance(); + var requestContext = RequestContext.Builder.newInstance() + .message(message) + .direction(RequestContext.Direction.Egress) + .build(); + var context = PolicyContextImpl.Builder.newInstance() .additional(RequestScope.Builder.class, requestScopeBuilder) + .additional(RequestContext.class, requestContext) .build(); var policyProvider = (Function) policyScope.policyProvider; policyEngine.evaluate(policyScope.scope, policyProvider.apply(message), context); diff --git a/data-protocols/dsp/dsp-http-core/src/test/java/org/eclipse/edc/protocol/dsp/dispatcher/DspHttpRemoteMessageDispatcherImplTest.java b/data-protocols/dsp/dsp-http-core/src/test/java/org/eclipse/edc/protocol/dsp/dispatcher/DspHttpRemoteMessageDispatcherImplTest.java index 59df78ade01..e3d960a56c5 100644 --- a/data-protocols/dsp/dsp-http-core/src/test/java/org/eclipse/edc/protocol/dsp/dispatcher/DspHttpRemoteMessageDispatcherImplTest.java +++ b/data-protocols/dsp/dsp-http-core/src/test/java/org/eclipse/edc/protocol/dsp/dispatcher/DspHttpRemoteMessageDispatcherImplTest.java @@ -28,6 +28,7 @@ import org.eclipse.edc.spi.http.EdcHttpClient; import org.eclipse.edc.spi.iam.AudienceResolver; import org.eclipse.edc.spi.iam.IdentityService; +import org.eclipse.edc.spi.iam.RequestContext; import org.eclipse.edc.spi.iam.RequestScope; import org.eclipse.edc.spi.iam.TokenParameters; import org.eclipse.edc.spi.iam.TokenRepresentation; @@ -208,6 +209,10 @@ void dispatch_PolicyEvaluatedScope() { verify(identityService).obtainClientCredentials(captor.capture()); verify(httpClient).executeAsync(argThat(r -> authToken.equals(r.headers().get("Authorization"))), isA(List.class)); verify(requestFactory).createRequest(message); + verify(policyEngine).evaluate(any(), any(), argThat(ctx -> { + var requestContext = ctx.getContextData(RequestContext.class); + return requestContext.getMessage().getClass().equals(TestMessage.class) && requestContext.getDirection().equals(RequestContext.Direction.Egress); + })); assertThat(captor.getValue()).satisfies(tr -> { assertThat(tr.getStringClaim(SCOPE_CLAIM)).isEqualTo("policy-test-scope"); assertThat(tr.getStringClaim(AUDIENCE_CLAIM)).isEqualTo(AUDIENCE_VALUE); diff --git a/spi/common/core-spi/src/main/java/org/eclipse/edc/spi/iam/RequestContext.java b/spi/common/core-spi/src/main/java/org/eclipse/edc/spi/iam/RequestContext.java new file mode 100644 index 00000000000..c071b5205a6 --- /dev/null +++ b/spi/common/core-spi/src/main/java/org/eclipse/edc/spi/iam/RequestContext.java @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2024 Bayerische Motoren Werke Aktiengesellschaft (BMW AG) + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0 + * + * SPDX-License-Identifier: Apache-2.0 + * + * Contributors: + * Bayerische Motoren Werke Aktiengesellschaft (BMW AG) - initial API and implementation + * + */ + +package org.eclipse.edc.spi.iam; + +import org.eclipse.edc.spi.types.domain.message.RemoteMessage; + +/** + * Provides additional context for scope extractors. + */ +public class RequestContext { + + private RemoteMessage message; + private Direction direction; + + private RequestContext() { + } + + /** + * Return the {@link RemoteMessage} associated to the request + * + * @return The message + */ + public RemoteMessage getMessage() { + return message; + } + + /** + * Returns the direction of the message Egress/Ingress + * + * @return The direction + */ + public Direction getDirection() { + return direction; + } + + public enum Direction { + Egress, + Ingress + } + + public static class Builder { + private final RequestContext context; + + private Builder() { + context = new RequestContext(); + } + + public static Builder newInstance() { + return new Builder(); + } + + public Builder message(RemoteMessage message) { + context.message = message; + return this; + } + + public Builder direction(Direction direction) { + context.direction = direction; + return this; + } + + public RequestContext build() { + return context; + } + + } +} diff --git a/spi/control-plane/control-plane-spi/src/main/java/org/eclipse/edc/connector/spi/protocol/ProtocolTokenValidator.java b/spi/control-plane/control-plane-spi/src/main/java/org/eclipse/edc/connector/spi/protocol/ProtocolTokenValidator.java index b95c1d6280f..2ec92dce863 100644 --- a/spi/control-plane/control-plane-spi/src/main/java/org/eclipse/edc/connector/spi/protocol/ProtocolTokenValidator.java +++ b/spi/control-plane/control-plane-spi/src/main/java/org/eclipse/edc/connector/spi/protocol/ProtocolTokenValidator.java @@ -19,6 +19,7 @@ import org.eclipse.edc.spi.agent.ParticipantAgent; import org.eclipse.edc.spi.iam.TokenRepresentation; import org.eclipse.edc.spi.result.ServiceResult; +import org.eclipse.edc.spi.types.domain.message.RemoteMessage; /** * Token validator to be used in protocol layer for verifying the token according the @@ -35,7 +36,19 @@ public interface ProtocolTokenValidator { * @return Returns the extracted {@link ParticipantAgent} if successful, failure otherwise */ default ServiceResult verify(TokenRepresentation tokenRepresentation, String policyScope) { - return verify(tokenRepresentation, policyScope, Policy.Builder.newInstance().build()); + return verify(tokenRepresentation, policyScope, Policy.Builder.newInstance().build(), null); + } + + /** + * Verify the {@link TokenRepresentation} + * + * @param tokenRepresentation The token + * @param policyScope The policy scope + * @param message The {@link RemoteMessage} + * @return Returns the extracted {@link ParticipantAgent} if successful, failure otherwise + */ + default ServiceResult verify(TokenRepresentation tokenRepresentation, String policyScope, RemoteMessage message) { + return verify(tokenRepresentation, policyScope, Policy.Builder.newInstance().build(), message); } /** @@ -44,7 +57,8 @@ default ServiceResult verify(TokenRepresentation tokenRepresen * @param tokenRepresentation The token * @param policyScope The policy scope * @param policy The policy + * @param message The {@link RemoteMessage} * @return Returns the extracted {@link ParticipantAgent} if successful, failure otherwise */ - ServiceResult verify(TokenRepresentation tokenRepresentation, String policyScope, Policy policy); + ServiceResult verify(TokenRepresentation tokenRepresentation, String policyScope, Policy policy, RemoteMessage message); }