Skip to content

Commit

Permalink
Update FedAuth tests to use ManagedIdentity (#2629)
Browse files Browse the repository at this point in the history
  • Loading branch information
lilgreenbird authored Mar 6, 2025
1 parent f0f59f6 commit dea2010
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import java.util.Random;
import java.util.concurrent.atomic.AtomicReference;

import com.microsoft.sqlserver.jdbc.TestUtils;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
Expand All @@ -23,7 +22,6 @@

@RunWith(JUnitPlatform.class)
@Tag(Constants.fedAuth)
@Tag(Constants.requireSecret)
public class ConcurrentLoginTest extends FedauthCommon {

final AtomicReference<Throwable> throwableRef = new AtomicReference<Throwable>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import java.sql.DriverManager;
import java.sql.SQLException;
import java.sql.Statement;
import java.text.MessageFormat;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
Expand All @@ -29,7 +28,6 @@

@RunWith(JUnitPlatform.class)
@Tag(Constants.fedAuth)
@Tag(Constants.requireSecret)
public class ConnectionEncryptionTest extends FedauthCommon {

static String charTable = TestUtils.escapeSingleQuotes(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

@RunWith(JUnitPlatform.class)
@Tag(Constants.fedAuth)
@Tag(Constants.requireSecret)
public class ErrorMessageTest extends FedauthCommon {

String badUserName = "abc" + azureUserName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import com.azure.core.credential.AccessToken;
import com.azure.core.credential.TokenRequestContext;
import com.azure.identity.ManagedIdentityCredential;
import com.azure.identity.ManagedIdentityCredentialBuilder;
import com.microsoft.aad.msal4j.ClientCredentialFactory;
import com.microsoft.aad.msal4j.ClientCredentialParameters;
import com.microsoft.aad.msal4j.ConfidentialClientApplication;
Expand All @@ -21,6 +25,7 @@
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collections;
import java.util.Date;
import java.util.HashSet;
import java.util.Locale;
Expand Down Expand Up @@ -216,25 +221,21 @@ public static void getConfigs() throws Exception {
static void getFedauthInfo() {
int retry = 0;
long interval = THROTTLE_RETRY_INTERVAL;
ManagedIdentityCredential credential = new ManagedIdentityCredentialBuilder()
.clientId(akvProviderManagedClientId).build();

while (retry <= THROTTLE_RETRY_COUNT) {
try {
Set<String> scopes = new HashSet<>();
scopes.add(spn + "/.default");
if (null == fedauthClientApp) {
IClientCredential credential = ClientCredentialFactory.createFromSecret(applicationKey);
fedauthClientApp = ConfidentialClientApplication.builder(applicationClientID, credential)
.executorService(Executors.newFixedThreadPool(1))
.setTokenCacheAccessAspect(FedauthTokenCache.getInstance()).authority(stsurl).build();
}
TokenRequestContext requestContext = new TokenRequestContext()
.setScopes(Collections.singletonList(spn + "/.default"));

final CompletableFuture<IAuthenticationResult> future = fedauthClientApp
.acquireToken(ClientCredentialParameters.builder(scopes).build());
AccessToken token = credential.getToken(requestContext).block();

final IAuthenticationResult authenticationResult = future.get();

secondsBeforeExpiration = TimeUnit.MILLISECONDS
.toSeconds(authenticationResult.expiresOnDate().getTime() - new Date().getTime());
accessToken = authenticationResult.accessToken();
if (token != null) {
secondsBeforeExpiration = TimeUnit.MILLISECONDS
.toSeconds(token.getExpiresAt().toInstant().toEpochMilli() - new Date().getTime());
accessToken = token.getToken();
}

retry = THROTTLE_RETRY_COUNT + 1;
} catch (MsalThrottlingException te) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@

@RunWith(JUnitPlatform.class)
@Tag(Constants.fedAuth)
@Tag(Constants.requireSecret)
public class FedauthTest extends FedauthCommon {
static String charTable = TestUtils
.escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("JDBC_FedAuthTest")));
Expand Down Expand Up @@ -286,6 +285,7 @@ public void testAADPasswordApplicationName() throws Exception {
*/
@Deprecated
@Test
@Tag(Constants.requireSecret)
public void testAADServicePrincipalAuthDeprecated() {
String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipal + ";AADSecurePrincipalId=" + applicationClientID
Expand All @@ -308,6 +308,7 @@ public void testAADServicePrincipalAuthDeprecated() {
* encryption.
*/
@Test
@Tag(Constants.requireSecret)
public void testAADServicePrincipalAuth() {
String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipal + ";Username=" + applicationClientID + ";Password="
Expand All @@ -326,6 +327,7 @@ public void testAADServicePrincipalAuth() {
}

@Test
@Tag(Constants.requireSecret)
public void testAADServicePrincipalAuthFailureOnSubsequentConnectionsWithInvalidatedTokenCacheWithInvalidSecret() throws Exception {
String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipal + ";Username=" + applicationClientID + ";Password="
Expand Down Expand Up @@ -364,6 +366,7 @@ public void testActiveDirectoryPasswordFailureOnSubsequentConnectionsWithInvalid
}

@Test
@Tag(Constants.requireSecret)
public void testAADServicePrincipalCertAuthFailureOnSubsequentConnectionsWithInvalidatedTokenCacheWithInvalidPassword() throws Exception {
// Should succeed on valid cert field values
String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
Expand All @@ -389,6 +392,7 @@ public void testAADServicePrincipalCertAuthFailureOnSubsequentConnectionsWithInv
* Test invalid connection property combinations when using AAD Service Principal Authentication.
*/
@Test
@Tag(Constants.requireSecret)
public void testAADServicePrincipalAuthWrong() {
String baseUrl = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipal + ";";
Expand Down Expand Up @@ -426,6 +430,7 @@ public void testAADServicePrincipalAuthWrong() {
* encryption.
*/
@Test
@Tag(Constants.requireSecret)
public void testAADServicePrincipalCertAuth() {
// certificate from AKV has no password
String url = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
Expand All @@ -449,6 +454,7 @@ public void testAADServicePrincipalCertAuth() {
* Test invalid connection property combinations when using AAD Service Principal Certificate Authentication.
*/
@Test
@Tag(Constants.requireSecret)
public void testAADServicePrincipalCertAuthWrong() {
String baseUrl = "jdbc:sqlserver://" + azureServer + ";database=" + azureDatabase + ";authentication="
+ SqlAuthentication.ActiveDirectoryServicePrincipalCertificate + ";userName="
Expand Down Expand Up @@ -488,23 +494,6 @@ public void testAccessTokenCallbackClassConnection() throws Exception {
try (Connection conn1 = DriverManager.getConnection(cs)) {}
}

@Test
public void testAccessTokenCache() {
try {
SilentParameters silentParameters = SilentParameters.builder(Collections.singleton(spn + "/.default"))
.build();

// this will fail if not cached
CompletableFuture<IAuthenticationResult> future = fedauthClientApp.acquireTokenSilently(silentParameters);
IAuthenticationResult authenticationResult = future.get();
assertNotNull(authenticationResult.accessToken());
assertTrue(authenticationResult.accessToken().equals(accessToken), accessToken);
} catch (Exception e) {
fail(e.getMessage());
}

}

private static void validateException(String url, String resourceKey) {
try (Connection conn = DriverManager.getConnection(url)) {
fail(TestResource.getResource("R_expectedFailPassed"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.junit.platform.runner.JUnitPlatform;
import org.junit.runner.RunWith;

import com.azure.identity.ManagedIdentityCredential;
import com.azure.identity.ManagedIdentityCredentialBuilder;
import com.microsoft.sqlserver.jdbc.RandomUtil;
import com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionAzureKeyVaultProvider;
import com.microsoft.sqlserver.jdbc.SQLServerColumnEncryptionJavaKeyStoreProvider;
Expand All @@ -37,7 +39,6 @@

@RunWith(JUnitPlatform.class)
@Tag(Constants.fedAuth)
@Tag(Constants.requireSecret)
public class FedauthWithAE extends FedauthCommon {

static String cmkName1 = Constants.CMK_NAME + "fedauthAE1";
Expand Down Expand Up @@ -282,16 +283,17 @@ private SQLServerColumnEncryptionKeyStoreProvider setupKeyStoreProvider_JKS() th

private SQLServerColumnEncryptionKeyStoreProvider setupKeyStoreProvider_AKV() throws SQLServerException {
SQLServerConnection.unregisterColumnEncryptionKeyStoreProviders();
return registerAKVProvider(
new SQLServerColumnEncryptionAzureKeyVaultProvider(applicationClientID, applicationKey));
return registerAKVProvider();
}

private SQLServerColumnEncryptionKeyStoreProvider registerAKVProvider(
SQLServerColumnEncryptionKeyStoreProvider provider) throws SQLServerException {
Map<String, SQLServerColumnEncryptionKeyStoreProvider> map1 = new HashMap<String, SQLServerColumnEncryptionKeyStoreProvider>();
map1.put(provider.getName(), provider);
SQLServerConnection.registerColumnEncryptionKeyStoreProviders(map1);
return provider;
private SQLServerColumnEncryptionKeyStoreProvider registerAKVProvider() throws SQLServerException {
Map<String, SQLServerColumnEncryptionKeyStoreProvider> map = new HashMap<String, SQLServerColumnEncryptionKeyStoreProvider>();
ManagedIdentityCredential credential = new ManagedIdentityCredentialBuilder()
.clientId(akvProviderManagedClientId).build();
akvProvider = new SQLServerColumnEncryptionAzureKeyVaultProvider(credential);
map.put(Constants.AZURE_KEY_VAULT_NAME, akvProvider);
SQLServerConnection.registerColumnEncryptionKeyStoreProviders(map);
return akvProvider;
}

private void createCMK(String cmkName, String keyStoreName, String keyPath, Statement stmt) throws SQLException {
Expand Down

0 comments on commit dea2010

Please sign in to comment.