diff --git a/Cargo.toml b/Cargo.toml index 1dc28c4f..357d0cb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,3 +106,4 @@ bytes = "1.10.1" rand = "0.9.0" indoc = "2.0.6" owo-colors = "4.2.0" +json5 = "0.4.1" diff --git a/docs/docs/ai/llm.mdx b/docs/docs/ai/llm.mdx index 02ced152..d4e751eb 100644 --- a/docs/docs/ai/llm.mdx +++ b/docs/docs/ai/llm.mdx @@ -97,4 +97,27 @@ cocoindex.LlmSpec( -You can find the full list of models supported by Gemini [here](https://ai.google.dev/gemini-api/docs/models). \ No newline at end of file +You can find the full list of models supported by Gemini [here](https://ai.google.dev/gemini-api/docs/models). + +### Anthropic + +To use the Anthropic LLM API, you need to set the environment variable `ANTHROPIC_API_KEY`. +You can generate the API key from [Anthropic API](https://console.anthropic.com/settings/keys). + +A spec for Anthropic looks like this: + + + + +```python +cocoindex.LlmSpec( + api_type=cocoindex.LlmApiType.ANTHROPIC, + model="claude-3-5-sonnet-latest", +) +``` + + + + +You can find the full list of models supported by Anthropic [here](https://docs.anthropic.com/en/docs/about-claude/models/all-models). + diff --git a/examples/manuals_llm_extraction/main.py b/examples/manuals_llm_extraction/main.py index 2547c3fa..2115731f 100644 --- a/examples/manuals_llm_extraction/main.py +++ b/examples/manuals_llm_extraction/main.py @@ -98,6 +98,10 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco # Replace by this spec below, to use Gemini API model # llm_spec=cocoindex.LlmSpec( # api_type=cocoindex.LlmApiType.GEMINI, model="gemini-2.0-flash"), + + # Replace by this spec below, to use Anthropic API model + # llm_spec=cocoindex.LlmSpec( + # api_type=cocoindex.LlmApiType.ANTHROPIC, model="claude-3-5-sonnet-latest"), output_type=ModuleInfo, instruction="Please extract Python module information from the manual.")) doc["module_summary"] = doc["module_info"].transform(summarize_module) diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index 9db6da41..5fcd0a73 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -6,6 +6,7 @@ class LlmApiType(Enum): OPENAI = "OpenAi" OLLAMA = "Ollama" GEMINI = "Gemini" + ANTHROPIC = "Anthropic" @dataclass class LlmSpec: diff --git a/src/llm/anthropic.rs b/src/llm/anthropic.rs new file mode 100644 index 00000000..65cfcb49 --- /dev/null +++ b/src/llm/anthropic.rs @@ -0,0 +1,137 @@ +use async_trait::async_trait; +use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; +use anyhow::{Result, bail, Context}; +use serde_json::Value; +use json5; + +use crate::api_bail; +use urlencoding::encode; + +pub struct Client { + model: String, + api_key: String, + client: reqwest::Client, +} + +impl Client { + pub async fn new(spec: LlmSpec) -> Result { + let api_key = match std::env::var("ANTHROPIC_API_KEY") { + Ok(val) => val, + Err(_) => api_bail!("ANTHROPIC_API_KEY environment variable must be set"), + }; + Ok(Self { + model: spec.model, + api_key, + client: reqwest::Client::new(), + }) + } +} + +#[async_trait] +impl LlmGenerationClient for Client { + async fn generate<'req>( + &self, + request: LlmGenerateRequest<'req>, + ) -> Result { + let messages = vec![serde_json::json!({ + "role": "user", + "content": request.user_prompt + })]; + + let mut payload = serde_json::json!({ + "model": self.model, + "messages": messages, + "max_tokens": 4096 + }); + + // Add system prompt as top-level field if present (required) + if let Some(system) = request.system_prompt { + payload["system"] = serde_json::json!(system); + } + + // Extract schema from output_format, error if not JsonSchema + let schema = match request.output_format.as_ref() { + Some(OutputFormat::JsonSchema { schema, .. }) => schema, + _ => api_bail!("Anthropic client expects OutputFormat::JsonSchema for all requests"), + }; + + let schema_json = serde_json::to_value(schema)?; + payload["tools"] = serde_json::json!([ + { "type": "custom", "name": "report_result", "input_schema": schema_json } + ]); + + let url = "https://api.anthropic.com/v1/messages"; + + let encoded_api_key = encode(&self.api_key); + + let resp = self.client + .post(url) + .header("x-api-key", encoded_api_key.as_ref()) + .header("anthropic-version", "2023-06-01") + .json(&payload) + .send() + .await + .context("HTTP error")?; + let mut resp_json: Value = resp.json().await.context("Invalid JSON")?; + if let Some(error) = resp_json.get("error") { + bail!("Anthropic API error: {:?}", error); + } + + // Debug print full response + // println!("Anthropic API full response: {resp_json:?}"); + + let resp_content = &resp_json["content"]; + let tool_name = "report_result"; + let mut extracted_json: Option = None; + if let Some(array) = resp_content.as_array() { + for item in array { + if item.get("type") == Some(&Value::String("tool_use".to_string())) + && item.get("name") == Some(&Value::String(tool_name.to_string())) + { + if let Some(input) = item.get("input") { + extracted_json = Some(input.clone()); + break; + } + } + } + } + let text = if let Some(json) = extracted_json { + // Try strict JSON serialization first + serde_json::to_string(&json)? + } else { + // Fallback: try text if no tool output found + match &mut resp_json["content"][0]["text"] { + Value::String(s) => { + // Try strict JSON parsing first + match serde_json::from_str::(s) { + Ok(_) => std::mem::take(s), + Err(e) => { + // Try permissive json5 parsing as fallback + match json5::from_str::(s) { + Ok(value) => { + println!("[Anthropic] Used permissive JSON5 parser for output"); + serde_json::to_string(&value)? + }, + Err(e2) => return Err(anyhow::anyhow!(format!("No structured tool output or text found in response, and permissive JSON5 parsing also failed: {e}; {e2}"))) + } + } + } + }, + _ => return Err(anyhow::anyhow!("No structured tool output or text found in response")), + } + }; + + Ok(LlmGenerateResponse { + text, + }) + } + + fn json_schema_options(&self) -> ToJsonSchemaOptions { + ToJsonSchemaOptions { + fields_always_required: false, + supports_format: false, + extract_descriptions: false, + top_level_must_be_object: true, + } + } +} diff --git a/src/llm/mod.rs b/src/llm/mod.rs index f06423d7..b91eed1c 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -12,6 +12,7 @@ pub enum LlmApiType { Ollama, OpenAi, Gemini, + Anthropic, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -54,6 +55,7 @@ pub trait LlmGenerationClient: Send + Sync { mod ollama; mod openai; mod gemini; +mod anthropic; pub async fn new_llm_generation_client(spec: LlmSpec) -> Result> { let client = match spec.api_type { @@ -66,6 +68,9 @@ pub async fn new_llm_generation_client(spec: LlmSpec) -> Result { Box::new(gemini::Client::new(spec).await?) as Box } + LlmApiType::Anthropic => { + Box::new(anthropic::Client::new(spec).await?) as Box + } }; Ok(client) }