Skip to content

Commit ad7498f

Browse files
authored
chore: refactor http request a bit. (b4rtaz#72)
1 parent df1d360 commit ad7498f

File tree

1 file changed

+94
-104
lines changed

1 file changed

+94
-104
lines changed

src/apps/dllama-api/dllama-api.cpp

Lines changed: 94 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -36,32 +36,12 @@ enum class HttpMethod {
3636

3737
class HttpRequest {
3838
public:
39-
std::string path;
40-
std::unordered_map<std::string, std::string> headers;
41-
std::string body;
42-
json parsedJson;
43-
HttpMethod method;
44-
std::string getMethod() {
45-
switch(method) {
46-
case HttpMethod::METHOD_GET:
47-
return "GET";
48-
case HttpMethod::METHOD_POST:
49-
return "POST";
50-
case HttpMethod::METHOD_PUT:
51-
return "PUT";
52-
case HttpMethod::METHOD_DELETE:
53-
return "DELETE";
54-
case HttpMethod::METHOD_UNKNOWN:
55-
default:
56-
return "UNKNOWN";
57-
}
58-
}
59-
};
39+
static HttpRequest read(Socket& socket) {
40+
HttpRequest req(&socket);
6041

61-
class HttpParser {
62-
public:
63-
static HttpRequest parseRequest(const std::string& request) {
64-
HttpRequest httpRequest;
42+
std::vector<char> httpRequest = socket.readHttpRequest();
43+
// Parse the HTTP request
44+
std::string request = std::string(httpRequest.begin(), httpRequest.end());
6545

6646
// Split request into lines
6747
std::istringstream iss(request);
@@ -72,8 +52,8 @@ class HttpParser {
7252
std::istringstream lineStream(line);
7353
std::string methodStr, path;
7454
lineStream >> methodStr >> path;
75-
httpRequest.method = parseMethod(methodStr);
76-
httpRequest.path = path;
55+
req.method = parseMethod(methodStr);
56+
req.path = path;
7757

7858
// Parse headers
7959
while (std::getline(iss, line) && line != "\r") {
@@ -85,56 +65,102 @@ class HttpParser {
8565
value.erase(std::remove_if(value.begin(), value.end(), [](unsigned char c) {
8666
return std::isspace(c) || !std::isprint(c);
8767
}), value.end());
88-
httpRequest.headers[key] = value;
68+
req.headers[key] = value;
8969
}
9070
}
9171

9272
// Parse body
93-
std::getline(iss, httpRequest.body, '\0');
73+
std::getline(iss, req.body, '\0');
9474

95-
if (httpRequest.body.size() > 0) {
75+
if (req.body.size() > 0) {
9676
// printf("body: %s\n", httpRequest.body.c_str());
97-
httpRequest.parsedJson = json::parse(httpRequest.body);
77+
req.parsedJson = json::parse(req.body);
9878
}
99-
return httpRequest;
79+
return req;
10080
}
101-
private:
81+
10282
static HttpMethod parseMethod(const std::string& method) {
103-
if (method == "GET") {
104-
return HttpMethod::METHOD_GET;
105-
} else if (method == "POST") {
106-
return HttpMethod::METHOD_POST;
107-
} else if (method == "PUT") {
108-
return HttpMethod::METHOD_PUT;
109-
} else if (method == "DELETE") {
110-
return HttpMethod::METHOD_DELETE;
111-
} else {
112-
return HttpMethod::METHOD_UNKNOWN;
113-
}
83+
if (method == "GET") return HttpMethod::METHOD_GET;
84+
if (method == "POST") return HttpMethod::METHOD_POST;
85+
if (method == "PUT") return HttpMethod::METHOD_PUT;
86+
if (method == "DELETE") return HttpMethod::METHOD_DELETE;
87+
return HttpMethod::METHOD_UNKNOWN;
88+
}
89+
90+
private:
91+
Socket* socket;
92+
public:
93+
std::string path;
94+
std::unordered_map<std::string, std::string> headers;
95+
std::string body;
96+
json parsedJson;
97+
HttpMethod method;
98+
99+
HttpRequest(Socket* socket) {
100+
this->socket = socket;
101+
}
102+
103+
std::string getMethod() {
104+
if (method == HttpMethod::METHOD_GET) return "GET";
105+
if (method == HttpMethod::METHOD_POST) return "POST";
106+
if (method == HttpMethod::METHOD_PUT) return "PUT";
107+
if (method == HttpMethod::METHOD_DELETE) return "DELETE";
108+
return "UNKNOWN";
109+
}
110+
111+
void writeNotFound() {
112+
const char* data = "HTTP/1.1 404 Not Found\r\n";
113+
socket->write(data, strlen(data));
114+
}
115+
116+
void writeJson(std::string json) {
117+
std::ostringstream buffer;
118+
buffer << "HTTP/1.1 200 OK\r\n"
119+
<< "Content-Type: application/json; charset=utf-8\r\n"
120+
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
121+
std::string data = buffer.str();
122+
socket->write(data.c_str(), data.size());
123+
}
124+
125+
void writeStreamStartChunk() {
126+
std::ostringstream buffer;
127+
buffer << "HTTP/1.1 200 OK\r\n"
128+
<< "Content-Type: text/event-stream; charset=utf-8\r\n"
129+
<< "Connection: close\r\n"
130+
<< "Transfer-Encoding: chunked\r\n\r\n";
131+
std::string data = buffer.str();
132+
socket->write(data.c_str(), data.size());
133+
}
134+
135+
void writeStreamChunk(const std::string data) {
136+
std::ostringstream buffer;
137+
buffer << std::hex << data.size() << "\r\n" << data << "\r\n";
138+
std::string d = buffer.str();
139+
socket->write(d.c_str(), d.size());
140+
}
141+
142+
void writeStreamEndChunk() {
143+
const char* endChunk = "0000\r\n\r\n";
144+
socket->write(endChunk, strlen(endChunk));
114145
}
115146
};
116147

117148
struct Route {
118149
std::string path;
119150
HttpMethod method;
120-
std::function<void(Socket&, HttpRequest&)> handler;
151+
std::function<void(HttpRequest&)> handler;
121152
};
122153

123154
class Router {
124155
public:
125-
static void routeRequest(Socket& client_socket, HttpRequest& request, std::vector<Route>& routes) {
156+
static void resolve(HttpRequest& request, std::vector<Route>& routes) {
126157
for (const auto& route : routes) {
127158
if (request.method == route.method && request.path == route.path) {
128-
route.handler(client_socket, request);
159+
route.handler(request);
129160
return;
130161
}
131162
}
132-
notFoundHandler(client_socket);
133-
}
134-
private:
135-
static void notFoundHandler(Socket& client_socket) {
136-
const char* data = "HTTP/1.1 404 Not Found\r\n";
137-
client_socket.write(data, strlen(data));
163+
request.writeNotFound();
138164
}
139165
};
140166

@@ -154,38 +180,7 @@ std::string buildChatPrompt(Tokenizer *tokenizer, const std::vector<ChatMessage>
154180
return oss.str();
155181
}
156182

157-
void writeJsonResponse(Socket& socket, std::string json) {
158-
std::ostringstream buffer;
159-
buffer << "HTTP/1.1 200 OK\r\n"
160-
<< "Content-Type: application/json; charset=utf-8\r\n"
161-
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
162-
std::string data = buffer.str();
163-
socket.write(data.c_str(), data.size());
164-
}
165-
166-
void writeStreamStartChunk(Socket& socket) {
167-
std::ostringstream buffer;
168-
buffer << "HTTP/1.1 200 OK\r\n"
169-
<< "Content-Type: text/event-stream; charset=utf-8\r\n"
170-
<< "Connection: close\r\n"
171-
<< "Transfer-Encoding: chunked\r\n\r\n";
172-
std::string data = buffer.str();
173-
socket.write(data.c_str(), data.size());
174-
}
175-
176-
void writeStreamChunk(Socket& socket, const std::string data) {
177-
std::ostringstream buffer;
178-
buffer << std::hex << data.size() << "\r\n" << data << "\r\n";
179-
std::string d = buffer.str();
180-
socket.write(d.c_str(), d.size());
181-
}
182-
183-
void writeStreamEndChunk(Socket& socket) {
184-
const char* endChunk = "0000\r\n\r\n";
185-
socket.write(endChunk, strlen(endChunk));
186-
}
187-
188-
void writeChatCompletionChunk(Socket &socket, const std::string &delta, const bool stop){
183+
void writeChatCompletionChunk(HttpRequest &request, const std::string &delta, const bool stop){
189184
ChunkChoice choice;
190185
if (stop) {
191186
choice.finish_reason = "stop";
@@ -196,15 +191,15 @@ void writeChatCompletionChunk(Socket &socket, const std::string &delta, const bo
196191

197192
std::ostringstream buffer;
198193
buffer << "data: " << ((json)chunk).dump() << "\r\n\r\n";
199-
writeStreamChunk(socket, buffer.str());
194+
request.writeStreamChunk(buffer.str());
200195

201196
if (stop) {
202-
writeStreamChunk(socket, "data: [DONE]");
203-
writeStreamEndChunk(socket);
197+
request.writeStreamChunk("data: [DONE]");
198+
request.writeStreamEndChunk();
204199
}
205200
}
206201

207-
void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
202+
void handleCompletionsRequest(HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
208203
InferenceParams params;
209204
params.temperature = args->temperature;
210205
params.top_p = args->topp;
@@ -244,7 +239,7 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
244239
generated.get_allocator().allocate(params.max_tokens);
245240

246241
if (params.stream) {
247-
writeStreamStartChunk(socket);
242+
request.writeStreamStartChunk();
248243
}
249244

250245
int promptLength = params.prompt.length();
@@ -298,7 +293,7 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
298293

299294
generated.push_back(string);
300295
if (params.stream) {
301-
writeChatCompletionChunk(socket, string, false);
296+
writeChatCompletionChunk(request, string, false);
302297
}
303298
}
304299
}
@@ -310,16 +305,16 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
310305
completion.usage = ChatUsage(nPromptTokens, generated.size(), nPromptTokens + generated.size());
311306

312307
std::string chatJson = ((json)completion).dump();
313-
writeJsonResponse(socket, chatJson);
308+
request.writeJson(chatJson);
314309
} else {
315-
writeChatCompletionChunk(socket, "", true);
310+
writeChatCompletionChunk(request, "", true);
316311
}
317312
printf("🔶\n");
318313
fflush(stdout);
319314
}
320315

321-
void handleModelsRequest(Socket& client_socket, HttpRequest& request) {
322-
writeJsonResponse(client_socket,
316+
void handleModelsRequest(HttpRequest& request) {
317+
request.writeJson(
323318
"{ \"object\": \"list\","
324319
"\"data\": ["
325320
"{ \"id\": \"dl\", \"object\": \"model\", \"created\": 0, \"owned_by\": \"user\" }"
@@ -334,26 +329,21 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,
334329
{
335330
"/v1/chat/completions",
336331
HttpMethod::METHOD_POST,
337-
std::bind(&handleCompletionsRequest, std::placeholders::_1, std::placeholders::_2, inference, tokenizer, sampler, args, spec)
332+
std::bind(&handleCompletionsRequest, std::placeholders::_1, inference, tokenizer, sampler, args, spec)
338333
},
339334
{
340335
"/v1/models",
341336
HttpMethod::METHOD_GET,
342-
std::bind(&handleModelsRequest, std::placeholders::_1, std::placeholders::_2)
337+
std::bind(&handleModelsRequest, std::placeholders::_1)
343338
}
344339
};
345340

346341
while (true) {
347342
try {
348-
// Accept incoming connection
349343
Socket client = server->accept();
350-
// Read the HTTP request
351-
std::vector<char> httpRequest = client.readHttpRequest();
352-
// Parse the HTTP request
353-
HttpRequest request = HttpParser::parseRequest(std::string(httpRequest.begin(), httpRequest.end()));
354-
// Handle the HTTP request
344+
HttpRequest request = HttpRequest::read(client);
355345
printf("🔷 %s %s\n", request.getMethod().c_str(), request.path.c_str());
356-
Router::routeRequest(client, request, routes);
346+
Router::resolve(request, routes);
357347
} catch (ReadSocketException& ex) {
358348
printf("Read socket error: %d %s\n", ex.code, ex.message);
359349
} catch (WriteSocketException& ex) {

0 commit comments

Comments
 (0)