Skip to content

Commit da1bcc2

Browse files
[ENH] Add support for Ollama assistants (#376)
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
1 parent a45bd90 commit da1bcc2

File tree

11 files changed

+315
-126
lines changed

11 files changed

+315
-126
lines changed

docs/examples/gallery_streaming.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@
3131
# - [ragna.assistants.Gpt4][]
3232
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
3333
# - [ragna.assistants.LlamafileAssistant][]
34+
# - [Ollama](https://ollama.com/)
35+
# - [ragna.assistants.OllamaGemma2B][]
36+
# - [ragna.assistants.OllamaLlama2][]
37+
# - [ragna.assistants.OllamaLlava][]
38+
# - [ragna.assistants.OllamaMistral][]
39+
# - [ragna.assistants.OllamaMixtral][]
40+
# - [ragna.assistants.OllamaOrcaMini][]
41+
# - [ragna.assistants.OllamaPhi2][]
3442

3543
from ragna import assistants
3644

docs/tutorials/gallery_python_api.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@
8787
# - [ragna.assistants.Jurassic2Ultra][]
8888
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
8989
# - [ragna.assistants.LlamafileAssistant][]
90+
# - [Ollama](https://ollama.com/)
91+
# - [ragna.assistants.OllamaGemma2B][]
92+
# - [ragna.assistants.OllamaLlama2][]
93+
# - [ragna.assistants.OllamaLlava][]
94+
# - [ragna.assistants.OllamaMistral][]
95+
# - [ragna.assistants.OllamaMixtral][]
96+
# - [ragna.assistants.OllamaOrcaMini][]
97+
# - [ragna.assistants.OllamaPhi2][]
9098
#
9199
# !!! note
92100
#

ragna/assistants/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
"CommandLight",
77
"GeminiPro",
88
"GeminiUltra",
9+
"OllamaGemma2B",
10+
"OllamaPhi2",
11+
"OllamaLlama2",
12+
"OllamaLlava",
13+
"OllamaMistral",
14+
"OllamaMixtral",
15+
"OllamaOrcaMini",
916
"Gpt35Turbo16k",
1017
"Gpt4",
1118
"Jurassic2Ultra",
@@ -19,6 +26,15 @@
1926
from ._demo import RagnaDemoAssistant
2027
from ._google import GeminiPro, GeminiUltra
2128
from ._llamafile import LlamafileAssistant
29+
from ._ollama import (
30+
OllamaGemma2B,
31+
OllamaLlama2,
32+
OllamaLlava,
33+
OllamaMistral,
34+
OllamaMixtral,
35+
OllamaOrcaMini,
36+
OllamaPhi2,
37+
)
2238
from ._openai import Gpt4, Gpt35Turbo16k
2339

2440
# isort: split

ragna/assistants/_ai21labs.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
class Ai21LabsAssistant(HttpApiAssistant):
99
_API_KEY_ENV_VAR = "AI21_API_KEY"
10+
_STREAMING_PROTOCOL = None
1011
_MODEL_TYPE: str
1112

1213
@classmethod
@@ -27,7 +28,8 @@ async def answer(
2728
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
2829
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
2930
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response
30-
response = await self._client.post(
31+
async for data in self._call_api(
32+
"POST",
3133
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
3234
headers={
3335
"accept": "application/json",
@@ -46,10 +48,8 @@ async def answer(
4648
],
4749
"system": self._make_system_content(sources),
4850
},
49-
)
50-
await self._assert_api_call_is_success(response)
51-
52-
yield cast(str, response.json()["outputs"][0]["text"])
51+
):
52+
yield cast(str, data["outputs"][0]["text"])
5353

5454

5555
# The Jurassic2Mid assistant receives a 500 internal service error from the remote

ragna/assistants/_anthropic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
from ragna.core import PackageRequirement, RagnaException, Requirement, Source
44

5-
from ._http_api import HttpApiAssistant
5+
from ._http_api import HttpApiAssistant, HttpStreamingProtocol
66

77

88
class AnthropicAssistant(HttpApiAssistant):
99
_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
10+
_STREAMING_PROTOCOL = HttpStreamingProtocol.SSE
1011
_MODEL: str
1112

1213
@classmethod
@@ -40,7 +41,7 @@ async def answer(
4041
) -> AsyncIterator[str]:
4142
# See https://docs.anthropic.com/claude/reference/messages_post
4243
# See https://docs.anthropic.com/claude/reference/streaming
43-
async for data in self._stream_sse(
44+
async for data in self._call_api(
4445
"POST",
4546
"https://api.anthropic.com/v1/messages",
4647
headers={

ragna/assistants/_cohere.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
from ragna.core import RagnaException, Source
44

5-
from ._http_api import HttpApiAssistant
5+
from ._http_api import HttpApiAssistant, HttpStreamingProtocol
66

77

88
class CohereAssistant(HttpApiAssistant):
99
_API_KEY_ENV_VAR = "COHERE_API_KEY"
10+
_STREAMING_PROTOCOL = HttpStreamingProtocol.JSONL
1011
_MODEL: str
1112

1213
@classmethod
@@ -29,7 +30,7 @@ async def answer(
2930
# See https://docs.cohere.com/docs/cochat-beta
3031
# See https://docs.cohere.com/reference/chat
3132
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
32-
async for event in self._stream_jsonl(
33+
async for event in self._call_api(
3334
"POST",
3435
"https://api.cohere.ai/v1/chat",
3536
headers={

ragna/assistants/_google.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,15 @@
11
from typing import AsyncIterator
22

3-
from ragna._compat import anext
4-
from ragna.core import PackageRequirement, Requirement, Source
3+
from ragna.core import Source
54

6-
from ._http_api import HttpApiAssistant
7-
8-
9-
# ijson does not support reading from an (async) iterator, but only from file-like
10-
# objects, i.e. https://docs.python.org/3/tutorial/inputoutput.html#methods-of-file-objects.
11-
# See https://github.com/ICRAR/ijson/issues/44 for details.
12-
# ijson actually doesn't care about most of the file interface and only requires the
13-
# read() method to be present.
14-
class AsyncIteratorReader:
15-
def __init__(self, ait: AsyncIterator[bytes]) -> None:
16-
self._ait = ait
17-
18-
async def read(self, n: int) -> bytes:
19-
# n is usually used to indicate how many bytes to read, but since we want to
20-
# return a chunk as soon as it is available, we ignore the value of n. The only
21-
# exception is n == 0, which is used by ijson to probe the return type and
22-
# set up decoding.
23-
if n == 0:
24-
return b""
25-
return await anext(self._ait, b"") # type: ignore[call-arg]
5+
from ._http_api import HttpApiAssistant, HttpStreamingProtocol
266

277

288
class GoogleAssistant(HttpApiAssistant):
299
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
10+
_STREAMING_PROTOCOL = HttpStreamingProtocol.JSON
3011
_MODEL: str
3112

32-
@classmethod
33-
def _extra_requirements(cls) -> list[Requirement]:
34-
return [PackageRequirement("ijson")]
35-
3613
@classmethod
3714
def display_name(cls) -> str:
3815
return f"Google/{cls._MODEL}"
@@ -51,9 +28,7 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
5128
async def answer(
5229
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
5330
) -> AsyncIterator[str]:
54-
import ijson
55-
56-
async with self._client.stream(
31+
async for chunk in self._call_api(
5732
"POST",
5833
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
5934
params={"key": self._api_key},
@@ -64,7 +39,10 @@ async def answer(
6439
],
6540
# https://ai.google.dev/docs/safety_setting_gemini
6641
"safetySettings": [
67-
{"category": f"HARM_CATEGORY_{category}", "threshold": "BLOCK_NONE"}
42+
{
43+
"category": f"HARM_CATEGORY_{category}",
44+
"threshold": "BLOCK_NONE",
45+
}
6846
for category in [
6947
"HARASSMENT",
7048
"HATE_SPEECH",
@@ -78,14 +56,9 @@ async def answer(
7856
"maxOutputTokens": max_new_tokens,
7957
},
8058
},
81-
) as response:
82-
await self._assert_api_call_is_success(response)
83-
84-
async for chunk in ijson.items(
85-
AsyncIteratorReader(response.aiter_bytes(1024)),
86-
"item.candidates.item.content.parts.item.text",
87-
):
88-
yield chunk
59+
parse_kwargs=dict(item="item.candidates.item.content.parts.item.text"),
60+
):
61+
yield chunk
8962

9063

9164
class GeminiPro(GoogleAssistant):

0 commit comments

Comments
 (0)