Skip to content

Commit b2c2377

Browse files
authored
Merge pull request #248 from mrhan1993/dev
Support image custom format
2 parents 0fbb004 + f2dd06a commit b2c2377

File tree

5 files changed

+54
-15
lines changed

5 files changed

+54
-15
lines changed

fooocusapi/api_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams:
167167
inpaint_additional_prompt=inpaint_additional_prompt,
168168
image_prompts=image_prompts,
169169
advanced_params=advanced_params,
170+
save_extension=req.save_extension,
170171
require_base64=req.require_base64,
171172
)
172173

fooocusapi/file_utils.py

+34-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
import numpy as np
66
from PIL import Image
77
import uuid
8+
import json
9+
from pathlib import Path
10+
from PIL.PngImagePlugin import PngInfo
11+
12+
813

914
output_dir = os.path.abspath(os.path.join(
1015
os.path.dirname(__file__), '..', 'outputs', 'files'))
@@ -13,16 +18,41 @@
1318
static_serve_base_url = 'http://127.0.0.1:8888/files/'
1419

1520

16-
def save_output_file(img: np.ndarray) -> str:
21+
def save_output_file(img: np.ndarray, image_meta: dict = None,
22+
image_name: str = '', extension: str = 'png') -> str:
23+
"""
24+
Save np image to file
25+
Args:
26+
img: np.ndarray image to save
27+
image_meta: dict of image metadata
28+
image_name: str of image name
29+
extension: str of image extension
30+
Returns:
31+
str of file name
32+
"""
1733
current_time = datetime.datetime.now()
1834
date_string = current_time.strftime("%Y-%m-%d")
1935

20-
filename = os.path.join(date_string, str(uuid.uuid4()) + '.png')
36+
image_name = str(uuid.uuid4()) if image_name == '' else image_name
37+
38+
filename = os.path.join(date_string, image_name + '.' + extension)
2139
file_path = os.path.join(output_dir, filename)
2240

41+
if extension not in ['png', 'jpg', 'webp']:
42+
extension = 'png'
43+
44+
if image_meta is None:
45+
image_meta = {}
46+
47+
meta = None
48+
if extension == 'png':
49+
meta = PngInfo()
50+
meta.add_text("params", json.dumps(image_meta))
51+
2352
os.makedirs(os.path.dirname(file_path), exist_ok=True)
24-
Image.fromarray(img).save(file_path)
25-
return filename
53+
Image.fromarray(img).save(file_path, format=extension,
54+
pnginfo=meta, optimize=True)
55+
return Path(filename).as_posix()
2656

2757

2858
def delete_output_file(filename: str):

