diff --git a/examples/llm/openai/thread/main.go b/examples/llm/openai/thread/main.go index 5180349c..695f4714 100644 --- a/examples/llm/openai/thread/main.go +++ b/examples/llm/openai/thread/main.go @@ -2,11 +2,12 @@ package main import ( "context" + "encoding/json" "fmt" - "strings" "github.com/henomis/lingoose/llm/openai" "github.com/henomis/lingoose/thread" + "github.com/henomis/lingoose/tools/dalle" "github.com/henomis/lingoose/transformer" ) @@ -32,15 +33,7 @@ func newStr(str string) *string { func main() { openaillm := openai.New().WithModel(openai.GPT4o) - openaillm.WithToolChoice(newStr("auto")) - err := openaillm.BindFunction( - crateImage, - "createImage", - "use this function to create an image from a description", - ) - if err != nil { - panic(err) - } + openaillm.WithToolChoice(newStr("auto")).WithTools(dalle.New()) t := thread.New().AddMessage( thread.NewUserMessage().AddContent( @@ -48,15 +41,22 @@ func main() { ), ) - err = openaillm.Generate(context.Background(), t) + err := openaillm.Generate(context.Background(), t) if err != nil { panic(err) } if t.LastMessage().Role == thread.RoleTool { + var output dalle.Output + + err = json.Unmarshal([]byte(t.LastMessage().Contents[0].AsToolResponseData().Result), &output) + if err != nil { + panic(err) + } + t.AddMessage(thread.NewUserMessage().AddContent( thread.NewImageContentFromURL( - strings.ReplaceAll(t.LastMessage().Contents[0].AsToolResponseData().Result, `"`, ""), + output.ImageURL, ), ).AddContent( thread.NewTextContent("can you describe the image?"), diff --git a/examples/llm/openai/tools/python/main.go b/examples/llm/openai/tools/python/main.go new file mode 100644 index 00000000..1e133410 --- /dev/null +++ b/examples/llm/openai/tools/python/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "context" + "fmt" + + "github.com/henomis/lingoose/llm/openai" + "github.com/henomis/lingoose/thread" + "github.com/henomis/lingoose/tools/python" +) + +func main() { + newStr := func(str string) *string { + return &str + } + llm := openai.New().WithModel(openai.GPT3Dot5Turbo0613).WithToolChoice(newStr("auto")).WithTools( + python.New(), + ) + + t := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent("calculate reverse string of 'ailatiditalia', don't try to guess, let's use appropriate tool"), + ), + ) + + llm.Generate(context.Background(), t) + if t.LastMessage().Role == thread.RoleTool { + llm.Generate(context.Background(), t) + } + + fmt.Println(t) +} diff --git a/examples/llm/openai/tools/rag/main.go b/examples/llm/openai/tools/rag/main.go new file mode 100644 index 00000000..29deadc9 --- /dev/null +++ b/examples/llm/openai/tools/rag/main.go @@ -0,0 +1,66 @@ +package main + +import ( + "context" + "fmt" + "os" + + openaiembedder "github.com/henomis/lingoose/embedder/openai" + "github.com/henomis/lingoose/index" + "github.com/henomis/lingoose/index/vectordb/jsondb" + "github.com/henomis/lingoose/llm/openai" + "github.com/henomis/lingoose/rag" + "github.com/henomis/lingoose/thread" + ragtool "github.com/henomis/lingoose/tools/rag" + "github.com/henomis/lingoose/tools/serpapi" + "github.com/henomis/lingoose/tools/shell" +) + +func main() { + + rag := rag.New( + index.New( + jsondb.New().WithPersist("index.json"), + openaiembedder.New(openaiembedder.AdaEmbeddingV2), + ), + ).WithChunkSize(1000).WithChunkOverlap(0) + + _, err := os.Stat("index.json") + if os.IsNotExist(err) { + err = rag.AddSources(context.Background(), "state_of_the_union.txt") + if err != nil { + panic(err) + } + } + + newStr := func(str string) *string { + return &str + } + llm := openai.New().WithModel(openai.GPT4o).WithToolChoice(newStr("auto")).WithTools( + ragtool.New(rag, "US covid vaccines"), + serpapi.New(), + shell.New(), + ) + + topics := []string{ + "how many covid vaccine doses US has donated to other countries.", + "who's the author of LinGoose github project.", + "which process is consuming the most memory.", + } + + for _, topic := range topics { + t := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent("Please tell me " + topic), + ), + ) + + llm.Generate(context.Background(), t) + if t.LastMessage().Role == thread.RoleTool { + llm.Generate(context.Background(), t) + } + + fmt.Println(t) + } + +} diff --git a/examples/tools/duckduckgo/main.go b/examples/tools/duckduckgo/main.go new file mode 100644 index 00000000..7be5f65c --- /dev/null +++ b/examples/tools/duckduckgo/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "fmt" + + "github.com/henomis/lingoose/tools/duckduckgo" +) + +func main() { + + t := duckduckgo.New().WithMaxResults(5) + f := t.Fn().(duckduckgo.FnPrototype) + + fmt.Println(f(duckduckgo.Input{Query: "Simone Vellei"})) +} diff --git a/examples/tools/python/main.go b/examples/tools/python/main.go new file mode 100644 index 00000000..57052fb6 --- /dev/null +++ b/examples/tools/python/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + + "github.com/henomis/lingoose/tools/python" +) + +func main() { + t := python.New().WithPythonPath("python3") + + pythonScript := `print("Hello from Python!")` + f := t.Fn().(python.FnPrototype) + + fmt.Println(f(python.Input{PythonCode: pythonScript})) +} diff --git a/examples/tools/serpapi/main.go b/examples/tools/serpapi/main.go new file mode 100644 index 00000000..d0578d3a --- /dev/null +++ b/examples/tools/serpapi/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "fmt" + + "github.com/henomis/lingoose/tools/serpapi" +) + +func main() { + + t := serpapi.New() + f := t.Fn().(serpapi.FnPrototype) + + fmt.Println(f(serpapi.Input{Query: "Simone Vellei"})) +} diff --git a/examples/tools/shell/main.go b/examples/tools/shell/main.go new file mode 100644 index 00000000..3ec73377 --- /dev/null +++ b/examples/tools/shell/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + + "github.com/henomis/lingoose/tools/shell" +) + +func main() { + t := shell.New() + + bashScript := `echo "Hello from $SHELL!"` + f := t.Fn().(shell.FnPrototype) + + fmt.Println(f(shell.Input{BashScript: bashScript})) +} diff --git a/go.mod b/go.mod index 25c30a82..4131f84e 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/henomis/restclientgo v1.2.0 github.com/invopop/jsonschema v0.7.0 github.com/sashabaranov/go-openai v1.24.0 + golang.org/x/net v0.25.0 ) require ( diff --git a/go.sum b/go.sum index e2d3391c..a79137ce 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/llm/openai/function.go b/llm/openai/function.go index 4b9c77bb..9c5953bd 100644 --- a/llm/openai/function.go +++ b/llm/openai/function.go @@ -79,6 +79,25 @@ func (o *OpenAI) BindFunction( return nil } +type Tool interface { + Description() string + Name() string + Fn() any +} + +func (o OpenAI) WithTools(tools ...Tool) OpenAI { + for _, tool := range tools { + function, err := bindFunction(tool.Fn(), tool.Name(), tool.Description()) + if err != nil { + fmt.Println(err) + } + + o.functions[tool.Name()] = *function + } + + return o +} + func (o *Legacy) getFunctions() []openai.FunctionDefinition { var functions []openai.FunctionDefinition diff --git a/tools/dalle/dalle.go b/tools/dalle/dalle.go new file mode 100644 index 00000000..03e3cb5e --- /dev/null +++ b/tools/dalle/dalle.go @@ -0,0 +1,56 @@ +package dalle + +import ( + "context" + "fmt" + "time" + + "github.com/henomis/lingoose/transformer" +) + +const ( + defaultTimeoutInSeconds = 60 +) + +type Tool struct { +} + +type Input struct { + Description string `json:"description" jsonschema:"description=the description of the image that should be created"` +} + +type Output struct { + Error string `json:"error,omitempty"` + ImageURL string `json:"imageURL,omitempty"` +} + +type FnPrototype func(Input) Output + +func New() *Tool { + return &Tool{} +} + +func (t *Tool) Name() string { + return "dalle" +} + +func (t *Tool) Description() string { + return "A tool that creates an image from a description." +} + +func (t *Tool) Fn() any { + return t.fn +} + +func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInSeconds*time.Second) + defer cancel() + + d := transformer.NewDallE().WithImageSize(transformer.DallEImageSize512x512) + imageURL, err := d.Transform(ctx, i.Description) + if err != nil { + return Output{Error: fmt.Sprintf("error creating image: %v", err)} + } + + return Output{ImageURL: imageURL.(string)} +} diff --git a/tools/duckduckgo/api.go b/tools/duckduckgo/api.go new file mode 100644 index 00000000..a301c5e7 --- /dev/null +++ b/tools/duckduckgo/api.go @@ -0,0 +1,168 @@ +package duckduckgo + +import ( + "bytes" + "io" + "regexp" + "strings" + + "github.com/henomis/restclientgo" + "golang.org/x/net/html" +) + +const ( + class = "class" +) + +type request struct { + Query string +} + +type response struct { + MaxResults uint + HTTPStatusCode int + RawBody []byte + Results []result +} + +type result struct { + Title string + Info string + URL string +} + +func (r *request) Path() (string, error) { + return "/html/?q=" + r.Query, nil +} + +func (r *request) Encode() (io.Reader, error) { + return nil, nil +} + +func (r *request) ContentType() string { + return "" +} + +func (r *response) Decode(body io.Reader) error { + results, err := r.parseBody(body) + if err != nil { + return err + } + + r.Results = results + return nil +} + +func (r *response) SetBody(body io.Reader) error { + r.RawBody, _ = io.ReadAll(body) + return nil +} + +func (r *response) AcceptContentType() string { + return "text/html" +} + +func (r *response) SetStatusCode(code int) error { + r.HTTPStatusCode = code + return nil +} + +func (r *response) SetHeaders(_ restclientgo.Headers) error { return nil } + +func (r *response) parseBody(body io.Reader) ([]result, error) { + doc, err := html.Parse(body) + if err != nil { + return nil, err + } + ch := make(chan result) + go r.findWebResults(ch, doc) + + results := []result{} + for n := range ch { + results = append(results, n) + } + + return results, nil +} + +func (r *response) findWebResults(ch chan result, doc *html.Node) { + var results uint + var f func(*html.Node) + f = func(n *html.Node) { + if results >= r.MaxResults { + return + } + if n.Type == html.ElementNode && n.Data == "div" { + for _, div := range n.Attr { + if div.Key == class && strings.Contains(div.Val, "web-result") { + info, href := r.findInfo(n) + ch <- result{ + Title: r.findTitle(n), + Info: info, + URL: href, + } + results++ + break + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + f(c) + } + } + f(doc) + close(ch) +} + +func (r *response) findTitle(n *html.Node) string { + var title string + var f func(*html.Node) + f = func(n *html.Node) { + if n.Type == html.ElementNode && n.Data == "a" { + for _, a := range n.Attr { + if a.Key == class && strings.Contains(a.Val, "result__a") { + title = n.FirstChild.Data + break + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + f(c) + } + } + f(n) + return title +} + +//nolint:gocognit +func (r *response) findInfo(n *html.Node) (string, string) { + var info string + var link string + var f func(*html.Node) + f = func(n *html.Node) { + if n.Type == html.ElementNode && n.Data == "a" { + for _, a := range n.Attr { + if a.Key == class && strings.Contains(a.Val, "result__snippet") { + var b bytes.Buffer + _ = html.Render(&b, n) + + re := regexp.MustCompile("<.*?>") + info = html.UnescapeString(re.ReplaceAllString(b.String(), "")) + + for _, h := range n.Attr { + if h.Key == "href" { + link = "https:" + h.Val + break + } + } + break + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + f(c) + } + } + f(n) + return info, link +} diff --git a/tools/duckduckgo/duckduckgo.go b/tools/duckduckgo/duckduckgo.go new file mode 100644 index 00000000..ec374942 --- /dev/null +++ b/tools/duckduckgo/duckduckgo.go @@ -0,0 +1,85 @@ +package duckduckgo + +import ( + "context" + "fmt" + "net/http" + "time" + + "github.com/henomis/restclientgo" +) + +const ( + defaultTimeoutInSeconds = 60 +) + +type Tool struct { + maxResults uint + userAgent string + restClient *restclientgo.RestClient +} + +type Input struct { + Query string `json:"query" jsonschema:"description=the query to search for"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Results []result `json:"results,omitempty"` +} + +type FnPrototype func(Input) Output + +func New() *Tool { + t := &Tool{ + maxResults: 1, + } + + restClient := restclientgo.New("https://html.duckduckgo.com"). + WithRequestModifier( + func(r *http.Request) *http.Request { + r.Header.Add("User-Agent", t.userAgent) + return r + }, + ) + + t.restClient = restClient + return t +} + +func (t *Tool) WithUserAgent(userAgent string) *Tool { + t.userAgent = userAgent + return t +} + +func (t *Tool) WithMaxResults(maxResults uint) *Tool { + t.maxResults = maxResults + return t +} + +func (t *Tool) Name() string { + return "duckduckgo" +} + +func (t *Tool) Description() string { + return "A tool that uses the DuckDuckGo internet search engine for a query." +} + +func (t *Tool) Fn() any { + return t.fn +} + +func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInSeconds*time.Second) + defer cancel() + + req := &request{Query: i.Query} + res := &response{MaxResults: t.maxResults} + + err := t.restClient.Get(ctx, req, res) + if err != nil { + return Output{Error: fmt.Sprintf("failed to search DuckDuckGo: %v", err)} + } + + return Output{Results: res.Results} +} diff --git a/tools/llm/llm.go b/tools/llm/llm.go new file mode 100644 index 00000000..4f6190cb --- /dev/null +++ b/tools/llm/llm.go @@ -0,0 +1,68 @@ +package llm + +import ( + "context" + "time" + + "github.com/henomis/lingoose/thread" +) + +const ( + defaultTimeoutInMinutes = 6 +) + +type LLM interface { + Generate(context.Context, *thread.Thread) error +} + +type Tool struct { + llm LLM +} + +func New(llm LLM) *Tool { + return &Tool{ + llm: llm, + } +} + +type Input struct { + Query string `json:"query" jsonschema:"description=user query"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Result string `json:"result,omitempty"` +} + +type FnPrototype func(Input) Output + +func (t *Tool) Name() string { + return "llm" +} + +func (t *Tool) Description() string { + return "A tool that uses a language model to generate a response to a user query." +} + +func (t *Tool) Fn() any { + return t.fn +} + +//nolint:gosec +func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInMinutes*time.Minute) + defer cancel() + + th := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent(i.Query), + ), + ) + + err := t.llm.Generate(ctx, th) + if err != nil { + return Output{Error: err.Error()} + } + + return Output{Result: th.LastMessage().Contents[0].AsString()} +} diff --git a/tools/python/python.go b/tools/python/python.go new file mode 100644 index 00000000..2ac686f2 --- /dev/null +++ b/tools/python/python.go @@ -0,0 +1,68 @@ +package python + +import ( + "bytes" + "fmt" + "os/exec" +) + +type Tool struct { + pythonPath string +} + +func New() *Tool { + return &Tool{ + pythonPath: "python3", + } +} + +func (t *Tool) WithPythonPath(pythonPath string) *Tool { + t.pythonPath = pythonPath + return t +} + +type Input struct { + PythonCode string `json:"python_code" jsonschema:"description=python code that prints the final result to stdout."` +} + +type Output struct { + Error string `json:"error,omitempty"` + Result string `json:"result,omitempty"` +} + +type FnPrototype = func(Input) Output + +func (t *Tool) Name() string { + return "python" +} + +func (t *Tool) Description() string { + return "A tool that runs Python code using the Python interpreter. The code should print the final result to stdout." +} + +func (t *Tool) Fn() any { + return t.fn +} + +//nolint:gosec +func (t *Tool) fn(i Input) Output { + // Create a command to run the Python interpreter with the script. + cmd := exec.Command(t.pythonPath, "-c", i.PythonCode) + + // Create a buffer to capture the output. + var out bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &stderr + + // Run the command. + err := cmd.Run() + if err != nil { + return Output{ + Error: fmt.Sprintf("failed to run script: %v, stderr: %v", err, stderr.String()), + } + } + + // Return the output as a string. + return Output{Result: out.String()} +} diff --git a/tools/rag/rag.go b/tools/rag/rag.go new file mode 100644 index 00000000..c46788cf --- /dev/null +++ b/tools/rag/rag.go @@ -0,0 +1,62 @@ +package rag + +import ( + "context" + "strings" + "time" + + "github.com/henomis/lingoose/rag" +) + +const ( + defaultTimeoutInMinutes = 6 +) + +type Tool struct { + rag *rag.RAG + topic string +} + +func New(rag *rag.RAG, topic string) *Tool { + return &Tool{ + rag: rag, + topic: topic, + } +} + +type Input struct { + Query string `json:"rag_query" jsonschema:"description=search query"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Result string `json:"result,omitempty"` +} + +type FnPrototype = func(Input) Output + +func (t *Tool) Name() string { + return "rag" +} + +func (t *Tool) Description() string { + return "A tool that searches information ONLY for this topic: " + t.topic + ". DO NOT use this tool for other topics." +} + +func (t *Tool) Fn() any { + return t.fn +} + +//nolint:gosec +func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInMinutes*time.Minute) + defer cancel() + + results, err := t.rag.Retrieve(ctx, i.Query) + if err != nil { + return Output{Error: err.Error()} + } + + // Return the output as a string. + return Output{Result: strings.Join(results, "\n")} +} diff --git a/tools/serpapi/api.go b/tools/serpapi/api.go new file mode 100644 index 00000000..944b71fe --- /dev/null +++ b/tools/serpapi/api.go @@ -0,0 +1,119 @@ +package serpapi + +import ( + "encoding/json" + "io" + + "github.com/henomis/restclientgo" +) + +type request struct { + Query string + GoogleDomain string + CountryCode string + LanguageCode string + APIKey string +} + +type response struct { + HTTPStatusCode int + Map map[string]interface{} + RawBody []byte + apiResponse apiResponse + Results []result +} + +type apiResponse struct { + OrganicResults []OrganicResults `json:"organic_results"` +} + +type Top struct { + Extensions []string `json:"extensions"` +} + +type RichSnippet struct { + Top Top `json:"top"` +} + +type OrganicResults struct { + Position int `json:"position"` + Title string `json:"title"` + Link string `json:"link"` + RedirectLink string `json:"redirect_link"` + DisplayedLink string `json:"displayed_link"` + Thumbnail string `json:"thumbnail,omitempty"` + Favicon string `json:"favicon"` + Snippet string `json:"snippet"` + Source string `json:"source"` + RichSnippet RichSnippet `json:"rich_snippet,omitempty"` + SnippetHighlightedWords []string `json:"snippet_highlighted_words,omitempty"` +} + +type result struct { + Title string + Info string + URL string +} + +func (r *request) Path() (string, error) { + urlValues := restclientgo.NewURLValues() + urlValues.Add("q", &r.Query) + urlValues.Add("api_key", &r.APIKey) + + if r.GoogleDomain != "" { + urlValues.Add("google_domain", &r.GoogleDomain) + } + + if r.CountryCode != "" { + urlValues.Add("gl", &r.CountryCode) + } + + if r.LanguageCode != "" { + urlValues.Add("hl", &r.LanguageCode) + } + + params := urlValues.Encode() + + return "/search?" + params, nil +} + +func (r *request) Encode() (io.Reader, error) { + return nil, nil +} + +func (r *request) ContentType() string { + return "" +} + +func (r *response) Decode(body io.Reader) error { + err := json.NewDecoder(body).Decode(&r.apiResponse) + if err != nil { + return err + } + + for _, res := range r.apiResponse.OrganicResults { + r.Results = append(r.Results, result{ + Title: res.Title, + Info: res.Snippet, + URL: res.Link, + }) + } + + return nil +} + +func (r *response) SetBody(body io.Reader) error { + r.RawBody, _ = io.ReadAll(body) + return nil +} + +func (r *response) AcceptContentType() string { + return "application/json" +} + +func (r *response) SetStatusCode(code int) error { + r.HTTPStatusCode = code + return nil +} + +func (r *response) SetHeaders(_ restclientgo.Headers) error { return nil } diff --git a/tools/serpapi/serpapi.go b/tools/serpapi/serpapi.go new file mode 100644 index 00000000..9637e691 --- /dev/null +++ b/tools/serpapi/serpapi.go @@ -0,0 +1,98 @@ +package serpapi + +import ( + "context" + "fmt" + "os" + "time" + + "github.com/henomis/restclientgo" +) + +const ( + defaultTimeoutInSeconds = 60 +) + +type Tool struct { + restClient *restclientgo.RestClient + googleDomain string + countryCode string + languageCode string + apiKey string +} + +type Input struct { + Query string `json:"query" jsonschema:"description=the query to search for"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Results []result `json:"results,omitempty"` +} + +type FnPrototype = func(Input) Output + +func New() *Tool { + t := &Tool{ + apiKey: os.Getenv("SERPAPI_API_KEY"), + restClient: restclientgo.New("https://serpapi.com"), + googleDomain: "google.com", + countryCode: "us", + languageCode: "en", + } + + return t +} + +func (t *Tool) WithGoogleDomain(googleDomain string) *Tool { + t.googleDomain = googleDomain + return t +} + +func (t *Tool) WithCountryCode(countryCode string) *Tool { + t.countryCode = countryCode + return t +} + +func (t *Tool) WithLanguageCode(languageCode string) *Tool { + t.languageCode = languageCode + return t +} + +func (t *Tool) WithAPIKey(apiKey string) *Tool { + t.apiKey = apiKey + return t +} + +func (t *Tool) Name() string { + return "google" +} + +func (t *Tool) Description() string { + return "A tool that uses the Google internet search engine for a query." +} + +func (t *Tool) Fn() any { + return t.fn +} + +func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInSeconds*time.Second) + defer cancel() + + req := &request{ + Query: i.Query, + GoogleDomain: t.googleDomain, + CountryCode: t.countryCode, + LanguageCode: t.languageCode, + APIKey: t.apiKey, + } + res := &response{} + + err := t.restClient.Get(ctx, req, res) + if err != nil { + return Output{Error: fmt.Sprintf("failed to search serpapi: %v", err)} + } + + return Output{Results: res.Results} +} diff --git a/tools/shell/shell.go b/tools/shell/shell.go new file mode 100644 index 00000000..c5c5e308 --- /dev/null +++ b/tools/shell/shell.go @@ -0,0 +1,91 @@ +package shell + +import ( + "bytes" + "fmt" + "os/exec" +) + +type Tool struct { + shell string + askForConfirm bool +} + +func New() *Tool { + return &Tool{ + shell: "bash", + askForConfirm: true, + } +} + +func (t *Tool) WithShell(shell string) *Tool { + t.shell = shell + return t +} + +func (t *Tool) WithAskForConfirm(askForConfirm bool) *Tool { + t.askForConfirm = askForConfirm + return t +} + +type Input struct { + BashScript string `json:"bash_code" jsonschema:"description=shell script"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Result string `json:"result,omitempty"` +} + +type FnPrototype = func(Input) Output + +func (t *Tool) Name() string { + return "bash" +} + +func (t *Tool) Description() string { + return "A tool that runs a shell script using the " + t.shell + " interpreter. Use it to interact with the OS." +} + +func (t *Tool) Fn() any { + return t.fn +} + +//nolint:gosec +func (t *Tool) fn(i Input) Output { + // Ask for confirmation if the flag is set. + if t.askForConfirm { + fmt.Println("Are you sure you want to run the following script?") + fmt.Println("-------------------------------------------------") + fmt.Println(i.BashScript) + fmt.Println("-------------------------------------------------") + fmt.Print("Type 'yes' to confirm > ") + var confirm string + fmt.Scanln(&confirm) + if confirm != "yes" { + return Output{ + Error: "script execution aborted", + } + } + } + + // Create a command to run the Bash interpreter with the script. + cmd := exec.Command(t.shell, "-c", i.BashScript) + + // Create a buffer to capture the output. + var out bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &stderr + + // Run the command. + err := cmd.Run() + if err != nil { + return Output{ + Error: fmt.Sprintf("failed to run script: %v, stderr: %v", err, stderr.String()), + } + } + + // Return the output as a string. + return Output{Result: out.String()} +} diff --git a/tools/tool_router/tool_router.go b/tools/tool_router/tool_router.go new file mode 100644 index 00000000..17d35bcd --- /dev/null +++ b/tools/tool_router/tool_router.go @@ -0,0 +1,84 @@ +package toolrouter + +import ( + "context" + "time" + + "github.com/henomis/lingoose/thread" +) + +const ( + defaultTimeoutInMinutes = 6 +) + +type TTool interface { + Description() string + Name() string + Fn() any +} + +type Tool struct { + llm LLM + tools []TTool +} + +type LLM interface { + Generate(context.Context, *thread.Thread) error +} + +func New(llm LLM, tools ...TTool) *Tool { + return &Tool{ + tools: tools, + llm: llm, + } +} + +type Input struct { + Query string `json:"query" jsonschema:"description=user query"` +} + +type Output struct { + Error string `json:"error,omitempty"` + Result any `json:"result,omitempty"` +} + +type FnPrototype func(Input) Output + +func (t *Tool) Name() string { + return "query_router" +} + +func (t *Tool) Description() string { + return "A tool that select the right tool to answer to user queries." +} + +func (t *Tool) Fn() any { + return t.fn +} + +//nolint:gosec +func (t *Tool) fn(i Input) Output { + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeoutInMinutes*time.Minute) + defer cancel() + + query := "Here's a list of available tools:\n\n" + for _, tool := range t.tools { + query += "Name: " + tool.Name() + "\nDescription: " + tool.Description() + "\n\n" + } + + query += "\nPlease select the right tool that can better answer the query '" + i.Query + + "'. Give me only the name of the tool, nothing else." + + th := thread.New().AddMessage( + thread.NewUserMessage().AddContent( + thread.NewTextContent(query), + ), + ) + + err := t.llm.Generate(ctx, th) + if err != nil { + return Output{Error: err.Error()} + } + + return Output{Result: th.LastMessage().Contents[0].AsString()} +}