Skip to content

Commit ed6b068

Browse files
Aman Shawaandrew-me
Aman Shaw
authored andcommitted
feat: add support for image generation with custom dimention and custom output filepath
1 parent 677194c commit ed6b068

File tree

3 files changed

+54
-26
lines changed

3 files changed

+54
-26
lines changed

Diff for: imagegen/imagegen.go

+15-11
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"io"
66
"os"
7+
"strconv"
78

89
url_package "net/url"
910

@@ -17,7 +18,7 @@ import (
1718

1819
var bold = color.New(color.Bold)
1920

20-
func GenerateImg(prompt string, params structs.Params, isQuite bool) {
21+
func GenerateImg(prompt string, params structs.ImageParams, isQuite bool) {
2122
if params.Provider == "pollinations" || params.Provider == "" {
2223
if !isQuite {
2324
bold.Println("Generating image with pollinations.ai...")
@@ -29,10 +30,11 @@ func GenerateImg(prompt string, params structs.Params, isQuite bool) {
2930

3031
} else {
3132
fmt.Fprintln(os.Stderr, "Such a provider doesn't exist")
33+
os.Exit(1)
3234
}
3335
}
3436

35-
func generateImagePollinations(prompt string, params structs.Params) string {
37+
func generateImagePollinations(prompt string, params structs.ImageParams) string {
3638

3739
client, err := client.NewClient()
3840
if err != nil {
@@ -42,11 +44,13 @@ func generateImagePollinations(prompt string, params structs.Params) string {
4244

4345
full_prompt := url_package.QueryEscape(prompt)
4446

45-
randId := utils.RandomString(20)
46-
filename := randId + ".jpg"
47+
filepath := params.Out
48+
if filepath == "" {
49+
randId := utils.RandomString(20)
50+
filepath = randId + ".jpg"
51+
}
4752

4853
model := "flux"
49-
5054
if params.ApiModel != "" {
5155
model = params.ApiModel
5256
}
@@ -58,8 +62,8 @@ func generateImagePollinations(prompt string, params structs.Params) string {
5862
seed := utils.GenerateRandomNumber(5)
5963

6064
queryParams.Add("model", model)
61-
queryParams.Add("width", "1024")
62-
queryParams.Add("height", "1024")
65+
queryParams.Add("width", strconv.Itoa(params.Width))
66+
queryParams.Add("height", strconv.Itoa(params.Height))
6367
queryParams.Add("nologo", "true")
6468
queryParams.Add("safe", "false")
6569
queryParams.Add("nsfw", "true")
@@ -80,6 +84,7 @@ func generateImagePollinations(prompt string, params structs.Params) string {
8084

8185
if err != nil {
8286
fmt.Fprint(os.Stderr, err)
87+
os.Exit(1)
8388
}
8489

8590
defer res.Body.Close()
@@ -89,10 +94,10 @@ func generateImagePollinations(prompt string, params structs.Params) string {
8994
responseText := string(body)
9095

9196
fmt.Fprintf(os.Stderr, "Some error has occurred. Try again (perhaps with a different model).\nError: %v", responseText)
92-
97+
os.Exit(1)
9398
}
9499

95-
file, err := os.Create(filename)
100+
file, err := os.Create(filepath)
96101
if err != nil {
97102
fmt.Fprintf(os.Stderr, "Error: %v", err)
98103
os.Exit(1)
@@ -106,6 +111,5 @@ func generateImagePollinations(prompt string, params structs.Params) string {
106111
os.Exit(1)
107112
}
108113

109-
return filename
110-
114+
return filepath
111115
}

Diff for: main.go

+32-15
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ var preprompt *string
4545
var url *string
4646
var logFile *string
4747
var shouldExecuteCommand *bool
48+
var out *string
49+
var height *int
50+
var width *int
4851

4952
func main() {
5053
execPath, err := os.Executable()
@@ -67,6 +70,9 @@ func main() {
6770
top_p = flag.String("top_p", "", "Set top_p")
6871
max_length = flag.String("max_length", "", "Set max length of response")
6972
preprompt = flag.String("preprompt", "", "Set preprompt")
73+
out = flag.String("out", "", "Output file path")
74+
width = flag.Int("width", 1024, "Output image width")
75+
height = flag.Int("height", 1024, "Output image height")
7076

7177
defaultUrl := ""
7278
if *provider == "openai" {
@@ -170,17 +176,22 @@ func main() {
170176
case *isChangelog:
171177
getVersionHistory()
172178
case *isImage:
173-
params := structs.Params{
174-
ApiKey: *apiKey,
175-
ApiModel: *apiModel,
176-
Provider: *provider,
177-
Max_length: *max_length,
178-
Temperature: *temperature,
179-
Top_p: *top_p,
180-
Preprompt: *preprompt,
181-
Url: *url,
182-
PrevMessages: "",
183-
ThreadID: "",
179+
params := structs.ImageParams{
180+
Params: structs.Params{
181+
ApiKey: *apiKey,
182+
ApiModel: *apiModel,
183+
Provider: *provider,
184+
Max_length: *max_length,
185+
Temperature: *temperature,
186+
Top_p: *top_p,
187+
Preprompt: *preprompt,
188+
Url: *url,
189+
PrevMessages: "",
190+
ThreadID: "",
191+
},
192+
Width: *width,
193+
Height: *height,
194+
Out: *out,
184195
}
185196

186197
if len(prompt) > 1 {
@@ -223,11 +234,11 @@ func main() {
223234
fmt.Fprintln(os.Stderr, `Example: tgpt -q "What is encryption?"`)
224235
os.Exit(1)
225236
}
226-
getSilentText(*preprompt + trimmedPrompt + contextText + pipedInput, structs.ExtraOptions{})
237+
getSilentText(*preprompt+trimmedPrompt+contextText+pipedInput, structs.ExtraOptions{})
227238
} else {
228239
formattedInput := getFormattedInputStdin()
229240
fmt.Println()
230-
getSilentText(*preprompt + formattedInput + cleanPipedInput, structs.ExtraOptions{})
241+
getSilentText(*preprompt+formattedInput+cleanPipedInput, structs.ExtraOptions{})
231242
}
232243
case *isShell:
233244
if len(prompt) > 1 {
@@ -381,7 +392,7 @@ func main() {
381392
os.Exit(1)
382393
}
383394

384-
getData(*preprompt+formattedInput+contextText+pipedInput, structs.Params{}, structs.ExtraOptions{IsNormal: true, IsInteractive: false, })
395+
getData(*preprompt+formattedInput+contextText+pipedInput, structs.Params{}, structs.ExtraOptions{IsNormal: true, IsInteractive: false})
385396
}
386397

387398
} else {
@@ -390,7 +401,7 @@ func main() {
390401
input := scanner.Text()
391402
go loading(&stopSpin)
392403
formattedInput := strings.TrimSpace(input)
393-
getData(*preprompt+formattedInput+pipedInput, structs.Params{}, structs.ExtraOptions{IsInteractive: false, })
404+
getData(*preprompt+formattedInput+pipedInput, structs.Params{}, structs.ExtraOptions{IsInteractive: false})
394405
}
395406
}
396407

@@ -537,6 +548,11 @@ func showHelpMessage() {
537548
fmt.Printf("%-50v Set preprompt\n", "--preprompt")
538549
fmt.Printf("%-50v Execute shell command without confirmation\n", "-y")
539550

551+
boldBlue.Println("\nOptions supported for image generation (with -image flag)")
552+
fmt.Printf("%-50v Output image filename\n", "-s, --out")
553+
fmt.Printf("%-50v Output image height\n", "-s, --height")
554+
fmt.Printf("%-50v Output image width\n", "-s, --width")
555+
540556
boldBlue.Println("\nOptions:")
541557
fmt.Printf("%-50v Print version \n", "-v, --version")
542558
fmt.Printf("%-50v Print help message \n", "-h, --help")
@@ -593,6 +609,7 @@ func showHelpMessage() {
593609
fmt.Println(`tgpt -s "How to update my system?"`)
594610
fmt.Println(`tgpt --provider duckduckgo "What is 1+1"`)
595611
fmt.Println(`tgpt --img "cat"`)
612+
fmt.Println(`tgpt --img --out ~/my-cat.jpg --height 256 --width 256 "cat"`)
596613
fmt.Println(`tgpt --provider openai --key "sk-xxxx" --model "gpt-3.5-turbo" "What is 1+1"`)
597614
fmt.Println(`cat install.sh | tgpt "Explain the code"`)
598615
}

Diff for: structs/structs.go

+7
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,10 @@ type CommonResponse struct {
3030
} `json:"delta"`
3131
} `json:"choices"`
3232
}
33+
34+
type ImageParams struct {
35+
Params
36+
Height int
37+
Width int
38+
Out string
39+
}

0 commit comments

Comments
 (0)