Skip to content

Commit

Permalink
feat: adds request context in the policy evaluation context (#4052)
Browse files Browse the repository at this point in the history
* feat: adds request context in the policy evaluation context

* pr remarks
  • Loading branch information
wolf4ood authored Mar 27, 2024
1 parent 0df9e65 commit f2ca192
Show file tree
Hide file tree
Showing 13 changed files with 776 additions and 641 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ void init() {
.protocolWebhook(protocolWebhook)
.build();

when(protocolTokenValidator.verify(eq(tokenRepresentation), any(), any())).thenReturn(ServiceResult.success(participantAgent));
when(protocolTokenValidator.verify(eq(tokenRepresentation), any(), any(), any())).thenReturn(ServiceResult.success(participantAgent));
consumerService = new ContractNegotiationProtocolServiceImpl(consumerStore, new NoopTransactionContext(), validationService, offerResolver, protocolTokenValidator, new ContractNegotiationObservableImpl(), monitor, mock());
providerService = new ContractNegotiationProtocolServiceImpl(providerStore, new NoopTransactionContext(), validationService, offerResolver, protocolTokenValidator, new ContractNegotiationObservableImpl(), monitor, mock());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public CatalogProtocolServiceImpl(DatasetResolver datasetResolver,
@Override
@NotNull
public ServiceResult<Catalog> getCatalog(CatalogRequestMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> protocolTokenValidator.verify(tokenRepresentation, CATALOGING_REQUEST_SCOPE)
return transactionContext.execute(() -> protocolTokenValidator.verify(tokenRepresentation, CATALOGING_REQUEST_SCOPE, message)
.map(agent -> {
try (var datasets = datasetResolver.query(agent, message.getQuerySpec())) {
var dataServices = dataServiceRegistry.getDataServices();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.result.ServiceResult;
import org.eclipse.edc.spi.telemetry.Telemetry;
import org.eclipse.edc.spi.types.domain.message.RemoteMessage;
import org.eclipse.edc.transaction.spi.TransactionContext;
import org.jetbrains.annotations.NotNull;

Expand Down Expand Up @@ -85,7 +86,7 @@ public ContractNegotiationProtocolServiceImpl(ContractNegotiationStore store,
@NotNull
public ServiceResult<ContractNegotiation> notifyRequested(ContractRequestMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchValidatableOffer(message)
.compose(validatableOffer -> verifyRequest(tokenRepresentation, validatableOffer.getContractPolicy())
.compose(validatableOffer -> verifyRequest(tokenRepresentation, validatableOffer.getContractPolicy(), message)
.compose(agent -> validateOffer(agent, validatableOffer))
.compose(validatedOffer -> {
var result = message.getProviderPid() == null
Expand All @@ -110,12 +111,12 @@ public ServiceResult<ContractNegotiation> notifyRequested(ContractRequestMessage
@WithSpan
@NotNull
public ServiceResult<ContractNegotiation> notifyOffered(ContractOfferMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> verifyRequest(tokenRepresentation, message.getContractOffer().getPolicy())
return transactionContext.execute(() -> verifyRequest(tokenRepresentation, message.getContractOffer().getPolicy(), message)
.compose(agent -> {
ServiceResult<ContractNegotiation> result = message.getConsumerPid() == null
? createNegotiation(message, agent.getIdentity(), CONSUMER, message.getCallbackAddress())
: getAndLeaseNegotiation(message.getProviderPid())
.compose(negotiation -> validateRequest(agent, negotiation).map(it -> negotiation));
.compose(negotiation -> validateRequest(agent, negotiation).map(it -> negotiation));

return result.onSuccess(negotiation -> {
if (negotiation.shouldIgnoreIncomingMessage(message.getId())) {
Expand All @@ -135,7 +136,7 @@ public ServiceResult<ContractNegotiation> notifyOffered(ContractOfferMessage mes
@NotNull
public ServiceResult<ContractNegotiation> notifyAccepted(ContractNegotiationEventMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> getNegotiation(message.getProcessId())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), message)
.compose(agent -> validateRequest(agent, contractNegotiation)))
.compose(cn -> onMessageDo(message, contractNegotiation -> acceptedAction(message, contractNegotiation))));

Expand All @@ -146,7 +147,7 @@ public ServiceResult<ContractNegotiation> notifyAccepted(ContractNegotiationEven
@NotNull
public ServiceResult<ContractNegotiation> notifyAgreed(ContractAgreementMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> getNegotiation(message.getProcessId())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), message)
.compose(agent -> validateAgreed(message, agent, contractNegotiation)))
.compose(cn -> onMessageDo(message, contractNegotiation -> agreedAction(message, contractNegotiation))));
}
Expand All @@ -156,7 +157,7 @@ public ServiceResult<ContractNegotiation> notifyAgreed(ContractAgreementMessage
@NotNull
public ServiceResult<ContractNegotiation> notifyVerified(ContractAgreementVerificationMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> getNegotiation(message.getProcessId())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), message)
.compose(agent -> validateRequest(agent, contractNegotiation)))
.compose(cn -> onMessageDo(message, contractNegotiation -> verifiedAction(message, contractNegotiation))));
}
Expand All @@ -166,7 +167,7 @@ public ServiceResult<ContractNegotiation> notifyVerified(ContractAgreementVerifi
@NotNull
public ServiceResult<ContractNegotiation> notifyFinalized(ContractNegotiationEventMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> getNegotiation(message.getProcessId())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), message)
.compose(agent -> validateRequest(agent, contractNegotiation)))
.compose(cn -> onMessageDo(message, contractNegotiation -> finalizedAction(message, contractNegotiation))));
}
Expand All @@ -176,7 +177,7 @@ public ServiceResult<ContractNegotiation> notifyFinalized(ContractNegotiationEve
@NotNull
public ServiceResult<ContractNegotiation> notifyTerminated(ContractNegotiationTerminationMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> getNegotiation(message.getProcessId())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), message)
.compose(agent -> validateRequest(agent, contractNegotiation)))
.compose(cn -> onMessageDo(message, contractNegotiation -> terminatedAction(message, contractNegotiation))));
}
Expand All @@ -186,7 +187,7 @@ public ServiceResult<ContractNegotiation> notifyTerminated(ContractNegotiationTe
@NotNull
public ServiceResult<ContractNegotiation> findById(String id, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> getNegotiation(id)
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy())
.compose(contractNegotiation -> verifyRequest(tokenRepresentation, contractNegotiation.getLastContractOffer().getPolicy(), null)
.compose(agent -> validateRequest(agent, contractNegotiation)
.map(it -> contractNegotiation))));
}
Expand Down Expand Up @@ -318,8 +319,8 @@ private ServiceResult<ContractNegotiation> getAndLeaseNegotiation(String negotia
.flatMap(ServiceResult::from);
}

private ServiceResult<ParticipantAgent> verifyRequest(TokenRepresentation tokenRepresentation, Policy policy) {
return protocolTokenValidator.verify(tokenRepresentation, CONTRACT_NEGOTIATION_REQUEST_SCOPE, policy)
private ServiceResult<ParticipantAgent> verifyRequest(TokenRepresentation tokenRepresentation, Policy policy, RemoteMessage message) {
return protocolTokenValidator.verify(tokenRepresentation, CONTRACT_NEGOTIATION_REQUEST_SCOPE, policy, message)
.onFailure(failure -> monitor.debug(() -> "Verification Failed: %s".formatted(failure.getFailureDetail())));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import org.eclipse.edc.spi.agent.ParticipantAgent;
import org.eclipse.edc.spi.agent.ParticipantAgentService;
import org.eclipse.edc.spi.iam.IdentityService;
import org.eclipse.edc.spi.iam.RequestContext;
import org.eclipse.edc.spi.iam.RequestScope;
import org.eclipse.edc.spi.iam.TokenRepresentation;
import org.eclipse.edc.spi.iam.VerificationContext;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.result.ServiceResult;
import org.eclipse.edc.spi.types.domain.message.RemoteMessage;

/**
* Implementation of {@link ProtocolTokenValidator} which uses the {@link PolicyEngine} for extracting
Expand All @@ -48,8 +50,8 @@ public ProtocolTokenValidatorImpl(IdentityService identityService, PolicyEngine
}

@Override
public ServiceResult<ParticipantAgent> verify(TokenRepresentation tokenRepresentation, String policyScope, Policy policy) {
var tokenValidation = identityService.verifyJwtToken(tokenRepresentation, createVerificationContext(policyScope, policy));
public ServiceResult<ParticipantAgent> verify(TokenRepresentation tokenRepresentation, String policyScope, Policy policy, RemoteMessage message) {
var tokenValidation = identityService.verifyJwtToken(tokenRepresentation, createVerificationContext(policyScope, policy, message));
if (tokenValidation.failed()) {
monitor.debug(() -> "Unauthorized: %s".formatted(tokenValidation.getFailureDetail()));
return ServiceResult.unauthorized("Unauthorized");
Expand All @@ -60,9 +62,11 @@ public ServiceResult<ParticipantAgent> verify(TokenRepresentation tokenRepresent
return ServiceResult.success(participantAgent);
}

private VerificationContext createVerificationContext(String scope, Policy policy) {
private VerificationContext createVerificationContext(String scope, Policy policy, RemoteMessage message) {
var requestScopeBuilder = RequestScope.Builder.newInstance();
var requestContext = RequestContext.Builder.newInstance().message(message).direction(RequestContext.Direction.Ingress).build();
var policyContext = PolicyContextImpl.Builder.newInstance()
.additional(RequestContext.class, requestContext)
.additional(RequestScope.Builder.class, requestScopeBuilder)
.build();
policyEngine.evaluate(scope, policy, policyContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.eclipse.edc.spi.telemetry.Telemetry;
import org.eclipse.edc.spi.types.domain.DataAddress;
import org.eclipse.edc.spi.types.domain.agreement.ContractAgreement;
import org.eclipse.edc.spi.types.domain.message.RemoteMessage;
import org.eclipse.edc.transaction.spi.TransactionContext;
import org.eclipse.edc.validator.spi.DataAddressValidatorRegistry;
import org.jetbrains.annotations.NotNull;
Expand Down Expand Up @@ -97,7 +98,7 @@ public TransferProcessProtocolServiceImpl(TransferProcessStore transferProcessSt
@NotNull
public ServiceResult<TransferProcess> notifyRequested(TransferRequestMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchNotifyRequestContext(message)
.compose(context -> verifyRequest(tokenRepresentation, context))
.compose(context -> verifyRequest(tokenRepresentation, context, message))
.compose(context -> validateDestination(message, context))
.compose(context -> validateAgreement(message, context))
.compose(context -> requestedAction(message, context.agreement().getAssetId())));
Expand All @@ -108,7 +109,7 @@ public ServiceResult<TransferProcess> notifyRequested(TransferRequestMessage mes
@NotNull
public ServiceResult<TransferProcess> notifyStarted(TransferStartMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess)
.compose(context -> verifyRequest(tokenRepresentation, context))
.compose(context -> verifyRequest(tokenRepresentation, context, message))
.compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> startedAction(message, transferProcess)))
);
}
Expand All @@ -118,15 +119,15 @@ public ServiceResult<TransferProcess> notifyStarted(TransferStartMessage message
@NotNull
public ServiceResult<TransferProcess> notifyCompleted(TransferCompletionMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess)
.compose(context -> verifyRequest(tokenRepresentation, context))
.compose(context -> verifyRequest(tokenRepresentation, context, message))
.compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> completedAction(message, transferProcess)))
);
}

@Override
public @NotNull ServiceResult<TransferProcess> notifySuspended(TransferSuspensionMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess)
.compose(context -> verifyRequest(tokenRepresentation, context))
.compose(context -> verifyRequest(tokenRepresentation, context, message))
.compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> suspendedAction(message, transferProcess)))
);
}
Expand All @@ -136,7 +137,7 @@ public ServiceResult<TransferProcess> notifyCompleted(TransferCompletionMessage
@NotNull
public ServiceResult<TransferProcess> notifyTerminated(TransferTerminationMessage message, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchRequestContext(message, this::findTransferProcess)
.compose(context -> verifyRequest(tokenRepresentation, context))
.compose(context -> verifyRequest(tokenRepresentation, context, message))
.compose(context -> onMessageDo(message, context.participantAgent(), context.agreement(), transferProcess -> terminatedAction(message, transferProcess)))
);
}
Expand All @@ -146,7 +147,7 @@ public ServiceResult<TransferProcess> notifyTerminated(TransferTerminationMessag
@NotNull
public ServiceResult<TransferProcess> findById(String id, TokenRepresentation tokenRepresentation) {
return transactionContext.execute(() -> fetchRequestContext(id, this::findTransferProcessById)
.compose(context -> verifyRequest(tokenRepresentation, context))
.compose(context -> verifyRequest(tokenRepresentation, context, null))
.compose(context -> validateCounterParty(context.participantAgent(), context.agreement(), context.transferProcess())));
}

Expand Down Expand Up @@ -281,8 +282,8 @@ private <T> ServiceResult<TransferRequestMessageContext> fetchRequestContext(T i
return tpProvider.apply(input).compose(transferProcess -> findContractByTransferProcess(transferProcess).map(agreement -> new TransferRequestMessageContext(agreement, transferProcess)));
}

private ServiceResult<ClaimTokenContext> verifyRequest(TokenRepresentation tokenRepresentation, TransferRequestMessageContext context) {
var result = protocolTokenValidator.verify(tokenRepresentation, TRANSFER_PROCESS_REQUEST_SCOPE, context.agreement().getPolicy());
private ServiceResult<ClaimTokenContext> verifyRequest(TokenRepresentation tokenRepresentation, TransferRequestMessageContext context, RemoteMessage message) {
var result = protocolTokenValidator.verify(tokenRepresentation, TRANSFER_PROCESS_REQUEST_SCOPE, context.agreement().getPolicy(), message);
if (result.failed()) {
monitor.debug(() -> "Verification Failed: %s".formatted(result.getFailureDetail()));
return ServiceResult.notFound("Not found");
Expand Down
Loading

0 comments on commit f2ca192

Please sign in to comment.