Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes MPS device errors from Tensor.type() when using generate_text_semantic and generate_coarse #27

Merged
merged 5 commits into from
May 14, 2023

Conversation

fiq
Copy link

@fiq fiq commented May 6, 2023

Addresses MPS specific errors in generate_text_semantic and generate_coarse when calling Tensor.type for logit handling on MPS devices. See this underlying pytorch issue pytorch/pytorch#78929

Note that I've also submitted a similar PR directly into bark Not clear on your syncing policy, so you may want to wait and see how that fairs. Although this issues applies to both bark and bark-with-voice-clones. Has fixed it for me on an M2 Pro

FIX Tested On:
M2 Pro

Expected:

  • Seamless voice generation

Actual:

❯ SUNO_ENABLE_MPS=True python ./test-case.py
  0%|                                                                                                                                                                                                                                                           	| 0/100 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/Users/innovation/Code/bark/./test-case.py", line 57, in <module>
	audio_array=  say_semantic(training_text, voice_name)
  File "/Users/innovation/Code/bark/./test-case.py", line 30, in say_semantic
	x_semantic = generate_text_semantic(
  File "/Users/innovation/Code/bark/bark/generation.py", line 479, in generate_text_semantic
	relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
ValueError: invalid type: 'torch.mps.FloatTensor'
  0%|

To recreate:

To recreate prior to this PR, on an MPS device (tested on M2 Pro), use this test script and run as above with SUNO_ENABLE_MPS set:

from bark.generation import load_codec_model, generate_text_semantic
from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic
from bark.api import generate_audio
import numpy as np
from scipy.io.wavfile import write as write_wav
import sounddevice as sd
import pprint
pp = pprint.PrettyPrinter()
preload_models()

def say_semantic(text_prompt, voice_name):
  preload_models(
    text_use_gpu=True,
    text_use_small=False,
    coarse_use_gpu=True,
    coarse_use_small=False,
    fine_use_gpu=True,
    fine_use_small=False,
    codec_use_gpu=True,
    force_reload=True,
  )


  x_semantic = generate_text_semantic(
        text_prompt,
        history_prompt=voice_name,
        temp=0.7,
        top_k=50,
        top_p=0.95,
    )

  x_coarse_gen = generate_coarse(
        x_semantic,
        history_prompt=voice_name,
        temp=0.7,
        top_k=50,
        top_p=0.95,
    )

  x_fine_gen = generate_fine(
    x_coarse_gen,
    history_prompt=voice_name,
    temp=0.5,
    )

  return codec_decode(x_fine_gen)

voice_name="en_speaker_0"
training_text = "Hello, there!"
audio_array=  say_semantic(training_text, voice_name)

pp.pprint(audio_array)
write_wav("output.wav", SAMPLE_RATE, audio_array)
sd.play(audio_array, SAMPLE_RATE)
# allow async sd.play to complete
sd.wait()

Btw, I'm LOVING both bark and bark-with-voice-clone (even though it sounds nothing like my tunings yet 😂). Thanks for forking and unlocking the voice cloning!

@dagshub
Copy link

dagshub bot commented May 6, 2023

Join the discussion on DagsHub!

@fiq
Copy link
Author

fiq commented May 7, 2023

I have just pushed up a simplification which shouldn't break support on other device types and reduces my previous MPS fix to two lines. Tested with device types of CPU and MPS.

@devinschumacher
Copy link
Member

thank you @fiq ! i tagged @francislabountyjr here and in our Discord group to take a look at your PR.

@fiq
Copy link
Author

fiq commented May 13, 2023

Thanks @devinschumacher.

@francislabountyjr I've pushed up a further simplification/cleanup. Tested against MPS and CPU devices.

Copy link
Member

@francislabountyjr francislabountyjr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! PR makes the necessary changes to run on MPS device while not affecting the functionality of cuda devices.

@francislabountyjr francislabountyjr merged commit 43cbc43 into serp-ai:main May 14, 2023
maximus-sallam pushed a commit to maximus-sallam/bark-with-voice-clone that referenced this pull request Jun 5, 2023
Add key/value caching for autoregressive generation
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants