Skip to content

Commit 8630b3f

Browse files
committed
feat: dedicated llm type
1 parent 42c953b commit 8630b3f

File tree

7 files changed

+33
-6
lines changed

7 files changed

+33
-6
lines changed

langchain_llamacpp_chat_model/llama_chat_model.py

+7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
class LlamaChatModel(BaseChatOpenAI):
1010
model_name: str = "unknown"
11+
llama: Llama = None
1112

1213
def __init__(
1314
self,
@@ -19,3 +20,9 @@ def __init__(
1920
client=LLamaOpenAIClientProxy(llama=llama),
2021
async_client=LLamaOpenAIClientAsyncProxy(llama=llama),
2122
)
23+
self.llama = llama
24+
25+
@property
26+
def _llm_type(self) -> str:
27+
"""Return type of chat model."""
28+
return self.llama.model_path

tests/test_functional/models_configuration.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def _create_models_settings():
4545
return models
4646

4747

48-
def create_llama(request) -> Llama:
49-
local_path = _model_local_path(request.param)
48+
def create_llama(params) -> Llama:
49+
local_path = _model_local_path(params)
5050

5151
return Llama(
5252
model_path=local_path,

tests/test_functional/test_ainvoke.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TestAInvoke:
1919
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
2020
)
2121
def llama(self, request) -> Llama:
22-
return create_llama(request)
22+
return create_llama(request.param)
2323

2424
@pytest.fixture
2525
def instance(self, llama):

tests/test_functional/test_astream.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class TestAStream:
1212
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
1313
)
1414
def llama(self, request) -> Llama:
15-
return create_llama(request)
15+
return create_llama(request.param)
1616

1717
@pytest.fixture
1818
def instance(self, llama):

tests/test_functional/test_invoke.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class TestInvoke:
2020
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
2121
)
2222
def llama(self, request) -> Llama:
23-
return create_llama(request)
23+
return create_llama(request.param)
2424

2525
@pytest.fixture
2626
def instance(self, llama):
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from llama_cpp import Llama
2+
import pytest
3+
from langchain_llamacpp_chat_model import LlamaChatModel
4+
from tests.test_functional.models_configuration import create_llama, models_to_test
5+
6+
7+
class TestInvoke:
8+
9+
@pytest.fixture()
10+
def llama(self) -> Llama:
11+
12+
return create_llama(models_to_test[0])
13+
14+
@pytest.fixture
15+
def instance(self, llama):
16+
return LlamaChatModel(llama=llama)
17+
18+
def test_llm_type(self, instance: LlamaChatModel):
19+
result = instance._llm_type
20+
assert models_to_test[0]["repo_id"] in result

tests/test_functional/test_stream.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class TestStream:
1212
params=models_to_test, ids=[config["repo_id"] for config in models_to_test]
1313
)
1414
def llama(self, request) -> Llama:
15-
return create_llama(request)
15+
return create_llama(request.param)
1616

1717
@pytest.fixture
1818
def instance(self, llama):

0 commit comments

Comments
 (0)