Skip to content

Commit 9ba399d

Browse files
elk-clonerngxson
andauthored
server : add support for "encoding_format": "base64" to the */embeddings endpoints (ggml-org#10967)
* add support for base64 * fix base64 test * improve test --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
1 parent 2cd43f4 commit 9ba399d

File tree

4 files changed

+76
-7
lines changed

4 files changed

+76
-7
lines changed

Diff for: examples/server/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ endforeach()
3434
add_executable(${TARGET} ${TARGET_SRCS})
3535
install(TARGETS ${TARGET} RUNTIME)
3636

37+
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
3738
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
3839

3940
if (LLAMA_SERVER_SSL)

Diff for: examples/server/server.cpp

+12-1
Original file line numberDiff line numberDiff line change
@@ -3790,6 +3790,17 @@ int main(int argc, char ** argv) {
37903790
return;
37913791
}
37923792

3793+
bool use_base64 = false;
3794+
if (body.count("encoding_format") != 0) {
3795+
const std::string& format = body.at("encoding_format");
3796+
if (format == "base64") {
3797+
use_base64 = true;
3798+
} else if (format != "float") {
3799+
res_error(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST));
3800+
return;
3801+
}
3802+
}
3803+
37933804
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
37943805
for (const auto & tokens : tokenized_prompts) {
37953806
// this check is necessary for models that do not add BOS token to the input
@@ -3841,7 +3852,7 @@ int main(int argc, char ** argv) {
38413852
}
38423853

38433854
// write JSON response
3844-
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
3855+
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
38453856
res_ok(res, root);
38463857
};
38473858

Diff for: examples/server/tests/unit/test_embedding.py

+41
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import base64
2+
import struct
13
import pytest
24
from openai import OpenAI
35
from utils import *
@@ -194,3 +196,42 @@ def test_embedding_usage_multiple():
194196
assert res.status_code == 200
195197
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
196198
assert res.body['usage']['prompt_tokens'] == 2 * 9
199+
200+
201+
def test_embedding_openai_library_base64():
202+
server.start()
203+
test_input = "Test base64 embedding output"
204+
205+
# get embedding in default format
206+
res = server.make_request("POST", "/v1/embeddings", data={
207+
"input": test_input
208+
})
209+
assert res.status_code == 200
210+
vec0 = res.body["data"][0]["embedding"]
211+
212+
# get embedding in base64 format
213+
res = server.make_request("POST", "/v1/embeddings", data={
214+
"input": test_input,
215+
"encoding_format": "base64"
216+
})
217+
218+
assert res.status_code == 200
219+
assert "data" in res.body
220+
assert len(res.body["data"]) == 1
221+
222+
embedding_data = res.body["data"][0]
223+
assert "embedding" in embedding_data
224+
assert isinstance(embedding_data["embedding"], str)
225+
226+
# Verify embedding is valid base64
227+
decoded = base64.b64decode(embedding_data["embedding"])
228+
# Verify decoded data can be converted back to float array
229+
float_count = len(decoded) // 4 # 4 bytes per float
230+
floats = struct.unpack(f'{float_count}f', decoded)
231+
assert len(floats) > 0
232+
assert all(isinstance(x, float) for x in floats)
233+
assert len(floats) == len(vec0)
234+
235+
# make sure the decoded data is the same as the original
236+
for x, y in zip(floats, vec0):
237+
assert abs(x - y) < EPSILON

Diff for: examples/server/utils.hpp

+22-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "common.h"
44
#include "log.h"
55
#include "llama.h"
6+
#include "common/base64.hpp"
67

78
#ifndef NDEBUG
89
// crash the server in debug mode, otherwise send an http 500 error
@@ -613,16 +614,31 @@ static json oaicompat_completion_params_parse(
613614
return llama_params;
614615
}
615616

616-
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
617+
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
617618
json data = json::array();
618619
int32_t n_tokens = 0;
619620
int i = 0;
620621
for (const auto & elem : embeddings) {
621-
data.push_back(json{
622-
{"embedding", json_value(elem, "embedding", json::array())},
623-
{"index", i++},
624-
{"object", "embedding"}
625-
});
622+
json embedding_obj;
623+
624+
if (use_base64) {
625+
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
626+
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
627+
size_t data_size = vec.size() * sizeof(float);
628+
embedding_obj = {
629+
{"embedding", base64::encode(data_ptr, data_size)},
630+
{"index", i++},
631+
{"object", "embedding"},
632+
{"encoding_format", "base64"}
633+
};
634+
} else {
635+
embedding_obj = {
636+
{"embedding", json_value(elem, "embedding", json::array())},
637+
{"index", i++},
638+
{"object", "embedding"}
639+
};
640+
}
641+
data.push_back(embedding_obj);
626642

627643
n_tokens += json_value(elem, "tokens_evaluated", 0);
628644
}

0 commit comments

Comments
 (0)