Skip to content

Commit

Permalink
Peformance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
eleanorTurintech committed Feb 3, 2025
1 parent 517a43e commit 30abb70
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 10 deletions.
30 changes: 21 additions & 9 deletions whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")

def compute_sha256(file_path: str) -> str:
sha256 = hashlib.sha256()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()

if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
if compute_sha256(download_target) == expected_sha256:
if in_memory:
with open(download_target, "rb") as f:
return f.read()
else:
return download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
Expand All @@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
output.write(buffer)
loop.update(len(buffer))

model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
if compute_sha256(download_target) != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)

return model_bytes if in_memory else download_target
if in_memory:
with open(download_target, "rb") as f:
return f.read()
else:
return download_target


def available_models() -> List[str]:
Expand Down Expand Up @@ -147,7 +159,7 @@ def load_model(
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
checkpoint = torch.load(fp, map_location=device,weights_only=True)
del checkpoint_file

dims = ModelDimensions(**checkpoint["dims"])
Expand All @@ -157,4 +169,4 @@ def load_model(
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

return model.to(device)
return model.to(device)
43 changes: 42 additions & 1 deletion whisper/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,47 @@ def __init__(
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)

# Optimisation: pre-compute and register the mask in CUDA if available
if torch.cuda.is_available():
self.register_buffer("mask_cuda", mask.cuda(), persistent=False)


def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
"""
Args:
tokens: (n_batch, n_token)
audio_features: (n_batch, n_audio_ctx, n_audio_state)
Returns:
logits: (n_batch, n_token, n_vocab)
"""
n_batch, n_token = tokens.shape
n_audio_ctx, n_audio_state = audio_features.shape[1:]

x = self.token_embedding(tokens) + self.positional_embedding[:n_token]

# Optimisation: Move audio_features to GPU once here.
if torch.cuda.is_available():
audio_features = audio_features.cuda()


for block in self.blocks:
x = block(x, audio_features)

x = self.ln(x)
logits = x @ self.token_embedding.weight.T

# Optimisation: Apply the precomputed CUDA mask if available.
if torch.cuda.is_available():
mask = self.mask_cuda[:n_token, :n_token]
else:
mask = self.mask[:n_token, :n_token]

logits = logits + mask

return logits


def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
Expand Down Expand Up @@ -342,4 +383,4 @@ def install_hooks(layer: nn.Module):

detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function
decode = decode_function

0 comments on commit 30abb70

Please sign in to comment.