From 0a85eed276859ef4b1d9280a2b8ecfb2ac23e6a5 Mon Sep 17 00:00:00 2001 From: Virginia Fernandez <61539159+virginiafdez@users.noreply.github.com> Date: Tue, 18 Feb 2025 12:57:02 +0000 Subject: [PATCH] Inferer modification - save_intermediates clashes with latent shape adjustment in latent diffusion inferers (#8343) Fixes #8334 ### Description There was an if save_intermediates missing in the code that was trying to run crop of the latent spaces on the sample function of the Latent Diffusion Inferers (normal one and ControlNet one) even when intermediates aren't created. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). --------- Signed-off-by: Virginia Fernandez Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Virginia Fernandez Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/inferers/inferer.py | 17 ++-- tests/inferers/test_controlnet_inferers.py | 82 ++++++++++++++++++- .../inferers/test_latent_diffusion_inferer.py | 62 +++++++++++++- 3 files changed, 151 insertions(+), 10 deletions(-) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 769b6cc0e7..7083373859 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -1202,15 +1202,16 @@ def sample( # type: ignore[override] if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg) image = decode(latent / self.scale_factor) - if save_intermediates: intermediates = [] for latent_intermediate in latent_intermediates: @@ -1727,9 +1728,11 @@ def sample( # type: ignore[override] if self.autoencoder_latent_shape is not None: latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0) - latent_intermediates = [ - torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates - ] + if save_intermediates: + latent_intermediates = [ + torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) + for l in latent_intermediates + ] decode = autoencoder_model.decode_stage_2_outputs if isinstance(autoencoder_model, SPADEAutoencoderKL): diff --git a/tests/inferers/test_controlnet_inferers.py b/tests/inferers/test_controlnet_inferers.py index e3b0aeb5a2..2ab5cec335 100644 --- a/tests/inferers/test_controlnet_inferers.py +++ b/tests/inferers/test_controlnet_inferers.py @@ -722,7 +722,7 @@ def test_prediction_shape( @parameterized.expand(LATENT_CNDM_TEST_CASES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape( + def test_pred_shape( self, ae_model_type, autoencoder_params, @@ -1165,7 +1165,7 @@ def test_sample_shape_conditioned_concat( @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape_different_latents( + def test_shape_different_latents( self, ae_model_type, autoencoder_params, @@ -1242,6 +1242,84 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_different_latents( + self, + ae_model_type, + autoencoder_params, + dm_model_type, + stage_2_params, + controlnet_params, + input_shape, + latent_shape, + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + controlnet = ControlNet(**controlnet_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + controlnet.to(device) + stage_1.eval() + stage_2.eval() + controlnet.eval() + + noise = torch.randn(latent_shape).to(device) + mask = torch.randn(input_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + if ae_model_type == "VQVAE": + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]] + else: + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + + inferer = ControlNetLatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction, _ = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + controlnet=controlnet, + cn_cond=mask, + input_noise=noise, + seg=input_seg, + save_intermediates=True, + ) + else: + prediction = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + input_noise=noise, + controlnet=controlnet, + cn_cond=mask, + save_intermediates=False, + ) + self.assertEqual(prediction.shape, input_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL( diff --git a/tests/inferers/test_latent_diffusion_inferer.py b/tests/inferers/test_latent_diffusion_inferer.py index 2e04ad6c5c..4f81b96ca1 100644 --- a/tests/inferers/test_latent_diffusion_inferer.py +++ b/tests/inferers/test_latent_diffusion_inferer.py @@ -714,7 +714,7 @@ def test_sample_shape_conditioned_concat( @parameterized.expand(TEST_CASES_DIFF_SHAPES) @skipUnless(has_einops, "Requires einops") - def test_sample_shape_different_latents( + def test_shape_different_latents( self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape ): stage_1 = None @@ -772,6 +772,66 @@ def test_sample_shape_different_latents( ) self.assertEqual(prediction.shape, latent_shape) + @parameterized.expand(TEST_CASES_DIFF_SHAPES) + @skipUnless(has_einops, "Requires einops") + def test_sample_shape_different_latents( + self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape + ): + stage_1 = None + + if ae_model_type == "AutoencoderKL": + stage_1 = AutoencoderKL(**autoencoder_params) + if ae_model_type == "VQVAE": + stage_1 = VQVAE(**autoencoder_params) + if ae_model_type == "SPADEAutoencoderKL": + stage_1 = SPADEAutoencoderKL(**autoencoder_params) + if dm_model_type == "SPADEDiffusionModelUNet": + stage_2 = SPADEDiffusionModelUNet(**stage_2_params) + else: + stage_2 = DiffusionModelUNet(**stage_2_params) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + stage_1.to(device) + stage_2.to(device) + stage_1.eval() + stage_2.eval() + + noise = torch.randn(latent_shape).to(device) + scheduler = DDPMScheduler(num_train_timesteps=10) + # We infer the VAE shape + if ae_model_type == "VQVAE": + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]] + else: + autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]] + + inferer = LatentDiffusionInferer( + scheduler=scheduler, + scale_factor=1.0, + ldm_latent_shape=list(latent_shape[2:]), + autoencoder_latent_shape=autoencoder_latent_shape, + ) + scheduler.set_timesteps(num_inference_steps=10) + + if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL": + input_shape_seg = list(input_shape) + if "label_nc" in stage_2_params.keys(): + input_shape_seg[1] = stage_2_params["label_nc"] + else: + input_shape_seg[1] = autoencoder_params["label_nc"] + input_seg = torch.randn(input_shape_seg).to(device) + prediction, _ = inferer.sample( + autoencoder_model=stage_1, + diffusion_model=stage_2, + input_noise=noise, + save_intermediates=True, + seg=input_seg, + ) + else: + prediction = inferer.sample( + autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False + ) + self.assertEqual(prediction.shape, input_shape) + @skipUnless(has_einops, "Requires einops") def test_incompatible_spade_setup(self): stage_1 = SPADEAutoencoderKL(