Skip to content

Commit 535d906

Browse files
committed
server: update
1 parent 2210257 commit 535d906

File tree

1 file changed

+125
-35
lines changed

1 file changed

+125
-35
lines changed

examples/server/main.cpp

+125-35
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ const char* sample_method_str[] = {
4040
"dpm++2s_a",
4141
"dpm++2m",
4242
"dpm++2mv2",
43+
"ipndm",
44+
"ipndm_v",
4345
"lcm",
4446
};
4547

@@ -48,7 +50,9 @@ const char* schedule_str[] = {
4850
"default",
4951
"discrete",
5052
"karras",
53+
"exponential",
5154
"ays",
55+
"gits",
5256
};
5357

5458
enum SDMode {
@@ -676,45 +680,45 @@ static void log_server_request(const httplib::Request& req, const httplib::Respo
676680
printf("request: %s %s (%s)\n", req.method.c_str(), req.path.c_str(), req.body.c_str());
677681
}
678682

679-
void parseJsonPrompt(std::string json_str, SDParams* params) {
683+
void parseJsonPrompt(std::string json_str, SDParams& params) {
680684
using namespace nlohmann;
681685
json payload = json::parse(json_str);
682686
// if no exception, the request is a json object
683687
// now we try to get the new param values from the payload object
684688
// const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path
685689
try {
686690
std::string prompt = payload["prompt"];
687-
params->prompt = prompt;
691+
params.prompt = prompt;
688692
} catch (...) {
689693
}
690694
try {
691695
std::string negative_prompt = payload["negative_prompt"];
692-
params->negative_prompt = negative_prompt;
696+
params.negative_prompt = negative_prompt;
693697
} catch (...) {
694698
}
695699
try {
696-
int clip_skip = payload["clip_skip"];
697-
params->clip_skip = clip_skip;
700+
int clip_skip = payload["clip_skip"];
701+
params.clip_skip = clip_skip;
698702
} catch (...) {
699703
}
700704
try {
701-
float cfg_scale = payload["cfg_scale"];
702-
params->cfg_scale = cfg_scale;
705+
float cfg_scale = payload["cfg_scale"];
706+
params.cfg_scale = cfg_scale;
703707
} catch (...) {
704708
}
705709
try {
706-
float guidance = payload["guidance"];
707-
params->guidance = guidance;
710+
float guidance = payload["guidance"];
711+
params.guidance = guidance;
708712
} catch (...) {
709713
}
710714
try {
711-
int width = payload["width"];
712-
params->width = width;
715+
int width = payload["width"];
716+
params.width = width;
713717
} catch (...) {
714718
}
715719
try {
716-
int height = payload["height"];
717-
params->height = height;
720+
int height = payload["height"];
721+
params.height = height;
718722
} catch (...) {
719723
}
720724
try {
@@ -727,25 +731,25 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
727731
}
728732
}
729733
if (sample_method_found >= 0) {
730-
params->sample_method = (sample_method_t)sample_method_found;
734+
params.sample_method = (sample_method_t)sample_method_found;
731735
} else {
732736
sd_log(sd_log_level_t::SD_LOG_WARN, "Unknown sampling method: %s\n", sample_method.c_str());
733737
}
734738
} catch (...) {
735739
}
736740
try {
737-
int sample_steps = payload["sample_steps"];
738-
params->sample_steps = sample_steps;
741+
int sample_steps = payload["sample_steps"];
742+
params.sample_steps = sample_steps;
739743
} catch (...) {
740744
}
741745
try {
742746
int64_t seed = payload["seed"];
743-
params->seed = seed;
747+
params.seed = seed;
744748
} catch (...) {
745749
}
746750
try {
747-
int batch_count = payload["batch_count"];
748-
params->batch_count = batch_count;
751+
int batch_count = payload["batch_count"];
752+
params.batch_count = batch_count;
749753
} catch (...) {
750754
}
751755

@@ -759,57 +763,135 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
759763
}
760764
try {
761765
float control_strength = payload["control_strength"];
762-
// params->control_strength = control_strength;
766+
// params.control_strength = control_strength;
763767
// LOG_WARN("control_strength is not supported yet\n");
764768
sd_log(sd_log_level_t::SD_LOG_WARN, "control_strength is not supported yet\n", params);
765769
} catch (...) {
766770
}
767771
try {
768772
float style_strength = payload["style_strength"];
769-
// params->style_strength = style_strength;
773+
// params.style_strength = style_strength;
770774
// LOG_WARN("style_strength is not supported yet\n");
771775
sd_log(sd_log_level_t::SD_LOG_WARN, "style_strength is not supported yet\n", params);
772776
} catch (...) {
773777
}
774778
try {
775-
bool normalize_input = payload["normalize_input"];
776-
params->normalize_input = normalize_input;
779+
bool normalize_input = payload["normalize_input"];
780+
params.normalize_input = normalize_input;
777781
} catch (...) {
778782
}
779783
try {
780784
std::string input_id_images_path = payload["input_id_images_path"];
781785
// TODO replace with b64 image maybe?
782-
params->input_id_images_path = input_id_images_path;
786+
params.input_id_images_path = input_id_images_path;
783787
} catch (...) {
784788
}
785789
try {
786790
std::string slg_scale = payload["slg_scale"];
787-
params->slg_scale = stof(slg_scale);
791+
params.slg_scale = stof(slg_scale);
788792
} catch (...) {
789793
}
790794
// TODO: more slg settings (layers, start and end)
791795
try {
792796
std::vector<int> skip_layers = payload["skip_layers"].get<std::vector<int>>();
793-
params->skip_layers.clear();
797+
params.skip_layers.clear();
794798
for (int i = 0; i < skip_layers.size(); i++) {
795-
params->skip_layers.push_back(skip_layers[i]);
799+
params.skip_layers.push_back(skip_layers[i]);
796800
}
797801
} catch (...) {
798802
}
799803
try {
800804
// skip_layer_start
801-
float skip_layer_start = payload["skip_layer_start"].get<float>();
802-
params->skip_layer_start = skip_layer_start;
805+
float skip_layer_start = payload["skip_layer_start"].get<float>();
806+
params.skip_layer_start = skip_layer_start;
803807
} catch (...) {
804808
}
805809
try {
806810
// skip_layer_end
807-
float skip_layer_end = payload["skip_layer_end"].get<float>();
808-
params->skip_layer_end = skip_layer_end;
811+
float skip_layer_end = payload["skip_layer_end"].get<float>();
812+
params.skip_layer_end = skip_layer_end;
809813
} catch (...) {
810814
}
811815
}
812816

