Skip to content

Commit 951b3f8

Browse files
committed
server: update
1 parent 2210257 commit 951b3f8

File tree

1 file changed

+47
-35
lines changed

1 file changed

+47
-35
lines changed

examples/server/main.cpp

+47-35
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ const char* sample_method_str[] = {
4040
"dpm++2s_a",
4141
"dpm++2m",
4242
"dpm++2mv2",
43+
"ipndm",
44+
"ipndm_v",
4345
"lcm",
4446
};
4547

@@ -48,7 +50,9 @@ const char* schedule_str[] = {
4850
"default",
4951
"discrete",
5052
"karras",
53+
"exponential",
5154
"ays",
55+
"gits",
5256
};
5357

5458
enum SDMode {
@@ -676,45 +680,45 @@ static void log_server_request(const httplib::Request& req, const httplib::Respo
676680
printf("request: %s %s (%s)\n", req.method.c_str(), req.path.c_str(), req.body.c_str());
677681
}
678682

679-
void parseJsonPrompt(std::string json_str, SDParams* params) {
683+
void parseJsonPrompt(std::string json_str, SDParams& params) {
680684
using namespace nlohmann;
681685
json payload = json::parse(json_str);
682686
// if no exception, the request is a json object
683687
// now we try to get the new param values from the payload object
684688
// const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path
685689
try {
686690
std::string prompt = payload["prompt"];
687-
params->prompt = prompt;
691+
params.prompt = prompt;
688692
} catch (...) {
689693
}
690694
try {
691695
std::string negative_prompt = payload["negative_prompt"];
692-
params->negative_prompt = negative_prompt;
696+
params.negative_prompt = negative_prompt;
693697
} catch (...) {
694698
}
695699
try {
696-
int clip_skip = payload["clip_skip"];
697-
params->clip_skip = clip_skip;
700+
int clip_skip = payload["clip_skip"];
701+
params.clip_skip = clip_skip;
698702
} catch (...) {
699703
}
700704
try {
701-
float cfg_scale = payload["cfg_scale"];
702-
params->cfg_scale = cfg_scale;
705+
float cfg_scale = payload["cfg_scale"];
706+
params.cfg_scale = cfg_scale;
703707
} catch (...) {
704708
}
705709
try {
706-
float guidance = payload["guidance"];
707-
params->guidance = guidance;
710+
float guidance = payload["guidance"];
711+
params.guidance = guidance;
708712
} catch (...) {
709713
}
710714
try {
711-
int width = payload["width"];
712-
params->width = width;
715+
int width = payload["width"];
716+
params.width = width;
713717
} catch (...) {
714718
}
715719
try {
716-
int height = payload["height"];
717-
params->height = height;
720+
int height = payload["height"];
721+
params.height = height;
718722
} catch (...) {
719723
}
720724
try {
@@ -727,25 +731,25 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
727731
}
728732
}
729733
if (sample_method_found >= 0) {
730-
params->sample_method = (sample_method_t)sample_method_found;
734+
params.sample_method = (sample_method_t)sample_method_found;
731735
} else {
732736
sd_log(sd_log_level_t::SD_LOG_WARN, "Unknown sampling method: %s\n", sample_method.c_str());
733737
}
734738
} catch (...) {
735739
}
736740
try {
737-
int sample_steps = payload["sample_steps"];
738-
params->sample_steps = sample_steps;
741+
int sample_steps = payload["sample_steps"];
742+
params.sample_steps = sample_steps;
739743
} catch (...) {
740744
}
741745
try {
742746
int64_t seed = payload["seed"];
743-
params->seed = seed;
747+
params.seed = seed;
744748
} catch (...) {
745749
}
746750
try {
747-
int batch_count = payload["batch_count"];
748-
params->batch_count = batch_count;
751+
int batch_count = payload["batch_count"];
752+
params.batch_count = batch_count;
749753
} catch (...) {
750754
}
751755

@@ -759,53 +763,53 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
759763
}
760764
try {
761765
float control_strength = payload["control_strength"];
762-
// params->control_strength = control_strength;
766+
// params.control_strength = control_strength;
763767
// LOG_WARN("control_strength is not supported yet\n");
764768
sd_log(sd_log_level_t::SD_LOG_WARN, "control_strength is not supported yet\n", params);
765769
} catch (...) {
766770
}
767771
try {
768772
float style_strength = payload["style_strength"];
769-
// params->style_strength = style_strength;
773+
// params.style_strength = style_strength;
770774
// LOG_WARN("style_strength is not supported yet\n");
771775
sd_log(sd_log_level_t::SD_LOG_WARN, "style_strength is not supported yet\n", params);
772776
} catch (...) {
773777
}
774778
try {
775-
bool normalize_input = payload["normalize_input"];
776-
params->normalize_input = normalize_input;
779+
bool normalize_input = payload["normalize_input"];
780+
params.normalize_input = normalize_input;
777781
} catch (...) {
778782
}
779783
try {
780784
std::string input_id_images_path = payload["input_id_images_path"];
781785
// TODO replace with b64 image maybe?
782-
params->input_id_images_path = input_id_images_path;
786+
params.input_id_images_path = input_id_images_path;
783787
} catch (...) {
784788
}
785789
try {
786790
std::string slg_scale = payload["slg_scale"];
787-
params->slg_scale = stof(slg_scale);
791+
params.slg_scale = stof(slg_scale);
788792
} catch (...) {
789793
}
790794
// TODO: more slg settings (layers, start and end)
791795
try {
792796
std::vector<int> skip_layers = payload["skip_layers"].get<std::vector<int>>();
793-
params->skip_layers.clear();
797+
params.skip_layers.clear();
794798
for (int i = 0; i < skip_layers.size(); i++) {
795-
params->skip_layers.push_back(skip_layers[i]);
799+
params.skip_layers.push_back(skip_layers[i]);
796800
}
797801
} catch (...) {
798802
}
799803
try {
800804
// skip_layer_start
801-
float skip_layer_start = payload["skip_layer_start"].get<float>();
802-
params->skip_layer_start = skip_layer_start;
805+
float skip_layer_start = payload["skip_layer_start"].get<float>();
806+
params.skip_layer_start = skip_layer_start;
803807
} catch (...) {
804808
}
805809
try {
806810
// skip_layer_end
807-
float skip_layer_end = payload["skip_layer_end"].get<float>();
808-
params->skip_layer_end = skip_layer_end;
811+
float skip_layer_end = payload["skip_layer_end"].get<float>();
812+
params.skip_layer_end = skip_layer_end;
809813
} catch (...) {
810814
}
811815
}
@@ -863,7 +867,7 @@ const float sd_latent_rgb_proj[4][3]{
863867
{-0.2829, 0.1762, 0.2721},
864868
{-0.2120, -0.2616, -0.7177}};
865869

866-
void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version) {
870+
void proj_latents(struct ggml_tensor* latents, enum SDVersion version, uint8_t* data) {
867871
const int channel = 3;
868872
int width = latents->ne[0];
869873
int height = latents->ne[1];
@@ -876,7 +880,7 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
876880

877881
if (version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_2B) {
878882
latent_rgb_proj = sd3_latent_rgb_proj;
879-
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
883+
} else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
880884
latent_rgb_proj = flux_latent_rgb_proj;
881885
} else {
882886
// unknown model
@@ -897,7 +901,6 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
897901
// unknown latent space
898902
return;
899903
}
900-
uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
901904
int data_head = 0;
902905
for (int j = 0; j < height; j++) {
903906
for (int i = 0; i < width; i++) {
@@ -925,6 +928,15 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
925928
data[data_head++] = (uint8_t)(b * 255.);
926929
}
927930
}
931+
}
932+
933+
void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version) {
934+
const int channel = 3;
935+
int width = latents->ne[0];
936+
int height = latents->ne[1];
937+
int dim = latents->ne[2];
938+
uint8_t* data = (uint8_t*)malloc(width * height * channel * sizeof(uint8_t));
939+
proj_latents(latents, version, data);
928940
stbi_write_png("latent-preview.png", width, height, channel, data, 0);
929941
free(data);
930942
}
@@ -982,7 +994,7 @@ int main(int argc, const char* argv[]) {
982994

983995
try {
984996
std::string json_str = req.body;
985-
parseJsonPrompt(json_str, &params);
997+
parseJsonPrompt(json_str, params);
986998
} catch (json::parse_error& e) {
987999
// assume the request is just a prompt
9881000
// LOG_WARN("Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what());

0 commit comments

Comments
 (0)