Skip to content

Commit 786c342

Browse files
committed
feat: add UpdateCodeScanningAlert and ListOrgCodeScanningAlerts tools to code_security
1 parent c17ebfe commit 786c342

File tree

3 files changed

+392
-0
lines changed

3 files changed

+392
-0
lines changed

pkg/github/code_scanning.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,171 @@ func ListCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHel
158158
return mcp.NewToolResultText(string(r)), nil
159159
}
160160
}
161+
162+
func ListOrgCodeScanningAlerts(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
163+
return mcp.NewTool("list_org_code_scanning_alerts",
164+
mcp.WithDescription(t("TOOL_LIST_ORG_CODE_SCANNING_ALERTS_DESCRIPTION", "List code scanning alerts for a GitHub organization.")),
165+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
166+
Title: t("TOOL_LIST_ORG_CODE_SCANNING_ALERTS_USER_TITLE", "List org code scanning alerts"),
167+
ReadOnlyHint: toBoolPtr(true),
168+
}),
169+
mcp.WithString("org",
170+
mcp.Required(),
171+
mcp.Description("The organization of the repository."),
172+
),
173+
mcp.WithString("sort",
174+
mcp.Description("Sort by"),
175+
mcp.Enum("created", "updated"),
176+
),
177+
mcp.WithString("severity",
178+
mcp.Description("Filter code scanning alerts by severity"),
179+
mcp.Enum("critical", "high", "medium", "low", "warning", "note", "error"),
180+
),
181+
mcp.WithString("tool_name",
182+
mcp.Description("The name of the tool used for code scanning."),
183+
),
184+
mcp.WithString("state",
185+
mcp.Description("Filter code scanning alerts by state. Defaults to open"),
186+
mcp.DefaultString("open"),
187+
mcp.Enum("open", "closed", "dismissed", "fixed"),
188+
),
189+
WithPagination(),
190+
),
191+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
192+
org, err := requiredParam[string](request, "org")
193+
if err != nil {
194+
return mcp.NewToolResultError(err.Error()), nil
195+
}
196+
sort, err := OptionalParam[string](request, "sort")
197+
if err != nil {
198+
return mcp.NewToolResultError(err.Error()), nil
199+
}
200+
severity, err := OptionalParam[string](request, "severity")
201+
if err != nil {
202+
return mcp.NewToolResultError(err.Error()), nil
203+
}
204+
toolName, err := OptionalParam[string](request, "tool_name")
205+
if err != nil {
206+
return mcp.NewToolResultError(err.Error()), nil
207+
}
208+
state, err := OptionalParam[string](request, "state")
209+
if err != nil {
210+
return mcp.NewToolResultError(err.Error()), nil
211+
}
212+
213+
client, err := getClient(ctx)
214+
if err != nil {
215+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
216+
}
217+
alerts, resp, err := client.CodeScanning.ListAlertsForOrg(ctx, org, &github.AlertListOptions{Sort: sort, State: state, Severity: severity, ToolName: toolName})
218+
if err != nil {
219+
return nil, fmt.Errorf("failed to list organization alerts: %w", err)
220+
}
221+
defer func() { _ = resp.Body.Close() }()
222+
223+
if resp.StatusCode != http.StatusOK {
224+
body, err := io.ReadAll(resp.Body)
225+
if err != nil {
226+
return nil, fmt.Errorf("failed to read response body: %w", err)
227+
}
228+
return mcp.NewToolResultError(fmt.Sprintf("failed to list organization alerts: %s", string(body))), nil
229+
}
230+
231+
r, err := json.Marshal(alerts)
232+
if err != nil {
233+
return nil, fmt.Errorf("failed to marshal alerts: %w", err)
234+
}
235+
236+
return mcp.NewToolResultText(string(r)), nil
237+
}
238+
}
239+
240+
func UpdateCodeScanningAlert(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
241+
return mcp.NewTool("update_code_scanning_alert",
242+
mcp.WithDescription(t("TOOL_UPDATE_CODE_SCANNING_ALERT_DESCRIPTION", "Update details of a specific code scanning alert in a GitHub repository.")),
243+
mcp.WithToolAnnotation(mcp.ToolAnnotation{
244+
Title: t("TOOL_UPDATE_CODE_SCANNING_ALERT_USER_TITLE", "Update code scanning alert"),
245+
ReadOnlyHint: toBoolPtr(false),
246+
}),
247+
mcp.WithString("owner",
248+
mcp.Required(),
249+
mcp.Description("The owner of the repository."),
250+
),
251+
mcp.WithString("repo",
252+
mcp.Required(),
253+
mcp.Description("The name of the repository."),
254+
),
255+
mcp.WithNumber("alertNumber",
256+
mcp.Required(),
257+
mcp.Description("The number of the alert."),
258+
),
259+
mcp.WithString("state",
260+
mcp.Required(),
261+
mcp.Description("State of the alert"),
262+
mcp.Enum("open", "dismissed"),
263+
),
264+
mcp.WithString("dismissed_reason",
265+
mcp.Description("Reason for dismissing or closing the alert"),
266+
mcp.Enum("false positive", "won't fix", "used in tests"),
267+
),
268+
mcp.WithString("dismissed_comment",
269+
mcp.Description("Dismissal comment associated with the dismissal of the alert"),
270+
),
271+
),
272+
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
273+
owner, err := requiredParam[string](request, "owner")
274+
if err != nil {
275+
return mcp.NewToolResultError(err.Error()), nil
276+
}
277+
repo, err := requiredParam[string](request, "repo")
278+
if err != nil {
279+
return mcp.NewToolResultError(err.Error()), nil
280+
}
281+
alertNumber, err := RequiredInt(request, "alertNumber")
282+
if err != nil {
283+
return mcp.NewToolResultError(err.Error()), nil
284+
}
285+
state, err := requiredParam[string](request, "state")
286+
if err != nil {
287+
return mcp.NewToolResultError(err.Error()), nil
288+
}
289+
dismissed_reason, err := OptionalParam[string](request, "dismissed_reason")
290+
if err != nil {
291+
return mcp.NewToolResultError(err.Error()), nil
292+
}
293+
dismissed_comment, err := OptionalParam[string](request, "dismissed_comment")
294+
if err != nil {
295+
return mcp.NewToolResultError(err.Error()), nil
296+
}
297+
298+
if state == "dismissed" && dismissed_reason == "" {
299+
return nil, fmt.Errorf("dismissed_reason required for 'dismissed' state ")
300+
}
301+
302+
client, err := getClient(ctx)
303+
if err != nil {
304+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
305+
}
306+
307+
alert, resp, err := client.CodeScanning.UpdateAlert(ctx, owner, repo, int64(alertNumber), &github.CodeScanningAlertState{State: state, DismissedReason: &dismissed_reason, DismissedComment: &dismissed_comment})
308+
if err != nil {
309+
return nil, fmt.Errorf("failed to update alert: %w", err)
310+
}
311+
defer func() { _ = resp.Body.Close() }()
312+
313+
if resp.StatusCode != http.StatusOK {
314+
body, err := io.ReadAll(resp.Body)
315+
if err != nil {
316+
return nil, fmt.Errorf("failed to read response body: %w", err)
317+
}
318+
return mcp.NewToolResultError(fmt.Sprintf("failed to update alert: %s", string(body))), nil
319+
}
320+
321+
r, err := json.Marshal(alert)
322+
if err != nil {
323+
return nil, fmt.Errorf("failed to marshal alert: %w", err)
324+
}
325+
326+
return mcp.NewToolResultText(string(r)), nil
327+
}
328+
}

