From 5f341a91a4a3ee6292194e6054d26fdbea18d7e7 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Thu, 29 May 2025 02:43:07 +0200 Subject: [PATCH 1/4] make a working version of completions --- e2e/e2e_test.go | 4 +- go.mod | 2 +- go.sum | 4 +- internal/ghmcp/server.go | 22 +-- pkg/github/code_scanning.go | 4 +- pkg/github/context_tools.go | 4 +- pkg/github/dynamic_tools.go | 4 +- pkg/github/helper_test.go | 2 +- pkg/github/issues.go | 4 +- pkg/github/notifications.go | 4 +- pkg/github/pullrequests.go | 4 +- pkg/github/repositories.go | 5 +- pkg/github/repository_completions.go | 160 ++++++++++++++++++++++ pkg/github/repository_completions_test.go | 31 +++++ pkg/github/repository_resource.go | 6 +- pkg/github/repository_resource_test.go | 2 +- pkg/github/resources.go | 2 +- pkg/github/search.go | 4 +- pkg/github/secret_scanning.go | 4 +- pkg/github/server.go | 8 +- pkg/github/tools.go | 2 +- pkg/toolsets/toolsets.go | 4 +- third-party-licenses.darwin.md | 2 +- third-party-licenses.linux.md | 2 +- third-party-licenses.windows.md | 2 +- 25 files changed, 242 insertions(+), 50 deletions(-) create mode 100644 pkg/github/repository_completions.go create mode 100644 pkg/github/repository_completions_test.go diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 71bd5a8a..52973f6d 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -20,8 +20,8 @@ import ( "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/translations" gogithub "github.com/google/go-github/v69/github" - mcpClient "github.com/mark3labs/mcp-go/client" - "github.com/mark3labs/mcp-go/mcp" + mcpClient "github.com/sammorrowdrums/mcp-go/client" + "github.com/sammorrowdrums/mcp-go/mcp" "github.com/stretchr/testify/require" ) diff --git a/go.mod b/go.mod index 684ce8f2..5fc5bf65 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,8 @@ go 1.23.7 require ( github.com/google/go-github/v69 v69.2.0 github.com/josephburnett/jd v1.9.2 - github.com/mark3labs/mcp-go v0.30.0 github.com/migueleliasweb/go-github-mock v1.3.0 + github.com/sammorrowdrums/mcp-go v0.0.0-20250528234530-f0daf2216052 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.1 diff --git a/go.sum b/go.sum index c2da59f6..1736aff4 100644 --- a/go.sum +++ b/go.sum @@ -47,8 +47,6 @@ github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= -github.com/mark3labs/mcp-go v0.30.0 h1:Taz7fiefkxY/l8jz1nA90V+WdM2eoMtlvwfWforVYbo= -github.com/mark3labs/mcp-go v0.30.0/go.mod h1:rXqOudj/djTORU/ThxYx8fqEVj/5pvTuuebQ2RC7uk4= github.com/migueleliasweb/go-github-mock v1.3.0 h1:2sVP9JEMB2ubQw1IKto3/fzF51oFC6eVWOOFDgQoq88= github.com/migueleliasweb/go-github-mock v1.3.0/go.mod h1:ipQhV8fTcj/G6m7BKzin08GaJ/3B5/SonRAkgrk0zCY= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= @@ -62,6 +60,8 @@ github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWN github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k= github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk= +github.com/sammorrowdrums/mcp-go v0.0.0-20250528234530-f0daf2216052 h1:c9HI0HGuXED8zwXCdnk2iGyaSC8mvZlBGl+SdHxYJgs= +github.com/sammorrowdrums/mcp-go v0.0.0-20250528234530-f0daf2216052/go.mod h1:Kwt02UMWGJxJ1IHMO9Wrj4GabTSvv9uVUrpht1vjiuk= github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 h1:cYCy18SHPKRkvclm+pWm1Lk4YrREb4IOIb/YdFO0p2M= github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7/go.mod h1:zqMwyHmnN/eDOZOdiTohqIUKUrTFX62PNlu7IJdu0q8= github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 h1:17JxqqJY66GmZVHkmAsGEkcIu0oCe3AM420QDgGwZx0= diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index a75a9e0c..8b1f0eb5 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -16,8 +16,8 @@ import ( mcplog "github.com/github/github-mcp-server/pkg/log" "github.com/github/github-mcp-server/pkg/translations" gogithub "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" "github.com/shurcooL/githubv4" "github.com/sirupsen/logrus" ) @@ -91,7 +91,15 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { OnBeforeInitialize: []server.OnBeforeInitializeFunc{beforeInit}, } - ghServer := github.NewServer(cfg.Version, server.WithHooks(hooks)) + getClient := func(_ context.Context) (*gogithub.Client, error) { + return restClient, nil // closing over client + } + + getGQLClient := func(_ context.Context) (*githubv4.Client, error) { + return gqlClient, nil // closing over client + } + + ghServer := github.NewServer(getClient, cfg.Version, server.WithHooks(hooks)) enabledToolsets := cfg.EnabledToolsets if cfg.DynamicToolsets { @@ -104,14 +112,6 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) { } } - getClient := func(_ context.Context) (*gogithub.Client, error) { - return restClient, nil // closing over client - } - - getGQLClient := func(_ context.Context) (*githubv4.Client, error) { - return gqlClient, nil // closing over client - } - // Create default toolsets toolsets, err := github.InitToolsets( enabledToolsets, diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 34a1b9ed..3d0c1363 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -9,8 +9,8 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" ) func GetCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { diff --git a/pkg/github/context_tools.go b/pkg/github/context_tools.go index 7b8ed249..d3f9e939 100644 --- a/pkg/github/context_tools.go +++ b/pkg/github/context_tools.go @@ -4,8 +4,8 @@ import ( "context" "github.com/github/github-mcp-server/pkg/translations" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" ) // GetMe creates a tool to get details of the authenticated user. diff --git a/pkg/github/dynamic_tools.go b/pkg/github/dynamic_tools.go index 0b098fb3..ac253244 100644 --- a/pkg/github/dynamic_tools.go +++ b/pkg/github/dynamic_tools.go @@ -7,8 +7,8 @@ import ( "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" ) func ToolsetEnum(toolsetGroup *toolsets.ToolsetGroup) mcp.PropertyOption { diff --git a/pkg/github/helper_test.go b/pkg/github/helper_test.go index 4b9a243d..8abbb64d 100644 --- a/pkg/github/helper_test.go +++ b/pkg/github/helper_test.go @@ -5,7 +5,7 @@ import ( "net/http" "testing" - "github.com/mark3labs/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/mcp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) diff --git a/pkg/github/issues.go b/pkg/github/issues.go index 07c76078..b1ed13be 100644 --- a/pkg/github/issues.go +++ b/pkg/github/issues.go @@ -12,8 +12,8 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/go-viper/mapstructure/v2" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" "github.com/shurcooL/githubv4" ) diff --git a/pkg/github/notifications.go b/pkg/github/notifications.go index ba9c6bc2..dfc961df 100644 --- a/pkg/github/notifications.go +++ b/pkg/github/notifications.go @@ -11,8 +11,8 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" ) const ( diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index abdf6448..6fb0c335 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -9,8 +9,8 @@ import ( "github.com/go-viper/mapstructure/v2" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" "github.com/shurcooL/githubv4" "github.com/github/github-mcp-server/pkg/translations" diff --git a/pkg/github/repositories.go b/pkg/github/repositories.go index 8c337163..7fc8e139 100644 --- a/pkg/github/repositories.go +++ b/pkg/github/repositories.go @@ -9,8 +9,8 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" ) func GetCommit(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { @@ -132,6 +132,7 @@ func ListCommits(getClient GetClientFn, t translations.TranslationHelperFunc) (t } client, err := getClient(ctx) + if err != nil { return nil, fmt.Errorf("failed to get GitHub client: %w", err) } diff --git a/pkg/github/repository_completions.go b/pkg/github/repository_completions.go new file mode 100644 index 00000000..fa99e061 --- /dev/null +++ b/pkg/github/repository_completions.go @@ -0,0 +1,160 @@ +package github + +import ( + "context" + "fmt" + "strings" + + "github.com/google/go-github/v69/github" + "github.com/sammorrowdrums/mcp-go/mcp" +) + +// RepositoryResourceCompletionHandler returns a CompletionHandlerFunc for repository resource completions. +func RepositoryResourceCompletionHandler(getClient GetClientFn) func(ctx context.Context, req mcp.CompleteRequest) (*mcp.CompleteResult, error) { + return func(ctx context.Context, req mcp.CompleteRequest) (*mcp.CompleteResult, error) { + ref, ok := req.Params.Ref.(map[string]any) + if !ok || ref["type"] != "ref/resource" { + return nil, nil // Not a resource completion + } + uri, _ := ref["uri"].(string) + argName := req.Params.Argument.Name + argValue := req.Params.Argument.Value + + client, err := getClient(ctx) + if err != nil { + return nil, err + } + + var values []string + + switch argName { + case "owner": + user, _, err := client.Users.Get(ctx, "") + if err == nil && user.GetLogin() != "" { + values = append(values, user.GetLogin()) + } + orgs, _, _ := client.Organizations.List(ctx, "", nil) + for _, org := range orgs { + values = append(values, org.GetLogin()) + } + case "repo": + // print the whole mcp complete request for debugging + fmt.Printf("MCP Complete Request: %+v\n", req) + + fmt.Printf("URI: %s\n", uri) + owner := getArgFromURI(uri, "owner") + if owner != "" { + repos, _, err := client.Search.Repositories(ctx, fmt.Sprintf("user:%s", owner), &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}}) + if err != nil || repos == nil { + break + } + for _, repo := range repos.Repositories { + if argValue == "" || strings.Contains(repo.GetName(), argValue) { + values = append(values, repo.GetName()) + } + } + } + case "branch": + owner := getArgFromURI(uri, "owner") + repo := getArgFromURI(uri, "repo") + if owner != "" && repo != "" { + branches, _, _ := client.Repositories.ListBranches(ctx, owner, repo, nil) + for _, branch := range branches { + if argValue == "" || strings.Contains(branch.GetName(), argValue) { + values = append(values, branch.GetName()) + } + } + } + case "sha": + owner := getArgFromURI(uri, "owner") + repo := getArgFromURI(uri, "repo") + if owner != "" && repo != "" { + commits, _, _ := client.Repositories.ListCommits(ctx, owner, repo, nil) + for _, commit := range commits { + sha := commit.GetSHA() + if argValue == "" || strings.HasPrefix(sha, argValue) { + values = append(values, sha) + } + } + } + case "tag": + owner := getArgFromURI(uri, "owner") + repo := getArgFromURI(uri, "repo") + if owner != "" && repo != "" { + tags, _, _ := client.Repositories.ListTags(ctx, owner, repo, nil) + for _, tag := range tags { + if argValue == "" || strings.Contains(tag.GetName(), argValue) { + values = append(values, tag.GetName()) + } + } + } + case "prNumber": + owner := getArgFromURI(uri, "owner") + repo := getArgFromURI(uri, "repo") + if owner != "" && repo != "" { + prs, _, _ := client.PullRequests.List(ctx, owner, repo, nil) + for _, pr := range prs { + num := fmt.Sprintf("%d", pr.GetNumber()) + if argValue == "" || strings.HasPrefix(num, argValue) { + values = append(values, num) + } + } + } + case "path": + owner := getArgFromURI(uri, "owner") + repo := getArgFromURI(uri, "repo") + refVal := getArgFromURI(uri, "branch") + if refVal == "" { + refVal = getArgFromURI(uri, "sha") + } + if refVal == "" { + refVal = getArgFromURI(uri, "tag") + } + if refVal == "" { + refVal = "main" + } + if owner != "" && repo != "" { + contents, dirContents, _, _ := client.Repositories.GetContents(ctx, owner, repo, "", &github.RepositoryContentGetOptions{Ref: refVal}) + if dirContents != nil { + for _, entry := range dirContents { + if argValue == "" || strings.HasPrefix(entry.GetName(), argValue) { + values = append(values, entry.GetName()) + } + } + } else if contents != nil { + if argValue == "" || strings.HasPrefix(contents.GetName(), argValue) { + values = append(values, contents.GetName()) + } + } + } + } + + if len(values) > 100 { + values = values[:100] + } + + return &mcp.CompleteResult{ + Completion: struct { + Values []string `json:"values"` + Total int `json:"total,omitempty"` + HasMore bool `json:"hasMore,omitempty"` + }{ + Values: values, + Total: len(values), + HasMore: false, + }, + }, nil + } +} + +func getArgFromURI(uri, name string) string { + trimmed := strings.TrimPrefix(uri, "repo://") + parts := strings.Split(trimmed, "/") + if name == "owner" && len(parts) > 0 && parts[0] != "" { + return parts[0] + } + if name == "repo" && len(parts) > 1 && parts[1] != "" { + return parts[1] + } + return "" +} diff --git a/pkg/github/repository_completions_test.go b/pkg/github/repository_completions_test.go new file mode 100644 index 00000000..d577c3b2 --- /dev/null +++ b/pkg/github/repository_completions_test.go @@ -0,0 +1,31 @@ +package github + +import ( + "context" + "testing" + + "github.com/google/go-github/v69/github" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/stretchr/testify/require" +) + +// Add more fake methods as needed for testing +func TestRepositoryResourceCompletionHandler_Owner(t *testing.T) { + // Stub getClient to return a fake client with a user and orgs + getClient := func(ctx context.Context) (*github.Client, error) { + client := github.NewClient(nil) + // You can use github's testing helpers or mock the methods as needed + return client, nil + } + + handler := RepositoryResourceCompletionHandler(getClient) + request := mcp.CompleteRequest{} + request.Params.Ref = map[string]any{"type": "ref/resource", "uri": "repo://"} + request.Params.Argument.Name = "owner" + request.Params.Argument.Value = "" + + result, err := handler(context.Background(), request) + require.NoError(t, err) + // In a real test, assert on result.Completion.Values + _ = result +} diff --git a/pkg/github/repository_resource.go b/pkg/github/repository_resource.go index fe34689f..3bd2d2b5 100644 --- a/pkg/github/repository_resource.go +++ b/pkg/github/repository_resource.go @@ -13,8 +13,8 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" ) // GetRepositoryResourceContent defines the resource template and handler for getting repository content. @@ -66,7 +66,7 @@ func GetRepositoryResourcePrContent(getClient GetClientFn, t translations.Transl func RepositoryResourceContentsHandler(getClient GetClientFn) func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { return func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { // the matcher will give []string with one element - // https://github.com/mark3labs/mcp-go/pull/54 + // https://github.com/sammorrowdrums/mcp-go/pull/54 o, ok := request.Params.Arguments["owner"].([]string) if !ok || len(o) == 0 { return nil, errors.New("owner is required") diff --git a/pkg/github/repository_resource_test.go b/pkg/github/repository_resource_test.go index f6a47e8c..bc693a18 100644 --- a/pkg/github/repository_resource_test.go +++ b/pkg/github/repository_resource_test.go @@ -7,8 +7,8 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" "github.com/migueleliasweb/go-github-mock/src/mock" + "github.com/sammorrowdrums/mcp-go/mcp" "github.com/stretchr/testify/require" ) diff --git a/pkg/github/resources.go b/pkg/github/resources.go index 774261e9..46b0c56a 100644 --- a/pkg/github/resources.go +++ b/pkg/github/resources.go @@ -2,7 +2,7 @@ package github import ( "github.com/github/github-mcp-server/pkg/translations" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/server" ) func RegisterResources(s *server.MCPServer, getClient GetClientFn, t translations.TranslationHelperFunc) { diff --git a/pkg/github/search.go b/pkg/github/search.go index ac5e2994..7b08497d 100644 --- a/pkg/github/search.go +++ b/pkg/github/search.go @@ -8,8 +8,8 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" ) // SearchRepositories creates a tool to search for GitHub repositories. diff --git a/pkg/github/secret_scanning.go b/pkg/github/secret_scanning.go index 847fcfc6..94db0d0f 100644 --- a/pkg/github/secret_scanning.go +++ b/pkg/github/secret_scanning.go @@ -9,8 +9,8 @@ import ( "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" ) func GetSecretScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { diff --git a/pkg/github/server.go b/pkg/github/server.go index b182b8ca..1ca97f6b 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -6,17 +6,17 @@ import ( "fmt" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" ) // NewServer creates a new GitHub MCP server with the specified GH client and logger. -func NewServer(version string, opts ...server.ServerOption) *server.MCPServer { - // Add default options +func NewServer(getClient GetClientFn, version string, opts ...server.ServerOption) *server.MCPServer { defaultOpts := []server.ServerOption{ server.WithToolCapabilities(true), server.WithResourceCapabilities(true, true), + server.WithCompletion(RepositoryResourceCompletionHandler(getClient)), server.WithLogging(), } opts = append(defaultOpts, opts...) diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 9c1ab34a..6a6ccd43 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -6,7 +6,7 @@ import ( "github.com/github/github-mcp-server/pkg/toolsets" "github.com/github/github-mcp-server/pkg/translations" "github.com/google/go-github/v69/github" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/server" "github.com/shurcooL/githubv4" ) diff --git a/pkg/toolsets/toolsets.go b/pkg/toolsets/toolsets.go index 7400119c..8f0f5105 100644 --- a/pkg/toolsets/toolsets.go +++ b/pkg/toolsets/toolsets.go @@ -3,8 +3,8 @@ package toolsets import ( "fmt" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "github.com/sammorrowdrums/mcp-go/mcp" + "github.com/sammorrowdrums/mcp-go/server" ) func NewServerTool(tool mcp.Tool, handler server.ToolHandlerFunc) server.ServerTool { diff --git a/third-party-licenses.darwin.md b/third-party-licenses.darwin.md index c1f098df..8519e017 100644 --- a/third-party-licenses.darwin.md +++ b/third-party-licenses.darwin.md @@ -18,7 +18,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/josephburnett/jd/v2](https://pkg.go.dev/github.com/josephburnett/jd/v2) ([MIT](https://github.com/josephburnett/jd/blob/v1.9.2/LICENSE)) - [github.com/josharian/intern](https://pkg.go.dev/github.com/josharian/intern) ([MIT](https://github.com/josharian/intern/blob/v1.0.0/license.md)) - [github.com/mailru/easyjson](https://pkg.go.dev/github.com/mailru/easyjson) ([MIT](https://github.com/mailru/easyjson/blob/v0.7.7/LICENSE)) - - [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.30.0/LICENSE)) + - [github.com/sammorrowdrums/mcp-go](https://pkg.go.dev/github.com/sammorrowdrums/mcp-go) ([MIT](https://github.com/sammorrowdrums/mcp-go/blob/v0.30.0/LICENSE)) - [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.3/LICENSE)) - [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.9.0/LICENSE)) - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE)) diff --git a/third-party-licenses.linux.md b/third-party-licenses.linux.md index c1f098df..8519e017 100644 --- a/third-party-licenses.linux.md +++ b/third-party-licenses.linux.md @@ -18,7 +18,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/josephburnett/jd/v2](https://pkg.go.dev/github.com/josephburnett/jd/v2) ([MIT](https://github.com/josephburnett/jd/blob/v1.9.2/LICENSE)) - [github.com/josharian/intern](https://pkg.go.dev/github.com/josharian/intern) ([MIT](https://github.com/josharian/intern/blob/v1.0.0/license.md)) - [github.com/mailru/easyjson](https://pkg.go.dev/github.com/mailru/easyjson) ([MIT](https://github.com/mailru/easyjson/blob/v0.7.7/LICENSE)) - - [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.30.0/LICENSE)) + - [github.com/sammorrowdrums/mcp-go](https://pkg.go.dev/github.com/sammorrowdrums/mcp-go) ([MIT](https://github.com/sammorrowdrums/mcp-go/blob/v0.30.0/LICENSE)) - [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.3/LICENSE)) - [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.9.0/LICENSE)) - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE)) diff --git a/third-party-licenses.windows.md b/third-party-licenses.windows.md index f57e547b..32d671a0 100644 --- a/third-party-licenses.windows.md +++ b/third-party-licenses.windows.md @@ -19,7 +19,7 @@ Some packages may only be included on certain architectures or operating systems - [github.com/josephburnett/jd/v2](https://pkg.go.dev/github.com/josephburnett/jd/v2) ([MIT](https://github.com/josephburnett/jd/blob/v1.9.2/LICENSE)) - [github.com/josharian/intern](https://pkg.go.dev/github.com/josharian/intern) ([MIT](https://github.com/josharian/intern/blob/v1.0.0/license.md)) - [github.com/mailru/easyjson](https://pkg.go.dev/github.com/mailru/easyjson) ([MIT](https://github.com/mailru/easyjson/blob/v0.7.7/LICENSE)) - - [github.com/mark3labs/mcp-go](https://pkg.go.dev/github.com/mark3labs/mcp-go) ([MIT](https://github.com/mark3labs/mcp-go/blob/v0.30.0/LICENSE)) + - [github.com/sammorrowdrums/mcp-go](https://pkg.go.dev/github.com/sammorrowdrums/mcp-go) ([MIT](https://github.com/sammorrowdrums/mcp-go/blob/v0.30.0/LICENSE)) - [github.com/pelletier/go-toml/v2](https://pkg.go.dev/github.com/pelletier/go-toml/v2) ([MIT](https://github.com/pelletier/go-toml/blob/v2.2.3/LICENSE)) - [github.com/sagikazarmark/locafero](https://pkg.go.dev/github.com/sagikazarmark/locafero) ([MIT](https://github.com/sagikazarmark/locafero/blob/v0.9.0/LICENSE)) - [github.com/shurcooL/githubv4](https://pkg.go.dev/github.com/shurcooL/githubv4) ([MIT](https://github.com/shurcooL/githubv4/blob/48295856cce7/LICENSE)) From 4e5868163e086dfa37382b3118955c3c3c978bad Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Mon, 2 Jun 2025 23:58:30 +0200 Subject: [PATCH 2/4] more work on completions --- go.mod | 2 +- go.sum | 2 + pkg/github/repository_completions.go | 355 ++++++++++++++++++--------- 3 files changed, 248 insertions(+), 111 deletions(-) diff --git a/go.mod b/go.mod index 5fc5bf65..d5195020 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/google/go-github/v69 v69.2.0 github.com/josephburnett/jd v1.9.2 github.com/migueleliasweb/go-github-mock v1.3.0 - github.com/sammorrowdrums/mcp-go v0.0.0-20250528234530-f0daf2216052 + github.com/sammorrowdrums/mcp-go v0.0.0-20250602101733-1a4eb277f6a0 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 github.com/spf13/viper v1.20.1 diff --git a/go.sum b/go.sum index 1736aff4..eb606dc4 100644 --- a/go.sum +++ b/go.sum @@ -62,6 +62,8 @@ github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFT github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk= github.com/sammorrowdrums/mcp-go v0.0.0-20250528234530-f0daf2216052 h1:c9HI0HGuXED8zwXCdnk2iGyaSC8mvZlBGl+SdHxYJgs= github.com/sammorrowdrums/mcp-go v0.0.0-20250528234530-f0daf2216052/go.mod h1:Kwt02UMWGJxJ1IHMO9Wrj4GabTSvv9uVUrpht1vjiuk= +github.com/sammorrowdrums/mcp-go v0.0.0-20250602101733-1a4eb277f6a0 h1:fmwKUofBVuktOGefuZUbvUCTPXovPDjTeN7X5N2S2GI= +github.com/sammorrowdrums/mcp-go v0.0.0-20250602101733-1a4eb277f6a0/go.mod h1:Kwt02UMWGJxJ1IHMO9Wrj4GabTSvv9uVUrpht1vjiuk= github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7 h1:cYCy18SHPKRkvclm+pWm1Lk4YrREb4IOIb/YdFO0p2M= github.com/shurcooL/githubv4 v0.0.0-20240727222349-48295856cce7/go.mod h1:zqMwyHmnN/eDOZOdiTohqIUKUrTFX62PNlu7IJdu0q8= github.com/shurcooL/graphql v0.0.0-20230722043721-ed46e5a46466 h1:17JxqqJY66GmZVHkmAsGEkcIu0oCe3AM420QDgGwZx0= diff --git a/pkg/github/repository_completions.go b/pkg/github/repository_completions.go index fa99e061..2c86ed34 100644 --- a/pkg/github/repository_completions.go +++ b/pkg/github/repository_completions.go @@ -9,6 +9,8 @@ import ( "github.com/sammorrowdrums/mcp-go/mcp" ) +// RepositoryResourceCompletionHandler returns a CompletionHandlerFunc for repository resource completions. + // RepositoryResourceCompletionHandler returns a CompletionHandlerFunc for repository resource completions. func RepositoryResourceCompletionHandler(getClient GetClientFn) func(ctx context.Context, req mcp.CompleteRequest) (*mcp.CompleteResult, error) { return func(ctx context.Context, req mcp.CompleteRequest) (*mcp.CompleteResult, error) { @@ -16,119 +18,42 @@ func RepositoryResourceCompletionHandler(getClient GetClientFn) func(ctx context if !ok || ref["type"] != "ref/resource" { return nil, nil // Not a resource completion } - uri, _ := ref["uri"].(string) + argName := req.Params.Argument.Name argValue := req.Params.Argument.Value + resolved, ok := any(req.Params.Resolved).(map[string]string) + if !ok && req.Params.Resolved != nil { + return nil, fmt.Errorf(".Resolved must be map[string]string, got %T", req.Params.Resolved) + } + if resolved == nil { + resolved = map[string]string{} + } client, err := getClient(ctx) if err != nil { return nil, err } - var values []string + // Argument resolver functions + resolvers := map[string]func(context.Context, *github.Client, map[string]string, string) ([]string, error){ + "owner": completeOwner, + "repo": completeRepo, + "branch": completeBranch, + "sha": completeSHA, + "tag": completeTag, + "prNumber": completePRNumber, + "path": completePath, + } - switch argName { - case "owner": - user, _, err := client.Users.Get(ctx, "") - if err == nil && user.GetLogin() != "" { - values = append(values, user.GetLogin()) - } - orgs, _, _ := client.Organizations.List(ctx, "", nil) - for _, org := range orgs { - values = append(values, org.GetLogin()) - } - case "repo": - // print the whole mcp complete request for debugging - fmt.Printf("MCP Complete Request: %+v\n", req) - - fmt.Printf("URI: %s\n", uri) - owner := getArgFromURI(uri, "owner") - if owner != "" { - repos, _, err := client.Search.Repositories(ctx, fmt.Sprintf("user:%s", owner), &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}}) - if err != nil || repos == nil { - break - } - for _, repo := range repos.Repositories { - if argValue == "" || strings.Contains(repo.GetName(), argValue) { - values = append(values, repo.GetName()) - } - } - } - case "branch": - owner := getArgFromURI(uri, "owner") - repo := getArgFromURI(uri, "repo") - if owner != "" && repo != "" { - branches, _, _ := client.Repositories.ListBranches(ctx, owner, repo, nil) - for _, branch := range branches { - if argValue == "" || strings.Contains(branch.GetName(), argValue) { - values = append(values, branch.GetName()) - } - } - } - case "sha": - owner := getArgFromURI(uri, "owner") - repo := getArgFromURI(uri, "repo") - if owner != "" && repo != "" { - commits, _, _ := client.Repositories.ListCommits(ctx, owner, repo, nil) - for _, commit := range commits { - sha := commit.GetSHA() - if argValue == "" || strings.HasPrefix(sha, argValue) { - values = append(values, sha) - } - } - } - case "tag": - owner := getArgFromURI(uri, "owner") - repo := getArgFromURI(uri, "repo") - if owner != "" && repo != "" { - tags, _, _ := client.Repositories.ListTags(ctx, owner, repo, nil) - for _, tag := range tags { - if argValue == "" || strings.Contains(tag.GetName(), argValue) { - values = append(values, tag.GetName()) - } - } - } - case "prNumber": - owner := getArgFromURI(uri, "owner") - repo := getArgFromURI(uri, "repo") - if owner != "" && repo != "" { - prs, _, _ := client.PullRequests.List(ctx, owner, repo, nil) - for _, pr := range prs { - num := fmt.Sprintf("%d", pr.GetNumber()) - if argValue == "" || strings.HasPrefix(num, argValue) { - values = append(values, num) - } - } - } - case "path": - owner := getArgFromURI(uri, "owner") - repo := getArgFromURI(uri, "repo") - refVal := getArgFromURI(uri, "branch") - if refVal == "" { - refVal = getArgFromURI(uri, "sha") - } - if refVal == "" { - refVal = getArgFromURI(uri, "tag") - } - if refVal == "" { - refVal = "main" - } - if owner != "" && repo != "" { - contents, dirContents, _, _ := client.Repositories.GetContents(ctx, owner, repo, "", &github.RepositoryContentGetOptions{Ref: refVal}) - if dirContents != nil { - for _, entry := range dirContents { - if argValue == "" || strings.HasPrefix(entry.GetName(), argValue) { - values = append(values, entry.GetName()) - } - } - } else if contents != nil { - if argValue == "" || strings.HasPrefix(contents.GetName(), argValue) { - values = append(values, contents.GetName()) - } - } - } + resolver, ok := resolvers[argName] + if !ok { + return nil, nil // Unknown argument } + values, err := resolver(ctx, client, resolved, argValue) + if err != nil { + return nil, err + } if len(values) > 100 { values = values[:100] } @@ -147,14 +72,224 @@ func RepositoryResourceCompletionHandler(getClient GetClientFn) func(ctx context } } -func getArgFromURI(uri, name string) string { - trimmed := strings.TrimPrefix(uri, "repo://") - parts := strings.Split(trimmed, "/") - if name == "owner" && len(parts) > 0 && parts[0] != "" { - return parts[0] +// --- Per-argument resolver functions --- + +func completeOwner(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) { + var values []string + user, _, err := client.Users.Get(ctx, "") + if err == nil && user.GetLogin() != "" { + values = append(values, user.GetLogin()) + } + orgs, _, _ := client.Organizations.List(ctx, "", &github.ListOptions{PerPage: 100}) + for _, org := range orgs { + values = append(values, org.GetLogin()) + } + // filter values based on argValue and replace values slice + if argValue != "" { + var filteredValues []string + for _, value := range values { + if strings.Contains(value, argValue) { + filteredValues = append(filteredValues, value) + } + } + values = filteredValues + } + if len(values) > 100 { + values = values[:100] + return values, nil // Limit to 100 results + } + // Else also do a client.Search.Users() + if argValue == "" { + return values, nil // No need to search if no argValue + } + users, _, err := client.Search.Users(ctx, argValue, &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100 - len(values)}}) + if err != nil || users == nil { + return nil, err + } + for _, user := range users.Users { + values = append(values, user.GetLogin()) + } + + if len(values) > 100 { + values = values[:100] + } + return values, nil +} + +func completeRepo(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) { + var values []string + owner := resolved["owner"] + if owner == "" { + return values, nil + } + + repos, _, err := client.Search.Repositories(ctx, fmt.Sprintf("org:%s %s in:name", owner, argValue), &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}}) + if err != nil || repos == nil { + return values, nil + } + + if len(values) > 100 { + values = values[:100] + } + return values, nil +} + +func completeBranch(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) { + var values []string + owner := resolved["owner"] + repo := resolved["repo"] + if owner == "" || repo == "" { + return values, nil + } + branches, _, _ := client.Repositories.ListBranches(ctx, owner, repo, nil) + + for _, branch := range branches { + if argValue == "" || strings.Contains(branch.GetName(), argValue) { + values = append(values, branch.GetName()) + } } - if name == "repo" && len(parts) > 1 && parts[1] != "" { - return parts[1] + if len(values) > 100 { + values = values[:100] + } + return values, nil +} + +func completeSHA(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) { + var values []string + owner := resolved["owner"] + repo := resolved["repo"] + if owner == "" || repo == "" { + return values, nil + } + commits, _, _ := client.Repositories.ListCommits(ctx, owner, repo, nil) + + for _, commit := range commits { + sha := commit.GetSHA() + if argValue == "" || strings.HasPrefix(sha, argValue) { + values = append(values, sha) + } + } + if len(values) > 100 { + values = values[:100] + } + return values, nil +} + +func completeTag(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) { + owner := resolved["owner"] + repo := resolved["repo"] + if owner == "" || repo == "" { + return nil, nil + } + tags, _, _ := client.Repositories.ListTags(ctx, owner, repo, nil) + var values []string + for _, tag := range tags { + if argValue == "" || strings.Contains(tag.GetName(), argValue) { + values = append(values, tag.GetName()) + } + } + if len(values) > 100 { + values = values[:100] + } + return values, nil +} + +func completePRNumber(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) { + var values []string + owner := resolved["owner"] + repo := resolved["repo"] + if owner == "" || repo == "" { + return values, nil + } + // prs, _, _ := client.PullRequests.List(ctx, owner, repo, &github.PullRequestListOptions{}) + prs, _, _ := client.Search.Issues(ctx, fmt.Sprintf("repo:%s/%s is:open is:pr %s", owner, repo, argValue), &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}}) + for _, pr := range prs.Issues { + num := fmt.Sprintf("%d", pr.GetNumber()) + if argValue == "" || strings.HasPrefix(num, argValue) { + values = append(values, num) + } + } + if len(values) > 100 { + values = values[:100] + } + return values, nil +} + +func completePath(ctx context.Context, client *github.Client, resolved map[string]string, argValue string) ([]string, error) { + owner := resolved["owner"] + repo := resolved["repo"] + if owner == "" || repo == "" { + return nil, nil + } + refVal := resolved["branch"] + if refVal == "" { + refVal = resolved["sha"] + } + if refVal == "" { + refVal = resolved["tag"] + } + if refVal == "" { + refVal = "HEAD" + } + + // Determine the prefix to complete (directory path or file path) + prefix := argValue + if prefix != "" && !strings.HasSuffix(prefix, "/") { + lastSlash := strings.LastIndex(prefix, "/") + if lastSlash >= 0 { + prefix = prefix[:lastSlash+1] + } else { + prefix = "" + } + } + + // Get the tree for the ref (recursive) + tree, _, err := client.Git.GetTree(ctx, owner, repo, refVal, true) + if err != nil || tree == nil { + return nil, nil + } + + // Collect immediate children of the prefix (both files and directories) + children := map[string]struct{}{} + prefixLen := len(prefix) + for _, entry := range tree.Entries { + if !strings.HasPrefix(entry.GetPath(), prefix) { + continue + } + rel := entry.GetPath()[prefixLen:] + if rel == "" { + continue + } + // Only immediate children (no deeper paths) + slashIdx := strings.Index(rel, "/") + if slashIdx >= 0 { + // Directory: only add the directory name (with trailing slash) + rel = rel[:slashIdx+1] + } else { + // File: leave as-is + } + // Optionally filter by argValue (if user is typing after last slash) + if argValue != "" { + afterSlash := argValue + if lastSlash := strings.LastIndex(argValue, "/"); lastSlash >= 0 { + afterSlash = argValue[lastSlash+1:] + } + if afterSlash != "" && !strings.HasPrefix(rel, afterSlash) { + continue + } + } + children[rel] = struct{}{} + } + + var values []string + for name := range children { + if name != "" { + values = append(values, name) + } + } + + if len(values) > 100 { + values = values[:100] } - return "" + return values, nil } From dab3fcafdf2721f1dbe7c811f5b29ceba294aa8b Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 3 Jun 2025 18:17:36 +0200 Subject: [PATCH 3/4] fix repo suggestions --- pkg/github/repository_completions.go | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/pkg/github/repository_completions.go b/pkg/github/repository_completions.go index 2c86ed34..bf81389e 100644 --- a/pkg/github/repository_completions.go +++ b/pkg/github/repository_completions.go @@ -21,10 +21,7 @@ func RepositoryResourceCompletionHandler(getClient GetClientFn) func(ctx context argName := req.Params.Argument.Name argValue := req.Params.Argument.Value - resolved, ok := any(req.Params.Resolved).(map[string]string) - if !ok && req.Params.Resolved != nil { - return nil, fmt.Errorf(".Resolved must be map[string]string, got %T", req.Params.Resolved) - } + resolved := req.Params.Resolved if resolved == nil { resolved = map[string]string{} } @@ -123,14 +120,23 @@ func completeRepo(ctx context.Context, client *github.Client, resolved map[strin return values, nil } - repos, _, err := client.Search.Repositories(ctx, fmt.Sprintf("org:%s %s in:name", owner, argValue), &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}}) + query := fmt.Sprintf("org:%s", owner) + + if argValue != "" { + query = fmt.Sprintf("%s %s", query, argValue) + } + repos, _, err := client.Search.Repositories(ctx, query, &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}}) if err != nil || repos == nil { return values, nil } - - if len(values) > 100 { - values = values[:100] + // filter repos based on argValue + for _, repo := range repos.Repositories { + name := repo.GetName() + if argValue == "" || strings.HasPrefix(name, argValue) { + values = append(values, name) + } } + return values, nil } @@ -144,7 +150,7 @@ func completeBranch(ctx context.Context, client *github.Client, resolved map[str branches, _, _ := client.Repositories.ListBranches(ctx, owner, repo, nil) for _, branch := range branches { - if argValue == "" || strings.Contains(branch.GetName(), argValue) { + if argValue == "" || strings.HasPrefix(branch.GetName(), argValue) { values = append(values, branch.GetName()) } } @@ -202,7 +208,7 @@ func completePRNumber(ctx context.Context, client *github.Client, resolved map[s return values, nil } // prs, _, _ := client.PullRequests.List(ctx, owner, repo, &github.PullRequestListOptions{}) - prs, _, _ := client.Search.Issues(ctx, fmt.Sprintf("repo:%s/%s is:open is:pr %s", owner, repo, argValue), &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}}) + prs, _, _ := client.Search.Issues(ctx, fmt.Sprintf("repo:%s/%s is:open is:pr", owner, repo), &github.SearchOptions{ListOptions: github.ListOptions{PerPage: 100}}) for _, pr := range prs.Issues { num := fmt.Sprintf("%d", pr.GetNumber()) if argValue == "" || strings.HasPrefix(num, argValue) { From fe63b685674f11da73607d673f0a5812d6489792 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 3 Jun 2025 18:33:38 +0200 Subject: [PATCH 4/4] make a bunch of improvements to the completions --- pkg/github/repository_completions.go | 64 +++++++++++++++++++--------- 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/pkg/github/repository_completions.go b/pkg/github/repository_completions.go index bf81389e..cdde9ceb 100644 --- a/pkg/github/repository_completions.go +++ b/pkg/github/repository_completions.go @@ -255,8 +255,9 @@ func completePath(ctx context.Context, client *github.Client, resolved map[strin return nil, nil } - // Collect immediate children of the prefix (both files and directories) - children := map[string]struct{}{} + // Collect immediate children of the prefix (files and directories, no duplicates) + dirs := map[string]struct{}{} + files := map[string]struct{}{} prefixLen := len(prefix) for _, entry := range tree.Entries { if !strings.HasPrefix(entry.GetPath(), prefix) { @@ -266,31 +267,56 @@ func completePath(ctx context.Context, client *github.Client, resolved map[strin if rel == "" { continue } - // Only immediate children (no deeper paths) + // Only immediate children slashIdx := strings.Index(rel, "/") if slashIdx >= 0 { - // Directory: only add the directory name (with trailing slash) - rel = rel[:slashIdx+1] + // Directory: only add the directory name (with trailing slash), prefixed with full path + dirName := prefix + rel[:slashIdx+1] + dirs[dirName] = struct{}{} + } else if entry.GetType() == "blob" { + // File: add as-is, prefixed with full path + fileName := prefix + rel + files[fileName] = struct{}{} + } + } + + // Optionally filter by argValue (if user is typing after last slash) + var filter string + if argValue != "" { + if lastSlash := strings.LastIndex(argValue, "/"); lastSlash >= 0 { + filter = argValue[lastSlash+1:] } else { - // File: leave as-is + filter = argValue } - // Optionally filter by argValue (if user is typing after last slash) - if argValue != "" { - afterSlash := argValue - if lastSlash := strings.LastIndex(argValue, "/"); lastSlash >= 0 { - afterSlash = argValue[lastSlash+1:] + } + + var values []string + // Add directories first, then files, both filtered + for dir := range dirs { + // Only filter on the last segment after the last slash + if filter == "" { + values = append(values, dir) + } else { + last := dir + if idx := strings.LastIndex(strings.TrimRight(dir, "/"), "/"); idx >= 0 { + last = dir[idx+1:] } - if afterSlash != "" && !strings.HasPrefix(rel, afterSlash) { - continue + if strings.HasPrefix(last, filter) { + values = append(values, dir) } } - children[rel] = struct{}{} } - - var values []string - for name := range children { - if name != "" { - values = append(values, name) + for file := range files { + if filter == "" { + values = append(values, file) + } else { + last := file + if idx := strings.LastIndex(file, "/"); idx >= 0 { + last = file[idx+1:] + } + if strings.HasPrefix(last, filter) { + values = append(values, file) + } } }