Skip to content

Commit 653ff88

Browse files
authored
Support expires_in as a number or string (#119)
1 parent 9a0b6fe commit 653ff88

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

internal/oauth2/oauth2.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ func WaitForCallback(clientConfig ClientConfig, serverConfig ServerConfig, hc *h
325325

326326
type TokenResponse struct {
327327
AccessToken string `json:"access_token,omitempty"`
328-
ExpiresIn int64 `json:"expires_in,omitempty"`
328+
ExpiresIn FlexibleInt64 `json:"expires_in,omitempty"`
329329
IDToken string `json:"id_token,omitempty"`
330330
IssuedTokenType string `json:"issued_token_type,omitempty"`
331331
RefreshToken string `json:"refresh_token,omitempty"`
@@ -334,12 +334,49 @@ type TokenResponse struct {
334334
AuthorizationDetails []map[string]interface{} `json:"authorization_details,omitempty"`
335335
}
336336

337+
// FlexibleInt64 is a type that can be unmarshaled from a JSON number or
338+
// string. This was added to support the `expires_in` field in the token
339+
// response. Typically it is expressed as a JSON number, but at least
340+
// login.microsoft.com returns the number as a string.
341+
type FlexibleInt64 int64
342+
343+
func (f *FlexibleInt64) UnmarshalJSON(b []byte) error {
344+
if len(b) == 0 {
345+
return fmt.Errorf("cannot unmarshal empty int")
346+
}
347+
348+
// check if we have a number in a string, and parse it if so
349+
if b[0] == '"' {
350+
var s string
351+
if err := json.Unmarshal(b, &s); err != nil {
352+
return err
353+
}
354+
355+
i, err := strconv.ParseInt(s, 10, 64)
356+
if err != nil {
357+
return err
358+
}
359+
360+
*f = FlexibleInt64(i)
361+
return nil
362+
}
363+
364+
// finally we assume that we have a number that's not wrapped in a string
365+
var i int64
366+
if err := json.Unmarshal(b, &i); err != nil {
367+
return err
368+
}
369+
370+
*f = FlexibleInt64(i)
371+
return nil
372+
}
373+
337374
func NewTokenResponseFromForm(f url.Values) TokenResponse {
338375
expiresIn, _ := strconv.ParseInt(f.Get("expires_in"), 10, 64)
339376

340377
return TokenResponse{
341378
AccessToken: f.Get("access_token"),
342-
ExpiresIn: expiresIn,
379+
ExpiresIn: FlexibleInt64(expiresIn),
343380
IDToken: f.Get("id_token"),
344381
IssuedTokenType: f.Get("issued_token_type"),
345382
RefreshToken: f.Get("refresh_token"),

internal/oauth2/oauth2_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package oauth2_test
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
7+
"github.com/pkg/errors"
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/cloudentity/oauth2c/internal/oauth2"
11+
)
12+
13+
func TestUnmarshalExpires(t *testing.T) {
14+
tests := map[string]struct {
15+
bytes []byte
16+
expectedValue oauth2.FlexibleInt64
17+
expectedErr error
18+
}{
19+
"number": {
20+
bytes: []byte(`{"expires_in": 3600}`),
21+
expectedValue: 3600,
22+
expectedErr: nil,
23+
},
24+
"number string": {
25+
bytes: []byte(`{"expires_in": "3600"}`),
26+
expectedValue: 3600,
27+
expectedErr: nil,
28+
},
29+
"null": {
30+
bytes: []byte(`{"expires_in": null}`),
31+
expectedValue: 0,
32+
expectedErr: nil,
33+
},
34+
"other string": {
35+
bytes: []byte(`{"expires_in": "foo"}`),
36+
expectedValue: 0,
37+
expectedErr: errors.New("invalid syntax"),
38+
},
39+
}
40+
41+
for name, test := range tests {
42+
t.Run(name, func(t *testing.T) {
43+
tokenResponse := oauth2.TokenResponse{}
44+
err := json.Unmarshal(test.bytes, &tokenResponse)
45+
if test.expectedErr != nil {
46+
require.Error(t, err)
47+
require.Contains(t, err.Error(), test.expectedErr.Error())
48+
} else {
49+
require.NoError(t, err)
50+
require.Equal(t, test.expectedValue, tokenResponse.ExpiresIn)
51+
}
52+
})
53+
}
54+
}

0 commit comments

Comments
 (0)