Skip to content

Commit c258948

Browse files
committed
feat: llama_cpp proxy chat model
1 parent e5f528d commit c258948

File tree

8 files changed

+77
-203
lines changed

8 files changed

+77
-203
lines changed

Diff for: langchain_llamacpp_chat_model/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .llama_chat_model import LlamaChatModel
2+
from .llama_proxy_chat_model import LlamaProxyChatModel
3+
4+
__all__ = ["LlamaChatModel", "LlamaProxyChatModel"]

Diff for: langchain_llamacpp_chat_model/llama_chat_model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ def __call__(self):
3535
def __enter__(self):
3636
return self()
3737

38-
def __exit__(self):
38+
def __exit__(self, exception_type, exception_value, exception_traceback):
39+
if hasattr(self.response, "close"):
40+
self.response.close()
3941
return False
4042

4143

Diff for: langchain_llamacpp_chat_model/llama_proxy_chat_model.py

+3-53
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,17 @@
1-
from pydantic import Field
2-
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence
3-
from langchain_openai.chat_models.base import BaseChatOpenAI
4-
from langchain_core.pydantic_v1 import BaseModel
5-
6-
from langchain_core.callbacks import (
7-
CallbackManagerForLLMRun,
8-
)
9-
from langchain_core.language_models import BaseChatModel
10-
from langchain_core.tools import BaseTool
11-
from langchain_core.runnables import Runnable
12-
from langchain_core.language_models.base import LanguageModelInput
13-
from langchain_core.messages import AIMessageChunk, BaseMessage, AIMessage
14-
from langchain_core.utils.function_calling import convert_to_openai_tool
15-
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
16-
from llama_cpp import (
17-
CreateCompletionResponse,
18-
CreateCompletionStreamResponse,
19-
Literal,
20-
LlamaGrammar,
21-
LogitsProcessorList,
22-
StoppingCriteriaList,
23-
Type,
24-
Union,
25-
)
1+
from typing import Any, Dict
262
from llama_cpp.server.app import LlamaProxy
273

284
from langchain_llamacpp_chat_model.llama_chat_model import LlamaChatModel
295

30-
# Use this class until it's implemented in LangChain Community
31-
326

337
class LlamaProxyChatModel(LlamaChatModel):
34-
model_name: str = Field(default="", alias="model")
35-
36-
suffix: Optional[str] = None
37-
max_tokens: Optional[int] = 2048
38-
temperature: float = 0.8
39-
top_p: float = 0.95
40-
min_p: float = 0.05
41-
typical_p: float = 1.0
42-
logprobs: Optional[int] = None
43-
echo: bool = False
44-
stop: Optional[Union[str, List[str]]] = []
45-
frequency_penalty: float = 0.0
46-
presence_penalty: float = 0.0
47-
repeat_penalty: float = 1.1
48-
top_k: int = 40
49-
seed: Optional[int] = None
50-
tfs_z: float = 1.0
51-
mirostat_mode: int = 0
52-
mirostat_tau: float = 5.0
53-
mirostat_eta: float = 0.1
54-
stopping_criteria: Optional[StoppingCriteriaList] = None
55-
logits_processor: Optional[LogitsProcessorList] = None
56-
grammar: Optional[LlamaGrammar] = None
57-
logit_bias: Optional[Dict[str, float]] = None
58-
598
def __init__(
609
self,
6110
llama_proxy: LlamaProxy,
6211
**kwargs,
6312
):
64-
llama = llama_proxy(self.model_name)
13+
model = kwargs.get("model_name", kwargs.get("model"))
14+
llama = llama_proxy(model)
6515
super().__init__(**kwargs, llama=llama)
6616

6717
@property

Diff for: tests/test_functional/models_configuration.py

+28-4
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,50 @@
11
import os
22
from llama_cpp import Llama
3+
from llama_cpp.server.app import LlamaProxy
4+
from llama_cpp.server.settings import ModelSettings
35

46
models_to_test = [
57
{
68
"repo_id": "lmstudio-community/Meta-Llama-3-8B-Instruct-GGUF",
79
"filename": "Meta-Llama-3-8B-Instruct-Q4_K_M.gguf",
10+
"alias": "llama3",
811
},
912
{
1013
"repo_id": "microsoft/Phi-3-mini-4k-instruct-gguf",
1114
"filename": "Phi-3-mini-4k-instruct-q4.gguf",
15+
"alias": "phi3",
1216
},
1317
]
1418

1519

16-
def create_llama(request) -> Llama:
17-
local_path = os.path.join(
20+
def _model_local_path(model) -> str:
21+
return os.path.join(
1822
os.path.expanduser("~/.cache/lm-studio/models"),
19-
request.param["repo_id"],
20-
request.param["filename"],
23+
model["repo_id"],
24+
model["filename"],
2125
)
2226

27+
28+
def _create_models_settings():
29+
models: list[ModelSettings] = []
30+
for model in models_to_test:
31+
local_path = _model_local_path(model)
32+
models.append(
33+
ModelSettings(model=local_path, model_alias=model["alias"], n_gpu_layers=-1)
34+
)
35+
36+
return models
37+
38+
39+
def create_llama(request) -> Llama:
40+
local_path = _model_local_path(request.param)
41+
2342
return Llama(
2443
model_path=local_path,
2544
n_gpu_layers=-1,
2645
)
46+
47+
48+
def create_llama_proxy() -> LlamaProxy:
49+
models = _create_models_settings()
50+
return LlamaProxy(models=models)

Diff for: tests/test_functional/test_invoke.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
from langchain_core.messages import AIMessage, HumanMessage
44

5-
from langchain_llamacpp_chat_model.llama_chat_model import LlamaChatModel
5+
from langchain_llamacpp_chat_model import LlamaChatModel
66

77
from langchain_core.pydantic_v1 import BaseModel, Field
88
from langchain_core.tools import tool

Diff for: tests/test_functional/test_llama_proxy.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import pytest
2+
from langchain_core.messages import AIMessage, HumanMessage
3+
4+
from langchain_llamacpp_chat_model import LlamaProxyChatModel
5+
6+
from tests.test_functional.models_configuration import (
7+
create_llama_proxy,
8+
models_to_test,
9+
)
10+
from llama_cpp.server.app import LlamaProxy
11+
12+
13+
@pytest.fixture
14+
def llama_proxy() -> LlamaProxy:
15+
return create_llama_proxy()
16+
17+
18+
class TestLlamaProxyChat:
19+
20+
@pytest.fixture(
21+
params=models_to_test, ids=[config["alias"] for config in models_to_test]
22+
)
23+
def instance(self, llama_proxy: LlamaProxy, request):
24+
return LlamaProxyChatModel(
25+
llama_proxy=llama_proxy, model_name=request.param["alias"]
26+
)
27+
28+
def test_conversation_memory(self, instance: LlamaProxyChatModel):
29+
result = instance.invoke(
30+
input=[
31+
HumanMessage(content="Remember that I like bananas"),
32+
AIMessage(content="Okay"),
33+
HumanMessage(content="What do I like?"),
34+
]
35+
)
36+
37+
assert "banana" in result.content

Diff for: tests/test_functional/test_stream.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33
from langchain_core.messages import AIMessage, HumanMessage
44

5-
from langchain_llamacpp_chat_model.llama_chat_model import LlamaChatModel
5+
from langchain_llamacpp_chat_model import LlamaChatModel
66
from tests.test_functional.models_configuration import create_llama, models_to_test
77

88

Diff for: tests/test_llama_proxy_cpp_chat_model.py

-143
This file was deleted.

0 commit comments

Comments
 (0)