Skip to content

feat: Allow renaming tools in ClientGroup #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: session_group_2_listtools
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 153 additions & 10 deletions src/client/clientGroup.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { ClientGroup } from "./clientGroup.js";
import { ClientGroup, ComponentRenamer } from "./clientGroup.js";
import { Client } from "./index.js";
import { Tool, CallToolRequest, CallToolResultSchema, Implementation } from "../types.js";
import {
Tool,
CallToolRequest,
CallToolResultSchema,
Implementation,
} from "../types.js";

// Mock Client class for testing ClientGroup
export class MockClient extends Client {
Expand Down Expand Up @@ -32,7 +37,6 @@ export class MockClient extends Client {
override assertRequestHandlerCapability() {}
}


describe("ClientGroup", () => {
let mockClient1: MockClient;
let mockClient2: MockClient;
Expand All @@ -43,8 +47,18 @@ describe("ClientGroup", () => {
});

test("should list tools from all clients", async () => {
const tool1: Tool = { name: "tool1", description: "description1", parameters: {}, inputSchema: { type: 'object' } };
const tool2: Tool = { name: "tool2", description: "description2", parameters: {}, inputSchema: { type: 'object' } };
const tool1: Tool = {
name: "tool1",
description: "description1",
parameters: {},
inputSchema: { type: "object" },
};
const tool2: Tool = {
name: "tool2",
description: "description2",
parameters: {},
inputSchema: { type: "object" },
};
mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] });
mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] });

Expand All @@ -58,8 +72,18 @@ describe("ClientGroup", () => {
});

test("should call the correct tool on the correct client", async () => {
const tool1: Tool = { name: "tool1", description: "description1", parameters: {}, inputSchema: { type: 'object' } };
const tool2: Tool = { name: "tool2", description: "description2", parameters: {}, inputSchema: { type: 'object' } };
const tool1: Tool = {
name: "tool1",
description: "description1",
parameters: {},
inputSchema: { type: "object" },
};
const tool2: Tool = {
name: "tool2",
description: "description2",
parameters: {},
inputSchema: { type: "object" },
};
mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] });
mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] });

Expand All @@ -74,7 +98,7 @@ describe("ClientGroup", () => {
expect(mockClient1.mockCallTool).toHaveBeenCalledWith(
params,
CallToolResultSchema,
undefined
undefined,
);
expect(mockClient2.mockCallTool).not.toHaveBeenCalled();
expect(result).toEqual({ result: "mock result for tool1" });
Expand All @@ -91,8 +115,127 @@ describe("ClientGroup", () => {
parameters: {},
};

await expect(clientGroup.callTool(params, CallToolResultSchema)).rejects.toThrow(
"Trying to call too nonExistentTool which is not provided by the client group"
await expect(
clientGroup.callTool(params, CallToolResultSchema),
).rejects.toThrow(
"Trying to call too nonExistentTool which is not provided by the client group",
);
});

test("should throw error on tool conflict", async () => {
const tool1: Tool = {
name: "tool",
description: "description1",
parameters: {},
inputSchema: { type: "object" },
};
const tool2: Tool = {
name: "tool",
description: "description2",
parameters: {},
inputSchema: { type: "object" },
};
mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] });
mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] });

await expect(
ClientGroup.create([mockClient1, mockClient2]),
).rejects.toThrow(
"Tool name: tool (original: tool) is available on multiple servers",
);
});

test("should list renamed tools", async () => {
const tool1: Tool = {
name: "tool",
description: "description1",
parameters: {},
inputSchema: { type: "object" },
};
const tool2: Tool = {
name: "tool",
description: "description2",
parameters: {},
inputSchema: { type: "object" },
};
mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] });
mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] });

const renamer: ComponentRenamer = function (
clientName: string,
componentName: string,
) {
return `${clientName}.${componentName}`;
};

const clientGroup = await ClientGroup.create(
[mockClient1, mockClient2],
renamer,
);

const tools = await clientGroup.listTools();

expect(tools).toHaveLength(2);
const expectedTool1: Tool = {
name: "client1.tool",
description: "description1",
parameters: {},
inputSchema: { type: "object" },
};
const expectedTool2: Tool = {
name: "client2.tool",
description: "description2",
parameters: {},
inputSchema: { type: "object" },
};
expect(tools).toEqual(
expect.arrayContaining([expectedTool1, expectedTool2]),
);
});

test("should call renamed tool", async () => {
const tool1: Tool = {
name: "tool",
description: "description1",
parameters: {},
inputSchema: { type: "object" },
};
const tool2: Tool = {
name: "tool",
description: "description2",
parameters: {},
inputSchema: { type: "object" },
};
mockClient1.mockListTools.mockResolvedValueOnce({ tools: [tool1] });
mockClient2.mockListTools.mockResolvedValueOnce({ tools: [tool2] });

const renamer: ComponentRenamer = function (
clientName: string,
componentName: string,
) {
return `${clientName}.${componentName}`;
};

const clientGroup = await ClientGroup.create(
[mockClient1, mockClient2],
renamer,
);

const params: CallToolRequest["params"] = {
name: "client1.tool",
parameters: { arg: "value" },
};

const result = await clientGroup.callTool(params, CallToolResultSchema);

const expectedCallPArams: CallToolRequest["params"] = {
name: "tool",
parameters: { arg: "value" },
};
expect(mockClient1.mockCallTool).toHaveBeenCalledWith(
expectedCallPArams,
CallToolResultSchema,
undefined,
);
});

Expand Down
89 changes: 71 additions & 18 deletions src/client/clientGroup.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
import { RequestOptions } from "../shared/protocol.js";
import { Tool, CallToolRequest, CallToolResultSchema, CompatibilityCallToolResultSchema } from "../types.js";
import {
Tool,
CallToolRequest,
CallToolResultSchema,
CompatibilityCallToolResultSchema,
} from "../types.js";
import { Client } from "./index.js";

