Skip to content

Commit 9cd0882

Browse files
authored
Merge pull request #14 from orsinium-labs/ws-auth
Support Authorization over WebSocket
2 parents 6772915 + 43d6076 commit 9cd0882

File tree

3 files changed

+137
-7
lines changed

3 files changed

+137
-7
lines changed

josh.go

-4
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,6 @@ func Read[T any](t string, r io.Reader) (Data[T], error) {
9999

100100
// Resp is a response type.
101101
//
102-
// The generic type T is the type of the data response.
103-
//
104-
// If the handler never returns data, use [Void] instead.
105-
//
106102
// https://jsonapi.org/format/#document-top-level
107103
type Resp struct {
108104
// The response status code.

middlewares/auth.go

+38-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,16 @@ type AuthValidator[U any] func(string) (U, error)
1313
//
1414
// The validators input is the "Bearer" token provided in the "Authorization" request header.
1515
//
16-
// If the validator returns an error, that erro is immediately returned
16+
// If the "Authorization" header is not provided, we also check the "Sec-WebSocket-Protocol"
17+
// as a necessary workaround for authenticating WebSocket requests from browsers
18+
// because the browser WebSocket API doesn't allow setting custom headers:
19+
//
20+
// https://github.com/whatwg/websockets/issues/16
21+
//
22+
// The very first protocol after the "Authorization" protocol is treated as
23+
// the authorization token.
24+
//
25+
// If the validator returns an error, that error is immediately returned
1726
// as an "Unathorized" response. Otherwise, the returned value (typically, the user
1827
// or their ID) is added into the request context using [josh.WithSingleton].
1928
func Auth[U any](v AuthValidator[U], h josh.Handler) josh.Handler {
@@ -31,17 +40,43 @@ func Auth[U any](v AuthValidator[U], h josh.Handler) josh.Handler {
3140

3241
func validateRequest[U any](validator AuthValidator[U], r josh.Req) (U, error) {
3342
header := r.Header.Get("Authorization")
34-
var def U
3543
if header == "" {
36-
return def, errors.New("Authorization header not found")
44+
// If Authorization header is not provided,
45+
// try authenticating it as a WebSocket request.
46+
return validateWSRequest(validator, r)
3747
}
3848
token, hasPrefix := strings.CutPrefix(header, "Bearer ")
3949
if !hasPrefix {
50+
var def U
4051
return def, errors.New("Unsupported Authorization type")
4152
}
4253
token = strings.TrimSpace(token)
4354
if token == "" {
55+
var def U
4456
return def, errors.New("Authorization token is empty")
4557
}
4658
return validator(token)
4759
}
60+
61+
func validateWSRequest[U any](validator AuthValidator[U], r josh.Req) (U, error) {
62+
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-WebSocket-Protocol
63+
headers := r.Header.Values("Sec-WebSocket-Protocol")
64+
var def U
65+
if len(headers) == 0 {
66+
return def, errors.New("Authorization header not found")
67+
}
68+
69+
foundAuth := false
70+
for _, tokens := range headers {
71+
for _, token := range strings.Split(tokens, ",") {
72+
token = strings.TrimSpace(token)
73+
if token == "Authorization" {
74+
foundAuth = true
75+
} else if foundAuth {
76+
return validator(token)
77+
}
78+
}
79+
}
80+
81+
return def, errors.New("Authorization header not found")
82+
}

middlewares/auth_test.go

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
package middlewares_test
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
10+
"github.com/orsinium-labs/josh"
11+
"github.com/orsinium-labs/josh/middlewares"
12+
)
13+
14+
func TestAuth(t *testing.T) {
15+
var req josh.Req
16+
url := "http://example.com/foo"
17+
18+
req = httptest.NewRequest("GET", url, nil)
19+
checkAuth(t, req, 401)
20+
21+
req = httptest.NewRequest("GET", url, nil)
22+
req.Header.Add("Authorization", "secret")
23+
checkAuth(t, req, 401)
24+
25+
req = httptest.NewRequest("GET", url, nil)
26+
req.Header.Add("Authorization", "Bearer ohno")
27+
checkAuth(t, req, 401)
28+
29+
req = httptest.NewRequest("GET", url, nil)
30+
req.Header.Add("Authorization", "Bearer secret")
31+
checkAuth(t, req, 200)
32+
33+
req = httptest.NewRequest("GET", url, nil)
34+
req.Header.Add("Sec-WebSocket-Protocol", "secret")
35+
checkAuth(t, req, 401)
36+
37+
req = httptest.NewRequest("GET", url, nil)
38+
req.Header.Add("Sec-WebSocket-Protocol", "secret, Authorization")
39+
checkAuth(t, req, 401)
40+
41+
req = httptest.NewRequest("GET", url, nil)
42+
req.Header.Add("Sec-WebSocket-Protocol", "Authorization, ohno")
43+
checkAuth(t, req, 401)
44+
45+
req = httptest.NewRequest("GET", url, nil)
46+
req.Header.Add("Sec-WebSocket-Protocol", "Authorization, secret")
47+
checkAuth(t, req, 200)
48+
49+
req = httptest.NewRequest("GET", url, nil)
50+
req.Header.Add("Sec-WebSocket-Protocol", "Authorization")
51+
req.Header.Add("Sec-WebSocket-Protocol", "secret")
52+
checkAuth(t, req, 200)
53+
54+
req = httptest.NewRequest("GET", url, nil)
55+
req.Header.Add("Sec-WebSocket-Protocol", "secret")
56+
req.Header.Add("Sec-WebSocket-Protocol", "Authorization")
57+
checkAuth(t, req, 401)
58+
59+
req = httptest.NewRequest("GET", url, nil)
60+
req.Header.Add("Sec-WebSocket-Protocol", "Authorization")
61+
req.Header.Add("Sec-WebSocket-Protocol", "ohno")
62+
req.Header.Add("Sec-WebSocket-Protocol", "secret")
63+
checkAuth(t, req, 401)
64+
}
65+
66+
func checkAuth(t *testing.T, req *http.Request, code int) {
67+
t.Helper()
68+
type User string
69+
h := func(r josh.Req) josh.Resp {
70+
u, err := josh.GetSingleton[User](r)
71+
if err != nil {
72+
t.Fatal(err)
73+
}
74+
if u != "Aragorn" {
75+
t.Fatalf("got %s, expected Aragorn", u)
76+
}
77+
return josh.Ok("all is good")
78+
}
79+
v := func(token string) (User, error) {
80+
if token == "secret" {
81+
return "Aragorn", nil
82+
}
83+
return "", fmt.Errorf("bad token: %s", token)
84+
}
85+
h = middlewares.Auth(v, h)
86+
hh := josh.Wrap(h)
87+
88+
w := httptest.NewRecorder()
89+
hh(w, req)
90+
resp := w.Result()
91+
body, err := io.ReadAll(resp.Body)
92+
if err != nil {
93+
t.Fatalf("read body: %v", err)
94+
}
95+
t.Log(string(body))
96+
if resp.StatusCode != code {
97+
t.Fatalf("got %d, expected %d", resp.StatusCode, code)
98+
}
99+
}

0 commit comments

Comments
 (0)