Skip to content

Commit df6f2a7

Browse files
authored
Merge pull request #101 from Datura-ai/hotfix-main-bittensor
Hotfix main bittensor
2 parents 5078171 + 0b9e982 commit df6f2a7

File tree

7 files changed

+136
-42
lines changed

7 files changed

+136
-42
lines changed

cortext/protocol.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ class StreamPrompting(bt.StreamingSynapse):
196196
)
197197

198198
seed: int = pydantic.Field(
199-
default="1234",
199+
default=1234,
200200
title="Seed",
201201
description="Seed for text generation. This attribute is immutable and cannot be updated.",
202202
)

cursor/app.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from fastapi import FastAPI, HTTPException
2+
from fastapi.responses import StreamingResponse
3+
import httpx
4+
import asyncio
5+
6+
app = FastAPI()
7+
organic_server_url = "https://0.0.0.0:8000"
8+
9+
10+
@app.post("/apis.datura.ai/s18_sigma")
11+
async def stream_prompt():
12+
try:
13+
# Create an asynchronous HTTP client session
14+
async with httpx.AsyncClient() as client:
15+
# Make a streaming GET request to the external server
16+
response = await client.stream("GET", organic_server_url)
17+
response.raise_for_status() # Raise an error for bad responses (4xx, 5xx)
18+
19+
# Define an async generator to read chunks of data from the external server
20+
async def stream_generator():
21+
async for chunk in response.aiter_bytes():
22+
yield chunk
23+
await asyncio.sleep(0) # Allow other tasks to run
24+
25+
# Return a StreamingResponse that forwards the data from the external server
26+
return StreamingResponse(stream_generator(), media_type="application/json")
27+
except httpx.HTTPStatusError as e:
28+
raise HTTPException(status_code=e.response.status_code, detail=str(e))
29+
except httpx.RequestError as e:
30+
raise HTTPException(status_code=500, detail="Failed to fetch data from external server")

server/app/curd.py

+61-23
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import json
2+
import time
13
import traceback
24

35
import psycopg2
4-
import os
6+
from psycopg2.extras import RealDictCursor
57
from typing import List
8+
69
from . import models, schemas
7-
from .database import cur, TABEL_NAME, conn, DATABASE_URL
10+
from .database import TABEL_NAME, DATABASE_URL
811
from fastapi import HTTPException
912

1013

@@ -20,10 +23,27 @@ def create_items(items: List[schemas.ItemCreate]):
2023
conn = psycopg2.connect(DATABASE_URL)
2124
# Create a cursor object to interact with the database
2225
cur = conn.cursor()
23-
query = f"INSERT INTO {TABEL_NAME} (p_key, question, answer, provider, model, timestamp) VALUES (%s, %s, %s, %s, %s, %s)"
26+
query = (f"INSERT INTO {TABEL_NAME} (p_key, question, answer, provider, model, timestamp, miner_hot_key, miner_uid"
27+
f", score, similarity, vali_uid, timeout, time_taken, epoch_num, cycle_num, block_num"
28+
f", name) VALUES (%s, %s, %s, %s, %s, %s"
29+
f", %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)")
2430
datas = []
2531
for item in items:
26-
datas.append((item.p_key, item.question, item.answer, item.provider, item.model, item.timestamp))
32+
question = json.loads(item.question)
33+
miner_uid = question.get("miner_info", {}).get("miner_id") or 99999
34+
miner_hot_key = question.get("miner_info", {}).get("miner_hotkey") or ""
35+
score = question.get("score") or 0
36+
similarity = question.get("similarity") or 0
37+
vali_uid = question.get("validator_info").get("vali_uid")
38+
timeout = question.get("timeout")
39+
time_taken = question.get("time_taken")
40+
epoch_num = question.get("epoch_num")
41+
cycle_num = question.get("cycle_num")
42+
block_num = question.get("block_num")
43+
name = question.get("name") or ""
44+
datas.append((item.p_key, item.question, item.answer, item.provider, item.model, item.timestamp, miner_hot_key,
45+
miner_uid, score, similarity, vali_uid, timeout,
46+
time_taken, epoch_num, cycle_num, block_num, name))
2747
try:
2848
if conn.closed:
2949
print("connection is closed already")
@@ -35,27 +55,45 @@ def create_items(items: List[schemas.ItemCreate]):
3555
raise HTTPException(status_code=500, detail=f"Internal Server Error {err}")
3656

3757

