From d024607bfbbd87a128e53cb4c81f0276fe66939c Mon Sep 17 00:00:00 2001 From: zackees Date: Thu, 11 Jan 2024 19:04:47 -0800 Subject: [PATCH] improves insanely fast whisper with tested outputs and conversion from json to srt/txt --- .gitignore | 3 +- tests/test_insanely_fast_whisper.py | 37 +++-- ...t_insanely_fast_whisper_json_to_srt_txt.py | 73 ++++++++++ transcribe_anything/cuda_available.py | 89 ++++++++++-- transcribe_anything/insanely_fast_whisper.py | 133 +++++++++++++----- transcribe_anything/whisper.py | 2 +- 6 files changed, 280 insertions(+), 57 deletions(-) create mode 100644 tests/test_insanely_fast_whisper_json_to_srt_txt.py diff --git a/.gitignore b/.gitignore index a45aec9..7021ada 100644 --- a/.gitignore +++ b/.gitignore @@ -129,4 +129,5 @@ dmypy.json .pyre/ activate.sh tests/test_data -tests/localfile/text_video \ No newline at end of file +tests/localfile/text_video +tests/localfile/text_video_insane \ No newline at end of file diff --git a/tests/test_insanely_fast_whisper.py b/tests/test_insanely_fast_whisper.py index b473b45..d688146 100644 --- a/tests/test_insanely_fast_whisper.py +++ b/tests/test_insanely_fast_whisper.py @@ -10,31 +10,40 @@ import shutil from pathlib import Path +from transcribe_anything.insanely_fast_whisper import ( + run_insanely_fast_whisper, has_nvidia_smi, CudaInfo, get_cuda_info +) + HERE = Path(os.path.abspath(os.path.dirname(__file__))) LOCALFILE_DIR = HERE / "localfile" -TESTS_DATA_DIR = LOCALFILE_DIR / "text_video" / "en" +TESTS_DATA_DIR = LOCALFILE_DIR / "text_video_insane" / "en" TEST_WAV = LOCALFILE_DIR / "video.wav" class InsanelFastWhisperTester(unittest.TestCase): """Tester for transcribe anything.""" - @unittest.skip("DISABLED FOR NOW - WORK IN PROGRESS") + @unittest.skipUnless(has_nvidia_smi(), "No GPU detected") def test_local_file(self) -> None: """Check that the command works on a local file.""" shutil.rmtree(TESTS_DATA_DIR, ignore_errors=True) - #run_insanely_fast_whisper( - # input_wav=TEST_WAV, - # #device="cuda", - # #device="cpu", - # device="cuda:0", - # model="small", - # output_dir=TESTS_DATA_DIR, - # task="transcribe", - # language="en", - # other_args=None, - #) - + run_insanely_fast_whisper( + input_wav=TEST_WAV, + model="small", + output_dir=TESTS_DATA_DIR, + task="transcribe", + language="en", + other_args=None, + ) + + @unittest.skipUnless(has_nvidia_smi(), "No GPU detected") + def test_cuda_info(self) -> None: + """Check that the command works on a local file.""" + cuda_info0 = get_cuda_info() + out = cuda_info0.to_json_str() + cuda_info1 = CudaInfo.from_json_str(out) + print(out) + self.assertEqual(cuda_info0, cuda_info1) if __name__ == "__main__": diff --git a/tests/test_insanely_fast_whisper_json_to_srt_txt.py b/tests/test_insanely_fast_whisper_json_to_srt_txt.py new file mode 100644 index 0000000..5723a34 --- /dev/null +++ b/tests/test_insanely_fast_whisper_json_to_srt_txt.py @@ -0,0 +1,73 @@ + + + + +""" +Tests transcribe_anything +""" + +# pylint: disable=bad-option-value,useless-option-value,no-self-use,protected-access,R0801 +# flake8: noqa E501 + +import unittest +import json + +from transcribe_anything.insanely_fast_whisper import convert_json_to_srt, convert_json_to_text + +EXAMPLE_JSON = """ +{ + "speakers": [], + "chunks": [ + { + "timestamp": [ + 0, + 3.24 + ], + "text": " Oh wow, I'm so nervous." + }, + { + "timestamp": [ + 3.24, + 6.56 + ], + "text": " Gosh, these lights are so bright." + }, + { + "timestamp": [ + 6.56, + 8.52 + ], + "text": " Is this mic on?" + }, + { + "timestamp": [ + 8.52, + 9.52 + ], + "text": " Is there even a mic?" + } + ], + "text": " Oh wow, I'm so nervous. Gosh, these lights are so bright. Is this mic on? Is there even a mic?" +} +""" + +class JsonToSrtTester(unittest.TestCase): + """Tester for transcribe anything.""" + + def test_json_to_srt(self) -> None: + """Check that the command works on a local file.""" + data = json.loads(EXAMPLE_JSON) + out = convert_json_to_srt(data) + print(out) + print() + + def test_json_to_txt(self) -> None: + """Check that the command works on a local file.""" + data = json.loads(EXAMPLE_JSON) + out = convert_json_to_text(data) + print(out) + print() + + +if __name__ == "__main__": + unittest.main() diff --git a/transcribe_anything/cuda_available.py b/transcribe_anything/cuda_available.py index ed61d28..b19f655 100644 --- a/transcribe_anything/cuda_available.py +++ b/transcribe_anything/cuda_available.py @@ -1,19 +1,92 @@ -# pylint: disable - """ -Returns 0 if cuda is available, 1 otherwise.. This is -designed to be run in an isolated environment. +Queries the system for CUDA devices and returns a json string with the information. +This is meant to be run under a "isolated-environment". """ + import sys -import torch # type: ignore # pylint: disable=import-error +from dataclasses import dataclass, asdict, fields +import json +from typing import Any + +import torch + + +@dataclass +class CudaDevice: + """Data class to hold CUDA device information.""" + name: str + vram: int # VRAM in bytes + multiprocessors: int # Number of multiprocessors + device_id: int + + def __str__(self): + return (f"{self.name} - VRAM: {self.vram / (1024 ** 3):.2f} GB, " + f"Multiprocessors: {self.multiprocessors}") + def to_json(self) -> dict[str, str | int]: + """Returns a dictionary representation of the object.""" + return asdict(self) + + @staticmethod + def from_json(json_data: dict[str, str | int]) -> 'CudaDevice': + """Returns a CudaDevice object from a dictionary.""" + return CudaDevice(**json_data) # type: ignore + +@dataclass +class CudaInfo: + """Cuda info.""" + cuda_available: bool + num_cuda_devices: int + cuda_devices: list[CudaDevice] + + def to_json_str(self) -> str: + """Returns json str.""" + # Convert dataclass to dictionary for serialization + data = self.to_json() + return json.dumps(data, indent=4, sort_keys=True) + + def to_json(self) -> dict[str, Any]: + """Returns a dictionary representation of the object.""" + out = {} + for field in fields(self): + out[field.name] = getattr(self, field.name) + if field.name == 'cuda_devices': + out[field.name] = [device.to_json() for device in out[field.name]] + return out + + @staticmethod + def from_json_str(json_str: str) -> 'CudaInfo': + """Loads from json str and returns a CudaInfo object.""" + data = json.loads(json_str) + cuda_devices_data = data.get('cuda_devices', []) + cuda_devices = [CudaDevice(**device) for device in cuda_devices_data] + return CudaInfo(data['cuda_available'], data['num_cuda_devices'], cuda_devices) + +def cuda_cards_available() -> CudaInfo: + """ + Returns a CudaInfo object with information about the CUDA cards, + ordered by VRAM and multiprocessors. + """ + if torch.cuda.is_available(): + devices = [ + CudaDevice( + name=torch.cuda.get_device_name(i), + vram=torch.cuda.get_device_properties(i).total_memory, + multiprocessors=torch.cuda.get_device_properties(i).multi_processor_count, + device_id=i + ) for i in range(torch.cuda.device_count()) + ] + # Sort devices by VRAM and then by number of multiprocessors in descending order + devices.sort(key=lambda x: (x.vram, x.multiprocessors), reverse=True) + return CudaInfo(True, len(devices), devices) + return CudaInfo(False, 0, []) def main() -> int: """Returns 0 if cuda is available, 1 otherwise.""" - if torch.cuda.is_available(): - print("CUDA is available") + cuda_info = cuda_cards_available() + print(cuda_info.to_json_str()) + if cuda_info.cuda_available: return 0 - print("CUDA is not available") return 1 if __name__ == '__main__': diff --git a/transcribe_anything/insanely_fast_whisper.py b/transcribe_anything/insanely_fast_whisper.py index abb23e3..69b2284 100644 --- a/transcribe_anything/insanely_fast_whisper.py +++ b/transcribe_anything/insanely_fast_whisper.py @@ -8,14 +8,18 @@ import shutil import sys import time +import json from pathlib import Path import subprocess -from typing import Optional +from typing import Optional, Any from isolated_environment import IsolatedEnvironment # type: ignore +from transcribe_anything.cuda_available import CudaInfo HERE = Path(__file__).parent ENV: Optional[IsolatedEnvironment] = None +CUDA_INFO: Optional[CudaInfo] = None + # Set the versions TENSOR_VERSION = "2.1.2" @@ -34,39 +38,85 @@ def has_nvidia_smi() -> bool: """Returns True if nvidia-smi is installed.""" return shutil.which("nvidia-smi") is not None - -def install_whisper_if_necessary() -> None: - """Installs whisper if necessary.""" - install_gpu = has_nvidia_smi() - gpu_requirements = [ - "torch==2.1.2", - "openai-whisper" - ] - TENSOR_VERSION = "2.1.2" - CUDA_VERSION = "cu121" - TENSOR_CUDA_VERSION = f"{TENSOR_VERSION}+{CUDA_VERSION}" - EXTRA_INDEX_URL = f"https://download.pytorch.org/whl/{CUDA_VERSION}" - - # Installing using pipx - try: - # Step 1: Install Python 3.11 (handled externally, not via Python script) - # Step 2: Install insanely-fast-whisper using pipx - subprocess.run(["pipx", "install", "insanely-fast-whisper", "--python", "python3.11"], check=True) - - # Steps 3-5: Injecting packages into the pipx environment - subprocess.run(["pipx", "inject", "insanely-fast-whisper", "torch==2.1.2"], check=True) - subprocess.run(["pipx", "inject", "insanely-fast-whisper", "openai-whisper"], check=True) - subprocess.run(["pipx", "inject", "insanely-fast-whisper", "transformers"], check=True) - - print("Whisper installation and configuration complete.") - except subprocess.CalledProcessError as e: - print(f"An error occurred during installation: {e}", file=sys.stderr) - +def get_environment() -> IsolatedEnvironment: + """Returns the environment.""" + global ENV # pylint: disable=global-statement + if ENV is not None: + return ENV + venv_dir = HERE / "venv" / "insanely_fast_whisper" + env = IsolatedEnvironment(venv_dir) + if not venv_dir.exists(): + env.install_environment() + if has_nvidia_smi(): + env.pip_install(f"torch=={TENSOR_VERSION}", extra_index=EXTRA_INDEX_URL) + else: + env.pip_install(f"torch=={TENSOR_VERSION}") + env.pip_install("openai-whisper") + env.pip_install("insanely-fast-whisper") + ENV = env + return env + + +def get_cuda_info() -> CudaInfo: + """Get the computing device.""" + global CUDA_INFO # pylint: disable=global-statement + if CUDA_INFO is None: + iso_env = get_environment() + env = iso_env.environment() + py_file = HERE / "cuda_available.py" + cp: subprocess.CompletedProcess = subprocess.run([ + "python", py_file + ], check=False, env=env, universal_newlines=True, stdout=subprocess.PIPE) + stdout = cp.stdout + CUDA_INFO = CudaInfo.from_json_str(stdout) + return CUDA_INFO + +def get_device_id() -> str: + """Get the device id.""" + # on mac, we just return "mps" + if sys.platform == "darwin": + return "mps" + cuda_info = get_cuda_info() + if not cuda_info.cuda_available: + raise ValueError("CUDA is not available.") + device_id = cuda_info.cuda_devices[0].device_id + return f"{device_id}" + +def get_batch_size() -> int | None: + """Returns the batch size.""" + if sys.platform == "darwin": + return 4 + return None + +def convert_time_to_srt_format(timestamp: float) -> str: + """Converts timestamp in seconds to SRT time format (hours:minutes:seconds,milliseconds).""" + hours, remainder = divmod(timestamp, 3600) + minutes, seconds = divmod(remainder, 60) + milliseconds = int((seconds % 1) * 1000) + seconds = int(seconds) + return f"{hours:02}:{minutes:02}:{seconds:02},{milliseconds:03}" + +def convert_json_to_srt(json_data: dict[str, Any]) -> str: + """Converts JSON data from speech-to-text tool to SRT format.""" + srt_content = "" + for index, chunk in enumerate(json_data['chunks'], start=1): + start_time, end_time = chunk['timestamp'] + start_time_str = convert_time_to_srt_format(start_time) + end_time_str = convert_time_to_srt_format(end_time) + text = str(chunk['text']).strip() + srt_content += f"{index}\n{start_time_str} --> {end_time_str}\n{text}\n\n" + return srt_content + +def convert_json_to_text(json_data: dict[str, Any]) -> str: + """Converts JSON data from speech-to-text tool to text.""" + text = "" + for chunk in json_data['chunks']: + text += str(chunk['text']).strip() + "\n" + return text def run_insanely_fast_whisper( # pylint: disable=too-many-arguments input_wav: Path, - device: str, # pylint: disable=unused-argument model: str, output_dir: Path, task: str, @@ -74,7 +124,11 @@ def run_insanely_fast_whisper( # pylint: disable=too-many-arguments other_args: Optional[list[str]] ) -> None: """Runs insanely fast whisper.""" + iso_env = get_environment() + device_id = get_device_id() cmd_list = [] + output_dir.mkdir(parents=True, exist_ok=True) + outfile = output_dir / "out.json" model = f"openai/whisper-{model}" if sys.platform == "win32": # Set the text mode to UTF-8 on Windows. @@ -82,12 +136,15 @@ def run_insanely_fast_whisper( # pylint: disable=too-many-arguments cmd_list += [ "insanely-fast-whisper", "--file-name", str(input_wav), - "--device-id", "0", + "--device-id", f"{device_id}", "--model-name", model, "--task", task, "--language", language, - "--transcript-path", str(output_dir), + "--transcript-path", str(outfile), ] + batch_size = get_batch_size() + if batch_size is not None: + cmd_list += ["--batch-size", f"{batch_size}"] if other_args: cmd_list.extend(other_args) # Remove the empty strings. @@ -96,7 +153,7 @@ def run_insanely_fast_whisper( # pylint: disable=too-many-arguments sys.stderr.write(f"Running:\n {cmd}\n") proc = subprocess.Popen( # pylint: disable=consider-using-with cmd, shell=True, universal_newlines=True, - encoding="utf-8" + encoding="utf-8", env=iso_env.environment() ) while True: rtn = proc.poll() @@ -107,3 +164,13 @@ def run_insanely_fast_whisper( # pylint: disable=too-many-arguments msg = f"Failed to execute {cmd}\n " raise OSError(msg) break + assert outfile.exists(), f"Expected {outfile} to exist." + json_text = outfile.read_text(encoding="utf-8") + json_data = json.loads(json_text) + srt_content = convert_json_to_srt(json_data) + srt_file = output_dir / "out.srt" + txt_content = convert_json_to_text(json_data) + srt_file.write_text(srt_content, encoding="utf-8") + txt_file = output_dir / "out.txt" + txt_file.write_text(txt_content, encoding="utf-8") + diff --git a/transcribe_anything/whisper.py b/transcribe_anything/whisper.py index b36015e..7dfda52 100644 --- a/transcribe_anything/whisper.py +++ b/transcribe_anything/whisper.py @@ -36,7 +36,7 @@ def get_environment() -> IsolatedEnvironment: if ENV is not None: return ENV venv_dir = HERE / "venv" / "whisper" - env = IsolatedEnvironment(HERE / "venv" / "whisper") + env = IsolatedEnvironment(venv_dir) if not venv_dir.exists(): env.install_environment() if has_nvidia_smi():