Skip to content

Commit

Permalink
refactor: combine preprocessor with conformer
Browse files Browse the repository at this point in the history
Preprocessor and conformer were requiring many of the same attributes.
Instead of passing many of the attributes to one another combined the
classes. Preprocessor methods are private because they are more complex.

Also move some preprocessor functions to transformations
  • Loading branch information
jcharkow committed Jan 24, 2024
1 parent 03ccb2c commit 04ace37
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 356 deletions.
32 changes: 31 additions & 1 deletion massdash/dataProcessing/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,34 @@ def equalize2D(arr: np.array, num_bins: int):

# use linear interpolation of cdf to find new pixel values
image_equalized = np.interp(arr.flatten(), bins[:-1], cdf)
return image_equalized.reshape(arr.shape)
return image_equalized.reshape(arr.shape)

def min_max_scale(data, min: float=None, max: float=None) -> np.ndarray:
"""
Perform min-max scaling on the input data.
Args:
data (numpy.ndarray): The input data to be scaled.
min (float, optional): The minimum value for scaling. If not provided, the minimum value of the data will be used.
max (float, optional): The maximum value for scaling. If not provided, the maximum value of the data will be used.
Returns:
numpy.ndarray: The scaled data.
"""
min = data.min() if not min else min
max = data.max() if not max else max

return np.nan_to_num((data - min) / (max - min))

def sigmoid(x: np.ndarray) -> np.ndarray:
"""
Compute the sigmoid function for the given input.
Parameters:
x (float): The input value.
Returns:
float: The sigmoid value of the input.
"""
return 1 / (1 + np.exp(-x))
280 changes: 265 additions & 15 deletions massdash/peakPickers/ConformerPeakPicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
"""

import os
from typing import List
from typing import List, Literal

import numpy as np

# Preprocess
from ..preprocess.ConformerPreprocessor import ConformerPreprocessor
# Structs
from ..structs.TransitionGroup import TransitionGroup
from ..structs.TransitionGroupFeature import TransitionGroupFeature
from ..loaders.SpectralLibraryLoader import SpectralLibraryLoader
# Utils
from ..util import check_package
from ..util import LOGGER
from ..util import check_package, LOGGER
# Transformations
from ..dataProcessing.transformations import min_max_scale, sigmoid


onnxruntime, ONNXRUNTIME_AVAILABLE = check_package("onnxruntime")
torch, TORCH_AVAILABLE = check_package("torch")
binary_recall_at_fixed_precision, TORCHMETRICS_AVAILABLE = check_package("torchmetrics", "functional.classification.binary_recall_at_fixed_precision")

class ConformerPeakPicker:
"""
Expand All @@ -37,16 +41,14 @@ class ConformerPeakPicker:
_convertConformerFeatureToTransitionGroupFeatures: Convert conformer predicted feature to TransitionGroupFeatures.
"""

def __init__(self, library_file: str, pretrained_model_file: str, prediction_threshold: float = 0.5, prediction_type: str = "logits"):
def __init__(self, library_file: str, pretrained_model_file: str, prediction_threshold: float = 0.5, prediction_type: Literal['logits', 'sigmoided', 'binarized'] = "logits"):
"""
Initialize the ConformerPeakPicker class.
Args:
transition_group (TransitionGroup): The transition group object.
pretrained_model_file (str): The path to the pretrained model file.
window_size (int, optional): The window size for peak picking. Defaults to 175.
prediction_threshold (float, optional): The prediction threshold for peak picking. Defaults to 0.5.
prediction_type (str, optional): The prediction type for peak picking. Defaults to "logits".
prediction_type (str): The prediction type for peak picking. Defaults to "logits". Valid options are ["logits", "sigmoided", "binarized"].
"""
self.pretrained_model_file = pretrained_model_file
self.prediction_threshold = prediction_threshold
Expand Down Expand Up @@ -96,19 +98,21 @@ def pick(self, transition_group, max_int_transition: int=1000) -> List[Transitio
Returns:
List[TransitionGroupFeature]: The list of transition group features.
"""
# Transform data into required input

LOGGER.info("Loading model...")
self.load_model()

