From 786c342c104d4f5ce8f9d8401e8012996b14bf37 Mon Sep 17 00:00:00 2001 From: Juan Broullon Date: Mon, 9 Jun 2025 01:24:57 +0200 Subject: [PATCH] feat: add UpdateCodeScanningAlert and ListOrgCodeScanningAlerts tools to code_security --- pkg/github/code_scanning.go | 168 ++++++++++++++++++++++++ pkg/github/code_scanning_test.go | 219 +++++++++++++++++++++++++++++++ pkg/github/tools.go | 5 + 3 files changed, 392 insertions(+) diff --git a/pkg/github/code_scanning.go b/pkg/github/code_scanning.go index 1886b6342..d4d11d22a 100644 --- a/pkg/github/code_scanning.go +++ b/pkg/github/code_scanning.go @@ -158,3 +158,171 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel return mcp.NewToolResultText(string(r)), nil } } + +func ListOrgCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("list_org_code_scanning_alerts", + mcp.WithDescription(t("TOOL_LIST_ORG_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts for a GitHub organization.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_LIST_ORG_CODE_SCANNING_ALERTS_USER_TITLE", "List org code scanning alerts"), + ReadOnlyHint: toBoolPtr(true), + }), + mcp.WithString("org", + mcp.Required(), + mcp.Description("The organization of the repository."), + ), + mcp.WithString("sort", + mcp.Description("Sort by"), + mcp.Enum("created", "updated"), + ), + mcp.WithString("severity", + mcp.Description("Filter code scanning alerts by severity"), + mcp.Enum("critical", "high", "medium", "low", "warning", "note", "error"), + ), + mcp.WithString("tool_name", + mcp.Description("The name of the tool used for code scanning."), + ), + mcp.WithString("state", + mcp.Description("Filter code scanning alerts by state. Defaults to open"), + mcp.DefaultString("open"), + mcp.Enum("open", "closed", "dismissed", "fixed"), + ), + WithPagination(), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + org, err := requiredParam[string](request, "org") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + sort, err := OptionalParam[string](request, "sort") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + severity, err := OptionalParam[string](request, "severity") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + toolName, err := OptionalParam[string](request, "tool_name") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := OptionalParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + alerts, resp, err := client.CodeScanning.ListAlertsForOrg(ctx, org, &github.AlertListOptions{Sort: sort, State: state, Severity: severity, ToolName: toolName}) + if err != nil { + return nil, fmt.Errorf("failed to list organization alerts: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to list organization alerts: %s", string(body))), nil + } + + r, err := json.Marshal(alerts) + if err != nil { + return nil, fmt.Errorf("failed to marshal alerts: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} + +func UpdateCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) { + return mcp.NewTool("update_code_scanning_alert", + mcp.WithDescription(t("TOOL_UPDATE_CODE_SCANNING_ALERT_DESCRIPTION", "Update details of a specific code scanning alert in a GitHub repository.")), + mcp.WithToolAnnotation(mcp.ToolAnnotation{ + Title: t("TOOL_UPDATE_CODE_SCANNING_ALERT_USER_TITLE", "Update code scanning alert"), + ReadOnlyHint: toBoolPtr(false), + }), + mcp.WithString("owner", + mcp.Required(), + mcp.Description("The owner of the repository."), + ), + mcp.WithString("repo", + mcp.Required(), + mcp.Description("The name of the repository."), + ), + mcp.WithNumber("alertNumber", + mcp.Required(), + mcp.Description("The number of the alert."), + ), + mcp.WithString("state", + mcp.Required(), + mcp.Description("State of the alert"), + mcp.Enum("open", "dismissed"), + ), + mcp.WithString("dismissed_reason", + mcp.Description("Reason for dismissing or closing the alert"), + mcp.Enum("false positive", "won't fix", "used in tests"), + ), + mcp.WithString("dismissed_comment", + mcp.Description("Dismissal comment associated with the dismissal of the alert"), + ), + ), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + owner, err := requiredParam[string](request, "owner") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + repo, err := requiredParam[string](request, "repo") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + alertNumber, err := RequiredInt(request, "alertNumber") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + state, err := requiredParam[string](request, "state") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + dismissed_reason, err := OptionalParam[string](request, "dismissed_reason") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + dismissed_comment, err := OptionalParam[string](request, "dismissed_comment") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + if state == "dismissed" && dismissed_reason == "" { + return nil, fmt.Errorf("dismissed_reason required for 'dismissed' state ") + } + + client, err := getClient(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GitHub client: %w", err) + } + + alert, resp, err := client.CodeScanning.UpdateAlert(ctx, owner, repo, int64(alertNumber), &github.CodeScanningAlertState{State: state, DismissedReason: &dismissed_reason, DismissedComment: &dismissed_comment}) + if err != nil { + return nil, fmt.Errorf("failed to update alert: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + return mcp.NewToolResultError(fmt.Sprintf("failed to update alert: %s", string(body))), nil + } + + r, err := json.Marshal(alert) + if err != nil { + return nil, fmt.Errorf("failed to marshal alert: %w", err) + } + + return mcp.NewToolResultText(string(r)), nil + } +} diff --git a/pkg/github/code_scanning_test.go b/pkg/github/code_scanning_test.go index b5facbf6b..92bb4ce6a 100644 --- a/pkg/github/code_scanning_test.go +++ b/pkg/github/code_scanning_test.go @@ -238,3 +238,222 @@ func Test_ListCodeScanningAlerts(t *testing.T) { }) } } + +func Test_UpdateCodeScanningAlert(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := UpdateCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "update_code_scanning_alert", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "owner") + assert.Contains(t, tool.InputSchema.Properties, "repo") + assert.Contains(t, tool.InputSchema.Properties, "alertNumber") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "alertNumber", "state"}) + + // Mock alert for success + mockAlert := &github.Alert{ + Number: github.Ptr(42), + State: github.Ptr("open"), + Rule: &github.Rule{ID: github.Ptr("rule-id"), Description: github.Ptr("desc")}, + HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/42"), + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAlert *github.Alert + expectedErrMsg string + }{ + { + name: "successful alert update", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatch( + mock.PatchReposCodeScanningAlertsByOwnerByRepoByAlertNumber, + mockAlert, + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alertNumber": float64(42), + "state": "open", + }, + expectError: false, + expectedAlert: mockAlert, + }, + { + name: "update fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.PatchReposCodeScanningAlertsByOwnerByRepoByAlertNumber, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"message": "Invalid request"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alertNumber": float64(9999), + "state": "open", + }, + expectError: true, + expectedErrMsg: "failed to update alert", + }, + { + name: "error when dismissed_reason not provided", + mockedClient: nil, // early exit happens before any HTTP call + requestArgs: map[string]interface{}{ + "owner": "owner", + "repo": "repo", + "alertNumber": float64(42), + "state": "dismissed", + "dismissed_reason": "", + }, + expectError: true, + expectedErrMsg: "dismissed_reason required for 'dismissed' state", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := UpdateCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + text := getTextResult(t, result) + var got github.Alert + require.NoError(t, json.Unmarshal([]byte(text.Text), &got)) + + assert.Equal(t, *tc.expectedAlert.Number, *got.Number) + assert.Equal(t, *tc.expectedAlert.State, *got.State) + assert.Equal(t, *tc.expectedAlert.Rule.ID, *got.Rule.ID) + assert.Equal(t, *tc.expectedAlert.HTMLURL, *got.HTMLURL) + }) + } +} + +func Test_ListOrgCodeScanningAlerts(t *testing.T) { + // Verify tool definition + mockClient := github.NewClient(nil) + tool, _ := ListOrgCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper) + + assert.Equal(t, "list_org_code_scanning_alerts", tool.Name) + assert.NotEmpty(t, tool.Description) + assert.Contains(t, tool.InputSchema.Properties, "org") + assert.Contains(t, tool.InputSchema.Properties, "sort") + assert.Contains(t, tool.InputSchema.Properties, "severity") + assert.Contains(t, tool.InputSchema.Properties, "tool_name") + assert.Contains(t, tool.InputSchema.Properties, "state") + assert.ElementsMatch(t, tool.InputSchema.Required, []string{"org"}) + + // Mock alerts for success + mockAlerts := []*github.Alert{ + { + Number: github.Ptr(100), + State: github.Ptr("open"), + Rule: &github.Rule{ID: github.Ptr("org-rule-1"), Description: github.Ptr("desc1")}, + HTMLURL: github.Ptr("https://github.com/org/repo/security/code-scanning/100"), + }, + { + Number: github.Ptr(101), + State: github.Ptr("dismissed"), + Rule: &github.Rule{ID: github.Ptr("org-rule-2"), Description: github.Ptr("desc2")}, + HTMLURL: github.Ptr("https://github.com/org/repo/security/code-scanning/101"), + }, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]interface{} + expectError bool + expectedAlerts []*github.Alert + expectedErrMsg string + }{ + { + name: "successful org alerts listing", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsCodeScanningAlertsByOrg, + expectQueryParams(t, map[string]string{ + "state": "open", + "severity": "high", + "tool_name": "codeql", + "sort": "created", + }).andThen( + mockResponse(t, http.StatusOK, mockAlerts), + ), + ), + ), + requestArgs: map[string]interface{}{ + "org": "org", + "state": "open", + "severity": "high", + "tool_name": "codeql", + "sort": "created", + }, + expectError: false, + expectedAlerts: mockAlerts, + }, + { + name: "org alerts listing fails", + mockedClient: mock.NewMockedHTTPClient( + mock.WithRequestMatchHandler( + mock.GetOrgsCodeScanningAlertsByOrg, + http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message":"Forbidden"}`)) + }), + ), + ), + requestArgs: map[string]interface{}{ + "org": "org", + }, + expectError: true, + expectedErrMsg: "failed to list organization alerts", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := github.NewClient(tc.mockedClient) + _, handler := ListOrgCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper) + request := createMCPRequest(tc.requestArgs) + + result, err := handler(context.Background(), request) + if tc.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErrMsg) + return + } + + require.NoError(t, err) + text := getTextResult(t, result) + + var got []*github.Alert + require.NoError(t, json.Unmarshal([]byte(text.Text), &got)) + assert.Len(t, got, len(tc.expectedAlerts)) + + for i := range got { + assert.Equal(t, *tc.expectedAlerts[i].Number, *got[i].Number) + assert.Equal(t, *tc.expectedAlerts[i].State, *got[i].State) + assert.Equal(t, *tc.expectedAlerts[i].Rule.ID, *got[i].Rule.ID) + assert.Equal(t, *tc.expectedAlerts[i].HTMLURL, *got[i].HTMLURL) + } + }) + } +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index f8e05fc85..687053eb5 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -84,7 +84,12 @@ func DefaultToolsetGroup(readOnly bool, getClient GetClientFn, getGQLClient GetG AddReadTools( toolsets.NewServerTool(GetCodeScanningAlert(getClient, t)), toolsets.NewServerTool(ListCodeScanningAlerts(getClient, t)), + toolsets.NewServerTool(ListOrgCodeScanningAlerts(getClient, t)), + ). + AddWriteTools( + toolsets.NewServerTool(UpdateCodeScanningAlert(getClient, t)), ) + secretProtection := toolsets.NewToolset("secret_protection", "Secret protection related tools, such as GitHub Secret Scanning"). AddReadTools( toolsets.NewServerTool(GetSecretScanningAlert(getClient, t)),