Skip to content

Porting "Improving logging for Azure account sign-in for connection" #19465 and #19497 #19493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions src/connectionconfig/azureHelpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { sendErrorEvent } from "../telemetry/telemetry";
import { getErrorMessage, listAllIterator } from "../utils/utils";
import { MssqlVSCodeAzureSubscriptionProvider } from "../azure/MssqlVSCodeAzureSubscriptionProvider";
import { configSelectedAzureSubscriptions } from "../constants/constants";
import { Logger } from "../models/logger";

//#region VS Code integration

Expand Down Expand Up @@ -59,6 +60,7 @@ export async function confirmVscodeAzureSignin(): Promise<
*/
export async function promptForAzureSubscriptionFilter(
state: ConnectionDialogWebviewState,
logger: Logger,
): Promise<boolean> {
try {
const auth = await confirmVscodeAzureSignin();
Expand Down Expand Up @@ -90,7 +92,7 @@ export async function promptForAzureSubscriptionFilter(
return true;
} catch (error) {
state.formError = l10n.t("Error loading Azure subscriptions.");
console.error(state.formError + "\n" + getErrorMessage(error));
logger.error(state.formError + "\n" + getErrorMessage(error));
return false;
}
}
Expand Down Expand Up @@ -176,6 +178,7 @@ export async function fetchServersFromAzure(sub: AzureSubscription): Promise<Azu

export async function getAccounts(
azureAccountService: AzureAccountService,
logger: Logger,
): Promise<FormItemOptions[]> {
let accounts: IAccount[] = [];
try {
Expand All @@ -187,7 +190,7 @@ export async function getAccounts(
};
});
} catch (error) {
console.error(`Error loading Azure accounts: ${getErrorMessage(error)}`);
logger.error(`Error loading Azure accounts: ${getErrorMessage(error)}`);

sendErrorEvent(
TelemetryViews.ConnectionDialog,
Expand All @@ -213,27 +216,45 @@ export async function getAccounts(
export async function getTenants(
azureAccountService: AzureAccountService,
accountId: string,
logger: Logger,
): Promise<FormItemOptions[]> {
let tenants: ITenant[] = [];
try {
const account = (await azureAccountService.getAccounts()).find(
(account) => account.displayInfo?.userId === accountId,
);
if (!account) {

if (!account?.properties?.tenants) {
const missingProp = !account
? "account"
: !account.properties
? "properties"
: "tenants";
const message = `Unable to retrieve tenants for the selected account due to undefined ${missingProp}`;
logger.error(message);

sendErrorEvent(
TelemetryViews.ConnectionDialog,
TelemetryActions.LoadAzureTenantsForEntraAuth,
new Error(message),
true, // includeErrorMessage
undefined, // errorCode
`missing_${missingProp}`, // errorType
);

return [];
}

tenants = account.properties.tenants;
if (!tenants) {
return [];
}

return tenants.map((tenant) => {
return {
displayName: tenant.displayName,
value: tenant.id,
};
});
} catch (error) {
console.error(`Error loading Azure tenants: ${getErrorMessage(error)}`);
logger.error(`Error loading Azure tenants: ${getErrorMessage(error)}`);

sendErrorEvent(
TelemetryViews.ConnectionDialog,
Expand Down Expand Up @@ -307,9 +328,9 @@ export async function constructAzureAccountForTenant(

//#endregion

//#region Miscellaneous Auzre helpers
//#region Miscellaneous Azure helpers

function extractFromResourceId(resourceId: string, property: string): string | undefined {
export function extractFromResourceId(resourceId: string, property: string): string | undefined {
if (!property.endsWith("/")) {
property += "/";
}
Expand All @@ -322,7 +343,12 @@ function extractFromResourceId(resourceId: string, property: string): string | u
startIndex += property.length;
}

return resourceId.substring(startIndex, resourceId.indexOf("/", startIndex));
let endIndex = resourceId.indexOf("/", startIndex);
if (endIndex === -1) {
endIndex = undefined;
}

return resourceId.substring(startIndex, endIndex);
}

//#endregion
63 changes: 44 additions & 19 deletions src/connectionconfig/connectionDialogWebviewController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
// Load connection form components
this.state.formComponents = await generateConnectionComponents(
this._mainController.connectionManager,
getAccounts(this._mainController.azureAccountService),
getAccounts(this._mainController.azureAccountService, this.logger),
this.getAzureActionButtons(),
);

Expand Down Expand Up @@ -299,7 +299,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController<

this.registerReducer("filterAzureSubscriptions", async (state) => {
try {
if (await promptForAzureSubscriptionFilter(state)) {
if (await promptForAzureSubscriptionFilter(state, this.logger)) {
await this.loadAllAzureServers(state);
}
} catch (err) {
Expand Down Expand Up @@ -514,6 +514,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
const tenants = await getTenants(
this._mainController.azureAccountService,
this.state.connectionProfile.accountId,
this.logger,
);
if (tenants.length === 1) {
hiddenProperties.push("tenantId");
Expand Down Expand Up @@ -923,7 +924,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
callback: async () => {
const account = await this._mainController.azureAccountService.addAccount();
this.logger.verbose(
`Added Azure account '${account.displayInfo}', ${account.key.id}`,
`Added Azure account '${account.displayInfo?.displayName}', ${account.key.id}`,
);

const accountsComponent = this.getFormComponent(this.state, "accountId");
Expand All @@ -935,14 +936,17 @@ export class ConnectionDialogWebviewController extends FormWebviewController<

accountsComponent.options = await getAccounts(
this._mainController.azureAccountService,
this.logger,
);

this.state.connectionProfile.accountId = account.key.id;

this.logger.verbose(
`Read ${accountsComponent.options.length} Azure accounts, selecting '${account.key.id}'`,
`Read ${accountsComponent.options.length} Azure accounts: ${accountsComponent.options.map((a) => a.value).join(", ")}`,
);

this.state.connectionProfile.accountId = account.key.id;

this.logger.verbose(`Selecting '${account.key.id}'`);

this.updateState();
await this.handleAzureMFAEdits("accountId");
},
Expand All @@ -956,15 +960,28 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
(account) => account.displayInfo.userId === this.state.connectionProfile.accountId,
);
if (account) {
const session =
await this._mainController.azureAccountService.getAccountSecurityToken(
account,
undefined,
let isTokenExpired = false;
try {
const session =
await this._mainController.azureAccountService.getAccountSecurityToken(
account,
undefined,
);
isTokenExpired = !AzureController.isTokenValid(
session.token,
session.expiresOn,
);
const isTokenExpired = !AzureController.isTokenValid(
session.token,
session.expiresOn,
);
} catch (err) {
this.logger.verbose(
`Error getting token or checking validity; prompting for refresh. Error: ${getErrorMessage(err)}`,
);

this.vscodeWrapper.showErrorMessage(
"Error validating Entra authentication token; you may need to refresh your token.",
);

isTokenExpired = true;
}

if (isTokenExpired) {
actionButtons.push({
Expand All @@ -979,12 +996,18 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
this.state.connectionProfile.accountId,
);
if (account) {
const session =
await this._mainController.azureAccountService.getAccountSecurityToken(
account,
undefined,
try {
const session =
await this._mainController.azureAccountService.getAccountSecurityToken(
account,
undefined,
);
this.logger.log("Token refreshed", session.expiresOn);
} catch (err) {
this.logger.error(
`Error refreshing token: ${getErrorMessage(err)}`,
);
this.logger.log("Token refreshed", session.expiresOn);
}
}
},
});
Expand Down Expand Up @@ -1016,6 +1039,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
tenants = await getTenants(
this._mainController.azureAccountService,
this.state.connectionProfile.accountId,
this.logger,
);
if (tenantComponent) {
tenantComponent.options = tenants;
Expand All @@ -1040,6 +1064,7 @@ export class ConnectionDialogWebviewController extends FormWebviewController<
tenants = await getTenants(
this._mainController.azureAccountService,
this.state.connectionProfile.accountId,
this.logger,
);
if (tenantComponent) {
tenantComponent.options = tenants;
Expand Down
86 changes: 86 additions & 0 deletions test/unit/azureHelpers.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

import { expect } from "chai";
import { AzureAccountService } from "../../src/services/azureAccountService";
import * as sinon from "sinon";
import * as azureHelpers from "../../src/connectionconfig/azureHelpers";
import { Logger } from "../../src/models/logger";
import { IAccount } from "vscode-mssql";

suite("Azure Helpers", () => {
let sandbox: sinon.SinonSandbox;
let mockAzureAccountService: AzureAccountService;
let mockLogger: Logger;

setup(() => {
sandbox = sinon.createSandbox();
mockAzureAccountService = sandbox.createStubInstance(AzureAccountService);
mockLogger = sandbox.createStubInstance(Logger);
});

teardown(() => {
sandbox.restore();
});
test("getTenants handles error cases", async () => {
const getAccountsStub = mockAzureAccountService.getAccounts as sinon.SinonStub;
// undefined tenants
getAccountsStub.resolves([
{
displayInfo: {
userId: "test-user-id",
},
properties: {
tenants: undefined,
},
} as IAccount,
]);

let result = await azureHelpers.getTenants(
mockAzureAccountService,
"test-user-id",
mockLogger,
);
expect(result).to.be.an("array").that.is.empty;
expect(
(mockLogger.error as sinon.SinonStub).calledWithMatch("undefined tenants"),
"logger should have been called with 'undefined tenants'",
).to.be.true;

// reset mocks for next case
getAccountsStub.reset();
(mockLogger.error as sinon.SinonStub).resetHistory();

// undefined properties
getAccountsStub.resolves([
{
displayInfo: {
userId: "test-user-id",
},
properties: undefined,
} as IAccount,
]);

result = await azureHelpers.getTenants(mockAzureAccountService, "test-user-id", mockLogger);
expect(result).to.be.an("array").that.is.empty;
expect(
(mockLogger.error as sinon.SinonStub).calledWithMatch("undefined properties"),
"logger should have been called with 'undefined properties'",
).to.be.true;
});

test("extractFromResourceId", () => {
const resourceId =
"subscriptions/test-subscription/resourceGroups/test-resource-group/providers/Microsoft.Sql/servers/test-server/databases/test-database";
let result = azureHelpers.extractFromResourceId(resourceId, "servers");
expect(result).to.equal("test-server");

result = azureHelpers.extractFromResourceId(resourceId, "databases");
expect(result).to.equal("test-database");

result = azureHelpers.extractFromResourceId(resourceId, "fakeProperty");
expect(result).to.be.undefined;
});
});
30 changes: 30 additions & 0 deletions test/unit/connectionDialogWebviewController.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import {
import { CreateSessionResponse } from "../../src/models/contracts/objectExplorer/createSessionRequest";
import { TreeNodeInfo } from "../../src/objectExplorer/nodes/treeNodeInfo";
import { mockGetCapabilitiesRequest } from "./mocks";
import { AzureController } from "../../src/azure/azureController";

suite("ConnectionDialogWebviewController Tests", () => {
let sandbox: sinon.SinonSandbox;
Expand Down Expand Up @@ -491,4 +492,33 @@ suite("ConnectionDialogWebviewController Tests", () => {
});
});
});

test("getAzureActionButtons", async () => {
controller.state.connectionProfile.authenticationType = AuthenticationType.AzureMFA;
controller.state.connectionProfile.accountId = "TestEntraAccountId";

const actionButtons = await controller["getAzureActionButtons"]();
expect(actionButtons.length).to.equal(1, "Should always have the Sign In button");
expect(actionButtons[0].id).to.equal("azureSignIn");

controller.state.connectionProfile.authenticationType = AuthenticationType.AzureMFA;
controller.state.connectionProfile.accountId = "TestUserId";

const isTokenValidStub = sandbox.stub(AzureController, "isTokenValid").returns(false);

// When there's no error, we should have refreshToken button
let buttons = await controller["getAzureActionButtons"]();
expect(buttons.length).to.equal(2);
expect(buttons[1].id).to.equal("refreshToken");

// Test error handling when getAccountSecurityToken throws
isTokenValidStub.restore();
sandbox
.stub(mainController.azureAccountService, "getAccountSecurityToken")
.throws(new Error("Test error"));

buttons = await controller["getAzureActionButtons"]();
expect(buttons.length).to.equal(2);
expect(buttons[1].id).to.equal("refreshToken");
});
});
Loading