Skip to content

Commit

Permalink
Update logic around VCF sorting with scatter/gather
Browse files Browse the repository at this point in the history
  • Loading branch information
bbimber committed Feb 7, 2024
1 parent 6efdb41 commit b122c55
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,22 @@ default void complete(PipelineJob job, List<SequenceOutputFile> inputs, List<Seq

enum ScatterGatherMethod
{
none(),
contig(),
chunked(),
fixedJobs()
none(false),
contig(false),
chunked(true),
fixedJobs(false);

private final boolean _mayRequireSort;

ScatterGatherMethod(boolean mayRequireSort)
{
_mayRequireSort = mayRequireSort;
}

public boolean mayRequireSort()
{
return _mayRequireSort;
}
}

interface Output extends PipelineStepOutput
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,16 @@ private void validateScatterForTask()
sg.validateScatter(getScatterGatherMethod(), this);
}

public boolean scatterMethodRequiresSort()
{
if (_scatterGatherMethod == null || !_scatterGatherMethod.mayRequireSort())
{
return false;
}

return !doAllowSplitContigs();
}

private LinkedHashMap<String, List<Interval>> establishIntervals()
{
LinkedHashMap<String, List<Interval>> ret;
Expand All @@ -137,7 +147,7 @@ else if (_scatterGatherMethod == VariantProcessingStep.ScatterGatherMethod.chunk
getLogger().info("Creating jobs with target bp size: " + basesPerJob + " mbp. allow splitting configs: " + allowSplitChromosomes + ", max contigs per job: " + maxContigsPerJob);

basesPerJob = basesPerJob * 1000000;
ret = ScatterGatherUtils.divideGenome(dict, basesPerJob, allowSplitChromosomes, maxContigsPerJob);
ret = ScatterGatherUtils.divideGenome(dict, basesPerJob, allowSplitChromosomes, maxContigsPerJob, scatterMethodRequiresSort());

}
else if (_scatterGatherMethod == VariantProcessingStep.ScatterGatherMethod.fixedJobs)
Expand All @@ -146,7 +156,7 @@ else if (_scatterGatherMethod == VariantProcessingStep.ScatterGatherMethod.fixed
int numJobs = getParameterJson().getInt("scatterGather.totalJobs");
int jobSize = (int)Math.ceil(totalSize / (double)numJobs);
getLogger().info("Creating " + numJobs + " jobs with approximate size: " + jobSize + " bp.");
ret = ScatterGatherUtils.divideGenome(dict, jobSize, true, -1);
ret = ScatterGatherUtils.divideGenome(dict, jobSize, true, -1, false);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ else if (!vcf.exists())
throw new PipelineJobException("Missing one of more VCFs: " + missing.stream().map(File::getPath).collect(Collectors.joining(",")));
}

boolean sortAfterMerge = handler instanceof VariantProcessingStep.SupportsScatterGather && ((VariantProcessingStep.SupportsScatterGather) handler).doSortAfterMerge();
boolean sortAfterMerge = getPipelineJob().scatterMethodRequiresSort() || handler instanceof VariantProcessingStep.SupportsScatterGather && ((VariantProcessingStep.SupportsScatterGather) handler).doSortAfterMerge();
combined = SequenceAnalysisService.get().combineVcfs(toConcat, combined, genome, getJob().getLogger(), true, null, sortAfterMerge);
}
manager.addOutput(action, "Merged VCF", combined);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import org.junit.Test;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -113,13 +112,17 @@ private void addInterval(String refName, int start, int end)
}
}

public static LinkedHashMap<String, List<Interval>> divideGenome(SAMSequenceDictionary dict, int optimalBasesPerJob, boolean allowSplitChromosomes, int maxContigsPerJob)
public static LinkedHashMap<String, List<Interval>> divideGenome(SAMSequenceDictionary dict, int optimalBasesPerJob, boolean allowSplitChromosomes, int maxContigsPerJob, boolean sortOnContigSize)
{
ActiveIntervalSet ais = new ActiveIntervalSet(optimalBasesPerJob, allowSplitChromosomes, maxContigsPerJob);

// Sort the sequences in descending length, rather than alphabetic on name:
List<SAMSequenceRecord> sortedSeqs = new ArrayList<>(dict.getSequences());
sortedSeqs.sort(Comparator.comparingInt(SAMSequenceRecord::getSequenceLength).reversed());
if (sortOnContigSize)
{
sortedSeqs.sort(Comparator.comparingInt(SAMSequenceRecord::getSequenceLength).reversed());
}

for (SAMSequenceRecord rec : sortedSeqs)
{
ais.add(rec);
Expand Down Expand Up @@ -152,13 +155,13 @@ private SAMSequenceDictionary getDict()
public void testScatter()
{
SAMSequenceDictionary dict = getDict();
Map<String, List<Interval>> ret = divideGenome(dict, 1000, true, -1);
Map<String, List<Interval>> ret = divideGenome(dict, 1000, true, -1, true);
assertEquals("Incorrect number of jobs", 8, ret.size());
assertEquals("Incorrect interval end", 1000, ret.get("Job3").get(0).getEnd());
assertEquals("Incorrect start", 1, ret.get("Job3").get(0).getStart());
assertEquals("Incorrect interval end", 4, ret.get("Job8").size());

Map<String, List<Interval>> ret2 = divideGenome(dict, 3000, false, -1);
Map<String, List<Interval>> ret2 = divideGenome(dict, 3000, false, -1, true);
assertEquals("Incorrect number of jobs", 3, ret2.size());
for (String jobName : ret2.keySet())
{
Expand All @@ -168,7 +171,7 @@ public void testScatter()
}
}

Map<String, List<Interval>> ret3 = divideGenome(dict, 3002, false, -1);
Map<String, List<Interval>> ret3 = divideGenome(dict, 3002, false, -1, true);
assertEquals("Incorrect number of jobs", 3, ret3.size());
for (String jobName : ret3.keySet())
{
Expand All @@ -178,7 +181,7 @@ public void testScatter()
}
}

Map<String, List<Interval>> ret4 = divideGenome(dict, 2999, false, -1);
Map<String, List<Interval>> ret4 = divideGenome(dict, 2999, false, -1, true);
assertEquals("Incorrect number of jobs", 3, ret4.size());
for (String jobName : ret4.keySet())
{
Expand All @@ -188,15 +191,15 @@ public void testScatter()
}
}

Map<String, List<Interval>> ret5 = divideGenome(dict, 750, true, -1);
Map<String, List<Interval>> ret5 = divideGenome(dict, 750, true, -1, true);
assertEquals("Incorrect number of jobs", 9, ret5.size());
assertEquals("Incorrect interval end", 750, ret5.get("Job1").get(0).getEnd());
assertEquals("Incorrect interval end", 4, ret5.get("Job9").size());

assertEquals("Incorrect interval start", 1501, ret5.get("Job3").get(0).getStart());
assertEquals("Incorrect interval start", 1, ret5.get("Job8").get(0).getStart());

Map<String, List<Interval>> ret6 = divideGenome(dict, 5000, false, 2);
Map<String, List<Interval>> ret6 = divideGenome(dict, 5000, false, 2, true);
assertEquals("Incorrect number of jobs", 5, ret6.size());
}
}
Expand Down

0 comments on commit b122c55

Please sign in to comment.