Skip to content

Commit

Permalink
Add mypy checker (#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
leehuwuj authored Oct 8, 2024
1 parent 0e78ba4 commit c60182a
Show file tree
Hide file tree
Showing 23 changed files with 339 additions and 179 deletions.
1 change: 1 addition & 0 deletions .husky/pre-commit
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pnpm format
pnpm lint
uvx ruff format --check templates/
238 changes: 147 additions & 91 deletions e2e/python/resolve_dependencies.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ const dataSource: string = process.env.DATASOURCE
? process.env.DATASOURCE
: "--example-file";

// TODO: add support for other templates

if (
dataSource === "--example-file" // XXX: this test provides its own data source - only trigger it on one data source (usually the CI matrix will trigger multiple data sources)
) {
Expand Down Expand Up @@ -45,102 +47,138 @@ if (

const observabilityOptions = ["llamatrace", "traceloop"];

// Run separate tests for each observability option to reduce CI runtime
test.describe("Test resolve python dependencies with observability", () => {
// Testing with streaming template, vectorDb: none, tools: none, and dataSource: --example-file
for (const observability of observabilityOptions) {
test(`observability: ${observability}`, async () => {
test.describe("Mypy check", () => {
test.describe.configure({ retries: 0 });

// Test vector databases
for (const vectorDb of vectorDbs) {
test(`Mypy check for vectorDB: ${vectorDb}`, async () => {
const cwd = await createTestDir();
const { pyprojectPath } = await createAndCheckLlamaProject({
options: {
cwd,
templateType: "streaming",
templateFramework,
dataSource: "--example-file",
vectorDb,
tools: "none",
port: 3000,
externalPort: 8000,
postInstallAction: "none",
templateUI: undefined,
appType: "--no-frontend",
llamaCloudProjectName: undefined,
llamaCloudIndexName: undefined,
observability: undefined,
},
});

const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8");
if (vectorDb !== "none") {
if (vectorDb === "pg") {
expect(pyprojectContent).toContain(
"llama-index-vector-stores-postgres",
);
} else {
expect(pyprojectContent).toContain(
`llama-index-vector-stores-${vectorDb}`,
);
}
}
});
}

await createAndCheckLlamaProject({
// Test tools
for (const tool of toolOptions) {
test(`Mypy check for tool: ${tool}`, async () => {
const cwd = await createTestDir();
const { pyprojectPath } = await createAndCheckLlamaProject({
options: {
cwd,
templateType: "streaming",
templateFramework,
dataSource: "--example-file",
vectorDb: "none",
tools: tool,
port: 3000,
externalPort: 8000,
postInstallAction: "none",
templateUI: undefined,
appType: "--no-frontend",
llamaCloudProjectName: undefined,
llamaCloudIndexName: undefined,
observability: undefined,
},
});

const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8");
if (tool === "wikipedia.WikipediaToolSpec") {
expect(pyprojectContent).toContain("wikipedia");
}
if (tool === "google.GoogleSearchToolSpec") {
expect(pyprojectContent).toContain("google");
}
});
}

// Test data sources
for (const dataSource of dataSources) {
const dataSourceType = dataSource.split(" ")[0];
test(`Mypy check for data source: ${dataSourceType}`, async () => {
const cwd = await createTestDir();
const { pyprojectPath } = await createAndCheckLlamaProject({
options: {
cwd,
templateType: "streaming",
templateFramework,
dataSource,
vectorDb: "none",
tools: "none",
port: 3000, // port, not used
externalPort: 8000, // externalPort, not used
postInstallAction: "none", // postInstallAction
templateUI: undefined, // ui
appType: "--no-frontend", // appType
llamaCloudProjectName: undefined, // llamaCloudProjectName
llamaCloudIndexName: undefined, // llamaCloudIndexName
observability,
port: 3000,
externalPort: 8000,
postInstallAction: "none",
templateUI: undefined,
appType: "--no-frontend",
llamaCloudProjectName: undefined,
llamaCloudIndexName: undefined,
observability: undefined,
},
});

const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8");
if (dataSource.includes("--web-source")) {
expect(pyprojectContent).toContain("llama-index-readers-web");
}
if (dataSource.includes("--db-source")) {
expect(pyprojectContent).toContain("llama-index-readers-database");
}
});
}
});

test.describe("Test resolve python dependencies", () => {
for (const vectorDb of vectorDbs) {
for (const tool of toolOptions) {
for (const dataSource of dataSources) {
const dataSourceType = dataSource.split(" ")[0];
const toolDescription = tool === "none" ? "no tools" : tool;
const optionDescription = `vectorDb: ${vectorDb}, ${toolDescription}, dataSource: ${dataSourceType}`;

test(`options: ${optionDescription}`, async () => {
const cwd = await createTestDir();

const { pyprojectPath, projectPath } =
await createAndCheckLlamaProject({
options: {
cwd,
templateType: "streaming",
templateFramework,
dataSource,
vectorDb,
tools: tool,
port: 3000, // port, not used
externalPort: 8000, // externalPort, not used
postInstallAction: "none", // postInstallAction
templateUI: undefined, // ui
appType: "--no-frontend", // appType
llamaCloudProjectName: undefined, // llamaCloudProjectName
llamaCloudIndexName: undefined, // llamaCloudIndexName
observability: undefined, // observability
},
});

// Additional checks for specific dependencies

// Verify that specific dependencies are in pyproject.toml
const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8");
if (vectorDb !== "none") {
if (vectorDb === "pg") {
expect(pyprojectContent).toContain(
"llama-index-vector-stores-postgres",
);
} else {
expect(pyprojectContent).toContain(
`llama-index-vector-stores-${vectorDb}`,
);
}
}
if (tool !== "none") {
if (tool === "wikipedia.WikipediaToolSpec") {
expect(pyprojectContent).toContain("wikipedia");
}
if (tool === "google.GoogleSearchToolSpec") {
expect(pyprojectContent).toContain("google");
}
}

// Check for data source specific dependencies
if (dataSource.includes("--web-source")) {
expect(pyprojectContent).toContain("llama-index-readers-web");
}
if (dataSource.includes("--db-source")) {
expect(pyprojectContent).toContain(
"llama-index-readers-database ",
);
}
});
}
}
// Test observability options
for (const observability of observabilityOptions) {
test(`Mypy check for observability: ${observability}`, async () => {
const cwd = await createTestDir();

const { pyprojectPath } = await createAndCheckLlamaProject({
options: {
cwd,
templateType: "streaming",
templateFramework,
dataSource: "--example-file",
vectorDb: "none",
tools: "none",
port: 3000,
externalPort: 8000,
postInstallAction: "none",
templateUI: undefined,
appType: "--no-frontend",
llamaCloudProjectName: undefined,
llamaCloudIndexName: undefined,
observability,
},
});
});
}
});
}
Expand All @@ -161,21 +199,39 @@ async function createAndCheckLlamaProject({
const pyprojectPath = path.join(projectPath, "pyproject.toml");
expect(fs.existsSync(pyprojectPath)).toBeTruthy();

// Run poetry lock
const env = {
...process.env,
POETRY_VIRTUALENVS_IN_PROJECT: "true",
};

// Run poetry install
try {
const { stdout: installStdout, stderr: installStderr } = await execAsync(
"poetry install",
{ cwd: projectPath, env },
);
console.log("poetry install stdout:", installStdout);
console.error("poetry install stderr:", installStderr);
} catch (error) {
console.error("Error running poetry install:", error);
throw error;
}

// Run poetry run mypy
try {
const { stdout, stderr } = await execAsync(
"poetry config virtualenvs.in-project true && poetry lock --no-update",
{ cwd: projectPath },
const { stdout: mypyStdout, stderr: mypyStderr } = await execAsync(
"poetry run mypy .",
{ cwd: projectPath, env },
);
console.log("poetry lock stdout:", stdout);
console.error("poetry lock stderr:", stderr);
console.log("poetry run mypy stdout:", mypyStdout);
console.error("poetry run mypy stderr:", mypyStderr);
} catch (error) {
console.error("Error running poetry lock:", error);
console.error("Error running mypy:", error);
throw error;
}

// Check if poetry.lock file was created
expect(fs.existsSync(path.join(projectPath, "poetry.lock"))).toBeTruthy();
// If we reach this point without throwing an error, the test passes
expect(true).toBeTruthy();

return { pyprojectPath, projectPath };
}
2 changes: 1 addition & 1 deletion helpers/python.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ const getAdditionalDependencies = (
extras: ["rsa"],
});
dependencies.push({
name: "psycopg2",
name: "psycopg2-binary",
version: "^2.9.9",
});
break;
Expand Down
7 changes: 5 additions & 2 deletions templates/components/engines/python/agent/engine.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os
from typing import List

from app.engine.index import IndexConfig, get_index
from app.engine.tools import ToolFactory
from llama_index.core.agent import AgentRunner
from llama_index.core.callbacks import CallbackManager
from llama_index.core.settings import Settings
from llama_index.core.tools import BaseTool
from llama_index.core.tools.query_engine import QueryEngineTool


def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs):
system_prompt = os.getenv("SYSTEM_PROMPT")
top_k = int(os.getenv("TOP_K", 0))
tools = []
tools: List[BaseTool] = []
callback_manager = CallbackManager(handlers=event_handlers or [])

# Add query tool if index exists
Expand All @@ -25,7 +27,8 @@ def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs):
tools.append(query_engine_tool)

# Add additional tools
tools += ToolFactory.from_env()
configured_tools: List[BaseTool] = ToolFactory.from_env()
tools.extend(configured_tools)

return AgentRunner.from_llm(
llm=Settings.llm,
Expand Down
34 changes: 22 additions & 12 deletions templates/components/engines/python/agent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import importlib
import os
from typing import Dict, List, Union

import yaml
import yaml # type: ignore
from llama_index.core.tools.function_tool import FunctionTool
from llama_index.core.tools.tool_spec.base import BaseToolSpec

Expand All @@ -17,7 +18,8 @@ class ToolFactory:
ToolType.LOCAL: "app.engine.tools",
}

def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
@staticmethod
def load_tools(tool_type: str, tool_name: str, config: dict) -> List[FunctionTool]:
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
try:
if "ToolSpec" in tool_name:
Expand All @@ -43,24 +45,32 @@ def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionToo
@staticmethod
def from_env(
map_result: bool = False,
) -> list[FunctionTool] | dict[str, FunctionTool]:
) -> Union[Dict[str, List[FunctionTool]], List[FunctionTool]]:
"""
Load tools from the configured file.
Params:
- use_map: if True, return map of tool name and the tool itself
Args:
map_result: If True, return a map of tool names to their corresponding tools.
Returns:
A dictionary of tool names to lists of FunctionTools if map_result is True,
otherwise a list of FunctionTools.
"""
if map_result:
tools = {}
else:
tools = []
tools: Union[Dict[str, List[FunctionTool]], List[FunctionTool]] = (
{} if map_result else []
)

if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r") as f:
tool_configs = yaml.safe_load(f)
for tool_type, config_entries in tool_configs.items():
for tool_name, config in config_entries.items():
tool = ToolFactory.load_tools(tool_type, tool_name, config)
loaded_tools = ToolFactory.load_tools(
tool_type, tool_name, config
)
if map_result:
tools[tool_name] = tool
tools[tool_name] = loaded_tools # type: ignore
else:
tools.extend(tool)
tools.extend(loaded_tools) # type: ignore

return tools
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def artifact(self, query: str, old_code: Optional[str] = None) -> Dict:
ChatMessage(role="user", content=user_message),
]
try:
sllm = Settings.llm.as_structured_llm(output_cls=CodeArtifact)
sllm = Settings.llm.as_structured_llm(output_cls=CodeArtifact) # type: ignore
response = sllm.chat(messages)
data: CodeArtifact = response.raw
return data.model_dump()
Expand Down
Loading

0 comments on commit c60182a

Please sign in to comment.