Skip to content

Commit 677194c

Browse files
Aman Shawaandrew-me
Aman Shaw
authored andcommitted
feat(imagegen): add support for quite flag for image generation
- feat: add support for quite flag during image generation - refactor: move imaggen logic to imagegen package - exit with exit code 1 when image generation fail
1 parent dcd145f commit 677194c

File tree

3 files changed

+146
-113
lines changed

3 files changed

+146
-113
lines changed

helper.go

+1-96
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55
"encoding/json"
66
"fmt"
77
"io"
8-
url_package "net/url"
98
"os"
109
"os/exec"
1110
"path/filepath"
@@ -17,7 +16,6 @@ import (
1716
"github.com/aandrew-me/tgpt/v2/providers"
1817
"github.com/aandrew-me/tgpt/v2/providers/gemini"
1918
"github.com/aandrew-me/tgpt/v2/structs"
20-
"github.com/aandrew-me/tgpt/v2/utils"
2119
http "github.com/bogdanfinn/fhttp"
2220

2321
"github.com/olekukonko/ts"
@@ -547,90 +545,6 @@ func handleStatus400(resp *http.Response) {
547545
// }
548546
// }
549547

550-
551-
func generateImagePollinations(prompt string) {
552-
bold.Println("Generating image with pollinations.ai...")
553-
554-
client, err := client.NewClient()
555-
if err != nil {
556-
fmt.Fprintln(os.Stderr, err)
557-
os.Exit(1)
558-
}
559-
560-
full_prompt := url_package.QueryEscape(prompt);
561-
562-
randId := utils.RandomString(20)
563-
filename := randId + ".jpg"
564-
565-
model := "flux"
566-
567-
if *apiModel != "" {
568-
model = *apiModel
569-
}
570-
571-
fmt.Println()
572-
573-
link := fmt.Sprintf("https://image.pollinations.ai/prompt/%v", full_prompt)
574-
575-
params := url_package.Values{}
576-
577-
seed := utils.GenerateRandomNumber(5)
578-
579-
params.Add("model", model)
580-
params.Add("width", "1024")
581-
params.Add("height", "1024")
582-
params.Add("nologo", "true")
583-
params.Add("safe", "false")
584-
params.Add("nsfw", "true")
585-
params.Add("isChild", "false")
586-
params.Add("seed", seed)
587-
588-
urlObj, err := url_package.Parse(link)
589-
if err != nil {
590-
fmt.Println("Error parsing URL:", err)
591-
return
592-
}
593-
594-
urlObj.RawQuery = params.Encode()
595-
596-
597-
req, _ := http.NewRequest("GET", urlObj.String(), nil)
598-
599-
res, err := client.Do(req)
600-
601-
if err != nil {
602-
fmt.Fprint(os.Stderr, err);
603-
}
604-
605-
defer res.Body.Close()
606-
607-
608-
if res.StatusCode == http.StatusOK {
609-
file, err := os.Create(filename)
610-
if err != nil {
611-
fmt.Fprintf(os.Stderr, "Error: %v", err)
612-
613-
return
614-
}
615-
defer file.Close()
616-
617-
// Copy the response body (image data) to the file
618-
_, err = io.Copy(file, res.Body)
619-
if err != nil {
620-
fmt.Fprintf(os.Stderr,"Error: %v", err)
621-
622-
return
623-
}
624-
625-
fmt.Printf("Saved image as %v\n", filename)
626-
} else {
627-
body, _ := io.ReadAll(res.Body)
628-
responseText := string(body)
629-
630-
fmt.Fprintf(os.Stderr,"Some error has occurred. Try again (perhaps with a different model).\nError: %v", responseText)
631-
}
632-
}
633-
634548
func downloadImage(url string, destDir string) error {
635549
client, err := client.NewClient()
636550
if err != nil {
@@ -799,13 +713,4 @@ func makeRequestAndGetData(input string, params structs.Params, extraOptions str
799713
}
800714

801715
return ""
802-
}
803-
804-
func generateImg(prompt string, provider string) {
805-
if provider == "pollinations" || provider == "" {
806-
generateImagePollinations(prompt)
807-
808-
} else {
809-
fmt.Fprintln(os.Stderr, "Such a provider doesn't exist")
810-
}
811-
}
716+
}

imagegen/imagegen.go

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package imagegen
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"os"
7+
8+
url_package "net/url"
9+
10+
http "github.com/bogdanfinn/fhttp"
11+
12+
"github.com/aandrew-me/tgpt/v2/client"
13+
"github.com/aandrew-me/tgpt/v2/structs"
14+
"github.com/aandrew-me/tgpt/v2/utils"
15+
"github.com/fatih/color"
16+
)
17+
18+
var bold = color.New(color.Bold)
19+
20+
func GenerateImg(prompt string, params structs.Params, isQuite bool) {
21+
if params.Provider == "pollinations" || params.Provider == "" {
22+
if !isQuite {
23+
bold.Println("Generating image with pollinations.ai...")
24+
}
25+
filename := generateImagePollinations(prompt, params)
26+
if !isQuite {
27+
fmt.Printf("Saved image as %v\n", filename)
28+
}
29+
30+
} else {
31+
fmt.Fprintln(os.Stderr, "Such a provider doesn't exist")
32+
}
33+
}
34+
35+
func generateImagePollinations(prompt string, params structs.Params) string {
36+
37+
client, err := client.NewClient()
38+
if err != nil {
39+
fmt.Fprintln(os.Stderr, err)
40+
os.Exit(1)
41+
}
42+
43+
full_prompt := url_package.QueryEscape(prompt)
44+
45+
randId := utils.RandomString(20)
46+
filename := randId + ".jpg"
47+
48+
model := "flux"
49+
50+
if params.ApiModel != "" {
51+
model = params.ApiModel
52+
}
53+
54+
link := fmt.Sprintf("https://image.pollinations.ai/prompt/%v", full_prompt)
55+
56+
queryParams := url_package.Values{}
57+
58+
seed := utils.GenerateRandomNumber(5)
59+
60+
queryParams.Add("model", model)
61+
queryParams.Add("width", "1024")
62+
queryParams.Add("height", "1024")
63+
queryParams.Add("nologo", "true")
64+
queryParams.Add("safe", "false")
65+
queryParams.Add("nsfw", "true")
66+
queryParams.Add("isChild", "false")
67+
queryParams.Add("seed", seed)
68+
69+
urlObj, err := url_package.Parse(link)
70+
if err != nil {
71+
fmt.Println("Error parsing URL:", err)
72+
os.Exit(1)
73+
}
74+
75+
urlObj.RawQuery = queryParams.Encode()
76+
77+
req, _ := http.NewRequest("GET", urlObj.String(), nil)
78+
79+
res, err := client.Do(req)
80+
81+
if err != nil {
82+
fmt.Fprint(os.Stderr, err)
83+
}
84+
85+
defer res.Body.Close()
86+
87+
if res.StatusCode != http.StatusOK {
88+
body, _ := io.ReadAll(res.Body)
89+
responseText := string(body)
90+
91+
fmt.Fprintf(os.Stderr, "Some error has occurred. Try again (perhaps with a different model).\nError: %v", responseText)
92+
93+
}
94+
95+
file, err := os.Create(filename)
96+
if err != nil {
97+
fmt.Fprintf(os.Stderr, "Error: %v", err)
98+
os.Exit(1)
99+
}
100+
defer file.Close()
101+
102+
// Copy the response body (image data) to the file
103+
_, err = io.Copy(file, res.Body)
104+
if err != nil {
105+
fmt.Fprintf(os.Stderr, "Error: %v", err)
106+
os.Exit(1)
107+
}
108+
109+
return filename
110+
111+
}

main.go

+34-17
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"strings"
1313
"syscall"
1414

15+
"github.com/aandrew-me/tgpt/v2/imagegen"
1516
"github.com/aandrew-me/tgpt/v2/structs"
1617
"github.com/aandrew-me/tgpt/v2/utils"
1718
"github.com/atotto/clipboard"
@@ -168,6 +169,38 @@ func main() {
168169
fmt.Println("tgpt", localVersion)
169170
case *isChangelog:
170171
getVersionHistory()
172+
case *isImage:
173+
params := structs.Params{
174+
ApiKey: *apiKey,
175+
ApiModel: *apiModel,
176+
Provider: *provider,
177+
Max_length: *max_length,
178+
Temperature: *temperature,
179+
Top_p: *top_p,
180+
Preprompt: *preprompt,
181+
Url: *url,
182+
PrevMessages: "",
183+
ThreadID: "",
184+
}
185+
186+
if len(prompt) > 1 {
187+
trimmedPrompt := strings.TrimSpace(prompt)
188+
if len(trimmedPrompt) < 1 {
189+
fmt.Fprintln(os.Stderr, "You need to provide some text")
190+
fmt.Fprintln(os.Stderr, `Example: tgpt -img "cat"`)
191+
os.Exit(1)
192+
}
193+
194+
imagegen.GenerateImg(trimmedPrompt, params, *isQuiet)
195+
196+
} else {
197+
formattedInput := getFormattedInputStdin()
198+
if !*isQuiet {
199+
fmt.Println()
200+
}
201+
202+
imagegen.GenerateImg(formattedInput, params, *isQuiet)
203+
}
171204
case *isWhole:
172205
if len(prompt) > 1 {
173206
trimmedPrompt := strings.TrimSpace(prompt)
@@ -337,22 +370,6 @@ func main() {
337370

338371
}
339372

340-
case *isImage:
341-
if len(prompt) > 1 {
342-
trimmedPrompt := strings.TrimSpace(prompt)
343-
if len(trimmedPrompt) < 1 {
344-
fmt.Fprintln(os.Stderr, "You need to provide some text")
345-
fmt.Fprintln(os.Stderr, `Example: tgpt -img "cat"`)
346-
os.Exit(1)
347-
}
348-
349-
generateImg(trimmedPrompt, *provider)
350-
} else {
351-
formattedInput := getFormattedInputStdin()
352-
fmt.Println()
353-
354-
generateImg(formattedInput, *provider)
355-
}
356373
case *isHelp:
357374
showHelpMessage()
358375
default:
@@ -549,7 +566,7 @@ func showHelpMessage() {
549566

550567
bold.Println("\nProvider: isou")
551568
fmt.Println("Free provider with web search")
552-
569+
553570
bold.Println("\nProvider: koboldai")
554571
fmt.Println("Uses koboldcpp/HF_SPACE_Tiefighter-13B only, answers from novels")
555572

0 commit comments

Comments
 (0)