|
38 | 38 | from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
39 | 39 | SequenceData)
|
40 | 40 | from vllm.transformers_utils.configs.ultravox import UltravoxConfig
|
| 41 | +from vllm.utils import is_list_of |
41 | 42 |
|
42 | 43 | from .interfaces import SupportsMultiModal, SupportsPP
|
43 | 44 |
|
@@ -119,6 +120,10 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
|
119 | 120 | if not isinstance(data, list):
|
120 | 121 | data = [data]
|
121 | 122 |
|
| 123 | + # If the audio inputs are embeddings, no need for preprocessing |
| 124 | + if is_list_of(data, torch.Tensor, check="all"): |
| 125 | + return MultiModalInputs({"audio_embeds": data}) |
| 126 | + |
122 | 127 | audio_features = []
|
123 | 128 | for audio_input in data:
|
124 | 129 | if not isinstance(audio_input, tuple):
|
@@ -165,25 +170,30 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
|
165 | 170 | audios = [audios]
|
166 | 171 |
|
167 | 172 | audio_token_counts = []
|
168 |
| - for audio_data, sample_rate in audios: |
169 |
| - audio_length = audio_data.shape[0] |
170 |
| - if sample_rate != feature_extractor.sampling_rate: |
171 |
| - # Account for resampling. |
172 |
| - adjustment = feature_extractor.sampling_rate / sample_rate |
173 |
| - audio_length = math.ceil(adjustment * audio_length) |
174 |
| - |
175 |
| - feature_extractor_output_length = math.ceil( |
176 |
| - (audio_length - (feature_extractor.hop_length - 1)) / |
177 |
| - feature_extractor.hop_length) |
178 |
| - |
179 |
| - uv_config = ctx.get_hf_config(UltravoxConfig) |
180 |
| - audio_num_tokens = min( |
181 |
| - max( |
182 |
| - 1, |
183 |
| - math.ceil(feature_extractor_output_length / |
184 |
| - (uv_config.stack_factor * 2))), |
185 |
| - get_ultravox_max_audio_tokens(ctx)) |
186 |
| - audio_token_counts.append(audio_num_tokens) |
| 173 | + for audio in audios: |
| 174 | + if isinstance(audio, torch.Tensor): |
| 175 | + audio_num_tokens = audio.shape[1] |
| 176 | + audio_token_counts.append(audio_num_tokens) |
| 177 | + else: |
| 178 | + audio_data, sample_rate = audio |
| 179 | + audio_length = audio_data.shape[0] |
| 180 | + if sample_rate != feature_extractor.sampling_rate: |
| 181 | + # Account for resampling. |
| 182 | + adjustment = feature_extractor.sampling_rate / sample_rate |
| 183 | + audio_length = math.ceil(adjustment * audio_length) |
| 184 | + |
| 185 | + feature_extractor_output_length = math.ceil( |
| 186 | + (audio_length - (feature_extractor.hop_length - 1)) / |
| 187 | + feature_extractor.hop_length) |
| 188 | + |
| 189 | + uv_config = ctx.get_hf_config(UltravoxConfig) |
| 190 | + audio_num_tokens = min( |
| 191 | + max( |
| 192 | + 1, |
| 193 | + math.ceil(feature_extractor_output_length / |
| 194 | + (uv_config.stack_factor * 2))), |
| 195 | + get_ultravox_max_audio_tokens(ctx)) |
| 196 | + audio_token_counts.append(audio_num_tokens) |
187 | 197 |
|
188 | 198 | tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer)
|
189 | 199 |
|
|
0 commit comments