Skip to content

Commit 14e4c2b

Browse files
committed
refactor GET /resources endpoint
1 parent 66dd789 commit 14e4c2b

File tree

2 files changed

+127
-97
lines changed

2 files changed

+127
-97
lines changed

.DS_Store

8 KB
Binary file not shown.

src/server/rest.py

+127-97
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313

1414
from chain.processor import ChainProcessor
1515
from orchestration import ContainerManager, DataStore, Guardian, Orchestrator
16-
from server.utils import is_local_ip
16+
from server.exceptions import BadRequestError, ForbiddenError, RESTServerError, RitualNodeError
17+
from server.utils import handle_errors, is_local_ip
1718
from shared import AsyncTask, JobResult
1819
from shared.config import ConfigChain, ConfigServer
1920
from shared.message import (
@@ -159,11 +160,12 @@ async def info() -> Tuple[Response, int]:
159160
),
160161
200,
161162
)
162-
163+
163164
@self._app.route("/resources", methods=["GET"])
164165
@rate_limit(
165166
self._rate_limit.num_requests, timedelta(seconds=self._rate_limit.period)
166167
)
168+
@handle_errors
167169
async def resources() -> Tuple[Response, int]:
168170
"""Collects container resources
169171
@@ -175,10 +177,14 @@ async def resources() -> Tuple[Response, int]:
175177
"""
176178
model_id = request.args.get("model_id")
177179

178-
return (
179-
jsonify(await self._orchestrator.collect_service_resources(model_id)),
180-
200,
181-
)
180+
try:
181+
resources = await self._orchestrator.collect_service_resources(model_id)
182+
return jsonify(resources), 200
183+
except Exception as e:
184+
raise RESTServerError(
185+
"Failed to collect resources",
186+
{"model_id": model_id, "error": str(e)}
187+
)
182188

183189
def filter_create_job(func): # type: ignore
184190
"""Decorator to filter and preprocess incoming off-chain messages"""
@@ -198,20 +204,24 @@ async def wrapper() -> Any:
198204
# Get the IP address of the client
199205
client_ip = request.remote_addr
200206
if not client_ip:
201-
return (
202-
jsonify({"error": "Could not get client IP address"}),
203-
400,
204-
)
207+
raise BadRequestError("Could not get client IP address")
205208

206209
# Parse message data, inject uuid and client IP
207210
job_id = str(uuid4()) # Generate a unique job ID
208211
log.debug(
209212
"Received new off-chain raw message", msg=data, job_id=job_id
210213
)
211-
parsed: OffchainMessage = from_union(
212-
OffchainMessage,
213-
{"id": job_id, "ip": client_ip, **data},
214-
)
214+
215+
try:
216+
parsed: OffchainMessage = from_union(
217+
OffchainMessage,
218+
{"id": job_id, "ip": client_ip, **data},
219+
)
220+
except Exception as e:
221+
raise BadRequestError(
222+
"Failed to parse message",
223+
{"error": str(e)}
224+
)
215225

