@@ -40,6 +40,8 @@ const char* sample_method_str[] = {
40
40
" dpm++2s_a" ,
41
41
" dpm++2m" ,
42
42
" dpm++2mv2" ,
43
+ " ipndm" ,
44
+ " ipndm_v" ,
43
45
" lcm" ,
44
46
};
45
47
@@ -48,7 +50,9 @@ const char* schedule_str[] = {
48
50
" default" ,
49
51
" discrete" ,
50
52
" karras" ,
53
+ " exponential" ,
51
54
" ays" ,
55
+ " gits" ,
52
56
};
53
57
54
58
enum SDMode {
@@ -676,45 +680,45 @@ static void log_server_request(const httplib::Request& req, const httplib::Respo
676
680
printf (" request: %s %s (%s)\n " , req.method .c_str (), req.path .c_str (), req.body .c_str ());
677
681
}
678
682
679
- void parseJsonPrompt (std::string json_str, SDParams* params) {
683
+ void parseJsonPrompt (std::string json_str, SDParams& params) {
680
684
using namespace nlohmann ;
681
685
json payload = json::parse (json_str);
682
686
// if no exception, the request is a json object
683
687
// now we try to get the new param values from the payload object
684
688
// const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path
685
689
try {
686
690
std::string prompt = payload[" prompt" ];
687
- params-> prompt = prompt;
691
+ params. prompt = prompt;
688
692
} catch (...) {
689
693
}
690
694
try {
691
695
std::string negative_prompt = payload[" negative_prompt" ];
692
- params-> negative_prompt = negative_prompt;
696
+ params. negative_prompt = negative_prompt;
693
697
} catch (...) {
694
698
}
695
699
try {
696
- int clip_skip = payload[" clip_skip" ];
697
- params-> clip_skip = clip_skip;
700
+ int clip_skip = payload[" clip_skip" ];
701
+ params. clip_skip = clip_skip;
698
702
} catch (...) {
699
703
}
700
704
try {
701
- float cfg_scale = payload[" cfg_scale" ];
702
- params-> cfg_scale = cfg_scale;
705
+ float cfg_scale = payload[" cfg_scale" ];
706
+ params. cfg_scale = cfg_scale;
703
707
} catch (...) {
704
708
}
705
709
try {
706
- float guidance = payload[" guidance" ];
707
- params-> guidance = guidance;
710
+ float guidance = payload[" guidance" ];
711
+ params. guidance = guidance;
708
712
} catch (...) {
709
713
}
710
714
try {
711
- int width = payload[" width" ];
712
- params-> width = width;
715
+ int width = payload[" width" ];
716
+ params. width = width;
713
717
} catch (...) {
714
718
}
715
719
try {
716
- int height = payload[" height" ];
717
- params-> height = height;
720
+ int height = payload[" height" ];
721
+ params. height = height;
718
722
} catch (...) {
719
723
}
720
724
try {
@@ -727,25 +731,25 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
727
731
}
728
732
}
729
733
if (sample_method_found >= 0 ) {
730
- params-> sample_method = (sample_method_t )sample_method_found;
734
+ params. sample_method = (sample_method_t )sample_method_found;
731
735
} else {
732
736
sd_log (sd_log_level_t ::SD_LOG_WARN, " Unknown sampling method: %s\n " , sample_method.c_str ());
733
737
}
734
738
} catch (...) {
735
739
}
736
740
try {
737
- int sample_steps = payload[" sample_steps" ];
738
- params-> sample_steps = sample_steps;
741
+ int sample_steps = payload[" sample_steps" ];
742
+ params. sample_steps = sample_steps;
739
743
} catch (...) {
740
744
}
741
745
try {
742
746
int64_t seed = payload[" seed" ];
743
- params-> seed = seed;
747
+ params. seed = seed;
744
748
} catch (...) {
745
749
}
746
750
try {
747
- int batch_count = payload[" batch_count" ];
748
- params-> batch_count = batch_count;
751
+ int batch_count = payload[" batch_count" ];
752
+ params. batch_count = batch_count;
749
753
} catch (...) {
750
754
}
751
755
@@ -759,53 +763,53 @@ void parseJsonPrompt(std::string json_str, SDParams* params) {
759
763
}
760
764
try {
761
765
float control_strength = payload[" control_strength" ];
762
- // params-> control_strength = control_strength;
766
+ // params. control_strength = control_strength;
763
767
// LOG_WARN("control_strength is not supported yet\n");
764
768
sd_log (sd_log_level_t ::SD_LOG_WARN, " control_strength is not supported yet\n " , params);
765
769
} catch (...) {
766
770
}
767
771
try {
768
772
float style_strength = payload[" style_strength" ];
769
- // params-> style_strength = style_strength;
773
+ // params. style_strength = style_strength;
770
774
// LOG_WARN("style_strength is not supported yet\n");
771
775
sd_log (sd_log_level_t ::SD_LOG_WARN, " style_strength is not supported yet\n " , params);
772
776
} catch (...) {
773
777
}
774
778
try {
775
- bool normalize_input = payload[" normalize_input" ];
776
- params-> normalize_input = normalize_input;
779
+ bool normalize_input = payload[" normalize_input" ];
780
+ params. normalize_input = normalize_input;
777
781
} catch (...) {
778
782
}
779
783
try {
780
784
std::string input_id_images_path = payload[" input_id_images_path" ];
781
785
// TODO replace with b64 image maybe?
782
- params-> input_id_images_path = input_id_images_path;
786
+ params. input_id_images_path = input_id_images_path;
783
787
} catch (...) {
784
788
}
785
789
try {
786
790
std::string slg_scale = payload[" slg_scale" ];
787
- params-> slg_scale = stof (slg_scale);
791
+ params. slg_scale = stof (slg_scale);
788
792
} catch (...) {
789
793
}
790
794
// TODO: more slg settings (layers, start and end)
791
795
try {
792
796
std::vector<int > skip_layers = payload[" skip_layers" ].get <std::vector<int >>();
793
- params-> skip_layers .clear ();
797
+ params. skip_layers .clear ();
794
798
for (int i = 0 ; i < skip_layers.size (); i++) {
795
- params-> skip_layers .push_back (skip_layers[i]);
799
+ params. skip_layers .push_back (skip_layers[i]);
796
800
}
797
801
} catch (...) {
798
802
}
799
803
try {
800
804
// skip_layer_start
801
- float skip_layer_start = payload[" skip_layer_start" ].get <float >();
802
- params-> skip_layer_start = skip_layer_start;
805
+ float skip_layer_start = payload[" skip_layer_start" ].get <float >();
806
+ params. skip_layer_start = skip_layer_start;
803
807
} catch (...) {
804
808
}
805
809
try {
806
810
// skip_layer_end
807
- float skip_layer_end = payload[" skip_layer_end" ].get <float >();
808
- params-> skip_layer_end = skip_layer_end;
811
+ float skip_layer_end = payload[" skip_layer_end" ].get <float >();
812
+ params. skip_layer_end = skip_layer_end;
809
813
} catch (...) {
810
814
}
811
815
}
@@ -863,7 +867,7 @@ const float sd_latent_rgb_proj[4][3]{
863
867
{-0.2829 , 0.1762 , 0.2721 },
864
868
{-0.2120 , -0.2616 , -0.7177 }};
865
869
866
- void step_callback ( int step, struct ggml_tensor * latents, enum SDVersion version) {
870
+ void proj_latents ( struct ggml_tensor * latents, enum SDVersion version, uint8_t * data ) {
867
871
const int channel = 3 ;
868
872
int width = latents->ne [0 ];
869
873
int height = latents->ne [1 ];
@@ -876,7 +880,7 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
876
880
877
881
if (version == VERSION_SD3_5_2B || version == VERSION_SD3_5_8B || version == VERSION_SD3_2B) {
878
882
latent_rgb_proj = sd3_latent_rgb_proj;
879
- } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) {
883
+ } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE ) {
880
884
latent_rgb_proj = flux_latent_rgb_proj;
881
885
} else {
882
886
// unknown model
@@ -897,7 +901,6 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
897
901
// unknown latent space
898
902
return ;
899
903
}
900
- uint8_t * data = (uint8_t *)malloc (width * height * channel * sizeof (uint8_t ));
901
904
int data_head = 0 ;
902
905
for (int j = 0 ; j < height; j++) {
903
906
for (int i = 0 ; i < width; i++) {
@@ -925,6 +928,15 @@ void step_callback(int step, struct ggml_tensor* latents, enum SDVersion version
925
928
data[data_head++] = (uint8_t )(b * 255 .);
926
929
}
927
930
}
931
+ }
932
+
933
+ void step_callback (int step, struct ggml_tensor * latents, enum SDVersion version) {
934
+ const int channel = 3 ;
935
+ int width = latents->ne [0 ];
936
+ int height = latents->ne [1 ];
937
+ int dim = latents->ne [2 ];
938
+ uint8_t * data = (uint8_t *)malloc (width * height * channel * sizeof (uint8_t ));
939
+ proj_latents (latents, version, data);
928
940
stbi_write_png (" latent-preview.png" , width, height, channel, data, 0 );
929
941
free (data);
930
942
}
@@ -982,7 +994,7 @@ int main(int argc, const char* argv[]) {
982
994
983
995
try {
984
996
std::string json_str = req.body ;
985
- parseJsonPrompt (json_str, & params);
997
+ parseJsonPrompt (json_str, params);
986
998
} catch (json::parse_error& e) {
987
999
// assume the request is just a prompt
988
1000
// LOG_WARN("Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what());
0 commit comments