Skip to content

Commit 76cf875

Browse files
authored
Merge pull request #1374 from bghira/main
merge
2 parents 2c8edeb + 7abe9b5 commit 76cf875

File tree

8 files changed

+109
-27
lines changed

8 files changed

+109
-27
lines changed

docker-start.sh

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This file can then later be sourced in a login shell
55
echo "Exporting environment variables..."
66
printenv |
7-
grep -E '^RUNPOD_|^PATH=|^HF_HOME=|^HUGGING_FACE_HUB_TOKEN=|^_=' |
7+
grep -E '^RUNPOD_|^PATH=|^HF_HOME=|^HF_TOKEN=|^HUGGING_FACE_HUB_TOKEN=|^WANDB_API_KEY=|^WANDB_TOKEN=|^_=' |
88
sed 's/^\(.*\)=\(.*\)$/export \1="\2"/' >>/etc/rp_environment
99

1010
# Add it to Bash login script
@@ -26,9 +26,19 @@ fi
2626
# Start SSH server
2727
service ssh start
2828

29-
# Load HF, WanDB tokens
30-
if [ -n "$HUGGING_FACE_HUB_TOKEN" ]; then huggingface-cli login --token "$HUGGING_FACE_HUB_TOKEN" --add-to-git-credential; else echo "HUGGING_FACE_HUB_TOKEN not set; skipping login"; fi
31-
if [ -n "$WANDB_TOKEN" ]; then wandb login "$WANDB_TOKEN"; else echo "WANDB_TOKEN not set; skipping login"; fi
29+
# Login to HF
30+
if [[ -n "${HF_TOKEN:-$HUGGING_FACE_HUB_TOKEN}" ]]; then
31+
huggingface-cli login --token "${HF_TOKEN:-$HUGGING_FACE_HUB_TOKEN}" --add-to-git-credential
32+
else
33+
echo "HF_TOKEN or HUGGING_FACE_HUB_TOKEN not set; skipping login"
34+
fi
35+
36+
# Login to WanDB
37+
if [[ -n "${WANDB_API_KEY:-$WANDB_TOKEN}" ]]; then
38+
wandb login "${WANDB_API_KEY:-$WANDB_TOKEN}"
39+
else
40+
echo "WANDB_API_KEY or WANDB_TOKEN not set; skipping login"
41+
fi
3242

3343
# 🫡
3444
sleep infinity

documentation/DOCKER.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ This command sets up the container with GPU access and maps the SSH port for ext
3939
To facilitate integration with external tools, the container supports environment variables for Huggingface and WandB tokens. Pass these at runtime as follows:
4040

4141
```bash
42-
docker run --gpus all -e HUGGING_FACE_HUB_TOKEN='your_token' -e WANDB_TOKEN='your_token' -it -p 22:22 simpletuner
42+
docker run --gpus all -e HF_TOKEN='your_token' -e WANDB_API_KEY='your_token' -it -p 22:22 simpletuner
4343
```
4444

4545
### 4. Data Volumes
@@ -98,8 +98,8 @@ services:
9898
- "[path to your datasets]:/datasets"
9999
- "[path to your configs]:/workspace/SimpleTuner/config"
100100
environment:
101-
HUGGING_FACE_HUB_TOKEN: [your hugging face token]
102-
WANDB_TOKEN: [your wanddb token]
101+
HF_TOKEN: [your hugging face token]
102+
WANDB_API_KEY: [your wanddb token]
103103
command: ["tail", "-f", "/dev/null"]
104104
deploy:
105105
resources:
@@ -155,4 +155,4 @@ services:
155155
### General Advice
156156

157157
- **Logs and Output**: Review the container logs and output for any error messages or warnings that can provide more context on the issue.
158-
- **Documentation and Forums**: Consult the Docker and NVIDIA CUDA documentation for more detailed troubleshooting advice. Community forums and issue trackers related to the specific software or dependencies you are using can also be valuable resources.
158+
- **Documentation and Forums**: Consult the Docker and NVIDIA CUDA documentation for more detailed troubleshooting advice. Community forums and issue trackers related to the specific software or dependencies you are using can also be valuable resources.

helpers/models/sd3/__init__.py

Whitespace-only changes.

helpers/models/sd3/pipeline.py

