Skip to content

Commit caea5e0

Browse files
separate org and user search
1 parent 02f0a75 commit caea5e0

File tree

3 files changed

+245
-78
lines changed

3 files changed

+245
-78
lines changed

pkg/github/search.go

Lines changed: 118 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -160,94 +160,134 @@ type MinimalSearchUsersResult struct {
160160
}
161161

162162
// SearchUsers creates a tool to search for GitHub users.
163-
func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
164-
return mcp.NewTool("search_users",
165-
mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Search for GitHub users")),
166-
mcp.WithToolAnnotation(mcp.ToolAnnotation{
167-
Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"),
168-
ReadOnlyHint: toBoolPtr(true),
169-
}),
170-
mcp.WithString("q",
171-
mcp.Required(),
172-
mcp.Description("Search query using GitHub users search syntax"),
173-
),
174-
mcp.WithString("sort",
175-
mcp.Description("Sort field by category"),
176-
mcp.Enum("followers", "repositories", "joined"),
177-
),
178-
mcp.WithString("order",
179-
mcp.Description("Sort order"),
180-
mcp.Enum("asc", "desc"),
181-
),
182-
WithPagination(),
183-
),
184-
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
185-
query, err := requiredParam[string](request, "q")
186-
if err != nil {
187-
return mcp.NewToolResultError(err.Error()), nil
188-
}
189-
sort, err := OptionalParam[string](request, "sort")
190-
if err != nil {
191-
return mcp.NewToolResultError(err.Error()), nil
192-
}
193-
order, err := OptionalParam[string](request, "order")
194-
if err != nil {
195-
return mcp.NewToolResultError(err.Error()), nil
196-
}
197-
pagination, err := OptionalPaginationParams(request)
198-
if err != nil {
199-
return mcp.NewToolResultError(err.Error()), nil
200-
}
163+
func userOrOrgHandler(accountType string, getClient GetClientFn) server.ToolHandlerFunc {
164+
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
165+
query, err := requiredParam[string](request, "q")
166+
if err != nil {
167+
return mcp.NewToolResultError(err.Error()), nil
168+
}
169+
sort, err := OptionalParam[string](request, "sort")
170+
if err != nil {
171+
return mcp.NewToolResultError(err.Error()), nil
172+
}
173+
order, err := OptionalParam[string](request, "order")
174+
if err != nil {
175+
return mcp.NewToolResultError(err.Error()), nil
176+
}
177+
pagination, err := OptionalPaginationParams(request)
178+
if err != nil {
179+
return mcp.NewToolResultError(err.Error()), nil
180+
}
201181

202-
opts := &github.SearchOptions{
203-
Sort: sort,
204-
Order: order,
205-
ListOptions: github.ListOptions{
206-
PerPage: pagination.perPage,
207-
Page: pagination.page,
208-
},
209-
}
182+
opts := &github.SearchOptions{
183+
Sort: sort,
184+
Order: order,
185+
ListOptions: github.ListOptions{
186+
PerPage: pagination.perPage,
187+
Page: pagination.page,
188+
},
189+
}
210190

211-
client, err := getClient(ctx)
212-
if err != nil {
213-
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
214-
}
191+
client, err := getClient(ctx)
192+
if err != nil {
193+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
194+
}
215195

216-
result, resp, err := client.Search.Users(ctx, "type:user "+query, opts)
196+
searchQuery := "type:" + accountType + " " + query
197+
result, resp, err := client.Search.Users(ctx, searchQuery, opts)
198+
if err != nil {
199+
return nil, fmt.Errorf("failed to search %ss: %w", accountType, err)
200+
}
201+
defer func() { _ = resp.Body.Close() }()
202+
203+
if resp.StatusCode != 200 {
204+
body, err := io.ReadAll(resp.Body)
217205
if err != nil {
218-
return nil, fmt.Errorf("failed to search users: %w", err)
206+
return nil, fmt.Errorf("failed to read response body: %w", err)
219207
}
220-
defer func() { _ = resp.Body.Close() }()
208+
return mcp.NewToolResultError(fmt.Sprintf("failed to search %ss: %s", accountType, string(body))), nil
209+
}
221210

222-
if resp.StatusCode != 200 {
223-
body, err := io.ReadAll(resp.Body)
224-
if err != nil {
225-
return nil, fmt.Errorf("failed to read response body: %w", err)
226-
}
227-
return mcp.NewToolResultError(fmt.Sprintf("failed to search users: %s", string(body))), nil
228-
}
211+
minimalUsers := make([]MinimalUser, 0, len(result.Users))
229212

230-
minimalUsers := make([]MinimalUser, 0, len(result.Users))
231-
for _, user := range result.Users {
232-
mu := MinimalUser{
233-
Login: user.GetLogin(),
234-
ID: user.GetID(),
235-
ProfileURL: user.GetHTMLURL(),
236-
AvatarURL: user.GetAvatarURL(),
213+
for _, user := range result.Users {
214+
if user.Login != nil {
215+
mu := MinimalUser{Login: *user.Login}
216+
if user.ID != nil {
217+
mu.ID = *user.ID
218+
}
219+
if user.HTMLURL != nil {
220+
mu.ProfileURL = *user.HTMLURL
221+
}
222+
if user.AvatarURL != nil {
223+
mu.AvatarURL = *user.AvatarURL
237224
}
238-
239225
minimalUsers = append(minimalUsers, mu)
240226
}
227+
}
228+
minimalResp := &MinimalSearchUsersResult{
229+
TotalCount: result.GetTotal(),
230+
IncompleteResults: result.GetIncompleteResults(),
231+
Items: minimalUsers,
232+
}
233+
if result.Total != nil {
234+
minimalResp.TotalCount = *result.Total
235+
}
236+
if result.IncompleteResults != nil {
237+
minimalResp.IncompleteResults = *result.IncompleteResults
238+
}
241239

242-
minimalResp := MinimalSearchUsersResult{
243-
TotalCount: result.GetTotal(),
244-
IncompleteResults: result.GetIncompleteResults(),
245-
Items: minimalUsers,
246-
}
247-
r, err := json.Marshal(minimalResp)
248-
if err != nil {
249-
return nil, fmt.Errorf("failed to marshal response: %w", err)
250-
}
251-
return mcp.NewToolResultText(string(r)), nil
240+
r, err := json.Marshal(minimalResp)
241+
if err != nil {
242+
return nil, fmt.Errorf("failed to marshal response: %w", err)
252243
}
244+
return mcp.NewToolResultText(string(r)), nil
245+
}
246+
}
247+
248+
func SearchUsers(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
249+
return mcp.NewTool("search_users",
250+
mcp.WithDescription(t("TOOL_SEARCH_USERS_DESCRIPTION", "Search for GitHub users exclusively")),
251+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
252+
Title: t("TOOL_SEARCH_USERS_USER_TITLE", "Search users"),
253+
ReadOnlyHint: toBoolPtr(true),
254+
}),
255+
mcp.WithString("q",
256+
mcp.Required(),
257+
mcp.Description("Search query using GitHub users search syntax scoped to type:user"),
258+
),
259+
mcp.WithString("sort",
260+
mcp.Description("Sort field by category"),
261+
mcp.Enum("followers", "repositories", "joined"),
262+
),
263+
mcp.WithString("order",
264+
mcp.Description("Sort order"),
265+
mcp.Enum("asc", "desc"),
266+
),
267+
WithPagination(),
268+
), userOrOrgHandler("user", getClient)
269+
}
270+
271+
// SearchOrgs creates a tool to search for GitHub organizations.
272+
func SearchOrgs(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
273+
return mcp.NewTool("search_orgs",
274+
mcp.WithDescription(t("TOOL_SEARCH_ORGS_DESCRIPTION", "Search for GitHub organizations exclusively")),
275+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
276+
Title: t("TOOL_SEARCH_ORGS_USER_TITLE", "Search organizations"),
277+
ReadOnlyHint: toBoolPtr(true),
278+
}),
279+
mcp.WithString("q",
280+
mcp.Required(),
281+
mcp.Description("Search query using GitHub organizations search syntax scoped to type:org"),
282+
),
283+
mcp.WithString("sort",
284+
mcp.Description("Sort field by category"),
285+
mcp.Enum("followers", "repositories", "joined"),
286+
),
287+
mcp.WithString("order",
288+
mcp.Description("Sort order"),
289+
mcp.Enum("asc", "desc"),
290+
),
291+
WithPagination(),
292+
), userOrOrgHandler("org", getClient)
253293
}

