Skip to content

Commit

Permalink
Update Scatter/Gather logic to sort on contig size
Browse files Browse the repository at this point in the history
  • Loading branch information
bbimber committed Jan 4, 2024
1 parent 0c73722 commit 0a65541
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,17 @@ public void processFilesRemote(List<SequenceOutputFile> inputFiles, JobContext c
List<File> outputs = new ArrayList<>();
if (getVariantPipelineJob(ctx.getJob()) != null && getVariantPipelineJob(ctx.getJob()).isScatterJob())
{
for (Interval i : getVariantPipelineJob(ctx.getJob()).getIntervalsForTask())
int idx = 0;
List<Interval> intervals = getVariantPipelineJob(ctx.getJob()).getIntervalsForTask();
for (Interval i : intervals)
{
idx++;
if (i.getStart() != 1)
{
throw new PipelineJobException("Expected all intervals to start on the first base: " + i);
}

File o = runPbsvCall(ctx, filesToProcess, genome, outputBaseName + (getVariantPipelineJob(ctx.getJob()).getIntervalsForTask().size() == 1 ? "" : "." + i.getContig()), i.getContig(), jobCompleted);
File o = runPbsvCall(ctx, filesToProcess, genome, outputBaseName + (getVariantPipelineJob(ctx.getJob()).getIntervalsForTask().size() == 1 ? "" : "." + i.getContig()), i.getContig(), (" (" + idx + " of " + intervals.size() + ")"), jobCompleted);
if (o != null)
{
outputs.add(o);
Expand All @@ -167,7 +170,7 @@ public void processFilesRemote(List<SequenceOutputFile> inputFiles, JobContext c
}
else
{
outputs.add(runPbsvCall(ctx, filesToProcess, genome, outputBaseName, null, jobCompleted));
outputs.add(runPbsvCall(ctx, filesToProcess, genome, outputBaseName, null, null, jobCompleted));
}

try
Expand Down Expand Up @@ -228,11 +231,11 @@ public void processFilesRemote(List<SequenceOutputFile> inputFiles, JobContext c
}
}

private File runPbsvCall(JobContext ctx, List<File> inputs, ReferenceGenome genome, String outputBaseName, @Nullable String contig, boolean jobCompleted) throws PipelineJobException
private File runPbsvCall(JobContext ctx, List<File> inputs, ReferenceGenome genome, String outputBaseName, @Nullable String contig, @Nullable String statusSuffix, boolean jobCompleted) throws PipelineJobException
{
if (contig != null)
{
ctx.getJob().setStatus(PipelineJob.TaskStatus.running, "Processing: " + contig);
ctx.getJob().setStatus(PipelineJob.TaskStatus.running, "Processing: " + contig + (statusSuffix == null ? "" : statusSuffix));
}

if (inputs.isEmpty())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import org.junit.Test;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
Expand Down Expand Up @@ -114,7 +116,11 @@ private void addInterval(String refName, int start, int end)
public static LinkedHashMap<String, List<Interval>> divideGenome(SAMSequenceDictionary dict, int optimalBasesPerJob, boolean allowSplitChromosomes, int maxContigsPerJob)
{
ActiveIntervalSet ais = new ActiveIntervalSet(optimalBasesPerJob, allowSplitChromosomes, maxContigsPerJob);
for (SAMSequenceRecord rec : dict.getSequences())

// Sort the sequences in descending length, rather than alphabetic on name:
List<SAMSequenceRecord> sortedSeqs = new ArrayList<>(dict.getSequences());
sortedSeqs.sort(Comparator.comparingInt(SAMSequenceRecord::getSequenceLength).reversed());
for (SAMSequenceRecord rec : sortedSeqs)
{
ais.add(rec);
}
Expand Down Expand Up @@ -148,8 +154,8 @@ public void testScatter()
SAMSequenceDictionary dict = getDict();
Map<String, List<Interval>> ret = divideGenome(dict, 1000, true, -1);
assertEquals("Incorrect number of jobs", 8, ret.size());
assertEquals("Incorrect interval end", 2000, ret.get("Job3").get(0).getEnd());
assertEquals("Incorrect start", 1001, ret.get("Job3").get(0).getStart());
assertEquals("Incorrect interval end", 1000, ret.get("Job3").get(0).getEnd());
assertEquals("Incorrect start", 1, ret.get("Job3").get(0).getStart());
assertEquals("Incorrect interval end", 4, ret.get("Job8").size());

Map<String, List<Interval>> ret2 = divideGenome(dict, 3000, false, -1);
Expand Down Expand Up @@ -183,12 +189,12 @@ public void testScatter()
}

Map<String, List<Interval>> ret5 = divideGenome(dict, 750, true, -1);
assertEquals("Incorrect number of jobs", 10, ret5.size());
assertEquals("Incorrect interval end", 1000, ret5.get("Job1").get(0).getEnd());
assertEquals("Incorrect interval end", 4, ret5.get("Job10").size());
assertEquals("Incorrect number of jobs", 9, ret5.size());
assertEquals("Incorrect interval end", 750, ret5.get("Job1").get(0).getEnd());
assertEquals("Incorrect interval end", 4, ret5.get("Job9").size());

assertEquals("Incorrect interval start", 751, ret5.get("Job3").get(0).getStart());
assertEquals("Incorrect interval start", 1501, ret5.get("Job8").get(0).getStart());
assertEquals("Incorrect interval start", 1501, ret5.get("Job3").get(0).getStart());
assertEquals("Incorrect interval start", 1, ret5.get("Job8").get(0).getStart());

Map<String, List<Interval>> ret6 = divideGenome(dict, 5000, false, 2);
assertEquals("Incorrect number of jobs", 5, ret6.size());
Expand Down

0 comments on commit 0a65541

Please sign in to comment.