Skip to content

Commit

Permalink
session: Update access and refresh token in session
Browse files Browse the repository at this point in the history
Update access and refresh token in session if the ones stored in db
have expired and been refreshed. Also trigger the session update
regardless of strict validation mode to ensure token validity.

GitHub-PR: #109

Signed-off-by: Athina Plaskasoviti <athinapl@arrikto.com>
Reviewed-by: Athanasios Markou <athamark@arrikto.com>
  • Loading branch information
Athina Plaskasoviti authored and Athanasios Markou committed Mar 13, 2023
1 parent 58427f7 commit 0c4ea9a
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 18 deletions.
4 changes: 3 additions & 1 deletion authenticators/opaque.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ func (s *OpaqueTokenAuthenticator) AuthenticateRequest(r *http.Request) (*authen

ctx := common.SetTLSContext(r.Context(), s.CaBundle)

userInfo, err := oidc.GetUserInfo(ctx, s.Provider, s.Oauth2Config.TokenSource(ctx, opaque))
newToken, _, err := oidc.TokenSource(ctx, s.Oauth2Config, opaque)

userInfo, err := oidc.GetUserInfo(ctx, s.Provider, newToken)
if err != nil {
var reqErr *common.RequestError
if !errors.As(err, &reqErr) {
Expand Down
21 changes: 17 additions & 4 deletions authenticators/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,25 @@ func (sa *SessionAuthenticator) AuthenticateRequest(r *http.Request) (*authentic
return nil, false, nil
}

ctx := common.SetTLSContext(r.Context(), sa.CaBundle)
token := session.Values[sessions.UserSessionOAuth2Tokens].(oauth2.Token)

newToken, err := sessions.SaveToken(session, ctx, sa.Oauth2Config, &token, httptest.NewRecorder())
if err != nil {
logger.Errorf("Failed to refresh token: %v", err)
// Access token has expired
logger.Info("OAuth2 tokens have expired, revoking OIDC session")
revokeErr := sessions.RevokeOIDCSession(ctx, httptest.NewRecorder(),
session, sa.Provider, sa.Oauth2Config, sa.CaBundle)
if revokeErr != nil {
logger.Errorf("Failed to revoke tokens: %v", revokeErr)
}
return nil, false, err
}

// User is logged in
if sa.StrictSessionValidation {
ctx := common.SetTLSContext(r.Context(), sa.CaBundle)
token := session.Values[sessions.UserSessionOAuth2Tokens].(oauth2.Token)
// TokenSource takes care of automatically renewing the access token.
_, err := oidc.GetUserInfo(ctx, sa.Provider, sa.Oauth2Config.TokenSource(ctx, &token))
_, err = oidc.GetUserInfo(ctx, sa.Provider, newToken)
if err != nil {
var reqErr *common.RequestError
if !errors.As(err, &reqErr) {
Expand Down
31 changes: 23 additions & 8 deletions oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (u *UserInfo) Claims(v interface{}) error {

// ParseUserInfo unmarshals the response of the UserInfo endpoint
// and enforces boolean value for the EmailVerified claim.
func ParseUserInfo(body []byte) (*UserInfo, error){
func ParseUserInfo(body []byte) (*UserInfo, error) {

raw := struct {
Subject string `json:"sub"`
Expand Down Expand Up @@ -111,13 +111,33 @@ func ParseUserInfo(body []byte) (*UserInfo, error){
return userInfo, nil
}

// TokenSource is a wrapper around oauth2.Config.TokenSource that additionally
// returns a boolean indicator for a token refresh.
func TokenSource(ctx context.Context, config *oauth2.Config,
token *oauth2.Token) (*oauth2.Token, bool, error) {

tokenSource := config.TokenSource(ctx, token)

newToken, err := tokenSource.Token()
if err != nil {
return nil, false, errors.Errorf("oidc: get access token: %v", err)
}

// Check if access token has been refreshed
if (newToken.AccessToken != token.AccessToken) || (newToken.RefreshToken != token.RefreshToken) {
return newToken, true, nil
}

return token, false, nil
}

// GetUserInfo uses the token source to query the provider's user info endpoint.
// We reimplement UserInfo [1] instead of using the go-oidc's library UserInfo, in
// order to include HTTP response information in case of an error during
// contacting the UserInfo endpoint.
//
// [1]: https://github.com/coreos/go-oidc/blob/v2.1.0/oidc.go#L180
func GetUserInfo(ctx context.Context, provider Provider, tokenSource oauth2.TokenSource) (*UserInfo, error) {
func GetUserInfo(ctx context.Context, provider Provider, token *oauth2.Token) (*UserInfo, error) {

discoveryClaims := &struct {
UserInfoURL string `json:"userinfo_endpoint"`
Expand All @@ -136,11 +156,6 @@ func GetUserInfo(ctx context.Context, provider Provider, tokenSource oauth2.Toke
return nil, errors.Errorf("oidc: create GET request: %v", err)
}

token, err := tokenSource.Token()

if err != nil {
return nil, errors.Errorf("oidc: get access token: %v", err)
}
token.SetAuthHeader(req)

resp, err := common.DoRequest(ctx, req)
Expand Down Expand Up @@ -168,4 +183,4 @@ func GetUserInfo(ctx context.Context, provider Provider, tokenSource oauth2.Toke
}

return userInfo, nil
}
}
4 changes: 2 additions & 2 deletions oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ func TestGetUserInfo_ContextCancelled(t *testing.T) {

// Make a UserInfo request
_, err = GetUserInfo(context.Background(), provider,
oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "test"}))
&oauth2.Token{AccessToken: "test"})

// Check that we find a wrapped requestError
var reqErr *common.RequestError
Expand Down Expand Up @@ -187,4 +187,4 @@ func TestParseUserInfo(t *testing.T){
}
})
}
}
}
6 changes: 4 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,14 +379,16 @@ func (s *server) callback(w http.ResponseWriter, r *http.Request) {

// UserInfo endpoint to get claims
claims := map[string]interface{}{}
oidcUserInfo, err := oidc.GetUserInfo(ctx, s.provider, s.oauth2Config.TokenSource(ctx, oauth2Tokens))

newTokens, _, err := oidc.TokenSource(ctx, s.oauth2Config, oauth2Tokens)
userInfo, err := oidc.GetUserInfo(ctx, s.provider, newTokens)
if err != nil {
logger.Errorf("Not able to fetch userinfo: %v", err)
common.ReturnMessage(w, http.StatusInternalServerError, "Not able to fetch userinfo.")
return
}

if err = oidcUserInfo.Claims(&claims); err != nil {
if err = userInfo.Claims(&claims); err != nil {
logger.Errorf("Problem getting userinfo claims: %v", err)
common.ReturnMessage(w, http.StatusInternalServerError, "Not able to fetch userinfo claims.")
return
Expand Down
28 changes: 27 additions & 1 deletion sessions/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sessions
import (
"context"
"net/http"
"sync"

"github.com/arrikto/oidc-authservice/common"
"github.com/arrikto/oidc-authservice/oidc"
Expand Down Expand Up @@ -125,6 +126,31 @@ func RevokeOIDCSession(ctx context.Context, w http.ResponseWriter,
return revokeSession(ctx, w, session)
}

var mutex sync.Mutex

// SaveToken triggers oidc.TokenSource to refresh access and refresh token
// if they have expired and saves them to the session
func SaveToken(session *sessions.Session, ctx context.Context,
config *oauth2.Config, token *oauth2.Token,
w http.ResponseWriter) (*oauth2.Token, error) {

logger := common.StandardLogger()

newToken, new, err := oidc.TokenSource(ctx, config, token)

if new {
mutex.Lock()
defer mutex.Unlock()
session.Values[UserSessionOAuth2Tokens] = newToken
r := &http.Request{}
if err := session.Save(r.WithContext(ctx), w); err != nil {
logger.Fatalf("Failed to update token in session: %v", err)
}
logger.Infof("Updated token in session")
}
return newToken, err
}

// InitiateSessionStores initiates both the required stores for the:
// * users sessions
// * OIDC states
Expand Down Expand Up @@ -164,4 +190,4 @@ func InitiateSessionStores(c *common.Config) (ClosableStore, ClosableStore) {
}

return store, oidcStateStore
}
}

0 comments on commit 0c4ea9a

Please sign in to comment.