Skip to content

Commit 601f80f

Browse files
committed
Add basic support for '--targets/-t'
1 parent a0b5342 commit 601f80f

File tree

4 files changed

+93
-14
lines changed

4 files changed

+93
-14
lines changed

tests/test_regions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Optional
2+
import pytest
3+
from vcztools.regions import parse_targets_string
4+
5+
@pytest.mark.parametrize(
6+
"targets, expected",
7+
[
8+
("chr1:12-103", ("chr1", 12, 103)),
9+
],
10+
)
11+
def test_parse_targets_string(targets: str, expected: tuple[str, Optional[int], Optional[int]]):
12+
assert parse_targets_string(targets) == expected

vcztools/cli.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,16 @@ def list_commands(self, ctx):
1818
@click.command
1919
@click.argument("path", type=click.Path())
2020
@click.option("-c", is_flag=True, default=False, help="Use C implementation")
21-
def view(path, c):
21+
@click.option(
22+
"-t",
23+
"--targets",
24+
type=str,
25+
default=None,
26+
help="Target regions to include.",
27+
)
28+
def view(path, c, targets):
2229
implementation = "c" if c else "numba"
23-
vcf_writer.write_vcf(path, sys.stdout, implementation=implementation)
30+
vcf_writer.write_vcf(path, sys.stdout, variant_targets=targets, implementation=implementation)
2431

2532

2633
@click.group(cls=NaturalOrderGroup, name="vcztools")

vcztools/regions.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import re
2+
from typing import Any, List, Optional
3+
4+
import numpy as np
5+
6+
def parse_targets_string(targets: str) -> tuple[str, Optional[int], Optional[int]]:
7+
"""Return the contig, start position and end position from a targets string."""
8+
if re.search(r":\d+-\d*$", targets):
9+
contig, start_end = targets.rsplit(":", 1)
10+
start, end = start_end.split("-")
11+
return contig, int(start), int(end)
12+
raise NotImplementedError()
13+
14+
15+
def pslice_to_slice(
16+
all_contigs: List[str],
17+
variant_contig: Any,
18+
variant_position: Any,
19+
contig: str,
20+
start: Optional[int] = None,
21+
end: Optional[int] = None,
22+
) -> slice:
23+
24+
contig_index = all_contigs.index(contig)
25+
contig_range = np.searchsorted(variant_contig, [contig_index, contig_index + 1])
26+
27+
if start is None and end is None:
28+
start_index, end_index = contig_range
29+
else:
30+
contig_pos = variant_position[slice(contig_range[0], contig_range[1])]
31+
if start is None:
32+
start_index = contig_range[0]
33+
end_index = contig_range[0] + np.searchsorted(contig_pos, [end])[0]
34+
elif end is None:
35+
start_index = contig_range[0] + np.searchsorted(contig_pos, [start])[0]
36+
end_index = contig_range[1]
37+
else:
38+
start_index, end_index = contig_range[0] + np.searchsorted(
39+
contig_pos, [start, end]
40+
)
41+
42+
return slice(start_index, end_index)

vcztools/vcf_writer.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import MutableMapping, Optional, TextIO, Union
66

77
import numpy as np
8+
from vcztools.regions import parse_targets_string, pslice_to_slice
89
import zarr
910

1011
from . import _vcztools
@@ -80,7 +81,7 @@ def dims(arr):
8081

8182

