|
| 1 | +import torch |
| 2 | +import re |
| 3 | +import numpy as np |
| 4 | +from transformers import ( |
| 5 | + DonutProcessor, |
| 6 | + VisionEncoderDecoderModel, |
| 7 | +) |
| 8 | +from features.cheque import ( |
| 9 | + spell_check, |
| 10 | + amount_letter_number_match, |
| 11 | +) |
| 12 | + |
| 13 | +# Determine the device to use based on the availability of CUDA (GPU) |
| 14 | +device = "cuda" if torch.cuda.is_available() else "cpu" |
| 15 | + |
| 16 | +def load_cheque_parser(folder_name): |
| 17 | + """ |
| 18 | + Loads a cheque parsing processor and model from a specified directory and moves the model to the appropriate device. |
| 19 | +
|
| 20 | + This function loads a DonutProcessor and a VisionEncoderDecoderModel. The model is moved to a GPU if available, otherwise to CPU. |
| 21 | +
|
| 22 | + Parameters: |
| 23 | + - folder_name (str): The directory where the processor and model's pretrained weights are stored. |
| 24 | +
|
| 25 | + Returns: |
| 26 | + - tuple: A tuple containing the loaded processor and model. |
| 27 | + """ |
| 28 | + # Load the DonutProcessor from the pretrained model directory |
| 29 | + processor = DonutProcessor.from_pretrained(folder_name) |
| 30 | + |
| 31 | + # Load the VisionEncoderDecoderModel from the pretrained model directory |
| 32 | + model = VisionEncoderDecoderModel.from_pretrained(folder_name) |
| 33 | + |
| 34 | + # Move the model to the available device (GPU or CPU) |
| 35 | + model.to(device) |
| 36 | + |
| 37 | + # Return the processor and model as a tuple |
| 38 | + return processor, model |
| 39 | + |
| 40 | + |
| 41 | +def parse_text(image, processor, model): |
| 42 | + """ |
| 43 | + Parses text from an image using a pre-trained model and processor, and formats the output. |
| 44 | +
|
| 45 | + Parameters: |
| 46 | + - image: The image from which to parse text. |
| 47 | + - processor: The processor instance equipped with methods for handling input preprocessing and decoding outputs. |
| 48 | + - model: The pre-trained VisionEncoderDecoderModel used to generate text from image data. |
| 49 | +
|
| 50 | + Returns: |
| 51 | + - dict: A dictionary containing parsed and formatted cheque details. |
| 52 | + """ |
| 53 | + # Prepare the initial task prompt and get decoder input IDs from the tokenizer |
| 54 | + task_prompt = "<parse-cheque>" |
| 55 | + decoder_input_ids = processor.tokenizer( |
| 56 | + task_prompt, |
| 57 | + add_special_tokens=False, |
| 58 | + return_tensors="pt", |
| 59 | + ).input_ids |
| 60 | + |
| 61 | + # Convert image to pixel values suitable for the model input |
| 62 | + pixel_values = processor(image, return_tensors="pt").pixel_values |
| 63 | + |
| 64 | + |
| 65 | + # Generate outputs from the model using the provided pixel values and decoder inputs |
| 66 | + outputs = model.generate( |
| 67 | + pixel_values.to(device), # Ensure pixel values are on the correct device (CPU/GPU) |
| 68 | + decoder_input_ids=decoder_input_ids.to(device), # Move decoder inputs to the correct device |
| 69 | + max_length=model.decoder.config.max_position_embeddings, # Set the maximum output length |
| 70 | + pad_token_id=processor.tokenizer.pad_token_id, # Define padding token |
| 71 | + eos_token_id=processor.tokenizer.eos_token_id, # Define end-of-sequence token |
| 72 | + use_cache=True, # Enable caching to improve performance |
| 73 | + bad_words_ids=[[processor.tokenizer.unk_token_id]], # Prevent generation of unknown tokens |
| 74 | + return_dict_in_generate=True, # Return outputs in a dictionary format |
| 75 | + ) |
| 76 | + |
| 77 | + # Decode the output sequences to text |
| 78 | + sequence = processor.batch_decode(outputs.sequences)[0] |
| 79 | + |
| 80 | + # Remove special tokens and clean up the sequence |
| 81 | + sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") |
| 82 | + |
| 83 | + # Remove the initial task prompt token |
| 84 | + sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() |
| 85 | + |
| 86 | + # Convert the cleaned sequence to JSON and format the output |
| 87 | + json = processor.token2json(sequence) |
| 88 | + |
| 89 | + return { |
| 90 | + key: (value if value != '' else 'missing') |
| 91 | + for attribute |
| 92 | + in json['cheque_details'] |
| 93 | + for key, value |
| 94 | + in attribute.items() |
| 95 | + } |
| 96 | + |
| 97 | + |
| 98 | +def evaluate_cheque_fraud(parsed_data, model_fraud_detection): |
| 99 | + """ |
| 100 | + Evaluates potential fraud in a cheque by analyzing the consistency of spelling and the match between |
| 101 | + numerical and textual representations of the amount. |
| 102 | +
|
| 103 | + Parameters: |
| 104 | + - parsed_data (dict): Dictionary containing parsed data from a cheque, including amounts in words and figures. |
| 105 | + - model_fraud_detection: A trained model used to predict whether the cheque is valid or fraudulent. |
| 106 | +
|
| 107 | + Returns: |
| 108 | + - tuple: A tuple containing the fraud evaluation result ('valid' or 'fraud'), spelling check message, |
| 109 | + and amount match message. |
| 110 | + """ |
| 111 | + # Check spelling for the amount in words and correct it if necessary |
| 112 | + spelling_is_correct, amount_in_text_corrected = spell_check( |
| 113 | + parsed_data['amt_in_words'], |
| 114 | + ) |
| 115 | + |
| 116 | + # Check if the corrected amount in words matches the amount in figures |
| 117 | + amount_match = amount_letter_number_match( |
| 118 | + amount_in_text_corrected, |
| 119 | + parsed_data['amt_in_figures'], |
| 120 | + ) |
| 121 | + |
| 122 | + # Handle the case where amount_match is a tuple, using only the first element if so |
| 123 | + amount_match_value = amount_match[0] if isinstance(amount_match, tuple) else amount_match |
| 124 | + |
| 125 | + # Prepare the input for the fraud detection model |
| 126 | + model_input = np.array([spelling_is_correct, amount_match_value]) |
| 127 | + |
| 128 | + # Predict fraud using the model, reshaping input to match expected format |
| 129 | + prediction = model_fraud_detection.predict( |
| 130 | + model_input.reshape(1, -1) |
| 131 | + ) |
| 132 | + |
| 133 | + # Construct messages regarding the spelling and value match |
| 134 | + spelling = f'Spelling is correct: {spelling_is_correct}' |
| 135 | + value_match = f'Numeric and alphabetic values match: {amount_match}' |
| 136 | + |
| 137 | + # Return the evaluation result along with explanatory messages |
| 138 | + return np.where(prediction[0] == 1, 'valid', 'fraud').item(), spelling, value_match |
0 commit comments