2
2
import torch
3
3
import torch .nn .functional as F
4
4
from transformers .generation import TopKLogitsWarper , TopPLogitsWarper
5
+
5
6
from ..utils .infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
7
+ from ..utils .io import del_all
8
+ from ..model .gpt import GPT_warpper
6
9
7
10
def infer_code (
8
11
models ,
@@ -14,39 +17,42 @@ def infer_code(
14
17
repetition_penalty = 1.05 ,
15
18
max_new_token = 2048 ,
16
19
stream = False ,
20
+ device = "cpu" ,
17
21
** kwargs
18
22
):
19
-
20
- device = next ( models ['gpt' ]. parameters ()). device
21
-
23
+
24
+ gpt : GPT_warpper = models ['gpt' ]
25
+
22
26
if not isinstance (text , list ):
23
27
text = [text ]
24
28
25
29
if not isinstance (temperature , list ):
26
- temperature = [temperature ] * models [ ' gpt' ] .num_vq
30
+ temperature = [temperature ] * gpt .num_vq
27
31
28
32
if spk_emb is not None :
29
33
text = [f'[Stts][spk_emb]{ i } [Ptts]' for i in text ]
30
34
else :
31
35
text = [f'[Stts][empty_spk]{ i } [Ptts]' for i in text ]
32
36
33
- text_token = models ['tokenizer' ](text , return_tensors = 'pt' , add_special_tokens = False , padding = True ).to (device )
34
- input_ids = text_token ['input_ids' ][...,None ].expand (- 1 , - 1 , models ['gpt' ].num_vq )
35
- text_mask = torch .ones (text_token ['input_ids' ].shape , dtype = bool , device = device )
36
-
37
- inputs = {
38
- 'input_ids' : input_ids ,
39
- 'text_mask' : text_mask ,
40
- 'attention_mask' : text_token ['attention_mask' ],
41
- }
37
+ text_token_tmp = models ['tokenizer' ](text , return_tensors = 'pt' , add_special_tokens = False , padding = True )
38
+ text_token = text_token_tmp .to (device )
39
+ del text_token_tmp
40
+ input_ids = text_token ['input_ids' ][...,None ].expand (- 1 , - 1 , gpt .num_vq ).to (gpt .device_gpt )
41
+ text_mask = torch .ones (text_token ['input_ids' ].shape , dtype = bool , device = gpt .device_gpt )
42
+
43
+ emb = gpt .get_emb (
44
+ input_ids = input_ids ,
45
+ text_mask = text_mask ,
46
+ )
47
+ del text_mask
42
48
43
- emb = models ['gpt' ].get_emb (** inputs )
44
49
if spk_emb is not None :
45
- emb [inputs ['input_ids' ][..., 0 ] == models ['tokenizer' ].convert_tokens_to_ids ('[spk_emb]' )] = \
46
- F .normalize (spk_emb .to (device ).to (emb .dtype )[None ].expand (len (text ), - 1 ), p = 2.0 , dim = 1 , eps = 1e-12 )
47
-
48
- num_code = models ['gpt' ].emb_code [0 ].num_embeddings - 1
49
-
50
+ n = F .normalize (spk_emb .to (emb .dtype )[None ].expand (len (text ), - 1 ), p = 2.0 , dim = 1 , eps = 1e-12 ).to (gpt .device_gpt )
51
+ emb [input_ids [..., 0 ] == models ['tokenizer' ].convert_tokens_to_ids ('[spk_emb]' )] = n
52
+ del n
53
+
54
+ num_code = int (gpt .emb_code [0 ].num_embeddings - 1 )
55
+
50
56
LogitsWarpers = []
51
57
if top_P is not None :
52
58
LogitsWarpers .append (TopPLogitsWarper (top_P , min_tokens_to_keep = 3 ))
@@ -58,10 +64,10 @@ def infer_code(
58
64
LogitsProcessors .append (CustomRepetitionPenaltyLogitsProcessorRepeat (\
59
65
repetition_penalty , num_code , 16 ))
60
66
61
- result = models [ ' gpt' ] .generate (
62
- emb , inputs [ ' input_ids' ] ,
67
+ result = gpt .generate (
68
+ emb , input_ids ,
63
69
temperature = torch .tensor (temperature , device = device ),
64
- attention_mask = inputs ['attention_mask' ],
70
+ attention_mask = text_token ['attention_mask' ],
65
71
LogitsWarpers = LogitsWarpers ,
66
72
LogitsProcessors = LogitsProcessors ,
67
73
eos_token = num_code ,
@@ -71,6 +77,11 @@ def infer_code(
71
77
** kwargs
72
78
)
73
79
80
+ del_all (text_token )
81
+ del emb , text_token , input_ids
82
+ del_all (LogitsWarpers )
83
+ del_all (LogitsProcessors )
84
+
74
85
return result
75
86
76
87
@@ -83,11 +94,12 @@ def refine_text(
83
94
repetition_penalty = 1.0 ,
84
95
max_new_token = 384 ,
85
96
prompt = '' ,
97
+ device = "cpu" ,
86
98
** kwargs
87
99
):
88
-
89
- device = next ( models ['gpt' ]. parameters ()). device
90
-
100
+
101
+ gpt : GPT_warpper = models ['gpt' ]
102
+
91
103
if not isinstance (text , list ):
92
104
text = [text ]
93
105
@@ -97,11 +109,7 @@ def refine_text(
97
109
text_token = models ['tokenizer' ](text , return_tensors = 'pt' , add_special_tokens = False , padding = True ).to (device )
98
110
text_mask = torch .ones (text_token ['input_ids' ].shape , dtype = bool , device = device )
99
111
100
- inputs = {
101
- 'input_ids' : text_token ['input_ids' ][...,None ].expand (- 1 , - 1 , models ['gpt' ].num_vq ),
102
- 'text_mask' : text_mask ,
103
- 'attention_mask' : text_token ['attention_mask' ],
104
- }
112
+ input_ids = text_token ['input_ids' ][...,None ].expand (- 1 , - 1 , gpt .num_vq )
105
113
106
114
LogitsWarpers = []
107
115
if top_P is not None :
@@ -112,11 +120,17 @@ def refine_text(
112
120
LogitsProcessors = []
113
121
if repetition_penalty is not None and repetition_penalty != 1 :
114
122
LogitsProcessors .append (CustomRepetitionPenaltyLogitsProcessorRepeat (repetition_penalty , len (models ['tokenizer' ]), 16 ))
115
-
116
- result = models ['gpt' ].generate (
117
- models ['gpt' ].get_emb (** inputs ), inputs ['input_ids' ],
123
+
124
+ emb = gpt .get_emb (
125
+ input_ids = input_ids ,
126
+ text_mask = text_mask ,
127
+ )
128
+ del text_mask
129
+
130
+ result = gpt .generate (
131
+ emb , input_ids ,
118
132
temperature = torch .tensor ([temperature ,], device = device ),
119
- attention_mask = inputs ['attention_mask' ],
133
+ attention_mask = text_token ['attention_mask' ],
120
134
LogitsWarpers = LogitsWarpers ,
121
135
LogitsProcessors = LogitsProcessors ,
122
136
eos_token = torch .tensor (models ['tokenizer' ].convert_tokens_to_ids ('[Ebreak]' ), device = device )[None ],
@@ -125,4 +139,10 @@ def refine_text(
125
139
stream = False ,
126
140
** kwargs
127
141
)
142
+
143
+ del_all (text_token )
144
+ del emb , text_token , input_ids
145
+ del_all (LogitsWarpers )
146
+ del_all (LogitsProcessors )
147
+
128
148
return next (result )
0 commit comments