Lines changed: 78 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,32 @@
6969
>>> image.save("sd3.png")
7070
```
7171
"""
72-
72+
@torch.cuda.amp.autocast(dtype=torch.float32)
73+
def optimized_scale(positive_flat, negative_flat):
74+
75+
# Calculate dot production
76+
dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
77+
78+
# Squared norm of uncondition
79+
squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
80+
81+
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
82+
st_star = dot_product / squared_norm
83+
84+
return st_star
85+
86+
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
87+
def calculate_shift(
88+
image_seq_len,
89+
base_seq_len: int = 256,
90+
max_seq_len: int = 4096,
91+
base_shift: float = 0.5,
92+
max_shift: float = 1.16,
93+
):
94+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
95+
b = base_shift - m * base_seq_len
96+
mu = image_seq_len * m + b
97+
return mu
7398

7499
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
75100
def retrieve_timesteps(
@@ -763,6 +788,7 @@ def __call__(
763788
width: Optional[int] = None,
764789
num_inference_steps: int = 28,
765790
timesteps: List[int] = None,
791+
sigmas: Optional[List[float]] = None,
766792
guidance_scale: float = 7.0,
767793
negative_prompt: Optional[Union[str, List[str]]] = None,
768794
negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -785,6 +811,11 @@ def __call__(
785811
skip_layer_guidance_scale: int = 2.8,
786812
skip_layer_guidance_stop: int = 0.2,
787813
skip_layer_guidance_start: int = 0.01,
814+
mu: Optional[float] = None,
815+
use_cfg_zero_star: Optional[bool] = True,
816+
use_zero_init: Optional[bool] = True,
817+
zero_steps: Optional[int] = 0,
818+
788819
):
789820
r"""
790821
Function invoked when calling the pipeline for generation.
@@ -970,16 +1001,7 @@ def __call__(
9701001
[negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0
9711002
)
9721003

973-
# 4. Prepare timesteps
974-
timesteps, num_inference_steps = retrieve_timesteps(
975-
self.scheduler, num_inference_steps, device, timesteps
976-
)
977-
num_warmup_steps = max(
978-
len(timesteps) - num_inference_steps * self.scheduler.order, 0
979-
)
980-
self._num_timesteps = len(timesteps)
981-
982-
# 5. Prepare latent variables
1004+
# 4. Prepare latent variables
9831005
num_channels_latents = self.transformer.config.in_channels
9841006
latents = self.prepare_latents(
9851007
batch_size * num_images_per_prompt,
@@ -992,6 +1014,35 @@ def __call__(
9921014
latents,
9931015
)
9941016

1017+
# 5. Prepare timesteps
1018+
scheduler_kwargs = {}
1019+
if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1020+
_, _, height, width = latents.shape
1021+
image_seq_len = (height // self.transformer.config.patch_size) * (
1022+
width // self.transformer.config.patch_size
1023+
)
1024+
mu = calculate_shift(
1025+
image_seq_len,
1026+
self.scheduler.config.base_image_seq_len,
1027+
self.scheduler.config.max_image_seq_len,
1028+
self.scheduler.config.base_shift,
1029+
self.scheduler.config.max_shift,
1030+
)
1031+
scheduler_kwargs["mu"] = mu
1032+
elif mu is not None:
1033+
scheduler_kwargs["mu"] = mu
1034+
timesteps, num_inference_steps = retrieve_timesteps(
1035+
self.scheduler,
1036+
num_inference_steps,
1037+
device,
1038+
sigmas=sigmas,
1039+
**scheduler_kwargs,
1040+
)
1041+
num_warmup_steps = max(
1042+
len(timesteps) - num_inference_steps * self.scheduler.order, 0
1043+
)
1044+
self._num_timesteps = len(timesteps)
1045+
9951046
# 6. Denoising loop
9961047
with self.progress_bar(total=num_inference_steps) as progress_bar:
9971048
for i, t in enumerate(timesteps):
@@ -1026,9 +1077,21 @@ def __call__(
10261077
# perform guidance
10271078
if self.do_classifier_free_guidance:
10281079
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1029-
noise_pred = noise_pred_uncond + self.guidance_scale * (
1030-
noise_pred_text - noise_pred_uncond
1031-
)
1080+
if use_cfg_zero_star:
1081+
positive_flat = noise_pred_text.view(batch_size, -1)
1082+
negative_flat = noise_pred_uncond.view(batch_size, -1)
1083+
1084+
alpha = optimized_scale(positive_flat,negative_flat)
1085+
alpha = alpha.view(batch_size, 1, 1, 1)
1086+
alpha = alpha.to(positive_flat.dtype)
1087+
1088+
if (i <= zero_steps) and use_zero_init:
1089+
noise_pred = noise_pred_text*0.
1090+
else:
1091+
noise_pred = noise_pred_uncond * alpha + guidance_scale * (noise_pred_text - noise_pred_uncond * alpha)
1092+
else:
1093+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1094+
10321095
should_skip_layers = (
10331096
True
10341097
if i > num_inference_steps * skip_layer_guidance_start
@@ -1810,6 +1873,7 @@ def __call__(
18101873
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
18111874
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
18121875
max_sequence_length: int = 256,
1876+
18131877
):
18141878
r"""
18151879
Function invoked when calling the pipeline for generation.

helpers/prompts.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
"fairy_garden": "Whimsical garden filled with fairies, magical plants, sparkling lights, serene atmosphere, high detail",
3131
"fantasy_dragon": "Majestic dragon soaring through the sky, detailed scales, dynamic pose, fantasy art, high resolution",
3232
"floating_islands": "Fantasy world, floating islands in the sky, waterfalls, lush vegetation, detailed landscape, high resolution",
33-
"futuristic_cityscape": "Futuristic city skyline at night, neon lights, cyberpunk style, high contrast, sharp focus",
3433
"galactic_battle": "Space battle scene, starships fighting, laser beams, explosions, cosmic background",
3534
"haunted_fairground": "Abandoned fairground at night, eerie rides, ghostly figures, fog, dark atmosphere, high detail",
3635
"haunted_mansion": "Spooky haunted mansion on a hill, dark and eerie, glowing windows, ghostly atmosphere, high detail",

helpers/training/trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,9 @@ def init_unload_text_encoder(self):
779779
reclaim_memory()
780780
memory_after_unload = self.stats_memory_used()
781781
memory_saved = memory_after_unload - memory_before_unload
782-
logger.info(f"After nuking text encoders from orbit, we freed {abs(round(memory_saved, 2))} GB of VRAM.")
782+
logger.info(
783+
f"After nuking text encoders from orbit, we freed {abs(round(memory_saved, 2))} GB of VRAM."
784+
)
783785

784786
def init_precision(
785787
self, preprocessing_models_only: bool = False, ema_only: bool = False
@@ -1650,6 +1652,8 @@ def init_resume_checkpoint(self, lr_scheduler):
16501652
p = group["params"][0]
16511653
group["running_d_numerator"] = group["running_d_numerator"].to(p.device)
16521654
group["running_d_denom"] = group["running_d_denom"].to(p.device)
1655+
if "use_focus" not in group:
1656+
group["use_focus"] = False
16531657

16541658
return lr_scheduler
16551659

helpers/training/validation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def init_vae(self):
611611
).to(self.inference_device)
612612
StateTracker.set_vae(self.vae)
613613

614-
logger.info(f"VAE type: {type(self.vae)}")
614+
# logger.info(f"VAE type: {type(self.vae)}")
615615
return self.vae
616616

617617
def _discover_validation_input_samples(self):
@@ -1108,7 +1108,7 @@ def setup_scheduler(self):
11081108
scheduler_args["use_beta_sigmas"] = True
11091109
scheduler_args["shift"] = self.args.flow_schedule_shift
11101110
if self.args.validation_noise_scheduler == "unipc":
1111-
scheduler_args["prediction_type"] = 'flow_prediction'
1111+
scheduler_args["prediction_type"] = "flow_prediction"
11121112
scheduler_args["use_flow_sigmas"] = True
11131113
scheduler_args["num_train_timesteps"] = 1000
11141114
scheduler_args["flow_shift"] = self.args.flow_schedule_shift

train.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,12 @@ if [ -z "${DISABLE_UPDATES}" ]; then
9191
fi
9292
# Run the training script.
9393
if [[ -z "${ACCELERATE_CONFIG_PATH}" ]]; then
94-
ACCELERATE_CONFIG_PATH="${HOME}/.cache/huggingface/accelerate/default_config.yaml"
94+
# Look for accelerate config in HF_HOME first, otherwise fallback to $HOME
95+
if [[ -f "${HF_HOME}/accelerate/default_config.yaml" ]]; then
96+
ACCELERATE_CONFIG_PATH="${HF_HOME}/accelerate/default_config.yaml"
97+
else
98+
ACCELERATE_CONFIG_PATH="${HOME}/.cache/huggingface/accelerate/default_config.yaml"
99+
fi
95100
fi
96101
if [ -f "${ACCELERATE_CONFIG_PATH}" ]; then
97102
echo "Using Accelerate config file: ${ACCELERATE_CONFIG_PATH}"

0 commit comments

Comments
 (0)