diff --git a/skynet/auth/jwt.py b/skynet/auth/jwt.py index 86e094b..35ad8ae 100644 --- a/skynet/auth/jwt.py +++ b/skynet/auth/jwt.py @@ -40,7 +40,12 @@ async def authorize(jwt_incoming: str) -> dict: raise HTTPException(status_code=401, detail=f'Failed to retrieve public key. {kid}') try: - return jwt.decode(jwt_incoming, public_key, algorithms=['RS256', 'HS512'], audience=asap_pub_keys_auds) + decoded = jwt.decode(jwt_incoming, public_key, algorithms=['RS256', 'HS512'], audience=asap_pub_keys_auds) + + if decoded.get('appId') is None: + decoded['appId'] = kid + + return decoded except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Expired token.") except Exception: diff --git a/skynet/modules/ttt/summaries/jobs.py b/skynet/modules/ttt/summaries/jobs.py index fa54888..50cf688 100644 --- a/skynet/modules/ttt/summaries/jobs.py +++ b/skynet/modules/ttt/summaries/jobs.py @@ -169,7 +169,7 @@ async def update_done_job(job: Job, result: str, processor: Processors, has_fail await db.lrem(RUNNING_JOBS_KEY, 0, job.id) if updated_job.status != JobStatus.SKIPPED: - SUMMARY_DURATION_METRIC.observe(updated_job.computed_duration) + SUMMARY_DURATION_METRIC.observe(updated_job.computed_duration, {'app_id': updated_job.metadata.app_id}) SUMMARY_FULL_DURATION_METRIC.observe(updated_job.computed_full_duration) SUMMARY_INPUT_LENGTH_METRIC.observe(len(updated_job.payload.text)) diff --git a/skynet/modules/ttt/summaries/v1/models.py b/skynet/modules/ttt/summaries/v1/models.py index 3b703ff..ed84274 100644 --- a/skynet/modules/ttt/summaries/v1/models.py +++ b/skynet/modules/ttt/summaries/v1/models.py @@ -37,6 +37,7 @@ class DocumentPayload(BaseModel): class DocumentMetadata(BaseModel): + app_id: str | None = None customer_id: str | None = None diff --git a/skynet/modules/ttt/summaries/v1/router.py b/skynet/modules/ttt/summaries/v1/router.py index a3db395..5bd84d5 100644 --- a/skynet/modules/ttt/summaries/v1/router.py +++ b/skynet/modules/ttt/summaries/v1/router.py @@ -18,6 +18,14 @@ def get_customer_id(request: Request) -> str: return id +def get_app_id(request: Request) -> str: + return request.state.decoded_jwt.get('appId') if hasattr(request.state, 'decoded_jwt') else None + + +def get_metadata(request: Request) -> DocumentMetadata: + return DocumentMetadata(app_id=get_app_id(request), customer_id=get_customer_id(request)) + + @api_version(1) @router.post("/action-items") async def get_action_items(payload: DocumentPayload, request: Request) -> JobId: @@ -25,9 +33,7 @@ async def get_action_items(payload: DocumentPayload, request: Request) -> JobId: Starts a job to extract action items from the given payload. """ - return await create_job( - job_type=JobType.ACTION_ITEMS, payload=payload, metadata=DocumentMetadata(customer_id=get_customer_id(request)) - ) + return await create_job(job_type=JobType.ACTION_ITEMS, payload=payload, metadata=get_metadata(request)) @api_version(1) @@ -37,9 +43,7 @@ async def get_summary(payload: DocumentPayload, request: Request) -> JobId: Starts a job to summarize the given payload. """ - return await create_job( - job_type=JobType.SUMMARY, payload=payload, metadata=DocumentMetadata(customer_id=get_customer_id(request)) - ) + return await create_job(job_type=JobType.SUMMARY, payload=payload, metadata=get_metadata(request)) @api_version(1)