Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental: multiprocessing for speeding up query map decodes. #84

Draft
wants to merge 2 commits into
base: staging
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 63 additions & 62 deletions async_substrate_interface/async_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,14 @@
)
from async_substrate_interface.utils.storage import StorageKey
from async_substrate_interface.type_registry import _TYPE_REGISTRY
from async_substrate_interface.utils.decoding_attempt import (
decode_query_map,
_decode_scale_with_runtime,
)

if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
from concurrent.futures import ProcessPoolExecutor

ResultHandler = Callable[[dict, Any], Awaitable[tuple[dict, bool]]]

Expand Down Expand Up @@ -413,6 +418,7 @@ def __init__(
last_key: Optional[str] = None,
max_results: Optional[int] = None,
ignore_decoding_errors: bool = False,
executor: Optional["ProcessPoolExecutor"] = None,
):
self.records = records
self.page_size = page_size
Expand All @@ -425,6 +431,7 @@ def __init__(
self.params = params
self.ignore_decoding_errors = ignore_decoding_errors
self.loading_complete = False
self.executor = executor
self._buffer = iter(self.records) # Initialize the buffer with initial records

async def retrieve_next_page(self, start_key) -> list:
Expand All @@ -437,6 +444,7 @@ async def retrieve_next_page(self, start_key) -> list:
start_key=start_key,
max_results=self.max_results,
ignore_decoding_errors=self.ignore_decoding_errors,
executor=self.executor,
)
if len(result.records) < self.page_size:
self.loading_complete = True
Expand Down Expand Up @@ -898,7 +906,7 @@ async def decode_scale(
else:
return obj

async def load_runtime(self, runtime):
def load_runtime(self, runtime):
self.runtime = runtime

# Update type registry
Expand Down Expand Up @@ -954,7 +962,7 @@ async def init_runtime(
)

if self.runtime and runtime_version == self.runtime.runtime_version:
return
return self.runtime

runtime = self.runtime_cache.retrieve(runtime_version=runtime_version)
if not runtime:
Expand Down Expand Up @@ -990,7 +998,7 @@ async def init_runtime(
runtime_version=runtime_version, runtime=runtime
)

await self.load_runtime(runtime)
self.load_runtime(runtime)

if self.ss58_format is None:
# Check and apply runtime constants
Expand All @@ -1000,6 +1008,7 @@ async def init_runtime(

if ss58_prefix_constant:
self.ss58_format = ss58_prefix_constant
return runtime

async def create_storage_key(
self,
Expand Down Expand Up @@ -2858,6 +2867,7 @@ async def query_map(
page_size: int = 100,
ignore_decoding_errors: bool = False,
reuse_block_hash: bool = False,
executor: Optional["ProcessPoolExecutor"] = None,
) -> AsyncQueryMapResult:
"""
Iterates over all key-pairs located at the given module and storage_function. The storage
Expand Down Expand Up @@ -2892,12 +2902,11 @@ async def query_map(
Returns:
AsyncQueryMapResult object
"""
hex_to_bytes_ = hex_to_bytes
params = params or []
block_hash = await self._get_current_block_hash(block_hash, reuse_block_hash)
if block_hash:
self.last_block_hash = block_hash
await self.init_runtime(block_hash=block_hash)
runtime = await self.init_runtime(block_hash=block_hash)

metadata_pallet = self.runtime.metadata.get_metadata_pallet(module)
if not metadata_pallet:
Expand Down Expand Up @@ -2952,19 +2961,6 @@ async def query_map(
result = []
last_key = None

def concat_hash_len(key_hasher: str) -> int:
"""
Helper function to avoid if statements
"""
if key_hasher == "Blake2_128Concat":
return 16
elif key_hasher == "Twox64Concat":
return 8
elif key_hasher == "Identity":
return 0
else:
raise ValueError("Unsupported hash type")

if len(result_keys) > 0:
last_key = result_keys[-1]

Expand All @@ -2975,51 +2971,55 @@ def concat_hash_len(key_hasher: str) -> int:

if "error" in response:
raise SubstrateRequestException(response["error"]["message"])

for result_group in response["result"]:
for item in result_group["changes"]:
try:
# Determine type string
key_type_string = []
for n in range(len(params), len(param_types)):
key_type_string.append(
f"[u8; {concat_hash_len(key_hashers[n])}]"
)
key_type_string.append(param_types[n])

item_key_obj = await self.decode_scale(
type_string=f"({', '.join(key_type_string)})",
scale_bytes=bytes.fromhex(item[0][len(prefix) :]),
return_scale_obj=True,
)

# strip key_hashers to use as item key
if len(param_types) - len(params) == 1:
item_key = item_key_obj[1]
else:
item_key = tuple(
item_key_obj[key + 1]
for key in range(len(params), len(param_types) + 1, 2)
)

except Exception as _:
if not ignore_decoding_errors:
raise
item_key = None

try:
item_bytes = hex_to_bytes_(item[1])

item_value = await self.decode_scale(
type_string=value_type,
scale_bytes=item_bytes,
return_scale_obj=True,
)
except Exception as _:
if not ignore_decoding_errors:
raise
item_value = None
result.append([item_key, item_value])
if executor:
# print(
# ("prefix", type("prefix")),
# ("runtime_registry", type(runtime.registry)),
# ("param_types", type(param_types)),
# ("params", type(params)),
# ("value_type", type(value_type)),
# ("key_hasher", type(key_hashers)),
# ("ignore_decoding_errors", type(ignore_decoding_errors)),
# )
result = await asyncio.get_running_loop().run_in_executor(
executor,
decode_query_map,
result_group["changes"],
prefix,
runtime.registry.registry,
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
)
# max_workers = executor._max_workers
# result_group_changes_groups = [result_group["changes"][i:i + max_workers] for i in range(0, len(result_group["changes"]), max_workers)]
# all_results = executor.map(
# self._decode_query_map,
# result_group["changes"],
# repeat(prefix),
# repeat(runtime.registry),
# repeat(param_types),
# repeat(params),
# repeat(value_type),
# repeat(key_hashers),
# repeat(ignore_decoding_errors)
# )
# for r in all_results:
# result.extend(r)
else:
result = decode_query_map(
result_group["changes"],
prefix,
runtime.registry.registry,
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
)
return AsyncQueryMapResult(
records=result,
page_size=page_size,
Expand All @@ -3031,6 +3031,7 @@ def concat_hash_len(key_hasher: str) -> int:
last_key=last_key,
max_results=max_results,
ignore_decoding_errors=ignore_decoding_errors,
executor=executor,
)

async def submit_extrinsic(
Expand Down
4 changes: 3 additions & 1 deletion async_substrate_interface/sync_substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ def __enter__(self):
return self

def __del__(self):
self.close()
self.ws.close()
print("DELETING SUBSTATE")
# self.ws.protocol.fail(code=1006) # ABNORMAL_CLOSURE

def initialize(self):
"""
Expand Down
104 changes: 104 additions & 0 deletions async_substrate_interface/utils/decoding_attempt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from scalecodec import ss58_encode

from async_substrate_interface.utils import hex_to_bytes
from bt_decode import decode as decode_by_type_string, PortableRegistry
from bittensor_wallet.utils import SS58_FORMAT


class ScaleObj:
def __init__(self, value):
self.value = value


def _decode_scale_with_runtime(
type_string: str,
scale_bytes: bytes,
runtime_registry: "Runtime",
return_scale_obj: bool = False,
):
if scale_bytes == b"":
return None
if type_string == "scale_info::0": # Is an AccountId
# Decode AccountId bytes to SS58 address
return ss58_encode(scale_bytes, SS58_FORMAT)
else:
obj = decode_by_type_string(type_string, runtime_registry, scale_bytes)
if return_scale_obj:
return ScaleObj(obj)
else:
return obj


def decode_query_map(
result_group_changes,
prefix,
runtime_registry,
param_types,
params,
value_type,
key_hashers,
ignore_decoding_errors,
):
def concat_hash_len(key_hasher: str) -> int:
"""
Helper function to avoid if statements
"""
if key_hasher == "Blake2_128Concat":
return 16
elif key_hasher == "Twox64Concat":
return 8
elif key_hasher == "Identity":
return 0
else:
raise ValueError("Unsupported hash type")

hex_to_bytes_ = hex_to_bytes
runtime_registry = PortableRegistry.from_json(runtime_registry)

result = []
for item in result_group_changes:
try:
# Determine type string
key_type_string = []
for n in range(len(params), len(param_types)):
key_type_string.append(f"[u8; {concat_hash_len(key_hashers[n])}]")
key_type_string.append(param_types[n])

item_key_obj = _decode_scale_with_runtime(
f"({', '.join(key_type_string)})",
bytes.fromhex(item[0][len(prefix) :]),
runtime_registry,
False,
)

# strip key_hashers to use as item key
if len(param_types) - len(params) == 1:
item_key = item_key_obj[1]
else:
item_key = tuple(
item_key_obj[key + 1]
for key in range(len(params), len(param_types) + 1, 2)
)

except Exception as _:
if not ignore_decoding_errors:
raise
item_key = None

try:
item_bytes = hex_to_bytes_(item[1])

item_value = _decode_scale_with_runtime(
value_type, item_bytes, runtime_registry, True
)

except Exception as _:
if not ignore_decoding_errors:
raise
item_value = None
result.append([item_key, item_value])
return result


if __name__ == "__main__":
pass
Loading