# Transform data into required input
LOGGER.info("Preprocessing data...")
conformer_preprocessor = ConformerPreprocessor(transition_group, self.window_size)
input_data = conformer_preprocessor.preprocess(self.library)
transition_group_adjusted = transition_group.adjust_length(self.window_size) # adjust length of the transition group data to the window size
input_data = self._preprocess(transition_group_adjusted)
LOGGER.info("Predicting...")
ort_input = {self.onnx_session.get_inputs()[0].name: input_data}
ort_output = self.onnx_session.run(None, ort_input)
LOGGER.info("Getting predicted boundaries...")
peak_info = conformer_preprocessor.find_top_peaks(ort_output[0], ["precursor"], self.prediction_threshold, self.prediction_type)
peak_info = self._find_top_peaks(ort_output[0], ["precursor"])
# Get actual peak boundaries
peak_info = conformer_preprocessor.get_peak_boundaries(peak_info)
peak_info = self._get_peak_boundaries(transition_group_adjusted, peak_info)
LOGGER.info(f"Peak info: {peak_info}")
return self._convertConformerFeatureToTransitionGroupFeatures(peak_info, max_int_transition)

Expand All @@ -131,4 +135,250 @@ def _convertConformerFeatureToTransitionGroupFeatures(self, peak_info: dict, max
rightBoundary=f['rt_end'],
areaIntensity=max_int_transition,
consensusApex=f['rt_apex']))
return transitionGroupFeatures
return transitionGroupFeatures

@staticmethod
def _find_thresholds(preds: np.ndarray, target: np.ndarray, min_precisions: float) -> List[tuple]:
"""
Finds the thresholds for a given set of predictions and targets based on minimum precisions.
Args:
preds (numpy.ndarray): Array of predicted values.
target (numpy.ndarray): Array of target values.
min_precisions (float): Float of minimum precisions.
Returns:
list: List of tuples containing minimum precision, recall, and threshold values.
"""
if not TORCH_AVAILABLE:
raise ImportError("PyTorch is required for finding thresholds, but not installed.")
if not TORCHMETRICS_AVAILABLE:
raise ImportError("PyTorchMetrics is required for finding thresholds, but not installed.")

preds = torch.from_numpy(preds)
target = torch.from_numpy(target)
thresholds = []

for min_precision in min_precisions:
recall, threshold = binary_recall_at_fixed_precision(
preds, target, min_precision
)
thresholds.append((min_precision, recall, threshold))

return thresholds

def _preprocess(self, transition_group) -> np.ndarray:
"""
Preprocesses the data by scaling and transforming it into a numpy array.
Code adapted from CAPE
Args:
transition_group (TransitionGroup): The transition group object to preprocess
Returns:
np.ndarray: The preprocessed data as a numpy array with shape (1, 21, len(data[0])).
"""

# Data array ordering:
# Row index 0-5: ms2 (sample min-max normalized)
# Row index 6-11: ms2 (trace min-max normalized)
# Row index 12: ms1
# Row index 13-18: library intensity
# Row index 19: library retention time diff
# Row index 20: precursor charge

if len(transition_group.transitionData) != 6:
raise ValueError(f"Transition group must have 6 transitions, but has {len(transition_group.transitionData)}.")

# initialize empty numpy array
data = np.empty((0, self.window_size), float)
lib_int_data = np.empty((0, self.window_size), float)

for chrom in transition_group.transitionData:
# append ms2 intensity data to data
data = np.append(data, [chrom.intensity], axis=0)

lib_int = self.library.get_fragment_library_intensity(transition_group.sequence, transition_group.precursor_charge, chrom.label)
lib_int = np.repeat(lib_int, self.window_size)
lib_int_data = np.append(lib_int_data, [lib_int], axis=0)

# initialize empty numpy array to store scaled data
new_data = np.empty((21, len(transition_group.transitionData[0].intensity)), float)

## MS2 data (sample min-max normalized)
new_data[0:6] = min_max_scale(data)

## MS2 trace data (trace min-max normalized)
for j in range(6, 12):
new_data[j : j + 1] = min_max_scale(
data[j - 6 : j - 5]
)

## MS1 trace data
prec_int = transition_group.precursorData[0].intensity

# append ms1 intensity data to data
new_data[12] = min_max_scale(prec_int)

## Library intensity data
# append library intensity data to data
new_data[13:19] = min_max_scale(lib_int_data)

## Library retention time diff
# Find the middle index of the array
middle_index = len(data[0]) // 2

# Create a new array with the same length as the original_data
tmp_arr = np.zeros_like(data[0], dtype=float)

# Set the middle point to 0
tmp_arr[middle_index] = 0

# Increment by one on either side of the middle point
for i in range(1, middle_index + 1):
if middle_index - i >= 0:
tmp_arr[middle_index - i] = i
if middle_index + i < len(data[0]):
tmp_arr[middle_index + i] = i

new_data[19] = tmp_arr

## Add charge state
new_data[20] = transition_group.precursor_charge * np.ones(len(data[0]))

## Convert to float32
new_data = new_data.astype(np.float32)

# cnvert the shape to be (1, 21, len(data[0]))
new_data = np.expand_dims(new_data, axis=0)

