Skip to content

Commit a45bd90

Browse files
authored
refactor assistant streaming and create OpenAI compliant base class (#425)
1 parent 84cf4f6 commit a45bd90

File tree

12 files changed

+241
-173
lines changed

12 files changed

+241
-173
lines changed

docs/examples/gallery_streaming.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
# - [OpenAI](https://openai.com/)
3030
# - [ragna.assistants.Gpt35Turbo16k][]
3131
# - [ragna.assistants.Gpt4][]
32+
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
33+
# - [ragna.assistants.LlamafileAssistant][]
3234

3335
from ragna import assistants
3436

docs/tutorials/gallery_python_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,12 @@
8585
# - [ragna.assistants.Gpt4][]
8686
# - [AI21 Labs](https://www.ai21.com/)
8787
# - [ragna.assistants.Jurassic2Ultra][]
88+
# - [llamafile](https://github.com/Mozilla-Ocho/llamafile)
89+
# - [ragna.assistants.LlamafileAssistant][]
8890
#
8991
# !!! note
9092
#
91-
# To use any of builtin assistants, you need to
93+
# To use some of the builtin assistants, you need to
9294
# [procure API keys](../../references/faq.md#where-do-i-get-api-keys-for-the-builtin-assistants)
9395
# first and set the corresponding environment variables.
9496

ragna/assistants/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"Gpt35Turbo16k",
1010
"Gpt4",
1111
"Jurassic2Ultra",
12+
"LlamafileAssistant",
1213
"RagnaDemoAssistant",
1314
]
1415

@@ -17,6 +18,7 @@
1718
from ._cohere import Command, CommandLight
1819
from ._demo import RagnaDemoAssistant
1920
from ._google import GeminiPro, GeminiUltra
21+
from ._llamafile import LlamafileAssistant
2022
from ._openai import Gpt4, Gpt35Turbo16k
2123

2224
# isort: split

ragna/assistants/_ai21labs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
from ragna.core import Source
44

5-
from ._api import ApiAssistant
5+
from ._http_api import HttpApiAssistant
66

77

8-
class Ai21LabsAssistant(ApiAssistant):
8+
class Ai21LabsAssistant(HttpApiAssistant):
99
_API_KEY_ENV_VAR = "AI21_API_KEY"
1010
_MODEL_TYPE: str
1111

@@ -21,8 +21,8 @@ def _make_system_content(self, sources: list[Source]) -> str:
2121
)
2222
return instruction + "\n\n".join(source.content for source in sources)
2323

24-
async def _call_api(
25-
self, prompt: str, sources: list[Source], *, max_new_tokens: int
24+
async def answer(
25+
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
2626
) -> AsyncIterator[str]:
2727
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
2828
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters

ragna/assistants/_anthropic.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import json
21
from typing import AsyncIterator, cast
32

43
from ragna.core import PackageRequirement, RagnaException, Requirement, Source
54

6-
from ._api import ApiAssistant
5+
from ._http_api import HttpApiAssistant
76

87

9-
class AnthropicApiAssistant(ApiAssistant):
8+
class AnthropicAssistant(HttpApiAssistant):
109
_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
1110
_MODEL: str
1211

@@ -36,15 +35,12 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str:
3635
+ "</documents>"
3736
)
3837

