From 65e96774c2e821f5bad43c7d707687751bb7ed06 Mon Sep 17 00:00:00 2001 From: Benjin Dubishar Date: Wed, 28 May 2025 09:07:28 -0700 Subject: [PATCH 1/2] Improving logging for Azure account sign-in for connection (#19465) * improving logging for Azure account sign-in * Error handling around stale credentials * another case * revert styling * adding tests --- src/connectionconfig/azureHelpers.ts | 46 +++++++--- .../connectionDialogWebviewController.ts | 53 ++++++++---- test/unit/azureHelpers.test.ts | 86 +++++++++++++++++++ .../connectionDialogWebviewController.test.ts | 30 +++++++ 4 files changed, 190 insertions(+), 25 deletions(-) create mode 100644 test/unit/azureHelpers.test.ts diff --git a/src/connectionconfig/azureHelpers.ts b/src/connectionconfig/azureHelpers.ts index 405ba61a57..f55e67d65b 100644 --- a/src/connectionconfig/azureHelpers.ts +++ b/src/connectionconfig/azureHelpers.ts @@ -22,6 +22,7 @@ import { sendErrorEvent } from "../telemetry/telemetry"; import { getErrorMessage, listAllIterator } from "../utils/utils"; import { MssqlVSCodeAzureSubscriptionProvider } from "../azure/MssqlVSCodeAzureSubscriptionProvider"; import { configSelectedAzureSubscriptions } from "../constants/constants"; +import { Logger } from "../models/logger"; //#region VS Code integration @@ -59,6 +60,7 @@ export async function confirmVscodeAzureSignin(): Promise< */ export async function promptForAzureSubscriptionFilter( state: ConnectionDialogWebviewState, + logger: Logger, ): Promise { try { const auth = await confirmVscodeAzureSignin(); @@ -90,7 +92,7 @@ export async function promptForAzureSubscriptionFilter( return true; } catch (error) { state.formError = l10n.t("Error loading Azure subscriptions."); - console.error(state.formError + "\n" + getErrorMessage(error)); + logger.error(state.formError + "\n" + getErrorMessage(error)); return false; } } @@ -176,6 +178,7 @@ export async function fetchServersFromAzure(sub: AzureSubscription): Promise { let accounts: IAccount[] = []; try { @@ -187,7 +190,7 @@ export async function getAccounts( }; }); } catch (error) { - console.error(`Error loading Azure accounts: ${getErrorMessage(error)}`); + logger.error(`Error loading Azure accounts: ${getErrorMessage(error)}`); sendErrorEvent( TelemetryViews.ConnectionDialog, @@ -213,19 +216,37 @@ export async function getAccounts( export async function getTenants( azureAccountService: AzureAccountService, accountId: string, + logger: Logger, ): Promise { let tenants: ITenant[] = []; try { const account = (await azureAccountService.getAccounts()).find( (account) => account.displayInfo?.userId === accountId, ); - if (!account) { + + if (!account?.properties?.tenants) { + const missingProp = !account + ? "account" + : !account.properties + ? "properties" + : "tenants"; + const message = `Unable to retrieve tenants for the selected account due to undefined ${missingProp}`; + logger.error(message); + + sendErrorEvent( + TelemetryViews.ConnectionDialog, + TelemetryActions.LoadAzureTenantsForEntraAuth, + new Error(message), + true, // includeErrorMessage + undefined, // errorCode + `missing_${missingProp}`, // errorType + ); + return []; } + tenants = account.properties.tenants; - if (!tenants) { - return []; - } + return tenants.map((tenant) => { return { displayName: tenant.displayName, @@ -233,7 +254,7 @@ export async function getTenants( }; }); } catch (error) { - console.error(`Error loading Azure tenants: ${getErrorMessage(error)}`); + logger.error(`Error loading Azure tenants: ${getErrorMessage(error)}`); sendErrorEvent( TelemetryViews.ConnectionDialog, @@ -307,9 +328,9 @@ export async function constructAzureAccountForTenant( //#endregion -//#region Miscellaneous Auzre helpers +//#region Miscellaneous Azure helpers -function extractFromResourceId(resourceId: string, property: string): string | undefined { +export function extractFromResourceId(resourceId: string, property: string): string | undefined { if (!property.endsWith("/")) { property += "/"; } @@ -322,7 +343,12 @@ function extractFromResourceId(resourceId: string, property: string): string | u startIndex += property.length; } - return resourceId.substring(startIndex, resourceId.indexOf("/", startIndex)); + let endIndex = resourceId.indexOf("/", startIndex); + if (endIndex === -1) { + endIndex = undefined; + } + + return resourceId.substring(startIndex, endIndex); } //#endregion diff --git a/src/connectionconfig/connectionDialogWebviewController.ts b/src/connectionconfig/connectionDialogWebviewController.ts index b6941c863f..047bcaed45 100644 --- a/src/connectionconfig/connectionDialogWebviewController.ts +++ b/src/connectionconfig/connectionDialogWebviewController.ts @@ -155,7 +155,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController< // Load connection form components this.state.formComponents = await generateConnectionComponents( this._mainController.connectionManager, - getAccounts(this._mainController.azureAccountService), + getAccounts(this._mainController.azureAccountService, this.logger), this.getAzureActionButtons(), ); @@ -299,7 +299,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController< this.registerReducer("filterAzureSubscriptions", async (state) => { try { - if (await promptForAzureSubscriptionFilter(state)) { + if (await promptForAzureSubscriptionFilter(state, this.logger)) { await this.loadAllAzureServers(state); } } catch (err) { @@ -514,6 +514,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController< const tenants = await getTenants( this._mainController.azureAccountService, this.state.connectionProfile.accountId, + this.logger, ); if (tenants.length === 1) { hiddenProperties.push("tenantId"); @@ -935,6 +936,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController< accountsComponent.options = await getAccounts( this._mainController.azureAccountService, + this.logger, ); this.state.connectionProfile.accountId = account.key.id; @@ -956,15 +958,28 @@ export class ConnectionDialogWebviewController extends FormWebviewController< (account) => account.displayInfo.userId === this.state.connectionProfile.accountId, ); if (account) { - const session = - await this._mainController.azureAccountService.getAccountSecurityToken( - account, - undefined, + let isTokenExpired = false; + try { + const session = + await this._mainController.azureAccountService.getAccountSecurityToken( + account, + undefined, + ); + isTokenExpired = !AzureController.isTokenValid( + session.token, + session.expiresOn, ); - const isTokenExpired = !AzureController.isTokenValid( - session.token, - session.expiresOn, - ); + } catch (err) { + this.logger.verbose( + `Error getting token or checking validity; prompting for refresh. Error: ${getErrorMessage(err)}`, + ); + + this.vscodeWrapper.showErrorMessage( + "Error validating Entra authentication token; you may need to refresh your token.", + ); + + isTokenExpired = true; + } if (isTokenExpired) { actionButtons.push({ @@ -979,12 +994,18 @@ export class ConnectionDialogWebviewController extends FormWebviewController< this.state.connectionProfile.accountId, ); if (account) { - const session = - await this._mainController.azureAccountService.getAccountSecurityToken( - account, - undefined, + try { + const session = + await this._mainController.azureAccountService.getAccountSecurityToken( + account, + undefined, + ); + this.logger.log("Token refreshed", session.expiresOn); + } catch (err) { + this.logger.error( + `Error refreshing token: ${getErrorMessage(err)}`, ); - this.logger.log("Token refreshed", session.expiresOn); + } } }, }); @@ -1016,6 +1037,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController< tenants = await getTenants( this._mainController.azureAccountService, this.state.connectionProfile.accountId, + this.logger, ); if (tenantComponent) { tenantComponent.options = tenants; @@ -1040,6 +1062,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController< tenants = await getTenants( this._mainController.azureAccountService, this.state.connectionProfile.accountId, + this.logger, ); if (tenantComponent) { tenantComponent.options = tenants; diff --git a/test/unit/azureHelpers.test.ts b/test/unit/azureHelpers.test.ts new file mode 100644 index 0000000000..14cd592e88 --- /dev/null +++ b/test/unit/azureHelpers.test.ts @@ -0,0 +1,86 @@ +/*--------------------------------------------------------------------------------------------- + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for license information. + *--------------------------------------------------------------------------------------------*/ + +import { expect } from "chai"; +import { AzureAccountService } from "../../src/services/azureAccountService"; +import * as sinon from "sinon"; +import * as azureHelpers from "../../src/connectionconfig/azureHelpers"; +import { Logger } from "../../src/models/logger"; +import { IAccount } from "vscode-mssql"; + +suite("Azure Helpers", () => { + let sandbox: sinon.SinonSandbox; + let mockAzureAccountService: AzureAccountService; + let mockLogger: Logger; + + setup(() => { + sandbox = sinon.createSandbox(); + mockAzureAccountService = sandbox.createStubInstance(AzureAccountService); + mockLogger = sandbox.createStubInstance(Logger); + }); + + teardown(() => { + sandbox.restore(); + }); + test("getTenants handles error cases", async () => { + const getAccountsStub = mockAzureAccountService.getAccounts as sinon.SinonStub; + // undefined tenants + getAccountsStub.resolves([ + { + displayInfo: { + userId: "test-user-id", + }, + properties: { + tenants: undefined, + }, + } as IAccount, + ]); + + let result = await azureHelpers.getTenants( + mockAzureAccountService, + "test-user-id", + mockLogger, + ); + expect(result).to.be.an("array").that.is.empty; + expect( + (mockLogger.error as sinon.SinonStub).calledWithMatch("undefined tenants"), + "logger should have been called with 'undefined tenants'", + ).to.be.true; + + // reset mocks for next case + getAccountsStub.reset(); + (mockLogger.error as sinon.SinonStub).resetHistory(); + + // undefined properties + getAccountsStub.resolves([ + { + displayInfo: { + userId: "test-user-id", + }, + properties: undefined, + } as IAccount, + ]); + + result = await azureHelpers.getTenants(mockAzureAccountService, "test-user-id", mockLogger); + expect(result).to.be.an("array").that.is.empty; + expect( + (mockLogger.error as sinon.SinonStub).calledWithMatch("undefined properties"), + "logger should have been called with 'undefined properties'", + ).to.be.true; + }); + + test("extractFromResourceId", () => { + const resourceId = + "subscriptions/test-subscription/resourceGroups/test-resource-group/providers/Microsoft.Sql/servers/test-server/databases/test-database"; + let result = azureHelpers.extractFromResourceId(resourceId, "servers"); + expect(result).to.equal("test-server"); + + result = azureHelpers.extractFromResourceId(resourceId, "databases"); + expect(result).to.equal("test-database"); + + result = azureHelpers.extractFromResourceId(resourceId, "fakeProperty"); + expect(result).to.be.undefined; + }); +}); diff --git a/test/unit/connectionDialogWebviewController.test.ts b/test/unit/connectionDialogWebviewController.test.ts index b77ff92183..b4c2c6de4d 100644 --- a/test/unit/connectionDialogWebviewController.test.ts +++ b/test/unit/connectionDialogWebviewController.test.ts @@ -38,6 +38,7 @@ import { import { CreateSessionResponse } from "../../src/models/contracts/objectExplorer/createSessionRequest"; import { TreeNodeInfo } from "../../src/objectExplorer/nodes/treeNodeInfo"; import { mockGetCapabilitiesRequest } from "./mocks"; +import { AzureController } from "../../src/azure/azureController"; suite("ConnectionDialogWebviewController Tests", () => { let sandbox: sinon.SinonSandbox; @@ -491,4 +492,33 @@ suite("ConnectionDialogWebviewController Tests", () => { }); }); }); + + test("getAzureActionButtons", async () => { + controller.state.connectionProfile.authenticationType = AuthenticationType.AzureMFA; + controller.state.connectionProfile.accountId = "TestEntraAccountId"; + + const actionButtons = await controller["getAzureActionButtons"](); + expect(actionButtons.length).to.equal(1, "Should always have the Sign In button"); + expect(actionButtons[0].id).to.equal("azureSignIn"); + + controller.state.connectionProfile.authenticationType = AuthenticationType.AzureMFA; + controller.state.connectionProfile.accountId = "TestUserId"; + + const isTokenValidStub = sandbox.stub(AzureController, "isTokenValid").returns(false); + + // When there's no error, we should have refreshToken button + let buttons = await controller["getAzureActionButtons"](); + expect(buttons.length).to.equal(2); + expect(buttons[1].id).to.equal("refreshToken"); + + // Test error handling when getAccountSecurityToken throws + isTokenValidStub.restore(); + sandbox + .stub(mainController.azureAccountService, "getAccountSecurityToken") + .throws(new Error("Test error")); + + buttons = await controller["getAzureActionButtons"](); + expect(buttons.length).to.equal(2); + expect(buttons[1].id).to.equal("refreshToken"); + }); }); From 47441d97a1b24a1fe2937acf269e595ddb1e8b1a Mon Sep 17 00:00:00 2001 From: Benjin Dubishar Date: Wed, 28 May 2025 11:15:05 -0700 Subject: [PATCH 2/2] Additional entra logging (#19497) --- .../connectionDialogWebviewController.ts | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/connectionconfig/connectionDialogWebviewController.ts b/src/connectionconfig/connectionDialogWebviewController.ts index 047bcaed45..e6625eeb7e 100644 --- a/src/connectionconfig/connectionDialogWebviewController.ts +++ b/src/connectionconfig/connectionDialogWebviewController.ts @@ -924,7 +924,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController< callback: async () => { const account = await this._mainController.azureAccountService.addAccount(); this.logger.verbose( - `Added Azure account '${account.displayInfo}', ${account.key.id}`, + `Added Azure account '${account.displayInfo?.displayName}', ${account.key.id}`, ); const accountsComponent = this.getFormComponent(this.state, "accountId"); @@ -939,12 +939,14 @@ export class ConnectionDialogWebviewController extends FormWebviewController< this.logger, ); - this.state.connectionProfile.accountId = account.key.id; - this.logger.verbose( - `Read ${accountsComponent.options.length} Azure accounts, selecting '${account.key.id}'`, + `Read ${accountsComponent.options.length} Azure accounts: ${accountsComponent.options.map((a) => a.value).join(", ")}`, ); + this.state.connectionProfile.accountId = account.key.id; + + this.logger.verbose(`Selecting '${account.key.id}'`); + this.updateState(); await this.handleAzureMFAEdits("accountId"); },