|
1 | 1 | from llama_cpp import Llama
|
2 | 2 |
|
3 | 3 | from langchain_openai.chat_models.base import BaseChatOpenAI
|
4 |
| -from pydantic import Field |
5 | 4 |
|
6 |
| - |
7 |
| -class LLamaOpenAIClientProxy: |
8 |
| - def __init__(self, llama: Llama): |
9 |
| - self.llama = llama |
10 |
| - |
11 |
| - def create(self, **kwargs): |
12 |
| - proxy = LlamaCreateContextManager(llama=self.llama, **kwargs) |
13 |
| - if "stream" in kwargs and kwargs["stream"] is True: |
14 |
| - return proxy |
15 |
| - else: |
16 |
| - return proxy() |
17 |
| - |
18 |
| - |
19 |
| -class LlamaCreateContextManager: |
20 |
| - |
21 |
| - def __init__(self, llama: Llama, **kwargs): |
22 |
| - self.llama = llama |
23 |
| - self.kwargs = kwargs |
24 |
| - self.response = None |
25 |
| - |
26 |
| - def __call__(self): |
27 |
| - self.kwargs.pop("n", None) |
28 |
| - self.kwargs.pop( |
29 |
| - "parallel_tool_calls", None |
30 |
| - ) # LLamaCPP does not support parallel tool calls for now |
31 |
| - |
32 |
| - self.response = self.llama.create_chat_completion(**self.kwargs) |
33 |
| - return self.response |
34 |
| - |
35 |
| - def __enter__(self): |
36 |
| - return self() |
37 |
| - |
38 |
| - def __exit__(self, exception_type, exception_value, exception_traceback): |
39 |
| - if hasattr(self.response, "close"): |
40 |
| - self.response.close() |
41 |
| - return False |
| 5 | +from .llama_client_proxy import LLamaOpenAIClientProxy |
| 6 | +from .llama_client_async_proxy import LLamaOpenAIClientAsyncProxy |
42 | 7 |
|
43 | 8 |
|
44 | 9 | class LlamaChatModel(BaseChatOpenAI):
|
45 |
| - model_name: str = Field(default="", alias="model") |
| 10 | + model_name: str = "unknown" |
46 | 11 |
|
47 | 12 | def __init__(
|
48 | 13 | self,
|
49 | 14 | llama: Llama,
|
50 | 15 | **kwargs,
|
51 | 16 | ):
|
52 |
| - super().__init__(**kwargs, client=LLamaOpenAIClientProxy(llama=llama)) |
| 17 | + super().__init__( |
| 18 | + **kwargs, |
| 19 | + client=LLamaOpenAIClientProxy(llama=llama), |
| 20 | + async_client=LLamaOpenAIClientAsyncProxy(llama=llama), |
| 21 | + ) |
0 commit comments