-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel_gemini.py
76 lines (61 loc) · 2.35 KB
/
model_gemini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import contextlib
from typing import AsyncGenerator, AsyncIterator
from google import genai
from av import AudioFrame, AudioResampler
from PIL.Image import Image
import io
from model import Model, Input, Output
SAMPLE_RATE = 16000
AUDIO_PTIME = 0.02
class Gemini(Model):
def __init__(self, session):
self.session = session
self.resampler = AudioResampler(
format="s16",
layout="mono",
rate=SAMPLE_RATE,
frame_size=int(SAMPLE_RATE * AUDIO_PTIME),
)
async def send(self, input: Input):
if isinstance(input, str):
await self.session.send(input=input, end_of_turn=True)
elif isinstance(input, AudioFrame):
for frame in self.resampler.resample(input):
blob = genai.types.BlobDict(
data=frame.to_ndarray().tobytes(),
mime_type=f"audio/pcm;rate={SAMPLE_RATE}",
)
await self.session.send(input=blob)
elif isinstance(input, Image):
array = io.BytesIO()
input.save(array, format="JPEG")
blob = genai.types.BlobDict(
data=array.getvalue(),
mime_type="image/jpeg",
)
await self.session.send(input=blob)
async def recv(self) -> AsyncIterator[Output]:
received = self.session.receive()
async for event in received:
if event.data is None:
# log_info(f"Server Message - {response}")
continue
mime_type = event.server_content.model_turn.parts[0].inline_data.mime_type
sample_rate = int(mime_type.split("rate=")[1])
frame = AudioFrame(format="s16", layout="mono", samples=len(event.data) / 2)
frame.sample_rate = sample_rate
frame.planes[0].update(event.data)
yield frame
async def close(self):
if self.session is None:
return
await self.session.close()
self.session = None
client = genai.Client(http_options={"api_version": "v1alpha"})
@contextlib.asynccontextmanager
async def connect_gemini() -> AsyncGenerator[Gemini, None]:
async with client.aio.live.connect(
model="gemini-2.0-flash-exp",
config={"generation_config": {"response_modalities": ["AUDIO"]}},
) as session:
yield Gemini(session)