|
3 | 3 | import os
|
4 | 4 | import re
|
5 | 5 | from dataclasses import dataclass
|
| 6 | +from fractions import Fraction |
6 | 7 | from typing import TYPE_CHECKING
|
7 | 8 |
|
8 | 9 | import numpy as np
|
|
25 | 26 | pAttr,
|
26 | 27 | pAttrs,
|
27 | 28 | )
|
| 29 | +from auto_editor.utils.subtitle_tools import convert_ass_to_text |
28 | 30 | from auto_editor.wavfile import read
|
29 | 31 |
|
30 | 32 | if TYPE_CHECKING:
|
@@ -307,31 +309,65 @@ def subtitle(
|
307 | 309 | except re.error as e:
|
308 | 310 | self.log.error(e)
|
309 | 311 |
|
310 |
| - sub_file = self.ensure.subtitle(self.src, stream) |
311 |
| - parser = SubtitleParser(self.tb) |
| 312 | + import av |
312 | 313 |
|
313 |
| - with open(sub_file, encoding="utf-8") as file: |
314 |
| - parser.parse(file.read(), "webvtt") |
| 314 | + try: |
| 315 | + container = av.open(self.src.path, "r") |
| 316 | + subtitle_stream = container.streams.subtitles[stream] |
| 317 | + assert isinstance(subtitle_stream.time_base, Fraction) |
| 318 | + except Exception as e: |
| 319 | + self.log.error(e) |
315 | 320 |
|
316 |
| - # stackoverflow.com/questions/9662346/python-code-to-remove-html-tags-from-a-string |
317 |
| - def cleanhtml(raw_html: str) -> str: |
318 |
| - cleanr = re.compile("<.*?>") |
319 |
| - return re.sub(cleanr, "", raw_html) |
| 321 | + # Get the length of the subtitle stream. |
| 322 | + sub_length = 0 |
| 323 | + for packet in container.demux(subtitle_stream): |
| 324 | + for subset in packet.decode(): |
| 325 | + if packet.pts is None or packet.duration is None: |
| 326 | + continue |
| 327 | + # See definition of `AVSubtitle` |
| 328 | + # in: https://ffmpeg.org/doxygen/trunk/avcodec_8h_source.html |
| 329 | + start = float(packet.pts * subtitle_stream.time_base) |
| 330 | + dur = float(packet.duration * subtitle_stream.time_base) |
320 | 331 |
|
321 |
| - if not parser.contents: |
322 |
| - self.log.error("subtitle has no valid entries") |
| 332 | + end = round((start + dur) * self.tb) |
| 333 | + sub_length = max(sub_length, end) |
323 | 334 |
|
324 |
| - result = np.zeros((parser.contents[-1].end), dtype=np.bool_) |
| 335 | + result = np.zeros((sub_length), dtype=np.bool_) |
| 336 | + del sub_length |
325 | 337 |
|
326 | 338 | count = 0
|
327 |
| - for content in parser.contents: |
328 |
| - if max_count is not None and count >= max_count: |
| 339 | + early_exit = False |
| 340 | + container.seek(0) |
| 341 | + for packet in container.demux(subtitle_stream): |
| 342 | + if early_exit: |
329 | 343 | break
|
330 | 344 |
|
331 |
| - line = cleanhtml(content.after.strip()) |
332 |
| - if line and re.search(pattern, line): |
333 |
| - result[content.start : content.end] = 1 |
334 |
| - count += 1 |
| 345 | + for subset in packet.decode(): |
| 346 | + if packet.pts is None or packet.duration is None: |
| 347 | + continue |
| 348 | + if max_count is not None and count >= max_count: |
| 349 | + early_exit = True |
| 350 | + break |
| 351 | + |
| 352 | + start = float(packet.pts * subtitle_stream.time_base) |
| 353 | + dur = float(packet.duration * subtitle_stream.time_base) |
| 354 | + |
| 355 | + san_start = round(start * self.tb) |
| 356 | + san_end = round((start + dur) * self.tb) |
| 357 | + |
| 358 | + for sub in subset: |
| 359 | + if sub.type == b"ass": |
| 360 | + line = convert_ass_to_text(sub.ass.decode(errors="ignore")) |
| 361 | + elif sub.type == b"text": |
| 362 | + line = sub.text.decode(errors="ignore") |
| 363 | + else: |
| 364 | + continue |
| 365 | + |
| 366 | + if line and re.search(pattern, line): |
| 367 | + result[san_start:san_end] = 1 |
| 368 | + count += 1 |
| 369 | + |
| 370 | + container.close() |
335 | 371 |
|
336 | 372 | return result
|
337 | 373 |
|
|
0 commit comments