817+
struct SDAPIParams {
818+
std::string prompt = "";
819+
std::string negative_prompt = "";
820+
// std::vector<std::string> styles = {};
821+
int seed = -1;
822+
// int subseed = -1;
823+
// float subseed_strength = 0.;
824+
// int seed_resize_from_h = -1;
825+
// int seed_resize_from_w = -1;
826+
std::string sampler_name = "";
827+
std::string scheduler = "";
828+
// int batch_size = 1; //batch processing
829+
int n_iter = 1; // batch_count
830+
int steps = 50;
831+
float cfg_scale = 7;
832+
int width = 512;
833+
int height = 512;
834+
// bool restore_faces = false;
835+
// bool tiling = false;
836+
// bool do_not_save_samples = false;
837+
// bool do_not_save_grid = false;
838+
// float eta = 0; //for ddim
839+
float denoising_strength = 0.75;
840+
// float s_min_uncond = 0;
841+
// float s_churn = 0;
842+
// float s_tmax = 0;
843+
// float s_tmin = 0;
844+
// float s_noise = 0;
845+
// nlohmann::json override_settings = {};
846+
// bool override_settings_restore_afterwards = true;
847+
// std::string refiner_checkpoint = "";
848+
// int refiner_switch_at = 0;
849+
// bool disable_extra_networks = false;
850+
// std::string firstpass_image = ""; //for highres_fix upscaling
851+
// nlohmann::json comments = {};
852+
853+
// // Highres_fix stuff
854+
// bool enable_hr = false;
855+
// int firstphase_width = 0;
856+
// int firstphase_height = 0;
857+
// int hr_scale = 2;
858+
// std::string hr_upscaler = "";
859+
// int hr_second_pass_steps = 0;
860+
// int hr_resize_x = 0;
861+
// int hr_resize_y = 0;
862+
// std::string hr_checkpoint_name = "";
863+
// std::string hr_sampler_name = "";
864+
// std::string hr_scheduler = "";
865+
// std::string hr_prompt = "";
866+
// std::string hr_negative_prompt = "";
867+
868+
// // img2img stuff
869+
std::vector<std::string> init_images = {};
870+
int resize_mode = 0;
871+
float image_cfg_scale = 0;
872+
std::string mask = "";
873+
int mask_blur_x = 4;
874+
int mask_blur_y = 4;
875+
int mask_blur = 0;
876+
bool mask_round = true;
877+
int inpainting_fill = 0;
878+
bool inpaint_full_res = true;
879+
int inpaint_full_res_padding = 0;
880+
int inpainting_mask_invert = 0;
881+
float initial_noise_multiplier = 0;
882+
std::string latent_mask = "";
883+
884+
// std::string force_task_id = "";
885+
std::string sampler_index = "Euler";
886+
// bool include_init_images = false;
887+
// std::string script_name = "";
888+
// nlohmann::json script_args = {};
889+
bool send_images = true;
890+
bool save_images = false;
891+
// nlohmann::json alwayson_scripts = {};
892+
// std::string infotext = "";
893+
};
894+
813895
// https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/latent_formats.py#L152-L169
814896
const float flux_latent_rgb_proj[16][3] = {
815897
{-0.0346, 0.0244, 0.0681},
@@ -863,7 +945,7 @@ const float sd_latent_rgb_proj[4][3]{
863945
{-0.2829, 0.1762, 0.2721},
864946
{-0.2120, -0.2616, -0.7177}};
865947

866-
void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version) {
948+
void proj_latents(struct ggml_tensor* latents, enum SDVersion version, uint8_t* data) {
867949
const int channel = 3;
868950
int width = latents->ne[0];
869951
int height = latents->ne[1];
@@ -876,7 +958,7 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
876958

877959
if (version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_2B) {
878960
latent_rgb_proj = sd3_latent_rgb_proj;
879-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
961+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
880962
latent_rgb_proj = flux_latent_rgb_proj;
881963
} else {
882964
// unknown model
@@ -897,7 +979,6 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
897979
// unknown latent space
898980
return;
899981
}
900-
uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
901982
int data_head = 0;
902983
for (int j = 0; j < height; j++) {
903984
for (int i = 0; i < width; i++) {
@@ -925,6 +1006,15 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
9251006
data[data_head++] = (uint8_t)(b * 255.);
9261007
}
9271008
}
1009+
}
1010+
1011+
void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version) {
1012+
const int channel = 3;
1013+
int width = latents->ne[0];
1014+
int height = latents->ne[1];
1015+
int dim = latents->ne[2];
1016+
uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
1017+
proj_latents(latents, version, data);
9281018
stbi_write_png("latent-preview.png", width, height, channel, data, 0);
9291019
free(data);
9301020
}
@@ -982,7 +1072,7 @@ int main(int argc, const char* argv[]) {
9821072

9831073
try {
9841074
std::string json_str = req.body;
985-
parseJsonPrompt(json_str, &params);
1075+
parseJsonPrompt(json_str, params);
9861076
} catch (json::parse_error& e) {
9871077
// assume the request is just a prompt
9881078
// LOG_WARN("Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what());

0 commit comments

Comments
 (0)