@@ -36,32 +36,12 @@ enum class HttpMethod {
36
36
37
37
class HttpRequest {
38
38
public:
39
- std::string path;
40
- std::unordered_map<std::string, std::string> headers;
41
- std::string body;
42
- json parsedJson;
43
- HttpMethod method;
44
- std::string getMethod () {
45
- switch (method) {
46
- case HttpMethod::METHOD_GET:
47
- return " GET" ;
48
- case HttpMethod::METHOD_POST:
49
- return " POST" ;
50
- case HttpMethod::METHOD_PUT:
51
- return " PUT" ;
52
- case HttpMethod::METHOD_DELETE:
53
- return " DELETE" ;
54
- case HttpMethod::METHOD_UNKNOWN:
55
- default :
56
- return " UNKNOWN" ;
57
- }
58
- }
59
- };
39
+ static HttpRequest read (Socket& socket) {
40
+ HttpRequest req (&socket);
60
41
61
- class HttpParser {
62
- public:
63
- static HttpRequest parseRequest (const std::string& request) {
64
- HttpRequest httpRequest;
42
+ std::vector<char > httpRequest = socket.readHttpRequest ();
43
+ // Parse the HTTP request
44
+ std::string request = std::string (httpRequest.begin (), httpRequest.end ());
65
45
66
46
// Split request into lines
67
47
std::istringstream iss (request);
@@ -72,8 +52,8 @@ class HttpParser {
72
52
std::istringstream lineStream (line);
73
53
std::string methodStr, path;
74
54
lineStream >> methodStr >> path;
75
- httpRequest .method = parseMethod (methodStr);
76
- httpRequest .path = path;
55
+ req .method = parseMethod (methodStr);
56
+ req .path = path;
77
57
78
58
// Parse headers
79
59
while (std::getline (iss, line) && line != " \r " ) {
@@ -85,56 +65,102 @@ class HttpParser {
85
65
value.erase (std::remove_if (value.begin (), value.end (), [](unsigned char c) {
86
66
return std::isspace (c) || !std::isprint (c);
87
67
}), value.end ());
88
- httpRequest .headers [key] = value;
68
+ req .headers [key] = value;
89
69
}
90
70
}
91
71
92
72
// Parse body
93
- std::getline (iss, httpRequest .body , ' \0 ' );
73
+ std::getline (iss, req .body , ' \0 ' );
94
74
95
- if (httpRequest .body .size () > 0 ) {
75
+ if (req .body .size () > 0 ) {
96
76
// printf("body: %s\n", httpRequest.body.c_str());
97
- httpRequest .parsedJson = json::parse (httpRequest .body );
77
+ req .parsedJson = json::parse (req .body );
98
78
}
99
- return httpRequest ;
79
+ return req ;
100
80
}
101
- private:
81
+
102
82
static HttpMethod parseMethod (const std::string& method) {
103
- if (method == " GET" ) {
104
- return HttpMethod::METHOD_GET;
105
- } else if (method == " POST" ) {
106
- return HttpMethod::METHOD_POST;
107
- } else if (method == " PUT" ) {
108
- return HttpMethod::METHOD_PUT;
109
- } else if (method == " DELETE" ) {
110
- return HttpMethod::METHOD_DELETE;
111
- } else {
112
- return HttpMethod::METHOD_UNKNOWN;
113
- }
83
+ if (method == " GET" ) return HttpMethod::METHOD_GET;
84
+ if (method == " POST" ) return HttpMethod::METHOD_POST;
85
+ if (method == " PUT" ) return HttpMethod::METHOD_PUT;
86
+ if (method == " DELETE" ) return HttpMethod::METHOD_DELETE;
87
+ return HttpMethod::METHOD_UNKNOWN;
88
+ }
89
+
90
+ private:
91
+ Socket* socket;
92
+ public:
93
+ std::string path;
94
+ std::unordered_map<std::string, std::string> headers;
95
+ std::string body;
96
+ json parsedJson;
97
+ HttpMethod method;
98
+
99
+ HttpRequest (Socket* socket) {
100
+ this ->socket = socket;
101
+ }
102
+
103
+ std::string getMethod () {
104
+ if (method == HttpMethod::METHOD_GET) return " GET" ;
105
+ if (method == HttpMethod::METHOD_POST) return " POST" ;
106
+ if (method == HttpMethod::METHOD_PUT) return " PUT" ;
107
+ if (method == HttpMethod::METHOD_DELETE) return " DELETE" ;
108
+ return " UNKNOWN" ;
109
+ }
110
+
111
+ void writeNotFound () {
112
+ const char * data = " HTTP/1.1 404 Not Found\r\n " ;
113
+ socket->write (data, strlen (data));
114
+ }
115
+
116
+ void writeJson (std::string json) {
117
+ std::ostringstream buffer;
118
+ buffer << " HTTP/1.1 200 OK\r\n "
119
+ << " Content-Type: application/json; charset=utf-8\r\n "
120
+ << " Content-Length: " << json.length () << " \r\n\r\n " << json;
121
+ std::string data = buffer.str ();
122
+ socket->write (data.c_str (), data.size ());
123
+ }
124
+
125
+ void writeStreamStartChunk () {
126
+ std::ostringstream buffer;
127
+ buffer << " HTTP/1.1 200 OK\r\n "
128
+ << " Content-Type: text/event-stream; charset=utf-8\r\n "
129
+ << " Connection: close\r\n "
130
+ << " Transfer-Encoding: chunked\r\n\r\n " ;
131
+ std::string data = buffer.str ();
132
+ socket->write (data.c_str (), data.size ());
133
+ }
134
+
135
+ void writeStreamChunk (const std::string data) {
136
+ std::ostringstream buffer;
137
+ buffer << std::hex << data.size () << " \r\n " << data << " \r\n " ;
138
+ std::string d = buffer.str ();
139
+ socket->write (d.c_str (), d.size ());
140
+ }
141
+
142
+ void writeStreamEndChunk () {
143
+ const char * endChunk = " 0000\r\n\r\n " ;
144
+ socket->write (endChunk, strlen (endChunk));
114
145
}
115
146
};
116
147
117
148
struct Route {
118
149
std::string path;
119
150
HttpMethod method;
120
- std::function<void (Socket&, HttpRequest&)> handler;
151
+ std::function<void (HttpRequest&)> handler;
121
152
};
122
153
123
154
class Router {
124
155
public:
125
- static void routeRequest (Socket& client_socket, HttpRequest& request, std::vector<Route>& routes) {
156
+ static void resolve ( HttpRequest& request, std::vector<Route>& routes) {
126
157
for (const auto & route : routes) {
127
158
if (request.method == route.method && request.path == route.path ) {
128
- route.handler (client_socket, request);
159
+ route.handler (request);
129
160
return ;
130
161
}
131
162
}
132
- notFoundHandler (client_socket);
133
- }
134
- private:
135
- static void notFoundHandler (Socket& client_socket) {
136
- const char * data = " HTTP/1.1 404 Not Found\r\n " ;
137
- client_socket.write (data, strlen (data));
163
+ request.writeNotFound ();
138
164
}
139
165
};
140
166
@@ -154,38 +180,7 @@ std::string buildChatPrompt(Tokenizer *tokenizer, const std::vector<ChatMessage>
154
180
return oss.str ();
155
181
}
156
182
157
- void writeJsonResponse (Socket& socket, std::string json) {
158
- std::ostringstream buffer;
159
- buffer << " HTTP/1.1 200 OK\r\n "
160
- << " Content-Type: application/json; charset=utf-8\r\n "
161
- << " Content-Length: " << json.length () << " \r\n\r\n " << json;
162
- std::string data = buffer.str ();
163
- socket.write (data.c_str (), data.size ());
164
- }
165
-
166
- void writeStreamStartChunk (Socket& socket) {
167
- std::ostringstream buffer;
168
- buffer << " HTTP/1.1 200 OK\r\n "
169
- << " Content-Type: text/event-stream; charset=utf-8\r\n "
170
- << " Connection: close\r\n "
171
- << " Transfer-Encoding: chunked\r\n\r\n " ;
172
- std::string data = buffer.str ();
173
- socket.write (data.c_str (), data.size ());
174
- }
175
-
176
- void writeStreamChunk (Socket& socket, const std::string data) {
177
- std::ostringstream buffer;
178
- buffer << std::hex << data.size () << " \r\n " << data << " \r\n " ;
179
- std::string d = buffer.str ();
180
- socket.write (d.c_str (), d.size ());
181
- }
182
-
183
- void writeStreamEndChunk (Socket& socket) {
184
- const char * endChunk = " 0000\r\n\r\n " ;
185
- socket.write (endChunk, strlen (endChunk));
186
- }
187
-
188
- void writeChatCompletionChunk (Socket &socket, const std::string &delta, const bool stop){
183
+ void writeChatCompletionChunk (HttpRequest &request, const std::string &delta, const bool stop){
189
184
ChunkChoice choice;
190
185
if (stop) {
191
186
choice.finish_reason = " stop" ;
@@ -196,15 +191,15 @@ void writeChatCompletionChunk(Socket &socket, const std::string &delta, const bo
196
191
197
192
std::ostringstream buffer;
198
193
buffer << " data: " << ((json)chunk).dump () << " \r\n\r\n " ;
199
- writeStreamChunk (socket, buffer.str ());
194
+ request. writeStreamChunk (buffer.str ());
200
195
201
196
if (stop) {
202
- writeStreamChunk (socket, " data: [DONE]" );
203
- writeStreamEndChunk (socket );
197
+ request. writeStreamChunk (" data: [DONE]" );
198
+ request. writeStreamEndChunk ();
204
199
}
205
200
}
206
201
207
- void handleCompletionsRequest (Socket& socket, HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
202
+ void handleCompletionsRequest (HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
208
203
InferenceParams params;
209
204
params.temperature = args->temperature ;
210
205
params.top_p = args->topp ;
@@ -244,7 +239,7 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
244
239
generated.get_allocator ().allocate (params.max_tokens );
245
240
246
241
if (params.stream ) {
247
- writeStreamStartChunk (socket );
242
+ request. writeStreamStartChunk ();
248
243
}
249
244
250
245
int promptLength = params.prompt .length ();
@@ -298,7 +293,7 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
298
293
299
294
generated.push_back (string);
300
295
if (params.stream ) {
301
- writeChatCompletionChunk (socket , string, false );
296
+ writeChatCompletionChunk (request , string, false );
302
297
}
303
298
}
304
299
}
@@ -310,16 +305,16 @@ void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* i
310
305
completion.usage = ChatUsage (nPromptTokens, generated.size (), nPromptTokens + generated.size ());
311
306
312
307
std::string chatJson = ((json)completion).dump ();
313
- writeJsonResponse (socket, chatJson);
308
+ request. writeJson ( chatJson);
314
309
} else {
315
- writeChatCompletionChunk (socket , " " , true );
310
+ writeChatCompletionChunk (request , " " , true );
316
311
}
317
312
printf (" 🔶\n " );
318
313
fflush (stdout);
319
314
}
320
315
321
- void handleModelsRequest (Socket& client_socket, HttpRequest& request) {
322
- writeJsonResponse (client_socket,
316
+ void handleModelsRequest (HttpRequest& request) {
317
+ request. writeJson (
323
318
" { \" object\" : \" list\" ,"
324
319
" \" data\" : ["
325
320
" { \" id\" : \" dl\" , \" object\" : \" model\" , \" created\" : 0, \" owned_by\" : \" user\" }"
@@ -334,26 +329,21 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,
334
329
{
335
330
" /v1/chat/completions" ,
336
331
HttpMethod::METHOD_POST,
337
- std::bind (&handleCompletionsRequest, std::placeholders::_1, std::placeholders::_2, inference, tokenizer, sampler, args, spec)
332
+ std::bind (&handleCompletionsRequest, std::placeholders::_1, inference, tokenizer, sampler, args, spec)
338
333
},
339
334
{
340
335
" /v1/models" ,
341
336
HttpMethod::METHOD_GET,
342
- std::bind (&handleModelsRequest, std::placeholders::_1, std::placeholders::_2 )
337
+ std::bind (&handleModelsRequest, std::placeholders::_1)
343
338
}
344
339
};
345
340
346
341
while (true ) {
347
342
try {
348
- // Accept incoming connection
349
343
Socket client = server->accept ();
350
- // Read the HTTP request
351
- std::vector<char > httpRequest = client.readHttpRequest ();
352
- // Parse the HTTP request
353
- HttpRequest request = HttpParser::parseRequest (std::string (httpRequest.begin (), httpRequest.end ()));
354
- // Handle the HTTP request
344
+ HttpRequest request = HttpRequest::read (client);
355
345
printf (" 🔷 %s %s\n " , request.getMethod ().c_str (), request.path .c_str ());
356
- Router::routeRequest (client, request, routes);
346
+ Router::resolve ( request, routes);
357
347
} catch (ReadSocketException& ex) {
358
348
printf (" Read socket error: %d %s\n " , ex.code , ex.message );
359
349
} catch (WriteSocketException& ex) {
0 commit comments