Skip to content

Commit 40053e6

Browse files
committed
Migrates loginGuard to middleware set with action descriptor options
1 parent 99294d4 commit 40053e6

File tree

9 files changed

+139
-30
lines changed

9 files changed

+139
-30
lines changed

cli/azd/cmd/actions/action_descriptor.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ type ActionDescriptorOptions struct {
191191
HelpOptions ActionHelpOptions
192192
// Defines grouping options for the command
193193
GroupingOptions CommandGroupOptions
194+
// Whether or not the command requires a principal login
195+
RequireLogin bool
194196
}
195197

196198
// Completion function used for cobra command flag completion

cli/azd/cmd/container.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ func registerCommonDependencies(container *ioc.NestedContainer) {
152152
ioc.RegisterInstance[auth.HttpClient](container, client)
153153

154154
// Auth
155-
container.MustRegisterSingleton(auth.NewLoggedInGuard)
156155
container.MustRegisterSingleton(auth.NewMultiTenantCredentialProvider)
157156
container.MustRegisterSingleton(func(mgr *auth.Manager) CredentialProviderFn {
158157
return mgr.CredentialForCurrentUser
@@ -579,6 +578,9 @@ func registerCommonDependencies(container *ioc.NestedContainer) {
579578
})
580579
container.MustRegisterScoped(auth.NewManager)
581580
container.MustRegisterSingleton(azapi.NewUserProfileService)
581+
container.MustRegisterScoped(func(authManager *auth.Manager) middleware.CurrentUserAuthManager {
582+
return authManager
583+
})
582584
container.MustRegisterSingleton(account.NewSubscriptionsService)
583585
container.MustRegisterSingleton(account.NewManager)
584586
container.MustRegisterSingleton(account.NewSubscriptionsManager)

cli/azd/cmd/middleware/login_guard.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
6+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
7+
"github.com/azure/azure-dev/cli/azd/cmd/actions"
8+
"github.com/azure/azure-dev/cli/azd/pkg/auth"
9+
"github.com/azure/azure-dev/cli/azd/pkg/cloud"
10+
)
11+
12+
type CurrentUserAuthManager interface {
13+
Cloud() *cloud.Cloud
14+
CredentialForCurrentUser(
15+
ctx context.Context,
16+
options *auth.CredentialForCurrentUserOptions,
17+
) (azcore.TokenCredential, error)
18+
}
19+
20+
// LoginGuardMiddleware ensures that the user is logged in otherwise it returns an error
21+
type LoginGuardMiddleware struct {
22+
authManager CurrentUserAuthManager
23+
}
24+
25+
// NewLoginGuardMiddleware creates a new instance of the LoginGuardMiddleware
26+
func NewLoginGuardMiddleware(authManager CurrentUserAuthManager) Middleware {
27+
return &LoginGuardMiddleware{
28+
authManager: authManager,
29+
}
30+
}
31+
32+
// Run ensures that the user is logged in otherwise it returns an error
33+
func (l *LoginGuardMiddleware) Run(ctx context.Context, next NextFn) (*actions.ActionResult, error) {
34+
cred, err := l.authManager.CredentialForCurrentUser(ctx, nil)
35+
if err != nil {
36+
return nil, err
37+
}
38+
39+
_, err = auth.EnsureLoggedInCredential(ctx, cred, l.authManager.Cloud())
40+
if err != nil {
41+
return nil, err
42+
}
43+
44+
// At this point we have ensured a logged in user, continue execution of the action
45+
return next(ctx)
46+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
8+
"github.com/azure/azure-dev/cli/azd/cmd/actions"
9+
"github.com/azure/azure-dev/cli/azd/pkg/auth"
10+
"github.com/azure/azure-dev/cli/azd/pkg/cloud"
11+
"github.com/azure/azure-dev/cli/azd/test/mocks"
12+
"github.com/stretchr/testify/mock"
13+
"github.com/stretchr/testify/require"
14+
)
15+
16+
func Test_LoginGuard_Run(t *testing.T) {
17+
t.Run("LoggedIn", func(t *testing.T) {
18+
mockContext := mocks.NewMockContext(context.Background())
19+
20+
mockAuthManager := &mockCurrentUserAuthManager{}
21+
mockAuthManager.On("Cloud").Return(cloud.AzurePublic())
22+
mockAuthManager.
23+
On("CredentialForCurrentUser", *mockContext.Context, mock.Anything).
24+
Return(mockContext.Credentials, nil)
25+
26+
middleware := LoginGuardMiddleware{
27+
authManager: mockAuthManager,
28+
}
29+
30+
result, err := middleware.Run(*mockContext.Context, next)
31+
require.NoError(t, err)
32+
require.NotNil(t, result)
33+
})
34+
t.Run("NotLoggedIn", func(t *testing.T) {
35+
mockContext := mocks.NewMockContext(context.Background())
36+
37+
mockAuthManager := &mockCurrentUserAuthManager{}
38+
mockAuthManager.On("Cloud").Return(cloud.AzurePublic())
39+
mockAuthManager.
40+
On("CredentialForCurrentUser", *mockContext.Context, mock.Anything).
41+
Return(nil, auth.ErrNoCurrentUser)
42+
43+
middleware := LoginGuardMiddleware{
44+
authManager: mockAuthManager,
45+
}
46+
47+
result, err := middleware.Run(*mockContext.Context, next)
48+
require.Error(t, err)
49+
require.Nil(t, result)
50+
})
51+
}
52+
53+
func next(ctx context.Context) (*actions.ActionResult, error) {
54+
return &actions.ActionResult{}, nil
55+
}
56+
57+
type mockCurrentUserAuthManager struct {
58+
mock.Mock
59+
}
60+
61+
func (m *mockCurrentUserAuthManager) Cloud() *cloud.Cloud {
62+
args := m.Called()
63+
return args.Get(0).(*cloud.Cloud)
64+
}
65+
66+
func (m *mockCurrentUserAuthManager) CredentialForCurrentUser(
67+
ctx context.Context,
68+
options *auth.CredentialForCurrentUserOptions,
69+
) (azcore.TokenCredential, error) {
70+
args := m.Called(ctx, options)
71+
72+
tokenVal := args.Get(0)
73+
if tokenVal == nil {
74+
return nil, args.Error(1)
75+
}
76+
77+
return tokenVal.(azcore.TokenCredential), args.Error(1)
78+
}

cli/azd/cmd/pipeline.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"github.com/MakeNowJust/heredoc/v2"
1111
"github.com/azure/azure-dev/cli/azd/cmd/actions"
1212
"github.com/azure/azure-dev/cli/azd/internal"
13-
"github.com/azure/azure-dev/cli/azd/pkg/auth"
1413
"github.com/azure/azure-dev/cli/azd/pkg/environment"
1514
"github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning"
1615
"github.com/azure/azure-dev/cli/azd/pkg/input"
@@ -89,6 +88,7 @@ func pipelineActions(root *actions.ActionDescriptor) *actions.ActionDescriptor {
8988
GroupingOptions: actions.CommandGroupOptions{
9089
RootLevelHelp: actions.CmdGroupBeta,
9190
},
91+
RequireLogin: true,
9292
})
9393

