diff --git a/lib/functions.py b/lib/functions.py index a1bb997d..0906eb81 100644 --- a/lib/functions.py +++ b/lib/functions.py @@ -648,92 +648,62 @@ def get_sentences(phoneme_list, max_tokens): def get_sentences(phoneme_list, max_tokens): """ Split a list of phoneme strings into sentences that do not exceed max_tokens. - If a sentence (or a long concatenation of phonemes) has no punctuation, - it is split by grouping words into chunks of max_tokens words. + If a sentence has no punctuation and exceeds max_tokens, it is split into chunks. + In normal cases, the sentence remains unchanged. """ sentences = [] current_sentence = "" current_token_count = 0 - # Helper: split a sentence into chunks of at most max_tokens words. + # Basic splitting by token count. def split_sentence_by_tokens(sentence, max_tokens): words = sentence.split() if len(words) <= max_tokens: return [sentence] - chunks = [] - for i in range(0, len(words), max_tokens): - chunk = " ".join(words[i:i + max_tokens]) - chunks.append(chunk) - return chunks - - # Helper: attempt to split a long sentence. - def split_long_sentence(sentence): - # If any punctuation is present, try to split near the middle + return [" ".join(words[i:i+max_tokens]) for i in range(0, len(words), max_tokens)] + + # Advanced splitting: only invoked if the sentence exceeds max_tokens. + def advanced_split(sentence): + words = sentence.split() + if len(words) <= max_tokens: + return [sentence] + # If punctuation exists, attempt to split at punctuation. if any(p in sentence for p in punctuation_split): - max_chars = max_tokens * 10 # as before - if len(sentence) <= max_chars: - return [sentence] - middle_index = len(sentence) // 2 - next_punc_index = -1 - for p in punctuation_split: - idx = sentence.find(p, middle_index) - if idx != -1: - if next_punc_index == -1 or idx < next_punc_index: - next_punc_index = idx - if next_punc_index != -1: - split_index = next_punc_index + 1 - first_part = sentence[:split_index].strip() - second_part = sentence[split_index:].strip() - # Further split each part by tokens if needed. - return split_sentence_by_tokens(first_part, max_tokens) + split_sentence_by_tokens(second_part, max_tokens) - else: - return split_sentence_by_tokens(sentence, max_tokens) - else: - # No punctuation found: force split by token count. - return split_sentence_by_tokens(sentence, max_tokens) + # Find the last punctuation in the sentence. + last_punc_index = max(sentence.rfind(p) for p in punctuation_split if p in sentence) + if last_punc_index != -1: + part1 = sentence[:last_punc_index+1].strip() + part2 = sentence[last_punc_index+1:].strip() + # Recursively split each part if necessary. + return advanced_split(part1) + advanced_split(part2) + # Fallback: split by tokens. + return split_sentence_by_tokens(sentence, max_tokens) for phoneme in phoneme_list: tokens = phoneme.split() token_count = len(tokens) if current_token_count + token_count > max_tokens: - # If current sentence ends with punctuation, try the splitting helper. - if any(current_sentence.endswith(p) for p in punctuation_split): - splits = split_long_sentence(current_sentence.strip()) - sentences.extend(splits) - current_sentence = phoneme - current_token_count = token_count + # Only apply advanced splitting if current_sentence exceeds the token limit. + if len(current_sentence.split()) > max_tokens: + sentences.extend(advanced_split(current_sentence.strip())) else: - # Look for the last punctuation inside current_sentence. - last_punc_index = -1 - for p in punctuation_split: - idx = current_sentence.rfind(p) - if idx > last_punc_index: - last_punc_index = idx - if last_punc_index != -1: - first_part = current_sentence[:last_punc_index + 1].strip() - second_part = current_sentence[last_punc_index + 1:].strip() - if first_part: - sentences.append(first_part) - # Combine the remainder with the current phoneme. - current_sentence = (second_part + " " + phoneme).strip() if second_part else phoneme - current_token_count = len(current_sentence.split()) - else: - # No punctuation found at all, force a split. - splits = split_sentence_by_tokens(current_sentence.strip(), max_tokens) - sentences.extend(splits) - current_sentence = phoneme - current_token_count = token_count + sentences.append(current_sentence.strip()) + current_sentence = phoneme + current_token_count = token_count else: current_sentence = (current_sentence + " " + phoneme).strip() if current_sentence else phoneme current_token_count += token_count if current_sentence: - splits = split_long_sentence(current_sentence.strip()) - sentences.extend(splits) + if len(current_sentence.split()) > max_tokens: + sentences.extend(advanced_split(current_sentence.strip())) + else: + sentences.append(current_sentence.strip()) return sentences + def get_vram(): os_name = platform.system() # NVIDIA (Cross-Platform: Windows, Linux, macOS)