|
10 | 10 | from substrateinterface import SubstrateInterface
|
11 | 11 | from functools import partial
|
12 | 12 | from typing import Tuple
|
| 13 | + |
| 14 | +from fastapi import HTTPException |
13 | 15 | import bittensor as bt
|
14 | 16 | from bittensor import StreamingSynapse
|
15 | 17 |
|
|
24 | 26 | from validators.task_manager import TaskMgr
|
25 | 27 | from cortext.dendrite import CortexDendrite
|
26 | 28 | from cortext.axon import CortexAxon
|
27 |
| -from fastapi import HTTPException |
| 29 | +from cursor.app.endpoints.text import chat |
| 30 | +from cursor.app.endpoints.generic import models |
| 31 | +from cursor.app.core.middleware import APIKeyMiddleware |
| 32 | + |
28 | 33 |
|
29 | 34 | scoring_organic_timeout = 60
|
30 | 35 | NUM_INTERVALS_PER_CYCLE = 10
|
@@ -104,7 +109,6 @@ def __init__(self, config, cache: QueryResponseCache, loop=None):
|
104 | 109 | # organic_thread.start()
|
105 | 110 | self.loop.create_task(self.consume_organic_queries())
|
106 | 111 |
|
107 |
| - |
108 | 112 | def start_axon_server(self):
|
109 | 113 | asyncio.run(self.consume_organic_queries())
|
110 | 114 |
|
@@ -557,11 +561,23 @@ async def consume_organic_queries(self):
|
557 | 561 | forward_fn=self.embeddings,
|
558 | 562 | blacklist_fn=self.blacklist_embeddings,
|
559 | 563 | )
|
| 564 | + self.cursor_setup() |
560 | 565 | self.axon.serve(netuid=self.netuid, subtensor=self.subtensor)
|
561 | 566 | print(f"axon: {self.axon}")
|
562 | 567 | self.axon.start()
|
563 | 568 | bt.logging.info(f"Running validator on uid: {self.my_uid}")
|
564 | 569 |
|
| 570 | + def cursor_setup(self): |
| 571 | + self.axon.router.add_api_route( |
| 572 | + "/v1/chat/completions", |
| 573 | + chat, |
| 574 | + methods=["POST", "OPTIONS"], |
| 575 | + tags=["StreamPrompting"], |
| 576 | + response_model=None |
| 577 | + ) |
| 578 | + self.axon.router.add_api_route("/v1/models", models, methods=["GET"], tags=["Text"], response_model=None) |
| 579 | + self.axon.app.add_middleware(APIKeyMiddleware) |
| 580 | + |
565 | 581 | def get_scoring_tasks_from_query_responses(self, queries_to_process):
|
566 | 582 |
|
567 | 583 | grouped_query_resps = defaultdict(list)
|
@@ -617,7 +633,6 @@ async def process_queries_from_database(self):
|
617 | 633 | queries_to_process = self.query_database.copy()
|
618 | 634 | self.query_database.clear()
|
619 | 635 |
|
620 |
| - |
621 | 636 | self.synthetic_task_done = False
|
622 | 637 | bt.logging.info("start scoring process")
|
623 | 638 | start_time = time.time()
|
|
0 commit comments