Skip to content

Commit

Permalink
improves insanely fast whisper with tested outputs and conversion fro…
Browse files Browse the repository at this point in the history
…m json to srt/txt
  • Loading branch information
zackees committed Jan 12, 2024
1 parent 338cf75 commit d024607
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 57 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,5 @@ dmypy.json
.pyre/
activate.sh
tests/test_data
tests/localfile/text_video
tests/localfile/text_video
tests/localfile/text_video_insane
37 changes: 23 additions & 14 deletions tests/test_insanely_fast_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,40 @@
import shutil
from pathlib import Path

from transcribe_anything.insanely_fast_whisper import (
run_insanely_fast_whisper, has_nvidia_smi, CudaInfo, get_cuda_info
)

HERE = Path(os.path.abspath(os.path.dirname(__file__)))
LOCALFILE_DIR = HERE / "localfile"
TESTS_DATA_DIR = LOCALFILE_DIR / "text_video" / "en"
TESTS_DATA_DIR = LOCALFILE_DIR / "text_video_insane" / "en"
TEST_WAV = LOCALFILE_DIR / "video.wav"


class InsanelFastWhisperTester(unittest.TestCase):
"""Tester for transcribe anything."""

@unittest.skip("DISABLED FOR NOW - WORK IN PROGRESS")
@unittest.skipUnless(has_nvidia_smi(), "No GPU detected")
def test_local_file(self) -> None:
"""Check that the command works on a local file."""
shutil.rmtree(TESTS_DATA_DIR, ignore_errors=True)
#run_insanely_fast_whisper(
# input_wav=TEST_WAV,
# #device="cuda",
# #device="cpu",
# device="cuda:0",
# model="small",
# output_dir=TESTS_DATA_DIR,
# task="transcribe",
# language="en",
# other_args=None,
#)

run_insanely_fast_whisper(
input_wav=TEST_WAV,
model="small",
output_dir=TESTS_DATA_DIR,
task="transcribe",
language="en",
other_args=None,
)

@unittest.skipUnless(has_nvidia_smi(), "No GPU detected")
def test_cuda_info(self) -> None:
"""Check that the command works on a local file."""
cuda_info0 = get_cuda_info()
out = cuda_info0.to_json_str()
cuda_info1 = CudaInfo.from_json_str(out)
print(out)
self.assertEqual(cuda_info0, cuda_info1)


if __name__ == "__main__":
Expand Down
73 changes: 73 additions & 0 deletions tests/test_insanely_fast_whisper_json_to_srt_txt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@




"""
Tests transcribe_anything
"""

# pylint: disable=bad-option-value,useless-option-value,no-self-use,protected-access,R0801
# flake8: noqa E501

import unittest
import json

from transcribe_anything.insanely_fast_whisper import convert_json_to_srt, convert_json_to_text

EXAMPLE_JSON = """
{
"speakers": [],
"chunks": [
{
"timestamp": [
0,
3.24
],
"text": " Oh wow, I'm so nervous."
},
{
"timestamp": [
3.24,
6.56
],
"text": " Gosh, these lights are so bright."
},
{
"timestamp": [
6.56,
8.52
],
"text": " Is this mic on?"
},
{
"timestamp": [
8.52,
9.52
],
"text": " Is there even a mic?"
}
],
"text": " Oh wow, I'm so nervous. Gosh, these lights are so bright. Is this mic on? Is there even a mic?"
}
"""

class JsonToSrtTester(unittest.TestCase):
"""Tester for transcribe anything."""

def test_json_to_srt(self) -> None:
"""Check that the command works on a local file."""
data = json.loads(EXAMPLE_JSON)
out = convert_json_to_srt(data)
print(out)
print()

def test_json_to_txt(self) -> None:
"""Check that the command works on a local file."""
data = json.loads(EXAMPLE_JSON)
out = convert_json_to_text(data)
print(out)
print()


if __name__ == "__main__":
unittest.main()
89 changes: 81 additions & 8 deletions transcribe_anything/cuda_available.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,92 @@
# pylint: disable

"""
Returns 0 if cuda is available, 1 otherwise.. This is
designed to be run in an isolated environment.
Queries the system for CUDA devices and returns a json string with the information.
This is meant to be run under a "isolated-environment".
"""


import sys
import torch # type: ignore # pylint: disable=import-error
from dataclasses import dataclass, asdict, fields
import json
from typing import Any

import torch


@dataclass
class CudaDevice:
"""Data class to hold CUDA device information."""
name: str
vram: int # VRAM in bytes
multiprocessors: int # Number of multiprocessors
device_id: int

def __str__(self):
return (f"{self.name} - VRAM: {self.vram / (1024 ** 3):.2f} GB, "
f"Multiprocessors: {self.multiprocessors}")
def to_json(self) -> dict[str, str | int]:
"""Returns a dictionary representation of the object."""
return asdict(self)

@staticmethod
def from_json(json_data: dict[str, str | int]) -> 'CudaDevice':
"""Returns a CudaDevice object from a dictionary."""
return CudaDevice(**json_data) # type: ignore

@dataclass
class CudaInfo:
"""Cuda info."""
cuda_available: bool
num_cuda_devices: int
cuda_devices: list[CudaDevice]

def to_json_str(self) -> str:
"""Returns json str."""
# Convert dataclass to dictionary for serialization
data = self.to_json()
return json.dumps(data, indent=4, sort_keys=True)

def to_json(self) -> dict[str, Any]:
"""Returns a dictionary representation of the object."""
out = {}
for field in fields(self):
out[field.name] = getattr(self, field.name)
if field.name == 'cuda_devices':
out[field.name] = [device.to_json() for device in out[field.name]]
return out

@staticmethod
def from_json_str(json_str: str) -> 'CudaInfo':
"""Loads from json str and returns a CudaInfo object."""
data = json.loads(json_str)
cuda_devices_data = data.get('cuda_devices', [])
cuda_devices = [CudaDevice(**device) for device in cuda_devices_data]
return CudaInfo(data['cuda_available'], data['num_cuda_devices'], cuda_devices)

def cuda_cards_available() -> CudaInfo:
"""
Returns a CudaInfo object with information about the CUDA cards,
ordered by VRAM and multiprocessors.
"""
if torch.cuda.is_available():
devices = [
CudaDevice(
name=torch.cuda.get_device_name(i),
vram=torch.cuda.get_device_properties(i).total_memory,
multiprocessors=torch.cuda.get_device_properties(i).multi_processor_count,
device_id=i
) for i in range(torch.cuda.device_count())
]
# Sort devices by VRAM and then by number of multiprocessors in descending order
devices.sort(key=lambda x: (x.vram, x.multiprocessors), reverse=True)
return CudaInfo(True, len(devices), devices)
return CudaInfo(False, 0, [])

def main() -> int:
"""Returns 0 if cuda is available, 1 otherwise."""
if torch.cuda.is_available():
print("CUDA is available")
cuda_info = cuda_cards_available()
print(cuda_info.to_json_str())
if cuda_info.cuda_available:
return 0
print("CUDA is not available")
return 1

if __name__ == '__main__':
Expand Down
Loading

0 comments on commit d024607

Please sign in to comment.