pkg/github/search_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,3 +461,125 @@ func Test_SearchUsers(t *testing.T) {
461461
})
462462
}
463463
}
464+
465+
func Test_SearchOrgs(t *testing.T) {
466+
// Verify tool definition once
467+
mockClient := github.NewClient(nil)
468+
tool, _ := SearchOrgs(stubGetClientFn(mockClient), translations.NullTranslationHelper)
469+
470+
assert.Equal(t, "search_orgs", tool.Name)
471+
assert.NotEmpty(t, tool.Description)
472+
assert.Contains(t, tool.InputSchema.Properties, "q")
473+
assert.Contains(t, tool.InputSchema.Properties, "sort")
474+
assert.Contains(t, tool.InputSchema.Properties, "order")
475+
assert.Contains(t, tool.InputSchema.Properties, "perPage")
476+
assert.Contains(t, tool.InputSchema.Properties, "page")
477+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"q"})
478+
479+
// Setup mock search results
480+
mockSearchResult := &github.UsersSearchResult{
481+
Total: github.Ptr(int(2)),
482+
IncompleteResults: github.Ptr(false),
483+
Users: []*github.User{
484+
{
485+
Login: github.Ptr("org-1"),
486+
ID: github.Ptr(int64(111)),
487+
HTMLURL: github.Ptr("https://github.com/org-1"),
488+
AvatarURL: github.Ptr("https://avatars.githubusercontent.com/u/111?v=4"),
489+
},
490+
{
491+
Login: github.Ptr("org-2"),
492+
ID: github.Ptr(int64(222)),
493+
HTMLURL: github.Ptr("https://github.com/org-2"),
494+
AvatarURL: github.Ptr("https://avatars.githubusercontent.com/u/222?v=4"),
495+
},
496+
},
497+
}
498+
499+
tests := []struct {
500+
name string
501+
mockedClient *http.Client
502+
requestArgs map[string]interface{}
503+
expectError bool
504+
expectedResult *github.UsersSearchResult
505+
expectedErrMsg string
506+
}{
507+
{
508+
name: "successful org search",
509+
mockedClient: mock.NewMockedHTTPClient(
510+
mock.WithRequestMatchHandler(
511+
mock.GetSearchUsers,
512+
expectQueryParams(t, map[string]string{
513+
"q": "type:org github",
514+
"page": "1",
515+
"per_page": "30",
516+
}).andThen(
517+
mockResponse(t, http.StatusOK, mockSearchResult),
518+
),
519+
),
520+
),
521+
requestArgs: map[string]interface{}{
522+
"q": "github",
523+
},
524+
expectError: false,
525+
expectedResult: mockSearchResult,
526+
},
527+
{
528+
name: "org search fails",
529+
mockedClient: mock.NewMockedHTTPClient(
530+
mock.WithRequestMatchHandler(
531+
mock.GetSearchUsers,
532+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
533+
w.WriteHeader(http.StatusBadRequest)
534+
_, _ = w.Write([]byte(`{"message": "Validation Failed"}`))
535+
}),
536+
),
537+
),
538+
requestArgs: map[string]interface{}{
539+
"q": "invalid:query",
540+
},
541+
expectError: true,
542+
expectedErrMsg: "failed to search orgs",
543+
},
544+
}
545+
546+
for _, tc := range tests {
547+
t.Run(tc.name, func(t *testing.T) {
548+
// Setup client with mock
549+
client := github.NewClient(tc.mockedClient)
550+
_, handler := SearchOrgs(stubGetClientFn(client), translations.NullTranslationHelper)
551+
552+
// Create call request
553+
request := createMCPRequest(tc.requestArgs)
554+
555+
// Call handler
556+
result, err := handler(context.Background(), request)
557+
558+
// Verify results
559+
if tc.expectError {
560+
require.Error(t, err)
561+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
562+
return
563+
}
564+
565+
require.NoError(t, err)
566+
require.NotNil(t, result)
567+
568+
textContent := getTextResult(t, result)
569+
570+
// Unmarshal and verify the result
571+
var returnedResult MinimalSearchUsersResult
572+
err = json.Unmarshal([]byte(textContent.Text), &returnedResult)
573+
require.NoError(t, err)
574+
assert.Equal(t, *tc.expectedResult.Total, returnedResult.TotalCount)
575+
assert.Equal(t, *tc.expectedResult.IncompleteResults, returnedResult.IncompleteResults)
576+
assert.Len(t, returnedResult.Items, len(tc.expectedResult.Users))
577+
for i, org := range returnedResult.Items {
578+
assert.Equal(t, *tc.expectedResult.Users[i].Login, org.Login)
579+
assert.Equal(t, *tc.expectedResult.Users[i].ID, org.ID)
580+
assert.Equal(t, *tc.expectedResult.Users[i].HTMLURL, org.ProfileURL)
581+
assert.Equal(t, *tc.expectedResult.Users[i].AvatarURL, org.AvatarURL)
582+
}
583+
})
584+
}
585+
}

pkg/github/tools.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
5757
AddReadTools(
5858
toolsets.NewServerTool(SearchUsers(getClient, t)),
5959
)
60+
orgs := toolsets.NewToolset("orgs", "GitHub Organization related tools").
61+
AddReadTools(
62+
toolsets.NewServerTool(SearchOrgs(getClient, t)),
63+
)
6064
pullRequests := toolsets.NewToolset("pull_requests", "GitHub Pull Request related tools").
6165
AddReadTools(
6266
toolsets.NewServerTool(GetPullRequest(getClient, t)),
@@ -111,6 +115,7 @@ func InitToolsets(passedToolsets []string, readOnly bool, getClient GetClientFn,
111115
tsg.AddToolset(repos)
112116
tsg.AddToolset(issues)
113117
tsg.AddToolset(users)
118+
tsg.AddToolset(orgs)
114119
tsg.AddToolset(pullRequests)
115120
tsg.AddToolset(codeSecurity)
116121
tsg.AddToolset(secretProtection)

0 commit comments

Comments
 (0)