Skip to content

Can you give the visualization code for the uncertainty estimation for each pixel point? #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
hejiaxiang1 opened this issue Mar 22, 2024 · 7 comments

Comments

@hejiaxiang1
Copy link

I tried to visualize the var part, but the output has no useful information. My modified /sd/dpmsolver_skipUQ.py code is as follows:

    #########   start sample  ########## 
    c = model.get_learned_conditioning(opt.prompt)
    c = torch.concat(opt.sample_batch_size * [c], dim=0)
    exp_dir = f'./dpm_solver_2_exp/skipUQ/{opt.prompt}_train{opt.train_la_data_size}_step{opt.timesteps}_S{opt.mc_size}/'
    os.makedirs(exp_dir, exist_ok=True)
    total_n_samples = opt.total_n_samples
    if total_n_samples % opt.sample_batch_size != 0:
        raise ValueError("Total samples for sampling must be divided exactly by opt.sample_batch_size, but got {} and {}".format(total_n_samples, opt.sample_batch_size))
    n_rounds = total_n_samples // opt.sample_batch_size
    var_sum = torch.zeros((opt.sample_batch_size, n_rounds)).to(device)
    sample_x = []
    var_x = [] # add
    img_id = 1000000
    precision_scope = autocast if opt.precision=="autocast" else nullcontext
    with torch.no_grad():
        with precision_scope("cuda"):
            with model.ema_scope():
                for loop in tqdm(
                    range(n_rounds), desc="Generating image samples for FID evaluation."
                ):
                    
                    xT, timestep, mc_sample_size  = torch.randn([opt.sample_batch_size, opt.C, opt.H // opt.f, opt.W // opt.f], device=device), opt.timesteps//2, opt.mc_size
                    T = t_seq[timestep]
                    if uq_array[timestep] == True:
                        xt_next = xT
                        exp_xt_next, var_xt_next = xT, torch.zeros_like(xT).to(device)
                        eps_mu_t_next, eps_var_t_next = custom_ld(xT, get_model_input_time(ns, T).expand(xT.shape[0]), c=c) 
                        cov_xt_next_epst_next = torch.zeros_like(xT).to(device)
                        _, model_s1, _ = conditioned_update(ns, xt_next, T, t_seq[timestep-1], custom_ld, eps_mu_t_next, pre_wuq=True, r1=0.5, c=c)
                        list_eps_mu_t_next_i = torch.unsqueeze(model_s1, dim=0)
                    else:
                        xt_next = xT
                        exp_xt_next, var_xt_next = xT, torch.zeros_like(xT).to(device)
                        eps_mu_t_next = custom_ld.accurate_forward(xT, get_model_input_time(ns, T).expand(xT.shape[0]), c=c)
                    
                    ####### Start skip UQ sampling  ######
                    for timestep in range(opt.timesteps//2, 0, -1):

                        if uq_array[timestep] == True:
                            xt = xt_next
                            exp_xt, var_xt = exp_xt_next, var_xt_next
                            eps_mu_t, eps_var_t, cov_xt_epst = eps_mu_t_next, eps_var_t_next, cov_xt_next_epst_next
                            mc_eps_exp_t = torch.mean(list_eps_mu_t_next_i, dim=0)
                        else: 
                            xt = xt_next
                            exp_xt, var_xt = exp_xt_next, var_xt_next
                            eps_mu_t = eps_mu_t_next
                        
                        s, t = t_seq[timestep], t_seq[timestep-1]
                        if uq_array[timestep] == True:
                            eps_t= sample_from_gaussion(eps_mu_t, eps_var_t)
                            xt_next, _ , model_s1_var = conditioned_update(ns=ns, x=xt, s=s, t=t, custom_ld=custom_ld, model_s=eps_t, pre_wuq=uq_array[timestep], c=c, r1=0.5)
                            exp_xt_next = conditioned_exp_iteration(exp_xt, ns, s, t, pre_wuq=uq_array[timestep], mc_eps_exp_s1=mc_eps_exp_t)
                            var_xt_next = conditioned_var_iteration(var_xt, ns, s, t, pre_wuq=uq_array[timestep], cov_xt_epst= cov_xt_epst, var_epst=model_s1_var)
                            # decide whether to see xt_next as a random variable
                            if uq_array[timestep-1] == True:
                                list_xt_next_i, list_eps_mu_t_next_i=[], []
                                s_next = t_seq[timestep-1]
                                t_next = t_seq[timestep-2]
                                lambda_s_next, lambda_t_next = ns.marginal_lambda(s_next), ns.marginal_lambda(t_next)
                                h_next = lambda_t_next - lambda_s_next
                                lambda_s1_next = lambda_s_next + 0.5 * h_next
                                s1_next = ns.inverse_lambda(lambda_s1_next)
                                sigma_s1_next = ns.marginal_std(s1_next)
                                log_alpha_s_next, log_alpha_s1_next = ns.marginal_log_mean_coeff(s_next), ns.marginal_log_mean_coeff(s1_next)
                                phi_11_next = torch.expm1(0.5*h_next)

                                for _ in range(mc_sample_size):
                                    
                                    var_xt_next = torch.clamp(var_xt_next, min=0)
                                    xt_next_i = sample_from_gaussion(exp_xt_next, var_xt_next)
                                    list_xt_next_i.append(xt_next_i)
                                    model_t_i, model_t_i_var = custom_ld(xt_next_i, get_model_input_time(ns, s_next).expand(xt_next_i.shape[0]), c=c)
                                    xu_next_i = sample_from_gaussion(torch.exp(log_alpha_s1_next - log_alpha_s_next) * xt_next_i-(sigma_s1_next * phi_11_next) * model_t_i, \
                                                                    torch.square(sigma_s1_next * phi_11_next) * model_t_i_var)
                                    model_u_i, _ = custom_ld(xu_next_i, get_model_input_time(ns, s1_next).expand(xt_next_i.shape[0]), c=c)
                                    list_eps_mu_t_next_i.append(model_u_i)

                                eps_mu_t_next, eps_var_t_next = custom_ld(xt_next, get_model_input_time(ns, s_next).expand(xt_next.shape[0]), c=c)
                                list_xt_next_i = torch.stack(list_xt_next_i, dim=0).to(device)
                                list_eps_mu_t_next_i = torch.stack(list_eps_mu_t_next_i, dim=0).to(device)
                                cov_xt_next_epst_next = torch.mean(list_xt_next_i*list_eps_mu_t_next_i, dim=0)-exp_xt_next*torch.mean(list_eps_mu_t_next_i, dim=0)
                            else:
                                eps_mu_t_next = custom_ld.accurate_forward(xt_next, get_model_input_time(ns, t).expand(xt_next.shape[0]), c=c)

                        else:
                            xt_next, model_s1 = conditioned_update(ns=ns, x=xt, s=s, t=t, custom_ld=custom_ld, model_s=eps_mu_t, pre_wuq=uq_array[timestep], c=c, r1=0.5)
                            exp_xt_next = conditioned_exp_iteration(exp_xt, ns, s, t, exp_s1= model_s1, pre_wuq=uq_array[timestep])
                            var_xt_next = conditioned_var_iteration(var_xt, ns, s, t, pre_wuq=uq_array[timestep])
                            if uq_array[timestep-1] == True:
                                list_xt_next_i, list_eps_mu_t_next_i=[], []
                                s_next = t_seq[timestep-1]
                                t_next = t_seq[timestep-2]
                                lambda_s_next, lambda_t_next = ns.marginal_lambda(s_next), ns.marginal_lambda(t_next)
                                h_next = lambda_t_next - lambda_s_next
                                lambda_s1_next = lambda_s_next + 0.5 * h_next
                                s1_next = ns.inverse_lambda(lambda_s1_next)
                                sigma_s1_next = ns.marginal_std(s1_next)
                                log_alpha_s_next, log_alpha_s1_next = ns.marginal_log_mean_coeff(s_next), ns.marginal_log_mean_coeff(s1_next)
                                phi_11_next = torch.expm1(0.5*h_next)

                                for _ in range(mc_sample_size):
                                    
                                    var_xt_next = torch.clamp(var_xt_next, min=0)
                                    xt_next_i = sample_from_gaussion(exp_xt_next, var_xt_next)
                                    list_xt_next_i.append(xt_next_i)
                                    model_t_i, model_t_i_var = custom_ld(xt_next_i, get_model_input_time(ns, s_next).expand(xt_next_i.shape[0]), c=c)
                                    xu_next_i = sample_from_gaussion(torch.exp(log_alpha_s1_next - log_alpha_s_next) * xt_next_i-(sigma_s1_next * phi_11_next) * model_t_i, \
                                                                    torch.square(sigma_s1_next * phi_11_next) * model_t_i_var)
                                    model_u_i, _ = custom_ld(xu_next_i, get_model_input_time(ns, s1_next).expand(xt_next_i.shape[0]), c=c)
                                    list_eps_mu_t_next_i.append(model_u_i)

                                eps_mu_t_next, eps_var_t_next = custom_ld(xt_next, get_model_input_time(ns, s_next).expand(xt_next.shape[0]), c=c)
                                list_xt_next_i = torch.stack(list_xt_next_i, dim=0).to(device)
                                list_eps_mu_t_next_i = torch.stack(list_eps_mu_t_next_i, dim=0).to(device)
                                cov_xt_next_epst_next = torch.mean(list_xt_next_i*list_eps_mu_t_next_i, dim=0)-exp_xt_next*torch.mean(list_eps_mu_t_next_i, dim=0)
                            else:
                                eps_mu_t_next = custom_ld.accurate_forward(xt_next, get_model_input_time(ns, t).expand(xt_next.shape[0]), c=c)

                    ####### Save variance and sample image  ######         
                    var_sum[:, loop] = var_xt_next.sum(dim=(1,2,3))
                    x_samples = model.decode_first_stage(xt_next) # 
                    # var_xt_next = model.decode_first_stage(var_xt_next)# add
                    x = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
                    # os.makedirs(os.path.join(exp_dir, 'sam/'), exist_ok=True)
                    # for i in range(x.shape[0]):
                    #     path = os.path.join(exp_dir, 'sam/', f"{img_id}.png")
                    #     tvu.save_image(x.cpu()[i].float(), path)
                    #     img_id += 1
                    sample_x.append(x)
                    var_x.append(var_xt_next) # add

                sample_x = torch.concat(sample_x, dim=0)
                var_x = torch.concat(var_x, dim=0)# add
                var = []
                for j in range(n_rounds):
                    var.append(var_sum[:, j])
                var = torch.concat(var, dim=0)
                sorted_var, sorted_indices = torch.sort(var, descending=True)
                reordered_sample_x = torch.index_select(sample_x, dim=0, index=sorted_indices.int())
                grid_sample_x = tvu.make_grid(reordered_sample_x, nrow=8, padding=2)
                tvu.save_image(grid_sample_x.cpu().float(), os.path.join(exp_dir, "sorted_sample.png"))

                print(f'Sampling {total_n_samples} images in {exp_dir}')
                torch.save(var_sum.cpu(), os.path.join(exp_dir, 'var_sum.pt'))

                var_x = var_x.mean(dim=1, keepdim=True) # add
                reordered_var_x = torch.index_select(var_x, dim=0, index=sorted_indices.int()) # add
                grid_var_x = tvu.make_grid(reordered_var_x, nrow=12, padding=1, normalize=True) # add
                tvu.save_image(grid_var_x.cpu().float(), os.path.join(exp_dir, "sorted_var.png")) # add
@xiexh20
Copy link

xiexh20 commented Mar 22, 2024

I have a similar issue. I tried to visualize ddpm uncertainty of imagenet generations, but the image is not very meaningful.

In section 4.3, you write we sample a variety of latent states...estimate the empirical variance...as the final pixel-wise uncertainty

How did you sample exactly? Did you do Gaussian sample over the final exp_xt?

Thank you for your time and help!

@karrykkk
Copy link
Owner

Thanks for your interest in our work!

For variance visualization of Stable Diffusion in the latent space, we save $E(z_0)$ and $Var(z_0)$ (exp_xt_next and var_xt_next in the xxUQ.py script) and resample $z_{0,1}, ..., z_{0,N}$ from Gaussian distribution $\mathcal{N}(E(z_0), Var(z_0))$. Then we decode them to $x_{0,1}, ..., x_{0,N}$ and estimate the empirical variance as the final pixel-wise variance.

For the code for visualization, you refer to this script below. Feel free to ask if you have any further questions.

import torch
from matplotlib import pyplot as plt
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
from torchvision import transforms
import torchvision.utils as tvu

to_pil = transforms.ToPILImage()

def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)

    model.eval()
    return model


config = OmegaConf.load(f"configs/stable-diffusion/v1-inference.yaml")
model = load_model_from_config(config, f"your_local_sd_ckpt").to(torch.device("cuda:5"))
device = torch.device("cuda:5")
#get z
z_dev_list = []
z_exp_list = []


exp_dir = 'your_local_exp_dir'

id = 0
z_var_i = torch.load(f'{exp_dir}/z_var/{id}.pth')
z_exp_i = torch.load(f'{exp_dir}/z_exp/{id}.pth')
z_dev_i = torch.clamp(z_var_i,min=0)**0.5
z_dev_list.append(z_dev_i)
z_exp_list.append(z_exp_i)


def get_dev_x_from_z(dev,exp,N):
     #get n samples from z distribution
    z_list = []
    for i in range(N):
        z_list.append(
            exp + torch.rand_like(exp) * dev
        )
    
    #### decode z into x
    Z = torch.stack(z_list,dim = 0)
    X = model.decode_first_stage(Z.to(device))
    var_x = torch.var(X,dim = 0)
    exp_x = torch.mean(X,dim=0)
    dev_x = (var_x)**0.5
    return dev_x

import os
os.makedirs(f'{exp_dir}/x_dev',exist_ok=True)

N = 15
for index in range(1):
    z_dev = z_dev_list[index]
    z_exp = z_exp_list[index]
    dev_x = get_dev_x_from_z(z_dev,z_exp,N)
    tvu.save_image(dev_x*100,f'{exp_dir}/x_dev/{id}.jpg' )

@cilevanmarken
Copy link

Hi! I have successfully created uncertainty maps for Stable Diffusion. However, the uncertainty maps I generated for DDIM_and_guided by visualizing the var do not align with the results in your paper. Could you kindly provide the visualization code for this? Thank you in advance.

@karrykkk
Copy link
Owner

karrykkk commented May 17, 2024

Hi👋~ Thank you for your interest in our work! For CELEBA uncertainty visualization using DDIM sampler, you can try the python script in ./ddpm_and_guided/ddim_skipUQ_visualization.py & this bash configuration:

DEVICES="5"
data="celeba"
steps="100"
mc_size="10"
sample_batch_size="16"
total_n_sample="16"
train_la_data_size="5000"
DIS="uniform"
fixed_class="10"
seed=123

CUDA_VISIBLE_DEVICES=$DEVICES python ddim_skipUQ_visualization.py \
--config $data".yml" --timesteps=$steps --skip_type=$DIS --train_la_batch_size 32 \
--mc_size=$mc_size --sample_batch_size=$sample_batch_size --fixed_class=$fixed_class --train_la_data_size=$train_la_data_size \
--total_n_sample=$total_n_sample --fixed_class=$fixed_class --seed=$seed

@cilevanmarken
Copy link

Thank you for your fast reply! However, when running the given visualization code on ImageNet instead of CELEBA (with the specifications of the last post as well as the standard specifications from the ddim.sh file), the generated uncertainty maps still don't make much sense. Do you have any pointers as to why this might be? Or does the visualization code for ImageNet differ from the visualization code of CELEBA? Thanks in advance!
visualize_sample
visualize_var

@karrykkk
Copy link
Owner

Hi @cilevanmarken ~

For ImageNet visualization, as the size of dataset grows, you need to increase train_la_data_size, which means using less data to fit the posterior distribution with the amount of #total_dataset_size/train_la_data_size to get larger variance. For example, you will get the following results after changing train_la_data_size=500000 in the bash script above.

visualize_sample
visualize_var

@LibertyRoamer
Copy link

Hi! I would like to ask a question. If I change to a different dataset that only contains 200 images, what value should I set for train_la_data_size? Or would a small dataset like this lead to suboptimal results?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants