Skip to content

Commit a5ca68a

Browse files
committed
Experimental: multiprocessing for speeding up query map decodes.
1 parent 77f9ec6 commit a5ca68a

File tree

3 files changed

+159
-63
lines changed

3 files changed

+159
-63
lines changed

async_substrate_interface/async_substrate.py

+57-62
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
import asyncio
8+
from concurrent.futures import ProcessPoolExecutor
89
import inspect
910
import logging
1011
import random
@@ -56,6 +57,7 @@
5657
)
5758
from async_substrate_interface.utils.storage import StorageKey
5859
from async_substrate_interface.type_registry import _TYPE_REGISTRY
60+
from async_substrate_interface.utils.decoding_attempt import _decode_query_map, _decode_scale_with_runtime
5961

6062
if TYPE_CHECKING:
6163
from websockets.asyncio.client import ClientConnection
@@ -413,6 +415,7 @@ def __init__(
413415
last_key: Optional[str] = None,
414416
max_results: Optional[int] = None,
415417
ignore_decoding_errors: bool = False,
418+
executor: Optional["ProcessPoolExecutor"] = None
416419
):
417420
self.records = records
418421
self.page_size = page_size
@@ -425,6 +428,7 @@ def __init__(
425428
self.params = params
426429
self.ignore_decoding_errors = ignore_decoding_errors
427430
self.loading_complete = False
431+
self.executor = executor
428432
self._buffer = iter(self.records) # Initialize the buffer with initial records
429433

430434
async def retrieve_next_page(self, start_key) -> list:
@@ -437,6 +441,7 @@ async def retrieve_next_page(self, start_key) -> list:
437441
start_key=start_key,
438442
max_results=self.max_results,
439443
ignore_decoding_errors=self.ignore_decoding_errors,
444+
executor=self.executor
440445
)
441446
if len(result.records) < self.page_size:
442447
self.loading_complete = True
@@ -862,6 +867,7 @@ async def encode_scale(
862867
await self._wait_for_registry(_attempt, _retries)
863868
return self._encode_scale(type_string, value)
864869

870+
865871
async def decode_scale(
866872
self,
867873
type_string: str,
@@ -898,7 +904,7 @@ async def decode_scale(
898904
else:
899905
return obj
900906

901-
async def load_runtime(self, runtime):
907+
def load_runtime(self, runtime):
902908
self.runtime = runtime
903909

904910
# Update type registry
@@ -954,7 +960,7 @@ async def init_runtime(
954960
)
955961

956962
if self.runtime and runtime_version == self.runtime.runtime_version:
957-
return
963+
return self.runtime
958964

959965
runtime = self.runtime_cache.retrieve(runtime_version=runtime_version)
960966
if not runtime:
@@ -990,7 +996,7 @@ async def init_runtime(
990996
runtime_version=runtime_version, runtime=runtime
991997
)
992998

993-
await self.load_runtime(runtime)
999+
self.load_runtime(runtime)
9941000

9951001
if self.ss58_format is None:
9961002
# Check and apply runtime constants
@@ -1000,6 +1006,7 @@ async def init_runtime(
10001006

10011007
if ss58_prefix_constant:
10021008
self.ss58_format = ss58_prefix_constant
1009+
return runtime
10031010

10041011
async def create_storage_key(
10051012
self,
@@ -2858,6 +2865,7 @@ async def query_map(
28582865
page_size: int = 100,
28592866
ignore_decoding_errors: bool = False,
28602867
reuse_block_hash: bool = False,
2868+
executor: Optional["ProcessPoolExecutor"] = None
28612869
) -> AsyncQueryMapResult:
28622870
"""
28632871
Iterates over all key-pairs located at the given module and storage_function. The storage
@@ -2892,12 +2900,11 @@ async def query_map(
28922900
Returns:
28932901
AsyncQueryMapResult object
28942902
"""
2895-
hex_to_bytes_ = hex_to_bytes
28962903
params = params or []
28972904
block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash)
28982905
if block_hash:
28992906
self.last_block_hash = block_hash
2900-
await self.init_runtime(block_hash=block_hash)
2907+
runtime = await self.init_runtime(block_hash=block_hash)
29012908

29022909
metadata_pallet = self.runtime.metadata.get_metadata_pallet(module)
29032910
if not metadata_pallet:
@@ -2952,19 +2959,6 @@ async def query_map(
29522959
result = []
29532960
last_key = None
29542961

2955-
def concat_hash_len(key_hasher: str) -> int:
2956-
"""
2957-
Helper function to avoid if statements
2958-
"""
2959-
if key_hasher == "Blake2_128Concat":
2960-
return 16
2961-
elif key_hasher == "Twox64Concat":
2962-
return 8
2963-
elif key_hasher == "Identity":
2964-
return 0
2965-
else:
2966-
raise ValueError("Unsupported hash type")
2967-
29682962
if len(result_keys) > 0:
29692963
last_key = result_keys[-1]
29702964

@@ -2975,51 +2969,51 @@ def concat_hash_len(key_hasher: str) -> int:
29752969

29762970
if "error" in response:
29772971
raise SubstrateRequestException(response["error"]["message"])
2978-
29792972
for result_group in response["result"]:
2980-
for item in result_group["changes"]:
2981-
try:
2982-
# Determine type string
2983-
key_type_string = []
2984-
for n in range(len(params), len(param_types)):
2985-
key_type_string.append(
2986-
f"[u8; {concat_hash_len(key_hashers[n])}]"
2987-
)
2988-
key_type_string.append(param_types[n])
2989-
2990-
item_key_obj = await self.decode_scale(
2991-
type_string=f"({', '.join(key_type_string)})",
2992-
scale_bytes=bytes.fromhex(item[0][len(prefix) :]),
2993-
return_scale_obj=True,
2994-
)
2995-
2996-
# strip key_hashers to use as item key
2997-
if len(param_types) - len(params) == 1:
2998-
item_key = item_key_obj[1]
2999-
else:
3000-
item_key = tuple(
3001-
item_key_obj[key + 1]
3002-
for key in range(len(params), len(param_types) + 1, 2)
3003-
)
3004-
3005-
except Exception as _:
3006-
if not ignore_decoding_errors:
3007-
raise
3008-
item_key = None
3009-
3010-
try:
3011-
item_bytes = hex_to_bytes_(item[1])
3012-
3013-
item_value = await self.decode_scale(
3014-
type_string=value_type,
3015-
scale_bytes=item_bytes,
3016-
return_scale_obj=True,
3017-
)
3018-
except Exception as _:
3019-
if not ignore_decoding_errors:
3020-
raise
3021-
item_value = None
3022-
result.append([item_key, item_value])
2973+
if executor:
2974+
# print(
2975+
# ("prefix", type("prefix")),
2976+
# ("runtime_registry", type(runtime.registry)),
2977+
# ("param_types", type(param_types)),
2978+
# ("params", type(params)),
2979+
# ("value_type", type(value_type)),
2980+
# ("key_hasher", type(key_hashers)),
2981+
# ("ignore_decoding_errors", type(ignore_decoding_errors)),
2982+
# )
2983+
result = await asyncio.get_running_loop().run_in_executor(
2984+
executor,
2985+
_decode_query_map,
2986+
result_group["changes"],
2987+
prefix,
2988+
runtime.registry.registry,
2989+
param_types,
2990+
params,
2991+
value_type, key_hashers, ignore_decoding_errors
2992+
)
2993+
# max_workers = executor._max_workers
2994+
# result_group_changes_groups = [result_group["changes"][i:i + max_workers] for i in range(0, len(result_group["changes"]), max_workers)]
2995+
# all_results = executor.map(
2996+
# self._decode_query_map,
2997+
# result_group["changes"],
2998+
# repeat(prefix),
2999+
# repeat(runtime.registry),
3000+
# repeat(param_types),
3001+
# repeat(params),
3002+
# repeat(value_type),
3003+
# repeat(key_hashers),
3004+
# repeat(ignore_decoding_errors)
3005+
# )
3006+
# for r in all_results:
3007+
# result.extend(r)
3008+
else:
3009+
result = _decode_query_map(
3010+
result_group["changes"],
3011+
prefix,
3012+
runtime.registry.registry,
3013+
param_types,
3014+
params,
3015+
value_type, key_hashers, ignore_decoding_errors
3016+
)
30233017
return AsyncQueryMapResult(
30243018
records=result,
30253019
page_size=page_size,
@@ -3031,6 +3025,7 @@ def concat_hash_len(key_hasher: str) -> int:
30313025
last_key=last_key,
30323026
max_results=max_results,
30333027
ignore_decoding_errors=ignore_decoding_errors,
3028+
executor=executor
30343029
)
30353030

30363031
async def submit_extrinsic(

async_substrate_interface/sync_substrate.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,9 @@ def __enter__(self):
525525
return self
526526

527527
def __del__(self):
528-
self.close()
528+
self.ws.close()
529+
print("DELETING SUBSTATE")
530+
# self.ws.protocol.fail(code=1006) # ABNORMAL_CLOSURE
529531

530532
def initialize(self):
531533
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from scalecodec import ss58_encode
2+
3+
from async_substrate_interface.utils import hex_to_bytes
4+
from bt_decode import decode as decode_by_type_string, PortableRegistry
5+
from bittensor_wallet.utils import SS58_FORMAT
6+
7+
8+
class ScaleObj:
9+
def __init__(self, value):
10+
self.value = value
11+
12+
def _decode_scale_with_runtime(
13+
type_string: str,
14+
scale_bytes: bytes,
15+
runtime_registry: "Runtime",
16+
return_scale_obj: bool = False
17+
):
18+
if scale_bytes == b"":
19+
return None
20+
if type_string == "scale_info::0": # Is an AccountId
21+
# Decode AccountId bytes to SS58 address
22+
return ss58_encode(scale_bytes, SS58_FORMAT)
23+
else:
24+
obj = decode_by_type_string(type_string, runtime_registry, scale_bytes)
25+
if return_scale_obj:
26+
return ScaleObj(obj)
27+
else:
28+
return obj
29+
30+
def _decode_query_map(result_group_changes, prefix, runtime_registry,
31+
param_types, params, value_type, key_hashers, ignore_decoding_errors):
32+
def concat_hash_len(key_hasher: str) -> int:
33+
"""
34+
Helper function to avoid if statements
35+
"""
36+
if key_hasher == "Blake2_128Concat":
37+
return 16
38+
elif key_hasher == "Twox64Concat":
39+
return 8
40+
elif key_hasher == "Identity":
41+
return 0
42+
else:
43+
raise ValueError("Unsupported hash type")
44+
45+
hex_to_bytes_ = hex_to_bytes
46+
runtime_registry = PortableRegistry.from_json(runtime_registry)
47+
48+
result = []
49+
for item in result_group_changes:
50+
try:
51+
# Determine type string
52+
key_type_string = []
53+
for n in range(len(params), len(param_types)):
54+
key_type_string.append(
55+
f"[u8; {concat_hash_len(key_hashers[n])}]"
56+
)
57+
key_type_string.append(param_types[n])
58+
59+
item_key_obj = _decode_scale_with_runtime(
60+
f"({', '.join(key_type_string)})",
61+
bytes.fromhex(item[0][len(prefix):]),
62+
runtime_registry,
63+
False
64+
)
65+
66+
# strip key_hashers to use as item key
67+
if len(param_types) - len(params) == 1:
68+
item_key = item_key_obj[1]
69+
else:
70+
item_key = tuple(
71+
item_key_obj[key + 1]
72+
for key in range(len(params), len(param_types) + 1, 2)
73+
)
74+
75+
except Exception as _:
76+
if not ignore_decoding_errors:
77+
raise
78+
item_key = None
79+
80+
try:
81+
item_bytes = hex_to_bytes_(item[1])
82+
83+
item_value = _decode_scale_with_runtime(
84+
value_type,
85+
item_bytes,
86+
runtime_registry,
87+
True
88+
)
89+
90+
except Exception as _:
91+
if not ignore_decoding_errors:
92+
raise
93+
item_value = None
94+
result.append([item_key, item_value])
95+
return result
96+
97+
98+
if __name__ == "__main__":
99+
pass

0 commit comments

Comments
 (0)