Skip to content

Commit 77da1e0

Browse files
authored
Merge pull request #5 from steffencruz/second-pass
Second pass
2 parents a3af5da + 62fc434 commit 77da1e0

23 files changed

+495
-1155
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ Before you proceed with the installation of the subnet, note the following:
8686
- `neurons/miner.py`: Use `pytesseract` for OCR, and use `OCRSynapse` to communicate with validator
8787

8888
### Remaining changes to be done
89-
In addition to the above files, we would also update the following files:
89+
In addition to the above files, we have also updated the following files:
9090
- `README.md`: This file contains the documentation for your project. Update this file to reflect your project's documentation.
9191
- `CONTRIBUTING.md`: This file contains the instructions for contributing to your project. Update this file to reflect your project's contribution guidelines.
9292
- `template/__init__.py`: This file contains the version of your project.

neurons/miner.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# The MIT License (MIT)
22
# Copyright © 2023 Yuma Rao
3-
# TODO(developer): Set your name
4-
# Copyright © 2023 <your name>
53

64
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
75
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
@@ -22,18 +20,16 @@
2220
import bittensor as bt
2321
import pytesseract
2422

25-
# Bittensor Miner Template:
23+
# Bittensor OCR Miner
2624
import ocr_subnet
2725

28-
from ocr_subnet.utils.serialize import deserialize_image
29-
3026
# import base miner class which takes care of most of the boilerplate
3127
from ocr_subnet.base.miner import BaseMinerNeuron
3228

3329

3430
class Miner(BaseMinerNeuron):
3531
"""
36-
Your miner neuron class. You should use this class to define your miner's behavior. In particular, you should replace the forward function with your own logic. You may also want to override the blacklist and priority functions according to your needs.
32+
OCR miner neuron class. You may also want to override the blacklist and priority functions according to your needs.
3733
3834
This class inherits from the BaseMinerNeuron class, which in turn inherits from BaseNeuron. The BaseNeuron class takes care of routine tasks such as setting up wallet, subtensor, metagraph, logging directory, parsing config, etc. You can override any of the methods in BaseNeuron if you need to customize the behavior.
3935
@@ -45,6 +41,7 @@ def __init__(self, config=None):
4541

4642
# TODO(developer): Anything specific to your use case you can do here
4743

44+
4845
async def forward(
4946
self, synapse: ocr_subnet.protocol.OCRSynapse
5047
) -> ocr_subnet.protocol.OCRSynapse:
@@ -58,27 +55,35 @@ async def forward(
5855
ocr_subnet.protocol.OCRSynapse: The synapse object with the 'response' field set to the extracted data.
5956
6057
"""
58+
# Get image data
59+
image = ocr_subnet.utils.image.deserialize(base64_string=synapse.base64_image)
6160

62-
image = deserialize_image(base64_string=synapse.base64_image)
6361
# Use pytesseract to get the data
6462
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
6563

6664
response = []
67-
# Loop over each item in the 'text' part of the data
6865
for i in range(len(data['text'])):
6966
if data['text'][i].strip() != '': # This filters out empty text results
7067
x1, y1, width, height = data['left'][i], data['top'][i], data['width'][i], data['height'][i]
68+
if width * height < 10: # This filters out small boxes (likely noise)
69+
continue
70+
7171
x2, y2 = x1 + width, y1 + height
7272

7373
# Here we don't have font information, so we'll omit that.
7474
# Pytesseract does not extract font family or size information.
7575
entry = {
76-
'index': i,
7776
'position': [x1, y1, x2, y2],
7877
'text': data['text'][i]
7978
}
8079
response.append(entry)
8180

81+
# Merge together words into sections, which are on the same line (same y value) and are close together (small distance in x)
82+
response = ocr_subnet.utils.process.group_and_merge_boxes(response)
83+
84+
# Sort sections by y, then sort by x so that they read left to right and top to bottom
85+
response = sorted(response, key=lambda item: (item['position'][1], item['position'][0]))
86+
8287
# Attach response to synapse and return it.
8388
synapse.response = response
8489

neurons/validator.py

+50-16
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# The MIT License (MIT)
22
# Copyright © 2023 Yuma Rao
3-
# TODO(developer): Set your name
4-
# Copyright © 2023 <your name>
53

64
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
75
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
@@ -17,21 +15,20 @@
1715
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
1816
# DEALINGS IN THE SOFTWARE.
1917

20-
18+
import os
2119
import time
22-
23-
# Bittensor
20+
import hashlib
2421
import bittensor as bt
2522

26-
from ocr_subnet.validator import forward
23+
import ocr_subnet
2724

2825
# import base validator class which takes care of most of the boilerplate
2926
from ocr_subnet.base.validator import BaseValidatorNeuron
3027

3128