38-
def get_items(skip: int = 0, limit: int = 10):
39-
req_body = {
40-
"filters": {
41-
"min_score": 0,
42-
"min_similarity": 120,
43-
"model": "",
44-
"provider": "",
45-
"min_timestamp": 12345,
46-
"max_timestamp": 12345
47-
},
48-
"search": 123 or "2FXABC",
49-
"sort_by": "miner",
50-
"sort_order": "desc"
51-
}
58+
def get_items(req_body: models.RequestBody):
5259
conn = psycopg2.connect(DATABASE_URL)
53-
# Create a cursor object to interact with the database
54-
cur = conn.cursor()
55-
query = f"SELECT * FROM {TABEL_NAME} offset {skip} limit {limit} ;"
56-
cur.execute(query)
60+
cur = conn.cursor(cursor_factory=RealDictCursor)
61+
skip = req_body.skip
62+
limit = req_body.limit
63+
64+
filter_by_miner_score = f"score>={req_body.filters.min_score}" if req_body.filters.min_score else ""
65+
filter_by_miner_similarity = f"score>={req_body.filters.min_similarity}" if req_body.filters.min_similarity else ""
66+
filter_by_provider = f"provider='{req_body.filters.provider}'" if req_body.filters.provider else ""
67+
filter_by_model = f"model='{req_body.filters.model}'" if req_body.filters.model else ""
68+
filter_by_min_timestamp = f"timestamp>={req_body.filters.min_timestamp}" if req_body.filters.min_timestamp else ""
69+
filter_by_max_timestamp = f"timestamp<={req_body.filters.max_timestamp}" if req_body.filters.max_timestamp else ""
70+
filter_by_epoch_num = f"epoch_num={req_body.filters.epoch_num}" if req_body.filters.epoch_num else ""
71+
filter_by_block_num = f"block_num={req_body.filters.block_num}" if req_body.filters.block_num else ""
72+
filter_by_cycle_num = f"cycle_num={req_body.filters.cycle_num}" if req_body.filters.cycle_num else ""
73+
filter_by_name = f"name={req_body.filters.name}" if req_body.filters.name else ""
74+
search_by_uid_or_hotkey = (f"miner_uid=%s" if str(req_body.search).isdigit()
75+
else f"miner_hot_key like %s") if req_body.search else ""
76+
conditions = [filter_by_miner_score, filter_by_miner_similarity, filter_by_provider, filter_by_model,
77+
filter_by_min_timestamp,
78+
filter_by_max_timestamp, filter_by_epoch_num, filter_by_block_num, filter_by_cycle_num,
79+
filter_by_name, search_by_uid_or_hotkey]
80+
conditions = [item for item in conditions if item]
81+
conditions_query = " and ".join(conditions)
82+
order_by = f"order by {req_body.sort_by} {req_body.sort_order}"
83+
query = f"SELECT * FROM {TABEL_NAME} where {conditions_query} {order_by} limit {limit} offset {skip};"
84+
print(query)
85+
query_cnt = f"SELECT count(*) FROM {TABEL_NAME} where {conditions_query}"
86+
start_time = time.time()
87+
cur.execute(query, (f"%{req_body.search}%" if not str(req_body.search).isdigit() else str(req_body.search),))
88+
print(f"execution query has been completed. {time.time() - start_time}")
5789
items = cur.fetchall() # Fetch all results
58-
return [item for item in items]
90+
print(f"loaded to python object array. {time.time() - start_time}")
91+
# cur.execute(query_cnt, (f"%{req_body.search}%" if not str(req_body.search).isdigit() else str(req_body.search),))
92+
# cnt = cur.fetchone().get('count')
93+
# print(f"all query execution has been succeed. {time.time() - start_time}")
94+
cur.close()
95+
conn.close()
96+
return {"records": items, "limit": limit, "skip": skip}
5997

6098

6199
def get_item(p_key: int):

