diff --git a/docs/docs/ai/llm.mdx b/docs/docs/ai/llm.mdx index 899c295e..02ced152 100644 --- a/docs/docs/ai/llm.mdx +++ b/docs/docs/ai/llm.mdx @@ -21,7 +21,7 @@ It has the following fields: * `address` (optional): The address of the LLM API. -## LLM API integrations +## LLM API Integrations CocoIndex integrates with various LLM APIs for these functions. @@ -77,3 +77,24 @@ cocoindex.LlmSpec( +### Google Gemini + +To use the Gemini LLM API, you need to set the environment variable `GEMINI_API_KEY`. +You can generate the API key from [Google AI Studio](https://aistudio.google.com/apikey). + +A spec for Gemini looks like this: + + + + +```python +cocoindex.LlmSpec( + api_type=cocoindex.LlmApiType.GEMINI, + model="gemini-2.0-flash", +) +``` + + + + +You can find the full list of models supported by Gemini [here](https://ai.google.dev/gemini-api/docs/models). \ No newline at end of file diff --git a/examples/manuals_llm_extraction/main.py b/examples/manuals_llm_extraction/main.py index 7cc8a1ad..2547c3fa 100644 --- a/examples/manuals_llm_extraction/main.py +++ b/examples/manuals_llm_extraction/main.py @@ -90,9 +90,14 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco # See the full list of models: https://ollama.com/library model="llama3.2" ), + # Replace by this spec below, to use OpenAI API model instead of ollama # llm_spec=cocoindex.LlmSpec( # api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"), + + # Replace by this spec below, to use Gemini API model + # llm_spec=cocoindex.LlmSpec( + # api_type=cocoindex.LlmApiType.GEMINI, model="gemini-2.0-flash"), output_type=ModuleInfo, instruction="Please extract Python module information from the manual.")) doc["module_summary"] = doc["module_info"].transform(summarize_module) diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index ab1a6f04..9db6da41 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -5,6 +5,7 @@ class LlmApiType(Enum): """The type of LLM API to use.""" OPENAI = "OpenAi" OLLAMA = "Ollama" + GEMINI = "Gemini" @dataclass class LlmSpec: diff --git a/src/lib_context.rs b/src/lib_context.rs index 1e3b99b4..5f8ebd25 100644 --- a/src/lib_context.rs +++ b/src/lib_context.rs @@ -95,7 +95,8 @@ static LIB_INIT: OnceLock<()> = OnceLock::new(); pub fn create_lib_context(settings: settings::Settings) -> Result { LIB_INIT.get_or_init(|| { console_subscriber::init(); - env_logger::init(); + let _ = env_logger::try_init(); + pyo3_async_runtimes::tokio::init_with_runtime(get_runtime()).unwrap(); }); diff --git a/src/llm/gemini.rs b/src/llm/gemini.rs new file mode 100644 index 00000000..f5e274ad --- /dev/null +++ b/src/llm/gemini.rs @@ -0,0 +1,110 @@ +use async_trait::async_trait; +use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat}; +use anyhow::{Result, bail, Context}; +use serde_json::Value; +use crate::api_bail; +use urlencoding::encode; + +pub struct Client { + model: String, + api_key: String, + client: reqwest::Client, +} + +impl Client { + pub async fn new(spec: LlmSpec) -> Result { + let api_key = match std::env::var("GEMINI_API_KEY") { + Ok(val) => val, + Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"), + }; + Ok(Self { + model: spec.model, + api_key, + client: reqwest::Client::new(), + }) + } +} + +// Recursively remove all `additionalProperties` fields from a JSON value +fn remove_additional_properties(value: &mut Value) { + match value { + Value::Object(map) => { + map.remove("additionalProperties"); + for v in map.values_mut() { + remove_additional_properties(v); + } + } + Value::Array(arr) => { + for v in arr { + remove_additional_properties(v); + } + } + _ => {} + } +} + +#[async_trait] +impl LlmGenerationClient for Client { + async fn generate<'req>( + &self, + request: LlmGenerateRequest<'req>, + ) -> Result { + // Compose the prompt/messages + let contents = vec![serde_json::json!({ + "role": "user", + "parts": [{ "text": request.user_prompt }] + })]; + + // Prepare payload + let mut payload = serde_json::json!({ "contents": contents }); + if let Some(system) = request.system_prompt { + payload["systemInstruction"] = serde_json::json!({ + "parts": [ { "text": system } ] + }); + } + + // If structured output is requested, add schema and responseMimeType + if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format { + let mut schema_json = serde_json::to_value(schema)?; + remove_additional_properties(&mut schema_json); + payload["generationConfig"] = serde_json::json!({ + "responseMimeType": "application/json", + "responseSchema": schema_json + }); + } + + let api_key = &self.api_key; + let url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}", + encode(&self.model), encode(api_key) + ); + + let resp = self.client.post(&url) + .json(&payload) + .send() + .await + .context("HTTP error")?; + + let resp_json: Value = resp.json().await.context("Invalid JSON")?; + + if let Some(error) = resp_json.get("error") { + bail!("Gemini API error: {:?}", error); + } + let mut resp_json = resp_json; + let text = match &mut resp_json["candidates"][0]["content"]["parts"][0]["text"] { + Value::String(s) => std::mem::take(s), + _ => bail!("No text in response"), + }; + + Ok(LlmGenerateResponse { text }) + } + + fn json_schema_options(&self) -> ToJsonSchemaOptions { + ToJsonSchemaOptions { + fields_always_required: false, + supports_format: false, + extract_descriptions: false, + top_level_must_be_object: true, + } + } +} \ No newline at end of file diff --git a/src/llm/mod.rs b/src/llm/mod.rs index e362fb31..f06423d7 100644 --- a/src/llm/mod.rs +++ b/src/llm/mod.rs @@ -11,6 +11,7 @@ use crate::base::json_schema::ToJsonSchemaOptions; pub enum LlmApiType { Ollama, OpenAi, + Gemini, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -52,6 +53,7 @@ pub trait LlmGenerationClient: Send + Sync { mod ollama; mod openai; +mod gemini; pub async fn new_llm_generation_client(spec: LlmSpec) -> Result> { let client = match spec.api_type { @@ -61,6 +63,9 @@ pub async fn new_llm_generation_client(spec: LlmSpec) -> Result { Box::new(openai::Client::new(spec).await?) as Box } + LlmApiType::Gemini => { + Box::new(gemini::Client::new(spec).await?) as Box + } }; Ok(client) }