export type ComponentRenamer = (
clientName: string,
componentName: string,
) => string;

/**
* A group of MCP clients.
*
Expand All @@ -23,29 +33,61 @@ import { Client } from "./index.js";
* // Close all clients
* await clientGroup.close();
* ```
*
* Example with renaming components to fix tool name conflicts:
*
* ```typescript
*
* // Define a renaming function
* const renamer = (clientName: string, componentName: string) => {
* if (clientName === "client-2" && componentName === "ping") {
* return "ping2";
* }
* return componentName;
* };
*
* // Create a client group with the renamer
* // In this case both clients provide a `ping` tool.
* const clientGroup = await ClientGroup.create([client1, client2], renamer);
*
* // List tools (will include "ping" and "ping2")
* const tools = await clientGroup.listTools();
*
* // Call the renamed tool
* const result = await clientGroup.callTool({ name: "ping2", params: {} });
* ```
*/
export class ClientGroup {
private _clients: Client[];
private _allTools: Tool[];
private _toolToClient: { [key: string]: Client; } = {};
private _toolToClient: {
[key: string]: { client: Client; origToolName: string };
} = {};
private _componentRename: (
componentName: string,
clientName: string,
) => string;

private constructor(
clients: Client[]
) {
private constructor(clients: Client[], componentRename?: ComponentRenamer) {
this._clients = clients;
this._allTools = [];
this._componentRename =
componentRename ??
((clientName: string, componentName: string) => componentName);
}

/**
* Creates a new ClientGroup.
*
*
* @param clients The list of clients to include in the group.
* @param componentRename An optional function to rename components (like tools or resources) to avoid name conflicts when combining multiple clients. The function takes the original component name and the client name as arguments and should return the new, unique component name. Defaults to using the original component name.
*/
static async create(
clients: Client[],
options?: RequestOptions
componentRename?: (componentName: string, clientName: string) => string,
options?: RequestOptions,
): Promise<ClientGroup> {
const group = new ClientGroup(clients);
const group = new ClientGroup(clients, componentRename);
await group.update(options);
return group;
}
Expand All @@ -55,13 +97,17 @@ export class ClientGroup {
this._toolToClient = {};
for (const client of this._clients) {
for (const tool of (await client.listTools(options)).tools) {
const origName = tool.name;
tool.name = this._componentRename(client.name, tool.name);
if (this._toolToClient[tool.name]) {
// TODO(amirh): we should allow the users to configure tool renames.
console.warn(
`Tool name: ${tool.name} is available on multiple servers, picking an arbitrary one`
);
throw new Error(`
Tool name: ${tool.name} (original: ${origName}) is available on multiple servers
`);
}
this._toolToClient[tool.name] = client;
this._toolToClient[tool.name] = {
client: client,
origToolName: origName,
};
this._allTools.push(tool);
}
}
Expand All @@ -88,16 +134,23 @@ export class ClientGroup {
*/
async callTool(
params: CallToolRequest["params"],
resultSchema: typeof CallToolResultSchema |
typeof CompatibilityCallToolResultSchema = CallToolResultSchema,
options?: RequestOptions
resultSchema:
| typeof CallToolResultSchema
| typeof CompatibilityCallToolResultSchema = CallToolResultSchema,
options?: RequestOptions,
) {
if (!this._toolToClient[params.name]) {
throw new Error(
`Trying to call too ${params.name} which is not provided by the client group`
`Trying to call too ${params.name} which is not provided by the client group`,
);
}
return this._toolToClient[params.name].callTool(params, resultSchema, options);
const actualParams = structuredClone(params);
actualParams.name = this._toolToClient[params.name].origToolName;
return this._toolToClient[params.name].client.callTool(
actualParams,
resultSchema,
options,
);
}

/**
Expand Down
7 changes: 6 additions & 1 deletion src/client/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -434,4 +434,9 @@ export class Client<
async sendRootsListChanged() {
return this.notification({ method: "notifications/roots/list_changed" });
}
}

get name(): string {
return this._clientInfo.name;
}
}

20 changes: 17 additions & 3 deletions src/examples/client/clientGroupSample.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Tool } from "../../types.js";
import { Client } from "../../client/index.js";
import { ClientGroup } from "../../client/clientGroup.js";
import { ClientGroup, ComponentRenamer } from "../../client/clientGroup.js";
import { InMemoryTransport } from "../../inMemory.js";
import { McpServer, ToolCallback } from "../../server/mcp.js";
import { Transport } from "../../shared/transport.js";
Expand Down Expand Up @@ -29,8 +29,17 @@ async function main(): Promise<void> {
});
client3.connect(clientTransports[2]);

const renamer: ComponentRenamer = function (
clientName: string,
componentName: string,
) {
if (clientName === "client-3" && componentName === "ping") {
return "ping2";
}
return componentName;
};
const allClients = [client1, client2, client3];
const clientGroup = await ClientGroup.create(allClients);
const clientGroup = await ClientGroup.create(allClients, renamer);

const allResources = [];
allResources.push(...(await client1.listResources()).resources);
Expand All @@ -40,10 +49,15 @@ async function main(): Promise<void> {
const toolName = simulatePromptModel(await clientGroup.listTools());

console.log(`Invoking tool: ${toolName}`);
const toolResult = await clientGroup.callTool({
let toolResult = await clientGroup.callTool({
name: toolName,
});
console.log(toolResult);

console.log("Invoking tool: ping2");
toolResult = await clientGroup.callTool({
name: "ping2",
});
console.log(toolResult);

clientGroup.close();
Expand Down