-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathconfig.py
85 lines (65 loc) · 3.11 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from dataclasses import dataclass, field
from logging import getLogger
from typing import Any, Dict, Optional
from ...import_utils import torch_version
from ..config import BackendConfig
AMP_DTYPES = ["bfloat16", "float16"]
TORCH_DTYPES = ["bfloat16", "float16", "float32", "auto"]
QUANTIZATION_CONFIGS = {"bnb": {"llm_int8_threshold": 0.0}}
LOGGER = getLogger(__name__)
@dataclass
class PyTorchConfig(BackendConfig):
name: str = "pytorch"
version: Optional[str] = torch_version()
_target_: str = "optimum_benchmark.backends.pytorch.backend.PyTorchBackend"
# load options
no_weights: bool = False
device_map: Optional[str] = None
torch_dtype: Optional[str] = None
# optimization options
eval_mode: bool = True
to_bettertransformer: bool = False
low_cpu_mem_usage: Optional[bool] = None
attn_implementation: Optional[str] = None
cache_implementation: Optional[str] = None
# automatic mixed precision options
autocast_enabled: bool = False
autocast_dtype: Optional[str] = None
# torch compile options
torch_compile: bool = False
torch_compile_target: str = "forward"
torch_compile_config: Dict[str, Any] = field(default_factory=dict)
# quantization options
quantization_scheme: Optional[str] = None
quantization_config: Dict[str, Any] = field(default_factory=dict)
# distributed inference options
deepspeed_inference: bool = False
deepspeed_inference_config: Dict[str, Any] = field(default_factory=dict)
# peft options
peft_type: Optional[str] = None
peft_config: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
if self.model_kwargs.get("torch_dtype", None) is not None:
raise ValueError(
"`torch_dtype` is an explicit argument in the PyTorch backend config. "
"Please remove it from the `model_kwargs` and set it in the backend config directly."
)
if self.torch_dtype is not None and self.torch_dtype not in TORCH_DTYPES:
raise ValueError(f"`torch_dtype` must be one of {TORCH_DTYPES}. Got {self.torch_dtype} instead.")
if self.autocast_dtype is not None and self.autocast_dtype not in AMP_DTYPES:
raise ValueError(f"`autocast_dtype` must be one of {AMP_DTYPES}. Got {self.autocast_dtype} instead.")
if self.quantization_scheme is not None:
LOGGER.warning(
"`backend.quantization_scheme` is deprecated and will be removed in a future version. "
"Please use `quantization_config.quant_method` instead."
)
if self.quantization_config is None:
self.quantization_config = {"quant_method": self.quantization_scheme}
else:
self.quantization_config["quant_method"] = self.quantization_scheme
if self.quantization_config is not None:
self.quantization_config = dict(
QUANTIZATION_CONFIGS.get(self.quantization_scheme, {}), # default config
**self.quantization_config, # user config (overwrites default)
)