Skip to content

Commit 45da8a4

Browse files
authored
Merge pull request #83 from Datura-ai/feature/code-refactor
Feature/code refactor
2 parents 5b05097 + cb13615 commit 45da8a4

File tree

7 files changed

+63
-48
lines changed

7 files changed

+63
-48
lines changed

cortext/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
# version must stay on line 22
22-
__version__ = "4.0.3"
22+
__version__ = "4.0.4"
2323
version_split = __version__.split(".")
2424
__spec_version__ = (
2525
(1000 * int(version_split[0]))

organic.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -206,4 +206,6 @@ async def main():
206206
response = await query_miner(dendrite, axon_to_use, synapse, timeout, streaming)
207207

208208
if __name__ == "__main__":
209-
asyncio.run(main())
209+
asyncio.run(main())
210+
211+

validators/config.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(self):
1616
self.ASYNC_TIME_OUT = int(os.getenv('ASYNC_TIME_OUT', 60))
1717
self.BT_SUBTENSOR_NETWORK = 'test' if self.ENV == 'test' else 'finney'
1818
self.SLEEP_PER_ITERATION = 1
19+
self.IMAGE_VALIDATOR_CHOOSE_PROBABILITY = 0.03
1920

2021
@staticmethod
2122
def check_required_env_vars():
@@ -58,9 +59,11 @@ def get_config() -> bt.config:
5859

5960
bt.axon.check_config(bt_config_)
6061
bt.logging.check_config(bt_config_)
62+
63+
local_host_str = ['local', '127.0.0.1', '0.0.0.0']
6164
if 'test' in bt_config_.subtensor.chain_endpoint:
6265
bt_config_.subtensor.network = 'test'
63-
elif 'local' in bt_config_.subtensor.chain_endpoint:
66+
elif any(word in bt_config_.subtensor.chain_endpoint for word in local_host_str):
6467
bt_config_.subtensor.network = 'local'
6568
else:
6669
bt_config_.subtensor.network = 'finney'

validators/services/validators/base_validator.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ async def load_questions(self, available_uids, item_type: str = "text", vision=F
4343
for index, uid in enumerate(available_uids):
4444

4545
if item_type == "images":
46-
messages = await utils.get_question("images", len(available_uids))
47-
content = " ".join(messages)
46+
content = await utils.get_question("images", len(available_uids))
4847
self.uid_to_questions[uid] = content # Store messages for each UID
4948
elif item_type == "text":
5049
question = await utils.get_question("text", len(available_uids), vision)

validators/services/validators/constants.py

+40
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,43 @@
1010
"claude-3-5-sonnet-20240620"]
1111
DEFAULT_NUM_UID_PICK = 30
1212
DEFAULT_NUM_UID_PICK_ANTHROPIC = 1
13+
TEXT_VALI_MODELS_WEIGHTS = {
14+
"AnthropicBedrock": {
15+
"anthropic.claude-v2:1": 1
16+
},
17+
"OpenAI": {
18+
"gpt-4o": 1,
19+
"gpt-4-1106-preview": 1,
20+
"gpt-3.5-turbo": 1,
21+
"gpt-3.5-turbo-16k": 1,
22+
"gpt-3.5-turbo-0125": 1,
23+
},
24+
"Gemini": {
25+
"gemini-pro": 1,
26+
"gemini-1.5-flash": 1,
27+
"gemini-1.5-pro": 1,
28+
},
29+
"Anthropic": {
30+
"claude-3-5-sonnet-20240620": 1,
31+
"claude-3-opus-20240229": 1,
32+
"claude-3-sonnet-20240229": 1,
33+
"claude-3-haiku-20240307": 1
34+
},
35+
"Groq": {
36+
"gemma-7b-it": 1,
37+
"llama3-70b-8192": 1,
38+
"llama3-8b-8192": 1,
39+
"mixtral-8x7b-32768": 1,
40+
},
41+
"Bedrock": {
42+
"anthropic.claude-3-sonnet-20240229-v1:0": 1,
43+
"cohere.command-r-v1:0": 1,
44+
"meta.llama2-70b-chat-v1": 1,
45+
"amazon.titan-text-express-v1": 1,
46+
"mistral.mistral-7b-instruct-v0:2": 1,
47+
"ai21.j2-mid-v1": 1,
48+
"anthropic.claude-3-5-sonnet-20240620-v1:0": 1,
49+
"anthropic.claude-3-opus-20240229-v1:0": 1,
50+
"anthropic.claude-3-haiku-20240307-v1:0": 1
51+
}
52+
}

validators/services/validators/text_validator.py

+8-30
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import random
33
from typing import AsyncIterator
44

5+
from cortext.reward import model
56
from validators.services.bittensor import bt_validator as bt
67
from . import constants
78
import cortext.reward
@@ -99,9 +100,11 @@ async def start_query(self, available_uids):
99100
syn = StreamPrompting(messages=messages, model=self.model, seed=self.seed, max_tokens=self.max_tokens,
100101
temperature=self.temperature, provider=self.provider, top_p=self.top_p,
101102
top_k=self.top_k)
103+
104+
image = image if image else ''
102105
bt.logging.info(
103106
f"Sending {syn.model} {self.query_type} request to uid: {uid}, "
104-
f"timeout {self.timeout}: {syn.messages[0]['content']}"
107+
f"timeout {self.timeout}: {syn.messages[0]['content']} {image}"
105108
)
106109
task = self.query_miner(self.metagraph, uid, syn)
107110
query_tasks.append(task)
@@ -119,35 +122,10 @@ def select_random_provider_and_model(self):
119122
self.provider = random.choice(providers)
120123
self.num_uids_to_pick = constants.DEFAULT_NUM_UID_PICK
121124

122-
if self.provider == "AnthropicBedrock":
123-
self.model = "anthropic.claude-v2:1"
124-
125-
elif self.provider == "OpenAI":
126-
models = ["gpt-4o", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0125"]
127-
self.model = random.choice(models)
128-
129-
elif self.provider == "Gemini":
130-
models = ["gemini-pro", "gemini-1.5-flash", "gemini-1.5-pro"]
131-
self.model = random.choice(models)
132-
133-
elif self.provider == "Anthropic":
134-
models = ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-sonnet-20240229",
135-
"claude-3-haiku-20240307"]
136-
self.model = random.choice(models)
137-
138-
elif self.provider == "Groq":
139-
models = ["gemma-7b-it", "llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768"]
140-
self.model = random.choice(models)
141-
142-
elif self.provider == "Bedrock":
143-
models = [
144-
"anthropic.claude-3-sonnet-20240229-v1:0", "cohere.command-r-v1:0",
145-
"meta.llama2-70b-chat-v1", "amazon.titan-text-express-v1",
146-
"mistral.mistral-7b-instruct-v0:2", "ai21.j2-mid-v1", "anthropic.claude-3-5-sonnet-20240620-v1:0"
147-
"anthropic.claude-3-opus-20240229-v1:0",
148-
"anthropic.claude-3-haiku-20240307-v1:0"
149-
]
150-
self.model = random.choice(models)
125+
model_to_weights = constants.TEXT_VALI_MODELS_WEIGHTS[self.provider]
126+
self.model = random.choices(list(model_to_weights.keys()),
127+
weights=list(model_to_weights.values()), k=1)[0]
128+
151129
return self.num_uids_to_pick
152130

153131
def should_i_score(self):

validators/weight_setter.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -219,29 +219,22 @@ async def perform_synthetic_scoring_and_update_weights(self):
219219
for uid, score in uid_to_scores.items():
220220
self.total_scores[uid] += score
221221

222-
steps_since_last_update = steps_passed % iterations_per_set_weights
223-
224-
if steps_since_last_update == iterations_per_set_weights - 1:
225-
await self.update_weights(steps_passed)
226-
else:
227-
bt.logging.info(
228-
f"Updating weights in {iterations_per_set_weights - steps_since_last_update - 1} iterations."
229-
)
230-
231222
# if we want to slow down the speed of the validator steps
232223
await asyncio.sleep(app_config.SLEEP_PER_ITERATION)
233224

234225
if (self.subtensor.block - cur_block) >= 360:
235-
print("refreshing metagraph...")
226+
bt.logging.info("refreshing metagraph...")
236227
cur_block = self.subtensor.block
237228
await self.refresh_metagraph()
229+
bt.logging.info("updating weights...")
230+
await self.update_weights(steps_passed)
238231

239232
@staticmethod
240233
def select_validator():
241234
rand = random.random()
242235
text_validator = ValidatorRegistryMeta.get_class('TextValidator')()
243236
image_validator = ValidatorRegistryMeta.get_class('ImageValidator')()
244-
if rand < 0.9:
237+
if rand > app_config.IMAGE_VALIDATOR_CHOOSE_PROBABILITY:
245238
bt.logging.info("text_validator is selected.")
246239
return text_validator
247240
else:
@@ -250,14 +243,14 @@ def select_validator():
250243

251244
async def get_available_uids(self):
252245
"""Get a dictionary of available UIDs and their axons asynchronously."""
246+
await self.dendrite.aclose_session()
253247
tasks = {uid.item(): self.check_uid(self.metagraph.axons[uid.item()], uid.item()) for uid in
254248
self.metagraph.uids}
255249
results = await asyncio.gather(*tasks.values())
256250

257251
# Create a dictionary of UID to axon info for active UIDs
258252
available_uids = {uid: axon_info for uid, axon_info in zip(tasks.keys(), results) if axon_info is not None}
259253

260-
await self.dendrite.aclose_session()
261254
return available_uids
262255

263256
async def check_uid(self, axon, uid):
@@ -326,7 +319,7 @@ async def set_weights(self, scores):
326319
wallet=self.wallet,
327320
uids=self.metagraph.uids,
328321
weights=self.moving_average_scores,
329-
wait_for_inclusion=False,
322+
wait_for_inclusion=True,
330323
version_key=cortext.__weights_version__,
331324
)
332325
)

0 commit comments

Comments
 (0)