server/app/database.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,18 @@ async def create_table(app):
2626
answer TEXT,
2727
provider VARCHAR(100),
2828
model VARCHAR(100),
29-
timestamp FLOAT
29+
timestamp FLOAT,
30+
miner_hot_key VARCHAR(100),
31+
miner_uid INTEGER,
32+
score FLOAT,
33+
similarity FLOAT,
34+
vali_uid INTEGER,
35+
timeout INTEGER,
36+
time_taken INTEGER,
37+
epoch_num INTEGER,
38+
cycle_num INTEGER,
39+
block_num INTEGER,
40+
name VARCHAR(100)
3041
);
3142
"""
3243

@@ -35,6 +46,9 @@ async def create_table(app):
3546
conn.commit() # Save changes
3647
create_index_query = f"""
3748
CREATE INDEX IF NOT EXISTS question_answer_index ON {TABEL_NAME} (provider, model);
49+
CREATE INDEX IF NOT EXISTS miner_id_index ON {TABEL_NAME} (miner_uid);
50+
CREATE INDEX IF NOT EXISTS miner_hot_key_index ON {TABEL_NAME} (miner_hot_key);
51+
CREATE INDEX IF NOT EXISTS idx_score_sim_timestamp ON {TABEL_NAME} (score, similarity, timestamp);
3852
"""
3953
cur.execute(create_index_query)
4054
conn.commit()

server/app/main.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from contextlib import asynccontextmanager
22
from fastapi import FastAPI, Depends, HTTPException
3+
from fastapi.middleware.cors import CORSMiddleware
34
from . import curd, models, schemas
45
from .database import create_table, conn, cur
56
from typing import List
@@ -13,6 +14,10 @@ async def lifespan(app: FastAPI):
1314

1415

1516
app = FastAPI(lifespan=lifespan)
17+
app.add_middleware(
18+
CORSMiddleware,
19+
allow_origins=["*"], # Allows all origins
20+
)
1621

1722

1823
@app.on_event("shutdown")
@@ -28,9 +33,9 @@ def create_item(items: List[schemas.ItemCreate]):
2833

2934

3035
# Read all items
31-
@app.get("/items")
32-
def read_items(skip: int = 0, limit: int = 10):
33-
items = curd.get_items(skip=skip, limit=limit)
36+
@app.post("/items/search")
37+
def read_items(req_body: models.RequestBody):
38+
items = curd.get_items(req_body)
3439
return items
3540

3641

server/app/models.py

+17-10
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,23 @@
33

44

55
class Filters(BaseModel):
6-
min_score: int = Field(0, ge=0)
7-
min_similarity: int = Field(120, ge=0)
8-
model: Optional[str] = Field(None)
9-
provider: Optional[str] = Field(None)
10-
min_timestamp: Optional[int] = Field(None, ge=0)
11-
max_timestamp: Optional[int] = Field(None, ge=0)
6+
min_score: float = Field(0, ge=0)
7+
min_similarity: float = Field(0.0, ge=0)
8+
model: Optional[str] = Field("")
9+
provider: Optional[str] = Field("")
10+
epoch_num: Optional[int] = Field("")
11+
cycle_num: Optional[int] = Field("")
12+
block_num: Optional[int] = Field("")
13+
name: Optional[str] = Field("")
14+
min_timestamp: Optional[float] = Field(0, ge=0)
15+
max_timestamp: Optional[float] = Field(999999999999, ge=0)
16+
1217

1318

1419
class RequestBody(BaseModel):
15-
filters: Filters
16-
search: Union[int, str] = Field(..., description="An integer ID or a string search key")
17-
sort_by: str = Field("miner", description="Field to sort by")
18-
sort_order: str = Field("desc", regex="^(asc|desc)$", description="Sorting order, 'asc' or 'desc'")
20+
filters: Optional[Filters] = Field(..., description="filter for searching")
21+
search: Optional[Union[int, str]] = Field(..., description="An integer ID or a string search key")
22+
sort_by: Optional[str] = Field("miner_uid", description="Field to sort by")
23+
sort_order: Optional[str] = Field("desc", pattern="^(asc|desc)$", description="Sorting order, 'asc' or 'desc'")
24+
skip: Optional[int] = Field(0, description="skip for pagination")
25+
limit: Optional[int] = Field(100, description="skip for pagination")

validators/weight_setter.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import asyncio
2-
import concurrent
32
import random
43
import threading
5-
import traceback
4+
import json
65

76
import torch
87
import time
9-
import requests
108

119
from black.trans import defaultdict
1210
from substrateinterface import SubstrateInterface
@@ -15,7 +13,6 @@
1513
import bittensor as bt
1614
from bittensor import StreamingSynapse
1715
import cortext
18-
import json
1916
from starlette.types import Send
2017

2118
from cortext.protocol import IsAlive, StreamPrompting, ImageResponse, Embeddings
@@ -26,6 +23,7 @@
2623
from validators.task_manager import TaskMgr
2724
from cortext.dendrite import CortexDendrite
2825
from cortext.axon import CortexAxon
26+
from fastapi import HTTPException
2927

3028
scoring_organic_timeout = 60
3129
NUM_INTERVALS_PER_CYCLE = 10
@@ -483,6 +481,8 @@ async def embeddings(self, synapse: Embeddings) -> Embeddings:
483481

484482
async def prompt(self, synapse: StreamPrompting) -> StreamingSynapse.BTStreamingResponse:
485483
bt.logging.info(f"Received {synapse}")
484+
if len(json.dumps(synapse.messages)) > 1024:
485+
raise HTTPException(status_code=413, detail="Request entity too large")
486486

487487
async def _prompt(query_synapse: StreamPrompting, send: Send):
488488
query_synapse.deserialize_flag = False

0 commit comments

Comments
 (0)