13
13
14
14
from chain .processor import ChainProcessor
15
15
from orchestration import ContainerManager , DataStore , Guardian , Orchestrator
16
- from server .utils import is_local_ip
16
+ from server .exceptions import BadRequestError , ForbiddenError , RESTServerError , RitualNodeError
17
+ from server .utils import handle_errors , is_local_ip
17
18
from shared import AsyncTask , JobResult
18
19
from shared .config import ConfigChain , ConfigServer
19
20
from shared .message import (
@@ -159,11 +160,12 @@ async def info() -> Tuple[Response, int]:
159
160
),
160
161
200 ,
161
162
)
162
-
163
+
163
164
@self ._app .route ("/resources" , methods = ["GET" ])
164
165
@rate_limit (
165
166
self ._rate_limit .num_requests , timedelta (seconds = self ._rate_limit .period )
166
167
)
168
+ @handle_errors
167
169
async def resources () -> Tuple [Response , int ]:
168
170
"""Collects container resources
169
171
@@ -175,10 +177,14 @@ async def resources() -> Tuple[Response, int]:
175
177
"""
176
178
model_id = request .args .get ("model_id" )
177
179
178
- return (
179
- jsonify (await self ._orchestrator .collect_service_resources (model_id )),
180
- 200 ,
181
- )
180
+ try :
181
+ resources = await self ._orchestrator .collect_service_resources (model_id )
182
+ return jsonify (resources ), 200
183
+ except Exception as e :
184
+ raise RESTServerError (
185
+ "Failed to collect resources" ,
186
+ {"model_id" : model_id , "error" : str (e )}
187
+ )
182
188
183
189
def filter_create_job (func ): # type: ignore
184
190
"""Decorator to filter and preprocess incoming off-chain messages"""
@@ -198,20 +204,24 @@ async def wrapper() -> Any:
198
204
# Get the IP address of the client
199
205
client_ip = request .remote_addr
200
206
if not client_ip :
201
- return (
202
- jsonify ({"error" : "Could not get client IP address" }),
203
- 400 ,
204
- )
207
+ raise BadRequestError ("Could not get client IP address" )
205
208
206
209
# Parse message data, inject uuid and client IP
207
210
job_id = str (uuid4 ()) # Generate a unique job ID
208
211
log .debug (
209
212
"Received new off-chain raw message" , msg = data , job_id = job_id
210
213
)
211
- parsed : OffchainMessage = from_union (
212
- OffchainMessage ,
213
- {"id" : job_id , "ip" : client_ip , ** data },
214
- )
214
+
215
+ try :
216
+ parsed : OffchainMessage = from_union (
217
+ OffchainMessage ,
218
+ {"id" : job_id , "ip" : client_ip , ** data },
219
+ )
220
+ except Exception as e :
221
+ raise BadRequestError (
222
+ "Failed to parse message" ,
223
+ {"error" : str (e )}
224
+ )
215
225
216
226
# Filter message through guardian
217
227
filtered = self ._guardian .process_message (parsed )
@@ -225,19 +235,20 @@ async def wrapper() -> Any:
225
235
err = filtered .error ,
226
236
** filtered .params ,
227
237
)
228
- return (
229
- jsonify (
230
- {"error" : filtered .error , "params" : filtered .params }
231
- ),
232
- 405 ,
238
+ raise ForbiddenError (
239
+ filtered .error ,
240
+ filtered .params
233
241
)
234
242
235
243
# Call actual endpoint function
236
244
return await func (message = filtered )
237
245
246
+ except RitualNodeError :
247
+ # Pass through our custom exceptions
248
+ raise
238
249
except Exception as e :
239
- log .error (f"Error in endpoint preprocessing: { e } " )
240
- return jsonify ({ "error" : f"Internal server error: { str (e )} " }), 500
250
+ log .error (f"Error in endpoint preprocessing: { e } " , exc_info = True )
251
+ raise RESTServerError ( f"Internal server error" , { "message" : str (e )})
241
252
242
253
return wrapper
243
254
@@ -246,6 +257,7 @@ async def wrapper() -> Any:
246
257
@rate_limit (
247
258
self ._rate_limit .num_requests , timedelta (seconds = self ._rate_limit .period )
248
259
)
260
+ @handle_errors
249
261
async def create_job (message : OffchainMessage ) -> Tuple [Response , int ]:
250
262
"""Creates new off-chain job (direct compute request or subscription)
251
263
@@ -277,7 +289,7 @@ async def create_job(message: OffchainMessage) -> Tuple[Response, int]:
277
289
# Should only reach this point if chain is enabled (else, filtered
278
290
# out upstream)
279
291
if self ._processor is None :
280
- raise RuntimeError ("Chain not enabled" )
292
+ raise RESTServerError ("Chain not enabled" )
281
293
282
294
# Submit delegated subscription message to processor
283
295
create_task (self ._processor .track (message ))
@@ -286,6 +298,8 @@ async def create_job(message: OffchainMessage) -> Tuple[Response, int]:
286
298
# fetched via REST so it would be misleading. They are tracked
287
299
# on-chain instead
288
300
return_obj = {}
301
+ else :
302
+ raise BadRequestError ("Unsupported message type" , {"type" : message .type })
289
303
290
304
# Return created message ID
291
305
log .debug (
@@ -297,16 +311,14 @@ async def create_job(message: OffchainMessage) -> Tuple[Response, int]:
297
311
id = str (message .id ),
298
312
)
299
313
return jsonify (return_obj ), 200
314
+ except RitualNodeError :
315
+ # Pass through our custom exceptions
316
+ raise
300
317
except Exception as e :
301
- # Return error
302
- log .error (
303
- "Processed REST response" ,
304
- endpoint = request .path ,
305
- method = request .method ,
306
- status = 500 ,
307
- err = str (e ),
318
+ raise RESTServerError (
319
+ "Could not enqueue job" ,
320
+ {"error" : str (e )}
308
321
)
309
- return jsonify ({"error" : f"Could not enqueue job: { str (e )} " }), 500
310
322
311
323
@self ._app .route ("/api/jobs/stream" , methods = ["POST" ])
312
324
@filter_create_job # type: ignore
@@ -357,6 +369,7 @@ async def generator() -> Any:
357
369
@rate_limit (
358
370
self ._rate_limit .num_requests , timedelta (seconds = self ._rate_limit .period )
359
371
)
372
+ @handle_errors
360
373
async def create_job_batch () -> Tuple [Response , int ]:
361
374
"""Creates off-chain jobs in batch (direct compute requests / subscriptions)
362
375
@@ -365,27 +378,36 @@ async def create_job_batch() -> Tuple[Response, int]:
365
378
"""
366
379
try :
367
380
# Collect JSON body
368
- data = await request .get_json (force = True )
381
+ try :
382
+ data = await request .get_json (force = True )
383
+ except Exception :
384
+ raise BadRequestError ("Invalid JSON body" )
369
385
370
386
# Get the IP address of the client
371
387
client_ip = request .remote_addr
372
388
if not client_ip :
373
- return jsonify ({ "error" : " Could not get client IP address"}), 400
389
+ raise BadRequestError ( " Could not get client IP address")
374
390
375
391
log .debug ("Received new off-chain raw message batch" , msg = data )
376
392
377
393
# If data is not an array, return error
378
394
if not isinstance (data , list ):
379
- return jsonify ({ "error" : " Expected a list" }), 400
395
+ raise BadRequestError ( " Expected a list of job requests" )
380
396
381
397
# Inject uuid and client IP to each message
382
- parsed : list [OffchainMessage ] = [
383
- from_union (
384
- OffchainMessage ,
385
- {"id" : str (uuid4 ()), "ip" : client_ip , ** item },
398
+ try :
399
+ parsed : list [OffchainMessage ] = [
400
+ from_union (
401
+ OffchainMessage ,
402
+ {"id" : str (uuid4 ()), "ip" : client_ip , ** item },
403
+ )
404
+ for item in data
405
+ ]
406
+ except Exception as e :
407
+ raise BadRequestError (
408
+ "Failed to parse message batch" ,
409
+ {"error" : str (e )}
386
410
)
387
- for item in data
388
- ]
389
411
390
412
# Filter messages through guardian
391
413
filtered = cast (
@@ -419,7 +441,8 @@ async def create_job_batch() -> Tuple[Response, int]:
419
441
# Should only reach this point if chain is enabled (else,
420
442
# filtered out upstream)
421
443
if self ._processor is None :
422
- raise RuntimeError ("Chain not enabled" )
444
+ results .append ({"error" : "Chain not enabled" })
445
+ continue
423
446
424
447
# Submit filtered delegated subscription message to processor
425
448
create_task (
@@ -429,7 +452,9 @@ async def create_job_batch() -> Tuple[Response, int]:
429
452
)
430
453
results .append ({})
431
454
else :
432
- results .append ({"error" : "Could not parse message" })
455
+ results .append ({"error" : "Unsupported message type" })
456
+ else :
457
+ results .append ({"error" : "Invalid message format" })
433
458
434
459
# Return created message IDs or errors
435
460
log .debug (
@@ -440,26 +465,31 @@ async def create_job_batch() -> Tuple[Response, int]:
440
465
results = results ,
441
466
)
442
467
return jsonify (results ), 200
468
+
469
+ except RitualNodeError :
470
+ # Pass through our custom exceptions
471
+ raise
443
472
except Exception as e :
444
- # Return error
445
- log .error (
446
- "Processed REST response" ,
447
- endpoint = request .path ,
448
- method = request .method ,
449
- status = 500 ,
450
- err = str (e ),
473
+ raise RESTServerError (
474
+ "Failed to process batch request" ,
475
+ {"error" : str (e )}
451
476
)
452
- return jsonify ({"error" : f"Could not enqueue job: { str (e )} " }), 500
453
477
454
478
@self ._app .route ("/api/jobs" , methods = ["GET" ])
455
479
@rate_limit (
456
480
self ._rate_limit .num_requests , timedelta (seconds = self ._rate_limit .period )
457
481
)
482
+ @handle_errors
458
483
async def get_job () -> Tuple [Response , int ]:
484
+ """Get job status by ID or list all jobs for a client.
485
+
486
+ Returns:
487
+ Response: Job data or list of job IDs
488
+ """
459
489
# Get the IP address of the client
460
490
client_ip = request .remote_addr
461
491
if not client_ip :
462
- return jsonify ({ "error" : " Could not get client IP address"}), 400
492
+ raise BadRequestError ( " Could not get client IP address")
463
493
464
494
# Get job ID from query
465
495
job_ids = request .args .getlist ("id" )
@@ -485,6 +515,7 @@ async def get_job() -> Tuple[Response, int]:
485
515
return jsonify (data ), 200
486
516
487
517
@self ._app .route ("/api/status" , methods = ["PUT" ])
518
+ @handle_errors
488
519
async def store_job_status () -> Tuple [Response , int ]:
489
520
"""Stores job status in data store"""
490
521
@@ -494,56 +525,55 @@ async def store_job_status() -> Tuple[Response, int]:
494
525
"Unauthorized attempt to store job status" ,
495
526
remote_addr = request .remote_addr ,
496
527
)
497
- return jsonify ({"error" : "Unauthorized" }), 403
528
+ raise ForbiddenError ("Unauthorized access" ,
529
+ {"remote_addr" : str (request .remote_addr )})
498
530
499
- try :
500
- # Collect JSON body
501
- data = await request .get_json (force = True )
531
+ # Collect JSON body
532
+ data = await request .get_json (force = True )
502
533
503
- # Get the IP address of the client
504
- client_ip = request .remote_addr
505
- if not client_ip :
506
- return jsonify ({"error" : "Could not get client IP address" }), 400
507
-
508
- log .debug ("Received new result" , result = data )
509
-
510
- # Create off-chain message with client IP
511
- parsed : OffchainMessage = from_union (
512
- OffchainMessage ,
513
- {
514
- "id" : data ["id" ],
515
- "ip" : client_ip ,
516
- "containers" : data ["containers" ],
517
- "data" : {},
518
- },
519
- )
534
+ # Get the IP address of the client
535
+ client_ip = request .remote_addr
536
+ if not client_ip :
537
+ raise BadRequestError ("Could not get client IP address" )
538
+
539
+ log .debug ("Received new result" , result = data )
540
+
541
+ # Validate required fields
542
+ if "id" not in data :
543
+ raise BadRequestError ("Missing required field: id" )
544
+ if "status" not in data :
545
+ raise BadRequestError ("Missing required field: status" )
546
+ if "containers" not in data :
547
+ raise BadRequestError ("Missing required field: containers" )
548
+
549
+ # Create off-chain message with client IP
550
+ parsed : OffchainMessage = from_union (
551
+ OffchainMessage ,
552
+ {
553
+ "id" : data ["id" ],
554
+ "ip" : client_ip ,
555
+ "containers" : data ["containers" ],
556
+ "data" : {},
557
+ },
558
+ )
520
559
521
- # Store job status
522
- match data ["status" ]:
523
- case "success" :
524
- self ._store .set_success (parsed , [])
525
- for container in data ["containers" ]:
526
- self ._store .track_container_status (container , "success" )
527
- case "failed" :
528
- self ._store .set_failed (parsed , [])
529
- for container in data ["containers" ]:
530
- self ._store .track_container_status (container , "failed" )
531
- case "running" :
532
- self ._store .set_running (parsed )
533
- case _:
534
- return jsonify ({"error" : "Status is invalid" }), 400
535
-
536
- return jsonify (), 200
537
- except Exception as e :
538
- # Return error
539
- log .error (
540
- "Processed REST response" ,
541
- endpoint = request .path ,
542
- method = request .method ,
543
- status = 500 ,
544
- err = e ,
545
- )
546
- return jsonify ({"error" : "Could not store job status" }), 500
560
+ # Store job status
561
+ match data ["status" ]:
562
+ case "success" :
563
+ self ._store .set_success (parsed , [])
564
+ for container in data ["containers" ]:
565
+ self ._store .track_container_status (container , "success" )
566
+ case "failed" :
567
+ self ._store .set_failed (parsed , [])
568
+ for container in data ["containers" ]:
569
+ self ._store .track_container_status (container , "failed" )
570
+ case "running" :
571
+ self ._store .set_running (parsed )
572
+ case _:
573
+ raise BadRequestError ("Invalid status value" ,
574
+ {"status" : data ["status" ]})
575
+
576
+ return jsonify (), 200
547
577
548
578
async def run_forever (self : RESTServer ) -> None :
549
579
"""Main RESTServer lifecycle loop. Uses production hypercorn server"""
0 commit comments