Skip to content

Commit

Permalink
Merge pull request #16 from Al-Murphy/main
Browse files Browse the repository at this point in the history
enable N return in characters()
  • Loading branch information
jmschrei authored Sep 6, 2024
2 parents 7d7ced4 + 7944034 commit 631baa0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
1 change: 1 addition & 0 deletions docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ match
- Various other changes:
1) Counts from regions that cannot be extracted from a provided bigwig file (such as for a missing chromosome) are now set to nan rather than 0. This will effect the threshold value used for filtering background regions.
2) Small change to the binning strategy for gc values, which could mean that matching loci generated in a previous version will not be reproduced exactly in all cases, even when using the same random seed.
3) Enable the handling of 'N' in sequences or [0,0,0,0], i.e. an ambiguous genomic positions. Updated the `characters()` and the `_validate_input()` in `utils` module to enable this.


Version 0.2.3
Expand Down
36 changes: 28 additions & 8 deletions tangermeme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tqdm import tqdm


def _validate_input(X, name, shape=None, dtype=None, min_value=None,
max_value=None, ohe=False, ohe_dim=1):
def _validate_input(X, name, shape=None, dtype=None, min_value=None,
max_value=None, ohe=False, ohe_dim=1,allow_N=False):
"""An internal function for validating properties of the input.
This function will take in an object and verify characteristics of it, such
Expand All @@ -33,7 +33,7 @@ def _validate_input(X, name, shape=None, dtype=None, min_value=None,
dtype: torch.dtype or None, optional
The dtype the tensor must have. If not provided, no check is performed.
Default is None.
min_value: float or None, optional
The minimum value that can be in the tensor, inclusive. If None, no
check is performed. Default is None.
Expand All @@ -45,6 +45,10 @@ def _validate_input(X, name, shape=None, dtype=None, min_value=None,
ohe: bool, optional
Whether the input must be a one-hot encoding, i.e., only consist of
zeroes and ones. Default is False.
allow_N: bool, optional
Whether to allow the return of the character 'N' in the sequence, i.e.
if pwm at a position is all 0's return N. Default is False.
"""

if not isinstance(X, torch.Tensor):
Expand All @@ -63,7 +67,7 @@ def _validate_input(X, name, shape=None, dtype=None, min_value=None,
raise ValueError("{} must have dtype {}".format(name, dtype))

if min_value is not None and X.min() < min_value:
raise ValueError("{} cannot have a value below {}".format(name,
raise ValueError("{} cannot have a value below {}".format(name,
min_value))

if max_value is not None and X.max() > max_value:
Expand All @@ -78,7 +82,8 @@ def _validate_input(X, name, shape=None, dtype=None, min_value=None,
if not all(values == torch.tensor([0, 1], device=X.device)):
raise ValueError("{} must be one-hot encoded.".format(name))

if not (X.sum(axis=ohe_dim) == 1).all():
if ((not (X.sum(axis=1) == 1).all()) and (not allow_N)
) or ((allow_N) and (not ((X.sum(axis=ohe_dim) == 1) | (X.sum(axis=ohe_dim) == 0)).all())):
raise ValueError("{} must be one-hot encoded ".format(name) +
"and cannot have unknown characters.")

Expand Down Expand Up @@ -132,7 +137,7 @@ def _cast_as_tensor(value, dtype=None):
return torch.tensor(value, dtype=dtype)


def characters(pwm, alphabet=['A', 'C', 'G', 'T'], force=False):
def characters(pwm, alphabet=['A', 'C', 'G', 'T'], force=False, allow_N=False):
"""Converts a PWM/one-hot encoding to a string sequence.
This function takes in a PWM or one-hot encoding and converts it to the
Expand All @@ -156,13 +161,21 @@ def characters(pwm, alphabet=['A', 'C', 'G', 'T'], force=False):
Whether to force a sequence to be produced even when there are ties.
At each position that there is a tight, the character earlier in the
sequence will be used. Default is False.
allow_N: bool, optional
Whether to allow the return of the character 'N' in the sequence, i.e.
if pwm at a position is all 0's return N. Default is False.
Returns
-------
seq: str
A string where the length is the second dimension of PWM.
"""

#if (batch, alphabet_size, motif_size) and batch = 1, remove batch axis
if len(pwm.shape) == 3 and pwm.shape[0] == 1:
pwm = pwm[0]

if len(pwm.shape) != 2:
raise ValueError("PWM must have two dimensions where the " +
Expand All @@ -174,15 +187,22 @@ def characters(pwm, alphabet=['A', 'C', 'G', 'T'], force=False):
"provided alphabet.")

pwm_ismax = pwm == pwm.max(dim=0, keepdims=True).values
if pwm_ismax.sum(axis=0).max() > 1 and force == False:
if pwm_ismax.sum(axis=0).max() > 1 and force == False and allow_N == False:
raise ValueError("At least one position in the PWM has multiple " +
"letters with the same probability.")

alphabet = numpy.array(alphabet)
if isinstance(pwm, torch.Tensor):
pwm = pwm.numpy(force=True)

return ''.join(alphabet[pwm.argmax(axis=0)])
if allow_N:
n_inds = numpy.where(pwm.sum(axis=0)==0)[0]
dna_chars = alphabet[pwm.argmax(axis=0)]
dna_chars[n_inds] = 'N'
else:
dna_chars = alphabet[pwm.argmax(axis=0)]

return ''.join(dna_chars)


@numba.njit("void(int8[:, :], int8[:], int8[:])", cache=True)
Expand Down
8 changes: 6 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,19 @@ def test_characters_raise_alphabet():

def test_characters_raise_dimensions():
seq = 'GCTAC'
#this will work for shape (1,4,5) but not for (N,4,5) where N > 1
ohe = torch.tensor([[
[0.25, 0.00, 0.10, 0.95, 0.00],
[0.20, 1.00, 1.00, 0.05, 1.00],
[0.30, 0.00, 0.30, 0.00, 0.00],
[0.25, 0.00, 3.00, 0.00, 0.00]
]])


assert characters(ohe) == seq

ohe = torch.concat([ohe, ohe], dim=0)
assert_raises(ValueError, characters, ohe, ['A', 'C', 'G', 'T'])

ohe = torch.tensor([0.25, 0.00, 0.10, 0.95, 0.00])
assert_raises(ValueError, characters, ohe, ['A', 'C', 'G', 'T'])

Expand Down

0 comments on commit 631baa0

Please sign in to comment.