Skip to content

Commit 516e960

Browse files
committed
Add support for targets in numba too
1 parent 9d8f88e commit 516e960

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

tests/test_vcf_writer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_write_vcf(shared_datadir, tmp_path, output_is_path, implementation):
5555
assert_vcfs_close(path, output)
5656

5757

58-
@pytest.mark.parametrize("implementation", ["c"])
58+
@pytest.mark.parametrize("implementation", ["c", "numba"])
5959
def test_write_vcf__targets(shared_datadir, tmp_path, implementation):
6060
path = shared_datadir / "vcf" / "sample.vcf.gz"
6161
intermediate_icf = tmp_path.joinpath("intermediate.icf")
@@ -72,8 +72,12 @@ def test_write_vcf__targets(shared_datadir, tmp_path, implementation):
7272

7373
assert v.samples == ["NA00001", "NA00002", "NA00003"]
7474

75+
count = 0
7576
for variant in v:
7677
assert variant.CHROM == "20"
78+
count += 1
79+
80+
assert count == 6
7781

7882

7983
def test_write_vcf__set_header(shared_datadir, tmp_path):

vcztools/vcf_writer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def write_vcf(
188188
numba_chunk_to_vcf(
189189
root,
190190
v_chunk,
191+
v_mask_chunk,
191192
header_info_fields,
192193
header_format_fields,
193194
contigs,
@@ -296,16 +297,16 @@ def c_chunk_to_vcf(root, v_chunk, v_mask_chunk, contigs, filters, output):
296297

297298

298299
def numba_chunk_to_vcf(
299-
root, v_chunk, header_info_fields, header_format_fields, contigs, filters, output
300+
root, v_chunk, v_mask_chunk, header_info_fields, header_format_fields, contigs, filters, output
300301
):
301302
# fixed fields
302303

303-
chrom = root.variant_contig.blocks[v_chunk]
304-
pos = root.variant_position.blocks[v_chunk]
305-
id = root.variant_id.blocks[v_chunk].astype("S")
306-
alleles = root.variant_allele.blocks[v_chunk].astype("S")
307-
qual = root.variant_quality.blocks[v_chunk]
308-
filter_ = root.variant_filter.blocks[v_chunk]
304+
chrom = get_block_selection(root.variant_contig, v_chunk, v_mask_chunk)
305+
pos = get_block_selection(root.variant_position, v_chunk, v_mask_chunk)
306+
id = get_block_selection(root.variant_id, v_chunk, v_mask_chunk).astype("S")
307+
alleles = get_block_selection(root.variant_allele, v_chunk, v_mask_chunk).astype("S")
308+
qual = get_block_selection(root.variant_quality, v_chunk, v_mask_chunk)
309+
filter_ = get_block_selection(root.variant_filter, v_chunk, v_mask_chunk)
309310

310311
n_variants = len(pos)
311312

@@ -327,7 +328,7 @@ def numba_chunk_to_vcf(
327328
# not the other way around. This is probably not what we want to
328329
# do, but keeping it this way to preserve tests initially.
329330
continue
330-
values = arr.blocks[v_chunk]
331+
values = get_block_selection(arr, v_chunk, v_mask_chunk)
331332
if arr.dtype == bool:
332333
info_mask[k] = create_mask(values)
333334
info_bufs.append(np.zeros(0, dtype=np.uint8))
@@ -366,7 +367,7 @@ def numba_chunk_to_vcf(
366367
var = "call_genotype" if key == "GT" else f"call_{key}"
367368
if var not in root:
368369
continue
369-
values = root[var].blocks[v_chunk]
370+
values = get_block_selection(root[var], v_chunk, v_mask_chunk)
370371
if key == "GT":
371372
n_samples = values.shape[1]
372373
format_mask[k] = create_mask(values)
@@ -394,7 +395,7 @@ def numba_chunk_to_vcf(
394395
format_indexes = np.empty((len(format_values), n_samples + 1), dtype=np.int32)
395396

396397
if "call_genotype_phased" in root:
397-
call_genotype_phased = root["call_genotype_phased"].blocks[v_chunk][:]
398+
call_genotype_phased = get_block_selection(root["call_genotype_phased"], v_chunk, v_mask_chunk)[:]
398399
else:
399400
call_genotype_phased = np.full((n_variants, n_samples), False, dtype=bool)
400401

0 commit comments

Comments
 (0)