Skip to content

Commit e41cfc7

Browse files
committed
llama: Refactor string_split to use template specialization, fixes parsing strings with spaces
1 parent 190a37d commit e41cfc7

File tree

4 files changed

+22
-21
lines changed

4 files changed

+22
-21
lines changed

common/arg.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,13 @@ static void common_params_handle_model_default(common_params & params) {
128128
}
129129
params.hf_file = params.model;
130130
} else if (params.model.empty()) {
131-
params.model = fs_get_cache_file(string_split(params.hf_file, '/').back());
131+
params.model = fs_get_cache_file(string_split<std::string>(params.hf_file, '/').back());
132132
}
133133
} else if (!params.model_url.empty()) {
134134
if (params.model.empty()) {
135-
auto f = string_split(params.model_url, '#').front();
136-
f = string_split(f, '?').front();
137-
params.model = fs_get_cache_file(string_split(f, '/').back());
135+
auto f = string_split<std::string>(params.model_url, '#').front();
136+
f = string_split<std::string>(f, '?').front();
137+
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
138138
}
139139
} else if (params.model.empty()) {
140140
params.model = DEFAULT_MODEL_PATH;
@@ -879,7 +879,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
879879
{"--samplers"}, "SAMPLERS",
880880
string_format("samplers that will be used for generation in the order, separated by \';\'\n(default: %s)", sampler_type_names.c_str()),
881881
[](common_params & params, const std::string & value) {
882-
const auto sampler_names = string_split(value, ';');
882+
const auto sampler_names = string_split<std::string>(value, ';');
883883
params.sparams.samplers = common_sampler_types_from_names(sampler_names, true);
884884
}
885885
).set_sparam());

common/common.cpp

-13
Original file line numberDiff line numberDiff line change
@@ -416,19 +416,6 @@ std::string string_format(const char * fmt, ...) {
416416
return std::string(buf.data(), size);
417417
}
418418

419-
std::vector<std::string> string_split(std::string input, char separator) {
420-
std::vector<std::string> parts;
421-
size_t separator_pos = input.find(separator);
422-
while (separator_pos != std::string::npos) {
423-
std::string part = input.substr(0, separator_pos);
424-
parts.emplace_back(part);
425-
input = input.substr(separator_pos + 1);
426-
separator_pos = input.find(separator);
427-
}
428-
parts.emplace_back(input);
429-
return parts;
430-
}
431-
432419
std::string string_strip(const std::string & str) {
433420
size_t start = 0;
434421
size_t end = str.size();

common/common.h

+16-2
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,6 @@ bool set_process_priority(enum ggml_sched_priority prio);
380380
LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
381381
std::string string_format(const char * fmt, ...);
382382

383-
std::vector<std::string> string_split(std::string input, char separator);
384-
385383
std::string string_strip(const std::string & str);
386384
std::string string_get_sortable_timestamp();
387385

@@ -401,6 +399,22 @@ static std::vector<T> string_split(const std::string & str, char delim) {
401399
return values;
402400
}
403401

402+
template<>
403+
std::vector<std::string> string_split<std::string>(const std::string & input, char separator)
404+
{
405+
std::vector<std::string> parts;
406+
size_t begin_pos = 0;
407+
size_t separator_pos = input.find(separator);
408+
while (separator_pos != std::string::npos) {
409+
std::string part = input.substr(begin_pos, separator_pos - begin_pos);
410+
parts.emplace_back(part);
411+
begin_pos = separator_pos + 1;
412+
separator_pos = input.find(separator, begin_pos);
413+
}
414+
parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos));
415+
return parts;
416+
}
417+
404418
bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
405419
void string_process_escapes(std::string & input);
406420

examples/server/server.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2612,7 +2612,7 @@ int main(int argc, char ** argv) {
26122612
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
26132613
server_state current_state = state.load();
26142614
if (current_state == SERVER_STATE_LOADING_MODEL) {
2615-
auto tmp = string_split(req.path, '.');
2615+
auto tmp = string_split<std::string>(req.path, '.');
26162616
if (req.path == "/" || tmp.back() == "html") {
26172617
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
26182618
res.status = 503;

0 commit comments

Comments
 (0)