Skip to content

Commit

Permalink
Fix the upstream OAuth 2.0 callback form deserialisation (#4010)
Browse files Browse the repository at this point in the history
Fixes #3957

This was broken since #3893
  • Loading branch information
sandhose authored Feb 11, 2025
2 parents eafe017 + 8dac005 commit b0bc692
Showing 1 changed file with 43 additions and 38 deletions.
81 changes: 43 additions & 38 deletions crates/handlers/src/upstream_oauth2/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,33 @@ use crate::{impl_from_error_for_route, upstream_oauth2::cache::MetadataCache, Pr

#[derive(Serialize, Deserialize)]
pub struct Params {
state: String,
state: Option<String>,

/// An extra parameter to track whether the POST request was re-made by us
/// to the same URL to escape Same-Site cookies restrictions
#[serde(default)]
did_mas_repost_to_itself: bool,

code: Option<String>,

error: Option<ClientErrorCode>,
error_description: Option<String>,
#[allow(dead_code)]
error_uri: Option<String>,

#[serde(flatten)]
code_or_error: CodeOrError,
extra_callback_parameters: Option<serde_json::Value>,
}

#[derive(Serialize, Deserialize)]
#[serde(untagged)]
enum CodeOrError {
Code {
code: String,

#[serde(flatten)]
extra_callback_parameters: Option<serde_json::Value>,
},
Error {
error: ClientErrorCode,
error_description: Option<String>,
#[allow(dead_code)]
error_uri: Option<String>,
},
impl Params {
/// Returns true if none of the fields are set
pub fn is_empty(&self) -> bool {
self.state.is_none()
&& self.code.is_none()
&& self.error.is_none()
&& self.error_description.is_none()
&& self.error_uri.is_none()
}
}

#[derive(Debug, Error)]
Expand All @@ -86,6 +87,12 @@ pub(crate) enum RouteError {
#[error("State parameter mismatch")]
StateMismatch,

#[error("Missing state parameter")]
MissingState,

#[error("Missing code parameter")]
MissingCode,

#[error("Could not extract subject from ID token")]
ExtractSubject(#[source] minijinja::Error),

Expand Down Expand Up @@ -161,7 +168,7 @@ pub(crate) async fn handler(
PreferredLanguage(locale): PreferredLanguage,
cookie_jar: CookieJar,
Path(provider_id): Path<Ulid>,
Form(params): Form<Option<Params>>,
Form(params): Form<Params>,
) -> Result<Response, RouteError> {
let provider = repo
.upstream_oauth_provider()
Expand All @@ -172,7 +179,7 @@ pub(crate) async fn handler(

let sessions_cookie = UpstreamSessionsCookie::load(&cookie_jar);

let Some(params) = params else {
if params.is_empty() {
if let Method::GET = method {
return Err(RouteError::MissingQueryParams);
}
Expand Down Expand Up @@ -204,8 +211,19 @@ pub(crate) async fn handler(
(Some(expected), _) => return Err(RouteError::InvalidResponseMode { expected }),
}

if let Some(error) = params.error {
return Err(RouteError::ClientError {
error,
error_description: params.error_description.clone(),
});
}

let Some(state) = params.state else {
return Err(RouteError::MissingState);
};

let (session_id, _post_auth_action) = sessions_cookie
.find_session(provider_id, &params.state)
.find_session(provider_id, &state)
.map_err(|_| RouteError::MissingCookie)?;

let session = repo
Expand All @@ -219,7 +237,7 @@ pub(crate) async fn handler(
return Err(RouteError::ProviderMismatch);
}

if params.state != session.state_str {
if state != session.state_str {
// The state in the session cookie should match the one from the params
return Err(RouteError::StateMismatch);
}
Expand All @@ -230,21 +248,8 @@ pub(crate) async fn handler(
}

// Let's extract the code from the params, and return if there was an error
let (code, extra_callback_parameters) = match params.code_or_error {
CodeOrError::Error {
error,
error_description,
..
} => {
return Err(RouteError::ClientError {
error,
error_description,
})
}
CodeOrError::Code {
code,
extra_callback_parameters,
} => (code, extra_callback_parameters),
let Some(code) = params.code else {
return Err(RouteError::MissingCode);
};

let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
Expand Down Expand Up @@ -326,7 +331,7 @@ pub(crate) async fn handler(
context = context.with_id_token_claims(claims);
}

if let Some(extra_callback_parameters) = extra_callback_parameters.clone() {
if let Some(extra_callback_parameters) = params.extra_callback_parameters.clone() {
context = context.with_extra_callback_parameters(extra_callback_parameters);
}

Expand Down Expand Up @@ -432,7 +437,7 @@ pub(crate) async fn handler(
session,
&link,
token_response.id_token,
extra_callback_parameters,
params.extra_callback_parameters,
userinfo,
)
.await?;
Expand Down

0 comments on commit b0bc692

Please sign in to comment.