Skip to content

Commit c3f1173

Browse files
author
acer-king
committed
cursor support
1 parent 694f2c5 commit c3f1173

File tree

2 files changed

+45
-8
lines changed

2 files changed

+45
-8
lines changed

cursor/app/core/middleware.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,29 @@
1-
import time
2-
from fastapi import HTTPException
1+
from fastapi import FastAPI, Depends, HTTPException, Request
2+
from cursor.app.core.config import config
3+
from starlette.middleware.base import BaseHTTPMiddleware
4+
from starlette.responses import JSONResponse
5+
# Your predefined valid API keys
6+
VALID_API_KEYS = {config.api_key}
37

48

5-
async def verify_api_key_rate_limit(config, api_key):
6-
# NOTE: abit dangerous but very useful
7-
pass
9+
class APIKeyMiddleware(BaseHTTPMiddleware):
10+
async def dispatch(self, request: Request, call_next):
11+
# Get the API key from the `Authorization` header
12+
if request.method == "OPTIONS":
13+
return await call_next(request)
14+
15+
if not request.headers.get("Authorization"):
16+
return JSONResponse(
17+
{"detail": "Invalid or missing API Key"}, status_code=401
18+
)
19+
20+
api_key = request.headers.get("Authorization").split(" ")[1]
21+
22+
# Validate the API key
23+
if not api_key or api_key not in VALID_API_KEYS:
24+
return JSONResponse(
25+
{"detail": "Invalid or missing API Key"}, status_code=401
26+
)
27+
28+
# Proceed to the next middleware or route handler
29+
return await call_next(request)

validators/weight_setter.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from substrateinterface import SubstrateInterface
1111
from functools import partial
1212
from typing import Tuple
13+
14+
from fastapi import HTTPException
1315
import bittensor as bt
1416
from bittensor import StreamingSynapse
1517

@@ -24,7 +26,10 @@
2426
from validators.task_manager import TaskMgr
2527
from cortext.dendrite import CortexDendrite
2628
from cortext.axon import CortexAxon
27-
from fastapi import HTTPException
29+
from cursor.app.endpoints.text import chat
30+
from cursor.app.endpoints.generic import models
31+
from cursor.app.core.middleware import APIKeyMiddleware
32+
2833

2934
scoring_organic_timeout = 60
3035
NUM_INTERVALS_PER_CYCLE = 10
@@ -104,7 +109,6 @@ def __init__(self, config, cache: QueryResponseCache, loop=None):
104109
# organic_thread.start()
105110
self.loop.create_task(self.consume_organic_queries())
106111

107-
108112
def start_axon_server(self):
109113
asyncio.run(self.consume_organic_queries())
110114

@@ -557,11 +561,23 @@ async def consume_organic_queries(self):
557561
forward_fn=self.embeddings,
558562
blacklist_fn=self.blacklist_embeddings,
559563
)
564+
self.cursor_setup()
560565
self.axon.serve(netuid=self.netuid, subtensor=self.subtensor)
561566
print(f"axon: {self.axon}")
562567
self.axon.start()
563568
bt.logging.info(f"Running validator on uid: {self.my_uid}")
564569

570+
def cursor_setup(self):
571+
self.axon.router.add_api_route(
572+
"/v1/chat/completions",
573+
chat,
574+
methods=["POST", "OPTIONS"],
575+
tags=["StreamPrompting"],
576+
response_model=None
577+
)
578+
self.axon.router.add_api_route("/v1/models", models, methods=["GET"], tags=["Text"], response_model=None)
579+
self.axon.app.add_middleware(APIKeyMiddleware)
580+
565581
def get_scoring_tasks_from_query_responses(self, queries_to_process):
566582

567583
grouped_query_resps = defaultdict(list)
@@ -617,7 +633,6 @@ async def process_queries_from_database(self):
617633
queries_to_process = self.query_database.copy()
618634
self.query_database.clear()
619635

620-
621636
self.synthetic_task_done = False
622637
bt.logging.info("start scoring process")
623638
start_time = time.time()

0 commit comments

Comments
 (0)