From b84940ef0cb70d1f17f17fabd566c73a0ade7ac6 Mon Sep 17 00:00:00 2001 From: Michael Lumish Date: Wed, 4 Dec 2024 10:40:22 -0500 Subject: [PATCH 1/2] grpc-js-xds: Add XdsChannelCredentials --- packages/grpc-js-xds/src/index.ts | 2 +- packages/grpc-js-xds/src/load-balancer-cds.ts | 161 +++++++++++++- packages/grpc-js-xds/src/resources.ts | 13 +- packages/grpc-js-xds/src/xds-credentials.ts | 38 +++- .../cluster-resource-type.ts | 98 ++++++++- .../listener-resource-type.ts | 15 ++ packages/grpc-js-xds/test/client.ts | 4 +- packages/grpc-js-xds/test/framework.ts | 19 +- .../grpc-js-xds/test/test-xds-credentials.ts | 203 +++++++++++++++++- packages/grpc-js/src/certificate-provider.ts | 18 +- packages/grpc-js/src/channel-credentials.ts | 119 ++++++---- packages/grpc-js/src/experimental.ts | 2 +- packages/grpc-js/src/internal-channel.ts | 1 - packages/grpc-js/src/load-balancing-call.ts | 3 +- packages/grpc-js/src/resolving-call.ts | 10 +- packages/grpc-js/src/server-credentials.ts | 3 + packages/grpc-js/src/subchannel-interface.ts | 9 + packages/grpc-js/src/subchannel.ts | 6 +- packages/grpc-js/src/transport.ts | 14 +- packages/grpc-js/test/common.ts | 3 + .../grpc-js/test/test-channel-credentials.ts | 37 +++- 21 files changed, 690 insertions(+), 88 deletions(-) diff --git a/packages/grpc-js-xds/src/index.ts b/packages/grpc-js-xds/src/index.ts index e080d1943..24a39b8c7 100644 --- a/packages/grpc-js-xds/src/index.ts +++ b/packages/grpc-js-xds/src/index.ts @@ -31,7 +31,7 @@ import * as typed_struct_lb from './lb-policy-registry/typed-struct'; import * as pick_first_lb from './lb-policy-registry/pick-first'; export { XdsServer } from './server'; -export { XdsServerCredentials } from './xds-credentials'; +export { XdsChannelCredentials, XdsServerCredentials } from './xds-credentials'; /** * Register the "xds:" name scheme with the @grpc/grpc-js library. diff --git a/packages/grpc-js-xds/src/load-balancer-cds.ts b/packages/grpc-js-xds/src/load-balancer-cds.ts index 1cb598228..9bcc18014 100644 --- a/packages/grpc-js-xds/src/load-balancer-cds.ts +++ b/packages/grpc-js-xds/src/load-balancer-cds.ts @@ -30,7 +30,11 @@ import { XdsConfig } from './xds-dependency-manager'; import { LocalityEndpoint, PriorityChildRaw } from './load-balancer-priority'; import { Locality__Output } from './generated/envoy/config/core/v3/Locality'; import { AGGREGATE_CLUSTER_BACKWARDS_COMPAT, EXPERIMENTAL_OUTLIER_DETECTION } from './environment'; -import { XDS_CONFIG_KEY } from './resolver-xds'; +import { XDS_CLIENT_KEY, XDS_CONFIG_KEY } from './resolver-xds'; +import { ContainsValueMatcher, Matcher, PrefixValueMatcher, RejectValueMatcher, SafeRegexValueMatcher, SuffixValueMatcher, ValueMatcher } from './matcher'; +import { StringMatcher__Output } from './generated/envoy/type/matcher/v3/StringMatcher'; +import { isIPv6 } from 'net'; +import { formatIPv6, parseIPv6 } from './cidr'; const TRACER_NAME = 'cds_balancer'; @@ -67,6 +71,125 @@ class CdsLoadBalancingConfig implements TypedLoadBalancingConfig { } } +type SupportedSanType = 'DNS' | 'URI' | 'email' | 'IP Address'; + +function isSupportedSanType(type: string): type is SupportedSanType { + return ['DNS', 'URI', 'email', 'IP Address'].includes(type); +} + +class DnsExactValueMatcher implements ValueMatcher { + constructor(private targetValue: string, private ignoreCase: boolean) { + if (ignoreCase) { + this.targetValue = this.targetValue.toLowerCase(); + } + } + apply(entry: string): boolean { + let [type, value] = entry.split(':'); + if (!isSupportedSanType(type)) { + return false; + } + if (!value) { + return false; + } + if (this.ignoreCase) { + value = value.toLowerCase(); + } + if (type === 'DNS' && value.startsWith('*.') && this.targetValue.includes('.', 1)) { + return value.substring(2) === this.targetValue.substring(this.targetValue.indexOf('.') + 1); + } else { + return value === this.targetValue; + } + } + + toString() { + return 'DnsExact(' + this.targetValue + ', ignore_case=' + this.ignoreCase + ')'; + } +} + +function canonicalizeSanEntryValue(type: SupportedSanType, value: string): string { + if (type === 'IP Address' && isIPv6(value)) { + return formatIPv6(parseIPv6(value)); + } + return value; +} + +class SanEntryMatcher implements ValueMatcher { + private childMatcher: ValueMatcher; + constructor(matcherConfig: StringMatcher__Output) { + const ignoreCase = matcherConfig.ignore_case; + switch(matcherConfig.match_pattern) { + case 'exact': + throw new Error('Unexpected exact matcher in SAN entry matcher'); + case 'prefix': + this.childMatcher = new PrefixValueMatcher(matcherConfig.prefix!, ignoreCase); + break; + case 'suffix': + this.childMatcher = new SuffixValueMatcher(matcherConfig.suffix!, ignoreCase); + break; + case 'safe_regex': + this.childMatcher = new SafeRegexValueMatcher(matcherConfig.safe_regex!.regex); + break; + case 'contains': + this.childMatcher = new ContainsValueMatcher(matcherConfig.contains!, ignoreCase); + break; + default: + this.childMatcher = new RejectValueMatcher(); + } + } + apply(entry: string): boolean { + let [type, value] = entry.split(':'); + if (!isSupportedSanType(type)) { + return false; + } + value = canonicalizeSanEntryValue(type, value); + if (!entry) { + return false; + } + return this.childMatcher.apply(value); + } + toString(): string { + return this.childMatcher.toString(); + } + +} + +export class SanMatcher implements ValueMatcher { + private childMatchers: ValueMatcher[]; + constructor(matcherConfigs: StringMatcher__Output[]) { + this.childMatchers = matcherConfigs.map(config => { + if (config.match_pattern === 'exact') { + return new DnsExactValueMatcher(config.exact!, config.ignore_case); + } else { + return new SanEntryMatcher(config); + } + }); + } + apply(value: string): boolean { + if (this.childMatchers.length === 0) { + return true; + } + for (const entry of value.split(', ')) { + for (const matcher of this.childMatchers) { + const checkResult = matcher.apply(entry); + if (checkResult) { + return true; + } + } + } + return false; + } + toString(): string { + return 'SanMatcher(' + this.childMatchers.map(matcher => matcher.toString()).sort().join(', ') + ')'; + } + + equals(other: SanMatcher): boolean { + return this.toString() === other.toString(); + } +} + +export const CA_CERT_PROVIDER_KEY = 'grpc.internal.ca_cert_provider'; +export const IDENTITY_CERT_PROVIDER_KEY = 'grpc.internal.identity_cert_provider'; +export const SAN_MATCHER_KEY = 'grpc.internal.san_matcher'; const RECURSION_DEPTH_LIMIT = 15; @@ -102,6 +225,8 @@ export class CdsLoadBalancer implements LoadBalancer { private priorityNames: string[] = []; private nextPriorityChildNumber = 0; + private latestSanMatcher: SanMatcher | null = null; + constructor(private readonly channelControlHelper: ChannelControlHelper) { this.childBalancer = new ChildLoadBalancerHandler(channelControlHelper); } @@ -140,7 +265,7 @@ export class CdsLoadBalancer implements LoadBalancer { leafClusters = getLeafClusters(xdsConfig, clusterName); } catch (e) { trace('xDS config parsing failed with error ' + (e as Error).message); - this.channelControlHelper.updateState(connectivityState.TRANSIENT_FAILURE, new UnavailablePicker({code: status.UNAVAILABLE, details: `xDS config parsing failed with error ${(e as Error).message}`, metadata: new Metadata()})); + this.channelControlHelper.updateState(connectivityState.TRANSIENT_FAILURE, new UnavailablePicker({code: status.UNAVAILABLE, details: `xDS config parsing failed with error ${(e as Error).message}`})); return; } const priorityChildren: {[name: string]: PriorityChildRaw} = {}; @@ -165,7 +290,7 @@ export class CdsLoadBalancer implements LoadBalancer { typedChildConfig = parseLoadBalancingConfig(childConfig); } catch (e) { trace('LB policy config parsing failed with error ' + (e as Error).message); - this.channelControlHelper.updateState(connectivityState.TRANSIENT_FAILURE, new UnavailablePicker({code: status.UNAVAILABLE, details: `LB policy config parsing failed with error ${(e as Error).message}`, metadata: new Metadata()})); + this.channelControlHelper.updateState(connectivityState.TRANSIENT_FAILURE, new UnavailablePicker({code: status.UNAVAILABLE, details: `LB policy config parsing failed with error ${(e as Error).message}`})); return; } this.childBalancer.updateAddressList(endpointList, typedChildConfig, {...options, [ROOT_CLUSTER_KEY]: clusterName}); @@ -272,17 +397,39 @@ export class CdsLoadBalancer implements LoadBalancer { } else { childConfig = xdsClusterImplConfig; } - trace(JSON.stringify(childConfig, undefined, 2)); let typedChildConfig: TypedLoadBalancingConfig; try { typedChildConfig = parseLoadBalancingConfig(childConfig); } catch (e) { trace('LB policy config parsing failed with error ' + (e as Error).message); - this.channelControlHelper.updateState(connectivityState.TRANSIENT_FAILURE, new UnavailablePicker({code: status.UNAVAILABLE, details: `LB policy config parsing failed with error ${(e as Error).message}`, metadata: new Metadata()})); + this.channelControlHelper.updateState(connectivityState.TRANSIENT_FAILURE, new UnavailablePicker({code: status.UNAVAILABLE, details: `LB policy config parsing failed with error ${(e as Error).message}`})); return; } - trace(JSON.stringify(typedChildConfig.toJsonObject(), undefined, 2)); - this.childBalancer.updateAddressList(childEndpointList, typedChildConfig, options); + const childOptions: ChannelOptions = {...options}; + if (clusterConfig.cluster.securityUpdate) { + const securityUpdate = clusterConfig.cluster.securityUpdate; + const xdsClient = options[XDS_CLIENT_KEY] as XdsClient; + const caCertProvider = xdsClient.getCertificateProvider(securityUpdate.caCertificateProviderInstance); + if (!caCertProvider) { + this.channelControlHelper.updateState(connectivityState.TRANSIENT_FAILURE, new UnavailablePicker({code: status.UNAVAILABLE, details: `Cluster ${clusterName} configured with CA certificate provider ${securityUpdate.caCertificateProviderInstance} not in bootstrap`})); + return; + } + if (securityUpdate.identityCertificateProviderInstance) { + const identityCertProvider = xdsClient.getCertificateProvider(securityUpdate.identityCertificateProviderInstance); + if (!identityCertProvider) { + this.channelControlHelper.updateState(connectivityState.TRANSIENT_FAILURE, new UnavailablePicker({code: status.UNAVAILABLE, details: `Cluster ${clusterName} configured with identity certificate provider ${securityUpdate.identityCertificateProviderInstance} not in bootstrap`})); + return; + } + childOptions[IDENTITY_CERT_PROVIDER_KEY] = identityCertProvider; + } + childOptions[CA_CERT_PROVIDER_KEY] = caCertProvider; + const sanMatcher = new SanMatcher(securityUpdate.subjectAltNameMatchers); + if (this.latestSanMatcher === null || !this.latestSanMatcher.equals(sanMatcher)) { + this.latestSanMatcher = sanMatcher; + } + childOptions[SAN_MATCHER_KEY] = this.latestSanMatcher; + } + this.childBalancer.updateAddressList(childEndpointList, typedChildConfig, childOptions); } } exitIdle(): void { diff --git a/packages/grpc-js-xds/src/resources.ts b/packages/grpc-js-xds/src/resources.ts index 244e9ec34..ba0bfeb16 100644 --- a/packages/grpc-js-xds/src/resources.ts +++ b/packages/grpc-js-xds/src/resources.ts @@ -29,6 +29,7 @@ import { ClusterConfig__Output } from './generated/envoy/extensions/clusters/agg import { HttpConnectionManager__Output } from './generated/envoy/extensions/filters/network/http_connection_manager/v3/HttpConnectionManager'; import { EXPERIMENTAL_FEDERATION } from './environment'; import { DownstreamTlsContext__Output } from './generated/envoy/extensions/transport_sockets/tls/v3/DownstreamTlsContext'; +import { UpstreamTlsContext__Output } from './generated/envoy/extensions/transport_sockets/tls/v3/UpstreamTlsContext'; export const EDS_TYPE_URL = 'type.googleapis.com/envoy.config.endpoint.v3.ClusterLoadAssignment'; export const CDS_TYPE_URL = 'type.googleapis.com/envoy.config.cluster.v3.Cluster'; @@ -55,10 +56,16 @@ export const DOWNSTREAM_TLS_CONTEXT_TYPE_URL = 'type.googleapis.com/envoy.extens export type DownstreamTlsContextTypeUrl = 'type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext'; +export const UPSTREAM_TLS_CONTEXT_TYPE_URL = 'type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext'; + +export type UpstreamTlsContextTypeUrl = 'type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext'; + +export type ResourceTypeUrl = AdsTypeUrl | HttpConnectionManagerTypeUrl | ClusterConfigTypeUrl | DownstreamTlsContextTypeUrl | UpstreamTlsContextTypeUrl; + /** * Map type URLs to their corresponding message types */ -export type AdsOutputType = T extends EdsTypeUrl +export type AdsOutputType = T extends EdsTypeUrl ? ClusterLoadAssignment__Output : T extends CdsTypeUrl ? Cluster__Output @@ -70,6 +77,8 @@ export type AdsOutputType(targetTypeUrl: T, message: Buffer): AdsOutputType { +export function decodeSingleResource(targetTypeUrl: T, message: Buffer): AdsOutputType { const name = targetTypeUrl.substring(targetTypeUrl.lastIndexOf('/') + 1); const type = resourceRoot.lookup(name); if (type) { diff --git a/packages/grpc-js-xds/src/xds-credentials.ts b/packages/grpc-js-xds/src/xds-credentials.ts index 4b60b2f50..6edac916f 100644 --- a/packages/grpc-js-xds/src/xds-credentials.ts +++ b/packages/grpc-js-xds/src/xds-credentials.ts @@ -15,7 +15,43 @@ * */ -import { ServerCredentials } from "@grpc/grpc-js"; +import { CallCredentials, ChannelCredentials, ChannelOptions, ServerCredentials, VerifyOptions, experimental } from "@grpc/grpc-js"; +import { CA_CERT_PROVIDER_KEY, IDENTITY_CERT_PROVIDER_KEY, SAN_MATCHER_KEY, SanMatcher } from "./load-balancer-cds"; +import GrpcUri = experimental.GrpcUri; +import SecureConnector = experimental.SecureConnector; +import createCertificateProviderChannelCredentials = experimental.createCertificateProviderChannelCredentials; + +export class XdsChannelCredentials extends ChannelCredentials { + constructor(private fallbackCredentials: ChannelCredentials) { + super(); + } + _isSecure(): boolean { + return true; + } + _equals(other: ChannelCredentials): boolean { + return other instanceof XdsChannelCredentials && this.fallbackCredentials === other.fallbackCredentials; + } + _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions, callCredentials?: CallCredentials): SecureConnector { + if (options[CA_CERT_PROVIDER_KEY]) { + const verifyOptions: VerifyOptions = {}; + if (options[SAN_MATCHER_KEY]) { + const matcher = options[SAN_MATCHER_KEY] as SanMatcher; + verifyOptions.checkServerIdentity = (hostname, cert) => { + if (cert.subjectaltname && matcher.apply(cert.subjectaltname)) { + return undefined; + } else { + return new Error('No matching subject alternative name found in certificate'); + } + } + } + const certProviderCreds = createCertificateProviderChannelCredentials(options[CA_CERT_PROVIDER_KEY], options[IDENTITY_CERT_PROVIDER_KEY] ?? null, verifyOptions); + return certProviderCreds._createSecureConnector(channelTarget, options, callCredentials); + } else { + return this.fallbackCredentials._createSecureConnector(channelTarget, options, callCredentials); + } + } + +} export class XdsServerCredentials extends ServerCredentials { constructor(private fallbackCredentials: ServerCredentials) { diff --git a/packages/grpc-js-xds/src/xds-resource-type/cluster-resource-type.ts b/packages/grpc-js-xds/src/xds-resource-type/cluster-resource-type.ts index c4081baa8..d6442622f 100644 --- a/packages/grpc-js-xds/src/xds-resource-type/cluster-resource-type.ts +++ b/packages/grpc-js-xds/src/xds-resource-type/cluster-resource-type.ts @@ -15,7 +15,7 @@ * */ -import { CDS_TYPE_URL, CLUSTER_CONFIG_TYPE_URL, decodeSingleResource } from "../resources"; +import { CDS_TYPE_URL, CLUSTER_CONFIG_TYPE_URL, decodeSingleResource, UPSTREAM_TLS_CONTEXT_TYPE_URL } from "../resources"; import { XdsDecodeContext, XdsDecodeResult, XdsResourceType } from "./xds-resource-type"; import { LoadBalancingConfig, experimental, logVerbosity } from "@grpc/grpc-js"; import { XdsServerConfig } from "../xds-bootstrap"; @@ -31,6 +31,8 @@ import { convertToLoadBalancingConfig } from "../lb-policy-registry"; import SuccessRateEjectionConfig = experimental.SuccessRateEjectionConfig; import FailurePercentageEjectionConfig = experimental.FailurePercentageEjectionConfig; import parseLoadBalancingConfig = experimental.parseLoadBalancingConfig; +import { StringMatcher__Output } from "../generated/envoy/type/matcher/v3/StringMatcher"; +import { CertificateValidationContext__Output } from "../generated/envoy/extensions/transport_sockets/tls/v3/CertificateValidationContext"; const TRACER_NAME = 'xds_client'; @@ -38,6 +40,11 @@ function trace(text: string): void { experimental.trace(logVerbosity.DEBUG, TRACER_NAME, text); } +export interface SecurityUpdate { + caCertificateProviderInstance: string; + identityCertificateProviderInstance?: string; + subjectAltNameMatchers: StringMatcher__Output[]; +} export interface CdsUpdate { type: 'AGGREGATE' | 'EDS' | 'LOGICAL_DNS'; @@ -49,6 +56,7 @@ export interface CdsUpdate { dnsHostname?: string; lbPolicyConfig: LoadBalancingConfig[]; outlierDetectionUpdate?: experimental.OutlierDetectionRawConfig; + securityUpdate?: SecurityUpdate; } function convertOutlierDetectionUpdate(outlierDetection: OutlierDetection__Output | null): experimental.OutlierDetectionRawConfig | undefined { @@ -201,6 +209,85 @@ export class ClusterResourceType extends XdsResourceType { } } } + let securityUpdate: SecurityUpdate | undefined = undefined; + if (message.transport_socket) { + const transportSocket = message.transport_socket; + if (!transportSocket.typed_config) { + trace('transportSocket.typed_config missing'); + return null; + } + if (transportSocket.typed_config.type_url !== UPSTREAM_TLS_CONTEXT_TYPE_URL) { + trace('Incorrect transportSocket.typed_config.type_url: ' + transportSocket.typed_config.type_url) + return null; + } + const upstreamTlsContext = decodeSingleResource(UPSTREAM_TLS_CONTEXT_TYPE_URL, transportSocket.typed_config.value); + if (!upstreamTlsContext.common_tls_context) { + trace('Could not decode UpstreamTlsContext'); + return null; + } + trace('Decoded UpstreamTlsContext: ' + JSON.stringify(upstreamTlsContext, undefined, 2)); + const commonTlsContext = upstreamTlsContext.common_tls_context; + let validationContext: CertificateValidationContext__Output; + switch (commonTlsContext.validation_context_type) { + case 'validation_context_sds_secret_config': + return null; + case 'validation_context': + if (!commonTlsContext.validation_context) { + return null; + } + validationContext = commonTlsContext.validation_context; + break; + case 'combined_validation_context': + if (!commonTlsContext.combined_validation_context?.default_validation_context) { + return null; + } + validationContext = commonTlsContext.combined_validation_context.default_validation_context; + break; + default: + return null; + } + if (!validationContext.ca_certificate_provider_instance) { + return null; + } + if (!(validationContext.ca_certificate_provider_instance.instance_name in context.bootstrap.certificateProviders)) { + return null; + } + if (validationContext.verify_certificate_spki.length > 0) { + return null; + } + if (validationContext.verify_certificate_hash.length > 0) { + return null; + } + if (validationContext.require_signed_certificate_timestamp) { + return null; + } + if (validationContext.crl) { + return null; + } + if (validationContext.custom_validator_config) { + return null; + } + if (commonTlsContext.tls_certificate_provider_instance) { + if (!(commonTlsContext.tls_certificate_provider_instance.instance_name in context.bootstrap.certificateProviders)) { + return null; + } + } else { + if (commonTlsContext.tls_certificates.length > 0 || commonTlsContext.tls_certificate_sds_secret_configs.length > 0) { + return null; + } + } + if (commonTlsContext.tls_params) { + return null; + } + if (commonTlsContext.custom_handshaker) { + return null; + } + securityUpdate = { + caCertificateProviderInstance: validationContext.ca_certificate_provider_instance.instance_name, + identityCertificateProviderInstance: commonTlsContext.tls_certificate_provider_instance?.instance_name, + subjectAltNameMatchers: validationContext.match_subject_alt_names + } + } if (message.cluster_discovery_type === 'cluster_type') { if (!(message.cluster_type?.typed_config && message.cluster_type.typed_config.type_url === CLUSTER_CONFIG_TYPE_URL)) { return null; @@ -214,7 +301,8 @@ export class ClusterResourceType extends XdsResourceType { name: message.name, aggregateChildren: clusterConfig.clusters, outlierDetectionUpdate: convertOutlierDetectionUpdate(null), - lbPolicyConfig: [lbPolicyConfig] + lbPolicyConfig: [lbPolicyConfig], + securityUpdate: securityUpdate }; } else { let maxConcurrentRequests: number | undefined = undefined; @@ -238,7 +326,8 @@ export class ClusterResourceType extends XdsResourceType { edsServiceName: message.eds_cluster_config.service_name === '' ? undefined : message.eds_cluster_config.service_name, lrsLoadReportingServer: message.lrs_server ? context.server : undefined, outlierDetectionUpdate: convertOutlierDetectionUpdate(message.outlier_detection), - lbPolicyConfig: [lbPolicyConfig] + lbPolicyConfig: [lbPolicyConfig], + securityUpdate: securityUpdate } } else if (message.type === 'LOGICAL_DNS') { if (!message.load_assignment) { @@ -268,7 +357,8 @@ export class ClusterResourceType extends XdsResourceType { dnsHostname: `${socketAddress.address}:${socketAddress.port_value}`, lrsLoadReportingServer: message.lrs_server ? context.server : undefined, outlierDetectionUpdate: convertOutlierDetectionUpdate(message.outlier_detection), - lbPolicyConfig: [lbPolicyConfig] + lbPolicyConfig: [lbPolicyConfig], + securityUpdate: securityUpdate }; } } diff --git a/packages/grpc-js-xds/src/xds-resource-type/listener-resource-type.ts b/packages/grpc-js-xds/src/xds-resource-type/listener-resource-type.ts index 8dacbae99..dd7c1ce07 100644 --- a/packages/grpc-js-xds/src/xds-resource-type/listener-resource-type.ts +++ b/packages/grpc-js-xds/src/xds-resource-type/listener-resource-type.ts @@ -198,6 +198,21 @@ function validateFilterChain(context: XdsDecodeContext, filterChain: FilterChain trace('require_client_certificate set without validationContext'); return false; } + if (validationContext && validationContext.verify_certificate_spki.length > 0) { + return false; + } + if (validationContext && validationContext.verify_certificate_hash.length > 0) { + return false; + } + if (validationContext?.require_signed_certificate_timestamp) { + return false; + } + if (validationContext?.crl) { + return false; + } + if (validationContext?.custom_validator_config) { + return false; + } if (commonTlsContext.tls_params) { trace('tls_params set'); return false; diff --git a/packages/grpc-js-xds/test/client.ts b/packages/grpc-js-xds/test/client.ts index 79bf2627b..dc2c7f243 100644 --- a/packages/grpc-js-xds/test/client.ts +++ b/packages/grpc-js-xds/test/client.ts @@ -15,7 +15,7 @@ * */ -import { ChannelCredentials, ChannelOptions, credentials, loadPackageDefinition, ServiceError } from "@grpc/grpc-js"; +import { ChannelCredentials, ChannelOptions, credentials, loadPackageDefinition, Metadata, ServiceError } from "@grpc/grpc-js"; import { loadSync } from "@grpc/proto-loader"; import { ProtoGrpcType } from "./generated/echo"; import { EchoTestServiceClient } from "./generated/grpc/testing/EchoTestService"; @@ -76,7 +76,7 @@ export class XdsTestClient { sendOneCall(callback: (error: ServiceError | null) => void) { const deadline = new Date(); - deadline.setMilliseconds(deadline.getMilliseconds() + 500); + deadline.setMilliseconds(deadline.getMilliseconds() + 1500); this.client.echo({message: 'test'}, {deadline}, (error, value) => { callback(error); }); diff --git a/packages/grpc-js-xds/test/framework.ts b/packages/grpc-js-xds/test/framework.ts index 87973265b..af15054ac 100644 --- a/packages/grpc-js-xds/test/framework.ts +++ b/packages/grpc-js-xds/test/framework.ts @@ -24,12 +24,13 @@ import { Route } from "../src/generated/envoy/config/route/v3/Route"; import { Listener } from "../src/generated/envoy/config/listener/v3/Listener"; import { HttpConnectionManager } from "../src/generated/envoy/extensions/filters/network/http_connection_manager/v3/HttpConnectionManager"; import { AnyExtension } from "@grpc/proto-loader"; -import { CLUSTER_CONFIG_TYPE_URL, HTTP_CONNECTION_MANGER_TYPE_URL } from "../src/resources"; +import { CLUSTER_CONFIG_TYPE_URL, HTTP_CONNECTION_MANGER_TYPE_URL, UPSTREAM_TLS_CONTEXT_TYPE_URL } from "../src/resources"; import { LocalityLbEndpoints } from "../src/generated/envoy/config/endpoint/v3/LocalityLbEndpoints"; import { LbEndpoint } from "../src/generated/envoy/config/endpoint/v3/LbEndpoint"; import { ClusterConfig } from "../src/generated/envoy/extensions/clusters/aggregate/v3/ClusterConfig"; import { Any } from "../src/generated/google/protobuf/Any"; import { ControlPlaneServer } from "./xds-server"; +import { UpstreamTlsContext } from "../src/generated/envoy/extensions/transport_sockets/tls/v3/UpstreamTlsContext"; interface Endpoint { locality: Locality; @@ -71,7 +72,13 @@ export interface FakeCluster { } export class FakeEdsCluster implements FakeCluster { - constructor(private clusterName: string, private endpointName: string, private endpoints: Endpoint[], private loadBalancingPolicyOverride?: Any | 'RING_HASH') {} + constructor( + private clusterName: string, + private endpointName: string, + private endpoints: Endpoint[], + private loadBalancingPolicyOverride?: Any | 'RING_HASH' | undefined, + private upstreamTlsContext?: UpstreamTlsContext + ) {} getEndpointConfig(): ClusterLoadAssignment { return { @@ -111,6 +118,14 @@ export class FakeEdsCluster implements FakeCluster { } else { result.lb_policy = 'ROUND_ROBIN'; } + if (this.upstreamTlsContext) { + result.transport_socket = { + typed_config: { + '@type': UPSTREAM_TLS_CONTEXT_TYPE_URL, + ...this.upstreamTlsContext + } + } + } return result; } diff --git a/packages/grpc-js-xds/test/test-xds-credentials.ts b/packages/grpc-js-xds/test/test-xds-credentials.ts index 36ad3bb14..27063d3e6 100644 --- a/packages/grpc-js-xds/test/test-xds-credentials.ts +++ b/packages/grpc-js-xds/test/test-xds-credentials.ts @@ -20,14 +20,20 @@ import { createBackends } from './backend'; import { FakeEdsCluster, FakeRouteGroup, FakeServerRoute } from './framework'; import { ControlPlaneServer } from './xds-server'; import { XdsTestClient } from './client'; -import { XdsServerCredentials } from '../src'; -import { credentials, ServerCredentials } from '@grpc/grpc-js'; +import { XdsChannelCredentials, XdsServerCredentials } from '../src'; +import { credentials, ServerCredentials, experimental } from '@grpc/grpc-js'; import { readFileSync } from 'fs'; import * as path from 'path'; import { Listener } from '../src/generated/envoy/config/listener/v3/Listener'; import { DownstreamTlsContext } from '../src/generated/envoy/extensions/transport_sockets/tls/v3/DownstreamTlsContext'; import { AnyExtension } from '@grpc/proto-loader'; import { DOWNSTREAM_TLS_CONTEXT_TYPE_URL } from '../src/resources'; +import { UpstreamTlsContext } from '../src/generated/envoy/extensions/transport_sockets/tls/v3/UpstreamTlsContext'; +import { StringMatcher } from '../src/generated/envoy/type/matcher/v3/StringMatcher'; +import FileWatcherCertificateProvider = experimental.FileWatcherCertificateProvider; +import createCertificateProviderChannelCredentials = experimental.createCertificateProviderChannelCredentials; + +const caPath = path.join(__dirname, 'fixtures', 'ca.pem'); const ca = readFileSync(path.join(__dirname, 'fixtures', 'ca.pem')); const key = readFileSync(path.join(__dirname, 'fixtures', 'server1.key')); @@ -168,3 +174,196 @@ describe('Server xDS Credentials', () => { assert.strictEqual(error, null); }); }); +describe('Client xDS credentials', () => { + let xdsServer: ControlPlaneServer; + let client: XdsTestClient; + beforeEach(done => { + xdsServer = new ControlPlaneServer(); + xdsServer.startServer(error => { + done(error); + }); + }); + afterEach(() => { + client?.close(); + xdsServer?.shutdownServer(); + }); + it('Should use fallback credentials when certificate providers are not configured', async () => { + const [backend] = await createBackends(1, true, ServerCredentials.createInsecure()); + const serverRoute = new FakeServerRoute(backend.getPort(), 'serverRoute'); + xdsServer.setRdsResource(serverRoute.getRouteConfiguration()); + xdsServer.setLdsResource(serverRoute.getListener()); + xdsServer.addResponseListener((typeUrl, responseState) => { + if (responseState.state === 'NACKED') { + client?.stopCalls(); + assert.fail(`Client NACKED ${typeUrl} resource with message ${responseState.errorMessage}`); + } + }); + const cluster = new FakeEdsCluster('cluster1', 'endpoint1', [{backends: [backend], locality:{region: 'region1'}}]); + const routeGroup = new FakeRouteGroup('listener1', 'route1', [{cluster: cluster}]); + await routeGroup.startAllBackends(xdsServer); + xdsServer.setEdsResource(cluster.getEndpointConfig()); + xdsServer.setCdsResource(cluster.getClusterConfig()); + xdsServer.setRdsResource(routeGroup.getRouteConfiguration()); + xdsServer.setLdsResource(routeGroup.getListener()); + client = XdsTestClient.createFromServer('listener1', xdsServer, new XdsChannelCredentials(credentials.createInsecure())); + const error = await client.sendOneCallAsync(); + assert.strictEqual(error, null); + }); + it('Should use CA certificates when configured', async () => { + const [backend] = await createBackends(1, true, ServerCredentials.createSsl(null, [{private_key: key, cert_chain: cert}])); + const serverRoute = new FakeServerRoute(backend.getPort(), 'serverRoute'); + xdsServer.setRdsResource(serverRoute.getRouteConfiguration()); + xdsServer.setLdsResource(serverRoute.getListener()); + xdsServer.addResponseListener((typeUrl, responseState) => { + if (responseState.state === 'NACKED') { + client?.stopCalls(); + assert.fail(`Client NACKED ${typeUrl} resource with message ${responseState.errorMessage}`); + } + }); + const upstreamTlsContext: UpstreamTlsContext = { + common_tls_context: { + validation_context: { + ca_certificate_provider_instance: { + instance_name: 'test_certificates' + } + } + } + }; + const cluster = new FakeEdsCluster('cluster1', 'endpoint1', [{backends: [backend], locality:{region: 'region1'}}], undefined, upstreamTlsContext); + const routeGroup = new FakeRouteGroup('listener1', 'route1', [{cluster: cluster}]); + await routeGroup.startAllBackends(xdsServer); + xdsServer.setEdsResource(cluster.getEndpointConfig()); + xdsServer.setCdsResource(cluster.getClusterConfig()); + xdsServer.setRdsResource(routeGroup.getRouteConfiguration()); + xdsServer.setLdsResource(routeGroup.getListener()); + client = XdsTestClient.createFromServer('listener1', xdsServer, new XdsChannelCredentials(credentials.createInsecure())); + const error = await client.sendOneCallAsync(); + assert.strictEqual(error, null); + }); + it('Should use identity and CA certificates when configured', async () => { + const [backend] = await createBackends(1, true, ServerCredentials.createSsl(ca, [{private_key: key, cert_chain: cert}], true)); + const serverRoute = new FakeServerRoute(backend.getPort(), 'serverRoute'); + xdsServer.setRdsResource(serverRoute.getRouteConfiguration()); + xdsServer.setLdsResource(serverRoute.getListener()); + xdsServer.addResponseListener((typeUrl, responseState) => { + if (responseState.state === 'NACKED') { + client?.stopCalls(); + assert.fail(`Client NACKED ${typeUrl} resource with message ${responseState.errorMessage}`); + } + }); + const upstreamTlsContext: UpstreamTlsContext = { + common_tls_context: { + tls_certificate_provider_instance: { + instance_name: 'test_certificates' + }, + validation_context: { + ca_certificate_provider_instance: { + instance_name: 'test_certificates' + } + } + } + }; + const cluster = new FakeEdsCluster('cluster1', 'endpoint1', [{backends: [backend], locality:{region: 'region1'}}], undefined, upstreamTlsContext); + const routeGroup = new FakeRouteGroup('listener1', 'route1', [{cluster: cluster}]); + await routeGroup.startAllBackends(xdsServer); + xdsServer.setEdsResource(cluster.getEndpointConfig()); + xdsServer.setCdsResource(cluster.getClusterConfig()); + xdsServer.setRdsResource(routeGroup.getRouteConfiguration()); + xdsServer.setLdsResource(routeGroup.getListener()); + client = XdsTestClient.createFromServer('listener1', xdsServer, new XdsChannelCredentials(credentials.createInsecure())); + const error = await client.sendOneCallAsync(); + assert.strictEqual(error, null); + }); + describe('Subject Alternative Name matching', () => { + interface SanTestCase { + name: string; + matchers: StringMatcher[]; + expectedSuccess: boolean; + } + const testCases: SanTestCase[] = [ + { + name: 'empty match', + matchers: [], + expectedSuccess: true + }, + { + name: 'exact DNS match', + matchers: [{ + exact: 'waterzooi.test.google.be', + ignore_case: false + }], + expectedSuccess: true + }, + { + name: 'wildcard DNS match', + matchers: [{ + exact: 'foo.test.google.fr', + ignore_case: false + }], + expectedSuccess: true + }, + { + name: 'exact IP match', + matchers: [{ + exact: '192.168.1.3', + ignore_case: false + }], + expectedSuccess: true + }, + { + name: 'suffix match', + matchers: [{ + suffix: 'test.google.fr', + ignore_case: false + }], + expectedSuccess: true + }, + { + name: 'unmatched matcher', + matchers: [{ + exact: 'incorret', + ignore_case: false + }], + expectedSuccess: false + }, + ]; + for (const {name, matchers, expectedSuccess} of testCases) { + it(name, async () => { + const [backend] = await createBackends(1, true, ServerCredentials.createSsl(null, [{private_key: key, cert_chain: cert}])); + const serverRoute = new FakeServerRoute(backend.getPort(), 'serverRoute'); + xdsServer.setRdsResource(serverRoute.getRouteConfiguration()); + xdsServer.setLdsResource(serverRoute.getListener()); + xdsServer.addResponseListener((typeUrl, responseState) => { + if (responseState.state === 'NACKED') { + client?.stopCalls(); + assert.fail(`Client NACKED ${typeUrl} resource with message ${responseState.errorMessage}`); + } + }); + const upstreamTlsContext: UpstreamTlsContext = { + common_tls_context: { + validation_context: { + ca_certificate_provider_instance: { + instance_name: 'test_certificates' + }, + match_subject_alt_names: matchers + } + } + }; + const cluster = new FakeEdsCluster('cluster1', 'endpoint1', [{backends: [backend], locality:{region: 'region1'}}], undefined, upstreamTlsContext); + const routeGroup = new FakeRouteGroup('listener1', 'route1', [{cluster: cluster}]); + await routeGroup.startAllBackends(xdsServer); + xdsServer.setEdsResource(cluster.getEndpointConfig()); + xdsServer.setCdsResource(cluster.getClusterConfig()); + xdsServer.setRdsResource(routeGroup.getRouteConfiguration()); + xdsServer.setLdsResource(routeGroup.getListener()); + client = XdsTestClient.createFromServer('listener1', xdsServer, new XdsChannelCredentials(credentials.createInsecure())); + const error = await client.sendOneCallAsync(); + if (expectedSuccess) { + assert.strictEqual(error, null); + } else { + assert.ok(error); + } + }); + } + }); +}); diff --git a/packages/grpc-js/src/certificate-provider.ts b/packages/grpc-js/src/certificate-provider.ts index 6eaf8447e..6a93936a5 100644 --- a/packages/grpc-js/src/certificate-provider.ts +++ b/packages/grpc-js/src/certificate-provider.ts @@ -59,9 +59,9 @@ export interface FileWatcherCertificateProviderConfig { export class FileWatcherCertificateProvider implements CertificateProvider { private refreshTimer: NodeJS.Timeout | null = null; private fileResultPromise: Promise<[PromiseSettledResult, PromiseSettledResult, PromiseSettledResult]> | null = null; - private latestCaUpdate: CaCertificateUpdate | null = null; + private latestCaUpdate: CaCertificateUpdate | null | undefined = undefined; private caListeners: Set = new Set(); - private latestIdentityUpdate: IdentityCertificateUpdate | null = null; + private latestIdentityUpdate: IdentityCertificateUpdate | null | undefined = undefined; private identityListeners: Set = new Set(); private lastUpdateTime: Date | null = null; @@ -105,6 +105,8 @@ export class FileWatcherCertificateProvider implements CertificateProvider { this.latestCaUpdate = { caCertificate: caCertificateResult.value }; + } else { + this.latestCaUpdate = null; } for (const listener of this.identityListeners) { listener(this.latestIdentityUpdate); @@ -128,8 +130,8 @@ export class FileWatcherCertificateProvider implements CertificateProvider { } if (timeSinceLastUpdate > this.config.refreshIntervalMs * 2) { // Clear out old updates if they are definitely stale - this.latestCaUpdate = null; - this.latestIdentityUpdate = null; + this.latestCaUpdate = undefined; + this.latestIdentityUpdate = undefined; } this.refreshTimer = setInterval(() => this.updateCertificates(), this.config.refreshIntervalMs); trace('File watcher started watching'); @@ -149,7 +151,9 @@ export class FileWatcherCertificateProvider implements CertificateProvider { addCaCertificateListener(listener: CaCertificateUpdateListener): void { this.caListeners.add(listener); this.maybeStartWatchingFiles(); - process.nextTick(listener, this.latestCaUpdate); + if (this.latestCaUpdate !== undefined) { + process.nextTick(listener, this.latestCaUpdate); + } } removeCaCertificateListener(listener: CaCertificateUpdateListener): void { this.caListeners.delete(listener); @@ -158,7 +162,9 @@ export class FileWatcherCertificateProvider implements CertificateProvider { addIdentityCertificateListener(listener: IdentityCertificateUpdateListener): void { this.identityListeners.add(listener); this.maybeStartWatchingFiles(); - process.nextTick(listener, this.latestIdentityUpdate); + if (this.latestIdentityUpdate !== undefined) { + process.nextTick(listener, this.latestIdentityUpdate); + } } removeIdentityCertificateListener(listener: IdentityCertificateUpdateListener): void { this.identityListeners.delete(listener); diff --git a/packages/grpc-js/src/channel-credentials.ts b/packages/grpc-js/src/channel-credentials.ts index dee6e06f6..13c9c5520 100644 --- a/packages/grpc-js/src/channel-credentials.ts +++ b/packages/grpc-js/src/channel-credentials.ts @@ -65,6 +65,7 @@ export interface VerifyOptions { export interface SecureConnector { connect(socket: Socket): Promise; + getCallCredentials(): CallCredentials; destroy(): void; } @@ -74,24 +75,14 @@ export interface SecureConnector { * over a channel initialized with an instance of this class. */ export abstract class ChannelCredentials { - protected callCredentials: CallCredentials; - - protected constructor(callCredentials?: CallCredentials) { - this.callCredentials = callCredentials || CallCredentials.createEmpty(); - } /** * Returns a copy of this object with the included set of per-call credentials * expanded to include callCredentials. * @param callCredentials A CallCredentials object to associate with this * instance. */ - abstract compose(callCredentials: CallCredentials): ChannelCredentials; - - /** - * Gets the set of per-call credentials associated with this instance. - */ - _getCallCredentials(): CallCredentials { - return this.callCredentials; + compose(callCredentials: CallCredentials): ChannelCredentials { + return new ComposedChannelCredentialsImpl(this, callCredentials); } /** @@ -106,7 +97,7 @@ export abstract class ChannelCredentials { */ abstract _equals(other: ChannelCredentials): boolean; - abstract _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions): SecureConnector; + abstract _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions, callCredentials?: CallCredentials): SecureConnector; /** * Return a new ChannelCredentials instance with a given set of credentials. @@ -175,7 +166,7 @@ class InsecureChannelCredentialsImpl extends ChannelCredentials { super(); } - compose(callCredentials: CallCredentials): never { + override compose(callCredentials: CallCredentials): never { throw new Error('Cannot compose insecure credentials'); } _isSecure(): boolean { @@ -184,11 +175,14 @@ class InsecureChannelCredentialsImpl extends ChannelCredentials { _equals(other: ChannelCredentials): boolean { return other instanceof InsecureChannelCredentialsImpl; } - _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions): SecureConnector { + _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions, callCredentials?: CallCredentials): SecureConnector { return { connect(socket) { return Promise.resolve(socket); }, + getCallCredentials: () => { + return callCredentials ?? CallCredentials.createEmpty(); + }, destroy() {} } } @@ -251,7 +245,7 @@ function getConnectionOptions(secureContext: SecureContext, verifyOptions: Verif } class SecureConnectorImpl implements SecureConnector { - constructor(private connectionOptions: ConnectionOptions) { + constructor(private connectionOptions: ConnectionOptions, private callCredentials: CallCredentials) { } connect(socket: Socket): Promise { const tlsConnectOptions: ConnectionOptions = { @@ -267,6 +261,9 @@ class SecureConnectorImpl implements SecureConnector { }); }); } + getCallCredentials(): CallCredentials { + return this.callCredentials; + } destroy() {} } @@ -278,11 +275,6 @@ class SecureChannelCredentialsImpl extends ChannelCredentials { super(); } - compose(callCredentials: CallCredentials): ChannelCredentials { - const combinedCallCredentials = - this.callCredentials.compose(callCredentials); - return new ComposedChannelCredentialsImpl(this, combinedCallCredentials); - } _isSecure(): boolean { return true; } @@ -300,26 +292,35 @@ class SecureChannelCredentialsImpl extends ChannelCredentials { return false; } } - _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions): SecureConnector { + _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions, callCredentials?: CallCredentials): SecureConnector { const connectionOptions = getConnectionOptions(this.secureContext, this.verifyOptions, channelTarget, options); - return new SecureConnectorImpl(connectionOptions); + return new SecureConnectorImpl(connectionOptions, callCredentials ?? CallCredentials.createEmpty()); } } class CertificateProviderChannelCredentialsImpl extends ChannelCredentials { private refcount: number = 0; - private latestCaUpdate: CaCertificateUpdate | null = null; - private latestIdentityUpdate: IdentityCertificateUpdate | null = null; + /** + * `undefined` means that the certificates have not yet been loaded. `null` + * means that an attempt to load them has completed, and has failed. + */ + private latestCaUpdate: CaCertificateUpdate | null | undefined = undefined; + /** + * `undefined` means that the certificates have not yet been loaded. `null` + * means that an attempt to load them has completed, and has failed. + */ + private latestIdentityUpdate: IdentityCertificateUpdate | null | undefined = undefined; private caCertificateUpdateListener: CaCertificateUpdateListener = this.handleCaCertificateUpdate.bind(this); private identityCertificateUpdateListener: IdentityCertificateUpdateListener = this.handleIdentityCertitificateUpdate.bind(this); + private secureContextWatchers: ((context: SecureContext | null) => void)[] = []; private static SecureConnectorImpl = class implements SecureConnector { - constructor(private parent: CertificateProviderChannelCredentialsImpl, private channelTarget: GrpcUri, private options: ChannelOptions) {} + constructor(private parent: CertificateProviderChannelCredentialsImpl, private channelTarget: GrpcUri, private options: ChannelOptions, private callCredentials: CallCredentials) {} connect(socket: Socket): Promise { - return new Promise((resolve, reject) => { - const secureContext = this.parent.getLatestSecureContext(); + return new Promise(async (resolve, reject) => { + const secureContext = await this.parent.getSecureContext(); if (!secureContext) { - reject(new Error('Credentials not loaded')); + reject(new Error('Failed to load credentials')); return; } const connnectionOptions = getConnectionOptions(secureContext, this.parent.verifyOptions, this.channelTarget, this.options); @@ -336,6 +337,10 @@ class CertificateProviderChannelCredentialsImpl extends ChannelCredentials { }); } + getCallCredentials(): CallCredentials { + return this.callCredentials; + } + destroy() { this.parent.unref(); } @@ -347,14 +352,6 @@ class CertificateProviderChannelCredentialsImpl extends ChannelCredentials { ) { super(); } - compose(callCredentials: CallCredentials): ChannelCredentials { - const combinedCallCredentials = - this.callCredentials.compose(callCredentials); - return new ComposedChannelCredentialsImpl( - this, - combinedCallCredentials - ); - } _isSecure(): boolean { return true; } @@ -384,24 +381,55 @@ class CertificateProviderChannelCredentialsImpl extends ChannelCredentials { this.identityCertificateProvider?.removeIdentityCertificateListener(this.identityCertificateUpdateListener); } } - _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions): SecureConnector { + _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions, callCredentials?: CallCredentials): SecureConnector { this.ref(); - return new CertificateProviderChannelCredentialsImpl.SecureConnectorImpl(this, channelTarget, options); + return new CertificateProviderChannelCredentialsImpl.SecureConnectorImpl(this, channelTarget, options, callCredentials ?? CallCredentials.createEmpty()); + } + + private maybeUpdateWatchers() { + if (this.hasReceivedUpdates()) { + for (const watcher of this.secureContextWatchers) { + watcher(this.getLatestSecureContext()); + } + this.secureContextWatchers = []; + } } private handleCaCertificateUpdate(update: CaCertificateUpdate | null) { this.latestCaUpdate = update; + this.maybeUpdateWatchers(); } private handleIdentityCertitificateUpdate(update: IdentityCertificateUpdate | null) { this.latestIdentityUpdate = update; + this.maybeUpdateWatchers(); + } + + private hasReceivedUpdates(): boolean { + if (this.latestCaUpdate === undefined) { + return false; + } + if (this.identityCertificateProvider && this.latestIdentityUpdate === undefined) { + return false; + } + return true; + } + + private getSecureContext(): Promise { + if (this.hasReceivedUpdates()) { + return Promise.resolve(this.getLatestSecureContext()); + } else { + return new Promise(resolve => { + this.secureContextWatchers.push(resolve); + }); + } } private getLatestSecureContext(): SecureContext | null { - if (this.latestCaUpdate === null) { + if (!this.latestCaUpdate) { return null; } - if (this.identityCertificateProvider !== null && this.latestIdentityUpdate === null) { + if (this.identityCertificateProvider !== null && !this.latestIdentityUpdate) { return null; } return createSecureContext({ @@ -420,9 +448,9 @@ export function createCertificateProviderChannelCredentials(caCertificateProvide class ComposedChannelCredentialsImpl extends ChannelCredentials { constructor( private channelCredentials: ChannelCredentials, - callCreds: CallCredentials + private callCredentials: CallCredentials ) { - super(callCreds); + super(); if (!channelCredentials._isSecure()) { throw new Error('Cannot compose insecure credentials'); } @@ -451,7 +479,8 @@ class ComposedChannelCredentialsImpl extends ChannelCredentials { return false; } } - _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions): SecureConnector { - return this.channelCredentials._createSecureConnector(channelTarget, options); + _createSecureConnector(channelTarget: GrpcUri, options: ChannelOptions, callCredentials?: CallCredentials): SecureConnector { + const combinedCallCredentials = this.callCredentials.compose(callCredentials ?? CallCredentials.createEmpty()); + return this.channelCredentials._createSecureConnector(channelTarget, options, combinedCallCredentials); } } diff --git a/packages/grpc-js/src/experimental.ts b/packages/grpc-js/src/experimental.ts index a616066db..b74728cd0 100644 --- a/packages/grpc-js/src/experimental.ts +++ b/packages/grpc-js/src/experimental.ts @@ -63,5 +63,5 @@ export { FileWatcherCertificateProvider, FileWatcherCertificateProviderConfig } from './certificate-provider'; -export { createCertificateProviderChannelCredentials } from './channel-credentials'; +export { createCertificateProviderChannelCredentials, SecureConnector } from './channel-credentials'; export { SUBCHANNEL_ARGS_EXCLUDE_KEY_PREFIX } from './internal-channel'; diff --git a/packages/grpc-js/src/internal-channel.ts b/packages/grpc-js/src/internal-channel.ts index 0ca264905..94b76fa9b 100644 --- a/packages/grpc-js/src/internal-channel.ts +++ b/packages/grpc-js/src/internal-channel.ts @@ -759,7 +759,6 @@ export class InternalChannel { method, finalOptions, this.filterStackFactory.clone(), - this.credentials._getCallCredentials(), callNumber ); diff --git a/packages/grpc-js/src/load-balancing-call.ts b/packages/grpc-js/src/load-balancing-call.ts index 2b23ff7f9..150300ae7 100644 --- a/packages/grpc-js/src/load-balancing-call.ts +++ b/packages/grpc-js/src/load-balancing-call.ts @@ -161,7 +161,8 @@ export class LoadBalancingCall implements Call, DeadlineInfoProvider { ); switch (pickResult.pickResultType) { case PickResultType.COMPLETE: - this.credentials + const combinedCallCredentials = this.credentials.compose(pickResult.subchannel!.getCallCredentials()); + combinedCallCredentials .generateMetadata({ method_name: this.methodName, service_url: this.serviceUrl }) .then( credsMetadata => { diff --git a/packages/grpc-js/src/resolving-call.ts b/packages/grpc-js/src/resolving-call.ts index 2c81e7883..a341a379f 100644 --- a/packages/grpc-js/src/resolving-call.ts +++ b/packages/grpc-js/src/resolving-call.ts @@ -62,12 +62,18 @@ export class ResolvingCall implements Call { private configReceivedTime: Date | null = null; private childStartTime: Date | null = null; + /** + * Credentials configured for this specific call. Does not include + * call credentials associated with the channel credentials used to create + * the channel. + */ + private credentials: CallCredentials = CallCredentials.createEmpty(); + constructor( private readonly channel: InternalChannel, private readonly method: string, options: CallStreamOptions, private readonly filterStackFactory: FilterStackFactory, - private credentials: CallCredentials, private callNumber: number ) { this.deadline = options.deadline; @@ -351,7 +357,7 @@ export class ResolvingCall implements Call { } } setCredentials(credentials: CallCredentials): void { - this.credentials = this.credentials.compose(credentials); + this.credentials = credentials; } addStatusWatcher(watcher: (status: StatusObject) => void) { diff --git a/packages/grpc-js/src/server-credentials.ts b/packages/grpc-js/src/server-credentials.ts index 071e2b687..68fc76038 100644 --- a/packages/grpc-js/src/server-credentials.ts +++ b/packages/grpc-js/src/server-credentials.ts @@ -338,6 +338,9 @@ class InterceptorServerCredentials extends ServerCredentials { override _removeWatcher(watcher: SecureContextWatcher): void { this.childCredentials._removeWatcher(watcher); } + override _getSettings(): SecureServerOptions | null { + return this.childCredentials._getSettings(); + } } export function createServerCredentialsWithInterceptors(credentials: ServerCredentials, interceptors: ServerInterceptor[]): ServerCredentials { diff --git a/packages/grpc-js/src/subchannel-interface.ts b/packages/grpc-js/src/subchannel-interface.ts index 6c314189a..ddf37d044 100644 --- a/packages/grpc-js/src/subchannel-interface.ts +++ b/packages/grpc-js/src/subchannel-interface.ts @@ -15,6 +15,7 @@ * */ +import { CallCredentials } from './call-credentials'; import type { SubchannelRef } from './channelz'; import { ConnectivityState } from './connectivity-state'; import { Subchannel } from './subchannel'; @@ -61,6 +62,11 @@ export interface SubchannelInterface { * to avoid implementing getRealSubchannel */ realSubchannelEquals(other: SubchannelInterface): boolean; + /** + * Get the call credentials associated with the channel credentials for this + * subchannel. + */ + getCallCredentials(): CallCredentials; } export abstract class BaseSubchannelWrapper implements SubchannelInterface { @@ -134,4 +140,7 @@ export abstract class BaseSubchannelWrapper implements SubchannelInterface { realSubchannelEquals(other: SubchannelInterface): boolean { return this.getRealSubchannel() === other.getRealSubchannel(); } + getCallCredentials(): CallCredentials { + return this.child.getCallCredentials(); + } } diff --git a/packages/grpc-js/src/subchannel.ts b/packages/grpc-js/src/subchannel.ts index 3074f63eb..cdf72861f 100644 --- a/packages/grpc-js/src/subchannel.ts +++ b/packages/grpc-js/src/subchannel.ts @@ -46,6 +46,7 @@ import { import { SubchannelCallInterceptingListener } from './subchannel-call'; import { SubchannelCall } from './subchannel-call'; import { CallEventTracker, SubchannelConnector, Transport } from './transport'; +import { CallCredentials } from './call-credentials'; const TRACER_NAME = 'subchannel'; @@ -54,7 +55,7 @@ const TRACER_NAME = 'subchannel'; * to calculate it */ const KEEPALIVE_MAX_TIME_MS = ~(1 << 31); -export class Subchannel { +export class Subchannel implements SubchannelInterface { /** * The subchannel's current connectivity state. Invariant: `session` === `null` * if and only if `connectivityState` is IDLE or TRANSIENT_FAILURE. @@ -515,4 +516,7 @@ export class Subchannel { this.keepaliveTime = newKeepaliveTime; } } + getCallCredentials(): CallCredentials { + return this.secureConnector.getCallCredentials(); + } } diff --git a/packages/grpc-js/src/transport.ts b/packages/grpc-js/src/transport.ts index 97c2ffbcd..85c43479a 100644 --- a/packages/grpc-js/src/transport.ts +++ b/packages/grpc-js/src/transport.ts @@ -729,9 +729,17 @@ export class Http2SubchannelConnector implements SubchannelConnector { if (this.isShutdown) { return Promise.reject(); } - const tcpConnection = await this.tcpConnect(address, options); - const secureConnection = await secureConnector.connect(tcpConnection); - return this.createSession(secureConnection, address, options); + let tcpConnection: net.Socket | null = null; + let secureConnection: net.Socket | null = null; + try { + tcpConnection = await this.tcpConnect(address, options); + secureConnection = await secureConnector.connect(tcpConnection); + return this.createSession(secureConnection, address, options); + } catch (e) { + tcpConnection?.destroy(); + secureConnection?.destroy(); + throw e; + } } shutdown(): void { diff --git a/packages/grpc-js/test/common.ts b/packages/grpc-js/test/common.ts index 5efbf9808..f64cbcba9 100644 --- a/packages/grpc-js/test/common.ts +++ b/packages/grpc-js/test/common.ts @@ -258,6 +258,9 @@ export class MockSubchannel implements SubchannelInterface { } addHealthStateWatcher(listener: HealthListener): void {} removeHealthStateWatcher(listener: HealthListener): void {} + getCallCredentials(): grpc.CallCredentials { + return grpc.CallCredentials.createEmpty(); + } } export { assert2 }; diff --git a/packages/grpc-js/test/test-channel-credentials.ts b/packages/grpc-js/test/test-channel-credentials.ts index f65e52e91..32c9c5cd8 100644 --- a/packages/grpc-js/test/test-channel-credentials.ts +++ b/packages/grpc-js/test/test-channel-credentials.ts @@ -21,12 +21,13 @@ import * as path from 'path'; import { promisify } from 'util'; import { CallCredentials } from '../src/call-credentials'; -import { ChannelCredentials } from '../src/channel-credentials'; +import { ChannelCredentials, createCertificateProviderChannelCredentials } from '../src/channel-credentials'; import * as grpc from '../src'; import { ServiceClient, ServiceClientConstructor } from '../src/make-client'; import { assert2, loadProtoFile, mockFunction } from './common'; import { sendUnaryData, ServerUnaryCall, ServiceError } from '../src'; +import { FileWatcherCertificateProvider } from '../src/certificate-provider'; const protoFile = path.join(__dirname, 'fixtures', 'echo_service.proto'); const echoService = loadProtoFile(protoFile) @@ -87,7 +88,7 @@ describe('ChannelCredentials Implementation', () => { const channelCreds = ChannelCredentials.createSsl(); const callCreds = new CallCredentialsMock(); const composedChannelCreds = channelCreds.compose(callCreds); - assert.strictEqual(composedChannelCreds._getCallCredentials(), callCreds); + assert.ok(composedChannelCreds instanceof ChannelCredentials); }); it('should be chainable', () => { @@ -99,11 +100,9 @@ describe('ChannelCredentials Implementation', () => { .compose(callCreds2); // Build a mock object that should be an identical copy const composedCallCreds = callCreds1.compose(callCreds2); - assert.ok( - composedCallCreds._equals( - composedChannelCreds._getCallCredentials() as CallCredentialsMock - ) - ); + const composedChannelCreds2 = ChannelCredentials.createSsl() + .compose(composedCallCreds); + assert.ok(composedChannelCreds._equals(composedChannelCreds2)); }); }); }); @@ -194,4 +193,28 @@ describe('ChannelCredentials usage', () => { ); assert2.afterMustCallsSatisfied(done); }); + it('Should handle certificate providers', done => { + const certificateProvider = new FileWatcherCertificateProvider({ + caCertificateFile: `${__dirname}/fixtures/ca.pem`, + certificateFile: `${__dirname}/fixtures/server1.pem`, + privateKeyFile: `${__dirname}/fixtures/server1.pem`, + refreshIntervalMs: 1000 + }); + const channelCreds = createCertificateProviderChannelCredentials(certificateProvider, null); + const client = new echoService(`localhost:${portNum}`, channelCreds, { + 'grpc.ssl_target_name_override': hostnameOverride, + 'grpc.default_authority': hostnameOverride, + }); + client.echo( + { value: 'test value', value2: 3 }, + new grpc.Metadata({waitForReady: true}), + (error: ServiceError, response: any) => { + client.close(); + assert.ifError(error); + assert.deepStrictEqual(response, { value: 'test value', value2: 3 }); + done(); + } + ); + + }) }); From 41f3fc096d4fa3beaca9c006c84f7bd4c23e3184 Mon Sep 17 00:00:00 2001 From: Michael Lumish Date: Mon, 9 Dec 2024 14:52:12 -0500 Subject: [PATCH 2/2] Remove test that became invalid --- packages/grpc-js/test/test-channel-credentials.ts | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/packages/grpc-js/test/test-channel-credentials.ts b/packages/grpc-js/test/test-channel-credentials.ts index 32c9c5cd8..1e1fb89cd 100644 --- a/packages/grpc-js/test/test-channel-credentials.ts +++ b/packages/grpc-js/test/test-channel-credentials.ts @@ -90,20 +90,6 @@ describe('ChannelCredentials Implementation', () => { const composedChannelCreds = channelCreds.compose(callCreds); assert.ok(composedChannelCreds instanceof ChannelCredentials); }); - - it('should be chainable', () => { - const callCreds1 = new CallCredentialsMock(); - const callCreds2 = new CallCredentialsMock(); - // Associate both call credentials with channelCreds - const composedChannelCreds = ChannelCredentials.createSsl() - .compose(callCreds1) - .compose(callCreds2); - // Build a mock object that should be an identical copy - const composedCallCreds = callCreds1.compose(callCreds2); - const composedChannelCreds2 = ChannelCredentials.createSsl() - .compose(composedCallCreds); - assert.ok(composedChannelCreds._equals(composedChannelCreds2)); - }); }); });