Skip to content

Commit

Permalink
update unit test for models' getModelProviderData update
Browse files Browse the repository at this point in the history
  • Loading branch information
odilitime committed Dec 20, 2024
1 parent 132d4ef commit 45eb657
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 54 deletions.
8 changes: 5 additions & 3 deletions packages/core/src/tests/generation.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { describe, expect, it, vi, beforeEach } from "vitest";
import { ModelProviderName, IAgentRuntime } from "../types";
import { models } from "../models";
import { getModelProviderData } from "../models";
import {
generateText,
generateTrueOrFalse,
Expand All @@ -15,6 +15,7 @@ vi.mock("../index.ts", () => ({
log: vi.fn(),
info: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
},
}));

Expand Down Expand Up @@ -76,9 +77,10 @@ describe("Generation", () => {
expect(result).toBe("mocked response");
});

it("should use correct model settings from provider config", () => {
it("should use correct model settings from provider config", async () => {
const modelProvider = mockRuntime.modelProvider;
const modelSettings = models[modelProvider].settings;
console.log('modelProvider', modelProvider)
const modelSettings = (await getModelProviderData(modelProvider)).settings;

expect(modelSettings).toBeDefined();
expect(modelSettings.temperature).toBeDefined();
Expand Down
96 changes: 45 additions & 51 deletions packages/core/src/tests/models.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { getModel, getEndpoint, models } from "../models.ts";
import { getModelProviderData } from "../models.ts";
import { ModelProviderName, ModelClass } from "../types.ts";
import { describe, test, expect, vi } from "vitest";

Expand All @@ -25,21 +25,21 @@ vi.mock("../settings", () => {

describe("Model Provider Configuration", () => {
describe("OpenAI Provider", () => {
test("should have correct endpoint", () => {
expect(models[ModelProviderName.OPENAI].endpoint).toBe("https://api.openai.com/v1");
test("should have correct endpoint", async () => {
expect((await getModelProviderData(ModelProviderName.OPENAI)).endpoint).toBe("https://api.openai.com/v1");
});

test("should have correct model mappings", () => {
const openAIModels = models[ModelProviderName.OPENAI].model;
test("should have correct model mappings", async () => {
const openAIModels = (await getModelProviderData(ModelProviderName.OPENAI)).model;
expect(openAIModels[ModelClass.SMALL]).toBe("gpt-4o-mini");
expect(openAIModels[ModelClass.MEDIUM]).toBe("gpt-4o");
expect(openAIModels[ModelClass.LARGE]).toBe("gpt-4o");
expect(openAIModels[ModelClass.EMBEDDING]).toBe("text-embedding-3-small");
expect(openAIModels[ModelClass.IMAGE]).toBe("dall-e-3");
});

test("should have correct settings configuration", () => {
const settings = models[ModelProviderName.OPENAI].settings;
test("should have correct settings configuration", async () => {
const settings = (await getModelProviderData(ModelProviderName.OPENAI)).settings;
expect(settings.maxInputTokens).toBe(128000);
expect(settings.maxOutputTokens).toBe(8192);
expect(settings.temperature).toBe(0.6);
Expand All @@ -49,19 +49,19 @@ describe("Model Provider Configuration", () => {
});

describe("Anthropic Provider", () => {
test("should have correct endpoint", () => {
expect(models[ModelProviderName.ANTHROPIC].endpoint).toBe("https://api.anthropic.com/v1");
test("should have correct endpoint", async () => {
expect((await getModelProviderData(ModelProviderName.ANTHROPIC)).endpoint).toBe("https://api.anthropic.com/v1");
});

test("should have correct model mappings", () => {
const anthropicModels = models[ModelProviderName.ANTHROPIC].model;
test("should have correct model mappings", async () => {
const anthropicModels = (await getModelProviderData(ModelProviderName.ANTHROPIC)).model;
expect(anthropicModels[ModelClass.SMALL]).toBe("claude-3-haiku-20240307");
expect(anthropicModels[ModelClass.MEDIUM]).toBe("claude-3-5-sonnet-20241022");
expect(anthropicModels[ModelClass.LARGE]).toBe("claude-3-5-sonnet-20241022");
});

test("should have correct settings configuration", () => {
const settings = models[ModelProviderName.ANTHROPIC].settings;
test("should have correct settings configuration", async () => {
const settings = (await getModelProviderData(ModelProviderName.ANTHROPIC)).settings;
expect(settings.maxInputTokens).toBe(200000);
expect(settings.maxOutputTokens).toBe(4096);
expect(settings.temperature).toBe(0.7);
Expand All @@ -71,21 +71,21 @@ describe("Model Provider Configuration", () => {
});

describe("LlamaCloud Provider", () => {
test("should have correct endpoint", () => {
expect(models[ModelProviderName.LLAMACLOUD].endpoint).toBe("https://api.llamacloud.com/v1");
test("should have correct endpoint", async () => {
expect((await getModelProviderData(ModelProviderName.LLAMACLOUD)).endpoint).toBe("https://api.llamacloud.com/v1");
});

test("should have correct model mappings", () => {
const llamaCloudModels = models[ModelProviderName.LLAMACLOUD].model;
test("should have correct model mappings", async () => {
const llamaCloudModels = (await getModelProviderData(ModelProviderName.LLAMACLOUD)).model;
expect(llamaCloudModels[ModelClass.SMALL]).toBe("meta-llama/Llama-3.2-3B-Instruct-Turbo");
expect(llamaCloudModels[ModelClass.MEDIUM]).toBe("meta-llama-3.1-8b-instruct");
expect(llamaCloudModels[ModelClass.LARGE]).toBe("meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo");
expect(llamaCloudModels[ModelClass.EMBEDDING]).toBe("togethercomputer/m2-bert-80M-32k-retrieval");
expect(llamaCloudModels[ModelClass.IMAGE]).toBe("black-forest-labs/FLUX.1-schnell");
});

test("should have correct settings configuration", () => {
const settings = models[ModelProviderName.LLAMACLOUD].settings;
test("should have correct settings configuration", async () => {
const settings = (await getModelProviderData(ModelProviderName.LLAMACLOUD)).settings;
expect(settings.maxInputTokens).toBe(128000);
expect(settings.maxOutputTokens).toBe(8192);
expect(settings.temperature).toBe(0.7);
Expand All @@ -94,8 +94,8 @@ describe("Model Provider Configuration", () => {
});

describe("Google Provider", () => {
test("should have correct model mappings", () => {
const googleModels = models[ModelProviderName.GOOGLE].model;
test("should have correct model mappings", async () => {
const googleModels = (await getModelProviderData(ModelProviderName.GOOGLE)).model;
expect(googleModels[ModelClass.SMALL]).toBe("gemini-1.5-flash-latest");
expect(googleModels[ModelClass.MEDIUM]).toBe("gemini-1.5-flash-latest");
expect(googleModels[ModelClass.LARGE]).toBe("gemini-1.5-pro-latest");
Expand All @@ -104,42 +104,36 @@ describe("Model Provider Configuration", () => {
});

describe("Model Retrieval Functions", () => {
describe("getModel function", () => {
test("should retrieve correct models for different providers and classes", () => {
expect(getModel(ModelProviderName.OPENAI, ModelClass.SMALL)).toBe("gpt-4o-mini");
expect(getModel(ModelProviderName.ANTHROPIC, ModelClass.LARGE)).toBe("claude-3-5-sonnet-20241022");
expect(getModel(ModelProviderName.LLAMACLOUD, ModelClass.MEDIUM)).toBe("meta-llama-3.1-8b-instruct");
describe("getModelProviderData function", () => {
test("should retrieve correct models for different providers and classes", async () => {
expect((await getModelProviderData(ModelProviderName.OPENAI)).model[ModelClass.SMALL]).toBe("gpt-4o-mini");
expect((await getModelProviderData(ModelProviderName.ANTHROPIC)).model[ModelClass.LARGE]).toBe("claude-3-5-sonnet-20241022");
expect((await getModelProviderData(ModelProviderName.LLAMACLOUD)).model[ModelClass.MEDIUM]).toBe("meta-llama-3.1-8b-instruct");
});

test("should handle environment variable overrides", () => {
expect(getModel(ModelProviderName.OPENROUTER, ModelClass.SMALL)).toBe("mock-small-model");
expect(getModel(ModelProviderName.OPENROUTER, ModelClass.LARGE)).toBe("mock-large-model");
expect(getModel(ModelProviderName.ETERNALAI, ModelClass.SMALL)).toBe("mock-eternal-model");
test("should handle environment variable overrides", async () => {
expect((await getModelProviderData(ModelProviderName.OPENROUTER)).model[ModelClass.SMALL]).toBe("mock-small-model");
expect((await getModelProviderData(ModelProviderName.OPENROUTER)).model[ModelClass.LARGE]).toBe("mock-large-model");
expect((await getModelProviderData(ModelProviderName.ETERNALAI)).model[ModelClass.SMALL]).toBe("mock-eternal-model");
});

test("should throw error for invalid model provider", () => {
expect(() => getModel("INVALID_PROVIDER" as any, ModelClass.SMALL)).toThrow();
});
});

describe("getEndpoint function", () => {
test("should retrieve correct endpoints for different providers", () => {
expect(getEndpoint(ModelProviderName.OPENAI)).toBe("https://api.openai.com/v1");
expect(getEndpoint(ModelProviderName.ANTHROPIC)).toBe("https://api.anthropic.com/v1");
expect(getEndpoint(ModelProviderName.LLAMACLOUD)).toBe("https://api.llamacloud.com/v1");
expect(getEndpoint(ModelProviderName.ETERNALAI)).toBe("https://mock.eternal.ai");
test("should throw error for invalid model provider", async () => {
await expect(getModelProviderData("INVALID_PROVIDER" as any)).rejects.toThrow();
});

test("should throw error for invalid provider", () => {
expect(() => getEndpoint("INVALID_PROVIDER" as any)).toThrow();
test("should retrieve correct endpoints for different providers", async () => {
expect((await getModelProviderData(ModelProviderName.OPENAI)).endpoint).toBe("https://api.openai.com/v1");
expect((await getModelProviderData(ModelProviderName.ANTHROPIC)).endpoint).toBe("https://api.anthropic.com/v1");
expect((await getModelProviderData(ModelProviderName.LLAMACLOUD)).endpoint).toBe("https://api.llamacloud.com/v1");
expect((await getModelProviderData(ModelProviderName.ETERNALAI)).endpoint).toBe("https://mock.eternal.ai");
});
});
});

describe("Model Settings Validation", () => {
test("all providers should have required settings", () => {
Object.values(ModelProviderName).forEach(provider => {
const providerConfig = models[provider];
Object.values(ModelProviderName).forEach(async provider => {
const providerConfig = await getModelProviderData(provider);
expect(providerConfig.settings).toBeDefined();
expect(providerConfig.settings.maxInputTokens).toBeGreaterThan(0);
expect(providerConfig.settings.maxOutputTokens).toBeGreaterThan(0);
Expand All @@ -148,8 +142,8 @@ describe("Model Settings Validation", () => {
});

test("all providers should have model mappings for basic model classes", () => {
Object.values(ModelProviderName).forEach(provider => {
const providerConfig = models[provider];
Object.values(ModelProviderName).forEach(async provider => {
const providerConfig = await getModelProviderData(provider);
expect(providerConfig.model).toBeDefined();
expect(providerConfig.model[ModelClass.SMALL]).toBeDefined();
expect(providerConfig.model[ModelClass.MEDIUM]).toBeDefined();
Expand All @@ -159,14 +153,14 @@ describe("Model Settings Validation", () => {
});

describe("Environment Variable Integration", () => {
test("should use environment variables for LlamaCloud models", () => {
const llamaConfig = models[ModelProviderName.LLAMACLOUD];
test("should use environment variables for LlamaCloud models", async () => {
const llamaConfig = await getModelProviderData(ModelProviderName.LLAMACLOUD);
expect(llamaConfig.model[ModelClass.SMALL]).toBe("meta-llama/Llama-3.2-3B-Instruct-Turbo");
expect(llamaConfig.model[ModelClass.LARGE]).toBe("meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo");
});

test("should use environment variables for Together models", () => {
const togetherConfig = models[ModelProviderName.TOGETHER];
test("should use environment variables for Together models", async () => {
const togetherConfig = await getModelProviderData(ModelProviderName.TOGETHER);
expect(togetherConfig.model[ModelClass.SMALL]).toBe("meta-llama/Llama-3.2-3B-Instruct-Turbo");
expect(togetherConfig.model[ModelClass.LARGE]).toBe("meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo");
});
Expand Down

0 comments on commit 45eb657

Please sign in to comment.