Skip to content

Commit b62e0dc

Browse files
authored
optimize(utils): move custom processors into model (#419)
1 parent e0a9e7e commit b62e0dc

File tree

6 files changed

+50
-49
lines changed

6 files changed

+50
-49
lines changed

ChatTTS/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _load(
119119
coef: Optional[str] = None
120120
):
121121
if device is None:
122-
device = select_device(4096)
122+
device = select_device()
123123
self.logger.log(logging.INFO, f'use {device}')
124124
self.device = device
125125

ChatTTS/infer/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn.functional as F
44
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
55

6-
from ..utils.infer import CustomRepetitionPenaltyLogitsProcessorRepeat
6+
from ..model.processors import CustomRepetitionPenaltyLogitsProcessorRepeat
77
from ..utils.io import del_all
88
from ..model.gpt import GPT
99

ChatTTS/model/gpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from transformers.cache_utils import Cache
1717
from transformers.modeling_outputs import BaseModelOutputWithPast
1818

19-
from ..utils.infer import CustomRepetitionPenaltyLogitsProcessorRepeat
19+
from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat
2020
from ..utils.io import del_all
2121

2222

ChatTTS/model/processors.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
5+
class CustomRepetitionPenaltyLogitsProcessorRepeat():
6+
7+
def __init__(self, penalty: float, max_input_ids, past_window):
8+
if not isinstance(penalty, float) or not (penalty > 0):
9+
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
10+
11+
self.penalty = penalty
12+
self.max_input_ids = max_input_ids
13+
self.past_window = past_window
14+
15+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
16+
17+
input_ids = input_ids[:, -self.past_window:]
18+
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
19+
freq[self.max_input_ids:] = 0
20+
alpha = self.penalty**freq
21+
scores = scores.contiguous()
22+
scores = torch.where(scores < 0, scores*alpha, scores/alpha)
23+
24+
return scores
25+
26+
class CustomRepetitionPenaltyLogitsProcessor():
27+
28+
def __init__(self, penalty: float, max_input_ids, past_window):
29+
if not isinstance(penalty, float) or not (penalty > 0):
30+
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
31+
32+
self.penalty = penalty
33+
self.max_input_ids = max_input_ids
34+
self.past_window = past_window
35+
36+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
37+
38+
input_ids = input_ids[:, -self.past_window:]
39+
score = torch.gather(scores, 1, input_ids)
40+
_score = score.detach().clone()
41+
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
42+
score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
43+
scores.scatter_(1, input_ids, score)
44+
45+
return scores

ChatTTS/utils/gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from .log import logger
55

6-
def select_device(min_memory=2048):
6+
def select_device(min_memory=2047):
77
if torch.cuda.is_available():
88
available_gpus = []
99
for i in range(torch.cuda.device_count()):

ChatTTS/utils/infer.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,7 @@
55

66
from numba import jit
77
import numpy as np
8-
import torch
9-
import torch.nn.functional as F
10-
11-
12-
class CustomRepetitionPenaltyLogitsProcessorRepeat():
13-
14-
def __init__(self, penalty: float, max_input_ids, past_window):
15-
if not isinstance(penalty, float) or not (penalty > 0):
16-
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
17-
18-
self.penalty = penalty
19-
self.max_input_ids = max_input_ids
20-
self.past_window = past_window
21-
22-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
23-
24-
input_ids = input_ids[:, -self.past_window:]
25-
freq = F.one_hot(input_ids, scores.size(1)).sum(1)
26-
freq[self.max_input_ids:] = 0
27-
alpha = self.penalty**freq
28-
scores = scores.contiguous()
29-
scores = torch.where(scores < 0, scores*alpha, scores/alpha)
30-
31-
return scores
32-
33-
class CustomRepetitionPenaltyLogitsProcessor():
34-
35-
def __init__(self, penalty: float, max_input_ids, past_window):
36-
if not isinstance(penalty, float) or not (penalty > 0):
37-
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
38-
39-
self.penalty = penalty
40-
self.max_input_ids = max_input_ids
41-
self.past_window = past_window
42-
43-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
44-
45-
input_ids = input_ids[:, -self.past_window:]
46-
score = torch.gather(scores, 1, input_ids)
47-
_score = score.detach().clone()
48-
score = torch.where(score < 0, score * self.penalty, score / self.penalty)
49-
score[input_ids>=self.max_input_ids] = _score[input_ids>=self.max_input_ids]
50-
scores.scatter_(1, input_ids, score)
51-
52-
return scores
8+
539

5410
@jit
5511
def _find_index(table: np.ndarray, val: np.uint16):

0 commit comments

Comments
 (0)