diff --git a/whisper/__init__.py b/whisper/__init__.py index e210718f3..afb7a2e78 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -60,11 +60,20 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: if os.path.exists(download_target) and not os.path.isfile(download_target): raise RuntimeError(f"{download_target} exists and is not a regular file") + def compute_sha256(file_path: str) -> str: + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + sha256.update(chunk) + return sha256.hexdigest() + if os.path.isfile(download_target): - with open(download_target, "rb") as f: - model_bytes = f.read() - if hashlib.sha256(model_bytes).hexdigest() == expected_sha256: - return model_bytes if in_memory else download_target + if compute_sha256(download_target) == expected_sha256: + if in_memory: + with open(download_target, "rb") as f: + return f.read() + else: + return download_target else: warnings.warn( f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" @@ -86,13 +95,16 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: output.write(buffer) loop.update(len(buffer)) - model_bytes = open(download_target, "rb").read() - if hashlib.sha256(model_bytes).hexdigest() != expected_sha256: + if compute_sha256(download_target) != expected_sha256: raise RuntimeError( "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model." ) - return model_bytes if in_memory else download_target + if in_memory: + with open(download_target, "rb") as f: + return f.read() + else: + return download_target def available_models() -> List[str]: @@ -147,7 +159,7 @@ def load_model( with ( io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") ) as fp: - checkpoint = torch.load(fp, map_location=device) + checkpoint = torch.load(fp, map_location=device,weights_only=True) del checkpoint_file dims = ModelDimensions(**checkpoint["dims"]) @@ -157,4 +169,4 @@ def load_model( if alignment_heads is not None: model.set_alignment_heads(alignment_heads) - return model.to(device) + return model.to(device) \ No newline at end of file diff --git a/whisper/model.py b/whisper/model.py index e53744738..997b54c3b 100644 --- a/whisper/model.py +++ b/whisper/model.py @@ -224,6 +224,47 @@ def __init__( mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1) self.register_buffer("mask", mask, persistent=False) + # Optimisation: pre-compute and register the mask in CUDA if available + if torch.cuda.is_available(): + self.register_buffer("mask_cuda", mask.cuda(), persistent=False) + + + def forward(self, tokens: Tensor, audio_features: Tensor) -> Tensor: + """ + Args: + tokens: (n_batch, n_token) + audio_features: (n_batch, n_audio_ctx, n_audio_state) + + Returns: + logits: (n_batch, n_token, n_vocab) + """ + n_batch, n_token = tokens.shape + n_audio_ctx, n_audio_state = audio_features.shape[1:] + + x = self.token_embedding(tokens) + self.positional_embedding[:n_token] + + # Optimisation: Move audio_features to GPU once here. + if torch.cuda.is_available(): + audio_features = audio_features.cuda() + + + for block in self.blocks: + x = block(x, audio_features) + + x = self.ln(x) + logits = x @ self.token_embedding.weight.T + + # Optimisation: Apply the precomputed CUDA mask if available. + if torch.cuda.is_available(): + mask = self.mask_cuda[:n_token, :n_token] + else: + mask = self.mask[:n_token, :n_token] + + logits = logits + mask + + return logits + + def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): """ x : torch.LongTensor, shape = (batch_size, <= n_ctx) @@ -342,4 +383,4 @@ def install_hooks(layer: nn.Module): detect_language = detect_language_function transcribe = transcribe_function - decode = decode_function + decode = decode_function \ No newline at end of file