216226
# Filter message through guardian
217227
filtered = self._guardian.process_message(parsed)
@@ -225,19 +235,20 @@ async def wrapper() -> Any:
225235
err=filtered.error,
226236
**filtered.params,
227237
)
228-
return (
229-
jsonify(
230-
{"error": filtered.error, "params": filtered.params}
231-
),
232-
405,
238+
raise ForbiddenError(
239+
filtered.error,
240+
filtered.params
233241
)
234242

235243
# Call actual endpoint function
236244
return await func(message=filtered)
237245

246+
except RitualNodeError:
247+
# Pass through our custom exceptions
248+
raise
238249
except Exception as e:
239-
log.error(f"Error in endpoint preprocessing: {e}")
240-
return jsonify({"error": f"Internal server error: {str(e)}"}), 500
250+
log.error(f"Error in endpoint preprocessing: {e}", exc_info=True)
251+
raise RESTServerError(f"Internal server error", {"message": str(e)})
241252

242253
return wrapper
243254

@@ -246,6 +257,7 @@ async def wrapper() -> Any:
246257
@rate_limit(
247258
self._rate_limit.num_requests, timedelta(seconds=self._rate_limit.period)
248259
)
260+
@handle_errors
249261
async def create_job(message: OffchainMessage) -> Tuple[Response, int]:
250262
"""Creates new off-chain job (direct compute request or subscription)
251263
@@ -277,7 +289,7 @@ async def create_job(message: OffchainMessage) -> Tuple[Response, int]:
277289
# Should only reach this point if chain is enabled (else, filtered
278290
# out upstream)
279291
if self._processor is None:
280-
raise RuntimeError("Chain not enabled")
292+
raise RESTServerError("Chain not enabled")
281293

282294
# Submit delegated subscription message to processor
283295
create_task(self._processor.track(message))
@@ -286,6 +298,8 @@ async def create_job(message: OffchainMessage) -> Tuple[Response, int]:
286298
# fetched via REST so it would be misleading. They are tracked
287299
# on-chain instead
288300
return_obj = {}
301+
else:
302+
raise BadRequestError("Unsupported message type", {"type": message.type})
289303

290304
# Return created message ID
291305
log.debug(
@@ -297,16 +311,14 @@ async def create_job(message: OffchainMessage) -> Tuple[Response, int]:
297311
id=str(message.id),
298312
)
299313
return jsonify(return_obj), 200
314+
except RitualNodeError:
315+
# Pass through our custom exceptions
316+
raise
300317
except Exception as e:
301-
# Return error
302-
log.error(
303-
"Processed REST response",
304-
endpoint=request.path,
305-
method=request.method,
306-
status=500,
307-
err=str(e),
318+
raise RESTServerError(
319+
"Could not enqueue job",
320+
{"error": str(e)}
308321
)
309-
return jsonify({"error": f"Could not enqueue job: {str(e)}"}), 500
310322

311323
@self._app.route("/api/jobs/stream", methods=["POST"])
312324
@filter_create_job # type: ignore
@@ -357,6 +369,7 @@ async def generator() -> Any:
357369
@rate_limit(
358370
self._rate_limit.num_requests, timedelta(seconds=self._rate_limit.period)
359371
)
372+
@handle_errors
360373
async def create_job_batch() -> Tuple[Response, int]:
361374
"""Creates off-chain jobs in batch (direct compute requests / subscriptions)
362375
@@ -365,27 +378,36 @@ async def create_job_batch() -> Tuple[Response, int]:
365378
"""
366379
try:
367380
# Collect JSON body
368-
data = await request.get_json(force=True)
381+
try:
382+
data = await request.get_json(force=True)
383+
except Exception:
384+
raise BadRequestError("Invalid JSON body")
369385

370386
# Get the IP address of the client
371387
client_ip = request.remote_addr
372388
if not client_ip:
373-
return jsonify({"error": "Could not get client IP address"}), 400
389+
raise BadRequestError("Could not get client IP address")
374390

375391
log.debug("Received new off-chain raw message batch", msg=data)
376392

377393
# If data is not an array, return error
378394
if not isinstance(data, list):
379-
return jsonify({"error": "Expected a list"}), 400
395+
raise BadRequestError("Expected a list of job requests")
380396

381397
# Inject uuid and client IP to each message
382-
parsed: list[OffchainMessage] = [
383-
from_union(
384-
OffchainMessage,
385-
{"id": str(uuid4()), "ip": client_ip, **item},
398+
try:
399+
parsed: list[OffchainMessage] = [
400+
from_union(
401+
OffchainMessage,
402+
{"id": str(uuid4()), "ip": client_ip, **item},
403+
)
404+
for item in data
405+
]
406+
except Exception as e:
407+
raise BadRequestError(
408+
"Failed to parse message batch",
409+
{"error": str(e)}
386410
)
387-
for item in data
388-
]
389411

390412
# Filter messages through guardian
391413
filtered = cast(
@@ -419,7 +441,8 @@ async def create_job_batch() -> Tuple[Response, int]:
419441
# Should only reach this point if chain is enabled (else,
420442
# filtered out upstream)
421443
if self._processor is None:
422-
raise RuntimeError("Chain not enabled")
444+
results.append({"error": "Chain not enabled"})
445+
continue
423446

424447
# Submit filtered delegated subscription message to processor
425448
create_task(
@@ -429,7 +452,9 @@ async def create_job_batch() -> Tuple[Response, int]:
429452
)
430453
results.append({})
431454
else:
432-
results.append({"error": "Could not parse message"})
455+
results.append({"error": "Unsupported message type"})
456+
else:
457+
results.append({"error": "Invalid message format"})
433458

434459
# Return created message IDs or errors
435460
log.debug(
@@ -440,26 +465,31 @@ async def create_job_batch() -> Tuple[Response, int]:
440465
results=results,
441466
)
442467
return jsonify(results), 200
468+
469+
except RitualNodeError:
470+
# Pass through our custom exceptions
471+
raise
443472
except Exception as e:
444-
# Return error
445-
log.error(
446-
"Processed REST response",
447-
endpoint=request.path,
448-
method=request.method,
449-
status=500,
450-
err=str(e),
473+
raise RESTServerError(
474+
"Failed to process batch request",
475+
{"error": str(e)}
451476
)
452-
return jsonify({"error": f"Could not enqueue job: {str(e)}"}), 500
453477

454478
@self._app.route("/api/jobs", methods=["GET"])
455479
@rate_limit(
456480
self._rate_limit.num_requests, timedelta(seconds=self._rate_limit.period)
457481
)
482+
@handle_errors
458483
async def get_job() -> Tuple[Response, int]:
484+
"""Get job status by ID or list all jobs for a client.
485+
486+
Returns:
487+
Response: Job data or list of job IDs
488+
"""
459489
# Get the IP address of the client
460490
client_ip = request.remote_addr
461491
if not client_ip:
462-
return jsonify({"error": "Could not get client IP address"}), 400
492+
raise BadRequestError("Could not get client IP address")
463493

464494
# Get job ID from query
465495
job_ids = request.args.getlist("id")
@@ -485,6 +515,7 @@ async def get_job() -> Tuple[Response, int]:
485515
return jsonify(data), 200
486516

487517
@self._app.route("/api/status", methods=["PUT"])
518+
@handle_errors
488519
async def store_job_status() -> Tuple[Response, int]:
489520
"""Stores job status in data store"""
490521

@@ -494,56 +525,55 @@ async def store_job_status() -> Tuple[Response, int]:
494525
"Unauthorized attempt to store job status",
495526
remote_addr=request.remote_addr,
496527
)
497-
return jsonify({"error": "Unauthorized"}), 403
528+
raise ForbiddenError("Unauthorized access",
529+
{"remote_addr": str(request.remote_addr)})
498530

499-
try:
500-
# Collect JSON body
501-
data = await request.get_json(force=True)
531+
# Collect JSON body
532+
data = await request.get_json(force=True)
502533

503-
# Get the IP address of the client
504-
client_ip = request.remote_addr
505-
if not client_ip:
506-
return jsonify({"error": "Could not get client IP address"}), 400
507-
508-
log.debug("Received new result", result=data)
509-
510-
# Create off-chain message with client IP
511-
parsed: OffchainMessage = from_union(
512-
OffchainMessage,
513-
{
514-
"id": data["id"],
515-
"ip": client_ip,
516-
"containers": data["containers"],
517-
"data": {},
518-
},
519-
)
534+
# Get the IP address of the client
535+
client_ip = request.remote_addr
536+
if not client_ip:
537+
raise BadRequestError("Could not get client IP address")
538+
539+
log.debug("Received new result", result=data)
540+
541+
# Validate required fields
542+
if "id" not in data:
543+
raise BadRequestError("Missing required field: id")
544+
if "status" not in data:
545+
raise BadRequestError("Missing required field: status")
546+
if "containers" not in data:
547+
raise BadRequestError("Missing required field: containers")
548+
549+
# Create off-chain message with client IP
550+
parsed: OffchainMessage = from_union(
551+
OffchainMessage,
552+
{
553+
"id": data["id"],
554+
"ip": client_ip,
555+
"containers": data["containers"],
556+
"data": {},
557+
},
558+
)
520559

521-
# Store job status
522-
match data["status"]:
523-
case "success":
524-
self._store.set_success(parsed, [])
525-
for container in data["containers"]:
526-
self._store.track_container_status(container, "success")
527-
case "failed":
528-
self._store.set_failed(parsed, [])
529-
for container in data["containers"]:
530-
self._store.track_container_status(container, "failed")
531-
case "running":
532-
self._store.set_running(parsed)
533-
case _:
534-
return jsonify({"error": "Status is invalid"}), 400
535-
536-
return jsonify(), 200
537-
except Exception as e:
538-
# Return error
539-
log.error(
540-
"Processed REST response",
541-
endpoint=request.path,
542-
method=request.method,
543-
status=500,
544-
err=e,
545-
)
546-
return jsonify({"error": "Could not store job status"}), 500
560+
# Store job status
561+
match data["status"]:
562+
case "success":
563+
self._store.set_success(parsed, [])
564+
for container in data["containers"]:
565+
self._store.track_container_status(container, "success")
566+
case "failed":
567+
self._store.set_failed(parsed, [])
568+
for container in data["containers"]:
569+
self._store.track_container_status(container, "failed")
570+
case "running":
571+
self._store.set_running(parsed)
572+
case _:
573+
raise BadRequestError("Invalid status value",
574+
{"status": data["status"]})
575+
576+
return jsonify(), 200
547577

548578
async def run_forever(self: RESTServer) -> None:
549579
"""Main RESTServer lifecycle loop. Uses production hypercorn server"""

0 commit comments

Comments
 (0)