Skip to content

Commit f9955d5

Browse files
committed
Correct spacing issues and adjust for non-local version of facexlib
1 parent 27d41af commit f9955d5

File tree

3 files changed

+65
-50
lines changed

3 files changed

+65
-50
lines changed

app.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class ModelVersion:
2929
STAGE_2 = "aes_stage2"
3030

3131
DEFAULT_VERSION = STAGE_2
32-
32+
3333
ENABLE_ANTI_BLUR_DEFAULT = False
3434
ENABLE_REALISM_DEFAULT = False
3535

@@ -60,13 +60,13 @@ def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
6060
global pipeline
6161

6262
if (
63-
pipeline
64-
and loaded_pipeline_config["enable_realism"] == enable_realism
63+
pipeline
64+
and loaded_pipeline_config["enable_realism"] == enable_realism
6565
and loaded_pipeline_config["enable_anti_blur"] == enable_anti_blur
6666
and model_version == loaded_pipeline_config["model_version"]
6767
):
6868
return
69-
69+
7070
loaded_pipeline_config["enable_realism"] = enable_realism
7171
loaded_pipeline_config["enable_anti_blur"] = enable_anti_blur
7272
loaded_pipeline_config["model_version"] = model_version
@@ -96,15 +96,15 @@ def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
9696

9797

