From e2910b771bf6dd7cc4196af689201c77445dd09d Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 8 Apr 2024 11:54:47 -0700 Subject: [PATCH] Poisson context likelihood and better cli flexibility (#126) * Add poisson likelihood functions * working poisson context likelihood * revamp filter method * format and lint * most tests passing * add context likelihood test * testing tweaks and format * update docs, format, and lint --- docs/quickstart.rst | 4 +- gctree/branching_processes.py | 316 ++++++++++++++++++++-------------- gctree/cli.py | 26 ++- gctree/isotype.py | 2 +- gctree/isotyping.py | 19 +- gctree/mutation_model.py | 142 ++++++++++----- tests/smalltest.sh | 8 + tests/test_isotype.py | 8 +- tests/test_likelihoods.py | 36 ++++ 9 files changed, 375 insertions(+), 186 deletions(-) diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 5fe28b24..fdf1fcdc 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -106,7 +106,7 @@ This file may be manipulated using ``gctree infer``, instead of providing a dnapars ``outfile``. .. note:: - Although described below, using mutability parsimony or isotype parsimony + Although described below, using context likelihood, mutability parsimony, or isotype parsimony as ranking criteria is experimental, and has not yet been shown in a careful validation to improve tree inference. Only the default branching process likelihood is recommended for tree ranking! @@ -117,7 +117,7 @@ between trees. Providing arguments ``--isotype_mapfile`` and arguments ``--mutability`` and ``--substitution`` allows trees to be ranked according to a context-sensitive mutation model. By default, trees are ranked lexicographically, first maximizing likelihood, then minimizing isotype -parsimony and mutabilities, if such information is provided. +parsimony, and finally maximizing a context-based poisson likelihood, if such information is provided. Ranking priorities can be adjusted using the argument ``--ranking_coeffs``. For example, to find the optimal tree diff --git a/gctree/branching_processes.py b/gctree/branching_processes.py index bb3be58d..17b9d43e 100755 --- a/gctree/branching_processes.py +++ b/gctree/branching_processes.py @@ -6,7 +6,10 @@ import gctree.utils from gctree.isotyping import _isotype_dagfuncs, _isotype_annotation_dagfuncs -from gctree.mutation_model import _mutability_dagfuncs +from gctree.mutation_model import ( + _mutability_dagfuncs, + _context_poisson_likelihood_dagfuncs, +) from gctree.phylip_parse import disambiguate from frozendict import frozendict @@ -408,7 +411,7 @@ def mle(self, **kwargs) -> Tuple[np.float64, np.float64]: (p, q) = \arg\max_{p,q\in [0,1]}\ell(p, q) Args: - kwargs: keyword arguments passed along to the log likelihood :meth:`CollapsedTree.ll` + kwargs: keyword arguments passed along to the branching process likelihood :meth:`CollapsedTree.ll` Returns: Tuple :math:`(p, q)` with estimated branching probability and estimated mutation probability @@ -602,7 +605,6 @@ def my_layout(node): "\n".join(mutations), fsize=6, tight_text=False, - ftype="Courier", ) if start == 0: T.margin_top = 6 @@ -1050,7 +1052,7 @@ def ll( marginal: compute the marginal likelihood over trees, otherwise compute the joint likelihood of trees Returns: - Log likelihood :math:`\ell(p, q; T, A)` and its gradient :math:`\nabla\ell(p, q; T, A)` + Log branching process likelihood :math:`\ell(p, q; T, A)` and its gradient :math:`\nabla\ell(p, q; T, A)` """ if self._cm_countlist is None: if self._forest is not None: @@ -1122,7 +1124,7 @@ def mle(self, **kwargs) -> Tuple[np.float64, np.float64]: (p, q) = \arg\max_{p,q\in [0,1]}\ell(p, q) Args: - kwargs: keyword arguments passed along to the log likelihood :meth:`CollapsedForest.ll` + kwargs: keyword arguments passed along to the branching process likelihood :meth:`CollapsedForest.ll` Returns: Tuple :math:`(p, q)` with estimated branching probability and estimated mutation probability @@ -1131,7 +1133,7 @@ def mle(self, **kwargs) -> Tuple[np.float64, np.float64]: return self.parameters @_requires_dag - def filter_trees( + def filter_trees( # noqa: C901 self, ranking_coeffs: Optional[Sequence[float]] = None, mutability_file: Optional[str] = None, @@ -1142,19 +1144,24 @@ def filter_trees( outbase: str = "gctree.out", summarize_forest: bool = False, tree_stats: bool = False, + branching_process_ranking_coeff: float = -1, + use_old_mut_parsimony: bool = False, ) -> CollapsedForest: """Filter trees according to specified criteria. Trim the forest to minimize a linear combination of branching process likelihood, isotype parsimony score, - mutability parsimony score, and number of alleles, with coefficients + context/mutability-based Poisson likelihood, and number of alleles, with coefficients provided in the argument ``ranking_coeffs`, in that order. Args: ranking_coeffs: A list or tuple of coefficients for prioritizing tree weights. - The order of coefficients is: isotype parsimony score, mutability parsimony score, + The order of coefficients is: isotype parsimony score, context poisson likelihood, and number of alleles. A coefficient of ``-1`` will be applied to branching process - likelihood. + likelihood by default, unless a different value is provided to the keyword argument + `branching_process_ranking_coeff`. Trees are chosen to minimize this linear combination + of tree weights, so weights for which larger values are more optimal (such as + likelihoods) should have negative coefficients. If ranking_coeffs is not provided, trees will be ranked lexicographically by likelihood, then by other traits, in the same order. mutability_file: A mutability model @@ -1162,34 +1169,55 @@ def filter_trees( ignore_isotype: Ignore isotype parsimony when ranking. By default, isotype information added with :meth:``add_isotypes`` will be used to compute isotype parsimony, which is used in ranking. chain_split: The index at which non-adjacent sequences are concatenated, for calculating - mutability parsimony. + context-based Poisson likelihood. verbose: print information about trimming outbase: file name stem for a file with information for each tree in the DAG. summarize_forest: whether to write a summary of the forest to file `[outbase].forest_summary.log` tree_stats: whether to write stats for each tree in the forest to file `[outbase].tree_stats.log` + branching_process_ranking_coeff: Ranking coefficient to use for branching process likelihood. Value + is ignored unless `ranking_coeffs` argument is provided. + use_old_mut_parsimony: Whether to use the deprecated 'mutability parsimony' instead of + context-based poisson likelihood (only applicable if mutability and substitution files are + provided. Returns: The trimmed forest, containing all optimal trees according to the specified criteria, and a tuple - of data about the trees in that forest, with format (ll, isotype parsimony, mutability parsimony, alleles). + of data about the trees in that forest, with format (branching process likelihood, isotype parsimony, + context-based Poisson likelihood, alleles). """ dag = self._forest - if self.parameters is None: - self.mle(marginal=True) - p, q = self.parameters - ll_dagfuncs = _ll_genotype_dagfuncs(p, q) - placeholder_dagfuncs = hdag.utils.AddFuncDict( - { - "start_func": lambda n: 0, - "edge_weight_func": lambda n1, n2: 0, - "accum_func": sum, - }, - name="", - ) - if ignore_isotype or not self.is_isotyped: - iso_funcs = placeholder_dagfuncs + + if ranking_coeffs: + if len(ranking_coeffs) != 3: + raise ValueError( + "If ranking_coeffs are provided to `filter_trees` method, a list of three values is expected." + ) + coeffs = [branching_process_ranking_coeff] + list(ranking_coeffs) + if sum(abs(c) for c in coeffs) == 0: + raise ValueError( + "At least one value provided to ranking_coeffs or the value of branching_process_ranking_coeff must be nonzero." + ) else: + coeffs = [1] * 4 + + ( + nz_coeff_bplikelihood, + nz_coeff_isotype_pars, + nz_coeff_context, + nz_coeff_alleles, + ) = [val != 0 for val in coeffs] + coeff_bplikelihood, coeff_isotype_pars, coeff_context, coeff_alleles = coeffs + + dag_filters = [] + if nz_coeff_bplikelihood: + if self.parameters is None: + self.mle(marginal=True) + p, q = self.parameters + ll_dagfuncs = _ll_genotype_dagfuncs(p, q) + dag_filters.append((ll_dagfuncs, coeff_bplikelihood)) if verbose: - print("Isotype parsimony will be used as a ranking criterion") + print(f"Branching process parameters to be used for ranking: {(p, q)}") + if nz_coeff_isotype_pars and self.is_isotyped and (not ignore_isotype): # Check for missing isotype data in all but root node, and fake root-adjacent leaf node rootname = list(self._forest.dagroot.children())[0].attr["name"] if any( @@ -1202,43 +1230,87 @@ def filter_trees( ) iso_funcs = _isotype_dagfuncs() - if mutability_file and substitution_file: - if verbose: - print("Mutation model parsimony will be used as a ranking criterion") + dag_filters.append((iso_funcs, coeff_isotype_pars)) + if nz_coeff_context and mutability_file and substitution_file: + if use_old_mut_parsimony: + mut_funcs = _mutability_dagfuncs( + mutability_file=mutability_file, + substitution_file=substitution_file, + splits=[] if chain_split is None else [chain_split], + ) + else: + mut_funcs = _context_poisson_likelihood_dagfuncs( + mutability_file=mutability_file, + substitution_file=substitution_file, + splits=[] if chain_split is None else [chain_split], + ) + dag_filters.append((mut_funcs, coeff_context)) + if nz_coeff_alleles: + allele_funcs = _allele_dagfuncs() + dag_filters.append((allele_funcs, coeff_alleles)) - mut_funcs = _mutability_dagfuncs( - mutability_file=mutability_file, - substitution_file=substitution_file, - splits=[] if chain_split is None else [chain_split], - ) - else: - mut_funcs = placeholder_dagfuncs - allele_funcs = _allele_dagfuncs() - kwargls = (ll_dagfuncs, iso_funcs, mut_funcs, allele_funcs) + combined_dag_filter = functools.reduce( + lambda x, y: x + y, (dag_filter for dag_filter, _ in dag_filters) + ) if ranking_coeffs: if len(ranking_coeffs) != 3: raise ValueError( "If ranking_coeffs are provided to `filter_trees` method, a list of three values is expected." ) - coeffs = [-1] + list(ranking_coeffs) + for dag_filter, coeff in dag_filters: + if dag_filter.optimal_func == max and coeff > 0: + warnings.warn( + f"Higher values for {dag_filter.weight_funcs.name} are generally better, but the " + "provided ranking coefficient is positive, so trees with lower values will be preferred." + ) + if dag_filter.optimal_func == min and coeff < 0: + warnings.warn( + f"Lower values for {dag_filter.weight_funcs.name} are generally better, but the " + "provided ranking coefficient is negative, so trees with higher values will be preferred." + ) - def minfunckey(weighttuple): - """Weighttuple will have (ll, isotypepars, mutabilitypars, - alleles)""" + filtered_coefficients = [coeff for _, coeff in dag_filters] + + def linear_combinator(weighttuple): return sum( [ priority * float(weight) - for priority, weight in zip(coeffs, weighttuple) + for priority, weight in zip(filtered_coefficients, weighttuple) ] ) + old_edge_func = combined_dag_filter["edge_weight_func"] + ranking_dag_filter = hdag.utils.HistoryDagFilter( + hdag.utils.AddFuncDict( + { + "start_func": lambda n: 0, + "edge_weight_func": lambda n1, n2: linear_combinator( + old_edge_func(n1, n2) + ), + "accum_func": sum, + } + ), + min, + ordering_name="LinearCombination", + ) + ranking_description = ( + "Ranking trees to minimize a linear combination of " + + " + ".join( + str(coeff) + "(" + fl.weight_funcs.name + ")" + for fl, coeff in dag_filters + ) + ) else: - - def minfunckey(weighttuple): - """Weighttuple will have (ll, isotypepars, mutabilitypars, - alleles)""" - # Sort output by likelihood, then isotype parsimony, then mutability score - return (-weighttuple[0],) + weighttuple[1:-1] + ranking_dag_filter = combined_dag_filter + ranking_description = "Ranking trees to " + " then ".join( + opt_name[:3] + "imize " + ord_name + for (opt_name, _), ord_name in zip( + ranking_dag_filter.ordering_names, + ranking_dag_filter.weight_funcs.names, + ) + ) + if verbose: + print(ranking_description) def print_stats(statlist, title, file=None, suppress_score=False): show_score = ranking_coeffs and not suppress_score @@ -1249,87 +1321,63 @@ def reformat(field, n=10): else: return f"{field:{n}.{n}}" - def mask(weighttuple, n=10): - return tuple( - reformat(field, n=n) - for field, kwargs in zip(weighttuple, kwargls) - if kwargs.name - ) - - print(f"Parameters: {(p, q)}", file=file) print("\n" + title + ":", file=file) - statstring = "\t".join(mask(tuple(kwargs.name for kwargs in kwargls), n=14)) + statstring = "\t".join( + tuple( + reformat(dfilter.weight_funcs.name, n=14) + for dfilter, _ in dag_filters + ) + ) print( f"tree \t{statstring}" + ("\ttreescore" if show_score else ""), file=file, ) for j, best_weighttuple in enumerate(statlist, 1): - statstring = "\t".join(mask(best_weighttuple)) + statstring = "\t".join(reformat(it) for it in best_weighttuple) print( f"{j:<10}\t{statstring}" + ( - f"\t{reformat(minfunckey(best_weighttuple))}" + f"\t{reformat(linear_combinator(best_weighttuple))}" if show_score else "" ), file=file, ) - # Filter by likelihood, isotype parsimony, mutability, - # and make ctrees, cforest, and render trees - dagweight_kwargs = ll_dagfuncs + iso_funcs + mut_funcs + allele_funcs - trimdag = dag.copy() - trimdag.trim_optimal_weight( - **dagweight_kwargs, - optimal_func=lambda l: min(l, key=minfunckey), # noqa: E741 - ) - # make sure trimming worked as expected: - min_weightcounter = trimdag.weight_count(**dagweight_kwargs) - min_weightset = {minfunckey(key) for key in min_weightcounter} - if len(min_weightset) != 1: - raise RuntimeError( - "Filtering was not successful. After trimming, these weights are represented:", - min_weightset, - ) + trimdag = dag[ranking_dag_filter] best_weighttuple = trimdag.optimal_weight_annotate( - **dagweight_kwargs, - optimal_func=lambda l: min(l, key=minfunckey), # noqa: E741 + **combined_dag_filter, ) + + if verbose: + print_stats([best_weighttuple], "Stats for optimal trees") + if summarize_forest: with open(outbase + ".forest_summary.log", "w") as fh: independent_best = [] - for kwargs in kwargls: - # Only summarize for stats for which information was - # provided (not just placeholders): - if kwargs.name: - independent_best.append([]) - for opt in [min, max]: - tempdag = dag.copy() - opt_weight = tempdag.trim_optimal_weight( - **kwargs, optimal_func=opt + for dfilter, _ in dag_filters: + tempdag = dag.copy() + min_val, max_val = tempdag.weight_range_annotate(**dfilter) + opt_weight = tempdag.trim_optimal_weight(**dfilter) + independent_best.append(opt_weight) + fh.write( + f"\nOverall {dfilter.weight_funcs.name} range {min_val} to {max_val}." + f"\nAmong trees with {dfilter.optimal_func.__name__} {dfilter.weight_funcs.name} of: {opt_weight}\n" + ) + for indfilter, _ in dag_filters: + if indfilter.weight_funcs.name != dfilter.weight_funcs.name: + minval, maxval = tempdag.weight_range_annotate( + **indfilter.weight_funcs ) - independent_best[-1].append(opt_weight) fh.write( - f"\nAmong trees with {opt.__name__} {kwargs.name} of: {opt_weight}\n" + f"\t{indfilter.weight_funcs.name} range: {minval} to {maxval}\n" ) - for inkwargs in kwargls: - if inkwargs != kwargs and inkwargs.name: - minval = tempdag.optimal_weight_annotate( - **inkwargs, optimal_func=min - ) - maxval = tempdag.optimal_weight_annotate( - **inkwargs, optimal_func=max - ) - fh.write( - f"\t{inkwargs.name} range: {minval} to {maxval}\n" - ) - independent_best[0].reverse() print("\n", file=fh) print_stats( [ [ - stat - best[0] + stat - best for stat, best in zip(best_weighttuple, independent_best) ] ], @@ -1339,26 +1387,29 @@ def mask(weighttuple, n=10): ) if tree_stats: - dag_ls = list(dag.weight_count(**dagweight_kwargs).elements()) + dag_ls = list(dag.weight_count(**combined_dag_filter).elements()) # To clear _dp_data fields of their large cargo dag.optimal_weight_annotate(edge_weight_func=lambda n1, n2: 0) + if ranking_coeffs: + minfunckey = linear_combinator + else: + minfunckey = ranking_dag_filter.optimal_func dag_ls.sort(key=minfunckey) - df = pd.DataFrame(dag_ls, columns=dagweight_kwargs.names) + df = pd.DataFrame(dag_ls, columns=combined_dag_filter.weight_funcs.names) df.to_csv(outbase + ".tree_stats.csv") df["set"] = ["all_trees"] * len(df) - bestdf = pd.DataFrame([best_weighttuple], columns=dagweight_kwargs.names) + bestdf = pd.DataFrame( + [best_weighttuple], columns=combined_dag_filter.weight_funcs.names + ) bestdf["set"] = ["best_tree"] toplot_df = pd.concat([df, bestdf], ignore_index=True) pplot = sns.pairplot( - toplot_df[["Log Likelihood", "Isotype Pars.", "Mut. Pars.", "set"]], + toplot_df.drop(["Alleles"], errors="ignore"), hue="set", diag_kind="hist", ) - pplot.savefig(outbase + ".tree_stats.pairplot.png") - - if verbose: - print_stats([best_weighttuple], "Stats for optimal trees") + pplot.savefig(outbase + ".tree_stats.pairplot.pdf") return (self._trimmed_self(trimdag), best_weighttuple) @@ -1631,7 +1682,7 @@ def _mle_helper( bounds = ((1e-6, 1 - 1e-6), (1e-6, 1 - 1e-6)) def f(x): - """Negative log likelihood.""" + """Negative log branching process likelihood.""" return tuple(-y for y in ll(*x, **kwargs)) grad_check = sco.check_grad(lambda x: f(x)[0], lambda x: f(x)[1], x_0) @@ -1868,8 +1919,9 @@ def accum_func(cmsetlist: List[multiset.FrozenMultiset]): ) -def _ll_genotype_dagfuncs(p: np.float64, q: np.float64) -> hdag.utils.AddFuncDict: - """Return functions for counting tree log likelihood on the history DAG. +def _ll_genotype_dagfuncs(p: np.float64, q: np.float64) -> hdag.utils.HistoryDagFilter: + """Return functions for counting tree log branching process likelihood on + the history DAG. For numerical consistency, we resort to the use of ``decimal.Decimal``. This is exactly for the purpose of solving the problem that float sum is @@ -1884,7 +1936,7 @@ def _ll_genotype_dagfuncs(p: np.float64, q: np.float64) -> hdag.utils.AddFuncDic p, q: branching process parameters Returns: - A :meth:`historydag.utils.AddFuncDict` which may be passed as keyword arguments + A :meth:`historydag.utils.HistoryDagFilter` which may be passed as keyword arguments to :meth:`historydag.HistoryDag.weight_count`, :meth:`historydag.HistoryDag.trim_optimal_weight`, or :meth:`historydag.HistoryDag.optimal_weight_annotate` methods to trim or annotate a :meth:`historydag.HistoryDag` according to branching process likelihood. @@ -1916,17 +1968,20 @@ def accum_func(weightlist): res = sum(weight.state for weight in weightlist) return hdag.utils.FloatState(float(round(res, 8)), state=res) - return hdag.utils.AddFuncDict( - { - "start_func": lambda n: hdag.utils.FloatState(0.0, state=Decimal(0)), - "edge_weight_func": edge_weight_ll_genotype, - "accum_func": accum_func, - }, - name="Log Likelihood", + return hdag.utils.HistoryDagFilter( + hdag.utils.AddFuncDict( + { + "start_func": lambda n: hdag.utils.FloatState(0.0, state=Decimal(0)), + "edge_weight_func": edge_weight_ll_genotype, + "accum_func": accum_func, + }, + name="LogBPLikelihood", + ), + max, ) -def _allele_dagfuncs() -> hdag.utils.AddFuncDict: +def _allele_dagfuncs() -> hdag.utils.HistoryDagFilter: """Return functions for filtering trees in a history DAG by allele count. The number of alleles in a tree is the number of unique sequences observed on nodes of that tree. @@ -1938,11 +1993,14 @@ def _allele_dagfuncs() -> hdag.utils.AddFuncDict: methods to trim or annotate a :meth:`historydag.HistoryDag` according to allele count. Weight format is ``int``. """ - return hdag.utils.AddFuncDict( - { - "start_func": lambda n: 0, - "edge_weight_func": lambda n1, n2: n1.label != n2.label, - "accum_func": sum, - }, - name="Alleles", + return hdag.utils.HistoryDagFilter( + hdag.utils.AddFuncDict( + { + "start_func": lambda n: 0, + "edge_weight_func": lambda n1, n2: n1.label != n2.label, + "accum_func": sum, + }, + name="Alleles", + ), + min, ) diff --git a/gctree/cli.py b/gctree/cli.py index c878cd24..46f0ee72 100644 --- a/gctree/cli.py +++ b/gctree/cli.py @@ -209,6 +209,8 @@ def isotype_add(forest): mutability_file=args.mutability, substitution_file=args.substitution, chain_split=args.chain_split, + branching_process_ranking_coeff=args.branching_process_ranking_coeff, + use_old_mut_parsimony=args.use_old_mut_parsimony, ) if args.verbose: @@ -535,7 +537,7 @@ def get_parser(): help=( "when using concatenated heavy and light chains, this is the 0-based" " index at which the 2nd chain begins, needed for determining coding frame in both chains," - " and also to correctly calculate mutability parsimony." + " and also to correctly calculate context-based Poisson likelihood." ), ) parser_infer.add_argument( @@ -610,6 +612,16 @@ def get_parser(): "See a file excerpt in the documentation for :meth:`mutation_model.MutationModel`." ), ) + parser_infer.add_argument( + "--branching_process_ranking_coeff", + type=float, + default=-1, + help=( + "Coefficient used for branching process likelihood, when ranking trees by a linear " + "combination of traits. This value will be ignored if `--ranking_coeffs` argument is not " + "also provided." + ), + ) parser_infer.add_argument( "--ranking_coeffs", type=float, @@ -620,7 +632,17 @@ def get_parser(): "Coefficients are in order: isotype parsimony, mutation model parsimony, number of alleles. " "A coefficient of -1 will be applied to branching process likelihood. " "If not provided, trees will be ranked lexicographically by likelihood, " - "isotype parsimony, and mutability parsimony in that order." + "isotype parsimony, and context-based Poisson likelihood in that order." + ), + ) + parser_infer.add_argument( + "--use_old_mut_parsimony", + action="store_true", + help=( + "Use old mutability parsimony instead of poisson context likelihood. Not recommended " + "unless attempting to reproduce results from older versions of gctree. " + "This argument will have no effect unless an S5F model is provided with the arguments " + "`--mutability` and `--substitution`." ), ) parser_infer.add_argument( diff --git a/gctree/isotype.py b/gctree/isotype.py index 31055ac9..abe9ec39 100644 --- a/gctree/isotype.py +++ b/gctree/isotype.py @@ -46,7 +46,7 @@ def get_parser() -> argparse.ArgumentParser: " nodes.\n\n" "This tool doesn’t make any judgements about which tree is best.\n" "Tree output order is the same as in gctree inference: ranking is\n" - "by log likelihood before isotype additions. A determination of\n" + "by branching process likelihood before isotype additions. A determination of\n" "which is the best tree is left to the user, based on likelihoods,\n" "isotype parsimony score, and changes in the number of nodes after\n" "isotype additions.\n" diff --git a/gctree/isotyping.py b/gctree/isotyping.py index 9695bec5..15d25a1a 100644 --- a/gctree/isotyping.py +++ b/gctree/isotyping.py @@ -405,7 +405,7 @@ def explode_idmap( return newidmap -def _isotype_dagfuncs() -> hdag.utils.AddFuncDict: +def _isotype_dagfuncs() -> hdag.utils.HistoryDagFilter: """Return functions for filtering by isotype parsimony score on the history DAG. @@ -435,13 +435,16 @@ def edge_weight_func(n1: hdag.HistoryDagNode, n2: hdag.HistoryDagNode): n1iso = list(n1isos.keys())[0] return int(sum(isotype_distance(n1iso, n2iso) for n2iso in n2isos.keys())) - return hdag.utils.AddFuncDict( - { - "start_func": lambda n: 0, - "edge_weight_func": edge_weight_func, - "accum_func": sum, - }, - name="Isotype Pars.", + return hdag.utils.HistoryDagFilter( + hdag.utils.AddFuncDict( + { + "start_func": lambda n: 0, + "edge_weight_func": edge_weight_func, + "accum_func": sum, + }, + name="Isotype Pars.", + ), + min, ) diff --git a/gctree/mutation_model.py b/gctree/mutation_model.py index 41019a71..2a29ae38 100644 --- a/gctree/mutation_model.py +++ b/gctree/mutation_model.py @@ -10,6 +10,8 @@ import historydag as hdag from multiset import FrozenMultiset from typing import Tuple, List, Callable, Optional +import itertools +import math class MutationModel: @@ -129,20 +131,25 @@ def mutability(self, kmer: str) -> Tuple[np.float64, np.float64]: "sequence {} must contain only characters A, C, G, T, or N".format(kmer) ) - mutabilities_to_average, substitutions_to_average = zip( - *[self.context_model[x] for x in MutationModel._disambiguate(kmer)] - ) - - average_mutability = np.mean(mutabilities_to_average) - average_substitution = { - b: sum( - substitution_dict[b] for substitution_dict in substitutions_to_average + cached = self.context_model.get(kmer, None) + if cached is None: + mutabilities_to_average, substitutions_to_average = zip( + *[self.context_model[x] for x in MutationModel._disambiguate(kmer)] ) - / len(substitutions_to_average) - for b in "ACGT" - } - return average_mutability, average_substitution + average_mutability = np.mean(mutabilities_to_average) + average_substitution = { + b: sum( + substitution_dict[b] + for substitution_dict in substitutions_to_average + ) + / len(substitutions_to_average) + for b in "ACGT" + } + cached = average_mutability, average_substitution + self.context_model[kmer] = cached + + return cached def mutabilities(self, sequence: str) -> List[Tuple[np.float64, np.float64]]: r"""Returns the mutability of a sequence at each site, along with @@ -440,7 +447,7 @@ def _sequence_disambiguations(sequence, _accum=""): def _mutability_dagfuncs( *args, splits: List[int] = [], **kwargs -) -> hdag.utils.AddFuncDict: +) -> hdag.utils.HistoryDagFilter: """Return functions for counting mutability parsimony on the history DAG. Mutability parsimony of a tree is the sum over all edges in the tree @@ -478,9 +485,16 @@ def distance(node1, node2): else: return dist(node1.label.sequence, node2.label.sequence) - return hdag.utils.AddFuncDict( - {"start_func": lambda n: 0, "edge_weight_func": distance, "accum_func": sum}, - name="Mut. Pars.", + return hdag.utils.HistoryDagFilter( + hdag.utils.AddFuncDict( + { + "start_func": lambda n: 0, + "edge_weight_func": distance, + "accum_func": sum, + }, + name="Mut. Pars.", + ), + min, ) @@ -488,26 +502,21 @@ def _mutability_distance_precursors( mutation_model: MutationModel, splits: List[int] = [] ): chunk_idxs = list(zip([0] + splits, splits + [None])) - # Caching could be moved to the MutationModel class instead. - context_model = mutation_model.context_model.copy() - k = mutation_model.k - h = k // 2 - # Build all sequences with (when k=5) one or two Ns on either end - templates = [ - ("N" * left, "N" * (k - left - right), "N" * right) - for left in range(h + 1) - for right in range(h + 1) - if left != 0 or right != 0 - ] - - kmers_to_compute = [ - leftns + stub + rightns - for leftns, ambig_stub, rightns in templates - for stub in _sequence_disambiguations(ambig_stub) - ] - # Cache all these mutabilities in context_model also - context_model.update( - {kmer: mutation_model.mutability(kmer) for kmer in kmers_to_compute} + + h = mutation_model.k // 2 + + # Pads sequence with N's, including in the chain-split boundary to + # avoid unrelated sites from being treated as part of each others' context. + + # Indices at which padding N's will be in sequences returned from add_ns. + # Does not include indices of last two N's. + padding_indices = set( + itertools.chain.from_iterable( + [ + range(split + idx * h, split + (idx + 1) * h) + for idx, split in enumerate([0] + splits) + ] + ) ) def add_ns(seq: str): @@ -535,8 +544,8 @@ def sum_minus_logp(pairs: FrozenMultiset): p_arr = [ mult * ( - np.log(context_model[mer][0]) - + np.log(context_model[mer][1][newbase]) + np.log(mutation_model.mutability(mer)[0]) + + np.log(mutation_model.mutability(mer)[1][newbase]) ) for (mer, newbase), mult in pairs ] @@ -544,7 +553,17 @@ def sum_minus_logp(pairs: FrozenMultiset): else: return 0.0 - return (mutpairs, sum_minus_logp) + def mutability_sum(parent_seq): + padded_seq = add_ns(parent_seq) + for idx in padding_indices: + assert padded_seq[idx] == "N" + return sum( + mutation_model.mutability(padded_seq[idx - h : idx + h + 1])[0] + for idx, _ in enumerate(padded_seq[:-h]) + if idx not in padding_indices + ) + + return (mutpairs, sum_minus_logp, mutability_sum) def _mutability_distance(mutation_model: MutationModel, splits=[]): @@ -562,7 +581,7 @@ def _mutability_distance(mutation_model: MutationModel, splits=[]): Note that, in particular, this function is not symmetric on its arguments. """ - mutpairs, sum_minus_logp = _mutability_distance_precursors( + mutpairs, sum_minus_logp, _ = _mutability_distance_precursors( mutation_model, splits=splits ) @@ -570,3 +589,46 @@ def distance(seq1, seq2): return sum_minus_logp(mutpairs(seq1, seq2)) return distance + + +def _context_poisson_likelihood(mutation_model: MutationModel, splits=[]): + mutpairs, sum_minus_logp, mutability_sum = _mutability_distance_precursors( + mutation_model, splits=splits + ) + + def distance(seq1, seq2): + subs = mutpairs(seq1, seq2) + sub_count = len(subs) + if sub_count == 0: + return 0 + else: + mut_sum = mutability_sum(seq1) + substitution_sum = -sum_minus_logp(subs) + return ( + substitution_sum + + (sub_count * (math.log(sub_count) - math.log(mut_sum))) + - sub_count + ) + + return distance + + +def _context_poisson_likelihood_dagfuncs(*args, splits: List[int] = [], **kwargs): + mutation_model = MutationModel(*args, **kwargs) + distance = _context_poisson_likelihood(mutation_model, splits=splits) + + return hdag.utils.HistoryDagFilter( + hdag.utils.AddFuncDict( + { + "start_func": lambda n: 0, + "edge_weight_func": lambda n1, n2: ( + 0 + if n1.is_ua_node() + else distance(n1.label.sequence, n2.label.sequence) + ), + "accum_func": sum, + }, + name="LogContextLikelihood", + ), + max, + ) diff --git a/tests/smalltest.sh b/tests/smalltest.sh index d90387e6..7f4d69f9 100755 --- a/tests/smalltest.sh +++ b/tests/smalltest.sh @@ -7,7 +7,15 @@ export MPLBACKEND=agg mkdir -p tests/smalltest_output wget -O HS5F_Mutability.csv https://bitbucket.org/kleinstein/shazam/raw/ba4b30fc6791e2cfd5712e9024803c53b136e664/data-raw/HS5F_Mutability.csv wget -O HS5F_Substitution.csv https://bitbucket.org/kleinstein/shazam/raw/ba4b30fc6791e2cfd5712e9024803c53b136e664/data-raw/HS5F_Substitution.csv + +gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv --ranking_coeffs 1 1 0 --use_old_mut_parsimony --branching_process_ranking_coeff 0 + +gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv --ranking_coeffs .01 -1 0 --branching_process_ranking_coeff -1 --summarize_forest --tree_stats + gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel + gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt + gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv + gctree infer tests/small_outfile tests/abundances.csv --outbase tests/smalltest_output/gctree.infer --root GL --frame 1 --verbose --idlabel --idmapfile tests/idmap.txt --isotype_mapfile tests/isotypemap.txt --mutability HS5F_Mutability.csv --substitution HS5F_Substitution.csv diff --git a/tests/test_isotype.py b/tests/test_isotype.py index fe1695d8..e6d280fb 100644 --- a/tests/test_isotype.py +++ b/tests/test_isotype.py @@ -51,9 +51,9 @@ def test_trim_byisotype(): for node in tdag.preorder(): if node.attr is not None: node.attr["isotype"] = node._dp_data - kwargs = _isotype_dagfuncs() - c = tdag.weight_count(**kwargs) + dag_filter = _isotype_dagfuncs() + c = tdag.weight_count(**dag_filter) key = min(c) count = c[key] - tdag.trim_optimal_weight(**kwargs, optimal_func=min) - assert tdag.weight_count(**kwargs) == {key: count} + tdag.trim_optimal_weight(**dag_filter) + assert tdag.weight_count(**dag_filter) == {key: count} diff --git a/tests/test_likelihoods.py b/tests/test_likelihoods.py index 538adafa..157a2005 100644 --- a/tests/test_likelihoods.py +++ b/tests/test_likelihoods.py @@ -1,6 +1,8 @@ import gctree.branching_processes as bp import gctree.phylip_parse as pp import gctree.utils as utils +import gctree.mutation_model as mm +from math import log import numpy as np from multiset import FrozenMultiset @@ -198,3 +200,37 @@ def test_recursion_depth(): bp.CollapsedTree._max_ll_cache = {} with np.errstate(all="raise"): bp.CollapsedTree._ll_genotype(2, 500, 0.4, 0.6) + + +def test_context_likelihood(): + # These files will be present if pytest is run through `make test`. + mutation_model = mm.MutationModel( + mutability_file="HS5F_Mutability.csv", substitution_file="HS5F_Substitution.csv" + ) + log_likelihood = mm._context_poisson_likelihood(mutation_model, splits=[]) + + parent_seq = "AAGAAA" + child_seq = "AATCAA" + + term1 = sum( + log( + mutation_model.mutability(fivemer)[0] + * mutation_model.mutability(fivemer)[1][target_base] + ) + for fivemer, target_base in [("AAGAA", "T"), ("AGAAA", "C")] + ) + sum_mutabilities = sum( + mutation_model.mutability(fivemer)[0] + for fivemer in ["NNAAG", "NAAGA", "AAGAA", "AGAAA", "GAAAN", "AAANN"] + ) + true_val = term1 + 2 * log(2 / sum_mutabilities) - 2 + assert true_val == log_likelihood(parent_seq, child_seq) + + # Now test chain split: + parent_seq = parent_seq + parent_seq + child_seq = child_seq + child_seq + # At index 6, the second concatenated sequence starts. + log_likelihood = mm._context_poisson_likelihood(mutation_model, splits=[6]) + + true_val = 2 * term1 + 4 * log(4 / (2 * sum_mutabilities)) - 4 + assert true_val == log_likelihood(parent_seq, child_seq)