39-
async def _call_api(
40-
self, prompt: str, sources: list[Source], *, max_new_tokens: int
38+
async def answer(
39+
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
4140
) -> AsyncIterator[str]:
42-
import httpx_sse
43-
4441
# See https://docs.anthropic.com/claude/reference/messages_post
4542
# See https://docs.anthropic.com/claude/reference/streaming
46-
async with httpx_sse.aconnect_sse(
47-
self._client,
43+
async for data in self._stream_sse(
4844
"POST",
4945
"https://api.anthropic.com/v1/messages",
5046
headers={
@@ -61,23 +57,19 @@ async def _call_api(
6157
"temperature": 0.0,
6258
"stream": True,
6359
},
64-
) as event_source:
65-
await self._assert_api_call_is_success(event_source.response)
66-
67-
async for sse in event_source.aiter_sse():
68-
data = json.loads(sse.data)
69-
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
70-
if "error" in data:
71-
raise RagnaException(data["error"].pop("message"), **data["error"])
72-
elif data["type"] == "message_stop":
73-
break
74-
elif data["type"] != "content_block_delta":
75-
continue
60+
):
61+
# See https://docs.anthropic.com/claude/reference/messages-streaming#raw-http-stream-response
62+
if "error" in data:
63+
raise RagnaException(data["error"].pop("message"), **data["error"])
64+
elif data["type"] == "message_stop":
65+
break
66+
elif data["type"] != "content_block_delta":
67+
continue
7668

77-
yield cast(str, data["delta"].pop("text"))
69+
yield cast(str, data["delta"].pop("text"))
7870

7971

80-
class ClaudeOpus(AnthropicApiAssistant):
72+
class ClaudeOpus(AnthropicAssistant):
8173
"""[Claude 3 Opus](https://docs.anthropic.com/claude/docs/models-overview)
8274
8375
!!! info "Required environment variables"
@@ -92,7 +84,7 @@ class ClaudeOpus(AnthropicApiAssistant):
9284
_MODEL = "claude-3-opus-20240229"
9385

9486

95-
class ClaudeSonnet(AnthropicApiAssistant):
87+
class ClaudeSonnet(AnthropicAssistant):
9688
"""[Claude 3 Sonnet](https://docs.anthropic.com/claude/docs/models-overview)
9789
9890
!!! info "Required environment variables"
@@ -107,7 +99,7 @@ class ClaudeSonnet(AnthropicApiAssistant):
10799
_MODEL = "claude-3-sonnet-20240229"
108100

109101

110-
class ClaudeHaiku(AnthropicApiAssistant):
102+
class ClaudeHaiku(AnthropicAssistant):
111103
"""[Claude 3 Haiku](https://docs.anthropic.com/claude/docs/models-overview)
112104
113105
!!! info "Required environment variables"

ragna/assistants/_api.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

ragna/assistants/_cohere.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
import json
21
from typing import AsyncIterator, cast
32

43
from ragna.core import RagnaException, Source
54

6-
from ._api import ApiAssistant
5+
from ._http_api import HttpApiAssistant
76

87

9-
class CohereApiAssistant(ApiAssistant):
8+
class CohereAssistant(HttpApiAssistant):
109
_API_KEY_ENV_VAR = "COHERE_API_KEY"
1110
_MODEL: str
1211

@@ -24,13 +23,13 @@ def _make_preamble(self) -> str:
2423
def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]:
2524
return [{"title": source.id, "snippet": source.content} for source in sources]
2625

27-
async def _call_api(
28-
self, prompt: str, sources: list[Source], *, max_new_tokens: int
26+
async def answer(
27+
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
2928
) -> AsyncIterator[str]:
3029
# See https://docs.cohere.com/docs/cochat-beta
3130
# See https://docs.cohere.com/reference/chat
3231
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
33-
async with self._client.stream(
32+
async for event in self._stream_jsonl(
3433
"POST",
3534
"https://api.cohere.ai/v1/chat",
3635
headers={
@@ -47,21 +46,17 @@ async def _call_api(
4746
"max_tokens": max_new_tokens,
4847
"documents": self._make_source_documents(sources),
4948
},
50-
) as response:
51-
await self._assert_api_call_is_success(response)
49+
):
50+
if event["event_type"] == "stream-end":
51+
if event["event_type"] == "COMPLETE":
52+
break
5253

53-
async for chunk in response.aiter_lines():
54-
event = json.loads(chunk)
55-
if event["event_type"] == "stream-end":
56-
if event["event_type"] == "COMPLETE":
57-
break
54+
raise RagnaException(event["error_message"])
55+
if "text" in event:
56+
yield cast(str, event["text"])
5857

59-
raise RagnaException(event["error_message"])
60-
if "text" in event:
61-
yield cast(str, event["text"])
6258

63-
64-
class Command(CohereApiAssistant):
59+
class Command(CohereAssistant):
6560
"""
6661
[Cohere Command](https://docs.cohere.com/docs/models#command)
6762
@@ -73,7 +68,7 @@ class Command(CohereApiAssistant):
7368
_MODEL = "command"
7469

7570

76-
class CommandLight(CohereApiAssistant):
71+
class CommandLight(CohereAssistant):
7772
"""
7873
[Cohere Command-Light](https://docs.cohere.com/docs/models#command)
7974

ragna/assistants/_google.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from ragna._compat import anext
44
from ragna.core import PackageRequirement, Requirement, Source
55

6-
from ._api import ApiAssistant
6+
from ._http_api import HttpApiAssistant
77

88

99
# ijson does not support reading from an (async) iterator, but only from file-like
@@ -25,7 +25,7 @@ async def read(self, n: int) -> bytes:
2525
return await anext(self._ait, b"") # type: ignore[call-arg]
2626

2727

28-
class GoogleApiAssistant(ApiAssistant):
28+
class GoogleAssistant(HttpApiAssistant):
2929
_API_KEY_ENV_VAR = "GOOGLE_API_KEY"
3030
_MODEL: str
3131

@@ -48,8 +48,8 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
4848
]
4949
)
5050

51-
async def _call_api(
52-
self, prompt: str, sources: list[Source], *, max_new_tokens: int
51+
async def answer(
52+
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
5353
) -> AsyncIterator[str]:
5454
import ijson
5555

@@ -88,7 +88,7 @@ async def _call_api(
8888
yield chunk
8989

9090

91-
class GeminiPro(GoogleApiAssistant):
91+
class GeminiPro(GoogleAssistant):
9292
"""[Google Gemini Pro](https://ai.google.dev/models/gemini)
9393
9494
!!! info "Required environment variables"
@@ -103,7 +103,7 @@ class GeminiPro(GoogleApiAssistant):
103103
_MODEL = "gemini-pro"
104104

105105

106-
class GeminiUltra(GoogleApiAssistant):
106+
class GeminiUltra(GoogleAssistant):
107107
"""[Google Gemini Ultra](https://ai.google.dev/models/gemini)
108108
109109
!!! info "Required environment variables"

0 commit comments

Comments
 (0)