5
5
"""
6
6
7
7
import asyncio
8
+ from concurrent .futures import ProcessPoolExecutor
8
9
import inspect
9
10
import logging
10
11
import random
56
57
)
57
58
from async_substrate_interface .utils .storage import StorageKey
58
59
from async_substrate_interface .type_registry import _TYPE_REGISTRY
60
+ from async_substrate_interface .utils .decoding_attempt import _decode_query_map , _decode_scale_with_runtime
59
61
60
62
if TYPE_CHECKING :
61
63
from websockets .asyncio .client import ClientConnection
@@ -413,6 +415,7 @@ def __init__(
413
415
last_key : Optional [str ] = None ,
414
416
max_results : Optional [int ] = None ,
415
417
ignore_decoding_errors : bool = False ,
418
+ executor : Optional ["ProcessPoolExecutor" ] = None
416
419
):
417
420
self .records = records
418
421
self .page_size = page_size
@@ -425,6 +428,7 @@ def __init__(
425
428
self .params = params
426
429
self .ignore_decoding_errors = ignore_decoding_errors
427
430
self .loading_complete = False
431
+ self .executor = executor
428
432
self ._buffer = iter (self .records ) # Initialize the buffer with initial records
429
433
430
434
async def retrieve_next_page (self , start_key ) -> list :
@@ -437,6 +441,7 @@ async def retrieve_next_page(self, start_key) -> list:
437
441
start_key = start_key ,
438
442
max_results = self .max_results ,
439
443
ignore_decoding_errors = self .ignore_decoding_errors ,
444
+ executor = self .executor
440
445
)
441
446
if len (result .records ) < self .page_size :
442
447
self .loading_complete = True
@@ -862,6 +867,7 @@ async def encode_scale(
862
867
await self ._wait_for_registry (_attempt , _retries )
863
868
return self ._encode_scale (type_string , value )
864
869
870
+
865
871
async def decode_scale (
866
872
self ,
867
873
type_string : str ,
@@ -898,7 +904,7 @@ async def decode_scale(
898
904
else :
899
905
return obj
900
906
901
- async def load_runtime (self , runtime ):
907
+ def load_runtime (self , runtime ):
902
908
self .runtime = runtime
903
909
904
910
# Update type registry
@@ -954,7 +960,7 @@ async def init_runtime(
954
960
)
955
961
956
962
if self .runtime and runtime_version == self .runtime .runtime_version :
957
- return
963
+ return self . runtime
958
964
959
965
runtime = self .runtime_cache .retrieve (runtime_version = runtime_version )
960
966
if not runtime :
@@ -990,7 +996,7 @@ async def init_runtime(
990
996
runtime_version = runtime_version , runtime = runtime
991
997
)
992
998
993
- await self .load_runtime (runtime )
999
+ self .load_runtime (runtime )
994
1000
995
1001
if self .ss58_format is None :
996
1002
# Check and apply runtime constants
@@ -1000,6 +1006,7 @@ async def init_runtime(
1000
1006
1001
1007
if ss58_prefix_constant :
1002
1008
self .ss58_format = ss58_prefix_constant
1009
+ return runtime
1003
1010
1004
1011
async def create_storage_key (
1005
1012
self ,
@@ -2858,6 +2865,7 @@ async def query_map(
2858
2865
page_size : int = 100 ,
2859
2866
ignore_decoding_errors : bool = False ,
2860
2867
reuse_block_hash : bool = False ,
2868
+ executor : Optional ["ProcessPoolExecutor" ] = None
2861
2869
) -> AsyncQueryMapResult :
2862
2870
"""
2863
2871
Iterates over all key-pairs located at the given module and storage_function. The storage
@@ -2892,12 +2900,11 @@ async def query_map(
2892
2900
Returns:
2893
2901
AsyncQueryMapResult object
2894
2902
"""
2895
- hex_to_bytes_ = hex_to_bytes
2896
2903
params = params or []
2897
2904
block_hash = await self ._get_current_block_hash (block_hash , reuse_block_hash )
2898
2905
if block_hash :
2899
2906
self .last_block_hash = block_hash
2900
- await self .init_runtime (block_hash = block_hash )
2907
+ runtime = await self .init_runtime (block_hash = block_hash )
2901
2908
2902
2909
metadata_pallet = self .runtime .metadata .get_metadata_pallet (module )
2903
2910
if not metadata_pallet :
@@ -2952,19 +2959,6 @@ async def query_map(
2952
2959
result = []
2953
2960
last_key = None
2954
2961
2955
- def concat_hash_len (key_hasher : str ) -> int :
2956
- """
2957
- Helper function to avoid if statements
2958
- """
2959
- if key_hasher == "Blake2_128Concat" :
2960
- return 16
2961
- elif key_hasher == "Twox64Concat" :
2962
- return 8
2963
- elif key_hasher == "Identity" :
2964
- return 0
2965
- else :
2966
- raise ValueError ("Unsupported hash type" )
2967
-
2968
2962
if len (result_keys ) > 0 :
2969
2963
last_key = result_keys [- 1 ]
2970
2964
@@ -2975,51 +2969,51 @@ def concat_hash_len(key_hasher: str) -> int:
2975
2969
2976
2970
if "error" in response :
2977
2971
raise SubstrateRequestException (response ["error" ]["message" ])
2978
-
2979
2972
for result_group in response ["result" ]:
2980
- for item in result_group ["changes" ]:
2981
- try :
2982
- # Determine type string
2983
- key_type_string = []
2984
- for n in range (len (params ), len (param_types )):
2985
- key_type_string .append (
2986
- f"[u8; { concat_hash_len (key_hashers [n ])} ]"
2987
- )
2988
- key_type_string .append (param_types [n ])
2989
-
2990
- item_key_obj = await self .decode_scale (
2991
- type_string = f"({ ', ' .join (key_type_string )} )" ,
2992
- scale_bytes = bytes .fromhex (item [0 ][len (prefix ) :]),
2993
- return_scale_obj = True ,
2994
- )
2995
-
2996
- # strip key_hashers to use as item key
2997
- if len (param_types ) - len (params ) == 1 :
2998
- item_key = item_key_obj [1 ]
2999
- else :
3000
- item_key = tuple (
3001
- item_key_obj [key + 1 ]
3002
- for key in range (len (params ), len (param_types ) + 1 , 2 )
3003
- )
3004
-
3005
- except Exception as _ :
3006
- if not ignore_decoding_errors :
3007
- raise
3008
- item_key = None
3009
-
3010
- try :
3011
- item_bytes = hex_to_bytes_ (item [1 ])
3012
-
3013
- item_value = await self .decode_scale (
3014
- type_string = value_type ,
3015
- scale_bytes = item_bytes ,
3016
- return_scale_obj = True ,
3017
- )
3018
- except Exception as _ :
3019
- if not ignore_decoding_errors :
3020
- raise
3021
- item_value = None
3022
- result .append ([item_key , item_value ])
2973
+ if executor :
2974
+ # print(
2975
+ # ("prefix", type("prefix")),
2976
+ # ("runtime_registry", type(runtime.registry)),
2977
+ # ("param_types", type(param_types)),
2978
+ # ("params", type(params)),
2979
+ # ("value_type", type(value_type)),
2980
+ # ("key_hasher", type(key_hashers)),
2981
+ # ("ignore_decoding_errors", type(ignore_decoding_errors)),
2982
+ # )
2983
+ result = await asyncio .get_running_loop ().run_in_executor (
2984
+ executor ,
2985
+ _decode_query_map ,
2986
+ result_group ["changes" ],
2987
+ prefix ,
2988
+ runtime .registry .registry ,
2989
+ param_types ,
2990
+ params ,
2991
+ value_type , key_hashers , ignore_decoding_errors
2992
+ )
2993
+ # max_workers = executor._max_workers
2994
+ # result_group_changes_groups = [result_group["changes"][i:i + max_workers] for i in range(0, len(result_group["changes"]), max_workers)]
2995
+ # all_results = executor.map(
2996
+ # self._decode_query_map,
2997
+ # result_group["changes"],
2998
+ # repeat(prefix),
2999
+ # repeat(runtime.registry),
3000
+ # repeat(param_types),
3001
+ # repeat(params),
3002
+ # repeat(value_type),
3003
+ # repeat(key_hashers),
3004
+ # repeat(ignore_decoding_errors)
3005
+ # )
3006
+ # for r in all_results:
3007
+ # result.extend(r)
3008
+ else :
3009
+ result = _decode_query_map (
3010
+ result_group ["changes" ],
3011
+ prefix ,
3012
+ runtime .registry .registry ,
3013
+ param_types ,
3014
+ params ,
3015
+ value_type , key_hashers , ignore_decoding_errors
3016
+ )
3023
3017
return AsyncQueryMapResult (
3024
3018
records = result ,
3025
3019
page_size = page_size ,
@@ -3031,6 +3025,7 @@ def concat_hash_len(key_hasher: str) -> int:
3031
3025
last_key = last_key ,
3032
3026
max_results = max_results ,
3033
3027
ignore_decoding_errors = ignore_decoding_errors ,
3028
+ executor = executor
3034
3029
)
3035
3030
3036
3031
async def submit_extrinsic (
0 commit comments