Skip to content

Commit 4793f78

Browse files
authored
ChequeDetection (#260)
1 parent 13d0a11 commit 4793f78

File tree

9 files changed

+3736
-0
lines changed

9 files changed

+3736
-0
lines changed

advanced_tutorials/fraud_cheque_detection/1_feature_pipeline.ipynb

+1,378
Large diffs are not rendered by default.

advanced_tutorials/fraud_cheque_detection/2_training_pipeline.ipynb

+573
Large diffs are not rendered by default.

advanced_tutorials/fraud_cheque_detection/3_inference_pipeline.ipynb

+1,241
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
### Dataset Configuration
2+
DATASET_NAME = "shivi/cheques_sample_data"
3+
4+
### Donut Configuration
5+
DONUT_BASE_REPO = "naver-clova-ix/donut-base" #"nielsr/donut-base"
6+
DONUT_FT_REPO = "shivi/donut-cheque-parser"
7+
IMAGE_SIZE = [960, 720] # Input image size
8+
MAX_LENGTH = 768 # Max generated sequence length for text decoder of Donut
9+
10+
### Task Tokens
11+
TASK_START_TOKEN = "<parse-cheque>"
12+
TASK_END_TOKEN = "<parse-cheque>"
13+
14+
### Training Configuration
15+
BATCH_SIZE = 1
16+
NUM_WORKERS = 4
17+
MAX_EPOCHS = 30
18+
VAL_CHECK_INTERVAL = 0.2
19+
CHECK_VAL_EVERY_N_EPOCH = 1
20+
GRADIENT_CLIP_VAL = 1.0
21+
LEARNING_RATE = 3e-5
22+
VERBOSE = True
23+
24+
### Hardware Configuration
25+
ACCELERATOR = "gpu"
26+
DEVICE_NUM = 1
27+
PRECISION = 16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from textblob import TextBlob
2+
from word2number import w2n
3+
4+
def spell_check(text):
5+
"""
6+
Checks and corrects the spelling of a given text.
7+
8+
Parameters:
9+
- text (str): The text whose spelling is to be checked.
10+
11+
Returns:
12+
- tuple: A tuple containing a boolean indicating if the original spelling was correct, and the corrected text.
13+
"""
14+
# Convert the text to lower case to standardize it
15+
text_lower = text.lower()
16+
17+
# Return early if the text is 'missing' or an empty string
18+
if text_lower in ['missing', ' ']:
19+
return False, 'missing'
20+
21+
# Correct the text using TextBlob
22+
text_corrected = str(TextBlob(text_lower).correct())
23+
24+
# Determine if the original text was spelled correctly
25+
spelling_is_correct = text_lower == text_corrected
26+
27+
return spelling_is_correct, text_corrected.strip()
28+
29+
30+
def amount_letter_number_match(amount_in_text_corrected, amount_in_number):
31+
"""
32+
Compares the numeric value of a text representation of an amount to its numeric counterpart.
33+
34+
Parameters:
35+
- amount_in_text_corrected (str): The text representation of an amount.
36+
- amount_in_number (str): The numeric representation of an amount.
37+
38+
Returns:
39+
- bool or tuple: True if the amounts match, False otherwise, or a tuple with a message if data is missing.
40+
"""
41+
# Handle missing values
42+
if 'missing' in [amount_in_text_corrected, amount_in_number]:
43+
return False, ('Amount in words is missing' if amount_in_text_corrected == 'missing' else 'Amount in numbers is missing')
44+
45+
try:
46+
# Attempt to convert the textual representation to a number
47+
amount_text_to_num = w2n.word_to_num(amount_in_text_corrected)
48+
49+
# Compare it to the provided numeric value, making sure to convert it to an int
50+
return amount_text_to_num == int(amount_in_number)
51+
52+
# If Spell correction fails (for -> four)
53+
except Exception as e:
54+
return False
55+
56+
57+
def get_amount_match_column(amount_in_text_corrected, amount_in_number):
58+
"""
59+
Retrieves the match status or value for an amount, handling tuples indicating missing data.
60+
61+
Parameters:
62+
- amount_in_text_corrected (str): The text representation of an amount, corrected for spelling.
63+
- amount_in_number (str): The numeric representation of an amount.
64+
65+
Returns:
66+
- bool: True if the amounts match, False otherwise, or the first element of the tuple if an error message is present.
67+
"""
68+
# Determine the match status or value
69+
match_value = amount_letter_number_match(
70+
amount_in_text_corrected,
71+
amount_in_number,
72+
)
73+
74+
# Return the value, handling tuple for error messages
75+
if isinstance(match_value, tuple):
76+
return match_value[0]
77+
78+
return match_value
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import torch
2+
import re
3+
import numpy as np
4+
from transformers import (
5+
DonutProcessor,
6+
VisionEncoderDecoderModel,
7+
)
8+
from features.cheque import (
9+
spell_check,
10+
amount_letter_number_match,
11+
)
12+
13+
# Determine the device to use based on the availability of CUDA (GPU)
14+
device = "cuda" if torch.cuda.is_available() else "cpu"
15+
16+
def load_cheque_parser(folder_name):
17+
"""
18+
Loads a cheque parsing processor and model from a specified directory and moves the model to the appropriate device.
19+
20+
This function loads a DonutProcessor and a VisionEncoderDecoderModel. The model is moved to a GPU if available, otherwise to CPU.
21+
22+
Parameters:
23+
- folder_name (str): The directory where the processor and model's pretrained weights are stored.
24+
25+
Returns:
26+
- tuple: A tuple containing the loaded processor and model.
27+
"""
28+
# Load the DonutProcessor from the pretrained model directory
29+
processor = DonutProcessor.from_pretrained(folder_name)
30+
31+
# Load the VisionEncoderDecoderModel from the pretrained model directory
32+
model = VisionEncoderDecoderModel.from_pretrained(folder_name)
33+
34+
# Move the model to the available device (GPU or CPU)
35+
model.to(device)
36+
37+
# Return the processor and model as a tuple
38+
return processor, model
39+
40+
41+
def parse_text(image, processor, model):
42+
"""
43+
Parses text from an image using a pre-trained model and processor, and formats the output.
44+
45+
Parameters:
46+
- image: The image from which to parse text.
47+
- processor: The processor instance equipped with methods for handling input preprocessing and decoding outputs.
48+
- model: The pre-trained VisionEncoderDecoderModel used to generate text from image data.
49+
50+
Returns:
51+
- dict: A dictionary containing parsed and formatted cheque details.
52+
"""
53+
# Prepare the initial task prompt and get decoder input IDs from the tokenizer
54+
task_prompt = "<parse-cheque>"
55+
decoder_input_ids = processor.tokenizer(
56+
task_prompt,
57+
add_special_tokens=False,
58+
return_tensors="pt",
59+
).input_ids
60+
61+
# Convert image to pixel values suitable for the model input
62+
pixel_values = processor(image, return_tensors="pt").pixel_values
63+
64+
65+
# Generate outputs from the model using the provided pixel values and decoder inputs
66+
outputs = model.generate(
67+
pixel_values.to(device), # Ensure pixel values are on the correct device (CPU/GPU)
68+
decoder_input_ids=decoder_input_ids.to(device), # Move decoder inputs to the correct device
69+
max_length=model.decoder.config.max_position_embeddings, # Set the maximum output length
70+
pad_token_id=processor.tokenizer.pad_token_id, # Define padding token
71+
eos_token_id=processor.tokenizer.eos_token_id, # Define end-of-sequence token
72+
use_cache=True, # Enable caching to improve performance
73+
bad_words_ids=[[processor.tokenizer.unk_token_id]], # Prevent generation of unknown tokens
74+
return_dict_in_generate=True, # Return outputs in a dictionary format
75+
)
76+
77+
# Decode the output sequences to text
78+
sequence = processor.batch_decode(outputs.sequences)[0]
79+
80+
# Remove special tokens and clean up the sequence
81+
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
82+
83+
# Remove the initial task prompt token
84+
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip()
85+
86+
# Convert the cleaned sequence to JSON and format the output
87+
json = processor.token2json(sequence)
88+
89+
return {
90+
key: (value if value != '' else 'missing')
91+
for attribute
92+
in json['cheque_details']
93+
for key, value
94+
in attribute.items()
95+
}
96+
97+
98+
def evaluate_cheque_fraud(parsed_data, model_fraud_detection):
99+
"""
100+
Evaluates potential fraud in a cheque by analyzing the consistency of spelling and the match between
101+
numerical and textual representations of the amount.
102+
103+
Parameters:
104+
- parsed_data (dict): Dictionary containing parsed data from a cheque, including amounts in words and figures.
105+
- model_fraud_detection: A trained model used to predict whether the cheque is valid or fraudulent.
106+
107+
Returns:
108+
- tuple: A tuple containing the fraud evaluation result ('valid' or 'fraud'), spelling check message,
109+
and amount match message.
110+
"""
111+
# Check spelling for the amount in words and correct it if necessary
112+
spelling_is_correct, amount_in_text_corrected = spell_check(
113+
parsed_data['amt_in_words'],
114+
)
115+
116+
# Check if the corrected amount in words matches the amount in figures
117+
amount_match = amount_letter_number_match(
118+
amount_in_text_corrected,
119+
parsed_data['amt_in_figures'],
120+
)
121+
122+
# Handle the case where amount_match is a tuple, using only the first element if so
123+
amount_match_value = amount_match[0] if isinstance(amount_match, tuple) else amount_match
124+
125+
# Prepare the input for the fraud detection model
126+
model_input = np.array([spelling_is_correct, amount_match_value])
127+
128+
# Predict fraud using the model, reshaping input to match expected format
129+
prediction = model_fraud_detection.predict(
130+
model_input.reshape(1, -1)
131+
)
132+
133+
# Construct messages regarding the spelling and value match
134+
spelling = f'Spelling is correct: {spelling_is_correct}'
135+
value_match = f'Numeric and alphabetic values match: {amount_match}'
136+
137+
# Return the evaluation result along with explanatory messages
138+
return np.where(prediction[0] == 1, 'valid', 'fraud').item(), spelling, value_match

0 commit comments

Comments
 (0)