Skip to content

Commit d303296

Browse files
authored
Feat: Add Gemini Support to DataFlow (Rust) (#360)
* feat: add gemini support to dataflow * feat: add docs for gemini * Refactor Gemini client * use context and merge payload
1 parent 1d2dd80 commit d303296

File tree

6 files changed

+145
-2
lines changed

6 files changed

+145
-2
lines changed

docs/docs/ai/llm.mdx

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ It has the following fields:
2121
* `address` (optional): The address of the LLM API.
2222

2323

24-
## LLM API integrations
24+
## LLM API Integrations
2525

2626
CocoIndex integrates with various LLM APIs for these functions.
2727

@@ -77,3 +77,24 @@ cocoindex.LlmSpec(
7777
</TabItem>
7878
</Tabs>
7979

80+
### Google Gemini
81+
82+
To use the Gemini LLM API, you need to set the environment variable `GEMINI_API_KEY`.
83+
You can generate the API key from [Google AI Studio](https://aistudio.google.com/apikey).
84+
85+
A spec for Gemini looks like this:
86+
87+
<Tabs>
88+
<TabItem value="python" label="Python" default>
89+
90+
```python
91+
cocoindex.LlmSpec(
92+
api_type=cocoindex.LlmApiType.GEMINI,
93+
model="gemini-2.0-flash",
94+
)
95+
```
96+
97+
</TabItem>
98+
</Tabs>
99+
100+
You can find the full list of models supported by Gemini [here](https://ai.google.dev/gemini-api/docs/models).

examples/manuals_llm_extraction/main.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,14 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco
9090
# See the full list of models: https://ollama.com/library
9191
model="llama3.2"
9292
),
93+
9394
# Replace by this spec below, to use OpenAI API model instead of ollama
9495
# llm_spec=cocoindex.LlmSpec(
9596
# api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"),
97+
98+
# Replace by this spec below, to use Gemini API model
99+
# llm_spec=cocoindex.LlmSpec(
100+
# api_type=cocoindex.LlmApiType.GEMINI, model="gemini-2.0-flash"),
96101
output_type=ModuleInfo,
97102
instruction="Please extract Python module information from the manual."))
98103
doc["module_summary"] = doc["module_info"].transform(summarize_module)

python/cocoindex/llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ class LlmApiType(Enum):
55
"""The type of LLM API to use."""
66
OPENAI = "OpenAi"
77
OLLAMA = "Ollama"
8+
GEMINI = "Gemini"
89

910
@dataclass
1011
class LlmSpec:

src/lib_context.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ static LIB_INIT: OnceLock<()> = OnceLock::new();
9595
pub fn create_lib_context(settings: settings::Settings) -> Result<LibContext> {
9696
LIB_INIT.get_or_init(|| {
9797
console_subscriber::init();
98-
env_logger::init();
98+
let _ = env_logger::try_init();
99+
99100
pyo3_async_runtimes::tokio::init_with_runtime(get_runtime()).unwrap();
100101
});
101102

src/llm/gemini.rs

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
use async_trait::async_trait;
2+
use crate::llm::{LlmGenerationClient, LlmSpec, LlmGenerateRequest, LlmGenerateResponse, ToJsonSchemaOptions, OutputFormat};
3+
use anyhow::{Result, bail, Context};
4+
use serde_json::Value;
5+
use crate::api_bail;
6+
use urlencoding::encode;
7+
8+
pub struct Client {
9+
model: String,
10+
api_key: String,
11+
client: reqwest::Client,
12+
}
13+
14+
impl Client {
15+
pub async fn new(spec: LlmSpec) -> Result<Self> {
16+
let api_key = match std::env::var("GEMINI_API_KEY") {
17+
Ok(val) => val,
18+
Err(_) => api_bail!("GEMINI_API_KEY environment variable must be set"),
19+
};
20+
Ok(Self {
21+
model: spec.model,
22+
api_key,
23+
client: reqwest::Client::new(),
24+
})
25+
}
26+
}
27+
28+
// Recursively remove all `additionalProperties` fields from a JSON value
29+
fn remove_additional_properties(value: &mut Value) {
30+
match value {
31+
Value::Object(map) => {
32+
map.remove("additionalProperties");
33+
for v in map.values_mut() {
34+
remove_additional_properties(v);
35+
}
36+
}
37+
Value::Array(arr) => {
38+
for v in arr {
39+
remove_additional_properties(v);
40+
}
41+
}
42+
_ => {}
43+
}
44+
}
45+
46+
#[async_trait]
47+
impl LlmGenerationClient for Client {
48+
async fn generate<'req>(
49+
&self,
50+
request: LlmGenerateRequest<'req>,
51+
) -> Result<LlmGenerateResponse> {
52+
// Compose the prompt/messages
53+
let contents = vec![serde_json::json!({
54+
"role": "user",
55+
"parts": [{ "text": request.user_prompt }]
56+
})];
57+
58+
// Prepare payload
59+
let mut payload = serde_json::json!({ "contents": contents });
60+
if let Some(system) = request.system_prompt {
61+
payload["systemInstruction"] = serde_json::json!({
62+
"parts": [ { "text": system } ]
63+
});
64+
}
65+
66+
// If structured output is requested, add schema and responseMimeType
67+
if let Some(OutputFormat::JsonSchema { schema, .. }) = &request.output_format {
68+
let mut schema_json = serde_json::to_value(schema)?;
69+
remove_additional_properties(&mut schema_json);
70+
payload["generationConfig"] = serde_json::json!({
71+
"responseMimeType": "application/json",
72+
"responseSchema": schema_json
73+
});
74+
}
75+
76+
let api_key = &self.api_key;
77+
let url = format!(
78+
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
79+
encode(&self.model), encode(api_key)
80+
);
81+
82+
let resp = self.client.post(&url)
83+
.json(&payload)
84+
.send()
85+
.await
86+
.context("HTTP error")?;
87+
88+
let resp_json: Value = resp.json().await.context("Invalid JSON")?;
89+
90+
if let Some(error) = resp_json.get("error") {
91+
bail!("Gemini API error: {:?}", error);
92+
}
93+
let mut resp_json = resp_json;
94+
let text = match &mut resp_json["candidates"][0]["content"]["parts"][0]["text"] {
95+
Value::String(s) => std::mem::take(s),
96+
_ => bail!("No text in response"),
97+
};
98+
99+
Ok(LlmGenerateResponse { text })
100+
}
101+
102+
fn json_schema_options(&self) -> ToJsonSchemaOptions {
103+
ToJsonSchemaOptions {
104+
fields_always_required: false,
105+
supports_format: false,
106+
extract_descriptions: false,
107+
top_level_must_be_object: true,
108+
}
109+
}
110+
}

src/llm/mod.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::base::json_schema::ToJsonSchemaOptions;
1111
pub enum LlmApiType {
1212
Ollama,
1313
OpenAi,
14+
Gemini,
1415
}
1516

1617
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -52,6 +53,7 @@ pub trait LlmGenerationClient: Send + Sync {
5253

5354
mod ollama;
5455
mod openai;
56+
mod gemini;
5557

5658
pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGenerationClient>> {
5759
let client = match spec.api_type {
@@ -61,6 +63,9 @@ pub async fn new_llm_generation_client(spec: LlmSpec) -> Result<Box<dyn LlmGener
6163
LlmApiType::OpenAi => {
6264
Box::new(openai::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
6365
}
66+
LlmApiType::Gemini => {
67+
Box::new(gemini::Client::new(spec).await?) as Box<dyn LlmGenerationClient>
68+
}
6469
};
6570
Ok(client)
6671
}

0 commit comments

Comments
 (0)