8283
def write_vcf(
83-
vcz, output, *, vcf_header: Optional[str] = None, implementation="numba"
84+
vcz, output, *, vcf_header: Optional[str] = None, variant_targets=None, implementation="numba"
8485
) -> None:
8586
"""Convert a dataset to a VCF file.
8687
@@ -163,7 +164,19 @@ def write_vcf(
163164
contigs = root["contig_id"][:].astype("S")
164165
filters = root["filter_id"][:].astype("S")
165166

167+
if variant_targets is None:
168+
variant_mask = np.ones(pos.shape[0], dtype=bool)
169+
else:
170+
contig, start, end = parse_targets_string(variant_targets)
171+
variant_slice = pslice_to_slice(root["contig_id"][:].astype("U").tolist(), root["variant_contig"], pos, contig, start, end)
172+
variant_mask = np.zeros(pos.shape[0], dtype=bool)
173+
variant_mask[variant_slice] = 1
174+
# Use zarr arrays to get mask chunks aligned with the main data
175+
# for convenience.
176+
z_variant_mask = zarr.array(variant_mask, chunks=pos.chunks[0])
177+
166178
for v_chunk in range(pos.cdata_shape[0]):
179+
v_mask_chunk = z_variant_mask.blocks[v_chunk]
167180
if implementation == "numba":
168181
numba_chunk_to_vcf(
169182
root,
@@ -178,22 +191,27 @@ def write_vcf(
178191
c_chunk_to_vcf(
179192
root,
180193
v_chunk,
194+
v_mask_chunk,
181195
contigs,
182196
filters,
183197
output,
184198
)
185199

186200

187-
def c_chunk_to_vcf(root, v_chunk, contigs, filters, output):
188-
chrom = contigs[root.variant_contig.blocks[v_chunk]]
201+
def get_block_selection(zarray, key, mask):
202+
return zarray.blocks[key][mask]
203+
204+
205+
def c_chunk_to_vcf(root, v_chunk, v_mask_chunk, contigs, filters, output):
206+
chrom = contigs[get_block_selection(root.variant_contig, v_chunk, v_mask_chunk)]
189207
# TODO check we don't truncate silently by doing this
190-
pos = root.variant_position.blocks[v_chunk].astype(np.int32)
191-
id = root.variant_id.blocks[v_chunk].astype("S")
192-
alleles = root.variant_allele.blocks[v_chunk]
208+
pos = get_block_selection(root.variant_position, v_chunk, v_mask_chunk).astype(np.int32)
209+
id = get_block_selection(root.variant_id, v_chunk, v_mask_chunk).astype("S")
210+
alleles = get_block_selection(root.variant_allele, v_chunk, v_mask_chunk)
193211
ref = alleles[:, 0].astype("S")
194212
alt = alleles[:, 1:].astype("S")
195-
qual = root.variant_quality.blocks[v_chunk]
196-
filter_ = root.variant_filter.blocks[v_chunk]
213+
qual = get_block_selection(root.variant_quality, v_chunk, v_mask_chunk)
214+
filter_ = get_block_selection(root.variant_filter, v_chunk, v_mask_chunk)
197215

198216
num_variants = len(pos)
199217
if len(id.shape) == 1:
@@ -207,21 +225,21 @@ def c_chunk_to_vcf(root, v_chunk, contigs, filters, output):
207225
for name, array in root.items():
208226
if name.startswith("call_") and not name.startswith("call_genotype"):
209227
vcf_name = name[len("call_") :]
210-
format_fields[vcf_name] = array.blocks[v_chunk]
228+
format_fields[vcf_name] = get_block_selection(array, v_chunk, v_mask_chunk)
211229
if num_samples is None:
212230
num_samples = array.shape[1]
213231
elif name.startswith("variant_") and name not in RESERVED_VARIABLE_NAMES:
214232
vcf_name = name[len("variant_") :]
215-
info_fields[vcf_name] = array.blocks[v_chunk]
233+
info_fields[vcf_name] = get_block_selection(array, v_chunk, v_mask_chunk)
216234

217235
gt = None
218236
gt_phased = None
219237
if "call_genotype" in root:
220238
array = root["call_genotype"]
221-
gt = array.blocks[v_chunk]
239+
gt = get_block_selection(array, v_chunk, v_mask_chunk)
222240
if "call_genotype_phased" in root:
223241
array = root["call_genotype_phased"]
224-
gt_phased = array.blocks[v_chunk]
242+
gt_phased = get_block_selection(array, v_chunk, v_mask_chunk)
225243
else:
226244
gt_phased = np.zeros_like(gt, dtype=bool)
227245

0 commit comments

Comments
 (0)