Skip to content

Commit

Permalink
feat: add support for MPS device on macOS for whisper processing
Browse files Browse the repository at this point in the history
  • Loading branch information
zackees committed Feb 1, 2025
1 parent b31cc59 commit 0593635
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/transcribe_anything/_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import argparse
import json
import os
import platform
import sys
import traceback
from pathlib import Path
Expand Down Expand Up @@ -94,11 +95,14 @@ def parse_arguments() -> argparse.Namespace:
default=None,
choices=[None] + whisper_options["language"],
)
choices = [None, "cpu", "cuda", "insane"]
if platform.system() == "Darwin":
choices.append("mps")
parser.add_argument(
"--device",
help="device to use for processing, None will auto select CUDA if available or else CPU",
default=None,
choices=[None, "cpu", "cuda", "insane"],
choices=choices,
)
parser.add_argument(
"--hf_token",
Expand Down
12 changes: 12 additions & 0 deletions src/transcribe_anything/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from transcribe_anything.logger import log_error
from transcribe_anything.util import chop_double_extension, sanitize_filename
from transcribe_anything.whisper import get_computing_device, run_whisper
from transcribe_anything.whisper_mac import run_whisper_mac_english

DISABLED_WARNINGS = [
".*set_audio_backend has been deprecated.*",
Expand All @@ -51,6 +52,7 @@ class Device(Enum):
CPU = "cpu"
CUDA = "cuda"
INSANE = "insane"
MPS = "mps"

def __str__(self) -> str:
return self.value
Expand All @@ -67,6 +69,10 @@ def from_str(device: str) -> "Device":
return Device.CUDA
if device == "insane":
return Device.INSANE
if device == "mps":
if sys.platform != "darwin":
raise ValueError("MPS is only supported on macOS.")
return Device.MPS
raise ValueError(f"Unknown device {device}")


Expand Down Expand Up @@ -222,6 +228,10 @@ def transcribe(
print("#####################################")
elif device_enum == Device.CPU:
print("WARNING: NOT using GPU acceleration, using 10x slower CPU instead.")
elif device_enum == Device.MPS:
print("#####################################")
print("####### MAC MPS GPU MODE! ###########")
print("#####################################")
else:
raise ValueError(f"Unknown device {device}")
print(f"Using device {device}")
Expand All @@ -241,6 +251,8 @@ def transcribe(
hugging_face_token=hugging_face_token,
other_args=other_args,
)
elif device_enum == Device.MPS and (language_str == "" or language_str == "en" or language_str == "English"):
run_whisper_mac_english(input_wav=Path(tmp_wav), model=model_str, output_dir=Path(tmpdir), task=task_str)
else:
run_whisper(
input_wav=Path(tmp_wav),
Expand Down
78 changes: 78 additions & 0 deletions src/transcribe_anything/whisper_mac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
Runs whisper api.
"""

import subprocess
import sys
import time
from pathlib import Path
from typing import Optional

from iso_env import IsoEnv, IsoEnvArgs, PyProjectToml # type: ignore

HERE = Path(__file__).parent
CUDA_AVAILABLE: Optional[bool] = None

# whisper-mps --file-name tests/localfile/video.mp4


def get_environment() -> IsoEnv:
"""Returns the environment."""
venv_dir = HERE / "venv" / "whisper_darwin"
content_lines: list[str] = []

content_lines.append("[build-system]")
content_lines.append('requires = ["setuptools", "wheel"]')
content_lines.append('build-backend = "setuptools.build_meta"')
content_lines.append("")
content_lines.append("[project]")
content_lines.append('name = "project"')
content_lines.append('version = "0.1.0"')
content_lines.append('requires-python = "==3.10.*"')
content_lines.append("dependencies = [")
content_lines.append(' "whisper-mps",')
content_lines.append("]")
content = "\n".join(content_lines)
pyproject_toml = PyProjectToml(content)
args = IsoEnvArgs(venv_dir, build_info=pyproject_toml)
env = IsoEnv(args)
return env


def run_whisper_mac_english( # pylint: disable=too-many-arguments
input_wav: Path,
model: str,
output_dir: Path,
task: str,
) -> None:
"""Runs whisper."""
input_wav_abs = input_wav.resolve()
if not output_dir.exists():
output_dir.mkdir(parents=True)
env = get_environment()
cmd_list = []
cmd_list.append("whisper-mps")
cmd_list.append("--file-name")
cmd_list.append(input_wav.name) # cwd is set to the same directory as the input file.
if model:
cmd_list.append("--model")
cmd_list.append(model)
cmd_list.append("--output_dir")
cmd_list.append(str(output_dir))
cmd_list.append("--task")
cmd_list.append(task)
# Remove the empty strings.
cmd_list = [str(x).strip() for x in cmd_list if str(x).strip()]
# cmd = " ".join(cmd_list)
cmd = subprocess.list2cmdline(cmd_list)
sys.stderr.write(f"Running:\n {cmd}\n")
proc = env.open_proc(cmd_list, shell=False, cwd=input_wav_abs.parent)
while True:
rtn = proc.poll()
if rtn is None:
time.sleep(0.25)
continue
if rtn != 0:
msg = f"Failed to execute {cmd}\n "
raise OSError(msg)
break

0 comments on commit 0593635

Please sign in to comment.