1
1
import random
2
+ from abc import abstractmethod
2
3
3
4
import numpy as np
4
- import torch
5
- from transformers import AutoModelForCausalLM , AutoTokenizer , PreTrainedModel , pipeline
5
+ from loguru import logger
6
+
7
+ try :
8
+ import torch
9
+ except ImportError :
10
+ logger .warning ("torch is not installed. This module will not be available." )
6
11
7
12
8
13
class ReproducibleHF :
9
- def __init__ (
10
- self ,
11
- model_id : str = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4" ,
12
- device : str = "cuda:0" ,
13
- sampling_params : dict [str , str | float | int | bool ] | None = None ,
14
- ):
15
- """Deterministic HuggingFace model."""
14
+ def __init__ (self , model_id : str , device : str , sampling_params : dict [str , str | float | int | bool ] | None = None ):
15
+ self .model_id = model_id
16
16
self ._device = device
17
- self .sampling_params = {} if sampling_params is None else sampling_params
18
- self .model : PreTrainedModel = AutoModelForCausalLM .from_pretrained (
19
- model_id ,
20
- torch_dtype = torch .float16 ,
21
- low_cpu_mem_usage = True ,
22
- device_map = self ._device ,
23
- )
17
+ self .sampling_params = sampling_params if sampling_params else {}
24
18
25
- self .tokenizer = AutoTokenizer .from_pretrained (model_id )
26
- self .valid_generation_params = set (
27
- AutoModelForCausalLM .from_pretrained (model_id ).generation_config .to_dict ().keys ()
28
- )
29
- self .llm = pipeline ("text-generation" , model = self .model , tokenizer = self .tokenizer )
19
+ @staticmethod
20
+ @abstractmethod
21
+ def format_messages (messages : list [str ] | list [dict [str , str ]]) -> list [dict [str , str | list [dict [str , str ]]]]:
22
+ raise NotImplementedError ("This method must be implemented by the subclass" )
30
23
31
- @torch .inference_mode ()
32
24
def generate (
33
25
self ,
34
26
messages : list [str ] | list [dict [str , str ]],
35
27
sampling_params : dict [str , str | float | int | bool ] | None = None ,
36
28
seed : int | None = None ,
37
29
) -> str :
38
30
"""Generate text with optimized performance."""
39
- self .set_random_seeds (seed )
31
+ with torch .inference_mode ():
32
+ self .set_random_seeds (seed )
40
33
41
- inputs = self .tokenizer .apply_chat_template (
42
- messages ,
43
- tokenize = True ,
44
- add_generation_prompt = True ,
45
- return_tensors = "pt" ,
46
- return_dict = True ,
47
- ).to (self ._device )
34
+ inputs = self .tokenizer .apply_chat_template (
35
+ self . message_formater ( messages ) ,
36
+ tokenize = True ,
37
+ add_generation_prompt = True ,
38
+ return_tensors = "pt" ,
39
+ return_dict = True ,
40
+ ).to (self ._device )
48
41
49
- params = sampling_params if sampling_params else self .sampling_params
50
- filtered_params = {k : v for k , v in params .items () if k in self .valid_generation_params }
42
+ params = sampling_params if sampling_params else self .sampling_params
43
+ filtered_params = {k : v for k , v in params .items () if k in self .valid_generation_params }
51
44
52
- outputs = self .model .generate (
53
- ** inputs ,
54
- ** filtered_params ,
55
- eos_token_id = self .tokenizer .eos_token_id ,
56
- )
45
+ outputs = self .model .generate (
46
+ ** inputs ,
47
+ ** filtered_params ,
48
+ )
57
49
58
- results = self .tokenizer .batch_decode (
59
- outputs [:, inputs ["input_ids" ].shape [1 ] :],
60
- skip_special_tokens = True ,
61
- )[0 ]
50
+ results = self .tokenizer .batch_decode (
51
+ outputs [:, inputs ["input_ids" ].shape [1 ] :],
52
+ skip_special_tokens = True ,
53
+ )[0 ]
62
54
63
- return results if len (results ) > 1 else results [0 ]
55
+ return results if len (results ) > 1 else results [0 ]
64
56
65
57
def set_random_seeds (self , seed : int | None = 42 ):
66
58
"""Set random seeds for reproducibility across all relevant libraries."""
@@ -72,8 +64,3 @@ def set_random_seeds(self, seed: int | None = 42):
72
64
torch .cuda .manual_seed_all (seed )
73
65
torch .backends .cudnn .deterministic = True
74
66
torch .backends .cudnn .benchmark = False
75
-
76
-
77
- # if __name__ == "__main__":
78
- # llm = ReproducibleHF(model="Qwen/Qwen2-0.5B", tensor_parallel_size=1, seed=42)
79
- # llm.generate({"role": "user", "content": "Hello, world!"})
0 commit comments