9494
group.Add("config", &actions.ActionDescriptorOptions{
@@ -134,7 +134,6 @@ type pipelineConfigAction struct {
134134

135135
func newPipelineConfigAction(
136136
env *environment.Environment,
137-
_ auth.LoggedInGuard,
138137
console input.Console,
139138
flags *pipelineConfigFlags,
140139
prompters prompt.Prompter,

cli/azd/cmd/root.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ func NewRootCmd(
303303
GroupingOptions: actions.CommandGroupOptions{
304304
RootLevelHelp: actions.CmdGroupStart,
305305
},
306+
RequireLogin: true,
306307
}).
307308
UseMiddleware("hooks", middleware.NewHooksMiddleware).
308309
UseMiddleware("extensions", middleware.NewExtensionsMiddleware)
@@ -356,6 +357,9 @@ func NewRootCmd(
356357
UseMiddleware("ux", middleware.NewUxMiddleware).
357358
UseMiddlewareWhen("telemetry", middleware.NewTelemetryMiddleware, func(descriptor *actions.ActionDescriptor) bool {
358359
return !descriptor.Options.DisableTelemetry
360+
}).
361+
UseMiddlewareWhen("loginGuard", middleware.NewLoginGuardMiddleware, func(descriptor *actions.ActionDescriptor) bool {
362+
return descriptor.Options.RequireLogin
359363
})
360364

361365
// Register common dependencies for the IoC rootContainer

cli/azd/cmd/up.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"github.com/azure/azure-dev/cli/azd/cmd/actions"
1414
"github.com/azure/azure-dev/cli/azd/internal"
1515
"github.com/azure/azure-dev/cli/azd/internal/cmd"
16-
"github.com/azure/azure-dev/cli/azd/pkg/auth"
1716
"github.com/azure/azure-dev/cli/azd/pkg/environment"
1817
"github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning"
1918
"github.com/azure/azure-dev/cli/azd/pkg/infra/provisioning/bicep"
@@ -83,7 +82,6 @@ func newUpAction(
8382
flags *upFlags,
8483
console input.Console,
8584
env *environment.Environment,
86-
_ auth.LoggedInGuard,
8785
projectConfig *project.ProjectConfig,
8886
provisioningManager *provisioning.Manager,
8987
envManager environment.Manager,

cli/azd/pkg/auth/guard.go

Lines changed: 0 additions & 25 deletions
This file was deleted.

cli/azd/pkg/auth/manager.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ func (m *Manager) LoginScopes() []string {
173173
return LoginScopes(m.cloud)
174174
}
175175

176+
// Cloud returns the cloud that the manager is configured to use.
177+
func (m *Manager) Cloud() *cloud.Cloud {
178+
return m.cloud
179+
}
180+
176181
func loginScopesMap(cloud *cloud.Cloud) map[string]struct{} {
177182
resourceManagerUrl := cloud.Configuration.Services[azcloud.ResourceManager].Endpoint
178183

0 commit comments

Comments
 (0)