Skip to content

Feat: Add Gemini Support to DataFlow (Rust) #360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion docs/docs/ai/llm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ It has the following fields:
* `address` (optional): The address of the LLM API.


## LLM API integrations
## LLM API Integrations

CocoIndex integrates with various LLM APIs for these functions.

Expand Down Expand Up @@ -77,3 +77,24 @@ cocoindex.LlmSpec(
</TabItem>
</Tabs>

### Google Gemini

To use the Gemini LLM API, you need to set the environment variable `GEMINI_API_KEY`.
You can generate the API key from [Google AI Studio](https://aistudio.google.com/apikey).

A spec for Gemini looks like this:

<Tabs>
<TabItem value="python" label="Python" default>

```python
cocoindex.LlmSpec(
api_type=cocoindex.LlmApiType.GEMINI,
model="gemini-2.0-flash",
)
```

</TabItem>
</Tabs>

You can find the full list of models supported by Gemini [here](https://ai.google.dev/gemini-api/docs/models).
5 changes: 5 additions & 0 deletions examples/manuals_llm_extraction/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,14 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco
# See the full list of models: https://ollama.com/library
model="llama3.2"
),

# Replace by this spec below, to use OpenAI API model instead of ollama
# llm_spec=cocoindex.LlmSpec(
# api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"),

# Replace by this spec below, to use Gemini API model
# llm_spec=cocoindex.LlmSpec(
# api_type=cocoindex.LlmApiType.GEMINI, model="gemini-2.0-flash"),
output_type=ModuleInfo,
instruction="Please extract Python module information from the manual."))
doc["module_summary"] = doc["module_info"].transform(summarize_module)
Expand Down
1 change: 1 addition & 0 deletions python/cocoindex/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class LlmApiType(Enum):
"""The type of LLM API to use."""
OPENAI = "OpenAi"
OLLAMA = "Ollama"
GEMINI = "Gemini"

@dataclass
class LlmSpec:
Expand Down
3 changes: 2 additions & 1 deletion src/lib_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ static LIB_INIT: OnceLock<()> = OnceLock::new();
pub fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
LIB_INIT.get_or_init(|| {
console_subscriber::init();
env_logger::init();
let _ = env_logger::try_init();

pyo3_async_runtimes::tokio::init_with_runtime(get_runtime()).unwrap();
});

Expand Down
110 changes: 110 additions & 0 deletions src/llm/gemini.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use async_trait::async_trait;
use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat};
use anyhow::{Result, bail, Context};
use serde_json::Value;
use crate::api_bail;
use urlencoding::encode;

pub struct Client {
model: String,
api_key: String,
client: reqwest::Client,
}

impl Client {
pub async fn new(spec: LlmSpec) -> Result<Self> {
let api_key = match std::env::var("GEMINI_API_KEY") {
Ok(val) => val,
Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"),
};
Ok(Self {
model: spec.model,
api_key,
client: reqwest::Client::new(),
})
}
}

// Recursively remove all `additionalProperties` fields from a JSON value
fn remove_additional_properties(value: &mut Value) {
match value {
Value::Object(map) => {
map.remove("additionalProperties");
for v in map.values_mut() {
remove_additional_properties(v);
}
}
Value::Array(arr) => {
for v in arr {
remove_additional_properties(v);
}
}
_ => {}
}
}

#[async_trait]
impl LlmGenerationClient for Client {
async fn generate<'req>(
&self,
request: LlmGenerateRequest<'req>,
) -> Result<LlmGenerateResponse> {
// Compose the prompt/messages
let contents = vec![serde_json::json!({
"role": "user",
"parts": [{ "text": request.user_prompt }]
})];

// Prepare payload
let mut payload = serde_json::json!({ "contents": contents });
if let Some(system) = request.system_prompt {
payload["systemInstruction"] = serde_json::json!({
"parts": [ { "text": system } ]
});
}

// If structured output is requested, add schema and responseMimeType
if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format {
let mut schema_json = serde_json::to_value(schema)?;
remove_additional_properties(&mut schema_json);
payload["generationConfig"] = serde_json::json!({
"responseMimeType": "application/json",
"responseSchema": schema_json
});
}

let api_key = &self.api_key;
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
encode(&self.model), encode(api_key)
);

let resp = self.client.post(&url)
.json(&payload)
.send()
.await
.context("HTTP error")?;

let resp_json: Value = resp.json().await.context("Invalid JSON")?;

if let Some(error) = resp_json.get("error") {
bail!("Gemini API error: {:?}", error);
}
let mut resp_json = resp_json;
let text = match &mut resp_json["candidates"][0]["content"]["parts"][0]["text"] {
Value::String(s) => std::mem::take(s),
_ => bail!("No text in response"),
};

Ok(LlmGenerateResponse { text })
}

fn json_schema_options(&self) -> ToJsonSchemaOptions {
ToJsonSchemaOptions {
fields_always_required: false,
supports_format: false,
extract_descriptions: false,
top_level_must_be_object: true,
}
}
}
5 changes: 5 additions & 0 deletions src/llm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::base::json_schema::ToJsonSchemaOptions;
pub enum LlmApiType {
Ollama,
OpenAi,
Gemini,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -52,6 +53,7 @@ pub trait LlmGenerationClient: Send + Sync {

mod ollama;
mod openai;
mod gemini;

pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
let client = match spec.api_type {
Expand All @@ -61,6 +63,9 @@ pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGener
LlmApiType::OpenAi => {
Box::new(openai::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
LlmApiType::Gemini => {
Box::new(gemini::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
}
};
Ok(client)
}