From 6c1da9fd3419bf3a92450af939b51c34d50ce6c6 Mon Sep 17 00:00:00 2001 From: Jacob Schreiber Date: Thu, 31 Oct 2024 16:59:02 +0000 Subject: [PATCH] ADD symmetric tomtom --- ...ulatory_Features_a_Model_Has_Learned.ipynb | 2 +- tangermeme/plot.py | 2 +- tangermeme/seqlet.py | 11 +- tangermeme/tools/fimo.py | 178 +++++---- tangermeme/tools/symmetric_tomtom.py | 345 ++++++++++++++++++ tangermeme/tools/tomtom.py | 114 +++--- tangermeme/utils.py | 2 +- tests/tools/test_tomtom.py | 104 ++++++ 8 files changed, 622 insertions(+), 136 deletions(-) create mode 100644 tangermeme/tools/symmetric_tomtom.py diff --git a/docs/vignettes/Inspecting_What_Cis-Regulatory_Features_a_Model_Has_Learned.ipynb b/docs/vignettes/Inspecting_What_Cis-Regulatory_Features_a_Model_Has_Learned.ipynb index 203bed5..402bd98 100644 --- a/docs/vignettes/Inspecting_What_Cis-Regulatory_Features_a_Model_Has_Learned.ipynb +++ b/docs/vignettes/Inspecting_What_Cis-Regulatory_Features_a_Model_Has_Learned.ipynb @@ -81,7 +81,7 @@ "\n", "Marginalization is the process of substituting in a motif and observing the change in model output before and after the substitution. Usually, this substitution is done into a background region so that one can estimate the \"marginal\" effect that each motif has on model predictions in isolation, and this value is averaged over many background sequencces.\n", "\n", - "An important complication in using marginalizations is the choice of background. Uniformly generated random background sequences are likely much higher in GC content than the genome these models were trained on and so they may not be the best choice, although they are very convenient to use. The best choice is usually to use regions of the genome that are known to not exhibit the activity you care about but have similar GC content, e.g., regions that are not peaks, regions that are not accessible, non-gene body non-promoter regions, etc. The second best choice is to use randomly generated sequences with proportions of each character derived from the genome. As a reference: for `chr1` after excluding N's these proportions are `[0.2910, 0.2085, 0.2087, 0.2918]`. When prototyping, it is okay to use uniformly randomly generated sequences, but make sure to switch before presenting your results. Models can be very sensitive to local GC content and so when you substitute a motif with a very different content you may get a \"mirage\" that is driven entirely by this local change in GC content.\n", + "An important complication in using marginalizations is the choice of background. Uniformly generated random background sequences are likely much higher in GC content than the genome these models were trained on and so they may not be the best choice, although they are very convenient to use. The best choice is usually to use regions of the genome that are known to not exhibit the activity you care about but have similar GC content and dinucleotide content, e.g., regions that are not peaks, regions that are not accessible, non-gene body non-promoter regions, etc. The second best choice is to use randomly generated sequences with proportions of each character derived from the genome. As a reference: for `chr1` after excluding N's these proportions are `[0.2910, 0.2085, 0.2087, 0.2918]`. When prototyping, it is okay to use uniformly randomly generated sequences, but make sure to switch before presenting your results. Models can be very sensitive to local GC content and can also be sensitive to dinucleotide content and so when you substitute a motif with a very different content you may get a \"mirage\" that is driven entirely by this local change in GC content.\n", "\n", "In general, more background sequences is better as long as you have the compute for it but this betterness drops off quickly after around 100 sequences. " ] diff --git a/tangermeme/plot.py b/tangermeme/plot.py index 4115a81..f40e440 100644 --- a/tangermeme/plot.py +++ b/tangermeme/plot.py @@ -153,7 +153,7 @@ def plot_logo(X_attr, ax, color=None, annotations=None, start=None, end=None, elif show_extra: s = motif_start - motifs[motif_start:motif_start+len(motif)*2, i] = 1 + motifs[motif_start:motif_start+len(str(motif))*2, i] = 1 y_offset += -0.1 + 0.2*(n_tracks) + 0.1*(i-n_tracks) plt.text(motif_start, -ylim*(y_offset+0.1), motif, diff --git a/tangermeme/seqlet.py b/tangermeme/seqlet.py index f77e4e6..5554733 100644 --- a/tangermeme/seqlet.py +++ b/tangermeme/seqlet.py @@ -73,7 +73,7 @@ def _laplacian_null(X_sum, num_to_samp=10000, random_state=1234): prob_pos = len(pos_values) / len(values) urand = torch.from_numpy(numpy.random.RandomState(random_state).uniform( - size=(num_to_samp, 2))).to(X_sum.device) + size=(num_to_samp, 2))) icdf = numpy.log(1 - urand[:, 1]) @@ -210,7 +210,7 @@ def _isotonic_thresholds(values, null_values, increasing, target_fdr, def tfmodisco_seqlets(X_attr, window_size=21, flank=10, target_fdr=0.2, min_passing_frac=0.03, max_passing_frac=0.2, - weak_threshold_for_counting_sign=0.8, device='cuda'): + weak_threshold_for_counting_sign=0.8): """Extract seqlets using the procedure from TF-MoDISco. Seqlets are contiguous spans of high attribution characters. This method @@ -280,7 +280,7 @@ def tfmodisco_seqlets(X_attr, window_size=21, flank=10, target_fdr=0.2, values = X_sum.flatten() if len(values) > 1000000: values = torch.from_numpy(numpy.random.RandomState(1234).choice( - a=values, size=1000000, replace=False)).to(device) + a=values, size=1000000, replace=False)) pos_values = values[values >= 0] neg_values = values[values < 0] @@ -315,8 +315,9 @@ def tfmodisco_seqlets(X_attr, window_size=21, flank=10, target_fdr=0.2, X_sum[idxs] = numpy.abs(X_sum[idxs]) X_sum[~idxs] = -numpy.inf - X_sum[:, :flank] = -numpy.inf - X_sum[:, -flank:] = -numpy.inf + if flank > 0: + X_sum[:, :flank] = -numpy.inf + X_sum[:, -flank:] = -numpy.inf seqlets = _iterative_extract_seqlets(X_sum=X_sum, window_size=window_size, flank=flank, suppress=suppress) diff --git a/tangermeme/tools/fimo.py b/tangermeme/tools/fimo.py index 4d20fa6..e9d487c 100644 --- a/tangermeme/tools/fimo.py +++ b/tangermeme/tools/fimo.py @@ -13,10 +13,7 @@ from tqdm import tqdm - -LOG_2 = math.log(2) - -@numba.njit('float64(float64, float64)', cache=True) +@numba.njit('float64(float64, float64)', fastmath=True, cache=True) def logaddexp2(x, y): """Calculate the logaddexp in a numerically stable manner in base 2. @@ -41,36 +38,17 @@ def logaddexp2(x, y): The result of log2(pow(2, x) + pow(2, y)) """ - vmax, vmin = max(x, y), min(x, y) - return vmax + math.log(math.pow(2, vmin - vmax) + 1) / LOG_2 - - -@numba.njit('void(int32[:, :], float64[:], float64[:], int32, int32, float64)', - cache=True) -def _fast_pwm_to_cdf(int_log_pwm, old_logpdf, logpdf, alphabet_length, - seq_length, log_bg): - """A fast internal function for running the dynamic programming algorithm. - - This function is written in numba to speed up the dynamic programming used - to convert score bins into log p-values. This is not meant to be used - externally. - """ - - for i in range(1, seq_length): - logpdf[:] = -numpy.inf - - for j, x in enumerate(old_logpdf): - if x != -numpy.inf: - for k in range(alphabet_length): - offset = int_log_pwm[i, k] + if x == float("-inf") and y == float("-inf"): + return float("-inf") - v1 = logpdf[j + offset] - v2 = log_bg + x - logpdf[j + offset] = logaddexp2(v1, v2) + if x == float("inf") or y == float("inf"): + return float("inf") - old_logpdf[:] = logpdf + vmax, vmin = max(x, y), min(x, y) + return vmax + math.log2(math.pow(2, vmin - vmax) + 1) +@numba.njit(cache=True) def _pwm_to_mapping(log_pwm, bin_size): """An internal method for calculating score <-> log p-value mappings. @@ -108,33 +86,72 @@ def _pwm_to_mapping(log_pwm, bin_size): associated with each score bin. """ + n, l = log_pwm.shape + log_bg = math.log2(0.25) - int_log_pwm = numpy.round(log_pwm / bin_size).astype(numpy.int32).T.copy() - - smallest = int(numpy.min(numpy.cumsum(numpy.min(int_log_pwm, axis=-1), - axis=-1))) - largest = int(numpy.max(numpy.cumsum(numpy.max(int_log_pwm, axis=-1), - axis=-1))) + log_pwm.shape[1] - - logpdf = -numpy.inf * numpy.ones(largest - smallest + 1) - for i in range(log_pwm.shape[0]): - idx = int_log_pwm[0, i] - smallest - logpdf[idx] = numpy.logaddexp2(logpdf[idx], log_bg) - - old_logpdf = logpdf.copy() - logpdf[:] = 0 - - _fast_pwm_to_cdf(int_log_pwm, old_logpdf, logpdf, log_pwm.shape[0], - log_pwm.shape[1], log_bg) - - log1mcdf = logpdf.copy() + int_log_pwm = numpy.round(log_pwm / bin_size).astype(numpy.int32) + + smallest, largest = 9999999, -9999999 + log_pwm_min_csum, log_pwm_max_csum = 0, 0 + for i in range(l): + log_pwm_min = 9999999 + log_pwm_max = -9999999 + + for j in range(n): + log_pwm_min = min(log_pwm_min, int_log_pwm[j, i]) + log_pwm_max = max(log_pwm_max, int_log_pwm[j, i]) + + log_pwm_min_csum += log_pwm_min + log_pwm_max_csum += log_pwm_max + + smallest = min(smallest, log_pwm_min_csum) + largest = max(largest, log_pwm_max_csum) + + largest += l + + logpdf = numpy.empty(largest - smallest + 1) + old_logpdf = -numpy.inf * numpy.ones(largest - smallest + 1) + for i in range(n): + idx = int_log_pwm[i, 0] - smallest + old_logpdf[idx] = logaddexp2(old_logpdf[idx], log_bg) + + for i in range(1, l): + for j in range(largest - smallest + 1): + logpdf[j] = -numpy.inf + + for j, x in enumerate(old_logpdf): + if x != -numpy.inf: + for k in range(n): + idx = j + int_log_pwm[k, i] + logpdf[idx] = logaddexp2(logpdf[idx], log_bg + x) + + for j in range(largest - smallest + 1): + old_logpdf[j] = logpdf[j] + for i in range(len(logpdf) - 2, -1, -1): - log1mcdf[i] = numpy.logaddexp2(log1mcdf[i], log1mcdf[i + 1]) + logpdf[i] = logaddexp2(logpdf[i], logpdf[i + 1]) + + return smallest, logpdf + + +@numba.njit(parallel=True, cache=True) +def _all_pwm_to_mapping(motifs, motif_lengths, bin_size): + n = len(motif_lengths) - 1 + + smallests = numpy.empty(n, dtype='int64') + logpdfs = [numpy.empty(0) for i in range(n)] + + for i in numba.prange(n): + s, e = motif_lengths[i], motif_lengths[i+1] - return smallest, log1mcdf + smallest, logpdf = _pwm_to_mapping(motifs[:, s:e], bin_size) + smallests[i] = smallest + logpdfs[i] = logpdf + return smallests, logpdfs -@numba.njit(parallel=True, fastmath=True) + +@numba.njit(parallel=True, fastmath=True, cache=True) def _fast_hits(X, chrom_lengths, pwm, pwm_lengths, score_threshold, bin_size, smallest, score_to_pvals, score_to_pval_lengths): n_motifs = len(pwm_lengths) - 1 @@ -180,14 +197,15 @@ def _fast_hits(X, chrom_lengths, pwm, pwm_lengths, score_threshold, bin_size, return hits -@numba.njit +@numba.njit(cache=True) def _fast_convert(X, mapping): for i in range(X.shape[0]): X[i] = mapping[X[i]] def fimo(motifs, sequences, alphabet=['A', 'C', 'G', 'T'], bin_size=0.1, - eps=0.0001, threshold=0.0001, reverse_complement=True, dim=0): + eps=0.0001, threshold=0.0001, reverse_complement=True, return_counts=False, + dim=0): """An implementation of the FIMO algorithm from the MEME suite. This function implements the "Finding Individual Motif Instances" (FIMO) @@ -234,6 +252,11 @@ def fimo(motifs, sequences, alphabet=['A', 'C', 'G', 'T'], bin_size=0.1, Whether to scan each motif and also the reverse complements. Default is True. + return_counts: bool, optioal + Whether to only return the count of the number of matches instead of + dataframes containing information about each match. If True, the return + will be a single array. Default is False + dim: 0 or 1, optional Whether to return one dataframe for each motif containing all hits for that motif across all examples (0, default) or one dataframe for each @@ -243,12 +266,13 @@ def fimo(motifs, sequences, alphabet=['A', 'C', 'G', 'T'], bin_size=0.1, Returns ------- - hits: list of pandas.DataFrames + hits: list of pandas.DataFrames or numpy.ndarray A list of pandas.DataFrames containing motif hits, where the exact - semantics of each dataframe are determined by `dim`. + semantics of each dataframe are determined by `dim`. Alternatively, + a numpy array of just the number of counts per motif if return_counts + is set to True. """ - tic = time.time() log_threshold = math.log2(threshold) # Extract the motifs and potentially the reverse complements @@ -267,36 +291,28 @@ def fimo(motifs, sequences, alphabet=['A', 'C', 'G', 'T'], bin_size=0.1, # Initialize arrays to store motif properties n_motifs = len(motifs) - motif_pwms, motif_names, motif_lengths = [], [], [0] - _score_to_pvals, _score_to_pvals_lengths = [], [0] - _smallest = numpy.empty(n_motifs, dtype=numpy.int32) - _score_thresholds = numpy.empty(n_motifs, dtype=numpy.float32) + motif_names = numpy.array([name for name, _ in motifs]) + motif_lengths = [0] + [pwm.shape[-1] for _, pwm in motifs] + motif_lengths = numpy.cumsum(motif_lengths).astype(numpy.uint64) - # Fill out these motif properties - for i, (name, motif) in enumerate(motifs): - motif_names.append(name) - motif_lengths.append(motif.shape[-1]) - - motif_pwm = numpy.log2(motif + eps) - math.log2(0.25) - motif_pwms.append(motif_pwm) + motif_pwms = numpy.concatenate([pwm for _, pwm in motifs], axis=-1) + motif_pwms = numpy.log2(motif_pwms + eps) - math.log2(0.25) - smallest, mapping = _pwm_to_mapping(motif_pwm, bin_size) - _smallest[i] = smallest - _score_to_pvals.append(mapping) - _score_to_pvals_lengths.append(len(mapping)) + _smallest, _score_to_pvals = _all_pwm_to_mapping(motif_pwms, motif_lengths, + bin_size) + _score_to_pvals_lengths = [0] + _score_thresholds = numpy.empty(n_motifs, dtype=numpy.float32) + + for i in range(n_motifs): + _score_to_pvals_lengths.append(len(_score_to_pvals[i])) idx = numpy.where(_score_to_pvals[i] < log_threshold)[0] if len(idx) > 0: - _score_thresholds[i] = (idx[0] + smallest) * bin_size + _score_thresholds[i] = (idx[0] + _smallest[i]) * bin_size else: _score_thresholds[i] = float("inf") - # Convert these back to numpy arrays - motif_pwms = numpy.concatenate(motif_pwms, axis=-1) - motif_names = numpy.array(motif_names) - motif_lengths = numpy.cumsum(motif_lengths).astype(numpy.uint64) - _score_to_pvals = numpy.concatenate(_score_to_pvals) _score_to_pvals_lengths = numpy.cumsum(_score_to_pvals_lengths) @@ -344,6 +360,12 @@ def fimo(motifs, sequences, alphabet=['A', 'C', 'G', 'T'], bin_size=0.1, names = ['sequence_name', 'start', 'end', 'score', 'p-value'] n_ = n_motifs // 2 if reverse_complement else n_motifs + if return_counts == True: + counts = numpy.zeros(n_, dtype='int32') + for i in range(n_): + counts[i] = len(hits[i]) + len(hits[i+n_]) + return counts + for i in range(n_): if reverse_complement: hits_ = pandas.DataFrame(hits[i] + hits[i + n_], columns=names) diff --git a/tangermeme/tools/symmetric_tomtom.py b/tangermeme/tools/symmetric_tomtom.py new file mode 100644 index 0000000..e20cf29 --- /dev/null +++ b/tangermeme/tools/symmetric_tomtom.py @@ -0,0 +1,345 @@ +# tomtom.py +# Contact: Jacob Schreiber + +import time +import math +import numpy +import numba +import torch + +from numba import njit +from numba import prange +from numpy import uint8, uint64 + +from .tomtom import _binned_median +from .tomtom import _pairwise_max +from .tomtom import _merge_rc_results + +from .tomtom import _integer_distances_and_histogram +from .tomtom import _p_values +from .tomtom import tomtom + + +@njit +def _p_value_backgrounds(f, A, B, A_csum, nq, n_bins, t_max, offset): + """An internal function that calculates the backgrounds for p-values. + + This method takes in the histogram of integerized scores `f` and returns + the background probabilities of each overlap achieving a given score. + These scores are calculated for the complete overlap of the query and + target, but also for all overhangs where only part of the query and the + target are overlapping (on either end). Additionally, background + probabilities are calculated for all spans across the query for when the + target is smaller than the query and has to be scanned against it. + """ + + n = n_bins*nq + nq*offset + nqm1 = uint64(nq-1) + + # Clear A + for i in range(nq): + i = uint64(i) + for j in range(n): + j = uint64(j) + A[0, i, j] = 0 + A[1, i, j] = 0 + + for i in range(nq): + c = offset * (nq - i - 1) + i, c = uint64(i), uint64(c) + im1, nqmi, nqmi1 = uint64(i-1), uint64(nq-i), uint64(nq-i-1) + + if i == 0: + for k in range(1, n_bins+1): + k = uint64(k) + A[0, 0, k+c] = f[0, k] + A[1, nqm1, k+c] = f[nqm1, k] + else: + for k in range(n_bins*i+1): + k = uint64(k) + a0 = A[0, im1, k+c+offset] + a1 = A[1, nqmi, k+c+offset] + + if a0 > 0: + for l in range(1, n_bins+1): + l = uint64(l) + A[0, i, l+k+c] += a0 * f[i, l] + + if a1 > 0: + for l in range(1, n_bins+1): + l = uint64(l) + A[1, nqmi1, l+k+c] += a1 * f[nqmi1, l] + + for k in range(n): + k, km1 = uint64(k), uint64(k-1) + + if k > n_bins*(i+1)+c: + A_csum[0, i, k] = 1 + A_csum[1, nqmi1, k] = 1 + else: + A_csum[0, i, k] = A[0, i, k] + A_csum[1, nqmi1, k] = A[1, nqmi1, k] + if k > 0: + A_csum[0, i, k] += A_csum[0, i, km1] + A_csum[1, nqmi1, k] += A_csum[1, nqmi1, km1] + + ### + + B[0] = -1 + for i in range(1, nq): + _pairwise_max(B[i-1], A[0, i-1], A_csum[0, i-1], B[i], n) + _pairwise_max(B[i], A[1, nq-i], A_csum[1, nq-i], B[i], n) + + for i in range(nq, t_max+1): + _pairwise_max(B[i-1], A[0, nq-1], A_csum[0, nq-1], B[i], n) + + # Again, `axis` is not implemented for cumsum + for i in range(B.shape[0]): + for j in range(1, n): + B[i, j] += B[i, j-1] + + for j in range(n): + B[i, j] = 1 - B[i, j] + + +@njit(parallel=True) +def _tomtom(Q, T, Q_lens, T_lens, Q_norm, T_norm, rr_inv, rr_counts, n_nearest, + n_score_bins, n_median_bins, n_cache, reverse_complement): + """An internal function implementing the TOMTOM algorithm. + + This internal function is necessary to handle the numba component of the + implementation. Here, scratchboard memory is allocated for each thread and + the main parallel loop is called. Additionally, if reverse complements are + being considered, values are merged across both strands. + """ + + T_max = max(T_lens) + + Q_offsets = numpy.zeros(len(Q_lens)+1, dtype='int64') + Q_offsets[1:] = numpy.cumsum(Q_lens) + Q_max = max(Q_lens) + + n_in_targets = len(T_lens) // 2 if reverse_complement else len(T_lens) + n_out_targets = n_in_targets if n_nearest == -1 else n_nearest + n_outputs = 5 if n_nearest == -1 else 6 + nt = T.shape[-1] + + # Re-usable workspace for each thread instead of re-allocating + # and freeing large arrays for each example. + n = numba.get_num_threads() + n_len = Q_max*n_score_bins + Q_max*n_cache + + _gamma = numpy.empty((n, nt, Q_max), dtype='float64') + _gamma_int = numpy.empty((n, nt, Q_max), dtype='int8') + _f = numpy.empty((n, Q_max, n_score_bins+1), dtype='float64') + + _A = numpy.empty((n, Q_max, Q_max, n_len), dtype='float64') + _B = numpy.empty((n, T_max+1, n_len), dtype='float64') + _A_csum = numpy.empty((n, Q_max, Q_max, n_len), dtype='float64') + + _medians = numpy.empty((n, Q_max), dtype='float64') + _median_bins = numpy.empty((n, n_median_bins, 2), dtype='float64') + + _results = numpy.empty((n, len(T_lens), 5), dtype='float64') + results = numpy.empty((len(Q_lens), n_out_targets, n_outputs), + dtype='float64') + + for i in prange(len(Q_lens)): + nq = Q_lens[i] + pid = numba.get_thread_id() + + offset = _integer_distances_and_histogram(Q, T, _gamma[pid], + _gamma_int[pid], _f[pid], _medians[pid], _median_bins[pid], Q_norm, + T_norm, rr_counts, Q_offsets[i], nq, n_score_bins) + + if offset > n_cache: + print("Offset is larger than `n_cache`. Please increase `n_cache`" + " to at least ", offset) + + _p_value_backgrounds(_f[pid], _A[pid], _B[pid], _A_csum[pid], nq, + n_score_bins, T_max, offset) + + _p_values(_gamma_int[pid], _B[pid], rr_inv, T_lens, i, nq, offset, + _results[pid]) + + if reverse_complement == 1: + _merge_rc_results(_results[pid]) + else: + _results[pid, :, 4] = 0 + + if n_nearest == -1: + results[i] = _results[pid, :n_in_targets] + else: + idxs = numpy.argsort(_results[pid, :n_in_targets, 0])[:n_nearest] + results[i, :, :5] = _results[pid, idxs] + results[i, :, 5] = idxs + + # Enforce symmetry + if n_nearest == -1: + for i in range(results.shape[-1]): + for j in range(results.shape[0]): + for k in range(j): + results[j, k, i] = results[k, j, i] + + return results + + +def symmetric_tomtom(Xs, n_score_bins=100, n_median_bins=1000, + n_target_bins=100, n_cache=100, reverse_complement=True, n_jobs=-1): + """A method for assigning p-values to motif similarity. + + This method implements the TOMTOM algorithm for assigning p-values to motif + similarity scores. TOMTOM accounts for several issues that arise when + motifs are scanned against each other, including correctly calculating + scores for overlaps and accounting for motif length and information content + within the motifs. + + At a high level, TOMTOM works by calculating a background distribution of + scores for each position in the query and then uses dynamic programming to + calculating a distribution of scores for each span of matches, allowing for + potential overhangs on either side. + + Importantly, this method implements the "complete score" version of TOMTOM + which is more robust to edge effects. The "incomplete score" is not a good + score and so is not implemented. + + + Parameters + ---------- + Qs: list or numpy.ndarrays or torch.Tensors with shape (len(alphabet), len) + A list of query motifs to consider. Each query must have a shape + according to the PyTorch format where the length is the last aspect. + If these are PyTorch tensors they will be internally converted to a + numpy.ndarray. + + Ts: list or numpy.ndarrays or torch.Tensors with shape (len(alphabet), len) + A list of target motifs to compare each query against. Each target must + have a shape according to the PyTorch format where the length is the + last aspect. If these are PyTorch tensors they will be internally + converted to a numpy.ndarray. + + n_nearest: int or None, optional + The number of nearest targets to keep for each query, where nearness is + defined by the p-value. Setting this can significant reduce memory + because, otherwise, you get a len(Qs) by len(Ts) complete matrix. If + None, return the complete matrix. Default is None. + + n_score_bins: int, optional + The number of bins to use when discretizing scores. A higher number is + not necessarily better because you need the data to support each bin + in the distribution. This is `t` from the TOMTOM paper. Default is 100. + + n_median_bins: int, optional + The number of bins to use when approximating the medians. More bins + means higher precision when estimating the median but can also cause it + to take linearly longer. Default is 1000. + + n_target_bins: int or None, optional + Whether to use approximate hashing to speed up calculations by merging + target columns that are similar. This can significantly speed up + calculations and reduce memory at the cost of approximation. Each value + in the columns are binned and targets are merged together if all values + fall within the same bins, e.g., if both columns after binning are + [5, 11, 0, 1]. This parameter sets the number of bins to use when + discretizing the values in the target columns. Fewer bins means more + targets get merged together, which can speed up the calculations, but + also mean that the resulting p-values are less accurate. Conversely, + more bins means that fewer targets get merged together and higher + accuracy p-values but slower. If None, don't use approximate hashing. + Default is 100. + + n_cache: int, optional + A cache size to use when allocating the scratchpad. A higher number will + linearly increase the amount of memory used but will not increase the + amount of compute needed. Default is 250. + + reverse_complement: bool, optional + Whether to automatically compare each query to targets and also the + reverse complement of the target and merge the scores and p-values + accordingly. Default is True. + + n_jobs: int, optional + The number of threads for numba to use when parallelizing the + processing of query sequences. If -1, use all available threads. + Default is -1. + + + Returns + ------- + best_p_values: torch.Tensor, shape=(len(Qs), len(Ts)) + The p-value of the best alignment between each query and each target. + + best_scores: torch.Tensor, shape=(len(Qs), len(Ts)) + The scores of the best alignment between each query and each target. + + best_offsets: torch.Tensor, shape=(len(Qs), len(Ts)) + The offset of the best alignment between each query and each target. + + best_overlaps: torch.Tensor, shape=(len(Qs), len(Ts)) + The overlap of the best alignment between each query and each target. + + best_strands: torch.Tensor, shape=(len(Qs), len(Ts)) + The strand for the best alignment between each query and each target. + + best_idxs: torch.Tensor, shape=(len(Qs), len(Ts)), optional + When returning only a number of nearest neighbors, the index in the + original ordering of the targets corresponding to each returned + neighbor. These will be sorted by p-value. + """ + + if n_jobs != -1: + _n_jobs = numba.get_num_threads() + numba.set_num_threads(n_jobs) + + if isinstance(Xs[0], torch.Tensor): + Xs = [X.numpy(force=True) for X in Xs] + + # Enforce ordering + X_lens = numpy.array([X.shape[-1] for X in Xs], dtype='int64') + X_idxs = numpy.argsort(X_lens, kind='stable') + Xs = [Xs[idx] for idx in X_idxs] + + Q_lens = numpy.array([X.shape[-1] for X in Xs], dtype='int64') + Q = numpy.concatenate(Xs, axis=-1) + Q_norm = (Q ** 2).sum(axis=0) + + if reverse_complement: + Xs = Xs + [X[::-1, ::-1] for X in Xs] + + T_lens = numpy.array([X.shape[-1] for X in Xs], dtype='int64') + T = numpy.concatenate(Xs, axis=-1) + T_norm = (T ** 2).sum(axis=0) + + # Proceeds normally from here + if Q_norm.max() == 0 or T_norm.max() == 0: + raise ValueError("Cannot have all-zeroes as targets or query.") + + if n_target_bins is not None: + T_min = T.min(axis=-1, keepdims=True) + T_max = T.max(axis=-1, keepdims=True) + T_max[T_max == T_min] = T_min[T_max == T_min] + 1 + + T_ints = numpy.around((T - T_min) / (T_max - T_min) * (n_target_bins-1)) + T_ints = T_ints.T.dot(n_target_bins ** numpy.arange(len(T))[:, None]) + _, rr_idxs, rr_inv, rr_counts = numpy.unique(T_ints.flatten(), + return_index=True, return_inverse=True, return_counts=True) + + T = T[:, rr_idxs] + T_norm = T_norm[rr_idxs] + rr_inv = rr_inv.astype('uint64') + else: + rr_inv = numpy.arange(T.shape[-1]) + rr_counts = numpy.ones_like(rr_inv) + + ### + + results = _tomtom(Q, T, Q_lens, T_lens, Q_norm, T_norm, rr_inv, rr_counts, + -1, n_score_bins, n_median_bins, n_cache, int(reverse_complement)) + + if n_jobs != -1: + numba.set_num_threads(_n_jobs) + + ### Undo swap + + X_idxs2 = numpy.argsort(X_idxs) + return torch.from_numpy(results[X_idxs2][:, X_idxs2]).permute(2, 0, 1) diff --git a/tangermeme/tools/tomtom.py b/tangermeme/tools/tomtom.py index ca6399d..366a321 100644 --- a/tangermeme/tools/tomtom.py +++ b/tangermeme/tools/tomtom.py @@ -1,6 +1,7 @@ # tomtom.py # Contact: Jacob Schreiber +import time import math import numpy import numba @@ -50,8 +51,8 @@ def _binned_median(x, bins, x_min, x_max, counts): @njit -def _integer_distances_and_histogram(X, Y, gamma, f, Z, medians, median_bins, - X_norm, Y_norm, Y_counts, nq_csum, nq, n_bins): +def _integer_distances_and_histogram(X, Y, gamma, gamma_int, f, medians, + median_bins, X_norm, Y_norm, Y_counts, nq_csum, nq, n_bins): """An internal function for integerized scores and the histogram. This function is the main workhorse for the TOMTOM algorithm. It contains @@ -79,39 +80,37 @@ def _integer_distances_and_histogram(X, Y, gamma, f, Z, medians, median_bins, z = -math.sqrt(z) if z > 0 else 0 z_max_ = max(z_max_, z) z_min_ = min(z_min_, z) - Z[i, j] = z + gamma[j, i] = z # Subtract out the median from each row - m = _binned_median(Z[i], median_bins, z_min_, z_max_, + m = _binned_median(gamma[:, i], median_bins, z_min_, z_max_, Y_counts) medians[i] = m - z_min_before = z_min - z_min = min(z_min, z_min_ - m) z_max = max(z_max, z_max_ - m) # Find the minimum value and the number of bins needed to get there i_min = int(math.floor(z_min)) #offset bin_scale = int(math.floor(n_bins / (z_max - i_min))) #scale + offset = -i_min * bin_scale + + for i in range(nq): + medians[i] = medians[i] + i_min f[:] = 0 - ys = numpy.sum(Y_counts) + # Convert the distances to bins and record the histogram of counts for i in range(nq): + k = nq - i - 1 for j in range(Y.shape[-1]): - x = (Z[i, j] - i_min - medians[i]) * bin_scale + x = math.floor((gamma[j, i] - medians[i]) * bin_scale + 0.5) - if x >= 0: - x_int = uint64(math.floor(x + 0.5)) - else: - x_int = uint64(math.floor(x - 0.5)) - - gamma[i, j] = x_int - f[i, x_int] += Y_counts[j] / ys + gamma_int[j, k] = x - offset + f[i, uint64(x)] += Y_counts[j] / ys - return uint64(-i_min * bin_scale) + return uint64(offset) @njit @@ -214,7 +213,7 @@ def _p_value_backgrounds(f, A, B, A_csum, nq, n_bins, t_max, offset): @njit -def _p_values(gamma, B_cdfs, rr_inv, T_lens, nq, offset, n_bins, results): +def _p_values(gamma, B_cdfs, rr_inv, T_lens, iq, nq, offset, results): """An internal function for calculating the best match and p-values. This function will take in the integerized score matrix `gamma` and @@ -225,36 +224,42 @@ def _p_values(gamma, B_cdfs, rr_inv, T_lens, nq, offset, n_bins, results): score to the background distribution. """ - total_offset = 0 + n = len(T_lens) // 2 + total_offset = uint64(0) + + max_nt = gamma.shape[0] + t_sums = numpy.empty(max_nt+nq-1, dtype='int16') + for i, nt in enumerate(T_lens): + nt = uint64(nt) + results[i, 0] = 1 results[i, 1] = 0 - - for k in range(-nq + 1, nt): - score = 0 - overlap = 0 - - if k < 0: - for j in range(-k, nq): - if nt < nq and overlap == nt: - break - - j, t_idx = uint64(j), uint64(total_offset + j + k) - score += gamma[j, rr_inv[t_idx]] - overlap += 1 - else: - for j in range(min(nq, nt-k)): - j, t_idx = uint64(j), uint64(total_offset + j + k) - score += gamma[j, rr_inv[t_idx]] - overlap += 1 - - score = score + (nq - overlap) * offset + + if i <= iq or (i >= n and i <= (n + iq)): + total_offset += nt + continue + + for k in range(nt+nq-1): + k = uint64(k) + t_sums[k] = nq * offset + + for k in range(nt): + k = uint64(k) + k_idx = uint64(rr_inv[total_offset + k]) + for l in range(nq): + l = uint64(l) + t_sums[k+l] += gamma[k_idx, l] + + for k in range(nt+nq-1): + score = t_sums[k] + overlap = min(k+1, nq) - max(0, k-nt+1) if score >= results[i, 1]: if score == results[i, 1] and results[i, 2] >= overlap: continue results[i, 0] = B_cdfs[nt, uint64(score-1)] results[i, 1] = score - results[i, 2] = k + results[i, 2] = k - nq + 1 results[i, 3] = overlap total_offset += nt @@ -308,8 +313,8 @@ def _tomtom(Q, T, Q_lens, T_lens, Q_norm, T_norm, rr_inv, rr_counts, n_nearest, n = numba.get_num_threads() n_len = Q_max*n_score_bins + Q_max*n_cache - _Z = numpy.empty((n, Q_max, nt), dtype='float64') - _gamma = numpy.empty((n, Q_max, nt), dtype='int32') + _gamma = numpy.empty((n, nt, Q_max), dtype='float64') + _gamma_int = numpy.empty((n, nt, Q_max), dtype='int8') _f = numpy.empty((n, Q_max, n_score_bins+1), dtype='float64') _A = numpy.empty((n, Q_max, Q_max, n_len), dtype='float64') @@ -327,9 +332,9 @@ def _tomtom(Q, T, Q_lens, T_lens, Q_norm, T_norm, rr_inv, rr_counts, n_nearest, nq = Q_lens[i] pid = numba.get_thread_id() - offset = _integer_distances_and_histogram(Q, T, _gamma[pid], _f[pid], - _Z[pid], _medians[pid], _median_bins[pid], Q_norm, T_norm, - rr_counts, Q_offsets[i], nq, n_score_bins) + offset = _integer_distances_and_histogram(Q, T, _gamma[pid], + _gamma_int[pid], _f[pid], _medians[pid], _median_bins[pid], Q_norm, + T_norm, rr_counts, Q_offsets[i], nq, n_score_bins) if offset > n_cache: print("Offset is larger than `n_cache`. Please increase `n_cache`" @@ -338,8 +343,8 @@ def _tomtom(Q, T, Q_lens, T_lens, Q_norm, T_norm, rr_inv, rr_counts, n_nearest, _p_value_backgrounds(_f[pid], _A[pid], _B[pid], _A_csum[pid], nq, n_score_bins, T_max, offset) - _p_values(_gamma[pid], _B[pid], rr_inv, T_lens, nq, offset, - n_score_bins, _results[pid]) + _p_values(_gamma_int[pid], _B[pid], rr_inv, T_lens, -1, nq, offset, + _results[pid]) if reverse_complement == 1: _merge_rc_results(_results[pid]) @@ -353,7 +358,8 @@ def _tomtom(Q, T, Q_lens, T_lens, Q_norm, T_norm, rr_inv, rr_counts, n_nearest, results[i, :, :5] = _results[pid, idxs] results[i, :, 5] = idxs - return results + + return results def tomtom(Qs, Ts, n_nearest=None, n_score_bins=100, n_median_bins=1000, @@ -469,7 +475,6 @@ def tomtom(Qs, Ts, n_nearest=None, n_score_bins=100, n_median_bins=1000, if isinstance(Ts[0], torch.Tensor): Ts = [T.numpy(force=True) for T in Ts] - Q_lens = numpy.array([Q.shape[-1] for Q in Qs], dtype='int64') Q = numpy.concatenate(Qs, axis=-1) Q_norm = (Q ** 2).sum(axis=0) @@ -481,9 +486,15 @@ def tomtom(Qs, Ts, n_nearest=None, n_score_bins=100, n_median_bins=1000, T = numpy.concatenate(Ts, axis=-1) T_norm = (T ** 2).sum(axis=0) + if Q_norm.max() == 0 or T_norm.max() == 0: + raise ValueError("Cannot have all-zeroes as targets or query.") + if n_target_bins is not None: - T_ints = numpy.around(T / T.max(axis=-1, keepdims=True) * - (n_target_bins-1)) + T_min = T.min(axis=-1, keepdims=True) + T_max = T.max(axis=-1, keepdims=True) + T_max[T_max == T_min] = T_min[T_max == T_min] + 1 + + T_ints = numpy.around((T - T_min) / (T_max - T_min) * (n_target_bins-1)) T_ints = T_ints.T.dot(n_target_bins ** numpy.arange(len(T))[:, None]) _, rr_idxs, rr_inv, rr_counts = numpy.unique(T_ints.flatten(), return_index=True, return_inverse=True, return_counts=True) @@ -491,6 +502,9 @@ def tomtom(Qs, Ts, n_nearest=None, n_score_bins=100, n_median_bins=1000, T = T[:, rr_idxs] T_norm = T_norm[rr_idxs] rr_inv = rr_inv.astype('uint64') + else: + rr_inv = numpy.arange(T.shape[-1]) + rr_counts = numpy.ones_like(rr_inv) ### diff --git a/tangermeme/utils.py b/tangermeme/utils.py index 9e3991f..f901d3f 100644 --- a/tangermeme/utils.py +++ b/tangermeme/utils.py @@ -1,4 +1,4 @@ -# ablate.py +# utils.py # Author: Jacob Schreiber import numpy diff --git a/tests/tools/test_tomtom.py b/tests/tools/test_tomtom.py index 411dbd1..5f68aae 100644 --- a/tests/tools/test_tomtom.py +++ b/tests/tools/test_tomtom.py @@ -183,6 +183,58 @@ def test_tomtom(): 0., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1]) +def test_tomtom_reverse_complement(): + pwms = generate_random_meme(n=20) + p0, scores0, offsets0, overlaps0, strands0 = tomtom(pwms, pwms) + p1, scores1, offsets1, overlaps1, strands1 = tomtom(pwms, + [p[::-1, ::-1] for p in pwms]) + + assert_array_almost_equal(p0, p1, 4) + assert_array_almost_equal(scores0, scores1, 4) + assert_array_almost_equal(offsets0, offsets1, 4) + assert_array_almost_equal(overlaps0, overlaps1, 4) + assert_array_almost_equal(strands0, 1-strands1) + + +def test_tomtom_homomotifs(): + pwms = generate_random_meme(n=5) + all_a = numpy.array([ + [1, 0, 0, 0], + [1, 0, 0, 0], + [1, 0, 0, 0], + [1, 0, 0, 0], + [1, 0, 0, 0] + ]) + + p, scores, offsets, overlaps, strands = tomtom(pwms, [all_a]) + assert_array_almost_equal(p[:, 0], [1, 1, 1, 1, 1], 4) + assert_array_almost_equal(scores[:, 0], [1600, 700, 400, 1800, 800], 4) + assert_array_almost_equal(torch.abs(offsets[:, 0]), [13, 4, 1, 15, 5], 4) + assert_array_almost_equal(overlaps[:, 0], [3, 3, 3, 3, 3], 4) + assert_array_almost_equal(strands[:, 0], [1, 1, 1, 1, 1]) + + p, scores, offsets, overlaps, strands = tomtom([all_a], pwms) + assert_array_almost_equal(p[0], [0.4936, 1.0000, 0.7399, 0.9794, 0.2521], 4) + assert_array_almost_equal(scores[0], [381., 360., 371., 374., 381.], 4) + assert_array_almost_equal(offsets[0], [-1., 6., -2., 1., 0.], 4) + assert_array_almost_equal(overlaps[0], [3., 1., 2., 4., 4.], 4) + assert_array_almost_equal(strands[0], [1., 1., 1., 0., 1.]) + + +def test_tomtom_zeroes(): + pwms = generate_random_meme(n=5) + all_zeroes = numpy.array([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0] + ]) + + assert_raises(ValueError, tomtom, [all_zeroes], pwms) + assert_raises(ValueError, tomtom, pwms, [all_zeroes]) + + def test_tomtom_subsets(): pwms = generate_random_meme(n=20) p, scores, offsets, overlaps, strands = tomtom(pwms[:2], pwms) @@ -222,6 +274,58 @@ def test_tomtom_subsets(): assert_array_almost_equal(strands, strands3[:2], 6) +def test_tomtom_self(): + pwms = generate_random_meme(n=1) + p, scores, offsets, overlaps, strands = tomtom(pwms, pwms) + + assert p.shape == (1, 1) + assert scores.shape == (1, 1) + assert offsets.shape == (1, 1) + assert overlaps.shape == (1, 1) + assert strands.shape == (1, 1) + + assert_array_almost_equal(p, [[4.4409e-16]], 6) + assert_array_almost_equal(scores, [[1346.]]) + assert_array_almost_equal(offsets, [[0.]]) + assert_array_almost_equal(overlaps, [[16.]]) + assert_array_almost_equal(strands, [[0.]]) + + +def test_tomtom_selfp1(): + pwms = generate_random_meme(n=1) + p, scores, offsets, overlaps, strands = tomtom(pwms, [p + 1 for p in pwms]) + + assert p.shape == (1, 1) + assert scores.shape == (1, 1) + assert offsets.shape == (1, 1) + assert overlaps.shape == (1, 1) + assert strands.shape == (1, 1) + + assert_array_almost_equal(p, [[-1.7764e-15]], 6) + assert_array_almost_equal(scores, [[1491.]]) + assert_array_almost_equal(offsets, [[0.]]) + assert_array_almost_equal(overlaps, [[16.]]) + assert_array_almost_equal(strands, [[0.]]) + + +def test_tomtom_self_rc(): + pwms = generate_random_meme(n=1) + p, scores, offsets, overlaps, strands = tomtom(pwms, + [p[::-1, ::-1] for p in pwms]) + + assert p.shape == (1, 1) + assert scores.shape == (1, 1) + assert offsets.shape == (1, 1) + assert overlaps.shape == (1, 1) + assert strands.shape == (1, 1) + + assert_array_almost_equal(p, [[4.4409e-16]], 6) + assert_array_almost_equal(scores, [[1346.]]) + assert_array_almost_equal(offsets, [[0.]]) + assert_array_almost_equal(overlaps, [[16.]]) + assert_array_almost_equal(strands, [[1.]]) + + def test_tomtom_meme(): pwms = list(read_meme("tests/data/test.meme").values()) p, scores, offsets, overlaps, strands = tomtom(pwms[:1], pwms)