From 7944034712fa053c35781296d177390af3ee30a9 Mon Sep 17 00:00:00 2001 From: Al-murphy Date: Wed, 21 Aug 2024 08:44:56 +0000 Subject: [PATCH] enable N return in characters() --- docs/whats_new.rst | 1 + tangermeme/utils.py | 36 ++++++++++++++++++++++++++++-------- tests/test_utils.py | 8 ++++++-- 3 files changed, 35 insertions(+), 10 deletions(-) diff --git a/docs/whats_new.rst b/docs/whats_new.rst index bc95b83..cfb5686 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -18,6 +18,7 @@ match - Various other changes: 1) Counts from regions that cannot be extracted from a provided bigwig file (such as for a missing chromosome) are now set to nan rather than 0. This will effect the threshold value used for filtering background regions. 2) Small change to the binning strategy for gc values, which could mean that matching loci generated in a previous version will not be reproduced exactly in all cases, even when using the same random seed. + 3) Enable the handling of 'N' in sequences or [0,0,0,0], i.e. an ambiguous genomic positions. Updated the `characters()` and the `_validate_input()` in `utils` module to enable this. Version 0.2.3 diff --git a/tangermeme/utils.py b/tangermeme/utils.py index 9b56509..6c4d004 100644 --- a/tangermeme/utils.py +++ b/tangermeme/utils.py @@ -8,8 +8,8 @@ from tqdm import tqdm -def _validate_input(X, name, shape=None, dtype=None, min_value=None, - max_value=None, ohe=False, ohe_dim=1): +def _validate_input(X, name, shape=None, dtype=None, min_value=None, + max_value=None, ohe=False, ohe_dim=1,allow_N=False): """An internal function for validating properties of the input. This function will take in an object and verify characteristics of it, such @@ -33,7 +33,7 @@ def _validate_input(X, name, shape=None, dtype=None, min_value=None, dtype: torch.dtype or None, optional The dtype the tensor must have. If not provided, no check is performed. Default is None. - + min_value: float or None, optional The minimum value that can be in the tensor, inclusive. If None, no check is performed. Default is None. @@ -45,6 +45,10 @@ def _validate_input(X, name, shape=None, dtype=None, min_value=None, ohe: bool, optional Whether the input must be a one-hot encoding, i.e., only consist of zeroes and ones. Default is False. + + allow_N: bool, optional + Whether to allow the return of the character 'N' in the sequence, i.e. + if pwm at a position is all 0's return N. Default is False. """ if not isinstance(X, torch.Tensor): @@ -63,7 +67,7 @@ def _validate_input(X, name, shape=None, dtype=None, min_value=None, raise ValueError("{} must have dtype {}".format(name, dtype)) if min_value is not None and X.min() < min_value: - raise ValueError("{} cannot have a value below {}".format(name, + raise ValueError("{} cannot have a value below {}".format(name, min_value)) if max_value is not None and X.max() > max_value: @@ -78,7 +82,8 @@ def _validate_input(X, name, shape=None, dtype=None, min_value=None, if not all(values == torch.tensor([0, 1], device=X.device)): raise ValueError("{} must be one-hot encoded.".format(name)) - if not (X.sum(axis=ohe_dim) == 1).all(): + if ((not (X.sum(axis=1) == 1).all()) and (not allow_N) + ) or ((allow_N) and (not ((X.sum(axis=ohe_dim) == 1) | (X.sum(axis=ohe_dim) == 0)).all())): raise ValueError("{} must be one-hot encoded ".format(name) + "and cannot have unknown characters.") @@ -132,7 +137,7 @@ def _cast_as_tensor(value, dtype=None): return torch.tensor(value, dtype=dtype) -def characters(pwm, alphabet=['A', 'C', 'G', 'T'], force=False): +def characters(pwm, alphabet=['A', 'C', 'G', 'T'], force=False, allow_N=False): """Converts a PWM/one-hot encoding to a string sequence. This function takes in a PWM or one-hot encoding and converts it to the @@ -156,6 +161,10 @@ def characters(pwm, alphabet=['A', 'C', 'G', 'T'], force=False): Whether to force a sequence to be produced even when there are ties. At each position that there is a tight, the character earlier in the sequence will be used. Default is False. + + allow_N: bool, optional + Whether to allow the return of the character 'N' in the sequence, i.e. + if pwm at a position is all 0's return N. Default is False. Returns @@ -163,6 +172,10 @@ def characters(pwm, alphabet=['A', 'C', 'G', 'T'], force=False): seq: str A string where the length is the second dimension of PWM. """ + + #if (batch, alphabet_size, motif_size) and batch = 1, remove batch axis + if len(pwm.shape) == 3 and pwm.shape[0] == 1: + pwm = pwm[0] if len(pwm.shape) != 2: raise ValueError("PWM must have two dimensions where the " + @@ -174,7 +187,7 @@ def characters(pwm, alphabet=['A', 'C', 'G', 'T'], force=False): "provided alphabet.") pwm_ismax = pwm == pwm.max(dim=0, keepdims=True).values - if pwm_ismax.sum(axis=0).max() > 1 and force == False: + if pwm_ismax.sum(axis=0).max() > 1 and force == False and allow_N == False: raise ValueError("At least one position in the PWM has multiple " + "letters with the same probability.") @@ -182,7 +195,14 @@ def characters(pwm, alphabet=['A', 'C', 'G', 'T'], force=False): if isinstance(pwm, torch.Tensor): pwm = pwm.numpy(force=True) - return ''.join(alphabet[pwm.argmax(axis=0)]) + if allow_N: + n_inds = numpy.where(pwm.sum(axis=0)==0)[0] + dna_chars = alphabet[pwm.argmax(axis=0)] + dna_chars[n_inds] = 'N' + else: + dna_chars = alphabet[pwm.argmax(axis=0)] + + return ''.join(dna_chars) @numba.njit("void(int8[:, :], int8[:], int8[:])") diff --git a/tests/test_utils.py b/tests/test_utils.py index 2fca19c..342ced0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -79,15 +79,19 @@ def test_characters_raise_alphabet(): def test_characters_raise_dimensions(): seq = 'GCTAC' + #this will work for shape (1,4,5) but not for (N,4,5) where N > 1 ohe = torch.tensor([[ [0.25, 0.00, 0.10, 0.95, 0.00], [0.20, 1.00, 1.00, 0.05, 1.00], [0.30, 0.00, 0.30, 0.00, 0.00], [0.25, 0.00, 3.00, 0.00, 0.00] ]]) - + + assert characters(ohe) == seq + + ohe = torch.concat([ohe, ohe], dim=0) assert_raises(ValueError, characters, ohe, ['A', 'C', 'G', 'T']) - + ohe = torch.tensor([0.25, 0.00, 0.10, 0.95, 0.00]) assert_raises(ValueError, characters, ohe, ['A', 'C', 'G', 'T'])