pkg/github/code_scanning_test.go

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,222 @@ func Test_ListCodeScanningAlerts(t *testing.T) {
238238
})
239239
}
240240
}
241+
242+
func Test_UpdateCodeScanningAlert(t *testing.T) {
243+
// Verify tool definition
244+
mockClient := github.NewClient(nil)
245+
tool, _ := UpdateCodeScanningAlert(stubGetClientFn(mockClient), translations.NullTranslationHelper)
246+
247+
assert.Equal(t, "update_code_scanning_alert", tool.Name)
248+
assert.NotEmpty(t, tool.Description)
249+
assert.Contains(t, tool.InputSchema.Properties, "owner")
250+
assert.Contains(t, tool.InputSchema.Properties, "repo")
251+
assert.Contains(t, tool.InputSchema.Properties, "alertNumber")
252+
assert.Contains(t, tool.InputSchema.Properties, "state")
253+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"owner", "repo", "alertNumber", "state"})
254+
255+
// Mock alert for success
256+
mockAlert := &github.Alert{
257+
Number: github.Ptr(42),
258+
State: github.Ptr("open"),
259+
Rule: &github.Rule{ID: github.Ptr("rule-id"), Description: github.Ptr("desc")},
260+
HTMLURL: github.Ptr("https://github.com/owner/repo/security/code-scanning/42"),
261+
}
262+
263+
tests := []struct {
264+
name string
265+
mockedClient *http.Client
266+
requestArgs map[string]interface{}
267+
expectError bool
268+
expectedAlert *github.Alert
269+
expectedErrMsg string
270+
}{
271+
{
272+
name: "successful alert update",
273+
mockedClient: mock.NewMockedHTTPClient(
274+
mock.WithRequestMatch(
275+
mock.PatchReposCodeScanningAlertsByOwnerByRepoByAlertNumber,
276+
mockAlert,
277+
),
278+
),
279+
requestArgs: map[string]interface{}{
280+
"owner": "owner",
281+
"repo": "repo",
282+
"alertNumber": float64(42),
283+
"state": "open",
284+
},
285+
expectError: false,
286+
expectedAlert: mockAlert,
287+
},
288+
{
289+
name: "update fails",
290+
mockedClient: mock.NewMockedHTTPClient(
291+
mock.WithRequestMatchHandler(
292+
mock.PatchReposCodeScanningAlertsByOwnerByRepoByAlertNumber,
293+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
294+
w.WriteHeader(http.StatusBadRequest)
295+
_, _ = w.Write([]byte(`{"message": "Invalid request"}`))
296+
}),
297+
),
298+
),
299+
requestArgs: map[string]interface{}{
300+
"owner": "owner",
301+
"repo": "repo",
302+
"alertNumber": float64(9999),
303+
"state": "open",
304+
},
305+
expectError: true,
306+
expectedErrMsg: "failed to update alert",
307+
},
308+
{
309+
name: "error when dismissed_reason not provided",
310+
mockedClient: nil, // early exit happens before any HTTP call
311+
requestArgs: map[string]interface{}{
312+
"owner": "owner",
313+
"repo": "repo",
314+
"alertNumber": float64(42),
315+
"state": "dismissed",
316+
"dismissed_reason": "",
317+
},
318+
expectError: true,
319+
expectedErrMsg: "dismissed_reason required for 'dismissed' state",
320+
},
321+
}
322+
323+
for _, tc := range tests {
324+
t.Run(tc.name, func(t *testing.T) {
325+
client := github.NewClient(tc.mockedClient)
326+
_, handler := UpdateCodeScanningAlert(stubGetClientFn(client), translations.NullTranslationHelper)
327+
request := createMCPRequest(tc.requestArgs)
328+
329+
result, err := handler(context.Background(), request)
330+
if tc.expectError {
331+
require.Error(t, err)
332+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
333+
return
334+
}
335+
336+
require.NoError(t, err)
337+
text := getTextResult(t, result)
338+
var got github.Alert
339+
require.NoError(t, json.Unmarshal([]byte(text.Text), &got))
340+
341+
assert.Equal(t, *tc.expectedAlert.Number, *got.Number)
342+
assert.Equal(t, *tc.expectedAlert.State, *got.State)
343+
assert.Equal(t, *tc.expectedAlert.Rule.ID, *got.Rule.ID)
344+
assert.Equal(t, *tc.expectedAlert.HTMLURL, *got.HTMLURL)
345+
})
346+
}
347+
}
348+
349+
func Test_ListOrgCodeScanningAlerts(t *testing.T) {
350+
// Verify tool definition
351+
mockClient := github.NewClient(nil)
352+
tool, _ := ListOrgCodeScanningAlerts(stubGetClientFn(mockClient), translations.NullTranslationHelper)
353+
354+
assert.Equal(t, "list_org_code_scanning_alerts", tool.Name)
355+
assert.NotEmpty(t, tool.Description)
356+
assert.Contains(t, tool.InputSchema.Properties, "org")
357+
assert.Contains(t, tool.InputSchema.Properties, "sort")
358+
assert.Contains(t, tool.InputSchema.Properties, "severity")
359+
assert.Contains(t, tool.InputSchema.Properties, "tool_name")
360+
assert.Contains(t, tool.InputSchema.Properties, "state")
361+
assert.ElementsMatch(t, tool.InputSchema.Required, []string{"org"})
362+
363+
// Mock alerts for success
364+
mockAlerts := []*github.Alert{
365+
{
366+
Number: github.Ptr(100),
367+
State: github.Ptr("open"),
368+
Rule: &github.Rule{ID: github.Ptr("org-rule-1"), Description: github.Ptr("desc1")},
369+
HTMLURL: github.Ptr("https://github.com/org/repo/security/code-scanning/100"),
370+
},
371+
{
372+
Number: github.Ptr(101),
373+
State: github.Ptr("dismissed"),
374+
Rule: &github.Rule{ID: github.Ptr("org-rule-2"), Description: github.Ptr("desc2")},
375+
HTMLURL: github.Ptr("https://github.com/org/repo/security/code-scanning/101"),
376+
},
377+
}
378+
379+
tests := []struct {
380+
name string
381+
mockedClient *http.Client
382+
requestArgs map[string]interface{}
383+
expectError bool
384+
expectedAlerts []*github.Alert
385+
expectedErrMsg string
386+
}{
387+
{
388+
name: "successful org alerts listing",
389+
mockedClient: mock.NewMockedHTTPClient(
390+
mock.WithRequestMatchHandler(
391+
mock.GetOrgsCodeScanningAlertsByOrg,
392+
expectQueryParams(t, map[string]string{
393+
"state": "open",
394+
"severity": "high",
395+
"tool_name": "codeql",
396+
"sort": "created",
397+
}).andThen(
398+
mockResponse(t, http.StatusOK, mockAlerts),
399+
),
400+
),
401+
),
402+
requestArgs: map[string]interface{}{
403+
"org": "org",
404+
"state": "open",
405+
"severity": "high",
406+
"tool_name": "codeql",
407+
"sort": "created",
408+
},
409+
expectError: false,
410+
expectedAlerts: mockAlerts,
411+
},
412+
{
413+
name: "org alerts listing fails",
414+
mockedClient: mock.NewMockedHTTPClient(
415+
mock.WithRequestMatchHandler(
416+
mock.GetOrgsCodeScanningAlertsByOrg,
417+
http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
418+
w.WriteHeader(http.StatusForbidden)
419+
_, _ = w.Write([]byte(`{"message":"Forbidden"}`))
420+
}),
421+
),
422+
),
423+
requestArgs: map[string]interface{}{
424+
"org": "org",
425+
},
426+
expectError: true,
427+
expectedErrMsg: "failed to list organization alerts",
428+
},
429+
}
430+
431+
for _, tc := range tests {
432+
t.Run(tc.name, func(t *testing.T) {
433+
client := github.NewClient(tc.mockedClient)
434+
_, handler := ListOrgCodeScanningAlerts(stubGetClientFn(client), translations.NullTranslationHelper)
435+
request := createMCPRequest(tc.requestArgs)
436+
437+
result, err := handler(context.Background(), request)
438+
if tc.expectError {
439+
require.Error(t, err)
440+
assert.Contains(t, err.Error(), tc.expectedErrMsg)
441+
return
442+
}
443+
444+
require.NoError(t, err)
445+
text := getTextResult(t, result)
446+
447+
var got []*github.Alert
448+
require.NoError(t, json.Unmarshal([]byte(text.Text), &got))
449+
assert.Len(t, got, len(tc.expectedAlerts))
450+
451+
for i := range got {
452+
assert.Equal(t, *tc.expectedAlerts[i].Number, *got[i].Number)
453+
assert.Equal(t, *tc.expectedAlerts[i].State, *got[i].State)
454+
assert.Equal(t, *tc.expectedAlerts[i].Rule.ID, *got[i].Rule.ID)
455+
assert.Equal(t, *tc.expectedAlerts[i].HTMLURL, *got[i].HTMLURL)
456+
}
457+
})
458+
}
459+
}

0 commit comments

Comments
 (0)