Skip to content

Commit d5f4dc7

Browse files
authored
feat: Add NvidiaChatGenerator based on OpenAIChatGenerator (#1776)
* NvidiaChatGenerator based on OpenAIChatGenerator * tool.mypy.overrides openai * add tests with extra_body * use serialize_tools_or_toolset * move tests into single file * fmt * pin lowest direct dependencies
1 parent b2f5303 commit d5f4dc7

File tree

5 files changed

+521
-5
lines changed

5 files changed

+521
-5
lines changed

.github/workflows/nvidia.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ jobs:
5555
if: matrix.python-version == '3.9' && runner.os == 'Linux'
5656
run: hatch run lint:all
5757

58-
- name: Run tests
59-
run: hatch run cov-retry
60-
6158
- name: Generate docs
6259
if: matrix.python-version == '3.9' && runner.os == 'Linux'
6360
run: hatch run docs
6461

62+
- name: Run tests
63+
run: hatch run cov-retry
64+
6565
- name: Run unit tests with lowest direct dependencies
6666
run: |
6767
hatch run uv pip compile pyproject.toml --resolution lowest-direct --output-file requirements_lowest_direct.txt

integrations/nvidia/pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ classifiers = [
2323
"Programming Language :: Python :: Implementation :: CPython",
2424
"Programming Language :: Python :: Implementation :: PyPy",
2525
]
26-
dependencies = ["haystack-ai", "requests>=2.25.0", "tqdm>=4.21.0"]
26+
dependencies = ["haystack-ai>=2.13.0", "requests>=2.25.0", "tqdm>=4.21.0"]
2727

2828
[project.urls]
2929
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/nvidia#readme"
@@ -46,6 +46,8 @@ installer = "uv"
4646
dependencies = [
4747
"coverage[toml]>=6.5",
4848
"pytest",
49+
"pytest-asyncio",
50+
"pytz",
4951
"pytest-rerunfailures",
5052
"haystack-pydoc-tools",
5153
"requests_mock",
@@ -160,6 +162,7 @@ module = [
160162
"pytest.*",
161163
"numpy.*",
162164
"requests_mock.*",
165+
"openai.*",
163166
"pydantic.*",
164167
]
165168
ignore_missing_imports = true

integrations/nvidia/src/haystack_integrations/components/generators/nvidia/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from .chat.chat_generator import NvidiaChatGenerator
56
from .generator import NvidiaGenerator
67

7-
__all__ = ["NvidiaGenerator"]
8+
__all__ = ["NvidiaChatGenerator", "NvidiaGenerator"]
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import os
6+
from typing import Any, Dict, List, Optional, Union
7+
8+
from haystack import component, default_to_dict, logging
9+
from haystack.components.generators.chat import OpenAIChatGenerator
10+
from haystack.dataclasses import StreamingCallbackT
11+
from haystack.tools import Tool, Toolset, serialize_tools_or_toolset
12+
from haystack.utils import serialize_callable
13+
from haystack.utils.auth import Secret
14+
15+
from haystack_integrations.utils.nvidia import DEFAULT_API_URL
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
@component
21+
class NvidiaChatGenerator(OpenAIChatGenerator):
22+
"""
23+
Enables text generation using NVIDIA generative models.
24+
For supported models, see [NVIDIA Docs](https://build.nvidia.com/models).
25+
26+
Users can pass any text generation parameters valid for the NVIDIA Chat Completion API
27+
directly to this component via the `generation_kwargs` parameter in `__init__` or the `generation_kwargs`
28+
parameter in `run` method.
29+
30+
This component uses the ChatMessage format for structuring both input and output,
31+
ensuring coherent and contextually relevant responses in chat-based text generation scenarios.
32+
Details on the ChatMessage format can be found in the
33+
[Haystack docs](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage)
34+
35+
For more details on the parameters supported by the NVIDIA API, refer to the
36+
[NVIDIA Docs](https://build.nvidia.com/models).
37+
38+
Usage example:
39+
```python
40+
from haystack_integrations.components.generators.nvidia import NvidiaChatGenerator
41+
from haystack.dataclasses import ChatMessage
42+
43+
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
44+
45+
client = NvidiaChatGenerator()
46+
response = client.run(messages)
47+
print(response)
48+
```
49+
"""
50+
51+
def __init__(
52+
self,
53+
*,
54+
api_key: Secret = Secret.from_env_var("NVIDIA_API_KEY"),
55+
model: str = "meta/llama-3.1-8b-instruct",
56+
streaming_callback: Optional[StreamingCallbackT] = None,
57+
api_base_url: Optional[str] = os.getenv("NVIDIA_API_URL", DEFAULT_API_URL),
58+
generation_kwargs: Optional[Dict[str, Any]] = None,
59+
tools: Optional[Union[List[Tool], Toolset]] = None,
60+
timeout: Optional[float] = None,
61+
max_retries: Optional[int] = None,
62+
http_client_kwargs: Optional[Dict[str, Any]] = None,
63+
):
64+
"""
65+
Creates an instance of NvidiaChatGenerator.
66+
67+
:param api_key:
68+
The NVIDIA API key.
69+
:param model:
70+
The name of the NVIDIA chat completion model to use.
71+
:param streaming_callback:
72+
A callback function that is called when a new token is received from the stream.
73+
The callback function accepts StreamingChunk as an argument.
74+
:param api_base_url:
75+
The NVIDIA API Base url.
76+
:param generation_kwargs:
77+
Other parameters to use for the model. These parameters are all sent directly to
78+
the NVIDIA API endpoint. See [NVIDIA API docs](https://docs.nvcf.nvidia.com/ai/generative-models/)
79+
for more details.
80+
Some of the supported parameters:
81+
- `max_tokens`: The maximum number of tokens the output text can have.
82+
- `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
83+
Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
84+
- `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
85+
considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens
86+
comprising the top 10% probability mass are considered.
87+
- `stream`: Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent
88+
events as they become available, with the stream terminated by a data: [DONE] message.
89+
:param tools:
90+
A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
91+
list of `Tool` objects or a `Toolset` instance.
92+
:param timeout:
93+
The timeout for the NVIDIA API call.
94+
:param max_retries:
95+
Maximum number of retries to contact NVIDIA after an internal error.
96+
If not set, it defaults to either the `NVIDIA_MAX_RETRIES` environment variable, or set to 5.
97+
:param http_client_kwargs:
98+
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
99+
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
100+
"""
101+
super(NvidiaChatGenerator, self).__init__( # noqa: UP008
102+
api_key=api_key,
103+
model=model,
104+
streaming_callback=streaming_callback,
105+
api_base_url=api_base_url,
106+
generation_kwargs=generation_kwargs,
107+
tools=tools,
108+
timeout=timeout,
109+
max_retries=max_retries,
110+
http_client_kwargs=http_client_kwargs,
111+
)
112+
113+
def to_dict(self) -> Dict[str, Any]:
114+
"""
115+
Serialize this component to a dictionary.
116+
117+
:returns:
118+
The serialized component as a dictionary.
119+
"""
120+
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
121+
122+
return default_to_dict(
123+
self,
124+
model=self.model,
125+
streaming_callback=callback_name,
126+
api_base_url=self.api_base_url,
127+
generation_kwargs=self.generation_kwargs,
128+
api_key=self.api_key.to_dict(),
129+
tools=serialize_tools_or_toolset(self.tools),
130+
timeout=self.timeout,
131+
max_retries=self.max_retries,
132+
http_client_kwargs=self.http_client_kwargs,
133+
)

0 commit comments

Comments
 (0)