Skip to content

Commit

Permalink
Inferer modification - save_intermediates clashes with latent shape a…
Browse files Browse the repository at this point in the history
…djustment in latent diffusion inferers (#8343)

Fixes #8334 

### Description

There was an if save_intermediates missing in the code that was trying
to run crop of the latent spaces on the sample function of the Latent
Diffusion Inferers (normal one and ControlNet one) even when
intermediates aren't created.

### Types of changes
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).

---------

Signed-off-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Virginia Fernandez <virginia.fernandez@kcl.ac.uk>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 18, 2025
1 parent 44add8d commit 0a85eed
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 10 deletions.
17 changes: 10 additions & 7 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,15 +1202,16 @@ def sample( # type: ignore[override]

if self.autoencoder_latent_shape is not None:
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
]
if save_intermediates:
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
if isinstance(autoencoder_model, SPADEAutoencoderKL):
decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
image = decode(latent / self.scale_factor)

if save_intermediates:
intermediates = []
for latent_intermediate in latent_intermediates:
Expand Down Expand Up @@ -1727,9 +1728,11 @@ def sample( # type: ignore[override]

if self.autoencoder_latent_shape is not None:
latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
]
if save_intermediates:
latent_intermediates = [
torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0)
for l in latent_intermediates
]

decode = autoencoder_model.decode_stage_2_outputs
if isinstance(autoencoder_model, SPADEAutoencoderKL):
Expand Down
82 changes: 80 additions & 2 deletions tests/inferers/test_controlnet_inferers.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ def test_prediction_shape(

@parameterized.expand(LATENT_CNDM_TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape(
def test_pred_shape(
self,
ae_model_type,
autoencoder_params,
Expand Down Expand Up @@ -1165,7 +1165,7 @@ def test_sample_shape_conditioned_concat(

@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape_different_latents(
def test_shape_different_latents(
self,
ae_model_type,
autoencoder_params,
Expand Down Expand Up @@ -1242,6 +1242,84 @@ def test_sample_shape_different_latents(
)
self.assertEqual(prediction.shape, latent_shape)

@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape_different_latents(
self,
ae_model_type,
autoencoder_params,
dm_model_type,
stage_2_params,
controlnet_params,
input_shape,
latent_shape,
):
stage_1 = None

if ae_model_type == "AutoencoderKL":
stage_1 = AutoencoderKL(**autoencoder_params)
if ae_model_type == "VQVAE":
stage_1 = VQVAE(**autoencoder_params)
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
stage_2 = DiffusionModelUNet(**stage_2_params)
controlnet = ControlNet(**controlnet_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
controlnet.to(device)
stage_1.eval()
stage_2.eval()
controlnet.eval()

noise = torch.randn(latent_shape).to(device)
mask = torch.randn(input_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
# We infer the VAE shape
if ae_model_type == "VQVAE":
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
else:
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]

inferer = ControlNetLatentDiffusionInferer(
scheduler=scheduler,
scale_factor=1.0,
ldm_latent_shape=list(latent_shape[2:]),
autoencoder_latent_shape=autoencoder_latent_shape,
)
scheduler.set_timesteps(num_inference_steps=10)

if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
input_shape_seg = list(input_shape)
if "label_nc" in stage_2_params.keys():
input_shape_seg[1] = stage_2_params["label_nc"]
else:
input_shape_seg[1] = autoencoder_params["label_nc"]
input_seg = torch.randn(input_shape_seg).to(device)
prediction, _ = inferer.sample(
autoencoder_model=stage_1,
diffusion_model=stage_2,
controlnet=controlnet,
cn_cond=mask,
input_noise=noise,
seg=input_seg,
save_intermediates=True,
)
else:
prediction = inferer.sample(
autoencoder_model=stage_1,
diffusion_model=stage_2,
input_noise=noise,
controlnet=controlnet,
cn_cond=mask,
save_intermediates=False,
)
self.assertEqual(prediction.shape, input_shape)

@skipUnless(has_einops, "Requires einops")
def test_incompatible_spade_setup(self):
stage_1 = SPADEAutoencoderKL(
Expand Down
62 changes: 61 additions & 1 deletion tests/inferers/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def test_sample_shape_conditioned_concat(

@parameterized.expand(TEST_CASES_DIFF_SHAPES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape_different_latents(
def test_shape_different_latents(
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
):
stage_1 = None
Expand Down Expand Up @@ -772,6 +772,66 @@ def test_sample_shape_different_latents(
)
self.assertEqual(prediction.shape, latent_shape)

@parameterized.expand(TEST_CASES_DIFF_SHAPES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape_different_latents(
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
):
stage_1 = None

if ae_model_type == "AutoencoderKL":
stage_1 = AutoencoderKL(**autoencoder_params)
if ae_model_type == "VQVAE":
stage_1 = VQVAE(**autoencoder_params)
if ae_model_type == "SPADEAutoencoderKL":
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
if dm_model_type == "SPADEDiffusionModelUNet":
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
else:
stage_2 = DiffusionModelUNet(**stage_2_params)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

noise = torch.randn(latent_shape).to(device)
scheduler = DDPMScheduler(num_train_timesteps=10)
# We infer the VAE shape
if ae_model_type == "VQVAE":
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
else:
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]

inferer = LatentDiffusionInferer(
scheduler=scheduler,
scale_factor=1.0,
ldm_latent_shape=list(latent_shape[2:]),
autoencoder_latent_shape=autoencoder_latent_shape,
)
scheduler.set_timesteps(num_inference_steps=10)

if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
input_shape_seg = list(input_shape)
if "label_nc" in stage_2_params.keys():
input_shape_seg[1] = stage_2_params["label_nc"]
else:
input_shape_seg[1] = autoencoder_params["label_nc"]
input_seg = torch.randn(input_shape_seg).to(device)
prediction, _ = inferer.sample(
autoencoder_model=stage_1,
diffusion_model=stage_2,
input_noise=noise,
save_intermediates=True,
seg=input_seg,
)
else:
prediction = inferer.sample(
autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False
)
self.assertEqual(prediction.shape, input_shape)

@skipUnless(has_einops, "Requires einops")
def test_incompatible_spade_setup(self):
stage_1 = SPADEAutoencoderKL(
Expand Down

0 comments on commit 0a85eed

Please sign in to comment.