Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes MPS device errors from Tensor.type() when using generate_text_semantic and generate_coarse #27

Merged
merged 5 commits into from
May 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions bark/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,6 @@ def preload_models(
# Generation Functionality
####


def _tokenize(tokenizer, text):
return tokenizer.encode(text, add_special_tokens=False)

Expand All @@ -351,7 +350,6 @@ def _detokenize(tokenizer, enc_text):
def _normalize_whitespace(text):
return re.sub(r"\s+", " ", text).strip()


TEXT_ENCODING_OFFSET = 10_048
SEMANTIC_PAD_TOKEN = 10_000
TEXT_PAD_TOKEN = 129_595
Expand Down Expand Up @@ -453,8 +451,7 @@ def generate_text_semantic(
)
if top_p is not None:
# faster to convert to numpy
logits_device = relevant_logits.device
logits_dtype = relevant_logits.type()
original_device = relevant_logits.device
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
sorted_indices = np.argsort(relevant_logits)[::-1]
sorted_logits = relevant_logits[sorted_indices]
Expand All @@ -464,7 +461,7 @@ def generate_text_semantic(
sorted_indices_to_remove[0] = False
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
relevant_logits = torch.from_numpy(relevant_logits)
relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
relevant_logits = relevant_logits.to(original_device)
if top_k is not None:
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
Expand Down Expand Up @@ -649,8 +646,7 @@ def generate_coarse(
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
if top_p is not None:
# faster to convert to numpy
logits_device = relevant_logits.device
logits_dtype = relevant_logits.type()
original_device = relevant_logits.device
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
sorted_indices = np.argsort(relevant_logits)[::-1]
sorted_logits = relevant_logits[sorted_indices]
Expand All @@ -660,7 +656,7 @@ def generate_coarse(
sorted_indices_to_remove[0] = False
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
relevant_logits = torch.from_numpy(relevant_logits)
relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
relevant_logits = relevant_logits.to(original_device)
if top_k is not None:
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
Expand Down