-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtext+cam2image.py
88 lines (73 loc) · 4.82 KB
/
text+cam2image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
import logging
import math
import os
from diffusers import StableDiffusionPipeline
from camera_embed import CameraSettingEmbedding
from inference import embed_camera_settings
from diffusers.schedulers import PNDMScheduler
import random
import torch
from safetensors import safe_open
import os
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a inference script.")
parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-2" , help="Stable Diffusion Model ID")
parser.add_argument("--camera_setting_embedding_id", type=str, default="ishengfang/Camera-Settings-as-Tokens-SD2" ,
help="Camera Setting Embedding Model ID")
parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of diffusion steps")
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
parser.add_argument("--output_dir", type=str, default="results", help="Output directory")
parser.add_argument("--output_basename", type=str, default="woman_cherry_blossom_trees", help="Output file basename")
parser.add_argument("--prompt", type=str, help="Prompt for the model",
default="half body portrait of a beautiful Portuguese woman, pale skin, brown hair with blonde highlights, wearing jeans, nature and cherry blossom trees in background")
parser.add_argument("--negative_prompt", type=str, help="Negative prompt for the model",
default="ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra legs, mutated hands fused fingers, too many fingers, long neck")
parser.add_argument("--seed", type=int, default=87, help="Random seed")
parser.add_argument("--focal_length", type=float, default=50.0, )
parser.add_argument("--f_number", type=float, default=4.0, )
parser.add_argument("--iso_speed_rating", type=float, default=100.0, )
parser.add_argument("--exposure_time", type=float, default=0.01 )
parser.add_argument("--lora_scale", type=float, default=1.0, )
args = parser.parse_args()
return args
def main():
args = parse_args()
pipeline = StableDiffusionPipeline.from_pretrained(args.model_id)
pipeline.scheduler = PNDMScheduler.from_pretrained(args.model_id, subfolder="scheduler")
cam_embed = CameraSettingEmbedding.from_pretrained(args.camera_setting_embedding_id, subfolder="cam_embed")
pipeline.load_lora_weights(args.camera_setting_embedding_id, adapter_name="camera-settings-as-tokens")
pipeline.set_adapters(["camera-settings-as-tokens"], adapter_weights=[args.lora_scale])
pipeline.to(args.device)
cam_embed.to(args.device)
os.makedirs(args.output_dir, exist_ok=True)
print(f"Output directory: {args.output_dir}")
print(f"Generating image with focal length {args.focal_length}mm and f/{args.f_number}")
print(f"iso_speed_rating: {args.iso_speed_rating}, exposure_time: {args.exposure_time}")
print(f"Seed: {args.seed}")
generator = torch.Generator(device=pipeline._execution_device)
if args.seed is not None:
generator.manual_seed(args.seed)
with torch.no_grad():
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
args.prompt, args.device, 1, True, negative_prompt=args.negative_prompt)
prompt_embeds, negative_prompt_embeds = embed_camera_settings(args.focal_length, args.f_number,
args.iso_speed_rating, args.exposure_time,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
cam_embed=cam_embed, device=args.device)
image = pipeline(prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
num_inference_steps=args.num_inference_steps,
generator=generator).images[0]
focal_length = str(int(args.focal_length))
f_number = str(args.f_number).replace('.','_')
ISO_speed_rating = str(int(args.iso_speed_rating))
exposure_time = str(args.exposure_time).replace('.','_')
save_name = f'{args.output_basename}+{focal_length}mm_f{f_number}_ISO{ISO_speed_rating}_ET{exposure_time}'
if args.seed is not None:
save_name += f'_seed{args.seed}'
save_name += '.png'
image.save(os.path.join(args.output_dir, save_name))
if __name__ == "__main__":
main()