diff --git a/package-lock.json b/package-lock.json index e895688619..7515f46162 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "model-share", - "version": "5.0.460", + "version": "5.0.464", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "model-share", - "version": "5.0.460", + "version": "5.0.464", "hasInstallScript": true, "dependencies": { "@aws-sdk/client-s3": "^3.490.0", diff --git a/package.json b/package.json index db229e9c8a..63c0e37e74 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "model-share", - "version": "5.0.460", + "version": "5.0.464", "private": true, "scripts": { "start": "next start", @@ -290,4 +290,4 @@ "overrides": { "@react-aria/interactions": "3.16.0" } -} \ No newline at end of file +} diff --git a/src/components/CardTemplates/AspectRatioImageCard.tsx b/src/components/CardTemplates/AspectRatioImageCard.tsx index a030fee0a9..aa4e793abe 100644 --- a/src/components/CardTemplates/AspectRatioImageCard.tsx +++ b/src/components/CardTemplates/AspectRatioImageCard.tsx @@ -63,6 +63,7 @@ export type AspectRatioImageCardProps = { routedDialog?: RoutedDialogProps; target?: string; isRemix?: boolean; + explain?: boolean; } & ContentTypeProps; const IMAGE_CARD_WIDTH = 450; @@ -83,6 +84,7 @@ export function AspectRatioImageCard({ routedDialog, target, isRemix, + explain, }: AspectRatioImageCardProps) { const { ref, inView } = useInView({ key: cosmetic ? 1 : 0 }); @@ -97,7 +99,7 @@ export function AspectRatioImageCard({ ref={ref} style={!cosmetic ? wrapperStyle : undefined} className={clsx(className)} - > + >
{inView && ( <> @@ -106,6 +108,7 @@ export function AspectRatioImageCard({ connectId={contentId as any} connectType={contentType as any} image={image} + explain={explain} > {(safe) => ( <> @@ -178,7 +181,7 @@ export function AspectRatioImageCard({ {onSite && } )} -
+ ); } diff --git a/src/components/CivitaiLink/CivitaiLinkResourceManager.tsx b/src/components/CivitaiLink/CivitaiLinkResourceManager.tsx index 0ecffa1925..6fc9892d0c 100644 --- a/src/components/CivitaiLink/CivitaiLinkResourceManager.tsx +++ b/src/components/CivitaiLink/CivitaiLinkResourceManager.tsx @@ -1,10 +1,11 @@ import { showNotification } from '@mantine/notifications'; import { ModelType } from '~/shared/utils/prisma/enums'; -import { useCallback } from 'react'; +import { useCallback, useEffect } from 'react'; import { useCivitaiLink, useCivitaiLinkStore } from '~/components/CivitaiLink/CivitaiLinkProvider'; import { CommandResourcesAdd } from '~/components/CivitaiLink/shared-types'; import { ModelHashModel } from '~/server/selectors/modelHash.selector'; import { trpc } from '~/utils/trpc'; +import { showErrorNotification } from '~/utils/notifications'; const supportedModelTypes: ModelType[] = [ 'Checkpoint', @@ -49,7 +50,7 @@ export function CivitaiLinkResourceManager({ ) ); // const activities: Response[] = []; - const { data, refetch, isFetched, isFetching } = trpc.model.getDownloadCommand.useQuery( + const { data, refetch, isFetched, isFetching, error } = trpc.model.getDownloadCommand.useQuery( { modelId, modelVersionId }, { enabled: false, @@ -59,6 +60,10 @@ export function CivitaiLinkResourceManager({ } ); + useEffect(() => { + if (error) showErrorNotification({ error }); + }, [error]); + if (!connected || !supportedModelTypes.includes(modelType) || !hashes || !hashes.length) return fallback ?? null; @@ -71,7 +76,7 @@ export function CivitaiLinkResourceManager({ if (resource) return; if (!isFetched) refetch(); else if (data) runAddCommands(data.commands); - else showNotification({ message: 'Could not get commands' }); + else showNotification({ message: `Could not get commands` }); }; const cancelDownload = () => { diff --git a/src/components/Image/ById/ImageById.tsx b/src/components/Image/ById/ImageById.tsx index 7518f63dc5..9e311e8dcb 100644 --- a/src/components/Image/ById/ImageById.tsx +++ b/src/components/Image/ById/ImageById.tsx @@ -2,7 +2,15 @@ import { Card, Center, Loader } from '@mantine/core'; import { AspectRatioImageCard } from '~/components/CardTemplates/AspectRatioImageCard'; import { trpc } from '~/utils/trpc'; -export const ImageById = ({ imageId, ...props }: { imageId: number; className?: string }) => { +export const ImageById = ({ + imageId, + explain, + ...props +}: { + imageId: number; + className?: string; + explain?: boolean; +}) => { const { data: image, isLoading } = trpc.image.get.useQuery({ id: imageId }); if (isLoading || !image) { @@ -21,6 +29,7 @@ export const ImageById = ({ imageId, ...props }: { imageId: number; className?: href={`/images/${image.id}`} target="_blank" aspectRatio="square" + explain={explain} {...props} /> ); diff --git a/src/components/ImageGeneration/GenerationForm/GenerationForm2.tsx b/src/components/ImageGeneration/GenerationForm/GenerationForm2.tsx index adda2efc97..49f0e9f6a7 100644 --- a/src/components/ImageGeneration/GenerationForm/GenerationForm2.tsx +++ b/src/components/ImageGeneration/GenerationForm/GenerationForm2.tsx @@ -102,7 +102,6 @@ import { numberWithCommas } from '~/utils/number-helpers'; import { getDisplayName, hashify, parseAIR } from '~/utils/string-helpers'; import { trpc } from '~/utils/trpc'; import { isDefined } from '~/utils/type-guards'; -import { Priority } from '@civitai/client'; let total = 0; const tips = { @@ -640,7 +639,7 @@ export function GenerationFormContent() { if (!remixOfId || !remixPrompt || !remixSimilarity) return <>; return ( -
+
diff --git a/src/components/ImageGeneration/GenerationForm/TextToImageWhatIfProvider.tsx b/src/components/ImageGeneration/GenerationForm/TextToImageWhatIfProvider.tsx index ed45a86c81..71ca8f4908 100644 --- a/src/components/ImageGeneration/GenerationForm/TextToImageWhatIfProvider.tsx +++ b/src/components/ImageGeneration/GenerationForm/TextToImageWhatIfProvider.tsx @@ -14,9 +14,9 @@ import { import { trpc } from '~/utils/trpc'; import { UseTRPCQueryResult } from '@trpc/react-query/shared'; +import { useCurrentUser } from '~/hooks/useCurrentUser'; import { GenerationWhatIfResponse } from '~/server/services/orchestrator/types'; import { parseAIR } from '~/utils/string-helpers'; -import { useCurrentUser } from '~/hooks/useCurrentUser'; import { isDefined } from '~/utils/type-guards'; // import { useFeatureFlags } from '~/providers/FeatureFlagsProvider'; @@ -44,8 +44,10 @@ export function TextToImageWhatIfProvider({ children }: { children: React.ReactN const { model, resources = [], vae, ...params } = watched; if (params.aspectRatio) { const size = getSizeFromAspectRatio(Number(params.aspectRatio), params.baseModel); - params.width = size.width; - params.height = size.height; + if (size) { + params.width = size.width; + params.height = size.height; + } } let modelId = defaultModel.id; diff --git a/src/components/ImageGeneration/GenerationForm/VideoGenerationForm.tsx b/src/components/ImageGeneration/GenerationForm/VideoGenerationForm.tsx index c23b99065e..cb9fd9f5d0 100644 --- a/src/components/ImageGeneration/GenerationForm/VideoGenerationForm.tsx +++ b/src/components/ImageGeneration/GenerationForm/VideoGenerationForm.tsx @@ -410,6 +410,10 @@ function MinimaxImg2VidGenerationForm() { ); } +function LightricksPromptDescription() { + const url = 'https://education.civitai.com/civitais-quickstart-guide-to-lightricks-ltxv/#prompting' + return If you see poor results, please refer to the prompt guide +} function LightricksTxt2VidGenerationForm() { return ( @@ -419,6 +423,7 @@ function LightricksTxt2VidGenerationForm() { label="Prompt" placeholder="Your prompt goes here..." autosize + description={LightricksPromptDescription()} /> - +
Duration diff --git a/src/components/Signals/SignalsProvider.tsx b/src/components/Signals/SignalsProvider.tsx index 5f023df2ba..249dacd0fc 100644 --- a/src/components/Signals/SignalsProvider.tsx +++ b/src/components/Signals/SignalsProvider.tsx @@ -1,25 +1,22 @@ import { MantineColor, Notification, NotificationProps } from '@mantine/core'; -import { useSession } from 'next-auth/react'; import { createContext, useContext, useEffect, useRef, useState } from 'react'; import { SignalNotifications } from '~/components/Signals/SignalsNotifications'; import { SignalsRegistrar } from '~/components/Signals/SignalsRegistrar'; -import { useFeatureFlags } from '~/providers/FeatureFlagsProvider'; import { SignalMessages } from '~/server/common/enums'; +import { SignalStatus } from '~/utils/signals/types'; // import { createSignalWorker, SignalWorker } from '~/utils/signals'; -import { useSignalsWorker, SignalWorker, SignalStatus } from '~/utils/signals/useSignalsWorker'; +import { useSignalsWorker, SignalWorker } from '~/utils/signals/useSignalsWorker'; import { trpc } from '~/utils/trpc'; type SignalState = { connected: boolean; - status?: SignalStatus; + status: SignalStatus | null; worker: SignalWorker | null; }; const signalStatusDictionary: Record = { connected: 'green', - reconnected: 'green', reconnecting: 'yellow', - error: 'red', closed: 'red', }; @@ -50,38 +47,30 @@ export const useSignalConnection = (message: SignalMessages, cb: SignalCallback) }; export function SignalProvider({ children }: { children: React.ReactNode }) { - const session = useSession(); const queryUtils = trpc.useUtils(); - const features = useFeatureFlags(); - - const [status, setStatus] = useState(); - - const { data } = trpc.signals.getToken.useQuery(undefined, { - enabled: !!session.data?.user && features.signal, + const prevStatusRef = useRef(null); + const hasConnectedAtLeastOnceRef = useRef(false); + + const [status, setStatus] = useState(null); + prevStatusRef.current = status ?? null; + + const worker = useSignalsWorker({ + onStateChange: ({ state }) => { + const prevStatus = prevStatusRef.current; + const hasConnectedAtLeastOnce = hasConnectedAtLeastOnceRef.current; + if (prevStatus !== state && state === 'connected' && hasConnectedAtLeastOnce) { + queryUtils.buzz.getBuzzAccount.invalidate(); + queryUtils.orchestrator.queryGeneratedImages.invalidate(); + } + + if (state === 'connected') hasConnectedAtLeastOnceRef.current = true; + setStatus(state); + }, }); - const accessToken = data?.accessToken; - const userId = session.data?.user?.id; - - const worker = useSignalsWorker( - { accessToken }, - { - onReconnected: () => { - if (userId) { - queryUtils.buzz.getBuzzAccount.invalidate(); - queryUtils.orchestrator.queryGeneratedImages.invalidate(); - } - }, - onError: () => { - queryUtils.signals.getToken.invalidate(); - }, - onStatusChange: ({ status }) => setStatus(status), - } - ); - - const connected = status === 'connected' || status === 'reconnected'; + const connected = status === 'connected'; - return features.signal ? ( + return ( {children} - ) : ( - - {children} - ); } diff --git a/src/pages/api/admin/test.ts b/src/pages/api/admin/test.ts index 172adcfab6..807357a46e 100644 --- a/src/pages/api/admin/test.ts +++ b/src/pages/api/admin/test.ts @@ -10,7 +10,7 @@ import { getSystemPermissions } from '~/server/services/system-cache'; import { addGenerationEngine } from '~/server/services/generation/engines'; import { dbWrite, dbRead } from '~/server/db/client'; import { limitConcurrency, Task } from '~/server/utils/concurrency-helpers'; -import { getGenerationResourceData } from '~/server/services/generation/generation.service'; +import { getResourceData } from '~/server/services/generation/generation.service'; import { Prisma } from '@prisma/client'; import { getCommentsThreadDetails2 } from '~/server/services/commentsv2.service'; diff --git a/src/server/controllers/model-version.controller.ts b/src/server/controllers/model-version.controller.ts index 05b5828eec..bc44de8502 100644 --- a/src/server/controllers/model-version.controller.ts +++ b/src/server/controllers/model-version.controller.ts @@ -58,7 +58,7 @@ import { dbRead } from '../db/client'; import { modelFileSelect } from '../selectors/modelFile.selector'; import { getFilesByEntity } from '../services/file.service'; import { createFile } from '../services/model-file.service'; -import { getGenerationResourceData } from './../services/generation/generation.service'; +import { getResourceData } from './../services/generation/generation.service'; export const getModelVersionRunStrategiesHandler = ({ input: { id } }: { input: GetByIdInput }) => { try { @@ -152,7 +152,7 @@ export const getModelVersionHandler = async ({ }); const recommendedResourceIds = version?.recommendedResources.map((x) => x.id) ?? []; - const generationResources = await getGenerationResourceData({ + const generationResources = await getResourceData({ ids: recommendedResourceIds, user: ctx?.user, }).then((data) => diff --git a/src/server/controllers/model.controller.ts b/src/server/controllers/model.controller.ts index 39d8427555..6120c088c6 100644 --- a/src/server/controllers/model.controller.ts +++ b/src/server/controllers/model.controller.ts @@ -120,7 +120,7 @@ import { isDefined } from '~/utils/type-guards'; import { redis, REDIS_KEYS } from '../redis/client'; import { BountyDetailsSchema } from '../schema/bounty.schema'; import { - getGenerationResourceData, + getResourceData, getUnavailableResources, } from '../services/generation/generation.service'; @@ -187,7 +187,7 @@ export const getModelHandler = async ({ input, ctx }: { input: GetByIdInput; ctx const recommendedResourceIds = model.modelVersions.flatMap((version) => version?.recommendedResources.map((x) => x.id)) ?? []; - const generationResources = await getGenerationResourceData({ + const generationResources = await getResourceData({ ids: recommendedResourceIds, user: ctx?.user, }); @@ -775,9 +775,11 @@ export const getDownloadCommandHandler = async ({ }); if (!modelVersion) throw throwNotFoundError(); - const isDownloadable = modelVersion.usageControl !== ModelUsageControl.Download; + const isOwner = ctx.user?.id === modelVersion.model.userId; + const isDownloadable = + modelVersion.usageControl === ModelUsageControl.Download || isOwner || ctx.user?.isModerator; - if (!isDownloadable && !(modelVersion.model.userId === ctx.user?.id || ctx.user?.isModerator)) { + if (!isDownloadable) { throw throwAuthorizationError(); } diff --git a/src/server/orchestrator/lightricks/lightricks.schema.ts b/src/server/orchestrator/lightricks/lightricks.schema.ts index b32962b46d..fb756b7f39 100644 --- a/src/server/orchestrator/lightricks/lightricks.schema.ts +++ b/src/server/orchestrator/lightricks/lightricks.schema.ts @@ -9,14 +9,14 @@ import { } from '~/server/orchestrator/infrastructure/base.schema'; import { numberEnum } from '~/utils/zod-helpers'; -export const lightricksAspectRatios = ['16:9', '1:1', '9:16'] as const; +export const lightricksAspectRatios = ['16:9', '1:1'] as const; export const lightricksDuration = [5, 10] as const; const lightricksTxt2VidSchema = textEnhancementSchema.extend({ engine: z.literal('lightricks'), workflow: z.string(), negativePrompt: negativePromptSchema, - aspectRatio: z.enum(lightricksAspectRatios).default('1:1').catch('1:1'), + aspectRatio: z.enum(lightricksAspectRatios).default('16:9').catch('16:9'), duration: numberEnum(lightricksDuration).default(5).catch(5), cfgScale: z.number().min(3).max(3.5).default(3).catch(3), steps: z.number().min(20).max(30).default(25).catch(25), diff --git a/src/server/routers/signals.router.ts b/src/server/routers/signals.router.ts index a13838efd7..c2c13a460b 100644 --- a/src/server/routers/signals.router.ts +++ b/src/server/routers/signals.router.ts @@ -1,6 +1,6 @@ import { getUserAccountHandler } from '~/server/controllers/signals.controller'; -import { isFlagProtected, protectedProcedure, router } from '~/server/trpc'; +import { protectedProcedure, router } from '~/server/trpc'; export const signalsRouter = router({ - getToken: protectedProcedure.use(isFlagProtected('signal')).query(getUserAccountHandler), + getToken: protectedProcedure.query(getUserAccountHandler), }); diff --git a/src/server/schema/model.schema.ts b/src/server/schema/model.schema.ts index 69726e4bc9..a8c6fea8e8 100644 --- a/src/server/schema/model.schema.ts +++ b/src/server/schema/model.schema.ts @@ -138,8 +138,8 @@ export const deleteModelSchema = getByIdSchema.extend({ permanently: z.boolean() export type DeleteModelSchema = z.infer; export const getDownloadSchema = z.object({ - modelId: z.preprocess((val) => Number(val), z.number()), - modelVersionId: z.preprocess((val) => Number(val), z.number()).optional(), + modelId: z.coerce.number(), + modelVersionId: z.coerce.number().optional(), type: z.enum(constants.modelFileTypes).optional(), format: z.enum(constants.modelFileFormats).optional(), }); diff --git a/src/server/services/feature-flags.service.ts b/src/server/services/feature-flags.service.ts index 7206ed8cc4..72266de7fb 100644 --- a/src/server/services/feature-flags.service.ts +++ b/src/server/services/feature-flags.service.ts @@ -32,7 +32,7 @@ const featureFlags = createFeatureFlags({ articles: ['blue', 'red', 'public'], articleCreate: ['public'], adminTags: ['mod', 'granted'], - civitaiLink: isDev ? ['granted'] : ['mod', 'member'], + civitaiLink: ['mod', 'member'], stripe: ['mod'], imageTraining: ['user'], imageTrainingResults: ['user'], @@ -59,7 +59,6 @@ const featureFlags = createFeatureFlags({ profileCollections: ['public'], imageSearch: ['public'], buzz: ['public'], - signal: isDev ? ['granted', 'user'] : ['user'], recommenders: isDev ? ['granted', 'dev', 'mod'] : ['dev', 'mod'], assistant: { toggleable: true, diff --git a/src/server/services/generation/generation.service.ts b/src/server/services/generation/generation.service.ts index 6593ebd919..9464de3d98 100644 --- a/src/server/services/generation/generation.service.ts +++ b/src/server/services/generation/generation.service.ts @@ -37,9 +37,12 @@ import { fromJson, toJson } from '~/utils/json-helpers'; import { getPagedData } from '~/server/utils/pagination-helpers'; import { + baseModelResourceTypes, fluxUltraAir, getBaseModelFromResources, getBaseModelSet, + getBaseModelSetType, + SupportedBaseModel, } from '~/shared/constants/generation.constants'; import { isFutureDate } from '~/utils/date-helpers'; import { cleanPrompt } from '~/utils/metadata/audit'; @@ -509,7 +512,7 @@ export type GenerationResource = GenerationResourceBase & { const explicitCoveredModelAirs = [fluxUltraAir]; const explicitCoveredModelVersionIds = explicitCoveredModelAirs.map((air) => parseAIR(air).version); -export async function getGenerationResourceData({ +export async function getResourceData({ ids, user, }: { @@ -651,3 +654,18 @@ export async function getGenerationResourceData({ }); }); } + +export async function getGenerationResourceData(args: { + ids: number[]; + user?: { + id?: number; + isModerator?: boolean; + }; +}) { + return await getResourceData(args).then((data) => + data.filter((resource) => { + const baseModel = getBaseModelSetType(resource.baseModel) as SupportedBaseModel; + return !!baseModelResourceTypes[baseModel]; + }) + ); +} diff --git a/src/shared/constants/generation.constants.ts b/src/shared/constants/generation.constants.ts index e8079460de..73a32d8f15 100644 --- a/src/shared/constants/generation.constants.ts +++ b/src/shared/constants/generation.constants.ts @@ -4,15 +4,16 @@ import { baseModelSets, BaseModelSetType, generation, + generationConfig, getGenerationConfig, Sampler, } from '~/server/common/constants'; +import { videoGenerationConfig } from '~/server/orchestrator/generation/generation.config'; import { GenerationLimits } from '~/server/schema/generation.schema'; import { TextToImageParams } from '~/server/schema/orchestrator/textToImage.schema'; import { WorkflowDefinition } from '~/server/services/orchestrator/types'; import { ModelType } from '~/shared/utils/prisma/enums'; import { findClosest } from '~/utils/number-helpers'; -import { videoGenerationConfig } from '~/server/orchestrator/generation/generation.config'; export const WORKFLOW_TAGS = { GENERATION: 'gen', @@ -339,7 +340,7 @@ export function sanitizeTextToImageParams>( export function getSizeFromAspectRatio(aspectRatio: number | string, baseModel?: string) { const numberAspectRatio = typeof aspectRatio === 'string' ? Number(aspectRatio) : aspectRatio; const config = getGenerationConfig(baseModel); - return config.aspectRatios[numberAspectRatio]; + return config.aspectRatios[numberAspectRatio] ?? generationConfig.SD1.aspectRatios[0]; } export const getClosestAspectRatio = (width?: number, height?: number, baseModel?: string) => { diff --git a/src/utils/delivery-worker.ts b/src/utils/delivery-worker.ts index ce48f757e8..0a4adb5684 100644 --- a/src/utils/delivery-worker.ts +++ b/src/utils/delivery-worker.ts @@ -47,7 +47,7 @@ export async function getDownloadUrl(fileUrl: string, fileName?: string) { } if (!response.ok) { - throw new Error(response.statusText); + throw new Error(`Delivery worker error: ${response.statusText}`); } const result = await response.json(); return result as DownloadInfo; diff --git a/src/utils/signals/index.ts b/src/utils/signals/index.ts deleted file mode 100644 index fbbdda57c6..0000000000 --- a/src/utils/signals/index.ts +++ /dev/null @@ -1,137 +0,0 @@ -import SharedWorker from '@okikio/sharedworker'; -import { createStore } from 'zustand/vanilla'; -import type { WorkerIncomingMessage, WorkerOutgoingMessage } from './types'; -import { Deferred, EventEmitter } from './utils'; - -// Debugging -const logs: Record = {}; - -type State = { available: boolean }; -type Store = State & { update: (fn: (args: State) => State) => void }; - -export type SignalWorker = ReturnType; -export const createSignalWorker = ({ - onConnected, - onClosed, - onError, - onReconnected, - onReconnecting, -}: { - onConnected?: () => void; - onReconnected?: () => void; - onReconnecting?: () => void; - /** A closed connection will not recover on its own. */ - onClosed?: (message?: string) => void; - onError?: (message?: string) => void; -}) => { - const deferred = new Deferred(); - const emitter = new EventEmitter(); - let pingDeferred: Deferred | undefined; - - const { getState, subscribe } = createStore((set) => ({ - available: false, - signal: 'closed', - update: (fn) => set((args) => ({ ...fn(args) })), - })); - - const worker = new SharedWorker(new URL('./worker.v1.2.ts', import.meta.url), { - name: 'civitai-signals:1.2.3', - type: 'module', - }); - - worker.port.onmessage = async ({ data }: { data: WorkerOutgoingMessage }) => { - if (data.type === 'worker:ready') deferred.resolve(); - else if (data.type === 'connection:ready') onConnected?.(); - else if (data.type === 'connection:closed') onClosed?.(data.message); - else if (data.type === 'connection:error') onError?.(data.message); - else if (data.type === 'connection:reconnected') onReconnected?.(); - else if (data.type === 'connection:reconnecting') onReconnecting?.(); - else if (data.type === 'event:received') emitter.emit(data.target, data.payload); - else if (data.type === 'pong') pingDeferred?.resolve(); - }; - - const postMessage = (message: WorkerIncomingMessage) => worker.port.postMessage(message); - - const on = (target: string, cb: (data: unknown) => void) => { - postMessage({ type: 'event:register', target }); - emitter.on(target, cb); - }; - - const off = (target: string, cb: (data: unknown) => void) => { - emitter.off(target, cb); - }; - - const unload = () => { - postMessage({ type: 'beforeunload' }); - emitter.stop(); - }; - - const ping = async () => { - if (!pingDeferred && document.visibilityState === 'visible') { - pingDeferred = new Deferred(); - postMessage({ type: 'ping' }); - setTimeout(() => { - if (pingDeferred) pingDeferred.reject(); - }, 1000); - - await pingDeferred.promise - .then(() => getState().update((state) => ({ ...state, available: true }))) - .catch(() => { - getState().update((state) => ({ ...state, available: false })); - onClosed?.('connection to shared worker lost'); - }); - pingDeferred = undefined; - } - }; - - if (typeof window !== 'undefined') { - window.logSignal = (target, selector) => { - function logFn(args: unknown) { - if (selector) { - const result = [args].find(selector); - if (result) console.log(result); - } else console.log(args); - } - - if (!logs[target]) { - logs[target] = true; - on(target, logFn); - console.log(`begin logging: ${target}`); - } - }; - - window.ping = () => { - window.logSignal('pong'); - postMessage({ type: 'ping' }); - }; - } - - const close = () => { - document.removeEventListener('visibilitychange', ping); - window.removeEventListener('beforeunload', unload); - unload(); - }; - - // fire off an event to remove this port from the worker - window.addEventListener('beforeunload', unload, { once: true }); - // ping-pong with worker to check for worker availability - document.addEventListener('visibilitychange', ping); - - async function init(token: string, userId: number) { - await deferred.promise; - postMessage({ type: 'connection:init', token, userId }); - } - - function send(target: string, args: Record) { - postMessage({ type: 'send', target, args }); - } - - return { - on, - off, - close, - subscribe, - init, - send, - }; -}; diff --git a/src/utils/signals/types.ts b/src/utils/signals/types.ts index 5f451b2559..cee5fcc656 100644 --- a/src/utils/signals/types.ts +++ b/src/utils/signals/types.ts @@ -1,32 +1,7 @@ -import { HubConnectionState } from '@microsoft/signalr'; - type SignalWorkerReady = { type: 'worker:ready'; }; -type SignalConnectionStarted = { - type: 'connection:ready'; -}; - -type SignalConnectionClosed = { - type: 'connection:closed'; - message?: string; -}; - -type SignalWorkerError = { - type: 'connection:error'; - message?: string; -}; - -type SignalWorkerReconnected = { - type: 'connection:reconnected'; -}; - -type SignalWorkerReconnecting = { - type: 'connection:reconnecting'; - message?: string; -}; - type SignalWorkerPong = { type: 'pong' }; type SignalEventReceived = { @@ -35,22 +10,20 @@ type SignalEventReceived = { payload: T; }; -type SignalStatus = { - type: 'connection:state'; - state?: HubConnectionState; +export type SignalStatus = 'connected' | 'closed' | 'reconnecting'; +export type SignalConnectionState = { + state: SignalStatus | null; message?: string; }; +type SignalWorkerState = { + type: 'connection:state'; +} & SignalConnectionState; export type WorkerOutgoingMessage = | SignalWorkerReady - | SignalConnectionStarted - | SignalConnectionClosed - | SignalWorkerError - | SignalWorkerReconnected - | SignalWorkerReconnecting | SignalEventReceived | SignalWorkerPong - | SignalStatus; + | SignalWorkerState; export type WorkerIncomingMessage = | { type: 'connection:init'; token: string; userId: number } diff --git a/src/utils/signals/useSignalsWorker.ts b/src/utils/signals/useSignalsWorker.ts index e70534bda7..8637d3578e 100644 --- a/src/utils/signals/useSignalsWorker.ts +++ b/src/utils/signals/useSignalsWorker.ts @@ -1,57 +1,49 @@ import { useEffect, useMemo, useRef, useState } from 'react'; import SharedWorker from '@okikio/sharedworker'; -import type { WorkerOutgoingMessage } from './types'; +import type { SignalConnectionState, SignalStatus, WorkerOutgoingMessage } from './types'; import { Deferred, EventEmitter } from './utils'; import { useCurrentUser } from '~/hooks/useCurrentUser'; -import { useFeatureFlags } from '~/providers/FeatureFlagsProvider'; +import { trpc } from '~/utils/trpc'; -export type SignalStatus = 'connected' | 'closed' | 'error' | 'reconnected' | 'reconnecting'; export type SignalWorker = NonNullable>; -type SignalState = { - status: SignalStatus; - message?: string; -}; const logs: Record = {}; +let logConnectionState = false; -export function useSignalsWorker( - args: { accessToken?: string }, - options?: { - onConnected?: () => void; - onReconnected?: () => void; - onReconnecting?: () => void; - /** A closed connection will not recover on its own. */ - onClosed?: (message?: string) => void; - onError?: (message?: string) => void; - onStatusChange?: (args: SignalState) => void; - } -) { +export function useSignalsWorker(options?: { + onStateChange?: (args: SignalConnectionState) => void; +}) { const currentUser = useCurrentUser(); - const features = useFeatureFlags(); - const { accessToken } = args; - const { onConnected, onClosed, onError, onReconnected, onReconnecting, onStatusChange } = - options ?? {}; + const userId = currentUser?.id; + const { onStateChange } = options ?? {}; - const [state, setState] = useState(); + const [connection, setConnection] = useState(); const [ready, setReady] = useState(false); const [worker, setWorker] = useState(null); + const shouldInitialize = connection === 'closed'; + + const queryUtils = trpc.useUtils(); + const { data } = trpc.signals.getToken.useQuery(undefined, { + enabled: !!userId && shouldInitialize, + }); + const accessToken = data?.accessToken; const emitterRef = useRef(new EventEmitter()); const deferredRef = useRef(new Deferred()); // handle init worker useEffect(() => { - if (worker || !features.signal) return; + if (worker) return; setReady(false); setWorker( (worker) => worker ?? - new SharedWorker(new URL('./worker.v1.2.ts', import.meta.url), { - name: 'civitai-signals:1.2.6', + new SharedWorker(new URL('./worker.ts', import.meta.url), { + name: 'civitai-signals:2', type: 'module', }) ); - }, [features.signal, worker]); + }, [worker]); // handle register worker events useEffect(() => { @@ -59,45 +51,17 @@ export function useSignalsWorker( worker.port.onmessage = async ({ data }: { data: WorkerOutgoingMessage }) => { if (data.type === 'worker:ready') setReady(true); - else if (data.type === 'connection:ready') - setState((prev) => { - if ( - prev?.status === 'closed' || - prev?.status === 'error' || - prev?.status === 'reconnecting' - ) - return { status: 'reconnected' }; - else return { status: 'connected' }; - }); - else if (data.type === 'connection:closed') - setState({ status: 'closed', message: data.message }); - else if (data.type === 'connection:error') - setState({ status: 'error', message: data.message }); - else if (data.type === 'connection:reconnected') setState({ status: 'reconnected' }); - else if (data.type === 'connection:reconnecting') setState({ status: 'reconnecting' }); else if (data.type === 'event:received') emitterRef.current.emit(data.target, data.payload); else if (data.type === 'pong') deferredRef.current.resolve(); + else if (data.type === 'connection:state') { + setConnection(data.state ?? 'closed'); + onStateChange?.({ state: data.state, message: data.message }); + if (data.state === 'closed') queryUtils.signals.getToken.invalidate(); + if (logConnectionState) console.log({ state: data.state }, new Date().toLocaleTimeString()); + } }; }, [worker]); - useEffect(() => { - if (!state) return; - console.debug(`SignalService :: ${state.status}`); - onStatusChange?.(state); - switch (state.status) { - case 'connected': - return onConnected?.(); - case 'reconnected': - return onReconnected?.(); - case 'reconnecting': - return onReconnecting?.(); - case 'closed': - return onClosed?.(state.message); - case 'error': - return onError?.(state.message); - } - }, [state]); - // handle tab close useEffect(() => { function unload() { @@ -113,18 +77,18 @@ export function useSignalsWorker( // init useEffect(() => { - if (worker && ready && accessToken && currentUser?.id) + if (worker && ready && accessToken && userId) worker.port.postMessage({ type: 'connection:init', token: accessToken, - userId: currentUser.id, + userId, }); - }, [worker, accessToken, ready, currentUser?.id]); + }, [worker, accessToken, ready, userId]); // ping useEffect(() => { function handleVisibilityChange() { - if (document.visibilityState !== 'visible') return; + if (document.visibilityState !== 'visible' || !worker) return; deferredRef.current = new Deferred(); worker?.port.postMessage({ type: 'ping' }); const timeout = setTimeout(() => deferredRef.current.reject(), 1000); @@ -135,7 +99,7 @@ export function useSignalsWorker( }) .catch(() => { setReady(false); - setState({ status: 'closed', message: 'connection to shared worker lost' }); + setConnection('closed'); }); } @@ -184,8 +148,8 @@ export function useSignalsWorker( }; window.ping = () => { - window.logSignal('pong'); worker?.port.postMessage({ type: 'ping' }); + logConnectionState = true; }; } }, [workerMethods]); diff --git a/src/utils/signals/worker.ts b/src/utils/signals/worker.ts new file mode 100644 index 0000000000..9acba19ee9 --- /dev/null +++ b/src/utils/signals/worker.ts @@ -0,0 +1,163 @@ +import { + HttpTransportType, + HubConnection, + HubConnectionBuilder, + HubConnectionState, + LogLevel, +} from '@microsoft/signalr'; +import { env } from '~/env/client'; +import type { + SignalConnectionState, + SignalStatus, + WorkerIncomingMessage, + WorkerOutgoingMessage, +} from './types'; +import { EventEmitter } from './utils'; + +// -------------------------------- +// Types +// -------------------------------- +interface SharedWorkerGlobalScope { + onconnect: (event: MessageEvent) => void; +} + +const _self: SharedWorkerGlobalScope = self as any; + +let connectionState: SignalConnectionState = { state: null }; +let connectedUserId: number | null = null; +let connection: HubConnection | null = null; +// let pingInterval: NodeJS.Timer | null = null; +const events: Record void> = {}; + +const emitter = new EventEmitter<{ + eventReceived: { target: string; payload: any }; + stateChanged: SignalConnectionState; + pong: undefined; +}>(); + +function setConnectionState(args: { state: SignalStatus; message?: string }) { + emitter.emit('stateChanged', args); +} + +function emitCurrentConnectionState() { + emitter.emit('stateChanged', connectionState); +} + +emitter.on('stateChanged', ({ state, message }) => { + connectionState = { state, message }; + if (state === 'closed') connection = null; + console.log(`SignalR status: ${state}`, message); +}); + +// let interval: NodeJS.Timer | undefined; +// if (interval) clearInterval(interval); +// interval = setInterval(emitCurrentConnectionState, 10 * 1000); + +async function connect() { + try { + if (!connection) throw new Error('missing SignalR connection'); + // don't try to connect unless the connection is closed + if (connection.state !== HubConnectionState.Disconnected) return; + try { + await connection.start(); + setConnectionState({ state: 'connected' }); + } catch (err) { + console.log(err); + setTimeout(() => connect(), 5000); + } + } catch (e) { + setConnectionState({ state: 'closed', message: (e as Error).message }); + } +} + +const buildHubConnection = async ({ userId, token }: { token: string; userId: number }) => { + if (userId !== connectedUserId) { + connectedUserId = userId; + if (connection) { + (connection as any)._closedCallbacks = []; + await connection.stop(); + connection = null; + } + } + + if (connection) return connection; + + connection = new HubConnectionBuilder() + .withUrl(`${env.NEXT_PUBLIC_SIGNALS_ENDPOINT}/hub`, { + accessTokenFactory: () => token, + skipNegotiation: true, + transport: HttpTransportType.WebSockets, + }) + .configureLogging(LogLevel.Information) + .withAutomaticReconnect([0, 2, 10, 18, 30, 45, 60, 90]) + .build(); + + connection.onreconnected(() => { + setConnectionState({ state: 'connected' }); + }); + connection.onreconnecting((error) => { + setConnectionState({ state: 'reconnecting', message: JSON.stringify(error) }); + }); + connection.onclose((error) => { + setConnectionState({ state: 'closed', message: JSON.stringify(error) }); + }); + connection.on('Pong', () => console.log('pong')); + + for (const [target, event] of Object.entries(events)) { + connection.on(target, event); + } + return connection; +}; + +async function registerEvents(targets: string[]) { + for (const target of targets) { + if (!events[target]) { + events[target] = (payload) => emitter.emit('eventReceived', { target, payload }); + if (connection) { + connection.on(target, events[target]); + } + } + } +} + +const start = async (port: MessagePort) => { + if (!port.postMessage) return; + if (port.start) port.start(); + + const postMessage = (req: WorkerOutgoingMessage) => port.postMessage(req); + postMessage({ type: 'worker:ready' }); + postMessage({ type: 'connection:state', ...connectionState }); + + const emitterOffHandlers = [ + emitter.on('stateChanged', ({ state, message }) => + postMessage({ type: 'connection:state', state, message }) + ), + emitter.on('eventReceived', ({ target, payload }) => + postMessage({ type: 'event:received', target, payload }) + ), + emitter.on('pong', () => postMessage({ type: 'pong' })), + ]; + + // incoming messages + port.onmessage = async ({ data }: { data: WorkerIncomingMessage }) => { + if (data.type === 'connection:init') { + await buildHubConnection({ token: data.token, userId: data.userId }); + await connect(); + } else if (data.type === 'event:register') registerEvents([data.target]); + else if (data.type === 'beforeunload') { + emitterOffHandlers.forEach((fn) => fn()); + port.close(); + } else if (data.type === 'ping') { + emitter.emit('pong', undefined); + emitCurrentConnectionState(); + } else if (data.type === 'send') connection?.send(data.target, data.args); + }; +}; + +_self.onconnect = (e) => { + const [port] = e.ports; + start(port); +}; + +// This is the fallback for WebWorkers, in case the browser doesn't support SharedWorkers natively +if (!('SharedWorkerGlobalScope' in _self)) start(_self as any); diff --git a/src/utils/signals/worker.v1.2.ts b/src/utils/signals/worker.v1.2.ts deleted file mode 100644 index 66c1b5c219..0000000000 --- a/src/utils/signals/worker.v1.2.ts +++ /dev/null @@ -1,169 +0,0 @@ -import { - HttpTransportType, - HubConnection, - HubConnectionBuilder, - HubConnectionState, - LogLevel, - // HubConnectionState, -} from '@microsoft/signalr'; -import { env } from '~/env/client'; -import type { WorkerIncomingMessage, WorkerOutgoingMessage } from './types'; -import { EventEmitter } from './utils'; - -// -------------------------------- -// Types -// -------------------------------- -interface SharedWorkerGlobalScope { - onconnect: (event: MessageEvent) => void; -} - -const _self: SharedWorkerGlobalScope = self as any; - -let connectedUserId: number | null = null; -let connection: HubConnection | null = null; -// let pingInterval: NodeJS.Timer | null = null; -const events: Record void> = {}; - -const emitter = new EventEmitter<{ - connectionReady: undefined; - connectionClosed: { message?: string }; - connectionError: { message?: string }; - connectionReconnecting: { message?: string }; - connectionReconnected: undefined; - eventReceived: { target: string; payload: any }; - pong: undefined; -}>(); - -async function connect(args?: { retryAttempts?: number; timeout?: number }) { - const { retryAttempts = 10, timeout = 5000 } = args ?? {}; - try { - if (!connection) throw new Error('no connection to start'); - if (connection.state !== HubConnectionState.Disconnected) { - throw new Error( - `cannot start new connection :: current connection status: ${connection.state}` - ); - } - - await connection.start(); - emitter.emit('connectionReady', undefined); - console.log('SignalR Connected.'); - } catch (err) { - console.log(err); - setTimeout(() => { - if (retryAttempts > 0) connect({ retryAttempts: retryAttempts - 1, timeout: timeout * 1.4 }); - else throw new Error('failed to connect to signal service'); - }, timeout); - } -} - -const getConnection = async ({ token, userId }: { token: string; userId: number }) => { - if (connection && userId === connectedUserId) { - emitter.emit('connectionReady', undefined); - return connection; - } - if (userId !== connectedUserId) { - connectedUserId = userId; - if (connection) { - (connection as any)._closedCallbacks = []; - await connection.stop(); - connection = null; - } - } - - connection = new HubConnectionBuilder() - .withUrl(`${env.NEXT_PUBLIC_SIGNALS_ENDPOINT}/hub`, { - accessTokenFactory: () => token, - skipNegotiation: true, - transport: HttpTransportType.WebSockets, - }) - .configureLogging(LogLevel.Information) - .withAutomaticReconnect() - .build(); - - try { - connection.onreconnected(() => { - emitter.emit('connectionReconnected', undefined); - }); - connection.onreconnecting((error) => { - emitter.emit('connectionReconnecting', { message: JSON.stringify(error) }); - }); - connection.onclose((error) => { - emitter.emit('connectionClosed', { message: JSON.stringify(error) }); - setTimeout(() => connect(), 5000); - }); - connection.on('Pong', () => { - console.log('pong'); - }); - - for (const [target, event] of Object.entries(events)) { - connection.on(target, event); - } - - await connect(); - } catch (error) { - console.log(error); - emitter.emit('connectionError', { message: (error as Error).message ?? '' }); - connection = null; - } - - return connection; -}; - -const registerEvents = async (targets: string[]) => { - for (const target of targets) { - if (!events[target]) { - events[target] = (payload) => emitter.emit('eventReceived', { target, payload }); - if (connection) { - connection.on(target, events[target]); - } - } - } -}; - -const start = async (port: MessagePort) => { - if (!port.postMessage) return; - if (port.start) port.start(); - - const postMessage = (req: WorkerOutgoingMessage) => port.postMessage(req); - postMessage({ type: 'worker:ready' }); - - const emitterOffHandlers = [ - emitter.on('connectionReconnected', () => postMessage({ type: 'connection:reconnected' })), - emitter.on('connectionReady', () => postMessage({ type: 'connection:ready' })), - emitter.on('connectionClosed', ({ message }) => - postMessage({ type: 'connection:closed', message }) - ), - emitter.on('connectionError', ({ message }) => - postMessage({ type: 'connection:error', message }) - ), - emitter.on('connectionReconnecting', ({ message }) => - postMessage({ type: 'connection:reconnecting', message }) - ), - emitter.on('eventReceived', ({ target, payload }) => - postMessage({ type: 'event:received', target, payload }) - ), - emitter.on('pong', () => { - postMessage({ type: 'pong' }); - postMessage({ type: 'event:received', target: 'pong', payload: connection?.state }); - }), - ]; - - // incoming messages - port.onmessage = async ({ data }: { data: WorkerIncomingMessage }) => { - if (data.type === 'connection:init') getConnection({ token: data.token, userId: data.userId }); - else if (data.type === 'event:register') registerEvents([data.target]); - else if (data.type === 'beforeunload') { - emitterOffHandlers.forEach((fn) => fn()); - port.close(); - } else if (data.type === 'ping') emitter.emit('pong', undefined); - else if (data.type === 'send') connection?.send(data.target, data.args); - }; -}; - -_self.onconnect = (e) => { - const [port] = e.ports; - start(port); -}; - -// This is the fallback for WebWorkers, in case the browser doesn't support SharedWorkers natively -if (!('SharedWorkerGlobalScope' in _self)) start(_self as any); diff --git a/tailwind.config.js b/tailwind.config.js index 53cb6e3945..a996627c3d 100644 --- a/tailwind.config.js +++ b/tailwind.config.js @@ -121,15 +121,15 @@ module.exports = { 9: '#d9480f', }, lime: { - 0: '#f4fce3', - 1: '#e9fac8', - 2: '#d8f5a2', - 3: '#c0eb75', - 4: '#a9e34b', - 5: '#94d82d', - 6: '#82c91e', - 7: '#74b816', - 8: '#66a80f', + 0: '#f4fce3', + 1: '#e9fac8', + 2: '#d8f5a2', + 3: '#c0eb75', + 4: '#a9e34b', + 5: '#94d82d', + 6: '#82c91e', + 7: '#74b816', + 8: '#66a80f', 9: '#5c940d', } }, diff --git a/tsconfig.json b/tsconfig.json index a6a0035ff7..02b42eb020 100644 --- a/tsconfig.json +++ b/tsconfig.json @@ -1,11 +1,7 @@ { "compilerOptions": { "target": "ES2018", - "lib": [ - "dom", - "dom.iterable", - "esnext" - ], + "lib": ["dom", "dom.iterable", "esnext"], // "types": ["offscreencanvas"], "allowJs": true, "skipLibCheck": true, @@ -23,13 +19,9 @@ "noUncheckedIndexedAccess": false, // TODO swap to true "baseUrl": "src", "paths": { - "~/*": [ - "./*" - ] + "~/*": ["./*"] }, - "typeRoots": [ - "./types" - ], + "typeRoots": ["./types"], "noErrorTruncation": true, "plugins": [ { @@ -44,12 +36,9 @@ // "**/*.cjs", // "**/*.mjs", "scripts", - "src" + "src", // ".next/types/**/*.ts" - , ".next/types/**/*.ts" ], - "exclude": [ - "node_modules" - ] + "exclude": ["node_modules"] }