From 863faea1d229203c6e77f3d92bb1fd03352fda01 Mon Sep 17 00:00:00 2001 From: Steffen Cruz Date: Fri, 8 Dec 2023 14:06:54 -0600 Subject: [PATCH] Add ocr_subnet dir, and TODOs --- ocr_subnet/__init__.py | 33 +++ ocr_subnet/base/__init__.py | 0 ocr_subnet/base/miner.py | 215 ++++++++++++++++++++ ocr_subnet/base/neuron.py | 168 ++++++++++++++++ ocr_subnet/base/validator.py | 332 +++++++++++++++++++++++++++++++ ocr_subnet/protocol.py | 48 +++++ ocr_subnet/utils/__init__.py | 3 + ocr_subnet/utils/config.py | 177 ++++++++++++++++ ocr_subnet/utils/misc.py | 112 +++++++++++ ocr_subnet/utils/uids.py | 63 ++++++ ocr_subnet/validator/__init__.py | 2 + ocr_subnet/validator/corrupt.py | 71 +++++++ ocr_subnet/validator/forward.py | 110 ++++++++++ ocr_subnet/validator/generate.py | 97 +++++++++ ocr_subnet/validator/reward.py | 138 +++++++++++++ ocr_subnet/validator/utils.py | 87 ++++++++ 16 files changed, 1656 insertions(+) create mode 100644 ocr_subnet/__init__.py create mode 100644 ocr_subnet/base/__init__.py create mode 100644 ocr_subnet/base/miner.py create mode 100644 ocr_subnet/base/neuron.py create mode 100644 ocr_subnet/base/validator.py create mode 100644 ocr_subnet/protocol.py create mode 100644 ocr_subnet/utils/__init__.py create mode 100644 ocr_subnet/utils/config.py create mode 100644 ocr_subnet/utils/misc.py create mode 100644 ocr_subnet/utils/uids.py create mode 100644 ocr_subnet/validator/__init__.py create mode 100644 ocr_subnet/validator/corrupt.py create mode 100644 ocr_subnet/validator/forward.py create mode 100644 ocr_subnet/validator/generate.py create mode 100644 ocr_subnet/validator/reward.py create mode 100644 ocr_subnet/validator/utils.py diff --git a/ocr_subnet/__init__.py b/ocr_subnet/__init__.py new file mode 100644 index 0000000..4854a3f --- /dev/null +++ b/ocr_subnet/__init__.py @@ -0,0 +1,33 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# TODO(developer): Set your name +# Copyright © 2023 + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +# TODO(developer): Change this value when updating your code base. +# Define the version of the template module. +__version__ = "0.0.0" +version_split = __version__.split(".") +__spec_version__ = ( + (1000 * int(version_split[0])) + + (10 * int(version_split[1])) + + (1 * int(version_split[2])) +) + +# Import all submodules. +from . import protocol +from . import base +from . import validator diff --git a/ocr_subnet/base/__init__.py b/ocr_subnet/base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ocr_subnet/base/miner.py b/ocr_subnet/base/miner.py new file mode 100644 index 0000000..242664d --- /dev/null +++ b/ocr_subnet/base/miner.py @@ -0,0 +1,215 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import time +import torch +import asyncio +import threading +import traceback + +import bittensor as bt + +from ocr_subnet.base.neuron import BaseNeuron + + +class BaseMinerNeuron(BaseNeuron): + """ + Base class for Bittensor miners. + """ + + def __init__(self, config=None): + super().__init__(config=config) + + # Warn if allowing incoming requests from anyone. + if not self.config.blacklist.force_validator_permit: + bt.logging.warning( + "You are allowing non-validators to send requests to your miner. This is a security risk." + ) + if self.config.blacklist.allow_non_registered: + bt.logging.warning( + "You are allowing non-registered entities to send requests to your miner. This is a security risk." + ) + + # The axon handles request processing, allowing validators to send this miner requests. + self.axon = bt.axon(wallet=self.wallet, port=self.config.axon.port) + + # Attach determiners which functions are called when servicing a request. + bt.logging.info(f"Attaching forward function to miner axon.") + self.axon.attach( + forward_fn=self.forward, + blacklist_fn=self.blacklist, + priority_fn=self.priority, + ) + bt.logging.info(f"Axon created: {self.axon}") + + # Instantiate runners + self.should_exit: bool = False + self.is_running: bool = False + self.thread: threading.Thread = None + self.lock = asyncio.Lock() + + def run(self): + """ + Initiates and manages the main loop for the miner on the Bittensor network. The main loop handles graceful shutdown on keyboard interrupts and logs unforeseen errors. + + This function performs the following primary tasks: + 1. Check for registration on the Bittensor network. + 2. Starts the miner's axon, making it active on the network. + 3. Periodically resynchronizes with the chain; updating the metagraph with the latest network state and setting weights. + + The miner continues its operations until `should_exit` is set to True or an external interruption occurs. + During each epoch of its operation, the miner waits for new blocks on the Bittensor network, updates its + knowledge of the network (metagraph), and sets its weights. This process ensures the miner remains active + and up-to-date with the network's latest state. + + Note: + - The function leverages the global configurations set during the initialization of the miner. + - The miner's axon serves as its interface to the Bittensor network, handling incoming and outgoing requests. + + Raises: + KeyboardInterrupt: If the miner is stopped by a manual interruption. + Exception: For unforeseen errors during the miner's operation, which are logged for diagnosis. + """ + + # Check that miner is registered on the network. + self.sync() + + # Serve passes the axon information to the network + netuid we are hosting on. + # This will auto-update if the axon port of external ip have changed. + bt.logging.info( + f"Serving miner axon {self.axon} on network: {self.config.subtensor.chain_endpoint} with netuid: {self.config.netuid}" + ) + self.axon.serve(netuid=self.config.netuid, subtensor=self.subtensor) + + # Start starts the miner's axon, making it active on the network. + self.axon.start() + + bt.logging.info(f"Miner starting at block: {self.block}") + + # This loop maintains the miner's operations until intentionally stopped. + try: + while not self.should_exit: + while ( + self.block - self.metagraph.last_update[self.uid] + < self.config.neuron.epoch_length + ): + # Wait before checking again. + time.sleep(1) + + # Check if we should exit. + if self.should_exit: + break + + # Sync metagraph and potentially set weights. + self.sync() + self.step += 1 + + # If someone intentionally stops the miner, it'll safely terminate operations. + except KeyboardInterrupt: + self.axon.stop() + bt.logging.success("Miner killed by keyboard interrupt.") + exit() + + # In case of unforeseen errors, the miner will log the error and continue operations. + except Exception as e: + bt.logging.error(traceback.format_exc()) + + def run_in_background_thread(self): + """ + Starts the miner's operations in a separate background thread. + This is useful for non-blocking operations. + """ + if not self.is_running: + bt.logging.debug("Starting miner in background thread.") + self.should_exit = False + self.thread = threading.Thread(target=self.run, daemon=True) + self.thread.start() + self.is_running = True + bt.logging.debug("Started") + + def stop_run_thread(self): + """ + Stops the miner's operations that are running in the background thread. + """ + if self.is_running: + bt.logging.debug("Stopping miner in background thread.") + self.should_exit = True + self.thread.join(5) + self.is_running = False + bt.logging.debug("Stopped") + + def __enter__(self): + """ + Starts the miner's operations in a background thread upon entering the context. + This method facilitates the use of the miner in a 'with' statement. + """ + self.run_in_background_thread() + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ + Stops the miner's background operations upon exiting the context. + This method facilitates the use of the miner in a 'with' statement. + + Args: + exc_type: The type of the exception that caused the context to be exited. + None if the context was exited without an exception. + exc_value: The instance of the exception that caused the context to be exited. + None if the context was exited without an exception. + traceback: A traceback object encoding the stack trace. + None if the context was exited without an exception. + """ + self.stop_run_thread() + + def set_weights(self): + """ + Self-assigns a weight of 1 to the current miner (identified by its UID) and + a weight of 0 to all other peers in the network. The weights determine the trust level the miner assigns to other nodes on the network. + + Raises: + Exception: If there's an error while setting weights, the exception is logged for diagnosis. + """ + try: + # --- query the chain for the most current number of peers on the network + chain_weights = torch.zeros( + self.subtensor.subnetwork_n(netuid=self.metagraph.netuid) + ) + chain_weights[self.uid] = 1 + + # --- Set weights. + self.subtensor.set_weights( + wallet=self.wallet, + netuid=self.metagraph.netuid, + uids=torch.arange(0, len(chain_weights)), + weights=chain_weights.to("cpu"), + wait_for_inclusion=False, + version_key=self.spec_version, + ) + + except Exception as e: + bt.logging.error( + f"Failed to set weights on chain with exception: { e }" + ) + + bt.logging.info(f"Set weights: {chain_weights}") + + def resync_metagraph(self): + """Resyncs the metagraph and updates the hotkeys and moving averages based on the new metagraph.""" + bt.logging.info("resync_metagraph()") + + # Sync the metagraph. + self.metagraph.sync(subtensor=self.subtensor) diff --git a/ocr_subnet/base/neuron.py b/ocr_subnet/base/neuron.py new file mode 100644 index 0000000..cc1bc41 --- /dev/null +++ b/ocr_subnet/base/neuron.py @@ -0,0 +1,168 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import copy +import typing + +import bittensor as bt + +from abc import ABC, abstractmethod + +# Sync calls set weights and also resyncs the metagraph. +from ocr_subnet.utils.config import check_config, add_args, config +from ocr_subnet.utils.misc import ttl_get_block +from ocr_subnet import __spec_version__ as spec_version + + +class BaseNeuron(ABC): + """ + Base class for Bittensor miners. This class is abstract and should be inherited by a subclass. It contains the core logic for all neurons; validators and miners. + + In addition to creating a wallet, subtensor, and metagraph, this class also handles the synchronization of the network state via a basic checkpointing mechanism based on epoch length. + """ + + @classmethod + def check_config(cls, config: "bt.Config"): + check_config(cls, config) + + @classmethod + def add_args(cls, parser): + add_args(cls, parser) + + @classmethod + def config(cls): + return config(cls) + + subtensor: "bt.subtensor" + wallet: "bt.wallet" + metagraph: "bt.metagraph" + spec_version: int = spec_version + + @property + def block(self): + return ttl_get_block(self) + + def __init__(self, config=None): + base_config = copy.deepcopy(config or BaseNeuron.config()) + self.config = self.config() + self.config.merge(base_config) + self.check_config(self.config) + + # Set up logging with the provided configuration and directory. + bt.logging(config=self.config, logging_dir=self.config.full_path) + + # If a gpu is required, set the device to cuda:N (e.g. cuda:0) + self.device = self.config.neuron.device + + # Log the configuration for reference. + bt.logging.info(self.config) + + # Build Bittensor objects + # These are core Bittensor classes to interact with the network. + bt.logging.info("Setting up bittensor objects.") + + # The wallet holds the cryptographic key pairs for the miner. + self.wallet = bt.wallet(config=self.config) + bt.logging.info(f"Wallet: {self.wallet}") + + # The subtensor is our connection to the Bittensor blockchain. + self.subtensor = bt.subtensor(config=self.config) + bt.logging.info(f"Subtensor: {self.subtensor}") + + # The metagraph holds the state of the network, letting us know about other validators and miners. + self.metagraph = self.subtensor.metagraph(self.config.netuid) + bt.logging.info(f"Metagraph: {self.metagraph}") + + # Check if the miner is registered on the Bittensor network before proceeding further. + self.check_registered() + + # Each miner gets a unique identity (UID) in the network for differentiation. + self.uid = self.metagraph.hotkeys.index( + self.wallet.hotkey.ss58_address + ) + bt.logging.info( + f"Running neuron on subnet: {self.config.netuid} with uid {self.uid} using network: {self.subtensor.chain_endpoint}" + ) + self.step = 0 + + @abstractmethod + async def forward(self, synapse: bt.Synapse) -> bt.Synapse: + ... + + @abstractmethod + def run(self): + ... + + def sync(self): + """ + Wrapper for synchronizing the state of the network for the given miner or validator. + """ + # Ensure miner or validator hotkey is still registered on the network. + self.check_registered() + + if self.should_sync_metagraph(): + self.resync_metagraph() + + if self.should_set_weights(): + self.set_weights() + + # Always save state. + self.save_state() + + def check_registered(self): + # --- Check for registration. + if not self.subtensor.is_hotkey_registered( + netuid=self.config.netuid, + hotkey_ss58=self.wallet.hotkey.ss58_address, + ): + bt.logging.error( + f"Wallet: {self.wallet} is not registered on netuid {self.config.netuid}." + f" Please register the hotkey using `btcli subnets register` before trying again" + ) + exit() + + def should_sync_metagraph(self): + """ + Check if enough epoch blocks have elapsed since the last checkpoint to sync. + """ + return ( + self.block - self.metagraph.last_update[self.uid] + ) > self.config.neuron.epoch_length + + def should_set_weights(self) -> bool: + # Don't set weights on initialization. + if self.step == 0: + return False + + # Check if enough epoch blocks have elapsed since the last epoch. + if self.config.neuron.disable_set_weights: + return False + + # Define appropriate logic for when set weights. + return ( + self.block - self.metagraph.last_update[self.uid] + ) > self.config.neuron.epoch_length + + def save_state(self): + bt.logging.warning( + "save_state() not implemented for this neuron. You can implement this function to save model checkpoints or other useful data." + ) + + def load_state(self): + bt.logging.warning( + "load_state() not implemented for this neuron. You can implement this function to load model checkpoints or other useful data." + ) diff --git a/ocr_subnet/base/validator.py b/ocr_subnet/base/validator.py new file mode 100644 index 0000000..5c67075 --- /dev/null +++ b/ocr_subnet/base/validator.py @@ -0,0 +1,332 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# TODO(developer): Set your name +# Copyright © 2023 + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + +import copy +import torch +import asyncio +import threading +import bittensor as bt + +from typing import List +from traceback import print_exception + +from ocr_subnet.base.neuron import BaseNeuron + + +class BaseValidatorNeuron(BaseNeuron): + """ + Base class for Bittensor validators. Your validator should inherit from this class. + """ + + def __init__(self, config=None): + super().__init__(config=config) + + # Save a copy of the hotkeys to local memory. + self.hotkeys = copy.deepcopy(self.metagraph.hotkeys) + + # Dendrite lets us send messages to other nodes (axons) in the network. + self.dendrite = bt.dendrite(wallet=self.wallet) + bt.logging.info(f"Dendrite: {self.dendrite}") + + # Set up initial scoring weights for validation + bt.logging.info("Building validation weights.") + self.scores = torch.zeros_like(self.metagraph.S, dtype=torch.float32) + + # Init sync with the network. Updates the metagraph. + self.sync() + + # Serve axon to enable external connections. + if not self.config.neuron.axon_off: + self.serve_axon() + else: + bt.logging.warning("axon off, not serving ip to chain.") + + # Create asyncio event loop to manage async tasks. + self.loop = asyncio.get_event_loop() + + # Instantiate runners + self.should_exit: bool = False + self.is_running: bool = False + self.thread: threading.Thread = None + self.lock = asyncio.Lock() + + def serve_axon(self): + """Serve axon to enable external connections.""" + + bt.logging.info("serving ip to chain...") + try: + self.axon = bt.axon(wallet=self.wallet, config=self.config) + + try: + self.subtensor.serve_axon( + netuid=self.config.netuid, + axon=self.axon, + ) + except Exception as e: + bt.logging.error(f"Failed to serve Axon with exception: {e}") + pass + + except Exception as e: + bt.logging.error( + f"Failed to create Axon initialize with exception: {e}" + ) + pass + + async def concurrent_forward(self): + coroutines = [ + self.forward() + for _ in range(self.config.neuron.num_concurrent_forwards) + ] + await asyncio.gather(*coroutines) + + def run(self): + """ + Initiates and manages the main loop for the miner on the Bittensor network. The main loop handles graceful shutdown on keyboard interrupts and logs unforeseen errors. + + This function performs the following primary tasks: + 1. Check for registration on the Bittensor network. + 2. Continuously forwards queries to the miners on the network, rewarding their responses and updating the scores accordingly. + 3. Periodically resynchronizes with the chain; updating the metagraph with the latest network state and setting weights. + + The essence of the validator's operations is in the forward function, which is called every step. The forward function is responsible for querying the network and scoring the responses. + + Note: + - The function leverages the global configurations set during the initialization of the miner. + - The miner's axon serves as its interface to the Bittensor network, handling incoming and outgoing requests. + + Raises: + KeyboardInterrupt: If the miner is stopped by a manual interruption. + Exception: For unforeseen errors during the miner's operation, which are logged for diagnosis. + """ + + # Check that validator is registered on the network. + self.sync() + + bt.logging.info( + f"Running validator {self.axon} on network: {self.config.subtensor.chain_endpoint} with netuid: {self.config.netuid}" + ) + + bt.logging.info(f"Validator starting at block: {self.block}") + + # This loop maintains the validator's operations until intentionally stopped. + try: + while True: + bt.logging.info(f"step({self.step}) block({self.block})") + + # Run multiple forwards concurrently. + self.loop.run_until_complete(self.concurrent_forward()) + + # Check if we should exit. + if self.should_exit: + break + + # Sync metagraph and potentially set weights. + self.sync() + + self.step += 1 + + # If someone intentionally stops the validator, it'll safely terminate operations. + except KeyboardInterrupt: + self.axon.stop() + bt.logging.success("Validator killed by keyboard interrupt.") + exit() + + # In case of unforeseen errors, the validator will log the error and continue operations. + except Exception as err: + bt.logging.error("Error during validation", str(err)) + bt.logging.debug( + print_exception(type(err), err, err.__traceback__) + ) + + def run_in_background_thread(self): + """ + Starts the validator's operations in a background thread upon entering the context. + This method facilitates the use of the validator in a 'with' statement. + """ + if not self.is_running: + bt.logging.debug("Starting validator in background thread.") + self.should_exit = False + self.thread = threading.Thread(target=self.run, daemon=True) + self.thread.start() + self.is_running = True + bt.logging.debug("Started") + + def stop_run_thread(self): + """ + Stops the validator's operations that are running in the background thread. + """ + if self.is_running: + bt.logging.debug("Stopping validator in background thread.") + self.should_exit = True + self.thread.join(5) + self.is_running = False + bt.logging.debug("Stopped") + + def __enter__(self): + self.run_in_background_thread() + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ + Stops the validator's background operations upon exiting the context. + This method facilitates the use of the validator in a 'with' statement. + + Args: + exc_type: The type of the exception that caused the context to be exited. + None if the context was exited without an exception. + exc_value: The instance of the exception that caused the context to be exited. + None if the context was exited without an exception. + traceback: A traceback object encoding the stack trace. + None if the context was exited without an exception. + """ + if self.is_running: + bt.logging.debug("Stopping validator in background thread.") + self.should_exit = True + self.thread.join(5) + self.is_running = False + bt.logging.debug("Stopped") + + def set_weights(self): + """ + Sets the validator weights to the metagraph hotkeys based on the scores it has received from the miners. The weights determine the trust and incentive level the validator assigns to miner nodes on the network. + """ + + # Check if self.scores contains any NaN values and log a warning if it does. + if torch.isnan(self.scores).any(): + bt.logging.warning( + f"Scores contain NaN values. This may be due to a lack of responses from miners, or a bug in your reward functions." + ) + + # Calculate the average reward for each uid across non-zero values. + # Replace any NaN values with 0. + raw_weights = torch.nn.functional.normalize(self.scores, p=1, dim=0) + bt.logging.trace("raw_weights", raw_weights) + bt.logging.trace("top10 values", raw_weights.sort()[0]) + bt.logging.trace("top10 uids", raw_weights.sort()[1]) + + # Process the raw weights to final_weights via subtensor limitations. + ( + processed_weight_uids, + processed_weights, + ) = bt.utils.weight_utils.process_weights_for_netuid( + uids=self.metagraph.uids.to("cpu"), + weights=raw_weights.to("cpu"), + netuid=self.config.netuid, + subtensor=self.subtensor, + metagraph=self.metagraph, + ) + bt.logging.trace("processed_weights", processed_weights) + bt.logging.trace("processed_weight_uids", processed_weight_uids) + + # Set the weights on chain via our subtensor connection. + self.subtensor.set_weights( + wallet=self.wallet, + netuid=self.config.netuid, + uids=processed_weight_uids, + weights=processed_weights, + wait_for_finalization=False, + version_key=self.spec_version, + ) + + bt.logging.info(f"Set weights: {processed_weights}") + + def resync_metagraph(self): + """Resyncs the metagraph and updates the hotkeys and moving averages based on the new metagraph.""" + bt.logging.info("resync_metagraph()") + + # Copies state of metagraph before syncing. + previous_metagraph = copy.deepcopy(self.metagraph) + + # Sync the metagraph. + self.metagraph.sync(subtensor=self.subtensor) + + # Check if the metagraph axon info has changed. + if previous_metagraph.axons == self.metagraph.axons: + return + + bt.logging.info( + "Metagraph updated, re-syncing hotkeys, dendrite pool and moving averages" + ) + # Zero out all hotkeys that have been replaced. + for uid, hotkey in enumerate(self.hotkeys): + if hotkey != self.metagraph.hotkeys[uid]: + self.scores[uid] = 0 # hotkey has been replaced + + # Check to see if the metagraph has changed size. + # If so, we need to add new hotkeys and moving averages. + if len(self.hotkeys) < len(self.metagraph.hotkeys): + # Update the size of the moving average scores. + new_moving_average = torch.zeros((self.metagraph.n)).to( + self.device + ) + min_len = min(len(self.hotkeys), len(self.scores)) + new_moving_average[:min_len] = self.scores[:min_len] + self.scores = new_moving_average + + # Update the hotkeys. + self.hotkeys = copy.deepcopy(self.metagraph.hotkeys) + + def update_scores(self, rewards: torch.FloatTensor, uids: List[int]): + """Performs exponential moving average on the scores based on the rewards received from the miners.""" + + # Check if rewards contains NaN values. + if torch.isnan(rewards).any(): + bt.logging.warning(f"NaN values detected in rewards: {rewards}") + # Replace any NaN values in rewards with 0. + rewards = torch.nan_to_num(rewards, 0) + + # Compute forward pass rewards, assumes uids are mutually exclusive. + # shape: [ metagraph.n ] + scattered_rewards: torch.FloatTensor = self.scores.scatter( + 0, torch.tensor(uids).to(self.device), rewards + ).to(self.device) + bt.logging.debug(f"Scattered rewards: {rewards}") + + # Update scores with rewards produced by this step. + # shape: [ metagraph.n ] + alpha: float = self.config.neuron.moving_average_alpha + self.scores: torch.FloatTensor = alpha * scattered_rewards + ( + 1 - alpha + ) * self.scores.to(self.device) + bt.logging.debug(f"Updated moving avg scores: {self.scores}") + + def save_state(self): + """Saves the state of the validator to a file.""" + bt.logging.info("Saving validator state.") + + # Save the state of the validator to file. + torch.save( + { + "step": self.step, + "scores": self.scores, + "hotkeys": self.hotkeys, + }, + self.config.neuron.full_path + "/state.pt", + ) + + def load_state(self): + """Loads the state of the validator from a file.""" + bt.logging.info("Loading validator state.") + + # Load the state of the validator from file. + state = torch.load(self.config.neuron.full_path + "/state.pt") + self.step = state["step"] + self.scores = state["scores"] + self.hotkeys = state["hotkeys"] diff --git a/ocr_subnet/protocol.py b/ocr_subnet/protocol.py new file mode 100644 index 0000000..ba6617b --- /dev/null +++ b/ocr_subnet/protocol.py @@ -0,0 +1,48 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# TODO(developer): Set your name +# Copyright © 2023 + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import typing +import bittensor as bt + + +class OCRSynapse(bt.Synapse): + """ + A simple OCR synapse protocol representation which uses bt.Synapse as its base. + This protocol enables communication betweenthe miner and the validator. + + Attributes: + - image: A pdf image to be processed by the miner. + - response: List[dict] containing data extracted from the image. + """ + + # Required request input, filled by sending dendrite caller. + image: int + + # Optional request output, filled by recieving axon. + response: typing.Optional[typing.List[dict]] = None + + def deserialize(self) -> int: + """ + Deserialize the miner response. This method retrieves the response from + the miner in the form of `response`, maybe this also takes care of casting it to List[dict]? + + Returns: + - List[dict: The deserialized response, which is a list of dictionaries containing the extracted data. + """ + return self.response diff --git a/ocr_subnet/utils/__init__.py b/ocr_subnet/utils/__init__.py new file mode 100644 index 0000000..1e61220 --- /dev/null +++ b/ocr_subnet/utils/__init__.py @@ -0,0 +1,3 @@ +from . import config +from . import misc +from . import uids diff --git a/ocr_subnet/utils/config.py b/ocr_subnet/utils/config.py new file mode 100644 index 0000000..c087b5e --- /dev/null +++ b/ocr_subnet/utils/config.py @@ -0,0 +1,177 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import os +import torch +import argparse +import bittensor as bt +from loguru import logger + + +def check_config(cls, config: "bt.Config"): + r"""Checks/validates the config namespace object.""" + bt.logging.check_config(config) + + full_path = os.path.expanduser( + "{}/{}/{}/netuid{}/{}".format( + config.logging.logging_dir, # TODO: change from ~/.bittensor/miners to ~/.bittensor/neurons + config.wallet.name, + config.wallet.hotkey, + config.netuid, + config.neuron.name, + ) + ) + print("full path:", full_path) + config.neuron.full_path = os.path.expanduser(full_path) + if not os.path.exists(config.neuron.full_path): + os.makedirs(config.neuron.full_path, exist_ok=True) + + if not config.neuron.dont_save_events: + # Add custom event logger for the events. + logger.level("EVENTS", no=38, icon="📝") + logger.add( + os.path.join(config.neuron.full_path, "events.log"), + rotation=config.neuron.events_retention_size, + serialize=True, + enqueue=True, + backtrace=False, + diagnose=False, + level="EVENTS", + format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}", + ) + + +def add_args(cls, parser): + """ + Adds relevant arguments to the parser for operation. + """ + # Netuid Arg: The netuid of the subnet to connect to. + parser.add_argument("--netuid", type=int, help="Subnet netuid", default=1) + + neuron_type = ( + "validator" if "miner" not in cls.__name__.lower() else "miner" + ) + + parser.add_argument( + "--neuron.name", + type=str, + help="Trials for this neuron go in neuron.root / (wallet_cold - wallet_hot) / neuron.name. ", + default=neuron_type, + ) + + parser.add_argument( + "--neuron.device", + type=str, + help="Device to run on.", + default="cpu", + ) + + parser.add_argument( + "--neuron.epoch_length", + type=int, + help="The default epoch length (how often we set weights, measured in 12 second blocks).", + default=100, + ) + + parser.add_argument( + "--neuron.events_retention_size", + type=str, + help="Events retention size.", + default="2 GB", + ) + + parser.add_argument( + "--neuron.dont_save_events", + action="store_true", + help="If set, we dont save events to a log file.", + default=False, + ) + + if neuron_type == "validator": + parser.add_argument( + "--neuron.num_concurrent_forwards", + type=int, + help="The number of concurrent forwards running at any time.", + default=1, + ) + + parser.add_argument( + "--neuron.sample_size", + type=int, + help="The number of miners to query in a single step.", + default=10, + ) + + parser.add_argument( + "--neuron.disable_set_weights", + action="store_true", + help="Disables setting weights.", + default=False, + ) + + parser.add_argument( + "--neuron.moving_average_alpha", + type=float, + help="Moving average alpha parameter, how much to add of the new observation.", + default=0.05, + ) + + parser.add_argument( + "--neuron.axon_off", + "--axon_off", + action="store_true", + # Note: the validator needs to serve an Axon with their IP or they may + # be blacklisted by the firewall of serving peers on the network. + help="Set this flag to not attempt to serve an Axon.", + default=False, + ) + + parser.add_argument( + "--neuron.vpermit_tao_limit", + type=int, + help="The maximum number of TAO allowed to query a validator with a vpermit.", + default=4096, + ) + + else: + parser.add_argument( + "--blacklist.force_validator_permit", + action="store_true", + help="If set, we will force incoming requests to have a permit.", + default=False, + ) + + parser.add_argument( + "--blacklist.allow_non_registered", + action="store_true", + help="If set, miners will accept queries from non registered entities. (Dangerous!)", + default=False, + ) + + +def config(cls): + """ + Returns the configuration object specific to this miner or validator after adding relevant arguments. + """ + parser = argparse.ArgumentParser() + bt.wallet.add_args(parser) + bt.subtensor.add_args(parser) + bt.logging.add_args(parser) + bt.axon.add_args(parser) + cls.add_args(parser) + return bt.config(parser) diff --git a/ocr_subnet/utils/misc.py b/ocr_subnet/utils/misc.py new file mode 100644 index 0000000..80b4e61 --- /dev/null +++ b/ocr_subnet/utils/misc.py @@ -0,0 +1,112 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# Copyright © 2023 Opentensor Foundation + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import time +import math +import hashlib as rpccheckhealth +from math import floor +from typing import Callable, Any +from functools import lru_cache, update_wrapper + + +# LRU Cache with TTL +def ttl_cache(maxsize: int = 128, typed: bool = False, ttl: int = -1): + """ + Decorator that creates a cache of the most recently used function calls with a time-to-live (TTL) feature. + The cache evicts the least recently used entries if the cache exceeds the `maxsize` or if an entry has + been in the cache longer than the `ttl` period. + + Args: + maxsize (int): Maximum size of the cache. Once the cache grows to this size, subsequent entries + replace the least recently used ones. Defaults to 128. + typed (bool): If set to True, arguments of different types will be cached separately. For example, + f(3) and f(3.0) will be treated as distinct calls with distinct results. Defaults to False. + ttl (int): The time-to-live for each cache entry, measured in seconds. If set to a non-positive value, + the TTL is set to a very large number, effectively making the cache entries permanent. Defaults to -1. + + Returns: + Callable: A decorator that can be applied to functions to cache their return values. + + The decorator is useful for caching results of functions that are expensive to compute and are called + with the same arguments frequently within short periods of time. The TTL feature helps in ensuring + that the cached values are not stale. + + Example: + @ttl_cache(ttl=10) + def get_data(param): + # Expensive data retrieval operation + return data + """ + if ttl <= 0: + ttl = 65536 + hash_gen = _ttl_hash_gen(ttl) + + def wrapper(func: Callable) -> Callable: + @lru_cache(maxsize, typed) + def ttl_func(ttl_hash, *args, **kwargs): + return func(*args, **kwargs) + + def wrapped(*args, **kwargs) -> Any: + th = next(hash_gen) + return ttl_func(th, *args, **kwargs) + + return update_wrapper(wrapped, func) + + return wrapper + + +def _ttl_hash_gen(seconds: int): + """ + Internal generator function used by the `ttl_cache` decorator to generate a new hash value at regular + time intervals specified by `seconds`. + + Args: + seconds (int): The number of seconds after which a new hash value will be generated. + + Yields: + int: A hash value that represents the current time interval. + + This generator is used to create time-based hash values that enable the `ttl_cache` to determine + whether cached entries are still valid or if they have expired and should be recalculated. + """ + start_time = time.time() + while True: + yield floor((time.time() - start_time) / seconds) + + +# 12 seconds updating block. +@ttl_cache(maxsize=1, ttl=12) +def ttl_get_block(self) -> int: + """ + Retrieves the current block number from the blockchain. This method is cached with a time-to-live (TTL) + of 12 seconds, meaning that it will only refresh the block number from the blockchain at most every 12 seconds, + reducing the number of calls to the underlying blockchain interface. + + Returns: + int: The current block number on the blockchain. + + This method is useful for applications that need to access the current block number frequently and can + tolerate a delay of up to 12 seconds for the latest information. By using a cache with TTL, the method + efficiently reduces the workload on the blockchain interface. + + Example: + current_block = ttl_get_block(self) + + Note: self here is the miner or validator instance + """ + return self.subtensor.get_current_block() diff --git a/ocr_subnet/utils/uids.py b/ocr_subnet/utils/uids.py new file mode 100644 index 0000000..ce78c80 --- /dev/null +++ b/ocr_subnet/utils/uids.py @@ -0,0 +1,63 @@ +import torch +import random +import bittensor as bt +from typing import List + + +def check_uid_availability( + metagraph: "bt.metagraph.Metagraph", uid: int, vpermit_tao_limit: int +) -> bool: + """Check if uid is available. The UID should be available if it is serving and has less than vpermit_tao_limit stake + Args: + metagraph (:obj: bt.metagraph.Metagraph): Metagraph object + uid (int): uid to be checked + vpermit_tao_limit (int): Validator permit tao limit + Returns: + bool: True if uid is available, False otherwise + """ + # Filter non serving axons. + if not metagraph.axons[uid].is_serving: + return False + # Filter validator permit > 1024 stake. + if metagraph.validator_permit[uid]: + if metagraph.S[uid] > vpermit_tao_limit: + return False + # Available otherwise. + return True + + +def get_random_uids( + self, k: int, exclude: List[int] = None +) -> torch.LongTensor: + """Returns k available random uids from the metagraph. + Args: + k (int): Number of uids to return. + exclude (List[int]): List of uids to exclude from the random sampling. + Returns: + uids (torch.LongTensor): Randomly sampled available uids. + Notes: + If `k` is larger than the number of available `uids`, set `k` to the number of available `uids`. + """ + candidate_uids = [] + avail_uids = [] + + for uid in range(self.metagraph.n.item()): + uid_is_available = check_uid_availability( + self.metagraph, uid, self.config.neuron.vpermit_tao_limit + ) + uid_is_not_excluded = exclude is None or uid not in exclude + + if uid_is_available: + avail_uids.append(uid) + if uid_is_not_excluded: + candidate_uids.append(uid) + + # Check if candidate_uids contain enough for querying, if not grab all avaliable uids + available_uids = candidate_uids + if len(candidate_uids) < k: + available_uids += random.sample( + [uid for uid in avail_uids if uid not in candidate_uids], + k - len(candidate_uids), + ) + uids = torch.tensor(random.sample(available_uids, k)) + return uids diff --git a/ocr_subnet/validator/__init__.py b/ocr_subnet/validator/__init__.py new file mode 100644 index 0000000..e43fa85 --- /dev/null +++ b/ocr_subnet/validator/__init__.py @@ -0,0 +1,2 @@ +from .forward import forward +from .reward import reward diff --git a/ocr_subnet/validator/corrupt.py b/ocr_subnet/validator/corrupt.py new file mode 100644 index 0000000..1efd68c --- /dev/null +++ b/ocr_subnet/validator/corrupt.py @@ -0,0 +1,71 @@ +import pdf2image +import math +import random +from IPython.display import display +from PIL import ImageFilter, ImageDraw + + + +def corrupt_image(input_pdf_path, output_pdf_path, theta=1, border=50, noise=0.1, scale=0.95, blur=1, spot=(100,100)): + # Convert PDF to images + images = pdf2image.convert_from_path(input_pdf_path) + + processed_images = [] + + for i, image in enumerate(images): + + display(image) + width, height = image.size + + + # # imitate curled page by making the top-right and bottom-left corners go slightly up and darkening the edges + if border is not None: + for x in range(1,border): + tone = 256 - int(250*(x/border-1)**2) + for y in range(height): + # only update color if the pixel is white + if min(image.getpixel((x,y))) < 20: + print(image.getpixel((x,y))) + continue + image.putpixel((x, y), (tone, tone, tone)) + image.putpixel((width-x, y), (tone, tone, tone)) + + # Apply noise + if noise is not None: + draw = ImageDraw.Draw(image) + for _ in range(int(width * height * noise)): + x = random.randint(0, width - 1) + y = random.randint(0, height - 1) + # TODO: Parameterize + delta = random.gauss(0,50) + rgb = tuple([int(min(max(0,val+delta),256)) for val in image.getpixel((x,y))]) + draw.point((x, y), fill=rgb) + + if spot is not None: + draw = ImageDraw.Draw(image) + for _ in range(int(width * height * noise)): + x = random.randint(0, width - 1) + y = random.randint(0, height - 1) + #TODO: Parameterize + delta = 100000 / (1 + math.sqrt((spot[0]-x)**2 + (spot[1]-y)**2)) + rgb = tuple([int(min(max(0,val-delta),256)) for val in image.getpixel((x,y))]) + draw.point((x, y), fill=rgb) + + # rescale the image within 10% to 20% + if scale is not None: + image = image.resize(size=(int(scale*width), int(scale*height))) + + # apply a rotation + if theta is not None: + image = image.rotate(theta, expand=True) + + # Apply blur + if blur is not None: + image = image.filter(ImageFilter.GaussianBlur(blur)) + + display(image) + + processed_images.append(image) + + # Save processed images back as a PDF + processed_images[0].save(output_pdf_path, "PDF", resolution=100.0, save_all=True, append_images=processed_images[1:]) diff --git a/ocr_subnet/validator/forward.py b/ocr_subnet/validator/forward.py new file mode 100644 index 0000000..55dcd7f --- /dev/null +++ b/ocr_subnet/validator/forward.py @@ -0,0 +1,110 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# TODO(developer): Set your name +# Copyright © 2023 + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import bittensor as bt + +import os +from PIL import Image + +from ocr_subnet.protocol import OCRSynapse +from ocr_subnet.validator.reward import get_rewards +from ocr_subnet.utils.uids import get_random_uids +from ocr_subnet.validator.generate import create_invoice +from ocr_subnet.validator.corrupt import corrupt_image + +def generate_image(image_type='invoice', corrupt=False): + """ + Generates a random invoice image to be sent to the miner. + + Returns: + - PIL.Image: The generated image. + + # TODO: return image and label (i.e. annotations for scoring) + """ + root_dir = './data/images/' + if not os.path.exists(root_dir): + os.makedirs(root_dir) + + if image_type == 'invoice': + path = create_invoice(root_dir=root_dir) + else: + raise NotImplementedError(f"Image type {image_type} not implemented.") + + if corrupt: + path = path.replace('.pdf', '_corrupt.pdf') + path = corrupt_image(path) + + return path + + +def load_image(path): + """ + Loads an image from the given path. + + Args: + - path (str): The path to the image. + + Returns: + - PIL.Image: The loaded image. + """ + return Image.open(path) + + +async def forward(self): + """ + The forward function is called by the validator every time step. + + It is responsible for querying the network and scoring the responses. + + Args: + self (:obj:`bittensor.neuron.Neuron`): The neuron object which contains all the necessary state for the validator. + + """ + + # get_random_uids is an example method, but you can replace it with your own. + miner_uids = get_random_uids(self, k=self.config.neuron.sample_size) + + # Create a random image and load it. + image_path = generate_image() + image = load_image(image_path) + + # Create synapse object to send to the miner and attach the image. + # TODO: it's probably not possible to send the image directly, so you'll need to encode it somehow. + synapse = OCRSynapse(image = image) + + # The dendrite client queries the network. + responses = self.dendrite.query( + # Send the query to selected miner axons in the network. + axons=[self.metagraph.axons[uid] for uid in miner_uids], + # Pass the synapse to the miner. + synapse=synapse, + # All responses have the deserialize function called on them before returning. + # You are encouraged to define your own deserialization function. + deserialize=True, + ) + + # Log the results for monitoring purposes. + bt.logging.info(f"Received responses: {responses}") + + # TODO: We need ground truth labels to score the responses! + rewards = get_rewards(self, query=self.step, responses=responses) + + bt.logging.info(f"Scored responses: {rewards}") + # Update the scores based on the rewards. You may want to define your own update_scores function for custom behavior. + self.update_scores(rewards, miner_uids) diff --git a/ocr_subnet/validator/generate.py b/ocr_subnet/validator/generate.py new file mode 100644 index 0000000..2e128f1 --- /dev/null +++ b/ocr_subnet/validator/generate.py @@ -0,0 +1,97 @@ +import os +import datetime +import random +from faker import Faker +from reportlab.lib.pagesizes import letter +from reportlab.pdfgen import canvas + +def apply_invoice_template(invoice_data, filename): + c = canvas.Canvas(filename, pagesize=letter) + c.setLineWidth(.3) + c.setFont('Helvetica', 12) + + # Draw the invoice header + c.drawString(30, 750, invoice_data['company_name']) + c.drawString(30, 735, invoice_data['company_address']) + c.drawString(30, 720, invoice_data['company_city_zip']) + c.drawString(400, 750, "Invoice Date: " + invoice_data['invoice_date']) + c.drawString(400, 735, "Invoice #: " + invoice_data['invoice_number']) + + # Draw the bill to section + c.drawString(30, 690, "Bill To:") + c.drawString(120, 690, invoice_data['customer_name']) + + # Table headers + c.drawString(30, 650, "Description") + c.drawString(300, 650, "Qty") + c.drawString(460, 650, "Cost") + c.line(30, 645, 560, 645) + + # List items + line_height = 625 + total = 0 + for item in invoice_data['items']: + c.drawString(30, line_height, item['desc']) + c.drawString(300, line_height, str(item['qty'])) + c.drawString(460, line_height, "${:.2f}".format(item['cost'])) + total += item['qty'] * item['cost'] + line_height -= 15 + + # Draw the total cost + c.drawString(400, line_height - 15, f"Total: ${total:,.2f}" ) + + # Terms and Conditions + c.drawString(30, line_height - 45, "Terms:") + c.drawString(120, line_height - 45, invoice_data['terms']) + + c.save() + +def create_invoice(root_dir): + + items_list = [ + {"desc": "Web hosting", "cost": 100.00}, + {"desc": "Domain registration", "cost": 10.00}, + {"desc": "SSL certificate", "cost": 5.50}, + {"desc": "Web design", "cost": 500.00}, + {"desc": "Web development", "cost": 500.00}, + {"desc": "SEO", "cost": 100.00}, + {"desc": "Content creation", "cost": 300.00}, + {"desc": "Social media marketing", "cost": 400.00}, + {"desc": "Email marketing", "cost": 150.00}, + {"desc": "PPC advertising", "cost": 200.00}, + {"desc": "Analytics", "cost": 400.00}, + {"desc": "Consulting", "cost": 700.00}, + {"desc": "Training", "cost": 1200.00}, + {"desc": "Maintenance", "cost": 650.00}, + {"desc": "Support", "cost": 80.00}, + {"desc": "Graphic design", "cost": 310.00}, + {"desc": "Logo design", "cost": 140.00}, + {"desc": "Branding", "cost": 750.00}, + ] + + def random_items(n): + items = sorted(random.sample(items_list, k=n), key=lambda x: x['desc']) + return [{**item, 'qty':random.randint(1,5)} for item in items] + + fake = Faker() + + # Sample data for the invoice + invoice_info = { + "company_name": fake.company(), + "company_address": fake.address(), + "company_city_zip": f'{fake.city()}, {fake.zipcode()}', + "company_phone": fake.phone_number(), + "customer_name": fake.name(), + "invoice_date": datetime.date.fromtimestamp(1700176424-random.random()*5e8).strftime("%B %d, %Y"), + "invoice_number": f"INV{random.randint(1,10000):06}", + "items": random_items(random.randint(3,15)), + "terms": "Payment due within 30 days" + } + + # make a random hash for the filename + filename = f"{fake.sha256()}.pdf" + path = os.path.join(root_dir, filename) + # Use the function and pass the data and the filename you want to save as + apply_invoice_template(invoice_info, path) + + return path \ No newline at end of file diff --git a/ocr_subnet/validator/reward.py b/ocr_subnet/validator/reward.py new file mode 100644 index 0000000..d76c8ca --- /dev/null +++ b/ocr_subnet/validator/reward.py @@ -0,0 +1,138 @@ +# The MIT License (MIT) +# Copyright © 2023 Yuma Rao +# TODO(developer): Set your name +# Copyright © 2023 + +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated +# documentation files (the “Software”), to deal in the Software without restriction, including without limitation +# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, +# and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of +# the Software. + +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO +# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + +import torch +from typing import List +from PIL import Image +from ocr_subnet.validator.utils import get_iou, get_edit_distance, get_font_distance + +""" +Loss function for OCR model: + +$$ L = \sum_i \alpha_p L^p_i + \alpha_f L^f_i + \alpha_t L^t_i $$ + +where + +$ L^p_i $ is the loss for section i based on positional/layout correctness. This should be zero if the OCR model returns the exact box on the page. + +We propose that the positional loss is the intersection over union of the bounding boxes: +$$ L^p_i = IOU(\hat{b}_i, b_i) $$ + +where $ \hat{b}_i $ is the predicted bounding box and $ b_i $ is the ground truth bounding box. + + +$ L^f_i $ is the loss for section i based on font correctness. This should be zero if the OCR model returns the exact font for the section, including font family, font size and perhaps even colors. + +We propose that the font loss is a delta between the predicted font and the ground truth font plus the square of the difference in font size: +$$ L^f_i = \alpha_f^f (1 - \delta(\hat{f}_i, f_i) )+ \alpha_f^s (\hat{s}_i - s_i)^2 $$ + +$ L^t_i $ is the loss for section i based on text correctness. This should be zero if the OCR model returns the exact text for the section. + +We propose that the text loss is the edit distance between the predicted text and the ground truth text: +$$ L^t_i = ED(\hat{t}_i, t_i) $$ + +where $ ED $ is the edit distance function. This is equivalent to the Levenshtein distance. + +$ \alpha_p, \alpha_f, \alpha_t $ are weights for each of the loss terms. These will impact the difficulty of the OCR challenge as text correctness is likely much easier than position correctness etc. + +We will invert the loss to produce a reward which is to be maximized by the miner. The reward is: + +$$ R = 1 / L $$ + +where $ L $ is the loss function defined above. This probably some epsilon to avoid division by zero. +""" + + +def score_section(image: Image, section: dict, alpha_p=1.0, alpha_f=1.0, alpha_t=1.0): + """ + Score a section of the image based on the section's correctness. + Correctness is defined as: + - the intersection over union of the bounding boxes, + - the delta between the predicted font and the ground truth font, + - and the edit distance between the predicted text and the ground truth text. + + Args: + - section (dict): The section of the image to score. + + Returns: + - float: The score for the section. + """ + # position loss is IOU of the bounding boxes + rect1 = section.get('position') + if rect1: + position_loss = get_iou(rect1, image.position) + else: + # otherwise set to max loss + position_loss = 1.0 + + font1 = section.get('font') + if font1: + font_loss = get_font_distance(font1, image.font) # this should actually calculate the font loss + else: + font_loss = 1.0 + + text1 = section.get('text') + if text1: + text_loss = get_edit_distance(text1, image.text) + + # TODO: convert loss to reward (invert and scale) + # TODO: add time penalty + return alpha_p * position_loss + alpha_f * font_loss + alpha_t * text_loss + + +def reward(image: Image, response: List[dict]) -> float: + """ + Reward the miner response to the OCR request. This method returns a reward + value for the miner, which is used to update the miner's score. + + Args: + - image (Image): The image sent to the miner. + - response (List[dict]): Response from the miner. + + The expected fields in each section of the response are: + - position (List[int]): The bounding box of the section e.g. [10, 20, 30, 40] + - font (dict): The font of the section e.g. {'family': 'Times New Roman', 'size':12} + - text (str): The text of the section e.g. 'Hello World!' + + Returns: + - float: The reward value for the miner. + """ + + return sum(score_section(section) for section in response) + + +def get_rewards( + self, + image: Image, + responses: List[List[dict]], +) -> torch.FloatTensor: + """ + Returns a tensor of rewards for the given image and responses. + + Args: + - image (Image): The image sent to the miner. + - responses (List[List[dict]]): A list of responses from the miner. + + Returns: + - torch.FloatTensor: A tensor of rewards for the given image and responses. + """ + # Get all the reward results by iteratively calling your reward() function. + return torch.FloatTensor( + [reward(image, response) for response in responses] + ).to(self.device) diff --git a/ocr_subnet/validator/utils.py b/ocr_subnet/validator/utils.py new file mode 100644 index 0000000..d0e6ec0 --- /dev/null +++ b/ocr_subnet/validator/utils.py @@ -0,0 +1,87 @@ + +import editdistance + +def get_iou(bb1, bb2): + """ + Calculate the Intersection over Union (IoU) of two bounding boxes. + NOTE: Thanks to this guy! https://stackoverflow.com/questions/25349178/calculating-percentage-of-bounding-box-overlap-for-image-detector-evaluation + + Parameters + ---------- + bb1 : dict + Keys: {'x1', 'x2', 'y1', 'y2'} + The (x1, y1) position is at the top left corner, + the (x2, y2) position is at the bottom right corner + bb2 : dict + Keys: {'x1', 'x2', 'y1', 'y2'} + The (x, y) position is at the top left corner, + the (x2, y2) position is at the bottom right corner + + Returns + ------- + float: Normalized between 0 and 1. + """ + assert bb1['x1'] < bb1['x2'] + assert bb1['y1'] < bb1['y2'] + assert bb2['x1'] < bb2['x2'] + assert bb2['y1'] < bb2['y2'] + + # determine the coordinates of the intersection rectangle + x_left = max(bb1['x1'], bb2['x1']) + y_top = max(bb1['y1'], bb2['y1']) + x_right = min(bb1['x2'], bb2['x2']) + y_bottom = min(bb1['y2'], bb2['y2']) + + if x_right < x_left or y_bottom < y_top: + return 0.0 + + # The intersection of two axis-aligned bounding boxes is always an + # axis-aligned bounding box + intersection_area = (x_right - x_left) * (y_bottom - y_top) + + # compute the area of both AABBs + bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1']) + bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1']) + + # compute the intersection over union by taking the intersection + # area and dividing it by the sum of prediction + ground-truth + # areas - the interesection area + iou = intersection_area / float(bb1_area + bb2_area - intersection_area) + assert iou >= 0.0 + assert iou <= 1.0 + return iou + +def get_edit_distance(text1: str, text2: str): + """Calculate the edit distance between two strings. + + Parameters + ---------- + text1 : str + The first string. + text2 : str + The second string. + + Returns + ------- + float + The edit distance between the two strings, normalized to be between 0 and 1. + """ + + return editdistance.eval(text1, text2) / max(len(text1), len(text2)) + +def get_font_distance(font1: dict, font2: dict): + """Calculate the distance between two fonts. + + Parameters + ---------- + font1 : dict + The first font. + font2 : dict + The second font. + + Returns + ------- + float + The distance between the two fonts. Normalized to be between 0 and 1. + """ + return 0 \ No newline at end of file