diff --git a/packages/core/src/tests/generation.test.ts b/packages/core/src/tests/generation.test.ts index f1ec8f9bc69..f5e7741c40d 100644 --- a/packages/core/src/tests/generation.test.ts +++ b/packages/core/src/tests/generation.test.ts @@ -1,6 +1,6 @@ import { describe, expect, it, vi, beforeEach } from "vitest"; import { ModelProviderName, IAgentRuntime } from "../types"; -import { models } from "../models"; +import { getModelProviderData } from "../models"; import { generateText, generateTrueOrFalse, @@ -15,6 +15,7 @@ vi.mock("../index.ts", () => ({ log: vi.fn(), info: vi.fn(), error: vi.fn(), + debug: vi.fn(), }, })); @@ -76,9 +77,10 @@ describe("Generation", () => { expect(result).toBe("mocked response"); }); - it("should use correct model settings from provider config", () => { + it("should use correct model settings from provider config", async () => { const modelProvider = mockRuntime.modelProvider; - const modelSettings = models[modelProvider].settings; + console.log('modelProvider', modelProvider) + const modelSettings = (await getModelProviderData(modelProvider)).settings; expect(modelSettings).toBeDefined(); expect(modelSettings.temperature).toBeDefined(); diff --git a/packages/core/src/tests/models.test.ts b/packages/core/src/tests/models.test.ts index f336093cfdd..21a232c8369 100644 --- a/packages/core/src/tests/models.test.ts +++ b/packages/core/src/tests/models.test.ts @@ -1,4 +1,4 @@ -import { getModel, getEndpoint, models } from "../models.ts"; +import { getModelProviderData } from "../models.ts"; import { ModelProviderName, ModelClass } from "../types.ts"; import { describe, test, expect, vi } from "vitest"; @@ -25,12 +25,12 @@ vi.mock("../settings", () => { describe("Model Provider Configuration", () => { describe("OpenAI Provider", () => { - test("should have correct endpoint", () => { - expect(models[ModelProviderName.OPENAI].endpoint).toBe("https://api.openai.com/v1"); + test("should have correct endpoint", async () => { + expect((await getModelProviderData(ModelProviderName.OPENAI)).endpoint).toBe("https://api.openai.com/v1"); }); - test("should have correct model mappings", () => { - const openAIModels = models[ModelProviderName.OPENAI].model; + test("should have correct model mappings", async () => { + const openAIModels = (await getModelProviderData(ModelProviderName.OPENAI)).model; expect(openAIModels[ModelClass.SMALL]).toBe("gpt-4o-mini"); expect(openAIModels[ModelClass.MEDIUM]).toBe("gpt-4o"); expect(openAIModels[ModelClass.LARGE]).toBe("gpt-4o"); @@ -38,8 +38,8 @@ describe("Model Provider Configuration", () => { expect(openAIModels[ModelClass.IMAGE]).toBe("dall-e-3"); }); - test("should have correct settings configuration", () => { - const settings = models[ModelProviderName.OPENAI].settings; + test("should have correct settings configuration", async () => { + const settings = (await getModelProviderData(ModelProviderName.OPENAI)).settings; expect(settings.maxInputTokens).toBe(128000); expect(settings.maxOutputTokens).toBe(8192); expect(settings.temperature).toBe(0.6); @@ -49,19 +49,19 @@ describe("Model Provider Configuration", () => { }); describe("Anthropic Provider", () => { - test("should have correct endpoint", () => { - expect(models[ModelProviderName.ANTHROPIC].endpoint).toBe("https://api.anthropic.com/v1"); + test("should have correct endpoint", async () => { + expect((await getModelProviderData(ModelProviderName.ANTHROPIC)).endpoint).toBe("https://api.anthropic.com/v1"); }); - test("should have correct model mappings", () => { - const anthropicModels = models[ModelProviderName.ANTHROPIC].model; + test("should have correct model mappings", async () => { + const anthropicModels = (await getModelProviderData(ModelProviderName.ANTHROPIC)).model; expect(anthropicModels[ModelClass.SMALL]).toBe("claude-3-haiku-20240307"); expect(anthropicModels[ModelClass.MEDIUM]).toBe("claude-3-5-sonnet-20241022"); expect(anthropicModels[ModelClass.LARGE]).toBe("claude-3-5-sonnet-20241022"); }); - test("should have correct settings configuration", () => { - const settings = models[ModelProviderName.ANTHROPIC].settings; + test("should have correct settings configuration", async () => { + const settings = (await getModelProviderData(ModelProviderName.ANTHROPIC)).settings; expect(settings.maxInputTokens).toBe(200000); expect(settings.maxOutputTokens).toBe(4096); expect(settings.temperature).toBe(0.7); @@ -71,12 +71,12 @@ describe("Model Provider Configuration", () => { }); describe("LlamaCloud Provider", () => { - test("should have correct endpoint", () => { - expect(models[ModelProviderName.LLAMACLOUD].endpoint).toBe("https://api.llamacloud.com/v1"); + test("should have correct endpoint", async () => { + expect((await getModelProviderData(ModelProviderName.LLAMACLOUD)).endpoint).toBe("https://api.llamacloud.com/v1"); }); - test("should have correct model mappings", () => { - const llamaCloudModels = models[ModelProviderName.LLAMACLOUD].model; + test("should have correct model mappings", async () => { + const llamaCloudModels = (await getModelProviderData(ModelProviderName.LLAMACLOUD)).model; expect(llamaCloudModels[ModelClass.SMALL]).toBe("meta-llama/Llama-3.2-3B-Instruct-Turbo"); expect(llamaCloudModels[ModelClass.MEDIUM]).toBe("meta-llama-3.1-8b-instruct"); expect(llamaCloudModels[ModelClass.LARGE]).toBe("meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"); @@ -84,8 +84,8 @@ describe("Model Provider Configuration", () => { expect(llamaCloudModels[ModelClass.IMAGE]).toBe("black-forest-labs/FLUX.1-schnell"); }); - test("should have correct settings configuration", () => { - const settings = models[ModelProviderName.LLAMACLOUD].settings; + test("should have correct settings configuration", async () => { + const settings = (await getModelProviderData(ModelProviderName.LLAMACLOUD)).settings; expect(settings.maxInputTokens).toBe(128000); expect(settings.maxOutputTokens).toBe(8192); expect(settings.temperature).toBe(0.7); @@ -94,8 +94,8 @@ describe("Model Provider Configuration", () => { }); describe("Google Provider", () => { - test("should have correct model mappings", () => { - const googleModels = models[ModelProviderName.GOOGLE].model; + test("should have correct model mappings", async () => { + const googleModels = (await getModelProviderData(ModelProviderName.GOOGLE)).model; expect(googleModels[ModelClass.SMALL]).toBe("gemini-1.5-flash-latest"); expect(googleModels[ModelClass.MEDIUM]).toBe("gemini-1.5-flash-latest"); expect(googleModels[ModelClass.LARGE]).toBe("gemini-1.5-pro-latest"); @@ -104,42 +104,36 @@ describe("Model Provider Configuration", () => { }); describe("Model Retrieval Functions", () => { - describe("getModel function", () => { - test("should retrieve correct models for different providers and classes", () => { - expect(getModel(ModelProviderName.OPENAI, ModelClass.SMALL)).toBe("gpt-4o-mini"); - expect(getModel(ModelProviderName.ANTHROPIC, ModelClass.LARGE)).toBe("claude-3-5-sonnet-20241022"); - expect(getModel(ModelProviderName.LLAMACLOUD, ModelClass.MEDIUM)).toBe("meta-llama-3.1-8b-instruct"); + describe("getModelProviderData function", () => { + test("should retrieve correct models for different providers and classes", async () => { + expect((await getModelProviderData(ModelProviderName.OPENAI)).model[ModelClass.SMALL]).toBe("gpt-4o-mini"); + expect((await getModelProviderData(ModelProviderName.ANTHROPIC)).model[ModelClass.LARGE]).toBe("claude-3-5-sonnet-20241022"); + expect((await getModelProviderData(ModelProviderName.LLAMACLOUD)).model[ModelClass.MEDIUM]).toBe("meta-llama-3.1-8b-instruct"); }); - test("should handle environment variable overrides", () => { - expect(getModel(ModelProviderName.OPENROUTER, ModelClass.SMALL)).toBe("mock-small-model"); - expect(getModel(ModelProviderName.OPENROUTER, ModelClass.LARGE)).toBe("mock-large-model"); - expect(getModel(ModelProviderName.ETERNALAI, ModelClass.SMALL)).toBe("mock-eternal-model"); + test("should handle environment variable overrides", async () => { + expect((await getModelProviderData(ModelProviderName.OPENROUTER)).model[ModelClass.SMALL]).toBe("mock-small-model"); + expect((await getModelProviderData(ModelProviderName.OPENROUTER)).model[ModelClass.LARGE]).toBe("mock-large-model"); + expect((await getModelProviderData(ModelProviderName.ETERNALAI)).model[ModelClass.SMALL]).toBe("mock-eternal-model"); }); - test("should throw error for invalid model provider", () => { - expect(() => getModel("INVALID_PROVIDER" as any, ModelClass.SMALL)).toThrow(); - }); - }); - - describe("getEndpoint function", () => { - test("should retrieve correct endpoints for different providers", () => { - expect(getEndpoint(ModelProviderName.OPENAI)).toBe("https://api.openai.com/v1"); - expect(getEndpoint(ModelProviderName.ANTHROPIC)).toBe("https://api.anthropic.com/v1"); - expect(getEndpoint(ModelProviderName.LLAMACLOUD)).toBe("https://api.llamacloud.com/v1"); - expect(getEndpoint(ModelProviderName.ETERNALAI)).toBe("https://mock.eternal.ai"); + test("should throw error for invalid model provider", async () => { + await expect(getModelProviderData("INVALID_PROVIDER" as any)).rejects.toThrow(); }); - test("should throw error for invalid provider", () => { - expect(() => getEndpoint("INVALID_PROVIDER" as any)).toThrow(); + test("should retrieve correct endpoints for different providers", async () => { + expect((await getModelProviderData(ModelProviderName.OPENAI)).endpoint).toBe("https://api.openai.com/v1"); + expect((await getModelProviderData(ModelProviderName.ANTHROPIC)).endpoint).toBe("https://api.anthropic.com/v1"); + expect((await getModelProviderData(ModelProviderName.LLAMACLOUD)).endpoint).toBe("https://api.llamacloud.com/v1"); + expect((await getModelProviderData(ModelProviderName.ETERNALAI)).endpoint).toBe("https://mock.eternal.ai"); }); }); }); describe("Model Settings Validation", () => { test("all providers should have required settings", () => { - Object.values(ModelProviderName).forEach(provider => { - const providerConfig = models[provider]; + Object.values(ModelProviderName).forEach(async provider => { + const providerConfig = await getModelProviderData(provider); expect(providerConfig.settings).toBeDefined(); expect(providerConfig.settings.maxInputTokens).toBeGreaterThan(0); expect(providerConfig.settings.maxOutputTokens).toBeGreaterThan(0); @@ -148,8 +142,8 @@ describe("Model Settings Validation", () => { }); test("all providers should have model mappings for basic model classes", () => { - Object.values(ModelProviderName).forEach(provider => { - const providerConfig = models[provider]; + Object.values(ModelProviderName).forEach(async provider => { + const providerConfig = await getModelProviderData(provider); expect(providerConfig.model).toBeDefined(); expect(providerConfig.model[ModelClass.SMALL]).toBeDefined(); expect(providerConfig.model[ModelClass.MEDIUM]).toBeDefined(); @@ -159,14 +153,14 @@ describe("Model Settings Validation", () => { }); describe("Environment Variable Integration", () => { - test("should use environment variables for LlamaCloud models", () => { - const llamaConfig = models[ModelProviderName.LLAMACLOUD]; + test("should use environment variables for LlamaCloud models", async () => { + const llamaConfig = await getModelProviderData(ModelProviderName.LLAMACLOUD); expect(llamaConfig.model[ModelClass.SMALL]).toBe("meta-llama/Llama-3.2-3B-Instruct-Turbo"); expect(llamaConfig.model[ModelClass.LARGE]).toBe("meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"); }); - test("should use environment variables for Together models", () => { - const togetherConfig = models[ModelProviderName.TOGETHER]; + test("should use environment variables for Together models", async () => { + const togetherConfig = await getModelProviderData(ModelProviderName.TOGETHER); expect(togetherConfig.model[ModelClass.SMALL]).toBe("meta-llama/Llama-3.2-3B-Instruct-Turbo"); expect(togetherConfig.model[ModelClass.LARGE]).toBe("meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo"); });