fooocusapi/models.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ class Text2ImgRequest(BaseModel):
123123
refiner_switch: float = Field(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0)
124124
loras: List[Lora] = Field(default=default_loras_model)
125125
advanced_params: AdvancedParams | None = AdvancedParams()
126+
save_extension: str = Field(default='png', description="Save extension, one of [png, jpg, webp]")
126127
require_base64: bool = Field(default=False, description="Return base64 data of generated image")
127128
async_process: bool = Field(default=False, description="Set to true will run async and return job info for retrieve generataion result later")
128129
webhook_url: str | None = Field(default='', description="Optional URL for a webhook callback. If provided, the system will send a POST request to this URL upon task completion or failure."
@@ -214,6 +215,7 @@ def as_form(cls, input_image: UploadFile = Form(description="Init image for upsa
214215
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
215216
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
216217
advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
218+
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
217219
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
218220
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generataion result later"),
219221
):
@@ -226,7 +228,8 @@ def as_form(cls, input_image: UploadFile = Form(description="Init image for upsa
226228
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
227229
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
228230
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
229-
loras=loras_model, advanced_params=advanced_params_obj, require_base64=require_base64, async_process=async_process)
231+
loras=loras_model, advanced_params=advanced_params_obj, save_extension=save_extension,
232+
require_base64=require_base64, async_process=async_process)
230233

231234

232235
class ImgInpaintOrOutpaintRequest(Text2ImgRequest):
@@ -262,6 +265,7 @@ def as_form(cls, input_image: UploadFile = Form(description="Init image for inpa
262265
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
263266
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
264267
advanced_params: str| None = Form(default=None, description="Advanced parameters in JSON"),
268+
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
265269
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
266270
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generataion result later"),
267271
):
@@ -281,7 +285,7 @@ def as_form(cls, input_image: UploadFile = Form(description="Init image for inpa
281285
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
282286
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
283287
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
284-
loras=loras_model, advanced_params=advanced_params_obj, require_base64=require_base64, async_process=async_process)
288+
loras=loras_model, advanced_params=advanced_params_obj, save_extension=save_extension, require_base64=require_base64, async_process=async_process)
285289

286290

287291
class ImgPromptRequest(ImgInpaintOrOutpaintRequest):
@@ -343,6 +347,7 @@ def as_form(cls, input_image: UploadFile = Form(File(None), description="Init im
343347
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
344348
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
345349
advanced_params: str| None = Form(default=None, description="Advanced parameters in JSON"),
350+
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
346351
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
347352
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generataion result later"),
348353
):
@@ -376,7 +381,7 @@ def as_form(cls, input_image: UploadFile = Form(File(None), description="Init im
376381
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
377382
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
378383
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
379-
loras=loras_model, advanced_params=advanced_params_obj, require_base64=require_base64, async_process=async_process)
384+
loras=loras_model, advanced_params=advanced_params_obj, save_extension=save_extension, require_base64=require_base64, async_process=async_process)
380385

381386

382387
class GeneratedImageResult(BaseModel):

fooocusapi/parameters.py

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(self, prompt: str,
101101
inpaint_additional_prompt: str | None,
102102
image_prompts: List[Tuple[np.ndarray, float, float, str]],
103103
advanced_params: List[any] | None,
104+
save_extension: str,
104105
require_base64: bool):
105106
self.prompt = prompt
106107
self.negative_prompt = negative_prompt
@@ -126,6 +127,7 @@ def __init__(self, prompt: str,
126127
self.inpaint_input_image = inpaint_input_image
127128
self.inpaint_additional_prompt = inpaint_additional_prompt
128129
self.image_prompts = image_prompts
130+
self.save_extension = save_extension
129131
self.require_base64 = require_base64
130132

131133
if advanced_params is None:

fooocusapi/worker.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,14 @@ def progressbar(_, number, text):
100100
print(f'[Fooocus] {text}')
101101
outputs.append(['preview', (number, text, None)])
102102

103-
def yield_result(_, imgs, tasks):
103+
def yield_result(_, imgs, tasks, extension='png'):
104104
if not isinstance(imgs, list):
105105
imgs = [imgs]
106106

107107
results = []
108108
for i, im in enumerate(imgs):
109109
seed = -1 if len(tasks) == 0 else tasks[i]['task_seed']
110-
img_filename = save_output_file(im)
110+
img_filename = save_output_file(img=im, extension=extension)
111111
results.append(ImageGenerationResult(im=img_filename, seed=str(seed), finish_reason=GenerationFinishReason.success))
112112
async_task.set_result(results, False)
113113
worker_queue.finish_task(async_task.job_id)
@@ -150,6 +150,7 @@ def yield_result(_, imgs, tasks):
150150
inpaint_input_image = params.inpaint_input_image
151151
inpaint_additional_prompt = params.inpaint_additional_prompt
152152
inpaint_mask_image_upload = None
153+
save_extension = params.save_extension
153154

154155
if inpaint_additional_prompt is None:
155156
inpaint_additional_prompt = ''
@@ -547,7 +548,7 @@ def yield_result(_, imgs, tasks):
547548
if direct_return:
548549
d = [('Upscale (Fast)', '2x')]
549550
log(uov_input_image, d)
550-
yield_result(async_task, uov_input_image, tasks)
551+
yield_result(async_task, uov_input_image, tasks, save_extension)
551552
return
552553

553554
tiled = True
@@ -693,7 +694,7 @@ def yield_result(_, imgs, tasks):
693694
cn_img = HWC3(cn_img)
694695
task[0] = core.numpy_to_pytorch(cn_img)
695696
if advanced_parameters.debugging_cn_preprocessor:
696-
yield_result(async_task, cn_img, tasks)
697+
yield_result(async_task, cn_img, tasks, save_extension)
697698
return
698699
for task in cn_tasks[flags.cn_cpds]:
699700
cn_img, cn_stop, cn_weight = task
@@ -705,7 +706,7 @@ def yield_result(_, imgs, tasks):
705706
cn_img = HWC3(cn_img)
706707
task[0] = core.numpy_to_pytorch(cn_img)
707708
if advanced_parameters.debugging_cn_preprocessor:
708-
yield_result(async_task, cn_img, tasks)
709+
yield_result(async_task, cn_img, tasks, save_extension)
709710
return
710711
for task in cn_tasks[flags.cn_ip]:
711712
cn_img, cn_stop, cn_weight = task
@@ -716,7 +717,7 @@ def yield_result(_, imgs, tasks):
716717

717718
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
718719
if advanced_parameters.debugging_cn_preprocessor:
719-
yield_result(async_task, cn_img, tasks)
720+
yield_result(async_task, cn_img, tasks, save_extension)
720721
return
721722
for task in cn_tasks[flags.cn_ip_face]:
722723
cn_img, cn_stop, cn_weight = task
@@ -730,7 +731,7 @@ def yield_result(_, imgs, tasks):
730731

731732
task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
732733
if advanced_parameters.debugging_cn_preprocessor:
733-
yield_result(async_task, cn_img, tasks)
734+
yield_result(async_task, cn_img, tasks, save_extension)
734735
return
735736

736737
all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
@@ -877,7 +878,7 @@ def callback(step, x0, x, total_steps, y):
877878
if async_task.finish_with_error:
878879
worker_queue.finish_task(async_task.job_id)
879880
return async_task.task_result
880-
yield_result(None, results, tasks)
881+
yield_result(None, results, tasks, save_extension)
881882
return
882883
except Exception as e:
883884
print('Worker error:', e)

0 commit comments

Comments
 (0)