Skip to content

Commit 7ecfa2e

Browse files
committed
add BasicAuthWithBcryptHashAndPrompt middleware
1 parent 3c3ee31 commit 7ecfa2e

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

basic_auth.go

+27
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,33 @@ func BasicAuthWithPrompt(user, passwd string) func(http.Handler) http.Handler {
113113
}
114114
}
115115

116+
// BasicAuthWithBcryptHashAndPrompt middleware requires basic auth and matches user & bcrypt hashed password
117+
// If the user is not authorized, it will prompt for basic auth
118+
func BasicAuthWithBcryptHashAndPrompt(user, hashedPassword string) func(http.Handler) http.Handler {
119+
checkFn := func(reqUser, reqPasswd string) bool {
120+
if reqUser != user {
121+
return false
122+
}
123+
err := bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(reqPasswd))
124+
return err == nil
125+
}
126+
127+
return func(h http.Handler) http.Handler {
128+
fn := func(w http.ResponseWriter, r *http.Request) {
129+
// extract basic auth from request
130+
u, p, ok := r.BasicAuth()
131+
if ok && checkFn(u, p) {
132+
h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), contextKey(baContextKey), true)))
133+
return
134+
}
135+
// not authorized, prompt for basic auth
136+
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
137+
http.Error(w, "Unauthorized", http.StatusUnauthorized)
138+
}
139+
return http.HandlerFunc(fn)
140+
}
141+
}
142+
116143
// GenerateBcryptHash generates a bcrypt hash from a password
117144
func GenerateBcryptHash(password string) (string, error) {
118145
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)

basic_auth_test.go

+84
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,87 @@ func TestArgon2InvalidInputs(t *testing.T) {
360360
assert.Equal(t, http.StatusForbidden, resp.StatusCode)
361361
})
362362
}
363+
364+
func TestBasicAuthWithBcryptHashAndPrompt(t *testing.T) {
365+
hashedPassword, err := bcrypt.GenerateFromPassword([]byte("good"), bcrypt.MinCost)
366+
require.NoError(t, err)
367+
t.Logf("hashed password: %s", string(hashedPassword))
368+
369+
mw := BasicAuthWithBcryptHashAndPrompt("dev", string(hashedPassword))
370+
371+
ts := httptest.NewServer(mw(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
372+
t.Logf("request %s", r.URL)
373+
w.WriteHeader(http.StatusOK)
374+
_, err := w.Write([]byte("blah"))
375+
require.NoError(t, err)
376+
assert.True(t, IsAuthorized(r.Context()))
377+
})))
378+
defer ts.Close()
379+
380+
u := fmt.Sprintf("%s%s", ts.URL, "/something")
381+
client := http.Client{Timeout: 5 * time.Second}
382+
383+
tests := []struct {
384+
name string
385+
username string
386+
password string
387+
expectedStatus int
388+
checkPrompt bool
389+
}{
390+
{
391+
name: "no auth provided",
392+
username: "",
393+
password: "",
394+
expectedStatus: http.StatusUnauthorized,
395+
checkPrompt: true,
396+
},
397+
{
398+
name: "correct credentials",
399+
username: "dev",
400+
password: "good",
401+
expectedStatus: http.StatusOK,
402+
checkPrompt: false,
403+
},
404+
{
405+
name: "wrong username",
406+
username: "wrong",
407+
password: "good",
408+
expectedStatus: http.StatusUnauthorized,
409+
checkPrompt: true,
410+
},
411+
{
412+
name: "wrong password",
413+
username: "dev",
414+
password: "bad",
415+
expectedStatus: http.StatusUnauthorized,
416+
checkPrompt: true,
417+
},
418+
{
419+
name: "empty password",
420+
username: "dev",
421+
password: "",
422+
expectedStatus: http.StatusUnauthorized,
423+
checkPrompt: true,
424+
},
425+
}
426+
427+
for _, tc := range tests {
428+
t.Run(tc.name, func(t *testing.T) {
429+
req, err := http.NewRequest("GET", u, http.NoBody)
430+
require.NoError(t, err)
431+
432+
if tc.username != "" || tc.password != "" {
433+
req.SetBasicAuth(tc.username, tc.password)
434+
}
435+
436+
resp, err := client.Do(req)
437+
require.NoError(t, err)
438+
assert.Equal(t, tc.expectedStatus, resp.StatusCode)
439+
440+
if tc.checkPrompt {
441+
assert.Equal(t, `Basic realm="restricted", charset="UTF-8"`, resp.Header.Get("WWW-Authenticate"),
442+
"should include WWW-Authenticate header")
443+
}
444+
})
445+
}
446+
}

0 commit comments

Comments
 (0)