Skip to content

Commit def5290

Browse files
committed
feat: support ainvoke, astream
1 parent 58ad941 commit def5290

File tree

6 files changed

+208
-39
lines changed

6 files changed

+208
-39
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,21 @@
11
from llama_cpp import Llama
22

33
from langchain_openai.chat_models.base import BaseChatOpenAI
4-
from pydantic import Field
54

6-
7-
class LLamaOpenAIClientProxy:
8-
def __init__(self, llama: Llama):
9-
self.llama = llama
10-
11-
def create(self, **kwargs):
12-
proxy = LlamaCreateContextManager(llama=self.llama, **kwargs)
13-
if "stream" in kwargs and kwargs["stream"] is True:
14-
return proxy
15-
else:
16-
return proxy()
17-
18-
19-
class LlamaCreateContextManager:
20-
21-
def __init__(self, llama: Llama, **kwargs):
22-
self.llama = llama
23-
self.kwargs = kwargs
24-
self.response = None
25-
26-
def __call__(self):
27-
self.kwargs.pop("n", None)
28-
self.kwargs.pop(
29-
"parallel_tool_calls", None
30-
) # LLamaCPP does not support parallel tool calls for now
31-
32-
self.response = self.llama.create_chat_completion(**self.kwargs)
33-
return self.response
34-
35-
def __enter__(self):
36-
return self()
37-
38-
def __exit__(self, exception_type, exception_value, exception_traceback):
39-
if hasattr(self.response, "close"):
40-
self.response.close()
41-
return False
5+
from .llama_client_proxy import LLamaOpenAIClientProxy
6+
from .llama_client_async_proxy import LLamaOpenAIClientAsyncProxy
427

438

449
class LlamaChatModel(BaseChatOpenAI):
45-
model_name: str = Field(default="", alias="model")
10+
model_name: str = "unknown"
4611

