-
Notifications
You must be signed in to change notification settings - Fork 93
/
Copy pathtrain_image2video_lora.py
786 lines (695 loc) · 28.6 KB
/
train_image2video_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
import os
import sys
import time
import warnings
import json
from collections import defaultdict
from dataclasses import dataclass, asdict, field
from pathlib import Path
from typing import Optional, Dict, Union
from functools import partial
warnings.filterwarnings("ignore")
import deepspeed
import torch
import torchvision
import torch.distributed as dist
from deepspeed.runtime import lr_schedules
from deepspeed.runtime.engine import DeepSpeedEngine
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from einops import rearrange
from hyvideo.config import parse_args
from hyvideo.constants import C_SCALE, PROMPT_TEMPLATE
from hyvideo.dataset.video_loader import VideoDataset
from hyvideo.diffusion import load_denoiser
from hyvideo.ds_config import get_deepspeed_config
from hyvideo.utils.train_utils import (
prepare_model_inputs,
load_state_dict,
set_worker_seed_builder,
get_module_kohya_state_dict,
load_lora,
)
from hyvideo.modules import load_model
from hyvideo.text_encoder import TextEncoder
from hyvideo.utils.file_utils import (
safe_dir,
get_experiment_max_number,
empty_logger,
dump_args,
dump_codes,
resolve_resume_path,
logger_filter,
)
from hyvideo.utils.helpers import (
as_tuple,
set_manual_seed,
set_reproducibility,
profiler_context,
all_gather_sum,
EventsMonitor,
)
from hyvideo.vae import load_vae
from hyvideo.constants import PRECISION_TO_TYPE
from peft import LoraConfig, get_peft_model
from safetensors.torch import save_file
def setup_distributed_training(args):
deepspeed.init_distributed()
# Treat micro/global batch size as tuples for compatibility with mix-scale training.
world_size = dist.get_world_size()
if args.data_type == "video" and args.video_micro_batch_size is None:
# When data_type is video and video_micro_batch_size is None, we set the value from micro_batch_size
args.video_micro_batch_size = args.micro_batch_size
micro_batch_size = as_tuple(args.micro_batch_size)
video_micro_batch_size = as_tuple(args.video_micro_batch_size)
grad_accu_steps = args.gradient_accumulation_steps
global_batch_size = as_tuple(args.global_batch_size)
if "video" in args.data_type:
refer_micro_batch_size = video_micro_batch_size
else:
refer_micro_batch_size = micro_batch_size
if global_batch_size[0] is None:
# Note: Model/Pipeline parallel is not supported yet. So, data-parallel-size equals to world-size.
global_batch_size = tuple(
[mbs_i * world_size * grad_accu_steps for mbs_i in refer_micro_batch_size]
)
else:
assert global_batch_size == [
mbs_i * world_size * grad_accu_steps for mbs_i in refer_micro_batch_size
], f"Global batch size should be divisible by world size, but got {global_batch_size} and {world_size}."
rank = dist.get_rank() # Rank of the current process in the cluster.
device = (
rank % torch.cuda.device_count()
) # Device of the current process in current node.
# Set current device for the current process, otherwise dist.barrier() will occupy more memory in rank 0.
torch.cuda.set_device(device)
# Setup seed for reproducibility or performance.
set_manual_seed(args.global_seed)
set_reproducibility(args.reproduce, args.global_seed)
return (
rank,
device,
world_size,
micro_batch_size,
video_micro_batch_size,
grad_accu_steps,
global_batch_size,
)
def setup_experiment_directory(args, rank):
output_dir = safe_dir(args.output_dir)
# Automatically increase the experiment number.
existed_experiments = list(output_dir.glob("*"))
experiment_index = get_experiment_max_number(existed_experiments) + 1
model_name = args.model.replace("/", "").replace(
"-", "_"
) # Replace '/' to avoid sub-directory.
experiment_dir = (
output_dir / f"{experiment_index:04d}_{model_name}_{args.task_flag}"
)
ckpt_dir = experiment_dir / "checkpoints"
# Makesure all processes have the same experiment directory.
dist.barrier()
if rank == 0:
from loguru import logger
logger.add(
experiment_dir / "train.log",
level="DEBUG",
colorize=False,
backtrace=True,
diagnose=True,
encoding="utf-8",
filter=logger_filter("train"),
)
logger.add(
experiment_dir / "val.log",
level="DEBUG",
colorize=False,
backtrace=True,
diagnose=True,
encoding="utf-8",
filter=logger_filter("val"),
)
train_logger = logger.bind(name="train")
val_logger = logger.bind(name="val")
ckpt_dir = safe_dir(ckpt_dir)
else:
val_logger = train_logger = empty_logger()
train_logger.info(f"Experiment directory created at: {experiment_dir}")
return experiment_dir, ckpt_dir, train_logger, val_logger
def get_trainable_params(model, args):
if args.training_parts is None:
params = []
for param in model.parameters():
if param.requires_grad == True:
params.append(param)
else:
raise ValueError(f"Unknown training_parts {args.training_parts}")
return params
@dataclass
class ScalarStates:
rank: int = 0 # rank id
epoch: int = 1 # Accumulated training epochs
epoch_train_steps: int = 0 # Accumulated training steps in current epoch
epoch_update_steps: int = 0 # Accumulated update steps in current epoch
train_steps: int = 0 # Accumulated training steps
update_steps: int = 0 # Accumulated update steps
current_run_update_steps: int = 0 # Update steps in current run
consumed_samples_total: int = 0 # Accumulated consumed samples
consumed_video_samples_total: int = 0 # Accumulated consumed video samples
consumed_samples_per_dp: int = (
0 # Accumulated consumed samples per data-parallel group
)
consumed_video_samples_per_dp: int = (
0 # Accumulated consumed video samples per data-parallel group
)
consumed_tokens_total: int = 0 # Accumulated consumed tokens
consumed_computations_attn: int = (
0 # Accumulated consumed computations of attention + mlp
)
consumed_computations_total: int = 0 # Accumulated consumed computations of total
def add(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, getattr(self, k) + v)
@dataclass
class CycleStates:
log_steps: int = 0
running_loss: float = 0
running_tokens: int = 0
running_samples: int = 0
running_video_samples: int = 0
running_grad_norm: float = 0
running_loss_dict: Dict[int, float] = field(
default_factory=lambda: defaultdict(float)
)
log_steps_dict: Dict[int, int] = field(default_factory=lambda: defaultdict(int))
def add(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, getattr(self, k) + v)
def reset(self):
self.log_steps = 0
self.running_loss = 0
self.running_tokens = 0
self.running_samples = 0
self.running_video_samples = 0
self.running_grad_norm = 0
# Must be reset to float to avoid all_reduce type error.
self.running_loss_dict = defaultdict(float)
self.log_steps_dict = defaultdict(int)
def save_checkpoint(
args,
rank: int,
logger,
model_engine: DeepSpeedEngine,
ema,
scalar_state: ScalarStates,
ckpt_dir: Path,
):
_ = rank # Currently not used.
# gather scalar state
scalar_state_dict = dict(**asdict(scalar_state))
gather_results_list = [None for _ in range(dist.get_world_size())]
torch.distributed.all_gather_object(gather_results_list, scalar_state_dict)
gather_scalar_states = {}
for results in gather_results_list:
gather_scalar_states[results["rank"]] = results
client_state = {
"args": args,
"scalar_state": gather_scalar_states,
}
if ema is not None:
client_state["ema"] = ema.state_dict()
client_state["ema_config"] = ema.config
def try_save(_save_name):
checkpoint_path = ckpt_dir / _save_name
try:
model_engine.save_checkpoint(
str(ckpt_dir),
client_state=client_state,
tag=_save_name,
)
logger.info(f"Saved checkpoint to {checkpoint_path}")
return checkpoint_path
except Exception as e:
logger.error(f"Saved failed to {checkpoint_path}. {type(e)}: {e}")
return None
update_steps = scalar_state.update_steps
save_name = f"{update_steps:07d}"
save_path = try_save(save_name)
return [save_path]
def main(args):
# ============================= Setup ==============================
# Setup distributed training environment and reproducibility.
(
rank,
device,
world_size,
micro_batch_size,
video_micro_batch_size,
grad_accu_steps,
global_batch_size,
) = setup_distributed_training(args)
# Setup experiment directory
exp_dir, ckpt_dir, logger, val_logger = setup_experiment_directory(args, rank)
# Load deepspeed config
deepspeed_config = get_deepspeed_config(
args,
video_micro_batch_size[0],
global_batch_size[0],
args.output_dir,
exp_dir.name,
)
# Log and dump the arguments and codes.
logger.info(sys.argv)
logger.info(str(args))
if rank == 0:
# Dump the arguments to a file.
extra_args = {"world_size": world_size, "global_batch_size": global_batch_size}
dump_args(args, exp_dir / "args.json", extra_args)
# Dump codes to the experiment directory.
dump_codes(
exp_dir / "codes.tar.gz",
root=Path(__file__).parent.parent,
sub_dirs=["hymm", "jobs"],
save_prefix=args.task_flag,
)
# =========================== Build main model ===========================
logger.info("Building model...")
factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
if args.i2v_mode and args.i2v_condition_type == "latent_concat":
in_channels = args.latent_channels * 2 + 1
image_embed_interleave = 2
elif args.i2v_mode and args.i2v_condition_type == "token_replace":
in_channels = args.latent_channels
image_embed_interleave = 4
else:
in_channels = args.latent_channels
image_embed_interleave = 1
out_channels = args.latent_channels
if args.embedded_cfg_scale:
factor_kwargs["guidance_embed"] = True
model = load_model(
args,
in_channels=in_channels,
out_channels=out_channels,
factor_kwargs=factor_kwargs,
)
model = load_state_dict(args, model, logger)
if args.use_lora:
for param in model.parameters():
param.requires_grad_(False)
target_modules = [
"linear",
"fc1",
"fc2",
"img_attn_qkv",
"img_attn_proj",
"txt_attn_qkv",
"txt_attn_proj",
"linear1",
"linear2",
]
lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_rank,
init_lora_weights="gaussian",
target_modules=target_modules,
)
model = get_peft_model(model, lora_config)
if args.lora_path != "":
model = load_lora(model, args.lora_path, device=device)
logger.info(model)
if args.reproduce:
model.enable_deterministic()
# After model initialization, we set different seed for each process.
if args.same_data_batch:
set_manual_seed(args.global_seed)
else:
set_manual_seed(args.global_seed + rank)
ema = None
ss = ScalarStates(rank=rank)
# ========================== Initialize model_engine, optimizer =========================
if args.warmup_num_steps > 0:
logger.info(
f"Building scheduler with warmup_min_lr={args.warmup_min_lr}, warmup_max_lr={args.lr}, "
f"warmup_num_steps={args.warmup_num_steps}."
)
lr_scheduler = partial(
lr_schedules.WarmupLR,
warmup_min_lr=args.warmup_min_lr,
warmup_max_lr=args.lr,
warmup_num_steps=args.warmup_num_steps,
)
else:
lr_scheduler = None
logger.info("Initializing optimizer (using deepspeed)...")
model_engine, opt, _, scheduler = deepspeed.initialize(
args=args,
model=model,
model_parameters=get_trainable_params(model, args),
config_params=deepspeed_config,
lr_scheduler=lr_scheduler,
)
# ====================== Build denoise scheduler ========================
logger.info("Building denoise scheduler...")
denoiser = load_denoiser(args)
# ============================= Build extra models =========================
# 2d/3d VAE
vae, vae_path, s_ratio, t_ratio = load_vae(
args.vae, args.vae_precision, logger=logger, device=device
)
# Text encoder
text_encoder = TextEncoder(
text_encoder_type=args.text_encoder,
max_length=args.text_len
+ (
PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
if args.prompt_template_video is not None
else PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
if args.prompt_template is not None
else 0
),
text_encoder_precision=args.text_encoder_precision,
tokenizer_type=args.tokenizer,
i2v_mode=args.i2v_mode,
prompt_template=(
PROMPT_TEMPLATE[args.prompt_template]
if args.prompt_template is not None
else None
),
prompt_template_video=(
PROMPT_TEMPLATE[args.prompt_template_video]
if args.prompt_template_video is not None
else None
),
hidden_state_skip_layer=args.hidden_state_skip_layer,
apply_final_norm=args.apply_final_norm,
reproduce=args.reproduce,
logger=logger,
device=device,
image_embed_interleave=image_embed_interleave
)
if args.text_encoder_2 is not None:
text_encoder_2 = TextEncoder(
text_encoder_type=args.text_encoder_2,
max_length=args.text_len_2,
text_encoder_precision=args.text_encoder_precision_2,
tokenizer_type=args.tokenizer_2,
reproduce=args.reproduce,
logger=logger,
device=device,
)
else:
text_encoder_2 = None
# ================== Define dtype and forward autocast ===============
target_dtype = None
autocast_enabled = False
if model_engine.bfloat16_enabled():
target_dtype = torch.bfloat16
autocast_enabled = True
elif model_engine.fp16_enabled():
target_dtype = torch.half
autocast_enabled = True
# ============================== Load dataset ==============================
if "video" in args.data_type:
video_dataset = VideoDataset(
data_jsons_path=args.data_jsons_path,
sample_n_frames=args.sample_n_frames,
sample_stride=args.sample_stride,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
uncond_p=args.uncond_p,
args=args,
logger=logger,
)
video_sampler = DistributedSampler(
video_dataset,
num_replicas=world_size,
rank=rank,
shuffle=True,
seed=args.global_seed,
drop_last=False,
)
video_batch_sampler = None
video_loader = DataLoader(
video_dataset,
batch_size=video_micro_batch_size[0],
shuffle=False,
sampler=video_sampler,
batch_sampler=video_batch_sampler,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
prefetch_factor=None if args.num_workers == 0 else args.prefetch_factor,
worker_init_fn=set_worker_seed_builder(rank),
persistent_workers=True,
)
num_video_samples = len(video_dataset)
else:
video_dataset = None
video_loader = None
num_video_samples = 0
loader = video_loader
# ============================= Print key info =============================
print(f"[{rank}] Worker ready.")
dist.barrier()
main_loader = video_loader
try:
iters_per_epoch = len(main_loader) // grad_accu_steps
except NotImplementedError:
iters_per_epoch = 0
except TypeError:
iters_per_epoch = 0
params_count = model.params_count()
logger.info("****************************** Running training ******************************")
logger.info(f" Number GPUs: {world_size}")
logger.info(f" Training video samples(total): {num_video_samples:,}")
for k, v in params_count.items():
logger.info(f" Number {k} parameters: {v:,}")
logger.info(f" Number trainable params: {sum(p.numel() for p in get_trainable_params(model, args)):,}")
logger.info("------------------------------------------------------------------------------")
logger.info(f" Iters per epoch: {iters_per_epoch:,}")
logger.info(f" Updates per epoch: {iters_per_epoch // grad_accu_steps:,}")
logger.info(f" Batch size per device: {video_micro_batch_size}")
logger.info(f" Batch size all device: {global_batch_size:}")
logger.info(f" Gradient Accu steps: {args.gradient_accumulation_steps}")
logger.info(f" Training epochs: {ss.epoch}/{args.epochs}")
logger.info(f" Training total steps: {ss.update_steps:,}/{args.max_training_steps:,}")
logger.info("------------------------------------------------------------------------------")
logger.info(f" Path type: {args.flow_path_type}")
logger.info(f" Predict type: {args.flow_predict_type}")
logger.info(f" Loss weight: {args.flow_loss_weight}")
logger.info(f" Flow reverse: {args.flow_reverse}")
logger.info(f" Flow shift: {args.flow_shift}")
logger.info(f" Train eps: {args.flow_train_eps}")
logger.info(f" Sample eps: {args.flow_sample_eps}")
logger.info(f" Timestep type: {args.flow_snr_type}")
logger.info("------------------------------------------------------------------------------")
logger.info(f" Main model precision: {args.precision}")
logger.info("------------------------------------------------------------------------------")
logger.info(f" VAE: {args.vae} ({args.vae_precision}) - {vae_path}")
logger.info(f" Text encoder: {text_encoder}")
if text_encoder_2 is not None:
logger.info(f" Text encoder 2: {text_encoder_2}")
logger.info(f" Experiment directory: {ckpt_dir}")
logger.info("*******************************************************************************")
# ============================= Start training =============================
model_engine.train()
if args.init_save:
save_checkpoint(args, rank, logger, model_engine, ema, ss, ckpt_dir)
# Training loop
start_epoch = ss.epoch
finished = False
ss.current_run_update_steps = 0
for epoch in range(start_epoch, args.epochs):
if video_dataset is not None:
logger.info(f"Start video random shuffle(seed={args.global_seed + epoch})")
video_sampler.set_epoch(epoch) # epoch start from 1
logger.info(f"End of video random shuffle")
logger.info(f"Beginning epoch {epoch}...")
with profiler_context(
args.profile, exp_dir, worker_name=f"Rank_{rank}"
) as prof:
# Define cycle states, which accumulate the training information between log_steps.
cs = CycleStates()
start_time = time.time()
for batch_idx, batch in enumerate(loader):
# broadcast a zero size tensor to indicate starting of step
start_flag_tensor = torch.cuda.FloatTensor([])
if torch.distributed.is_initialized():
torch.distributed.broadcast(start_flag_tensor, 0, async_op=True)
# main diff
(
latents,
model_kwargs,
n_tokens,
cond_latents,
) = prepare_model_inputs(
args,
batch,
device,
model,
vae,
text_encoder,
text_encoder_2,
rope_theta_rescale_factor=args.rope_theta_rescale_factor,
rope_interpolation_factor=args.rope_interpolation_factor,
)
cur_batch_size = latents.shape[0]
cur_anchor_size = max(args.video_size)
# A forward-backward step
with torch.autocast(
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
):
_, loss_dict = denoiser.training_losses(
model_engine,
latents,
model_kwargs,
n_tokens=n_tokens,
i2v_mode=args.i2v_mode,
cond_latents=cond_latents,
args=args,
)
loss = loss_dict["loss"].mean()
model_engine.backward(loss)
# Update model parameters at the step of gradient accumulation.
model_engine.step(lr_kwargs={"last_batch_iteration": ss.update_steps})
# Update accumulated states
ss.add(
train_steps=1,
epoch_train_steps=1,
consumed_samples_per_dp=cur_batch_size,
)
ss.add(consumed_video_samples_per_dp=cur_batch_size)
# We enable `is_update_step` if the current step is the gradient accumulation boundary.
is_update_step = ss.train_steps % grad_accu_steps == 0
if is_update_step:
ss.add(
update_steps=1, epoch_update_steps=1, current_run_update_steps=1
)
if ss.update_steps >= args.max_training_steps:
# Enter stopping routine if max steps reached after this step.
finished = True
# Log training information:
cs.add(
log_steps=1,
running_loss=loss.item(),
running_samples=cur_batch_size,
running_tokens=cur_batch_size * n_tokens,
running_grad_norm=0,
)
cs.add(running_video_samples=cur_batch_size)
cs.running_loss_dict[cur_anchor_size] += loss.item()
cs.log_steps_dict[cur_anchor_size] += 1
if is_update_step and ss.update_steps % args.log_every == 0:
# Reduce loss history over all processes:
avg_loss = (
all_gather_sum(cs.running_loss / cs.log_steps, device)
/ world_size
)
avg_grad_norm = (
all_gather_sum(cs.running_grad_norm / cs.log_steps, device)
/ world_size
)
cum_samples = all_gather_sum(cs.running_samples, device)
cum_video_samples = all_gather_sum(cs.running_video_samples, device)
cum_tokens = all_gather_sum(cs.running_tokens, device)
# Measure training speed:
torch.cuda.synchronize()
end_time = time.time()
steps_per_sec = (
cs.log_steps / (end_time - start_time) / grad_accu_steps
)
samples_per_sec = cum_samples / (end_time - start_time)
sec_per_step = (end_time - start_time) / cs.log_steps
ss.add(
consumed_samples_total=cum_samples,
consumed_video_samples_total=cum_video_samples,
consumed_tokens_total=cum_tokens,
consumed_computations_attn=6
* params_count["attn+mlp"]
* cum_tokens
/ C_SCALE,
consumed_computations_total=6
* params_count["total"]
* cum_tokens
/ C_SCALE,
)
log_events = [
f"Train Loss: {avg_loss:.4f}",
f"Grad Norm: {avg_grad_norm:.4f}",
f"Lr: {opt.param_groups[0]['lr']:.6g}",
f"Sec/Step: {sec_per_step:.2f}, "
f"Steps/Sec: {steps_per_sec:.2f}",
f"Samples/Sec: {int(samples_per_sec):d}",
f"Consumed Samples: {ss.consumed_samples_total:,}",
f"Consumed Video Samples: {ss.consumed_video_samples_total:,}",
f"Consumed Tokens: {ss.consumed_tokens_total:,}",
]
summary_events = [
("Train/Steps/train_loss", avg_loss, ss.update_steps),
("Train/Steps/grad_norm", avg_grad_norm, ss.update_steps),
("Train/Steps/steps_per_sec", steps_per_sec, ss.update_steps),
(
"Train/Steps/samples_per_sec",
int(samples_per_sec),
ss.update_steps,
),
("Train/Tokens/train_loss", avg_loss, ss.consumed_tokens_total),
(
"Train/ComputationsAttn/train_loss",
avg_loss,
ss.consumed_computations_attn,
),
(
"Train/ComputationsTotal/train_loss",
avg_loss,
ss.consumed_computations_total,
),
]
# Log the training information to the logger.
logger.info(
f"(step={ss.update_steps:07d}) " + ", ".join(log_events)
)
if model_engine.monitor.enabled and rank == 0:
model_engine.monitor.write_events(summary_events)
# Reset monitoring variables:
cs.reset()
start_time = time.time()
# Save checkpoint:
if (is_update_step and ss.update_steps % args.ckpt_every == 0) or (
finished and args.final_save
):
if args.use_lora:
if rank == 0:
output_dir = os.path.join(
ckpt_dir, f"global_step{ss.update_steps}"
)
os.makedirs(output_dir, exist_ok=True)
lora_kohya_state_dict = get_module_kohya_state_dict(
model, "Hunyuan_video_I2V_lora", dtype=torch.bfloat16
)
save_file(
lora_kohya_state_dict,
f"{output_dir}/pytorch_lora_kohaya_weights.safetensors",
)
else:
save_checkpoint(
args, rank, logger, model_engine, ema, ss, ckpt_dir
)
if prof:
prof.step()
if finished:
logger.info(
f"Finished and breaking loop at step={ss.update_steps}."
)
break
if finished:
logger.info(f"Finished and breaking loop at epoch={epoch}.")
break
# Reset epoch states
ss.epoch += 1
ss.epoch_train_steps = 0
ss.epoch_update_steps = 0
logger.info("Training Finished!")
if __name__ == "__main__":
main(parse_args(mode="train"))