Skip to content

Commit

Permalink
Merge branch 'dev' into add-monaihosting-backup-url
Browse files Browse the repository at this point in the history
  • Loading branch information
KumoLiu authored Feb 24, 2025
2 parents 1d92253 + ab07523 commit aab8ad9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 11 deletions.
40 changes: 29 additions & 11 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,13 +1334,15 @@ def __call__( # type: ignore[override]
raise NotImplementedError(f"{mode} condition is not supported")

noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
)

if mode == "concat" and condition is not None:
noisy_image = torch.cat([noisy_image, condition], dim=1)
condition = None

down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond, context=condition
)

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)
Expand Down Expand Up @@ -1396,17 +1398,21 @@ def sample( # type: ignore[override]
progress_bar = iter(scheduler.timesteps)
intermediates = []
for t in progress_bar:
# 1. ControlNet forward
down_block_res_samples, mid_block_res_sample = controlnet(
x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
)
# 2. predict noise model_output
diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)

if mode == "concat" and conditioning is not None:
# 1. Conditioning
model_input = torch.cat([image, conditioning], dim=1)
# 2. ControlNet forward
down_block_res_samples, mid_block_res_sample = controlnet(
x=model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
controlnet_cond=cn_cond,
context=None,
)
# 3. predict noise model_output
model_output = diffuse(
model_input,
timesteps=torch.Tensor((t,)).to(input_noise.device),
Expand All @@ -1415,6 +1421,12 @@ def sample( # type: ignore[override]
mid_block_additional_residual=mid_block_res_sample,
)
else:
down_block_res_samples, mid_block_res_sample = controlnet(
x=image,
timesteps=torch.Tensor((t,)).to(input_noise.device),
controlnet_cond=cn_cond,
context=conditioning,
)
model_output = diffuse(
image,
timesteps=torch.Tensor((t,)).to(input_noise.device),
Expand Down Expand Up @@ -1485,16 +1497,16 @@ def get_likelihood( # type: ignore[override]
for t in progress_bar:
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
)

diffuse = diffusion_model
if isinstance(diffusion_model, SPADEDiffusionModelUNet):
diffuse = partial(diffusion_model, seg=seg)

if mode == "concat" and conditioning is not None:
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond, context=None
)
model_output = diffuse(
noisy_image,
timesteps=timesteps,
Expand All @@ -1503,6 +1515,12 @@ def get_likelihood( # type: ignore[override]
mid_block_additional_residual=mid_block_res_sample,
)
else:
down_block_res_samples, mid_block_res_sample = controlnet(
x=noisy_image,
timesteps=torch.Tensor((t,)).to(inputs.device),
controlnet_cond=cn_cond,
context=conditioning,
)
model_output = diffuse(
x=noisy_image,
timesteps=timesteps,
Expand Down
9 changes: 9 additions & 0 deletions tests/inferers/test_controlnet_inferers.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,8 @@ def test_ddim_sampler(self, model_params, controlnet_params, input_shape):
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
model_params["with_conditioning"] = True
model_params["cross_attention_dim"] = 3
controlnet_params["with_conditioning"] = True
controlnet_params["cross_attention_dim"] = 3
model = DiffusionModelUNet(**model_params)
controlnet = ControlNet(**controlnet_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -619,8 +621,11 @@ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input
model_params = model_params.copy()
n_concat_channel = 2
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
model_params["cross_attention_dim"] = None
controlnet_params["cross_attention_dim"] = None
model_params["with_conditioning"] = False
controlnet_params["with_conditioning"] = False
model = DiffusionModelUNet(**model_params)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
Expand Down Expand Up @@ -1023,8 +1028,10 @@ def test_prediction_shape_conditioned_concat(
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
stage_2_params = stage_2_params.copy()
controlnet_params = controlnet_params.copy()
n_concat_channel = 3
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
Expand Down Expand Up @@ -1106,8 +1113,10 @@ def test_sample_shape_conditioned_concat(
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
stage_2_params = stage_2_params.copy()
controlnet_params = controlnet_params.copy()
n_concat_channel = 3
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
Expand Down

0 comments on commit aab8ad9

Please sign in to comment.