From fbc2c9473ef8025140a4bcaaeabc4475979f3da2 Mon Sep 17 00:00:00 2001 From: Param Arora Date: Mon, 21 Apr 2025 03:30:12 +0530 Subject: [PATCH 1/4] feat: add gemini support to dataflow --- examples/manuals_llm_extraction/main.py | 5 + python/cocoindex/llm.py | 1 + src/lib_context.rs | 3 +- src/llm/gemini.rs | 116 ++++++++++++++++++++++++ src/llm/mod.rs | 5 + 5 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 src/llm/gemini.rs diff --git a/examples/manuals_llm_extraction/main.py b/examples/manuals_llm_extraction/main.py index 7cc8a1ad..2547c3fa 100644 --- a/examples/manuals_llm_extraction/main.py +++ b/examples/manuals_llm_extraction/main.py @@ -90,9 +90,14 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco # See the full list of models: https://ollama.com/library model="llama3.2" ), + # Replace by this spec below, to use OpenAI API model instead of ollama # llm_spec=cocoindex.LlmSpec( # api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"), + + # Replace by this spec below, to use Gemini API model + # llm_spec=cocoindex.LlmSpec( + # api_type=cocoindex.LlmApiType.GEMINI, model="gemini-2.0-flash"), output_type=ModuleInfo, instruction="Please extract Python module information from the manual.")) doc["module_summary"] = doc["module_info"].transform(summarize_module) diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index ab1a6f04..9db6da41 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -5,6 +5,7 @@ class LlmApiType(Enum): """The type of LLM API to use.""" OPENAI = "OpenAi" OLLAMA = "Ollama" + GEMINI = "Gemini" @dataclass class LlmSpec: diff --git a/src/lib_context.rs b/src/lib_context.rs index 1e3b99b4..5f8ebd25 100644 --- a/src/lib_context.rs +++ b/src/lib_context.rs @@ -95,7 +95,8 @@ static LIB_INIT: OnceLock<()> = OnceLock::new(); pub fn create_lib_context(settings: settings::Settings) -> Result { LIB_INIT.get_or_init(|| { console_subscriber::init(); - env_logger::init(); + let _ = env_logger::try_init(); + pyo3_async_runtimes::tokio::init_with_runtime(get_runtime()).unwrap(); }); diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs new file mode 100644 index 00000000..cb0c9bf1 --- /dev/null +++ b/src/llm/gemini.rs @@ -0,0 +1,116 @@ +use async_trait::async_trait; +use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; +use anyhow::{Result, anyhow}; +use serde_json; +use reqwest::Client as HttpClient; +use serde_json::{json, Value}; + +pub struct Client { + model: String, +} + +impl Client { + pub async fn new(spec: LlmSpec) -> Result { + if std::env::var("GEMINI_API_KEY").is_err() { + anyhow::bail!("GEMINI_API_KEY environment variable must be set"); + } + Ok(Self { + model: spec.model, + }) + } +} + +// Recursively remove all `additionalProperties` fields from a JSON value +fn remove_additional_properties(value: &mut Value) { + match value { + Value::Object(map) => { + map.remove("additionalProperties"); + for v in map.values_mut() { + remove_additional_properties(v); + } + } + Value::Array(arr) => { + for v in arr { + remove_additional_properties(v); + } + } + _ => {} + } +} + +#[async_trait] +impl LlmGenerationClient for Client { + async fn generate<'req>( + &self, + request: LlmGenerateRequest<'req>, + ) -> Result { + // Compose the prompt/messages + let mut contents = vec![serde_json::json!({ + "role": "user", + "parts": [{ "text": request.user_prompt }] + })]; + + // Optionally add system prompt + let mut system_instruction = None; + if let Some(system) = request.system_prompt { + system_instruction = Some(serde_json::json!({ + "parts": [{ "text": system }] + })); + } + + // Prepare payload + let mut payload = serde_json::json!({ "contents": contents }); + if let Some(system) = system_instruction { + payload["systemInstruction"] = system; + } + + // If structured output is requested, add schema and responseMimeType + if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format { + let mut schema_json = serde_json::to_value(schema)?; + remove_additional_properties(&mut schema_json); + payload["generationConfig"] = serde_json::json!({ + "responseMimeType": "application/json", + "responseSchema": schema_json + }); + } + + let api_key = std::env::var("GEMINI_API_KEY") + .map_err(|_| anyhow!("GEMINI_API_KEY environment variable must be set"))?; + let url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", + self.model, api_key + ); + + let client = HttpClient::new(); + let resp = client.post(&url) + .json(&payload) + .send() + .await + .map_err(|e| anyhow!("HTTP error: {e}"))?; + + let resp_json: Value = resp.json().await.map_err(|e| anyhow!("Invalid JSON: {e}"))?; + + // Debug log Gemini response + println!("Gemini request payload: {:#}", payload); + println!("Gemini response JSON: {:#}", resp_json); + + if let Some(error) = resp_json.get("error") { + return Err(anyhow!("Gemini API error: {:?}", error)); + } + let text = resp_json["candidates"][0]["content"]["parts"][0]["text"] + .as_str() + .unwrap_or("") + .to_string(); + + Ok(LlmGenerateResponse { text }) + } + + fn json_schema_options(&self) -> ToJsonSchemaOptions { + ToJsonSchemaOptions { + fields_always_required: false, + supports_format: false, + extract_descriptions: false, + top_level_must_be_object: true, + } + } +} \ No newline at end of file diff --git a/src/llm/mod.rs b/src/llm/mod.rs index e362fb31..f06423d7 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -11,6 +11,7 @@ use crate::base::json_schema::ToJsonSchemaOptions; pub enum LlmApiType { Ollama, OpenAi, + Gemini, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -52,6 +53,7 @@ pub trait LlmGenerationClient: Send + Sync { mod ollama; mod openai; +mod gemini; pub async fn new_llm_generation_client(spec: LlmSpec) -> Result> { let client = match spec.api_type { @@ -61,6 +63,9 @@ pub async fn new_llm_generation_client(spec: LlmSpec) -> Result { Box::new(openai::Client::new(spec).await?) as Box } + LlmApiType::Gemini => { + Box::new(gemini::Client::new(spec).await?) as Box + } }; Ok(client) } From 05de0417bdd2aa5f88c70286d9341d2c2d73b10f Mon Sep 17 00:00:00 2001 From: Param Arora Date: Mon, 21 Apr 2025 04:29:47 +0530 Subject: [PATCH 2/4] feat: add docs for gemini --- docs/docs/ai/llm.mdx | 23 ++++++++++++++++++++++- src/llm/gemini.rs | 8 ++------ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/docs/docs/ai/llm.mdx b/docs/docs/ai/llm.mdx index 899c295e..02ced152 100644 --- a/docs/docs/ai/llm.mdx +++ b/docs/docs/ai/llm.mdx @@ -21,7 +21,7 @@ It has the following fields: * `address` (optional): The address of the LLM API. -## LLM API integrations +## LLM API Integrations CocoIndex integrates with various LLM APIs for these functions. @@ -77,3 +77,24 @@ cocoindex.LlmSpec( +### Google Gemini + +To use the Gemini LLM API, you need to set the environment variable `GEMINI_API_KEY`. +You can generate the API key from [Google AI Studio](https://aistudio.google.com/apikey). + +A spec for Gemini looks like this: + + + + +```python +cocoindex.LlmSpec( + api_type=cocoindex.LlmApiType.GEMINI, + model="gemini-2.0-flash", +) +``` + + + + +You can find the full list of models supported by Gemini [here](https://ai.google.dev/gemini-api/docs/models). \ No newline at end of file diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index cb0c9bf1..7d60c8d1 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -3,7 +3,7 @@ use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateRe use anyhow::{Result, anyhow}; use serde_json; use reqwest::Client as HttpClient; -use serde_json::{json, Value}; +use serde_json::Value; pub struct Client { model: String, @@ -45,7 +45,7 @@ impl LlmGenerationClient for Client { request: LlmGenerateRequest<'req>, ) -> Result { // Compose the prompt/messages - let mut contents = vec![serde_json::json!({ + let contents = vec![serde_json::json!({ "role": "user", "parts": [{ "text": request.user_prompt }] })]; @@ -90,10 +90,6 @@ impl LlmGenerationClient for Client { let resp_json: Value = resp.json().await.map_err(|e| anyhow!("Invalid JSON: {e}"))?; - // Debug log Gemini response - println!("Gemini request payload: {:#}", payload); - println!("Gemini response JSON: {:#}", resp_json); - if let Some(error) = resp_json.get("error") { return Err(anyhow!("Gemini API error: {:?}", error)); } From 0ff37d5a774c41cfd9c9ef4757a1ea20c33cb265 Mon Sep 17 00:00:00 2001 From: Param Arora Date: Tue, 22 Apr 2025 03:11:21 +0530 Subject: [PATCH 3/4] Refactor Gemini client --- src/llm/gemini.rs | 58 +++++++++++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index 7d60c8d1..cf9ed9d7 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -1,21 +1,26 @@ use async_trait::async_trait; use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; -use anyhow::{Result, anyhow}; -use serde_json; -use reqwest::Client as HttpClient; +use anyhow::{Result, bail}; use serde_json::Value; +use crate::api_bail; +use urlencoding::encode; pub struct Client { model: String, + api_key: String, + client: reqwest::Client, } impl Client { pub async fn new(spec: LlmSpec) -> Result { - if std::env::var("GEMINI_API_KEY").is_err() { - anyhow::bail!("GEMINI_API_KEY environment variable must be set"); - } + let api_key = match std::env::var("GEMINI_API_KEY") { + Ok(val) => val, + Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"), + }; Ok(Self { model: spec.model, + api_key, + client: reqwest::Client::new(), }) } } @@ -51,12 +56,11 @@ impl LlmGenerationClient for Client { })]; // Optionally add system prompt - let mut system_instruction = None; - if let Some(system) = request.system_prompt { - system_instruction = Some(serde_json::json!({ - "parts": [{ "text": system }] - })); - } + let system_instruction = request.system_prompt.map(|system| + serde_json::json!({ + "parts": [ { "text": system } ] + }) + ); // Prepare payload let mut payload = serde_json::json!({ "contents": contents }); @@ -74,29 +78,33 @@ impl LlmGenerationClient for Client { }); } - let api_key = std::env::var("GEMINI_API_KEY") - .map_err(|_| anyhow!("GEMINI_API_KEY environment variable must be set"))?; + let api_key = &self.api_key; let url = format!( "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", - self.model, api_key + encode(&self.model), encode(api_key) ); - let client = HttpClient::new(); - let resp = client.post(&url) + let resp = match self.client.post(&url) .json(&payload) .send() - .await - .map_err(|e| anyhow!("HTTP error: {e}"))?; + .await { + Ok(resp) => resp, + Err(e) => api_bail!("HTTP error: {e}"), + }; - let resp_json: Value = resp.json().await.map_err(|e| anyhow!("Invalid JSON: {e}"))?; + let resp_json: Value = match resp.json().await { + Ok(json) => json, + Err(e) => api_bail!("Invalid JSON: {e}"), + }; if let Some(error) = resp_json.get("error") { - return Err(anyhow!("Gemini API error: {:?}", error)); + bail!("Gemini API error: {:?}", error); } - let text = resp_json["candidates"][0]["content"]["parts"][0]["text"] - .as_str() - .unwrap_or("") - .to_string(); + let mut resp_json = resp_json; + let text = match &mut resp_json["candidates"][0]["content"]["parts"][0]["text"] { + Value::String(s) => std::mem::take(s), + _ => bail!("No text in response"), + }; Ok(LlmGenerateResponse { text }) } From 3af31043297121c9fa3e5bd72bed8e555f926fa3 Mon Sep 17 00:00:00 2001 From: Param Arora Date: Tue, 22 Apr 2025 12:53:00 +0530 Subject: [PATCH 4/4] use context and merge payload --- src/llm/gemini.rs | 28 +++++++++------------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs index cf9ed9d7..f5e274ad 100644 --- a/src/llm/gemini.rs +++ b/src/llm/gemini.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; -use anyhow::{Result, bail}; +use anyhow::{Result, bail, Context}; use serde_json::Value; use crate::api_bail; use urlencoding::encode; @@ -55,17 +55,12 @@ impl LlmGenerationClient for Client { "parts": [{ "text": request.user_prompt }] })]; - // Optionally add system prompt - let system_instruction = request.system_prompt.map(|system| - serde_json::json!({ - "parts": [ { "text": system } ] - }) - ); - // Prepare payload let mut payload = serde_json::json!({ "contents": contents }); - if let Some(system) = system_instruction { - payload["systemInstruction"] = system; + if let Some(system) = request.system_prompt { + payload["systemInstruction"] = serde_json::json!({ + "parts": [ { "text": system } ] + }); } // If structured output is requested, add schema and responseMimeType @@ -84,18 +79,13 @@ impl LlmGenerationClient for Client { encode(&self.model), encode(api_key) ); - let resp = match self.client.post(&url) + let resp = self.client.post(&url) .json(&payload) .send() - .await { - Ok(resp) => resp, - Err(e) => api_bail!("HTTP error: {e}"), - }; + .await + .context("HTTP error")?; - let resp_json: Value = match resp.json().await { - Ok(json) => json, - Err(e) => api_bail!("Invalid JSON: {e}"), - }; + let resp_json: Value = resp.json().await.context("Invalid JSON")?; if let Some(error) = resp_json.get("error") { bail!("Gemini API error: {:?}", error);