From cbc913bdc0b7b7ae2c8738dc26ff9dcecef46f62 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 17 May 2024 08:51:53 +0200 Subject: [PATCH] Update dall-e openai support (#198) --- examples/llm/openai/thread/main.go | 65 ++++++++++----------- examples/transformer/dalle/main.go | 7 ++- llm/openai/function.go | 13 +++-- transformer/dall-e.go | 94 +++++++++++++++++++++--------- 4 files changed, 111 insertions(+), 68 deletions(-) diff --git a/examples/llm/openai/thread/main.go b/examples/llm/openai/thread/main.go index cc21df67..5180349c 100644 --- a/examples/llm/openai/thread/main.go +++ b/examples/llm/openai/thread/main.go @@ -3,17 +3,27 @@ package main import ( "context" "fmt" + "strings" "github.com/henomis/lingoose/llm/openai" "github.com/henomis/lingoose/thread" + "github.com/henomis/lingoose/transformer" ) -type Answer struct { - Answer string `json:"answer" jsonschema:"description=the pirate answer"` +type Image struct { + Description string `json:"description" jsonschema:"description=the description of the image that should be created"` } -func getAnswer(a Answer) string { - return "🦜 ☠️ " + a.Answer +func crateImage(i Image) string { + d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize512x512) + imageURL, err := d.Transform(context.Background(), i.Description) + if err != nil { + return fmt.Errorf("error creating image: %w", err).Error() + } + + fmt.Println("Image created with url:", imageURL) + + return imageURL.(string) } func newStr(str string) *string { @@ -21,12 +31,12 @@ func newStr(str string) *string { } func main() { - openaillm := openai.New() - openaillm.WithToolChoice(newStr("getPirateAnswer")) + openaillm := openai.New().WithModel(openai.GPT4o) + openaillm.WithToolChoice(newStr("auto")) err := openaillm.BindFunction( - getAnswer, - "getPirateAnswer", - "use this function to get the pirate answer", + crateImage, + "createImage", + "use this function to create an image from a description", ) if err != nil { panic(err) @@ -34,41 +44,28 @@ func main() { t := thread.New().AddMessage( thread.NewUserMessage().AddContent( - thread.NewTextContent("Hello, I'm a user"), - ).AddContent( - thread.NewTextContent("Can you greet me?"), - ), - ).AddMessage( - thread.NewUserMessage().AddContent( - thread.NewTextContent("please greet me as a pirate."), + thread.NewTextContent("Please, create an image that inspires you"), ), ) - fmt.Println(t) - err = openaillm.Generate(context.Background(), t) if err != nil { panic(err) } - t.AddMessage(thread.NewUserMessage().AddContent( - thread.NewTextContent("now translate to italian as a poem"), - )) + if t.LastMessage().Role == thread.RoleTool { + t.AddMessage(thread.NewUserMessage().AddContent( + thread.NewImageContentFromURL( + strings.ReplaceAll(t.LastMessage().Contents[0].AsToolResponseData().Result, `"`, ""), + ), + ).AddContent( + thread.NewTextContent("can you describe the image?"), + )) - fmt.Println(t) - // disable functions - openaillm.WithToolChoice(nil) - openaillm.WithStream(true, func(a string) { - if a == openai.EOS { - fmt.Printf("\n") - return + err = openaillm.Generate(context.Background(), t) + if err != nil { + panic(err) } - fmt.Printf("%s", a) - }) - - err = openaillm.Generate(context.Background(), t) - if err != nil { - panic(err) } fmt.Println(t) diff --git a/examples/transformer/dalle/main.go b/examples/transformer/dalle/main.go index 47127cc3..cecbea34 100644 --- a/examples/transformer/dalle/main.go +++ b/examples/transformer/dalle/main.go @@ -2,16 +2,19 @@ package main import ( "context" + "fmt" "github.com/henomis/lingoose/transformer" ) func main() { - d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize1024).AsFile("test.png") + d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize1024x1024) - _, err := d.Transform(context.Background(), "a goose working with pipelines") + imageURL, err := d.Transform(context.Background(), "a goose working with pipelines") if err != nil { panic(err) } + + fmt.Println("Image created:", imageURL) } diff --git a/llm/openai/function.go b/llm/openai/function.go index 37b413b7..4b9c77bb 100644 --- a/llm/openai/function.go +++ b/llm/openai/function.go @@ -1,10 +1,12 @@ package openai import ( + "bytes" "encoding/json" "errors" "fmt" "reflect" + "strings" "github.com/invopop/jsonschema" "github.com/sashabaranov/go-openai" @@ -187,11 +189,14 @@ func callFnWithArgumentAsJSON(fn interface{}, argumentAsJSON string) (string, er // Marshal the function result to JSON if len(result) > 0 { - jsonResultData, errMarshal := json.Marshal(result[0].Interface()) - if errMarshal != nil { - return "", fmt.Errorf("error marshaling result: %w", errMarshal) + var resultBytes bytes.Buffer + enc := json.NewEncoder(&resultBytes) + enc.SetEscapeHTML(false) + err = enc.Encode(result[0].Interface()) + if err != nil { + return "", fmt.Errorf("error marshaling result: %w", err) } - return string(jsonResultData), nil + return strings.TrimSpace(resultBytes.String()), nil } return "", nil diff --git a/transformer/dall-e.go b/transformer/dall-e.go index 1c38aa74..64574b15 100644 --- a/transformer/dall-e.go +++ b/transformer/dall-e.go @@ -17,9 +17,11 @@ type DallEImageOutput any type DallEImageSize string const ( - DallEImageSize256 DallEImageSize = openai.CreateImageSize256x256 - DallEImageSize512 DallEImageSize = openai.CreateImageSize512x512 - DallEImageSize1024 DallEImageSize = openai.CreateImageSize1024x1024 + DallEImageSize256x256 DallEImageSize = openai.CreateImageSize256x256 + DallEImageSize512x512 DallEImageSize = openai.CreateImageSize512x512 + DallEImageSize1024x1024 DallEImageSize = openai.CreateImageSize1024x1024 + DallEImageSize1792x104 DallEImageSize = openai.CreateImageSize1792x1024 + DallEImageSize1024x1792 DallEImageSize = openai.CreateImageSize1024x1792 ) type DallEImageFormat string @@ -30,19 +32,45 @@ const ( DallEImageFormatImage DallEImageFormat = "image" ) +type DallEModel string + +const ( + DallEModel2 DallEModel = openai.CreateImageModelDallE2 + DallEModel3 DallEModel = openai.CreateImageModelDallE3 +) + +type DallEImageQuality string + +const ( + DallEImageQualityHD DallEImageQuality = openai.CreateImageQualityHD + DallEImageQualityStandard DallEImageQuality = openai.CreateImageQualityStandard +) + +type DallEImageStyle string + +const ( + DallEImageStyleVivid DallEImageStyle = openai.CreateImageStyleVivid + DallEImageStyleNatural DallEImageStyle = openai.CreateImageStyleNatural +) + type DallE struct { openAIClient *openai.Client + model DallEModel imageSize DallEImageSize imageFormat DallEImageFormat - imageFile string + imageStyle DallEImageStyle + imageQuality DallEImageQuality } func NewDallE() *DallE { openAIKey := os.Getenv("OPENAI_API_KEY") return &DallE{ openAIClient: openai.NewClient(openAIKey), - imageSize: DallEImageSize256, + model: DallEModel2, + imageSize: DallEImageSize256x256, imageFormat: DallEImageFormatURL, + imageStyle: DallEImageStyleNatural, + imageQuality: DallEImageQualityStandard, } } @@ -56,73 +84,83 @@ func (d *DallE) WithImageSize(imageSize DallEImageSize) *DallE { return d } -func (d *DallE) AsURL() *DallE { - d.imageFormat = DallEImageFormatURL +func (d *DallE) WithImageStyle(imageStyle DallEImageStyle) *DallE { + d.imageStyle = imageStyle return d } -func (d *DallE) AsFile(path string) *DallE { - d.imageFormat = DallEImageFormatFile - d.imageFile = path +func (d *DallE) WithImageQuality(imageQuality DallEImageQuality) *DallE { + d.imageQuality = imageQuality return d } -func (d *DallE) AsImage() *DallE { - d.imageFormat = DallEImageFormatImage +func (d *DallE) WithModel(model DallEModel) *DallE { + d.model = model + return d +} + +func (d *DallE) WithImageFormat(imageFormat DallEImageFormat) *DallE { + d.imageFormat = imageFormat return d } func (d *DallE) Transform(ctx context.Context, input string) (any, error) { switch d.imageFormat { case DallEImageFormatURL: - return d.transformToURL(ctx, input) + return d.TransformAsURL(ctx, input) case DallEImageFormatFile: - return d.transformToFile(ctx, input) + return d.TransformAsFile(ctx, input, nil) case DallEImageFormatImage: - return d.transformToImage(ctx, input) + return d.TransformToImage(ctx, input) default: return "", fmt.Errorf("unknown image format: %s", d.imageFormat) } } -func (d *DallE) transformToURL(ctx context.Context, input string) (any, error) { +func (d *DallE) TransformAsURL(ctx context.Context, input string) (string, error) { reqURL := openai.ImageRequest{ Prompt: input, + Model: string(d.model), Size: string(d.imageSize), + Quality: string(d.imageQuality), + Style: string(d.imageStyle), ResponseFormat: openai.CreateImageResponseFormatURL, N: 1, } respURL, err := d.openAIClient.CreateImage(ctx, reqURL) if err != nil { - return nil, err + return "", err } return respURL.Data[0].URL, nil } -func (d *DallE) transformToFile(ctx context.Context, input string) (any, error) { - imgData, err := d.transformToImage(ctx, input) +func (d *DallE) TransformAsFile(ctx context.Context, input string, file *os.File) (string, error) { + imgData, err := d.TransformToImage(ctx, input) if err != nil { - return nil, err + return "", err } - file, err := os.Create(d.imageFile) - if err != nil { - return nil, err + if file == nil { + // create a temporary file + file, err = os.CreateTemp("", "dall-e-*.png") + if err != nil { + return "", err + } } + defer file.Close() - err = png.Encode(file, imgData.(image.Image)) + err = png.Encode(file, imgData) if err != nil { - return nil, err + return "", err } - var output interface{} - return output, nil + return file.Name(), nil } -func (d *DallE) transformToImage(ctx context.Context, input string) (any, error) { +func (d *DallE) TransformToImage(ctx context.Context, input string) (image.Image, error) { reqBase64 := openai.ImageRequest{ Prompt: input, Size: string(d.imageSize),