Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow standard schemas to validate endpoint values #4864

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/toolkit/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
"tsup": "^8.2.3",
"tsx": "^4.19.0",
"typescript": "^5.8.2",
"valibot": "^1.0.0",
"vite-tsconfig-paths": "^4.3.1",
"vitest": "^1.6.0",
"yargs": "^15.3.1"
Expand All @@ -124,6 +125,8 @@
"react"
],
"dependencies": {
"@standard-schema/spec": "^1.0.0",
"@standard-schema/utils": "^0.3.0",
"immer": "^10.0.3",
"redux": "^5.0.1",
"redux-thunk": "^3.1.0",
Expand Down
14 changes: 9 additions & 5 deletions packages/toolkit/src/query/core/apiState.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ export type MutationKeys<Definitions extends EndpointDefinitions> = {
}[keyof Definitions]

type BaseQuerySubState<
D extends BaseEndpointDefinition<any, any, any>,
D extends BaseEndpointDefinition<any, any, any, any>,
DataType = ResultTypeFrom<D>,
> = {
/**
Expand Down Expand Up @@ -222,7 +222,7 @@ type BaseQuerySubState<
}

export type QuerySubState<
D extends BaseEndpointDefinition<any, any, any>,
D extends BaseEndpointDefinition<any, any, any, any>,
DataType = ResultTypeFrom<D>,
> = Id<
| ({
Expand Down Expand Up @@ -252,15 +252,17 @@ export type QuerySubState<
export type InfiniteQueryDirection = 'forward' | 'backward'

export type InfiniteQuerySubState<
D extends BaseEndpointDefinition<any, any, any>,
D extends BaseEndpointDefinition<any, any, any, any>,
> =
D extends InfiniteQueryDefinition<any, any, any, any, any>
? QuerySubState<D, InfiniteData<ResultTypeFrom<D>, PageParamFrom<D>>> & {
direction?: InfiniteQueryDirection
}
: never

type BaseMutationSubState<D extends BaseEndpointDefinition<any, any, any>> = {
type BaseMutationSubState<
D extends BaseEndpointDefinition<any, any, any, any>,
> = {
requestId: string
data?: ResultTypeFrom<D>
error?:
Expand All @@ -273,7 +275,9 @@ type BaseMutationSubState<D extends BaseEndpointDefinition<any, any, any>> = {
fulfilledTimeStamp?: number
}

export type MutationSubState<D extends BaseEndpointDefinition<any, any, any>> =
export type MutationSubState<
D extends BaseEndpointDefinition<any, any, any, any>,
> =
| (({
status: QueryStatus.fulfilled
} & WithRequiredProp<
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import type { ThunkDispatch, UnknownAction } from '@reduxjs/toolkit'
import type { BaseQueryFn, BaseQueryMeta } from '../../baseQueryTypes'
import type {
BaseQueryFn,
BaseQueryMeta,
BaseQueryResult,
} from '../../baseQueryTypes'
import type { BaseEndpointDefinition } from '../../endpointDefinitions'
import { DefinitionType } from '../../endpointDefinitions'
import type { QueryCacheKey, RootState } from '../apiState'
Expand Down Expand Up @@ -32,7 +36,8 @@ export interface QueryBaseLifecycleApi<
{ type: DefinitionType.query } & BaseEndpointDefinition<
QueryArg,
BaseQuery,
ResultType
ResultType,
BaseQueryResult<BaseQuery>
>
>
/**
Expand All @@ -55,7 +60,8 @@ export type MutationBaseLifecycleApi<
{ type: DefinitionType.mutation } & BaseEndpointDefinition<
QueryArg,
BaseQuery,
ResultType
ResultType,
BaseQueryResult<BaseQuery>
>
>
}
Expand Down
143 changes: 112 additions & 31 deletions packages/toolkit/src/query/core/buildThunks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import type {
QueryDefinition,
ResultDescription,
ResultTypeFrom,
SchemaFailureHandler,
SchemaFailureInfo,
} from '../endpointDefinitions'
import {
calculateProvidedBy,
Expand Down Expand Up @@ -65,6 +67,7 @@ import {
isRejectedWithValue,
SHOULD_AUTOBATCH,
} from './rtkImports'
import { parseWithSchema, NamedSchemaError } from '../standardSchema'

export type BuildThunksApiEndpointQuery<
Definition extends QueryDefinition<any, any, any, any, any>,
Expand Down Expand Up @@ -338,6 +341,7 @@ export function buildThunks<
api,
assertTagType,
selectors,
onSchemaFailure,
}: {
baseQuery: BaseQuery
reducerPath: ReducerPath
Expand All @@ -346,6 +350,7 @@ export function buildThunks<
api: Api<BaseQuery, Definitions, ReducerPath, any>
assertTagType: AssertTagTypes
selectors: AllSelectors
onSchemaFailure: SchemaFailureHandler | undefined
}) {
type State = RootState<any, string, ReducerPath>

Expand Down Expand Up @@ -502,10 +507,13 @@ export function buildThunks<
},
) => {
const endpointDefinition = endpointDefinitions[arg.endpointName]
const { metaSchema } = endpointDefinition

try {
let transformResponse: TransformCallback =
getTransformCallbackForEndpoint(endpointDefinition, 'transformResponse')
let transformResponse = getTransformCallbackForEndpoint(
endpointDefinition,
'transformResponse',
)

const baseQueryApi = {
signal,
Expand Down Expand Up @@ -562,7 +570,16 @@ export function buildThunks<
finalQueryArg: unknown,
): Promise<QueryReturnValue> {
let result: QueryReturnValue
const { extraOptions } = endpointDefinition
const { extraOptions, argSchema, rawResponseSchema, responseSchema } =
endpointDefinition

if (argSchema) {
finalQueryArg = await parseWithSchema(
argSchema,
finalQueryArg,
'argSchema',
)
}

if (forceQueryFn) {
// upsertQueryData relies on this to pass in the user-provided value
Expand Down Expand Up @@ -617,12 +634,30 @@ export function buildThunks<

if (result.error) throw new HandledError(result.error, result.meta)

const transformedResponse = await transformResponse(
result.data,
let { data } = result

if (rawResponseSchema) {
data = await parseWithSchema(
rawResponseSchema,
result.data,
'rawResponseSchema',
)
}

let transformedResponse = await transformResponse(
data,
result.meta,
finalQueryArg,
)

if (responseSchema) {
transformedResponse = await parseWithSchema(
responseSchema,
transformedResponse,
'responseSchema',
)
}

return {
...result,
data: transformedResponse,
Expand Down Expand Up @@ -712,6 +747,14 @@ export function buildThunks<
finalQueryReturnValue = await executeRequest(arg.originalArgs)
}

if (metaSchema && finalQueryReturnValue.meta) {
finalQueryReturnValue.meta = await parseWithSchema(
metaSchema,
finalQueryReturnValue.meta,
'metaSchema',
)
}

// console.log('Final result: ', transformedData)
return fulfillWithValue(
finalQueryReturnValue.data,
Expand All @@ -721,40 +764,78 @@ export function buildThunks<
}),
)
} catch (error) {
let catchedError = error
if (catchedError instanceof HandledError) {
let transformErrorResponse: TransformCallback =
getTransformCallbackForEndpoint(
try {
let caughtError = error
if (caughtError instanceof HandledError) {
let transformErrorResponse = getTransformCallbackForEndpoint(
endpointDefinition,
'transformErrorResponse',
)
const { rawErrorResponseSchema, errorResponseSchema } =
endpointDefinition

let { value, meta } = caughtError

if (rawErrorResponseSchema) {
value = await parseWithSchema(
rawErrorResponseSchema,
value,
'rawErrorResponseSchema',
)
}

if (metaSchema) {
meta = await parseWithSchema(metaSchema, meta, 'metaSchema')
}

try {
return rejectWithValue(
await transformErrorResponse(
catchedError.value,
catchedError.meta,
try {
let transformedErrorResponse = await transformErrorResponse(
value,
meta,
arg.originalArgs,
),
addShouldAutoBatch({ baseQueryMeta: catchedError.meta }),
)
} catch (e) {
catchedError = e
)
if (errorResponseSchema) {
transformedErrorResponse = await parseWithSchema(
errorResponseSchema,
transformedErrorResponse,
'errorResponseSchema',
)
}

return rejectWithValue(
transformedErrorResponse,
addShouldAutoBatch({ baseQueryMeta: meta }),
)
} catch (e) {
caughtError = e
}
}
}
if (
typeof process !== 'undefined' &&
process.env.NODE_ENV !== 'production'
) {
console.error(
`An unhandled error occurred processing a request for the endpoint "${arg.endpointName}".
if (
typeof process !== 'undefined' &&
process.env.NODE_ENV !== 'production'
) {
console.error(
`An unhandled error occurred processing a request for the endpoint "${arg.endpointName}".
In the case of an unhandled error, no tags will be "provided" or "invalidated".`,
catchedError,
)
} else {
console.error(catchedError)
caughtError,
)
} else {
console.error(caughtError)
}
throw caughtError
} catch (error) {
if (error instanceof NamedSchemaError) {
const info: SchemaFailureInfo = {
endpoint: arg.endpointName,
arg: arg.originalArgs,
type: arg.type,
queryCacheKey: arg.type === 'query' ? arg.queryCacheKey : undefined,
}
endpointDefinition.onSchemaFailure?.(error, info)
onSchemaFailure?.(error, info)
}
throw error
}
throw catchedError
}
}

Expand Down
2 changes: 2 additions & 0 deletions packages/toolkit/src/query/core/module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ export const coreModule = ({
refetchOnFocus,
refetchOnReconnect,
invalidationBehavior,
onSchemaFailure,
},
context,
) {
Expand Down Expand Up @@ -582,6 +583,7 @@ export const coreModule = ({
serializeQueryArgs,
assertTagType,
selectors,
onSchemaFailure,
})

const { reducer, actions: sliceActions } = buildSlice({
Expand Down
3 changes: 3 additions & 0 deletions packages/toolkit/src/query/createApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { defaultSerializeQueryArgs } from './defaultSerializeQueryArgs'
import type {
EndpointBuilder,
EndpointDefinitions,
SchemaFailureHandler,
} from './endpointDefinitions'
import {
DefinitionType,
Expand Down Expand Up @@ -212,6 +213,8 @@ export interface CreateApiOptions<
NoInfer<TagTypes>,
NoInfer<ReducerPath>
>

onSchemaFailure?: SchemaFailureHandler
}

export type CreateApi<Modules extends ModuleName> = {
Expand Down
Loading