From 854784880b75f89a90594a73585de9696ac0fcfd Mon Sep 17 00:00:00 2001 From: "matt@aero" Date: Tue, 4 Feb 2025 17:37:09 -0600 Subject: [PATCH] fix(transcribe): fix censor re has been added to imports and censor_path added to params. The goal is to allow users to create their own censor json file to use rather than have it supplied to them. A check is used to verify the file exists if the censor flag is set, and if it does not or it is not the proper file tye, the censor is disabled. Segments and full text are both censored. The returned dict was set to a variable called "data" to allow this to occur. To do so another way would be text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]) if not censor else censor_text(tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), forbidden_words).... which is much more difficult to read. BREAKING CHANGE: I have not confirmed issues yet, however it may be possible for the censor to bug if weird formats or improper design is put in place of the json file. Signed-off-by: matt@aero --- whisper/transcribe.py | 52 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 3 deletions(-) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0a4cc3623..a1154438a 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,6 +2,8 @@ import os import traceback import warnings +import json +import re from typing import TYPE_CHECKING, List, Optional, Tuple, Union import numpy as np @@ -52,6 +54,8 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + censor: bool = False, + censor_path: str = None, **decode_options, ): """ @@ -124,6 +128,8 @@ def transcribe( A dictionary containing the resulting text ("text") and segment-level details ("segments"), and the spoken language ("language"), which is detected when `decode_options["language"]` is None. """ + + dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32 if model.device == torch.device("cpu"): if torch.cuda.is_available(): @@ -165,6 +171,21 @@ def transcribe( task=task, ) + forbidden_words = [] + if censor: + if ( + censor_path is None + or not os.path.exists(censor_path) + or not censor_path.endswith(".json") + ): + warnings.warn("Please provide a valid censor directory, censoring disabled.") + censor = False + else: + with open(f'{censor_path}', 'r') as f: + censor_data = json.load(f) + + forbidden_words = censor_data.get(language, []) + if isinstance(clip_timestamps, str): clip_timestamps = [ float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else []) @@ -243,16 +264,32 @@ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult: else: initial_prompt_tokens = [] + def censor_text(text, forbidden): + + def censor_match(match): + word = match.group(0) + return '*' * len(word) if word.lower() in forbidden_words else word + + censored_text = re.sub(r'\w+|[^\w\s]', censor_match, text) + + return censored_text + def new_segment( *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult ): tokens = tokens.tolist() text_tokens = [token for token in tokens if token < tokenizer.eot] + + if censor: + text = censor_text(tokenizer.decode(text_tokens), forbidden_words) + else: + text = tokenizer.decode(text_tokens) + return { "seek": seek, "start": start, "end": end, - "text": tokenizer.decode(text_tokens), + "text": text, "tokens": tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, @@ -507,12 +544,19 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: # update progress bar pbar.update(min(content_frames, seek) - previous_seek) - return dict( - text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), + text = tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]) + + if censor: + text = censor_text(text, forbidden_words) + + data = dict( + text=text, segments=all_segments, language=language, ) + return data + def cli(): from . import available_models @@ -533,6 +577,8 @@ def valid_model_name(name): parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") + parser.add_argument("--censor", type=str2bool, default=True, help="(requires --censor_path=\"\") whether to censor out profanity or not") + parser.add_argument("--censor_path", type=str2bool, default=True, help="censored words path. Use json format - {lang: [words]}") parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')") parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")