Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mypy checker #346

Merged
merged 16 commits into from
Oct 8, 2024
Merged
1 change: 1 addition & 0 deletions .husky/pre-commit
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pnpm format
pnpm lint
uvx ruff format --check templates/
238 changes: 147 additions & 91 deletions e2e/python/resolve_dependencies.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ const dataSource: string = process.env.DATASOURCE
? process.env.DATASOURCE
: "--example-file";

// TODO: add support for other templates

if (
dataSource === "--example-file" // XXX: this test provides its own data source - only trigger it on one data source (usually the CI matrix will trigger multiple data sources)
) {
Expand Down Expand Up @@ -45,102 +47,138 @@ if (

const observabilityOptions = ["llamatrace", "traceloop"];

// Run separate tests for each observability option to reduce CI runtime
test.describe("Test resolve python dependencies with observability", () => {
// Testing with streaming template, vectorDb: none, tools: none, and dataSource: --example-file
for (const observability of observabilityOptions) {
test(`observability: ${observability}`, async () => {
test.describe("Mypy check", () => {
test.describe.configure({ retries: 0 });

// Test vector databases
for (const vectorDb of vectorDbs) {
test(`Mypy check for vectorDB: ${vectorDb}`, async () => {
const cwd = await createTestDir();
const { pyprojectPath } = await createAndCheckLlamaProject({
options: {
cwd,
templateType: "streaming",
templateFramework,
dataSource: "--example-file",
vectorDb,
tools: "none",
port: 3000,
externalPort: 8000,
postInstallAction: "none",
templateUI: undefined,
appType: "--no-frontend",
llamaCloudProjectName: undefined,
llamaCloudIndexName: undefined,
observability: undefined,
},
});

const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8");
if (vectorDb !== "none") {
if (vectorDb === "pg") {
expect(pyprojectContent).toContain(
"llama-index-vector-stores-postgres",
);
} else {
expect(pyprojectContent).toContain(
`llama-index-vector-stores-${vectorDb}`,
);
}
}
});
}

await createAndCheckLlamaProject({
// Test tools
for (const tool of toolOptions) {
test(`Mypy check for tool: ${tool}`, async () => {
const cwd = await createTestDir();
const { pyprojectPath } = await createAndCheckLlamaProject({
options: {
cwd,
templateType: "streaming",
templateFramework,
dataSource: "--example-file",
vectorDb: "none",
tools: tool,
port: 3000,
externalPort: 8000,
postInstallAction: "none",
templateUI: undefined,
appType: "--no-frontend",
llamaCloudProjectName: undefined,
llamaCloudIndexName: undefined,
observability: undefined,
},
});

const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8");
if (tool === "wikipedia.WikipediaToolSpec") {
expect(pyprojectContent).toContain("wikipedia");
}
if (tool === "google.GoogleSearchToolSpec") {
expect(pyprojectContent).toContain("google");
}
});
}

// Test data sources
for (const dataSource of dataSources) {
const dataSourceType = dataSource.split(" ")[0];
test(`Mypy check for data source: ${dataSourceType}`, async () => {
const cwd = await createTestDir();
const { pyprojectPath } = await createAndCheckLlamaProject({
options: {
cwd,
templateType: "streaming",
templateFramework,
dataSource,
vectorDb: "none",
tools: "none",
port: 3000, // port, not used
externalPort: 8000, // externalPort, not used
postInstallAction: "none", // postInstallAction
templateUI: undefined, // ui
appType: "--no-frontend", // appType
llamaCloudProjectName: undefined, // llamaCloudProjectName
llamaCloudIndexName: undefined, // llamaCloudIndexName
observability,
port: 3000,
externalPort: 8000,
postInstallAction: "none",
templateUI: undefined,
appType: "--no-frontend",
llamaCloudProjectName: undefined,
llamaCloudIndexName: undefined,
observability: undefined,
},
});

const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8");
if (dataSource.includes("--web-source")) {
expect(pyprojectContent).toContain("llama-index-readers-web");
}
if (dataSource.includes("--db-source")) {
expect(pyprojectContent).toContain("llama-index-readers-database");
}
});
}
});

test.describe("Test resolve python dependencies", () => {
for (const vectorDb of vectorDbs) {
for (const tool of toolOptions) {
for (const dataSource of dataSources) {
const dataSourceType = dataSource.split(" ")[0];
const toolDescription = tool === "none" ? "no tools" : tool;
const optionDescription = `vectorDb: ${vectorDb}, ${toolDescription}, dataSource: ${dataSourceType}`;

test(`options: ${optionDescription}`, async () => {
const cwd = await createTestDir();

const { pyprojectPath, projectPath } =
await createAndCheckLlamaProject({
options: {
cwd,
templateType: "streaming",
templateFramework,
dataSource,
vectorDb,
tools: tool,
port: 3000, // port, not used
externalPort: 8000, // externalPort, not used
postInstallAction: "none", // postInstallAction
templateUI: undefined, // ui
appType: "--no-frontend", // appType
llamaCloudProjectName: undefined, // llamaCloudProjectName
llamaCloudIndexName: undefined, // llamaCloudIndexName
observability: undefined, // observability
},
});

// Additional checks for specific dependencies

// Verify that specific dependencies are in pyproject.toml
const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8");
if (vectorDb !== "none") {
if (vectorDb === "pg") {
expect(pyprojectContent).toContain(
"llama-index-vector-stores-postgres",
);
} else {
expect(pyprojectContent).toContain(
`llama-index-vector-stores-${vectorDb}`,
);
}
}
if (tool !== "none") {
if (tool === "wikipedia.WikipediaToolSpec") {
expect(pyprojectContent).toContain("wikipedia");
}
if (tool === "google.GoogleSearchToolSpec") {
expect(pyprojectContent).toContain("google");
}
}

// Check for data source specific dependencies
if (dataSource.includes("--web-source")) {
expect(pyprojectContent).toContain("llama-index-readers-web");
}
if (dataSource.includes("--db-source")) {
expect(pyprojectContent).toContain(
"llama-index-readers-database ",
);
}
});
}
}
// Test observability options
for (const observability of observabilityOptions) {
test(`Mypy check for observability: ${observability}`, async () => {
const cwd = await createTestDir();

const { pyprojectPath } = await createAndCheckLlamaProject({
options: {
cwd,
templateType: "streaming",
templateFramework,
dataSource: "--example-file",
vectorDb: "none",
tools: "none",
port: 3000,
externalPort: 8000,
postInstallAction: "none",
templateUI: undefined,
appType: "--no-frontend",
llamaCloudProjectName: undefined,
llamaCloudIndexName: undefined,
observability,
},
});
});
}
});
}
Expand All @@ -161,21 +199,39 @@ async function createAndCheckLlamaProject({
const pyprojectPath = path.join(projectPath, "pyproject.toml");
expect(fs.existsSync(pyprojectPath)).toBeTruthy();

// Run poetry lock
const env = {
...process.env,
POETRY_VIRTUALENVS_IN_PROJECT: "true",
};

// Run poetry install
try {
const { stdout: installStdout, stderr: installStderr } = await execAsync(
"poetry install",
{ cwd: projectPath, env },
);
console.log("poetry install stdout:", installStdout);
console.error("poetry install stderr:", installStderr);
} catch (error) {
console.error("Error running poetry install:", error);
throw error;
}

// Run poetry run mypy
try {
const { stdout, stderr } = await execAsync(
"poetry config virtualenvs.in-project true && poetry lock --no-update",
{ cwd: projectPath },
const { stdout: mypyStdout, stderr: mypyStderr } = await execAsync(
"poetry run mypy .",
{ cwd: projectPath, env },
);
console.log("poetry lock stdout:", stdout);
console.error("poetry lock stderr:", stderr);
console.log("poetry run mypy stdout:", mypyStdout);
console.error("poetry run mypy stderr:", mypyStderr);
} catch (error) {
console.error("Error running poetry lock:", error);
console.error("Error running mypy:", error);
throw error;
}

// Check if poetry.lock file was created
expect(fs.existsSync(path.join(projectPath, "poetry.lock"))).toBeTruthy();
// If we reach this point without throwing an error, the test passes
expect(true).toBeTruthy();

return { pyprojectPath, projectPath };
}
2 changes: 1 addition & 1 deletion helpers/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ const getAdditionalDependencies = (
extras: ["rsa"],
});
dependencies.push({
name: "psycopg2",
name: "psycopg2-binary",
version: "^2.9.9",
});
break;
Expand Down
7 changes: 5 additions & 2 deletions templates/components/engines/python/agent/engine.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os
from typing import List

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from llama_index.core.agent import AgentRunner
from llama_index.core.callbacks import CallbackManager
from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool
from llama_index.core.tools.query_engine import QueryEngineTool


def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs):
system_prompt = os.getenv("SYSTEM_PROMPT")
top_k = int(os.getenv("TOP_K", 0))
tools = []
tools: List[BaseTool] = []
callback_manager = CallbackManager(handlers=event_handlers or [])

# Add query tool if index exists
Expand All @@ -25,7 +27,8 @@ def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs):
tools.append(query_engine_tool)

# Add additional tools
tools += ToolFactory.from_env()
configured_tools: List[BaseTool] = ToolFactory.from_env()
tools.extend(configured_tools)

return AgentRunner.from_llm(
llm=Settings.llm,
Expand Down
34 changes: 22 additions & 12 deletions templates/components/engines/python/agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import importlib
import os
from typing import Dict, List, Union

import yaml
import yaml # type: ignore
from llama_index.core.tools.function_tool import FunctionTool
from llama_index.core.tools.tool_spec.base import BaseToolSpec

Expand All @@ -17,7 +18,8 @@ class ToolFactory:
ToolType.LOCAL: "app.engine.tools",
}

def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
@staticmethod
def load_tools(tool_type: str, tool_name: str, config: dict) -> List[FunctionTool]:
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
try:
if "ToolSpec" in tool_name:
Expand All @@ -43,24 +45,32 @@ def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionToo
@staticmethod
def from_env(
map_result: bool = False,
) -> list[FunctionTool] | dict[str, FunctionTool]:
) -> Union[Dict[str, List[FunctionTool]], List[FunctionTool]]:
"""
Load tools from the configured file.
Params:
- use_map: if True, return map of tool name and the tool itself

Args:
map_result: If True, return a map of tool names to their corresponding tools.

Returns:
A dictionary of tool names to lists of FunctionTools if map_result is True,
otherwise a list of FunctionTools.
"""
if map_result:
tools = {}
else:
tools = []
tools: Union[Dict[str, List[FunctionTool]], List[FunctionTool]] = (
{} if map_result else []
)

if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f)
for tool_type, config_entries in tool_configs.items():
for tool_name, config in config_entries.items():
tool = ToolFactory.load_tools(tool_type, tool_name, config)
loaded_tools = ToolFactory.load_tools(
tool_type, tool_name, config
)
if map_result:
tools[tool_name] = tool
tools[tool_name] = loaded_tools # type: ignore
else:
tools.extend(tool)
tools.extend(loaded_tools) # type: ignore

return tools
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def artifact(self, query: str, old_code: Optional[str] = None) -> Dict:
ChatMessage(role="user", content=user_message),
]
try:
sllm = Settings.llm.as_structured_llm(output_cls=CodeArtifact)
sllm = Settings.llm.as_structured_llm(output_cls=CodeArtifact) # type: ignore
response = sllm.chat(messages)
data: CodeArtifact = response.raw
return data.model_dump()
Expand Down
Loading
Loading