return new_data

def _find_top_peaks(self, preds, seq_classes: List[str]='input_precursor'):
"""
Find the top peaks in the predictions.
Args:
preds (numpy.ndarray): The predictions.
seq_classes (list): The sequence classes, to group the peaks by, i.e. ['precursor_id']
Returns:
dict: A dictionary containing the peak information for each sequence class.
The dictionary has the following structure:
{
"sequence_class_1": [
{
"max_idx": int,
"start_idx": int,
"end_idx": int,
"score": float
},
...
],
"sequence_class_2": [
...
],
...
}
"""
peak_info = {}

if self.prediction_type == "logits":
preds = sigmoid(preds)

for row_idx in range(preds.shape[0]):
if self.prediction_type == "logits" or self.prediction_type == "sigmoided":
max_col_idx = np.argmax(preds[row_idx])
score = preds[row_idx][max_col_idx]

if preds[row_idx][max_col_idx] < self.prediction_threshold:
continue

peak_start = max_col_idx
peak_end = max_col_idx

for col_idx in range(max_col_idx, -1, -1):
if preds[row_idx, col_idx] < preds[row_idx, max_col_idx] * 0.5:
peak_start = col_idx
break

for col_idx in range(max_col_idx, preds.shape[1]):
if preds[row_idx, col_idx] < preds[row_idx, max_col_idx] * 0.5:
peak_end = col_idx
break
elif self.prediction_type == "binarized":
if np.sum(preds[row_idx]) == 0:
continue

try:
diff = np.diff(preds[row_idx])
peak_start = (np.where(diff == 1)[0] + 1)[0]
peak_end = np.where(diff == -1)[0][0]
max_col_idx = (peak_start + peak_end) // 2
score = np.inf
except:
continue
else:
raise NotImplementedError("Prediction type {} is not supported. Currently supported prediction types include [logits, sigmoided, binarized]".format(self.prediction_type))

if seq_classes[row_idx] in peak_info:
peak_info[seq_classes[row_idx]].append(
{
"max_idx": max_col_idx,
"start_idx": peak_start,
"end_idx": peak_end,
"score": score,
}
)
else:
peak_info[seq_classes[row_idx]] = [
{
"max_idx": max_col_idx,
"start_idx": peak_start,
"end_idx": peak_end,
"score": score,
}
]

return peak_info

def _get_peak_boundaries(self, transition_group, peak_info: dict):
"""
Adjusts the peak boundaries in the peak_info dictionary based on the window size and the dimensions of the input rt_array.
Calculates the actual RT values from the rt_array and appends them to the peak_info dictionary.
Args:
peak_info (dict): A dictionary containing information about the peaks.
window_size (int, optional): The size of the window used for trimming the rt_array. Defaults to 175.
Returns:
dict: The updated peak_info dictionary with adjusted peak boundaries and RT values.
"""
rt_array = transition_group.transitionData[0].data
if rt_array.shape[0] != self.window_size:
print(f"input_data {rt_array.shape[0]} was trimmed to {self.window_size}, adjusting peak_info indexes to map to the original datas dimensions")
for key in peak_info.keys():
for i in range(len(peak_info[key])):
peak_info[key][i]['max_idx_org'] = peak_info[key][i]['max_idx']
peak_info[key][i]['start_idx_org'] = peak_info[key][i]['start_idx']
peak_info[key][i]['end_idx_org'] = peak_info[key][i]['end_idx']
new_max_idx = peak_info[key][i]['max_idx'] + (self.window_size // 2) - (rt_array.shape[0] // 2)
if not new_max_idx < 0:
peak_info[key][i]['max_idx'] = new_max_idx

new_start_idx = peak_info[key][i]['start_idx'] + (self.window_size // 2) - (rt_array.shape[0] // 2)
if not new_start_idx < 0:
peak_info[key][i]['start_idx'] = new_start_idx

peak_info[key][i]['end_idx'] = peak_info[key][i]['end_idx'] + (self.window_size // 2) - (rt_array.shape[0] // 2)

# get actual RT value from RT array and append to peak_info
for key in peak_info.keys():
for i in range(len(peak_info[key])):
peak_info[key][i]['rt_apex'] = rt_array[peak_info[key][i]['max_idx']]
peak_info[key][i]['rt_start'] = rt_array[peak_info[key][i]['start_idx']]
peak_info[key][i]['rt_end'] = rt_array[peak_info[key][i]['end_idx']]
peak_info[key][i]['int_apex'] = np.max([tg.intensity[peak_info[key][i]['max_idx']] for tg in transition_group.transitionData])

return peak_info
Loading

0 comments on commit 04ace37

Please sign in to comment.