3229
class Validator(BaseValidatorNeuron):
3330
"""
34-
Your validator neuron class. You should use this class to define your validator's behavior. In particular, you should replace the forward function with your own logic.
31+
OCR validator neuron class.
3532
3633
This class inherits from the BaseValidatorNeuron class, which in turn inherits from BaseNeuron. The BaseNeuron class takes care of routine tasks such as setting up wallet, subtensor, metagraph, logging directory, parsing config, etc. You can override any of the methods in BaseNeuron if you need to customize the behavior.
3734
@@ -44,19 +41,56 @@ def __init__(self, config=None):
4441
bt.logging.info("load_state()")
4542
self.load_state()
4643

47-
# TODO(developer): Anything specific to your use case you can do here
44+
self.image_dir = './data/images/'
45+
if not os.path.exists(self.image_dir):
46+
os.makedirs(self.image_dir)
47+
4848

4949
async def forward(self):
5050
"""
51-
Validator forward pass. Consists of:
52-
- Generating the query
53-
- Querying the miners
54-
- Getting the responses
55-
- Rewarding the miners
56-
- Updating the scores
51+
The forward function is called by the validator every time step.
52+
53+
It consists of 3 important steps:
54+
- Generate a challenge for the miners (in this case it creates a synthetic invoice image)
55+
- Query the miners with the challenge
56+
- Score the responses from the miners
57+
58+
Args:
59+
self (:obj:`bittensor.neuron.Neuron`): The neuron object which contains all the necessary state for the validator.
60+
5761
"""
58-
# TODO(developer): Rewrite this function based on your protocol definition.
59-
return await forward(self)
62+
63+
# get_random_uids is an example method, but you can replace it with your own.
64+
miner_uids = ocr_subnet.utils.uids.get_random_uids(self, k=self.config.neuron.sample_size)
65+
66+
# make a hash from the timestamp
67+
filename = hashlib.md5(str(time.time()).encode()).hexdigest()
68+
69+
# Create a random image and load it.
70+
image_data = ocr_subnet.validator.generate.invoice(path=os.path.join(self.image_dir, f"{filename}.pdf"), corrupt=True)
71+
72+
# Create synapse object to send to the miner and attach the image.
73+
synapse = ocr_subnet.protocol.OCRSynapse(base64_image = image_data['base64_image'])
74+
75+
# The dendrite client queries the network.
76+
responses = self.dendrite.query(
77+
# Send the query to selected miner axons in the network.
78+
axons=[self.metagraph.axons[uid] for uid in miner_uids],
79+
# Pass the synapse to the miner.
80+
synapse=synapse,
81+
# Do not deserialize the response so that we have access to the raw response.
82+
deserialize=False,
83+
)
84+
85+
# Log the results for monitoring purposes.
86+
bt.logging.info(f"Received responses: {responses}")
87+
88+
rewards = ocr_subnet.validator.reward.get_rewards(self, labels=image_data['labels'], responses=responses)
89+
90+
bt.logging.info(f"Scored responses: {rewards}")
91+
92+
# Update the scores based on the rewards. You may want to define your own update_scores function for custom behavior.
93+
self.update_scores(rewards, miner_uids)
6094

6195

6296
# The main function parses the configuration and runs the validator.

ocr_subnet/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# The MIT License (MIT)
22
# Copyright © 2023 Yuma Rao
3-
# TODO(developer): Set your name
4-
# Copyright © 2023 <your name>
53

64
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
75
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
@@ -29,3 +27,4 @@
2927
from . import protocol
3028
from . import base
3129
from . import validator
30+
from . import utils

ocr_subnet/base/validator.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# The MIT License (MIT)
22
# Copyright © 2023 Yuma Rao
3-
# TODO(developer): Set your name
4-
# Copyright © 2023 <your name>
53

64
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
75
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation

ocr_subnet/protocol.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
# The MIT License (MIT)
22
# Copyright © 2023 Yuma Rao
3-
# TODO(developer): Set your name
4-
# Copyright © 2023 <your name>
53

64
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
75
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
@@ -17,32 +15,31 @@
1715
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
1816
# DEALINGS IN THE SOFTWARE.
1917

20-
import typing
21-
import bittensor as bt
2218

19+
import bittensor as bt
20+
from typing import Optional, List
2321

2422
class OCRSynapse(bt.Synapse):
2523
"""
2624
A simple OCR synapse protocol representation which uses bt.Synapse as its base.
2725
This protocol enables communication betweenthe miner and the validator.
2826
2927
Attributes:
30-
- image: A pdf image to be processed by the miner.
28+
- base64_image: Base64 encoding of pdf image to be processed by the miner.
3129
- response: List[dict] containing data extracted from the image.
3230
"""
3331

3432
# Required request input, filled by sending dendrite caller. It is a base64 encoded string.
3533
base64_image: str
3634

3735
# Optional request output, filled by recieving axon.
38-
response: typing.Optional[typing.List[dict]] = None
36+
response: Optional[List[dict]] = None
3937

40-
def deserialize(self) -> int:
38+
def deserialize(self) -> List[dict]:
4139
"""
42-
Deserialize the miner response. This method retrieves the response from
43-
the miner in the form of `response`, maybe this also takes care of casting it to List[dict]?
40+
Deserialize the miner response.
4441
4542
Returns:
46-
- List[dict: The deserialized response, which is a list of dictionaries containing the extracted data.
43+
- List[dict]: The deserialized response, which is a list of dictionaries containing the extracted data.
4744
"""
4845
return self.response

ocr_subnet/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from . import config
22
from . import misc
33
from . import uids
4+
from . import process

ocr_subnet/utils/image.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import io
2+
import fitz
3+
import base64
4+
5+
from typing import List
6+
from PIL import Image, ImageDraw
7+
8+
9+
def serialize(image: Image, format: str="JPEG") -> str:
10+
"""Converts PIL image to base64 string.
11+
"""
12+
13+
buffer = io.BytesIO()
14+
image.save(buffer, format=format)
15+
return buffer.getvalue()
16+
17+
18+
def deserialize(base64_string: str) -> Image:
19+
"""Converts base64 string to PIL image.
20+
"""
21+
22+
return Image.open(io.BytesIO(base64.b64decode(base64_string)))
23+
24+
25+
def load(pdf_path: str, page: int=0, zoom_x: float=1.0, zoom_y: float=1.0) -> Image:
26+
"""Loads pdf image and converts to PIL image
27+
"""
28+
29+
# Read the pdf into memory
30+
pdf = fitz.open(pdf_path)
31+
page = pdf[page]
32+
33+
# Set zoom factors for x and y axis (1.0 means 100%)
34+
mat = fitz.Matrix(zoom_x, zoom_y)
35+
pix = page.get_pixmap(matrix=mat)
36+
img_data = io.BytesIO(pix.tobytes('png'))
37+
38+
# convert to PIL image
39+
return Image.open(img_data)
40+
41+
def draw_boxes(image: Image, response: List[dict], color='red'):
42+
"""Draws boxes around text on the image
43+
"""
44+
45+
draw = ImageDraw.Draw(image)
46+
for item in response:
47+
draw.rectangle(item['position'], outline=color)
48+
49+
return image

ocr_subnet/utils/process.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# The MIT License (MIT)
2+
# Copyright © 2023 Yuma Rao
3+
4+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
5+
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
6+
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
7+
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
8+
9+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
10+
# the Software.
11+
12+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
13+
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
14+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
15+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
16+
# DEALINGS IN THE SOFTWARE.
17+
18+
from typing import List
19+
20+
def group_and_merge_boxes(data: List[dict], xtol: int=25, ytol: int=5) -> List[dict]:
21+
"""
22+
Combines boxes that are close together into a single box so that the text is grouped into sections.
23+
24+
Args:
25+
- data (list): List of dictionaries containing the position, font and text of each section
26+
- xtol (int): Maximum distance between boxes in the x direction to be considered part of the same section
27+
- ytol (int): Maximum distance between boxes in the y direction to be considered part of the same section
28+
29+
Returns:
30+
- list: List of dictionaries containing the position, font and text of each section
31+
"""
32+
# Ensure all data items are valid and have a 'position' key
33+
data = [box for box in data if box is not None and 'position' in box]
34+
35+
# Step 1: Group boxes by lines
36+
lines = []
37+
for box in data:
38+
added_to_line = False
39+
for line in lines:
40+
if line and abs(line[0]['position'][1] - box['position'][1]) <= ytol:
41+
line.append(box)
42+
added_to_line = True
43+
break
44+
if not added_to_line:
45+
lines.append([box])
46+
47+
# Step 2: Sort and merge within each line
48+
merged_data = []
49+
for line in lines:
50+
line.sort(key=lambda item: item['position'][0]) # Sort by x1
51+
i = 0
52+
while i < len(line) - 1:
53+
box1 = line[i]['position']
54+
box2 = line[i + 1]['position']
55+
if abs(box1[2] - box2[0]) <= xtol: # Check horizontal proximity
56+
new_box = {'position': [min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3])],
57+
'text': line[i]['text'] + ' ' + line[i + 1]['text']}
58+
line[i] = new_box
59+
del line[i + 1]
60+
else:
61+
i += 1
62+
merged_data.extend(line)
63+
64+
return merged_data

ocr_subnet/utils/serialize.py

-19
This file was deleted.

0 commit comments

Comments
 (0)