Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from civitai:main #330

Merged
merged 13 commits into from
Feb 6, 2025
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "model-share",
"version": "5.0.460",
"version": "5.0.464",
"private": true,
"scripts": {
"start": "next start",
Expand Down Expand Up @@ -290,4 +290,4 @@
"overrides": {
"@react-aria/interactions": "3.16.0"
}
}
}
7 changes: 5 additions & 2 deletions src/components/CardTemplates/AspectRatioImageCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export type AspectRatioImageCardProps<T extends DialogKey> = {
routedDialog?: RoutedDialogProps<T>;
target?: string;
isRemix?: boolean;
explain?: boolean;
} & ContentTypeProps;

const IMAGE_CARD_WIDTH = 450;
Expand All @@ -83,6 +84,7 @@ export function AspectRatioImageCard<T extends DialogKey>({
routedDialog,
target,
isRemix,
explain,
}: AspectRatioImageCardProps<T>) {
const { ref, inView } = useInView({ key: cosmetic ? 1 : 0 });

Expand All @@ -97,7 +99,7 @@ export function AspectRatioImageCard<T extends DialogKey>({
ref={ref}
style={!cosmetic ? wrapperStyle : undefined}
className={clsx(className)}
>
>
<div className={clsx(styles.content, { [styles.inView]: inView })}>
{inView && (
<>
Expand All @@ -106,6 +108,7 @@ export function AspectRatioImageCard<T extends DialogKey>({
connectId={contentId as any}
connectType={contentType as any}
image={image}
explain={explain}
>
{(safe) => (
<>
Expand Down Expand Up @@ -178,7 +181,7 @@ export function AspectRatioImageCard<T extends DialogKey>({
{onSite && <OnsiteIndicator isRemix={isRemix} />}
</>
)}
</div>
</div>
</CosmeticCard>
);
}
Expand Down
11 changes: 8 additions & 3 deletions src/components/CivitaiLink/CivitaiLinkResourceManager.tsx
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import { showNotification } from '@mantine/notifications';
import { ModelType } from '~/shared/utils/prisma/enums';
import { useCallback } from 'react';
import { useCallback, useEffect } from 'react';
import { useCivitaiLink, useCivitaiLinkStore } from '~/components/CivitaiLink/CivitaiLinkProvider';
import { CommandResourcesAdd } from '~/components/CivitaiLink/shared-types';
import { ModelHashModel } from '~/server/selectors/modelHash.selector';
import { trpc } from '~/utils/trpc';
import { showErrorNotification } from '~/utils/notifications';

const supportedModelTypes: ModelType[] = [
'Checkpoint',
Expand Down Expand Up @@ -49,7 +50,7 @@ export function CivitaiLinkResourceManager({
)
);
// const activities: Response[] = [];
const { data, refetch, isFetched, isFetching } = trpc.model.getDownloadCommand.useQuery(
const { data, refetch, isFetched, isFetching, error } = trpc.model.getDownloadCommand.useQuery(
{ modelId, modelVersionId },
{
enabled: false,
Expand All @@ -59,6 +60,10 @@ export function CivitaiLinkResourceManager({
}
);

useEffect(() => {
if (error) showErrorNotification({ error });
}, [error]);

if (!connected || !supportedModelTypes.includes(modelType) || !hashes || !hashes.length)
return fallback ?? null;

Expand All @@ -71,7 +76,7 @@ export function CivitaiLinkResourceManager({
if (resource) return;
if (!isFetched) refetch();
else if (data) runAddCommands(data.commands);
else showNotification({ message: 'Could not get commands' });
else showNotification({ message: `Could not get commands` });
};

const cancelDownload = () => {
Expand Down
11 changes: 10 additions & 1 deletion src/components/Image/ById/ImageById.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@ import { Card, Center, Loader } from '@mantine/core';
import { AspectRatioImageCard } from '~/components/CardTemplates/AspectRatioImageCard';
import { trpc } from '~/utils/trpc';

export const ImageById = ({ imageId, ...props }: { imageId: number; className?: string }) => {
export const ImageById = ({
imageId,
explain,
...props
}: {
imageId: number;
className?: string;
explain?: boolean;
}) => {
const { data: image, isLoading } = trpc.image.get.useQuery({ id: imageId });

if (isLoading || !image) {
Expand All @@ -21,6 +29,7 @@ export const ImageById = ({ imageId, ...props }: { imageId: number; className?:
href={`/images/${image.id}`}
target="_blank"
aspectRatio="square"
explain={explain}
{...props}
/>
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ import { numberWithCommas } from '~/utils/number-helpers';
import { getDisplayName, hashify, parseAIR } from '~/utils/string-helpers';
import { trpc } from '~/utils/trpc';
import { isDefined } from '~/utils/type-guards';
import { Priority } from '@civitai/client';

let total = 0;
const tips = {
Expand Down Expand Up @@ -640,7 +639,7 @@ export function GenerationFormContent() {
if (!remixOfId || !remixPrompt || !remixSimilarity) return <></>;

return (
<div className="radius-md my-2 flex flex-col gap-2 overflow-hidden">
<div className="my-2 flex flex-col gap-2 overflow-hidden rounded-md">
<div
className={clsx('flex rounded-md', {
'border-2 border-red-500': remixSimilarity < 0.75,
Expand All @@ -649,7 +648,8 @@ export function GenerationFormContent() {
<div className=" flex-none">
<ImageById
imageId={remixOfId}
className="h-28 rounded-none rounded-l-md"
className="h-28 rounded-none rounded-l-md"
explain={false}
/>
</div>
<div className="h-28 flex-1">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ import {
import { trpc } from '~/utils/trpc';

import { UseTRPCQueryResult } from '@trpc/react-query/shared';
import { useCurrentUser } from '~/hooks/useCurrentUser';
import { GenerationWhatIfResponse } from '~/server/services/orchestrator/types';
import { parseAIR } from '~/utils/string-helpers';
import { useCurrentUser } from '~/hooks/useCurrentUser';
import { isDefined } from '~/utils/type-guards';
// import { useFeatureFlags } from '~/providers/FeatureFlagsProvider';

Expand Down Expand Up @@ -44,8 +44,10 @@ export function TextToImageWhatIfProvider({ children }: { children: React.ReactN
const { model, resources = [], vae, ...params } = watched;
if (params.aspectRatio) {
const size = getSizeFromAspectRatio(Number(params.aspectRatio), params.baseModel);
params.width = size.width;
params.height = size.height;
if (size) {
params.width = size.width;
params.height = size.height;
}
}

let modelId = defaultModel.id;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,10 @@ function MinimaxImg2VidGenerationForm() {
);
}

function LightricksPromptDescription() {
const url = 'https://education.civitai.com/civitais-quickstart-guide-to-lightricks-ltxv/#prompting'
return <span>If you see poor results, please refer to the <Anchor href={url} target="_blank">prompt guide</Anchor></span>
}
function LightricksTxt2VidGenerationForm() {
return (
<FormWrapper engine="lightricks">
Expand All @@ -419,6 +423,7 @@ function LightricksTxt2VidGenerationForm() {
label="Prompt"
placeholder="Your prompt goes here..."
autosize
description={LightricksPromptDescription()}
/>
<InputTextArea name="negativePrompt" label="Negative Prompt" autosize />
<InputAspectRatioColonDelimited
Expand Down Expand Up @@ -493,7 +498,7 @@ function LightricksTxt2VidGenerationForm() {
function LightricksImg2VidGenerationForm() {
return (
<FormWrapper engine="lightricks">
<InputTextArea name="prompt" label="Prompt" placeholder="Your prompt goes here..." autosize />
<InputTextArea name="prompt" label="Prompt" placeholder="Your prompt goes here..." autosize description={LightricksPromptDescription()}/>
<InputTextArea name="negativePrompt" label="Negative Prompt" autosize />
<div className="flex flex-col gap-0.5">
<Input.Label>Duration</Input.Label>
Expand Down
61 changes: 23 additions & 38 deletions src/components/Signals/SignalsProvider.tsx
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
import { MantineColor, Notification, NotificationProps } from '@mantine/core';
import { useSession } from 'next-auth/react';
import { createContext, useContext, useEffect, useRef, useState } from 'react';
import { SignalNotifications } from '~/components/Signals/SignalsNotifications';
import { SignalsRegistrar } from '~/components/Signals/SignalsRegistrar';
import { useFeatureFlags } from '~/providers/FeatureFlagsProvider';
import { SignalMessages } from '~/server/common/enums';
import { SignalStatus } from '~/utils/signals/types';
// import { createSignalWorker, SignalWorker } from '~/utils/signals';
import { useSignalsWorker, SignalWorker, SignalStatus } from '~/utils/signals/useSignalsWorker';
import { useSignalsWorker, SignalWorker } from '~/utils/signals/useSignalsWorker';
import { trpc } from '~/utils/trpc';

type SignalState = {
connected: boolean;
status?: SignalStatus;
status: SignalStatus | null;
worker: SignalWorker | null;
};

const signalStatusDictionary: Record<SignalStatus, MantineColor> = {
connected: 'green',
reconnected: 'green',
reconnecting: 'yellow',
error: 'red',
closed: 'red',
};

Expand Down Expand Up @@ -50,38 +47,30 @@ export const useSignalConnection = (message: SignalMessages, cb: SignalCallback)
};

export function SignalProvider({ children }: { children: React.ReactNode }) {
const session = useSession();
const queryUtils = trpc.useUtils();
const features = useFeatureFlags();

const [status, setStatus] = useState<SignalStatus>();

const { data } = trpc.signals.getToken.useQuery(undefined, {
enabled: !!session.data?.user && features.signal,
const prevStatusRef = useRef<SignalStatus | null>(null);
const hasConnectedAtLeastOnceRef = useRef(false);

const [status, setStatus] = useState<SignalStatus | null>(null);
prevStatusRef.current = status ?? null;

const worker = useSignalsWorker({
onStateChange: ({ state }) => {
const prevStatus = prevStatusRef.current;
const hasConnectedAtLeastOnce = hasConnectedAtLeastOnceRef.current;
if (prevStatus !== state && state === 'connected' && hasConnectedAtLeastOnce) {
queryUtils.buzz.getBuzzAccount.invalidate();
queryUtils.orchestrator.queryGeneratedImages.invalidate();
}

if (state === 'connected') hasConnectedAtLeastOnceRef.current = true;
setStatus(state);
},
});

const accessToken = data?.accessToken;
const userId = session.data?.user?.id;

const worker = useSignalsWorker(
{ accessToken },
{
onReconnected: () => {
if (userId) {
queryUtils.buzz.getBuzzAccount.invalidate();
queryUtils.orchestrator.queryGeneratedImages.invalidate();
}
},
onError: () => {
queryUtils.signals.getToken.invalidate();
},
onStatusChange: ({ status }) => setStatus(status),
}
);

const connected = status === 'connected' || status === 'reconnected';
const connected = status === 'connected';

return features.signal ? (
return (
<SignalContext.Provider
value={{
connected,
Expand All @@ -93,10 +82,6 @@ export function SignalProvider({ children }: { children: React.ReactNode }) {
<SignalsRegistrar />
{children}
</SignalContext.Provider>
) : (
<SignalContext.Provider value={{ connected: false, worker: null }}>
{children}
</SignalContext.Provider>
);
}

Expand Down
2 changes: 1 addition & 1 deletion src/pages/api/admin/test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import { getSystemPermissions } from '~/server/services/system-cache';
import { addGenerationEngine } from '~/server/services/generation/engines';
import { dbWrite, dbRead } from '~/server/db/client';
import { limitConcurrency, Task } from '~/server/utils/concurrency-helpers';
import { getGenerationResourceData } from '~/server/services/generation/generation.service';
import { getResourceData } from '~/server/services/generation/generation.service';
import { Prisma } from '@prisma/client';
import { getCommentsThreadDetails2 } from '~/server/services/commentsv2.service';

Expand Down
4 changes: 2 additions & 2 deletions src/server/controllers/model-version.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ import { dbRead } from '../db/client';
import { modelFileSelect } from '../selectors/modelFile.selector';
import { getFilesByEntity } from '../services/file.service';
import { createFile } from '../services/model-file.service';
import { getGenerationResourceData } from './../services/generation/generation.service';
import { getResourceData } from './../services/generation/generation.service';

export const getModelVersionRunStrategiesHandler = ({ input: { id } }: { input: GetByIdInput }) => {
try {
Expand Down Expand Up @@ -152,7 +152,7 @@ export const getModelVersionHandler = async ({
});

const recommendedResourceIds = version?.recommendedResources.map((x) => x.id) ?? [];
const generationResources = await getGenerationResourceData({
const generationResources = await getResourceData({
ids: recommendedResourceIds,
user: ctx?.user,
}).then((data) =>
Expand Down
10 changes: 6 additions & 4 deletions src/server/controllers/model.controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ import { isDefined } from '~/utils/type-guards';
import { redis, REDIS_KEYS } from '../redis/client';
import { BountyDetailsSchema } from '../schema/bounty.schema';
import {
getGenerationResourceData,
getResourceData,
getUnavailableResources,
} from '../services/generation/generation.service';

Expand Down Expand Up @@ -187,7 +187,7 @@ export const getModelHandler = async ({ input, ctx }: { input: GetByIdInput; ctx
const recommendedResourceIds =
model.modelVersions.flatMap((version) => version?.recommendedResources.map((x) => x.id)) ??
[];
const generationResources = await getGenerationResourceData({
const generationResources = await getResourceData({
ids: recommendedResourceIds,
user: ctx?.user,
});
Expand Down Expand Up @@ -775,9 +775,11 @@ export const getDownloadCommandHandler = async ({
});
if (!modelVersion) throw throwNotFoundError();

const isDownloadable = modelVersion.usageControl !== ModelUsageControl.Download;
const isOwner = ctx.user?.id === modelVersion.model.userId;
const isDownloadable =
modelVersion.usageControl === ModelUsageControl.Download || isOwner || ctx.user?.isModerator;

if (!isDownloadable && !(modelVersion.model.userId === ctx.user?.id || ctx.user?.isModerator)) {
if (!isDownloadable) {
throw throwAuthorizationError();
}

Expand Down
4 changes: 2 additions & 2 deletions src/server/orchestrator/lightricks/lightricks.schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import {
} from '~/server/orchestrator/infrastructure/base.schema';
import { numberEnum } from '~/utils/zod-helpers';

export const lightricksAspectRatios = ['16:9', '1:1', '9:16'] as const;
export const lightricksAspectRatios = ['16:9', '1:1'] as const;
export const lightricksDuration = [5, 10] as const;

const lightricksTxt2VidSchema = textEnhancementSchema.extend({
engine: z.literal('lightricks'),
workflow: z.string(),
negativePrompt: negativePromptSchema,
aspectRatio: z.enum(lightricksAspectRatios).default('1:1').catch('1:1'),
aspectRatio: z.enum(lightricksAspectRatios).default('16:9').catch('16:9'),
duration: numberEnum(lightricksDuration).default(5).catch(5),
cfgScale: z.number().min(3).max(3.5).default(3).catch(3),
steps: z.number().min(20).max(30).default(25).catch(25),
Expand Down
Loading