4712
def __init__(
4813
self,
4914
llama: Llama,
5015
**kwargs,
5116
):
52-
super().__init__(**kwargs, client=LLamaOpenAIClientProxy(llama=llama))
17+
super().__init__(
18+
**kwargs,
19+
client=LLamaOpenAIClientProxy(llama=llama),
20+
async_client=LLamaOpenAIClientAsyncProxy(llama=llama),
21+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from llama_cpp import Llama
2+
3+
4+
async def to_async_iterator(iterator):
5+
for item in iterator:
6+
yield item
7+
8+
9+
class LlamaCreateAsyncContextManager:
10+
11+
def __init__(self, llama: Llama, **kwargs):
12+
self.llama = llama
13+
self.kwargs = kwargs
14+
self.response = None
15+
16+
def __aiter__(self):
17+
return self
18+
19+
async def __anext__(self):
20+
try:
21+
return next(self.response)
22+
except Exception:
23+
raise StopAsyncIteration()
24+
25+
def __call__(self):
26+
self.kwargs.pop("n", None)
27+
self.kwargs.pop(
28+
"parallel_tool_calls", None
29+
) # LLamaCPP does not support parallel tool calls for now
30+
31+
self.response = self.llama.create_chat_completion(**self.kwargs)
32+
return self.response
33+
34+
async def __aenter__(self):
35+
return self()
36+
37+
async def __aexit__(self, exception_type, exception_value, exception_traceback):
38+
if hasattr(self.response, "close"):
39+
self.response.close()
40+
return False
41+
42+
43+
class LLamaOpenAIClientAsyncProxy:
44+
def __init__(self, llama: Llama):
45+
self.llama = llama
46+
47+
async def create(self, **kwargs):
48+
proxy = LlamaCreateAsyncContextManager(llama=self.llama, **kwargs)
49+
if "stream" in kwargs and kwargs["stream"] is True:
50+
return proxy
51+
else:
52+
return proxy()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from llama_cpp import Llama
2+
3+
4+
class LlamaCreateContextManager:
5+
6+
def __init__(self, llama: Llama, **kwargs):
7+
self.llama = llama
8+
self.kwargs = kwargs
9+
self.response = None
10+
11+
def __call__(self):
12+
self.kwargs.pop("n", None)
13+
self.kwargs.pop(
14+
"parallel_tool_calls", None
15+
) # LLamaCPP does not support parallel tool calls for now
16+
17+
self.response = self.llama.create_chat_completion(**self.kwargs)
18+
return self.response
19+
20+
def __enter__(self):
21+
return self()
22+
23+
def __exit__(self, exception_type, exception_value, exception_traceback):
24+
if hasattr(self.response, "close"):
25+
self.response.close()
26+
return False
27+
28+
29+
class LLamaOpenAIClientProxy:
30+
def __init__(self, llama: Llama):
31+
self.llama = llama
32+
33+
def create(self, **kwargs):
34+
proxy = LlamaCreateContextManager(llama=self.llama, **kwargs)
35+
if "stream" in kwargs and kwargs["stream"] is True:
36+
return proxy
37+
else:
38+
return proxy()

tests/test_functional/test_ainvoke.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from llama_cpp import Llama
2+
import pytest
3+
from langchain_core.messages import AIMessage, HumanMessage
4+
5+
from langchain_llamacpp_chat_model import LlamaChatModel
6+
7+
from langchain_core.pydantic_v1 import BaseModel, Field
8+
from tests.test_functional.models_configuration import create_llama, models_to_test
9+
10+
11+
class Joke(BaseModel):
12+
setup: str = Field(description="The setup of the joke")
13+
punchline: str = Field(description="The punchline to the joke")
14+
15+
16+
class TestAInvoke:
17+
18+
@pytest.fixture(
19+
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
20+
)
21+
def llama(self, request) -> Llama:
22+
return create_llama(request)
23+
24+
@pytest.fixture
25+
def instance(self, llama):
26+
return LlamaChatModel(llama=llama)
27+
28+
@pytest.mark.asyncio
29+
async def test_ainvoke(self, instance: LlamaChatModel):
30+
result = await instance.ainvoke("Say Hi!")
31+
32+
assert len(result.content) > 0
33+
34+
@pytest.mark.asyncio
35+
async def test_conversation_memory(self, instance: LlamaChatModel):
36+
result = await instance.ainvoke(
37+
input=[
38+
HumanMessage(content="Remember that I like bananas"),
39+
AIMessage(content="Okay"),
40+
HumanMessage(content="What do I like?"),
41+
]
42+
)
43+
44+
assert "banana" in result.content

tests/test_functional/test_astream.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from llama_cpp import Llama
2+
import pytest
3+
from langchain_core.messages import AIMessage, HumanMessage
4+
5+
from langchain_llamacpp_chat_model import LlamaChatModel
6+
from tests.test_functional.models_configuration import create_llama, models_to_test
7+
8+
9+
class TestAStream:
10+
11+
@pytest.fixture(
12+
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
13+
)
14+
def llama(self, request) -> Llama:
15+
return create_llama(request)
16+
17+
@pytest.fixture
18+
def instance(self, llama):
19+
return LlamaChatModel(llama=llama)
20+
21+
@pytest.mark.asyncio
22+
async def test_astream(self, instance: LlamaChatModel):
23+
24+
chunks = []
25+
async for chunk in instance.astream("Say Hi!"):
26+
chunks.append(chunk)
27+
28+
final_content = "".join(chunk.content for chunk in chunks)
29+
30+
assert len(final_content) > 0
31+
32+
@pytest.mark.asyncio
33+
async def test_conversation_memory(self, instance: LlamaChatModel):
34+
stream = instance.astream(
35+
input=[
36+
HumanMessage(content="Remember that I like bananas"),
37+
AIMessage(content="Okay"),
38+
HumanMessage(content="What do I like?"),
39+
]
40+
)
41+
42+
final_content = ""
43+
async for token in stream:
44+
final_content += token.content
45+
46+
assert len(final_content) > 0
47+
assert "banana" in final_content

tests/test_functional/test_invoke.py

+19
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,22 @@ def magic_number_tool(input: int) -> int:
6767
result = llm_with_tool.invoke("What is the magic mumber of 2?")
6868

6969
assert result.tool_calls[0]["name"] == "magic_number_tool"
70+
71+
72+
class TestAInvoke:
73+
74+
@pytest.fixture(
75+
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
76+
)
77+
def llama(self, request) -> Llama:
78+
return create_llama(request)
79+
80+
@pytest.fixture
81+
def instance(self, llama):
82+
return LlamaChatModel(llama=llama)
83+
84+
@pytest.mark.asyncio
85+
async def test_ainvoke(self, instance: LlamaChatModel):
86+
result = await instance.ainvoke("Say Hi!")
87+
88+
assert len(result.content) > 0

0 commit comments

Comments
 (0)