9898
def generate_image(
99-
input_image,
100-
control_image,
101-
prompt,
102-
seed,
99+
input_image,
100+
control_image,
101+
prompt,
102+
seed,
103103
width,
104104
height,
105-
guidance_scale,
106-
num_steps,
107-
infusenet_conditioning_scale,
105+
guidance_scale,
106+
num_steps,
107+
infusenet_conditioning_scale,
108108
infusenet_guidance_start,
109109
infusenet_guidance_end,
110110
enable_realism,
@@ -175,15 +175,15 @@ def generate_examples(id_image, control_image, prompt_text, seed, enable_realism
175175
4. *[Optional] Adjust advanced hyperparameters or apply optional LoRAs to meet personal needs.* Please refer to **important usage tips** under the Generated Image field.
176176
5. **Click the "Generate" button to generate an image.** Enjoy!
177177
""")
178-
178+
179179
with gr.Row():
180180
with gr.Column(scale=3):
181181
with gr.Row():
182182
ui_id_image = gr.Image(label="Identity Image", type="pil", scale=3, height=370, min_width=100)
183183

184184
with gr.Column(scale=2, min_width=100):
185185
ui_control_image = gr.Image(label="Control Image [Optional]", type="pil", height=370, min_width=100)
186-
186+
187187
ui_prompt_text = gr.Textbox(label="Prompt", value="Portrait, 4K, high quality, cinematic")
188188
ui_model_version = gr.Dropdown(
189189
label="Model Version",
@@ -231,42 +231,42 @@ def generate_examples(id_image, control_image, prompt_text, seed, enable_realism
231231
)
232232

233233
ui_btn_generate.click(
234-
generate_image,
234+
generate_image,
235235
inputs=[
236-
ui_id_image,
237-
ui_control_image,
238-
ui_prompt_text,
239-
ui_seed,
236+
ui_id_image,
237+
ui_control_image,
238+
ui_prompt_text,
239+
ui_seed,
240240
ui_width,
241241
ui_height,
242-
ui_guidance_scale,
243-
ui_num_steps,
244-
ui_infusenet_conditioning_scale,
245-
ui_infusenet_guidance_start,
242+
ui_guidance_scale,
243+
ui_num_steps,
244+
ui_infusenet_conditioning_scale,
245+
ui_infusenet_guidance_start,
246246
ui_infusenet_guidance_end,
247247
ui_enable_realism,
248248
ui_enable_anti_blur,
249249
ui_model_version
250-
],
251-
outputs=[image_output],
250+
],
251+
outputs=[image_output],
252252
concurrency_id="gpu"
253253
)
254254

255255
with gr.Accordion("Local Gradio Demo for Developers", open=False):
256256
gr.Markdown(
257257
'Please refer to our GitHub repository to [run the InfiniteYou-FLUX gradio demo locally](https://github.com/bytedance/InfiniteYou#local-gradio-demo).'
258258
)
259-
259+
260260
gr.Markdown(
261261
"""
262262
---
263-
### 📜 Disclaimer and Licenses
263+
### 📜 Disclaimer and Licenses
264264
Some images in this demo are from public domains or generated by models. These pictures are intended solely to show the capabilities of our research. If you have any concerns, please contact us, and we will promptly remove any inappropriate content.
265-
265+
266266
The use of the released code, model, and demo must strictly adhere to the respective licenses. Our code is released under the Apache 2.0 License, and our model is released under the Creative Commons Attribution-NonCommercial 4.0 International Public License for academic research purposes only. Any manual or automatic downloading of the face models from [InsightFace](https://github.com/deepinsight/insightface), the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) base model, LoRAs, etc., must follow their original licenses and be used only for academic research purposes.
267267
268268
This research aims to positively impact the Generative AI field. Users are granted freedom to create images using this tool, but they must comply with local laws and use it responsibly. The developers do not assume any responsibility for potential misuse.
269-
269+
270270
### 📖 Citation
271271
272272
If you find InfiniteYou useful for your research or applications, please cite our paper:

pipelines/pipeline_infu_flux.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
import numpy as np
2222
import torch
2323
from diffusers.models import FluxControlNetModel
24-
from facexlib.recognition import init_recognition_model
24+
from facexlib.recognition import Backbone
25+
from facexlib.utils import load_file_from_url
2526
from huggingface_hub import snapshot_download
2627
from insightface.app import FaceAnalysis
2728
from insightface.utils import face_align
@@ -30,6 +31,20 @@
3031
from .pipeline_flux_infusenet import FluxInfuseNetPipeline
3132
from .resampler import Resampler
3233

34+
def init_recognition_model(model_name, half=False, device='cuda', model_rootpath=None):
35+
if model_name == 'arcface':
36+
model = Backbone(num_layers=50, drop_ratio=0.6, mode='ir_se').to(device).eval()
37+
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/recognition_arcface_ir_se50.pth'
38+
else:
39+
raise NotImplementedError(f'{model_name} is not implemented.')
40+
41+
model_path = load_file_from_url(
42+
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
43+
model.load_state_dict(torch.load(model_path, map_location=device), strict=True)
44+
model.eval()
45+
model = model.to(device)
46+
return model
47+
3348
def seed_everything(seed, deterministic=False):
3449
"""Set random seed.
3550
@@ -44,8 +59,8 @@ def seed_everything(seed, deterministic=False):
4459
np.random.seed(seed)
4560
torch.manual_seed(seed)
4661
if torch.cuda.is_available():
47-
torch.cuda.manual_seed(seed)
48-
torch.cuda.manual_seed_all(seed)
62+
torch.cuda.manual_seed(seed)
63+
torch.cuda.manual_seed_all(seed)
4964
elif torch.backends.mps.is_available():
5065
torch.mps.manual_seed(seed)
5166
os.environ['PYTHONHASHSEED'] = str(seed)
@@ -100,34 +115,34 @@ def resize_and_pad_image(source_img, target_img_size):
100115
# Get original and target sizes
101116
source_img_size = source_img.size
102117
target_width, target_height = target_img_size
103-
118+
104119
# Determine the new size based on the shorter side of target_img
105120
if target_width <= target_height:
106121
new_width = target_width
107122
new_height = int(target_width * (source_img_size[1] / source_img_size[0]))
108123
else:
109124
new_height = target_height
110125
new_width = int(target_height * (source_img_size[0] / source_img_size[1]))
111-
126+
112127
# Resize the source image using LANCZOS interpolation for high quality
113128
resized_source_img = source_img.resize((new_width, new_height), Image.LANCZOS)
114-
129+
115130
# Compute padding to center resized image
116131
pad_left = (target_width - new_width) // 2
117132
pad_top = (target_height - new_height) // 2
118-
133+
119134
# Create a new image with white background
120135
padded_img = Image.new("RGB", target_img_size, (255, 255, 255))
121136
padded_img.paste(resized_source_img, (pad_left, pad_top))
122-
137+
123138
return padded_img
124139

125140

126141
class InfUFluxPipeline:
127142
def __init__(
128-
self,
129-
base_model_path,
130-
infu_model_path,
143+
self,
144+
base_model_path,
145+
infu_model_path,
131146
insightface_root_path = './',
132147
image_proj_num_tokens=8,
133148
infu_flux_version='v1.0',
@@ -136,7 +151,7 @@ def __init__(
136151

137152
self.infu_flux_version = infu_flux_version
138153
self.model_version = model_version
139-
154+
140155
# Load pipeline
141156
try:
142157
infusenet_path = os.path.join(infu_model_path, 'InfuseNetModel')
@@ -195,15 +210,15 @@ def __init__(
195210
self.image_proj_model = image_proj_model
196211

197212
# Load face encoder
198-
self.app_640 = FaceAnalysis(name='antelopev2',
213+
self.app_640 = FaceAnalysis(name='antelopev2',
199214
root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
200215
self.app_640.prepare(ctx_id=0, det_size=(640, 640))
201216

202-
self.app_320 = FaceAnalysis(name='antelopev2',
217+
self.app_320 = FaceAnalysis(name='antelopev2',
203218
root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
204219
self.app_320.prepare(ctx_id=0, det_size=(320, 320))
205220

206-
self.app_160 = FaceAnalysis(name='antelopev2',
221+
self.app_160 = FaceAnalysis(name='antelopev2',
207222
root=insightface_root_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
208223
self.app_160.prepare(ctx_id=0, det_size=(160, 160))
209224

@@ -225,7 +240,7 @@ def _detect_face(self, id_image_cv2):
225240
face_info = self.app_640.get(id_image_cv2)
226241
if len(face_info) > 0:
227242
return face_info
228-
243+
229244
face_info = self.app_320.get(id_image_cv2)
230245
if len(face_info) > 0:
231246
return face_info
@@ -246,14 +261,14 @@ def __call__(
246261
infusenet_conditioning_scale = 1.0,
247262
infusenet_guidance_start = 0.0,
248263
infusenet_guidance_end = 1.0,
249-
):
264+
):
250265
# Extract ID embeddings
251266
print('Preparing ID embeddings')
252267
id_image_cv2 = cv2.cvtColor(np.array(id_image), cv2.COLOR_RGB2BGR)
253268
face_info = self._detect_face(id_image_cv2)
254269
if len(face_info) == 0:
255270
raise ValueError('No face detected in the input ID image')
256-
271+
257272
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
258273
landmark = face_info['kps']
259274
id_embed = extract_arcface_bgr_embedding(id_image_cv2, landmark, self.arcface_model)
@@ -266,7 +281,7 @@ def __call__(
266281
id_embed = id_embed.repeat(1, 1, 1)
267282
id_embed = id_embed.view(bs_embed * 1, seq_len, -1)
268283
id_embed = id_embed.to(device=torch.empty(1).device, dtype=torch.bfloat16)
269-
284+
270285
# Load control image
271286
print('Preparing the control image')
272287
if control_image is not None:

test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def main():
4949

5050
# Set cuda device
5151
if torch.cuda.is_available():
52-
torch.cuda.set_device(args.cuda_device)
52+
torch.cuda.set_device(args.cuda_device)
5353
elif torch.backends.mps.is_available():
5454
torch.set_default_device("mps:0")
5555
print(f'Using cuda device: {torch.empty(1).device}')
@@ -73,7 +73,7 @@ def main():
7373
if args.enable_anti_blur_lora:
7474
loras.append([os.path.join(lora_dir, 'flux_anti_blur_lora.safetensors'), 'anti_blur', 1.0])
7575
pipe.load_loras(loras)
76-
76+
7777
# Perform inference
7878
if args.seed == 0:
7979
args.seed = torch.seed() & 0xFFFFFFFF
@@ -88,7 +88,7 @@ def main():
8888
infusenet_guidance_start=args.infusenet_guidance_start,
8989
infusenet_guidance_end=args.infusenet_guidance_end,
9090
)
91-
91+
9292
# Save results
9393
os.makedirs(args.out_results_dir, exist_ok=True)
9494
index = len(os.listdir(args.out_results_dir))

0 commit comments

Comments
 (0)