Skip to content

Commit

Permalink
Add SD3 as a tool
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarrazin committed Jun 12, 2024
1 parent bfd6b5f commit eb2bfa9
Showing 1 changed file with 28 additions and 36 deletions.
64 changes: 28 additions & 36 deletions src/lib/server/tools/images/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,8 @@ import { uploadFile } from "../../files/uploadFile";
import { MessageUpdateType } from "$lib/types/MessageUpdate";
import { callSpace, getIpToken, type GradioImage } from "../utils";

type ImageGenerationInput = [
number /* number (numeric value between 1 and 8) in 'Number of Images' Slider component */,
number /* number in 'Image Height' Number component */,
number /* number in 'Image Width' Number component */,
string /* prompt */,
number /* seed random */
];
type ImageGenerationOutput = [{ image: GradioImage }[]];
type ImageGenerationInput = [string, string, number, boolean, number, number, number, number];
type ImageGenerationOutput = [GradioImage, unknown];

const imageGeneration: BackendTool = {
name: "image_generation",
Expand All @@ -24,11 +18,12 @@ const imageGeneration: BackendTool = {
type: "string",
required: true,
},
numberOfImages: {
description: "Number of images to generate, between 1 and 8.",
type: "number",
negativePrompt: {
description:
"A prompt for things that should not be in the image. Simple terms, separate terms with a comma.",
type: "string",
required: false,
default: 1,
default: "",
},
width: {
description: "Width of the generated image.",
Expand All @@ -43,41 +38,38 @@ const imageGeneration: BackendTool = {
default: 1024,
},
},
async *call({ prompt, numberOfImages, width, height }, { conv, ip, username }) {
async *call({ prompt, negativePrompt, width, height }, { conv, ip, username }) {
const ipToken = await getIpToken(ip, username);

const outputs = await callSpace<ImageGenerationInput, ImageGenerationOutput>(
"ByteDance/Hyper-SDXL-1Step-T2I",
"/process_image",
"stabilityai/stable-diffusion-3-medium",
"/infer",
[
Number(numberOfImages), // number (numeric value between 1 and 8) in 'Number of Images' Slider component
Number(height), // number in 'Image Height' Number component
Number(width), // number in 'Image Width' Number component
String(prompt), // prompt
String(negativePrompt),
Math.floor(Math.random() * 1000), // seed random
true, // randomize seed
Number(width), // number in 'Image Width' Number component
Number(height), // number in 'Image Height' Number component
5,
28,
],
ipToken
);
const imageBlobs = await Promise.all(
outputs[0].map((output) =>
fetch(output.image.url)
.then((res) => res.blob())
.then(
(blob) =>
new File([blob], `${prompt}.${blob.type.split("/")[1] ?? "png"}`, { type: blob.type })
)
.then((file) => uploadFile(file, conv))
const image = await fetch(outputs[0].url)
.then((res) => res.blob())
.then(
(blob) =>
new File([blob], `${prompt}.${blob.type.split("/")[1] ?? "png"}`, { type: blob.type })
)
);
.then((file) => uploadFile(file, conv));

for (const image of imageBlobs) {
yield {
type: MessageUpdateType.File,
name: image.name,
sha: image.value,
mime: image.mime,
};
}
yield {
type: MessageUpdateType.File,
name: image.name,
sha: image.value,
mime: image.mime,
};

return {
outputs: [
Expand Down

0 comments on commit eb2bfa9

Please sign in to comment.