diff --git a/README.md b/README.md
index 10ef22bf..ab7bf35b 100644
--- a/README.md
+++ b/README.md
@@ -20,11 +20,11 @@ Open-Sora not only democratizes access to advanced video generation techniques,
streamlined and user-friendly platform that simplifies the complexities of video generation.
With Open-Sora, our goal is to foster innovation, creativity, and inclusivity within the field of content creation.
-[[中文文档]](/docs/zh_CN/README.md) [[潞晨云部署视频教程]](https://www.bilibili.com/video/BV141421R7Ag)
+[[中文文档](/docs/zh_CN/README.md)] [[潞晨云](https://cloud.luchentech.com/)|[OpenSora镜像](https://cloud.luchentech.com/doc/docs/image/open-sora/)|[视频教程](https://www.bilibili.com/video/BV1ow4m1e7PX/?vd_source=c6b752764cd36ff0e535a768e35d98d2)]
## 📰 News
-- **[2024.06.17]** 🔥 We released **Open-Sora 1.2**, which includes **3D-VAE**, **rectified flow**, and **score condition**. The video quality is greatly improved. [[checkpoints]](#open-sora-10-model-weights) [[report]](/docs/report_03.md)
+- **[2024.06.17]** 🔥 We released **Open-Sora 1.2**, which includes **3D-VAE**, **rectified flow**, and **score condition**. The video quality is greatly improved. [[checkpoints]](#open-sora-10-model-weights) [[report]](/docs/report_03.md) [[blog]](https://hpc-ai.com/blog/open-sora-from-hpc-ai-tech-team-continues-open-source-generate-any-16-second-720p-hd-video-with-one-click-model-weights-ready-to-use)
- **[2024.04.25]** 🤗 We released the [Gradio demo for Open-Sora](https://huggingface.co/spaces/hpcai-tech/open-sora) on Hugging Face Spaces.
- **[2024.04.25]** We released **Open-Sora 1.1**, which supports **2s~15s, 144p to 720p, any aspect ratio** text-to-image, **text-to-video, image-to-video, video-to-video, infinite time** generation. In addition, a full video processing pipeline is released. [[checkpoints]]() [[report]](/docs/report_02.md)
- **[2024.03.18]** We released **Open-Sora 1.0**, a fully open-source project for video generation.
@@ -38,8 +38,7 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi
## 🎥 Latest Demo
-🔥 You can experience Open-Sora on our [🤗 Gradio application on Hugging Face](https://huggingface.co/spaces/hpcai-tech/open-sora). More samples are available in our [Gallery](https://hpcaitech.github.io/Open-Sora/).
-
+🔥 You can experience Open-Sora on our [🤗 Gradio application on Hugging Face](https://huggingface.co/spaces/hpcai-tech/open-sora). More samples and corresponding prompts are available in our [Gallery](https://hpcaitech.github.io/Open-Sora/).
| **4s 720×1280** | **4s 720×1280** | **4s 720×1280** |
| ---------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
@@ -47,7 +46,6 @@ With Open-Sora, our goal is to foster innovation, creativity, and inclusivity wi
| [
](https://github.com/hpcaitech/Open-Sora/assets/99191637/644bf938-96ce-44aa-b797-b3c0b513d64c) | [
](https://github.com/hpcaitech/Open-Sora/assets/99191637/272d88ac-4b4a-484d-a665-8d07431671d0) | [
](https://github.com/hpcaitech/Open-Sora/assets/99191637/ebbac621-c34e-4bb4-9543-1c34f8989764) |
| [
](https://github.com/hpcaitech/Open-Sora/assets/99191637/a1e3a1a3-4abd-45f5-8df2-6cced69da4ca) | [
](https://github.com/hpcaitech/Open-Sora/assets/99191637/d6ce9c13-28e1-4dff-9644-cc01f5f11926) | [
](https://github.com/hpcaitech/Open-Sora/assets/99191637/561978f8-f1b0-4f4d-ae7b-45bec9001b4a) |
-
OpenSora 1.1 Demo
diff --git a/configs/opensora-v1-2/train/demo_360p.py b/configs/opensora-v1-2/train/demo_360p.py
new file mode 100644
index 00000000..e27bd3cd
--- /dev/null
+++ b/configs/opensora-v1-2/train/demo_360p.py
@@ -0,0 +1,58 @@
+# Dataset settings
+dataset = dict(
+ type="VariableVideoTextDataset",
+ transform_name="resize_crop",
+)
+
+# webvid
+bucket_config = {"360p": {102: (1.0, 5)}}
+grad_checkpoint = True
+
+# Acceleration settings
+num_workers = 8
+num_bucket_build_workers = 16
+dtype = "bf16"
+plugin = "zero2"
+
+# Model settings
+model = dict(
+ type="STDiT3-XL/2",
+ from_pretrained=None,
+ qk_norm=True,
+ enable_flash_attn=True,
+ enable_layernorm_kernel=True,
+ freeze_y_embedder=True,
+)
+vae = dict(
+ type="OpenSoraVAE_V1_2",
+ from_pretrained="hpcai-tech/OpenSora-VAE-v1.2",
+ micro_frame_size=17,
+ micro_batch_size=4,
+)
+text_encoder = dict(
+ type="t5",
+ from_pretrained="DeepFloyd/t5-v1_1-xxl",
+ model_max_length=300,
+ shardformer=True,
+)
+scheduler = dict(
+ type="rflow",
+ use_timestep_transform=True,
+ sample_method="logit-normal",
+)
+
+# Log settings
+seed = 42
+outputs = "outputs"
+wandb = False
+epochs = 1000
+log_every = 10
+ckpt_every = 200
+
+# optimization settings
+load = None
+grad_clip = 1.0
+lr = 1e-4
+ema_decay = 0.99
+adam_eps = 1e-15
+warmup_steps = 1000
diff --git a/configs/opensora-v1-2/train/stage3_480p.py b/configs/opensora-v1-2/train/demo_480p.py
similarity index 76%
rename from configs/opensora-v1-2/train/stage3_480p.py
rename to configs/opensora-v1-2/train/demo_480p.py
index b4b9ffdb..08121c7b 100644
--- a/configs/opensora-v1-2/train/stage3_480p.py
+++ b/configs/opensora-v1-2/train/demo_480p.py
@@ -9,7 +9,7 @@
grad_checkpoint = True
# Acceleration settings
-num_workers = 0
+num_workers = 8
num_bucket_build_workers = 16
dtype = "bf16"
plugin = "zero2"
@@ -41,21 +41,6 @@
sample_method="logit-normal",
)
-# Mask settings
-# 25%
-mask_ratios = {
- "random": 0.01,
- "intepolate": 0.002,
- "quarter_random": 0.002,
- "quarter_head": 0.002,
- "quarter_tail": 0.002,
- "quarter_head_tail": 0.002,
- "image_random": 0.0,
- "image_head": 0.22,
- "image_tail": 0.005,
- "image_head_tail": 0.005,
-}
-
# Log settings
seed = 42
outputs = "outputs"
diff --git a/docs/zh_CN/report_v1.md b/docs/zh_CN/report_v1.md
index bf12131a..feedda37 100644
--- a/docs/zh_CN/report_v1.md
+++ b/docs/zh_CN/report_v1.md
@@ -11,11 +11,11 @@ OpenAI的Sora在生成一分钟高质量视频方面非常出色。然而,它
如图中所示,在STDiT(ST代表时空)中,我们在每个空间注意力之后立即插入一个时间注意力。这类似于Latte论文中的变种3。然而,我们并没有控制这些变体的相似数量的参数。虽然Latte的论文声称他们的变体比变种3更好,但我们在16x256x256视频上的实验表明,相同数量的迭代次数下,性能排名为:DiT(完整)> STDiT(顺序)> STDiT(并行)≈ Latte。因此,我们出于效率考虑选择了STDiT(顺序)。[这里](/docs/acceleration.md#efficient-stdit)提供了速度基准测试。
-
+
为了专注于视频生成,我们希望基于一个强大的图像生成模型来训练我们的模型。PixArt-α是一个经过高效训练的高质量图像生成模型,具有T5条件化的DiT结构。我们使用[PixArt-α](https://github.com/PixArt-alpha/PixArt-alpha)初始化我们的模型,并将插入的时间注意力的投影层初始化为零。这种初始化在开始时保留了模型的图像生成能力,而Latte的架构则不能。插入的注意力将参数数量从5.8亿增加到7.24亿。
-
+
借鉴PixArt-α和Stable Video Diffusion的成功,我们还采用了渐进式训练策略:在366K预训练数据集上进行16x256x256的训练,然后在20K数据集上进行16x256x256、16x512x512和64x512x512的训练。通过扩展位置嵌入,这一策略极大地降低了计算成本。
@@ -26,7 +26,7 @@ OpenAI的Sora在生成一分钟高质量视频方面非常出色。然而,它
我们发现数据的数量和质量对生成视频的质量有很大的影响,甚至比模型架构和训练策略的影响还要大。目前,我们只从[HD-VG-130M](https://github.com/daooshee/HD-VG-130M)准备了第一批分割(366K个视频片段)。这些视频的质量参差不齐,而且字幕也不够准确。因此,我们进一步从提供免费许可视频的[Pexels](https://www.pexels.com/)收集了20k相对高质量的视频。我们使用LLaVA,一个图像字幕模型,通过三个帧和一个设计好的提示来标记视频。有了设计好的提示,LLaVA能够生成高质量的字幕。
-
+
由于我们更加注重数据质量,我们准备收集更多数据,并在下一版本中构建一个视频预处理流程。
@@ -38,12 +38,12 @@ OpenAI的Sora在生成一分钟高质量视频方面非常出色。然而,它
16x256x256 预训练损失曲线
-
+
16x256x256 高质量训练损失曲线
-
+
16x512x512 高质量训练损失曲线
-
+
diff --git a/gradio/README.md b/gradio/README.md
index 5dddd470..aee7303a 100644
--- a/gradio/README.md
+++ b/gradio/README.md
@@ -1,4 +1,4 @@
----
+gaungxiangyang---
title: Open Sora
emoji: 🎥
colorFrom: red
diff --git a/opensora/datasets/dataloader.py b/opensora/datasets/dataloader.py
index 8bcaed95..15058ac8 100644
--- a/opensora/datasets/dataloader.py
+++ b/opensora/datasets/dataloader.py
@@ -34,6 +34,7 @@ def prepare_dataloader(
process_group: Optional[ProcessGroup] = None,
bucket_config=None,
num_bucket_build_workers=1,
+ prefetch_factor=None,
**kwargs,
):
_kwargs = kwargs.copy()
@@ -57,6 +58,7 @@ def prepare_dataloader(
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
+ prefetch_factor=prefetch_factor,
**_kwargs,
),
batch_sampler,
@@ -79,6 +81,7 @@ def prepare_dataloader(
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_default,
+ prefetch_factor=prefetch_factor,
**_kwargs,
),
sampler,
@@ -98,6 +101,7 @@ def prepare_dataloader(
pin_memory=pin_memory,
num_workers=num_workers,
collate_fn=collate_fn_batch,
+ prefetch_factor=prefetch_factor,
**_kwargs,
),
sampler,
diff --git a/opensora/datasets/datasets.py b/opensora/datasets/datasets.py
index 34a5dcf6..8b5fdd6a 100644
--- a/opensora/datasets/datasets.py
+++ b/opensora/datasets/datasets.py
@@ -151,9 +151,11 @@ def getitem(self, index):
# Sampling video frames
video = temporal_random_crop(vframes, num_frames, self.frame_interval)
+ video = video.clone()
+ del vframes
video_fps = video_fps // self.frame_interval
-
+
# transform
transform = get_transforms_video(self.transform_name, (height, width))
video = transform(video) # T C H W
diff --git a/opensora/datasets/read_video.py b/opensora/datasets/read_video.py
index f988c306..6b5da346 100644
--- a/opensora/datasets/read_video.py
+++ b/opensora/datasets/read_video.py
@@ -1,20 +1,19 @@
import gc
import math
import os
+import re
+import warnings
from fractions import Fraction
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union
import av
import cv2
import numpy as np
import torch
-from torchvision.io.video import (
- _align_audio_frames,
- _check_av_available,
- _log_api_usage_once,
- _read_from_stream,
- _video_opt,
-)
+from torchvision import get_video_backend
+from torchvision.io.video import _check_av_available
+
+MAX_NUM_FRAMES = 2500
def read_video_av(
@@ -27,6 +26,13 @@ def read_video_av(
"""
Reads a video from a file, returning both the video frames and the audio frames
+ This method is modified from torchvision.io.video.read_video, with the following changes:
+
+ 1. will not extract audio frames and return empty for aframes
+ 2. remove checks and only support pyav
+ 3. add container.close() and gc.collect() to avoid thread leakage
+ 4. try our best to avoid memory leak
+
Args:
filename (str): path to the video file
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional):
@@ -42,99 +48,169 @@ def read_video_av(
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int)
"""
- if not torch.jit.is_scripting() and not torch.jit.is_tracing():
- _log_api_usage_once(read_video)
-
+ # format
output_format = output_format.upper()
if output_format not in ("THWC", "TCHW"):
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")
-
- from torchvision import get_video_backend
-
+ # file existence
if not os.path.exists(filename):
raise RuntimeError(f"File not found: {filename}")
-
- if get_video_backend() != "pyav":
- vframes, aframes, info = _video_opt._read_video(filename, start_pts, end_pts, pts_unit)
- else:
- _check_av_available()
-
- if end_pts is None:
- end_pts = float("inf")
-
- if end_pts < start_pts:
- raise ValueError(
- f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}"
- )
-
- info = {}
- video_frames = []
- audio_frames = []
- audio_timebase = _video_opt.default_timebase
-
+ # backend check
+ assert get_video_backend() == "pyav", "pyav backend is required for read_video_av"
+ _check_av_available()
+ # end_pts check
+ if end_pts is None:
+ end_pts = float("inf")
+ if end_pts < start_pts:
+ raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")
+
+ # == get video info ==
+ info = {}
+ # TODO: creating an container leads to memory leak (1G for 8 workers 1 GPU)
+ container = av.open(filename, metadata_errors="ignore")
+ # fps
+ video_fps = container.streams.video[0].average_rate
+ # guard against potentially corrupted files
+ if video_fps is not None:
+ info["video_fps"] = float(video_fps)
+ iter_video = container.decode(**{"video": 0})
+ frame = next(iter_video).to_rgb().to_ndarray()
+ height, width = frame.shape[:2]
+ total_frames = container.streams.video[0].frames
+ if total_frames == 0:
+ total_frames = MAX_NUM_FRAMES
+ warnings.warn(f"total_frames is 0, using {MAX_NUM_FRAMES} as a fallback")
+ container.close()
+ del container
+
+ # HACK: must create before iterating stream
+ # use np.zeros will not actually allocate memory
+ # use np.ones will lead to a little memory leak
+ video_frames = np.zeros((total_frames, height, width, 3), dtype=np.uint8)
+
+ # == read ==
+ try:
+ # TODO: The reading has memory leak (4G for 8 workers 1 GPU)
container = av.open(filename, metadata_errors="ignore")
- try:
- if container.streams.audio:
- audio_timebase = container.streams.audio[0].time_base
- if container.streams.video:
- video_frames = _read_from_stream(
- container,
- start_pts,
- end_pts,
- pts_unit,
- container.streams.video[0],
- {"video": 0},
- )
- video_fps = container.streams.video[0].average_rate
- # guard against potentially corrupted files
- if video_fps is not None:
- info["video_fps"] = float(video_fps)
-
- if container.streams.audio:
- audio_frames = _read_from_stream(
- container,
- start_pts,
- end_pts,
- pts_unit,
- container.streams.audio[0],
- {"audio": 0},
- )
- info["audio_fps"] = container.streams.audio[0].rate
- except av.AVError:
- # TODO raise a warning?
- pass
- finally:
- container.close()
- del container
- # NOTE: manually garbage collect to close pyav threads
- gc.collect()
-
- vframes_list = [frame.to_rgb().to_ndarray() for frame in video_frames]
- aframes_list = [frame.to_ndarray() for frame in audio_frames]
-
- if vframes_list:
- vframes = torch.as_tensor(np.stack(vframes_list))
- else:
- vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8)
-
- if aframes_list:
- aframes = np.concatenate(aframes_list, 1)
- aframes = torch.as_tensor(aframes)
- if pts_unit == "sec":
- start_pts = int(math.floor(start_pts * (1 / audio_timebase)))
- if end_pts != float("inf"):
- end_pts = int(math.ceil(end_pts * (1 / audio_timebase)))
- aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
- else:
- aframes = torch.empty((1, 0), dtype=torch.float32)
-
+ assert container.streams.video is not None
+ video_frames = _read_from_stream(
+ video_frames,
+ container,
+ start_pts,
+ end_pts,
+ pts_unit,
+ container.streams.video[0],
+ {"video": 0},
+ filename=filename,
+ )
+ except av.AVError as e:
+ print(f"[Warning] Error while reading video {filename}: {e}")
+
+ vframes = torch.from_numpy(video_frames).clone()
+ del video_frames
if output_format == "TCHW":
# [T,H,W,C] --> [T,C,H,W]
vframes = vframes.permute(0, 3, 1, 2)
+ aframes = torch.empty((1, 0), dtype=torch.float32)
return vframes, aframes, info
+def _read_from_stream(
+ video_frames,
+ container: "av.container.Container",
+ start_offset: float,
+ end_offset: float,
+ pts_unit: str,
+ stream: "av.stream.Stream",
+ stream_name: Dict[str, Optional[Union[int, Tuple[int, ...], List[int]]]],
+ filename: Optional[str] = None,
+) -> List["av.frame.Frame"]:
+
+ if pts_unit == "sec":
+ # TODO: we should change all of this from ground up to simply take
+ # sec and convert to MS in C++
+ start_offset = int(math.floor(start_offset * (1 / stream.time_base)))
+ if end_offset != float("inf"):
+ end_offset = int(math.ceil(end_offset * (1 / stream.time_base)))
+ else:
+ warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")
+
+ should_buffer = True
+ max_buffer_size = 5
+ if stream.type == "video":
+ # DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)
+ # so need to buffer some extra frames to sort everything
+ # properly
+ extradata = stream.codec_context.extradata
+ # overly complicated way of finding if `divx_packed` is set, following
+ # https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263
+ if extradata and b"DivX" in extradata:
+ # can't use regex directly because of some weird characters sometimes...
+ pos = extradata.find(b"DivX")
+ d = extradata[pos:]
+ o = re.search(rb"DivX(\d+)Build(\d+)(\w)", d)
+ if o is None:
+ o = re.search(rb"DivX(\d+)b(\d+)(\w)", d)
+ if o is not None:
+ should_buffer = o.group(3) == b"p"
+ seek_offset = start_offset
+ # some files don't seek to the right location, so better be safe here
+ seek_offset = max(seek_offset - 1, 0)
+ if should_buffer:
+ # FIXME this is kind of a hack, but we will jump to the previous keyframe
+ # so this will be safe
+ seek_offset = max(seek_offset - max_buffer_size, 0)
+ try:
+ # TODO check if stream needs to always be the video stream here or not
+ container.seek(seek_offset, any_frame=False, backward=True, stream=stream)
+ except av.AVError as e:
+ print(f"[Warning] Error while seeking video {filename}: {e}")
+ return []
+
+ # == main ==
+ buffer_count = 0
+ frames_pts = []
+ cnt = 0
+ try:
+ for _idx, frame in enumerate(container.decode(**stream_name)):
+ frames_pts.append(frame.pts)
+ video_frames[cnt] = frame.to_rgb().to_ndarray()
+ cnt += 1
+ if cnt >= len(video_frames):
+ break
+ if frame.pts >= end_offset:
+ if should_buffer and buffer_count < max_buffer_size:
+ buffer_count += 1
+ continue
+ break
+ except av.AVError as e:
+ print(f"[Warning] Error while reading video {filename}: {e}")
+
+ # garbage collection for thread leakage
+ container.close()
+ del container
+ # NOTE: manually garbage collect to close pyav threads
+ gc.collect()
+
+ # ensure that the results are sorted wrt the pts
+ # NOTE: here we assert frames_pts is sorted
+ start_ptr = 0
+ end_ptr = cnt
+ while start_ptr < end_ptr and frames_pts[start_ptr] < start_offset:
+ start_ptr += 1
+ while start_ptr < end_ptr and frames_pts[end_ptr - 1] > end_offset:
+ end_ptr -= 1
+ if start_offset > 0 and start_offset not in frames_pts[start_ptr:end_ptr]:
+ # if there is no frame that exactly matches the pts of start_offset
+ # add the last frame smaller than start_offset, to guarantee that
+ # we will have all the necessary data. This is most useful for audio
+ if start_ptr > 0:
+ start_ptr -= 1
+ result = video_frames[start_ptr:end_ptr].copy()
+ return result
+
+
def read_video_cv2(video_path):
cap = cv2.VideoCapture(video_path)
@@ -181,8 +257,3 @@ def read_video(video_path, backend="av"):
raise ValueError
return vframes, vinfo
-
-
-if __name__ == "__main__":
- vframes, vinfo = read_video("./data/colors/9.mp4", backend="cv2")
- x = 0
diff --git a/opensora/models/layers/blocks.py b/opensora/models/layers/blocks.py
index 8bc7e720..5e2c13da 100644
--- a/opensora/models/layers/blocks.py
+++ b/opensora/models/layers/blocks.py
@@ -499,7 +499,7 @@ def forward(self, x, cond, mask=None):
# shape:
# q, k, v: [B, SUB_N, NUM_HEADS, HEAD_DIM]
- q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
+ q = self.q_linear(x).view(B, -1, self.num_heads, self.head_dim)
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
kv = split_forward_gather_backward(kv, get_sequence_parallel_group(), dim=3, grad_scale="down")
k, v = kv.unbind(2)
diff --git a/opensora/models/stdit/stdit.py b/opensora/models/stdit/stdit.py
index 0bb3d5dd..428ad269 100644
--- a/opensora/models/stdit/stdit.py
+++ b/opensora/models/stdit/stdit.py
@@ -255,7 +255,7 @@ def __init__(
else:
self.sp_rank = None
- def forward(self, x, timestep, y, mask=None, x_mask=None):
+ def forward(self, x, timestep, y, mask=None, x_mask=None, **kwargs):
"""
Forward pass of STDiT.
Args:
diff --git a/opensora/models/stdit/stdit3.py b/opensora/models/stdit/stdit3.py
index 8703b2d1..bd9672db 100644
--- a/opensora/models/stdit/stdit3.py
+++ b/opensora/models/stdit/stdit3.py
@@ -448,7 +448,7 @@ def unpatchify(self, x, N_t, N_h, N_w, R_t, R_h, R_w):
@MODELS.register_module("STDiT3-XL/2")
def STDiT3_XL_2(from_pretrained=None, **kwargs):
force_huggingface = kwargs.pop("force_huggingface", False)
- if force_huggingface or from_pretrained is not None and not os.path.isdir(from_pretrained):
+ if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1152, patch_size=(1, 2, 2), num_heads=16, **kwargs)
@@ -460,7 +460,8 @@ def STDiT3_XL_2(from_pretrained=None, **kwargs):
@MODELS.register_module("STDiT3-3B/2")
def STDiT3_3B_2(from_pretrained=None, **kwargs):
- if from_pretrained is not None and not os.path.isdir(from_pretrained):
+ force_huggingface = kwargs.pop("force_huggingface", False)
+ if force_huggingface or from_pretrained is not None and not os.path.exists(from_pretrained):
model = STDiT3.from_pretrained(from_pretrained, **kwargs)
else:
config = STDiT3Config(depth=28, hidden_size=1872, patch_size=(1, 2, 2), num_heads=26, **kwargs)
diff --git a/opensora/models/vae/vae.py b/opensora/models/vae/vae.py
index bf50ec83..9802b02d 100644
--- a/opensora/models/vae/vae.py
+++ b/opensora/models/vae/vae.py
@@ -277,7 +277,7 @@ def OpenSoraVAE_V1_2(
scale=scale,
)
- if force_huggingface or (from_pretrained is not None and not os.path.isdir(from_pretrained)):
+ if force_huggingface or (from_pretrained is not None and not os.path.exists(from_pretrained)):
model = VideoAutoencoderPipeline.from_pretrained(from_pretrained, **kwargs)
else:
config = VideoAutoencoderPipelineConfig(**kwargs)
diff --git a/opensora/schedulers/rf/rectified_flow.py b/opensora/schedulers/rf/rectified_flow.py
index 58d7b486..8acaff5d 100644
--- a/opensora/schedulers/rf/rectified_flow.py
+++ b/opensora/schedulers/rf/rectified_flow.py
@@ -15,6 +15,11 @@ def timestep_transform(
scale=1.0,
num_timesteps=1,
):
+ # Force fp16 input to fp32 to avoid nan output
+ for key in ["height", "width", "num_frames"]:
+ if model_kwargs[key].dtype == torch.float16:
+ model_kwargs[key] = model_kwargs[key].float()
+
t = t / num_timesteps
resolution = model_kwargs["height"] * model_kwargs["width"]
ratio_space = (resolution / base_resolution).sqrt()
diff --git a/scripts/train.py b/scripts/train.py
index 2ebdd412..98bb2d70 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -98,6 +98,7 @@ def main():
drop_last=True,
pin_memory=True,
process_group=get_data_parallel_group(),
+ prefetch_factor=cfg.get("prefetch_factor", None),
)
dataloader, sampler = prepare_dataloader(
bucket_config=cfg.get("bucket_config", None),
diff --git a/tools/architecture/__init__.py b/tools/architecture/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/tools/architecture/net2net.py b/tools/architecture/net2net.py
deleted file mode 100644
index d5d7eb34..00000000
--- a/tools/architecture/net2net.py
+++ /dev/null
@@ -1,73 +0,0 @@
-"""
-Implementation of Net2Net (http://arxiv.org/abs/1511.05641)
-Numpy modules for Net2Net
-- Net2Wider
-- Net2Deeper
-
-Written by Kyunghyun Paeng
-
-"""
-
-
-def net2net(teach_param, stu_param):
- # teach param with shape (a, b)
- # stu param with shape (c, d)
- # net to net (a, b) -> (c, d) where c >= a and d >= b
- teach_param_shape = teach_param.shape
- stu_param_shape = stu_param.shape
-
- if len(stu_param_shape) > 2:
- teach_param = teach_param.reshape(teach_param_shape[0], -1)
- stu_param = stu_param.reshape(stu_param_shape[0], -1)
-
- assert len(stu_param.shape) == 1 or len(stu_param.shape) == 2, "teach_param and stu_param must be 2-dim array"
- assert len(teach_param_shape) == len(stu_param_shape), "teach_param and stu_param must have same dimension"
-
- if len(teach_param_shape) == 1:
- stu_param[: teach_param_shape[0]] = teach_param
- elif len(teach_param_shape) == 2:
- stu_param[: teach_param_shape[0], : teach_param_shape[1]] = teach_param
- else:
- breakpoint()
-
- if stu_param.shape != stu_param_shape:
- stu_param = stu_param.reshape(stu_param_shape)
-
- return stu_param
-
-
-if __name__ == "__main__":
- """Net2Net Class Test"""
-
- import torch
-
- from opensora.models.pixart import PixArt_1B_2
-
- model = PixArt_1B_2(no_temporal_pos_emb=True, space_scale=4, enable_flash_attn=True, enable_layernorm_kernel=True)
- print("load model done")
-
- ckpt = torch.load("/home/zhouyukun/projs/opensora/pretrained_models/PixArt-Sigma-XL-2-2K-MS.pth")
- print("load ckpt done")
-
- ckpt = ckpt["state_dict"]
- ckpt["x_embedder.proj.weight"] = ckpt["x_embedder.proj.weight"].unsqueeze(2)
-
- missing_keys = []
- for name, module in model.named_parameters():
- if name in ckpt:
- teach_param = ckpt[name].data
- stu_param = module.data
- stu_param = net2net(teach_param, stu_param)
-
- module.data = stu_param
-
- print("processing layer: ", name, "shape: ", module.size())
-
- else:
- # print("Missing key: ", name)
- missing_keys.append(name)
-
- print(missing_keys)
-
- breakpoint()
- torch.save({"state_dict": model.state_dict()}, "PixArt-1B-2.pth")