From 3182ca950706d556e702556069042b0ad2d25b28 Mon Sep 17 00:00:00 2001 From: Amir Hardon Date: Mon, 12 May 2025 22:43:29 -0700 Subject: [PATCH] feat: Allow renaming tools in ClientGroup When using tools from multiple servers a name conflict might happen. In this change we: 1. Throw when a ClientGroup is created with conflicting tool names. 2. Allow the user to specify a component renaming function to fix name conflicts. We also add a `name` getter in the `Client` class as it provides useful information for solving tool name conflicts. --- src/client/clientGroup.test.ts | 163 +++++++++++++++++++++-- src/client/clientGroup.ts | 89 ++++++++++--- src/client/index.ts | 7 +- src/examples/client/clientGroupSample.ts | 20 ++- 4 files changed, 247 insertions(+), 32 deletions(-) diff --git a/src/client/clientGroup.test.ts b/src/client/clientGroup.test.ts index f53c682a..11ae0f3d 100644 --- a/src/client/clientGroup.test.ts +++ b/src/client/clientGroup.test.ts @@ -1,6 +1,11 @@ -import { ClientGroup } from "./clientGroup.js"; +import { ClientGroup, ComponentRenamer } from "./clientGroup.js"; import { Client } from "./index.js"; -import { Tool, CallToolRequest, CallToolResultSchema, Implementation } from "../types.js"; +import { + Tool, + CallToolRequest, + CallToolResultSchema, + Implementation, +} from "../types.js"; // Mock Client class for testing ClientGroup export class MockClient extends Client { @@ -32,7 +37,6 @@ export class MockClient extends Client { override assertRequestHandlerCapability() {} } - describe("ClientGroup", () => { let mockClient1: MockClient; let mockClient2: MockClient; @@ -43,8 +47,18 @@ describe("ClientGroup", () => { }); test("should list tools from all clients", async () => { - const tool1: Tool = { name: "tool1", description: "description1", parameters: {}, inputSchema: { type: 'object' } }; - const tool2: Tool = { name: "tool2", description: "description2", parameters: {}, inputSchema: { type: 'object' } }; + const tool1: Tool = { + name: "tool1", + description: "description1", + parameters: {}, + inputSchema: { type: "object" }, + }; + const tool2: Tool = { + name: "tool2", + description: "description2", + parameters: {}, + inputSchema: { type: "object" }, + }; mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] }); mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] }); @@ -58,8 +72,18 @@ describe("ClientGroup", () => { }); test("should call the correct tool on the correct client", async () => { - const tool1: Tool = { name: "tool1", description: "description1", parameters: {}, inputSchema: { type: 'object' } }; - const tool2: Tool = { name: "tool2", description: "description2", parameters: {}, inputSchema: { type: 'object' } }; + const tool1: Tool = { + name: "tool1", + description: "description1", + parameters: {}, + inputSchema: { type: "object" }, + }; + const tool2: Tool = { + name: "tool2", + description: "description2", + parameters: {}, + inputSchema: { type: "object" }, + }; mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] }); mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] }); @@ -74,7 +98,7 @@ describe("ClientGroup", () => { expect(mockClient1.mockCallTool).toHaveBeenCalledWith( params, CallToolResultSchema, - undefined + undefined, ); expect(mockClient2.mockCallTool).not.toHaveBeenCalled(); expect(result).toEqual({ result: "mock result for tool1" }); @@ -91,8 +115,127 @@ describe("ClientGroup", () => { parameters: {}, }; - await expect(clientGroup.callTool(params, CallToolResultSchema)).rejects.toThrow( - "Trying to call too nonExistentTool which is not provided by the client group" + await expect( + clientGroup.callTool(params, CallToolResultSchema), + ).rejects.toThrow( + "Trying to call too nonExistentTool which is not provided by the client group", + ); + }); + + test("should throw error on tool conflict", async () => { + const tool1: Tool = { + name: "tool", + description: "description1", + parameters: {}, + inputSchema: { type: "object" }, + }; + const tool2: Tool = { + name: "tool", + description: "description2", + parameters: {}, + inputSchema: { type: "object" }, + }; + mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] }); + mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] }); + + await expect( + ClientGroup.create([mockClient1, mockClient2]), + ).rejects.toThrow( + "Tool name: tool (original: tool) is available on multiple servers", + ); + }); + + test("should list renamed tools", async () => { + const tool1: Tool = { + name: "tool", + description: "description1", + parameters: {}, + inputSchema: { type: "object" }, + }; + const tool2: Tool = { + name: "tool", + description: "description2", + parameters: {}, + inputSchema: { type: "object" }, + }; + mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] }); + mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] }); + + const renamer: ComponentRenamer = function ( + clientName: string, + componentName: string, + ) { + return `${clientName}.${componentName}`; + }; + + const clientGroup = await ClientGroup.create( + [mockClient1, mockClient2], + renamer, + ); + + const tools = await clientGroup.listTools(); + + expect(tools).toHaveLength(2); + const expectedTool1: Tool = { + name: "client1.tool", + description: "description1", + parameters: {}, + inputSchema: { type: "object" }, + }; + const expectedTool2: Tool = { + name: "client2.tool", + description: "description2", + parameters: {}, + inputSchema: { type: "object" }, + }; + expect(tools).toEqual( + expect.arrayContaining([expectedTool1, expectedTool2]), + ); + }); + + test("should call renamed tool", async () => { + const tool1: Tool = { + name: "tool", + description: "description1", + parameters: {}, + inputSchema: { type: "object" }, + }; + const tool2: Tool = { + name: "tool", + description: "description2", + parameters: {}, + inputSchema: { type: "object" }, + }; + mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] }); + mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] }); + + const renamer: ComponentRenamer = function ( + clientName: string, + componentName: string, + ) { + return `${clientName}.${componentName}`; + }; + + const clientGroup = await ClientGroup.create( + [mockClient1, mockClient2], + renamer, + ); + + const params: CallToolRequest["params"] = { + name: "client1.tool", + parameters: { arg: "value" }, + }; + + const result = await clientGroup.callTool(params, CallToolResultSchema); + + const expectedCallPArams: CallToolRequest["params"] = { + name: "tool", + parameters: { arg: "value" }, + }; + expect(mockClient1.mockCallTool).toHaveBeenCalledWith( + expectedCallPArams, + CallToolResultSchema, + undefined, ); }); diff --git a/src/client/clientGroup.ts b/src/client/clientGroup.ts index c9821848..61c1c7c5 100644 --- a/src/client/clientGroup.ts +++ b/src/client/clientGroup.ts @@ -1,7 +1,17 @@ import { RequestOptions } from "../shared/protocol.js"; -import { Tool, CallToolRequest, CallToolResultSchema, CompatibilityCallToolResultSchema } from "../types.js"; +import { + Tool, + CallToolRequest, + CallToolResultSchema, + CompatibilityCallToolResultSchema, +} from "../types.js"; import { Client } from "./index.js"; +export type ComponentRenamer = ( + clientName: string, + componentName: string, +) => string; + /** * A group of MCP clients. * @@ -23,29 +33,61 @@ import { Client } from "./index.js"; * // Close all clients * await clientGroup.close(); * ``` + * + * Example with renaming components to fix tool name conflicts: + * + * ```typescript + * + * // Define a renaming function + * const renamer = (clientName: string, componentName: string) => { + * if (clientName === "client-2" && componentName === "ping") { + * return "ping2"; + * } + * return componentName; + * }; + * + * // Create a client group with the renamer + * // In this case both clients provide a `ping` tool. + * const clientGroup = await ClientGroup.create([client1, client2], renamer); + * + * // List tools (will include "ping" and "ping2") + * const tools = await clientGroup.listTools(); + * + * // Call the renamed tool + * const result = await clientGroup.callTool({ name: "ping2", params: {} }); + * ``` */ export class ClientGroup { private _clients: Client[]; private _allTools: Tool[]; - private _toolToClient: { [key: string]: Client; } = {}; + private _toolToClient: { + [key: string]: { client: Client; origToolName: string }; + } = {}; + private _componentRename: ( + componentName: string, + clientName: string, + ) => string; - private constructor( - clients: Client[] - ) { + private constructor(clients: Client[], componentRename?: ComponentRenamer) { this._clients = clients; this._allTools = []; + this._componentRename = + componentRename ?? + ((clientName: string, componentName: string) => componentName); } /** * Creates a new ClientGroup. - * + * * @param clients The list of clients to include in the group. + * @param componentRename An optional function to rename components (like tools or resources) to avoid name conflicts when combining multiple clients. The function takes the original component name and the client name as arguments and should return the new, unique component name. Defaults to using the original component name. */ static async create( clients: Client[], - options?: RequestOptions + componentRename?: (componentName: string, clientName: string) => string, + options?: RequestOptions, ): Promise { - const group = new ClientGroup(clients); + const group = new ClientGroup(clients, componentRename); await group.update(options); return group; } @@ -55,13 +97,17 @@ export class ClientGroup { this._toolToClient = {}; for (const client of this._clients) { for (const tool of (await client.listTools(options)).tools) { + const origName = tool.name; + tool.name = this._componentRename(client.name, tool.name); if (this._toolToClient[tool.name]) { - // TODO(amirh): we should allow the users to configure tool renames. - console.warn( - `Tool name: ${tool.name} is available on multiple servers, picking an arbitrary one` - ); + throw new Error(` + Tool name: ${tool.name} (original: ${origName}) is available on multiple servers + `); } - this._toolToClient[tool.name] = client; + this._toolToClient[tool.name] = { + client: client, + origToolName: origName, + }; this._allTools.push(tool); } } @@ -88,16 +134,23 @@ export class ClientGroup { */ async callTool( params: CallToolRequest["params"], - resultSchema: typeof CallToolResultSchema | - typeof CompatibilityCallToolResultSchema = CallToolResultSchema, - options?: RequestOptions + resultSchema: + | typeof CallToolResultSchema + | typeof CompatibilityCallToolResultSchema = CallToolResultSchema, + options?: RequestOptions, ) { if (!this._toolToClient[params.name]) { throw new Error( - `Trying to call too ${params.name} which is not provided by the client group` + `Trying to call too ${params.name} which is not provided by the client group`, ); } - return this._toolToClient[params.name].callTool(params, resultSchema, options); + const actualParams = structuredClone(params); + actualParams.name = this._toolToClient[params.name].origToolName; + return this._toolToClient[params.name].client.callTool( + actualParams, + resultSchema, + options, + ); } /** diff --git a/src/client/index.ts b/src/client/index.ts index 69612b77..0c66f8bc 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -434,4 +434,9 @@ export class Client< async sendRootsListChanged() { return this.notification({ method: "notifications/roots/list_changed" }); } -} \ No newline at end of file + + get name(): string { + return this._clientInfo.name; + } +} + diff --git a/src/examples/client/clientGroupSample.ts b/src/examples/client/clientGroupSample.ts index 86b7c0aa..81c5dab9 100644 --- a/src/examples/client/clientGroupSample.ts +++ b/src/examples/client/clientGroupSample.ts @@ -1,6 +1,6 @@ import { Tool } from "../../types.js"; import { Client } from "../../client/index.js"; -import { ClientGroup } from "../../client/clientGroup.js"; +import { ClientGroup, ComponentRenamer } from "../../client/clientGroup.js"; import { InMemoryTransport } from "../../inMemory.js"; import { McpServer, ToolCallback } from "../../server/mcp.js"; import { Transport } from "../../shared/transport.js"; @@ -29,8 +29,17 @@ async function main(): Promise { }); client3.connect(clientTransports[2]); + const renamer: ComponentRenamer = function ( + clientName: string, + componentName: string, + ) { + if (clientName === "client-3" && componentName === "ping") { + return "ping2"; + } + return componentName; + }; const allClients = [client1, client2, client3]; - const clientGroup = await ClientGroup.create(allClients); + const clientGroup = await ClientGroup.create(allClients, renamer); const allResources = []; allResources.push(...(await client1.listResources()).resources); @@ -40,10 +49,15 @@ async function main(): Promise { const toolName = simulatePromptModel(await clientGroup.listTools()); console.log(`Invoking tool: ${toolName}`); - const toolResult = await clientGroup.callTool({ + let toolResult = await clientGroup.callTool({ name: toolName, }); + console.log(toolResult); + console.log("Invoking tool: ping2"); + toolResult = await clientGroup.callTool({ + name: "ping2", + }); console.log(toolResult); clientGroup.close();