From eb2bfa92d1519bafadfabb96b40ac933d7314c31 Mon Sep 17 00:00:00 2001 From: Nathan Sarrazin Date: Wed, 12 Jun 2024 18:39:22 +0200 Subject: [PATCH] Add SD3 as a tool --- src/lib/server/tools/images/generation.ts | 64 ++++++++++------------- 1 file changed, 28 insertions(+), 36 deletions(-) diff --git a/src/lib/server/tools/images/generation.ts b/src/lib/server/tools/images/generation.ts index 9e1cecaf47b..27f37c6bc7f 100644 --- a/src/lib/server/tools/images/generation.ts +++ b/src/lib/server/tools/images/generation.ts @@ -3,14 +3,8 @@ import { uploadFile } from "../../files/uploadFile"; import { MessageUpdateType } from "$lib/types/MessageUpdate"; import { callSpace, getIpToken, type GradioImage } from "../utils"; -type ImageGenerationInput = [ - number /* number (numeric value between 1 and 8) in 'Number of Images' Slider component */, - number /* number in 'Image Height' Number component */, - number /* number in 'Image Width' Number component */, - string /* prompt */, - number /* seed random */ -]; -type ImageGenerationOutput = [{ image: GradioImage }[]]; +type ImageGenerationInput = [string, string, number, boolean, number, number, number, number]; +type ImageGenerationOutput = [GradioImage, unknown]; const imageGeneration: BackendTool = { name: "image_generation", @@ -24,11 +18,12 @@ const imageGeneration: BackendTool = { type: "string", required: true, }, - numberOfImages: { - description: "Number of images to generate, between 1 and 8.", - type: "number", + negativePrompt: { + description: + "A prompt for things that should not be in the image. Simple terms, separate terms with a comma.", + type: "string", required: false, - default: 1, + default: "", }, width: { description: "Width of the generated image.", @@ -43,41 +38,38 @@ const imageGeneration: BackendTool = { default: 1024, }, }, - async *call({ prompt, numberOfImages, width, height }, { conv, ip, username }) { + async *call({ prompt, negativePrompt, width, height }, { conv, ip, username }) { const ipToken = await getIpToken(ip, username); const outputs = await callSpace( - "ByteDance/Hyper-SDXL-1Step-T2I", - "/process_image", + "stabilityai/stable-diffusion-3-medium", + "/infer", [ - Number(numberOfImages), // number (numeric value between 1 and 8) in 'Number of Images' Slider component - Number(height), // number in 'Image Height' Number component - Number(width), // number in 'Image Width' Number component String(prompt), // prompt + String(negativePrompt), Math.floor(Math.random() * 1000), // seed random + true, // randomize seed + Number(width), // number in 'Image Width' Number component + Number(height), // number in 'Image Height' Number component + 5, + 28, ], ipToken ); - const imageBlobs = await Promise.all( - outputs[0].map((output) => - fetch(output.image.url) - .then((res) => res.blob()) - .then( - (blob) => - new File([blob], `${prompt}.${blob.type.split("/")[1] ?? "png"}`, { type: blob.type }) - ) - .then((file) => uploadFile(file, conv)) + const image = await fetch(outputs[0].url) + .then((res) => res.blob()) + .then( + (blob) => + new File([blob], `${prompt}.${blob.type.split("/")[1] ?? "png"}`, { type: blob.type }) ) - ); + .then((file) => uploadFile(file, conv)); - for (const image of imageBlobs) { - yield { - type: MessageUpdateType.File, - name: image.name, - sha: image.value, - mime: image.mime, - }; - } + yield { + type: MessageUpdateType.File, + name: image.name, + sha: image.value, + mime: image.mime, + }; return { outputs: [