|
2 | 2 |
|
3 | 3 | import multiprocessing
|
4 | 4 |
|
5 |
| -from typing import Optional, List, Literal, Union |
6 |
| -from pydantic import Field, root_validator |
| 5 | +from typing import Optional, List, Literal, Union, Dict, cast |
| 6 | +from typing_extensions import Self |
| 7 | + |
| 8 | +from pydantic import Field, model_validator |
7 | 9 | from pydantic_settings import BaseSettings
|
8 | 10 |
|
9 | 11 | import llama_cpp
|
@@ -173,15 +175,16 @@ class ModelSettings(BaseSettings):
|
173 | 175 | default=True, description="Whether to print debug information."
|
174 | 176 | )
|
175 | 177 |
|
176 |
| - @root_validator(pre=True) # pre=True to ensure this runs before any other validation |
177 |
| - def set_dynamic_defaults(cls, values): |
| 178 | + @model_validator(mode="before") # pre=True to ensure this runs before any other validation |
| 179 | + def set_dynamic_defaults(self) -> Self: |
178 | 180 | # If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count()
|
179 | 181 | cpu_count = multiprocessing.cpu_count()
|
| 182 | + values = cast(Dict[str, int], self) |
180 | 183 | if values.get('n_threads', 0) == -1:
|
181 | 184 | values['n_threads'] = cpu_count
|
182 | 185 | if values.get('n_threads_batch', 0) == -1:
|
183 | 186 | values['n_threads_batch'] = cpu_count
|
184 |
| - return values |
| 187 | + return self |
185 | 188 |
|
186 | 189 |
|
187 | 190 | class ServerSettings(BaseSettings):
|
|
0 commit comments