Skip to content

Commit

Permalink
adds srt wrap and upgrades isolated_environment to use the new patter…
Browse files Browse the repository at this point in the history
…n of installing dependencies
  • Loading branch information
zackees committed Jan 14, 2024
1 parent d592046 commit 07ffc65
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 43 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ static-ffmpeg>=2.5
yt-dlp>=2023.3.4
appdirs==1.4.4
disklru>=1.0.7
isolated-environment>=1.2.0
isolated-environment>=1.2.3
json5
FileLock
webvtt-py==0.4.6
Expand Down
0 test.sh → test
100755 → 100644
File renamed without changes.
3 changes: 3 additions & 0 deletions tests/localfile/long.srt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
1
00:00:00,000 --> 00:00:03,240
sdjfksdl ksdj fklsd;jf sdl;kfj sd;lkj fdklsfj sd;klf jwsdk fsdj fksdj fsdl;k fjds;kl fjsdkl; fjsd;klf jksdlj
40 changes: 40 additions & 0 deletions tests/test_srt_wrap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
Tests transcribe_anything
"""

# pylint: disable=bad-option-value,useless-option-value,no-self-use,protected-access,R0801
# flake8: noqa E501

import os
import unittest
import shutil
from pathlib import Path
import tempfile
from transcribe_anything.insanely_fast_whisper import (
srt_wrap_to_string,
has_nvidia_smi,
)


HERE = Path(os.path.abspath(os.path.dirname(__file__)))
LOCALFILE_DIR = HERE / "localfile"
TEST_SRT = LOCALFILE_DIR / "long.srt"


class InsanelFastWhisperTester(unittest.TestCase):
"""Tester for transcribe anything."""

@unittest.skipUnless(has_nvidia_smi(), "No GPU detected")
def test_srt_wrap(self) -> None:
"""Check that the command works on a local file."""
with tempfile.TemporaryDirectory() as tempdir:
td = Path(tempdir)
target = td / "long.srt"
shutil.copy(TEST_SRT, target)
wrapped_srt = srt_wrap_to_string(target)
print(wrapped_srt)
print()


if __name__ == "__main__":
unittest.main()
26 changes: 24 additions & 2 deletions transcribe_anything/insanely_fast_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pathlib import Path
import subprocess
from typing import Optional, Any
import warnings

import webvtt # type: ignore
from isolated_environment import isolated_environment # type: ignore
Expand All @@ -26,6 +27,7 @@
CUDA_VERSION = "cu121"
TENSOR_CUDA_VERSION = f"{TENSOR_VERSION}+{CUDA_VERSION}"
EXTRA_INDEX_URL = f"https://download.pytorch.org/whl/{CUDA_VERSION}"
WRAP_SRT_PY = HERE / "srt_wrap.py"


def get_current_python_version() -> str:
Expand All @@ -50,7 +52,7 @@ def get_environment() -> dict[str, Any]:
"insanely-fast-whisper",
]
if has_nvidia_smi():
deps.append(f"torch=={TENSOR_CUDA_VERSION}")
deps.append(f"torch=={TENSOR_CUDA_VERSION} --extra-index-url {EXTRA_INDEX_URL}")
else:
deps.append(f"torch=={TENSOR_VERSION}")
deps += [
Expand All @@ -68,6 +70,7 @@ def get_cuda_info() -> CudaInfo:
py_file = HERE / "cuda_available.py"
cp: subprocess.CompletedProcess = subprocess.run(
["python", py_file],
shell=True,
check=False,
env=env,
universal_newlines=True,
Expand Down Expand Up @@ -180,8 +183,26 @@ def visit(node: dict[str, Any]) -> None:
visit(json_data)


def srt_wrap_to_string(srt_file: Path) -> str:
env = get_environment()
process = subprocess.run(
["python", str(WRAP_SRT_PY), str(srt_file)],
env=env,
capture_output=True,
text=True,
shell=True,
)
out = process.stdout
return out


def srt_wrap(srt_file: Path) -> None:
pass
try:
out = srt_wrap_to_string(srt_file)
WRAP_SRT_PY.write_text(out, encoding="utf-8")
except subprocess.CalledProcessError as exc:
warnings.warn(f"Failed to run srt_wrap: {exc}")
return


def run_insanely_fast_whisper(
Expand Down Expand Up @@ -271,6 +292,7 @@ def run_insanely_fast_whisper(
error_file.write_text(json_text, encoding="utf-8")
raise
srt_file.write_text(srt_content, encoding="utf-8")
srt_wrap(srt_file)
txt_file = output_dir / "out.txt"
txt_file.write_text(txt_content, encoding="utf-8")
convert_to_webvtt(srt_file, output_dir / "out.vtt")
Expand Down
37 changes: 0 additions & 37 deletions transcribe_anything/srt_wrap.py
Original file line number Diff line number Diff line change
@@ -1,37 +0,0 @@
"""
Wraps the srt file.
"""

# designed to be run in an isolated environment.

from argparse import ArgumentParser, Namespace
import sys
from srtranslator import SrtFile # type: ignore


# srtranslator==0.2.6
def srt_wrap(srt_file: str, out_file: str) -> None:
"""Wrap lines in a srt file."""
srt = SrtFile(srt_file)
srt.wrap_lines()
srt.save(out_file)


def create_args() -> Namespace:
"""Create args."""
parser = ArgumentParser(description="Wrap lines in a srt file.")
parser.add_argument("src_srt_file", help="The srt file to wrap.")
parser.add_argument("dst_srt_file", help="The output file.", nargs="?")
args = parser.parse_args()
return args


def main() -> int:
"""Main entry point for the command line tool."""
args = create_args()
srt_wrap(args.src_srt_file, args.dst_srt_file or args.src_srt_file)
return 0


if __name__ == "__main__":
sys.exit(main())
4 changes: 1 addition & 3 deletions transcribe_anything/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
import subprocess
from typing import Optional
from typing import Any
from filelock import FileLock
from isolated_environment import isolated_environment # type: ignore

HERE = Path(__file__).parent
CUDA_AVAILABLE: Optional[bool] = None
ENV_LOCK = FileLock(HERE / "whisper_env.lock")

# Set the versions
TENSOR_VERSION = "2.1.2"
Expand All @@ -38,7 +36,7 @@ def get_environment() -> dict[str, Any]:
"openai-whisper",
]
if has_nvidia_smi():
deps.append(f"torch=={TENSOR_VERSION}+{CUDA_VERSION}")
deps.append(f"torch=={TENSOR_VERSION}+{CUDA_VERSION} --extra-index-url {EXTRA_INDEX_URL}")
else:
deps.append(f"torch=={TENSOR_VERSION}")
env = isolated_environment(venv_dir, deps)
Expand Down

0 comments on commit 07ffc65

Please sign in to comment.