Skip to content

Commit 2ed8fca

Browse files
committed
server: update lib API
1 parent bc79fb6 commit 2ed8fca

File tree

1 file changed

+37
-6
lines changed

1 file changed

+37
-6
lines changed

examples/server/main.cpp

+37-6
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ struct SDRequestParams {
142142
float skip_layer_end = 0.2;
143143
bool normalize_input = false;
144144

145+
float apg_eta = 1.0f;
146+
float apg_momentum = 0.0f;
147+
float apg_norm_threshold = 0.0f;
148+
float apg_norm_smoothing = 0.0f;
149+
145150
sd_preview_t preview_method = SD_PREVIEW_NONE;
146151
int preview_interval = 1;
147152
};
@@ -742,6 +747,28 @@ std::string get_image_params(SDParams params, int64_t seed) {
742747
}
743748
parameter_string += "Steps: " + std::to_string(params.lastRequest.sample_steps) + ", ";
744749
parameter_string += "CFG scale: " + std::to_string(params.lastRequest.cfg_scale) + ", ";
750+
if (params.lastRequest.apg_eta != 1) {
751+
parameter_string += "APG eta: " + std::to_string(params.lastRequest.apg_eta) + ", ";
752+
}
753+
if (params.lastRequest.apg_momentum != 0) {
754+
parameter_string += "CFG momentum: " + std::to_string(params.lastRequest.apg_momentum) + ", ";
755+
}
756+
if (params.lastRequest.apg_norm_threshold != 0) {
757+
parameter_string += "CFG normalization threshold: " + std::to_string(params.lastRequest.apg_norm_threshold) + ", ";
758+
if (params.lastRequest.apg_norm_smoothing != 0) {
759+
parameter_string += "CFG normalization threshold: " + std::to_string(params.lastRequest.apg_norm_smoothing) + ", ";
760+
}
761+
}
762+
if (params.lastRequest.slg_scale != 0 && params.lastRequest.skip_layers.size() != 0) {
763+
parameter_string += "SLG scale: " + std::to_string(params.lastRequest.cfg_scale) + ", ";
764+
parameter_string += "Skip layers: [";
765+
for (const auto& layer : params.lastRequest.skip_layers) {
766+
parameter_string += std::to_string(layer) + ", ";
767+
}
768+
parameter_string += "], ";
769+
parameter_string += "Skip layer start: " + std::to_string(params.lastRequest.skip_layer_start) + ", ";
770+
parameter_string += "Skip layer end: " + std::to_string(params.lastRequest.skip_layer_end) + ", ";
771+
}
745772
parameter_string += "Guidance: " + std::to_string(params.lastRequest.guidance) + ", ";
746773
parameter_string += "Seed: " + std::to_string(seed) + ", ";
747774
parameter_string += "Size: " + std::to_string(params.lastRequest.width) + "x" + std::to_string(params.lastRequest.height) + ", ";
@@ -1151,7 +1178,7 @@ bool parseJsonPrompt(std::string json_str, SDParams* params) {
11511178
}
11521179
} catch (...) {
11531180
}
1154-
1181+
// TODO SLG and APG params
11551182
return updatectx;
11561183
}
11571184

@@ -1427,11 +1454,15 @@ void start_server(SDParams params) {
14271454
params.lastRequest.style_ratio,
14281455
params.lastRequest.normalize_input,
14291456
params.input_id_images_path.c_str(),
1430-
params.lastRequest.skip_layers.data(),
1431-
params.lastRequest.skip_layers.size(),
1432-
params.lastRequest.slg_scale,
1433-
params.lastRequest.skip_layer_start,
1434-
params.lastRequest.skip_layer_end);
1457+
sd_slg_params_t{params.lastRequest.skip_layers.data(),
1458+
params.lastRequest.skip_layers.size(),
1459+
params.lastRequest.slg_scale,
1460+
params.lastRequest.skip_layer_start,
1461+
params.lastRequest.skip_layer_end},
1462+
sd_apg_params_t{params.lastRequest.apg_eta,
1463+
params.lastRequest.apg_momentum,
1464+
params.lastRequest.apg_norm_threshold,
1465+
params.lastRequest.apg_norm_smoothing});
14351466

14361467
if (results == NULL) {
14371468
printf("generate failed\n");

0 commit comments

Comments
 (0)