Skip to content

Commit 0c0c47a

Browse files
Add support for clear-default missing fields in tskit
Closes #107
1 parent b0d2da8 commit 0c0c47a

File tree

3 files changed

+165
-9
lines changed

3 files changed

+165
-9
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ dev = [
5252
"cyvcf2",
5353
"pytest",
5454
"pytest-cov",
55+
"msprime",
56+
"sgkit",
5557
]
5658

5759
[tool.setuptools]

tests/test_tskit_data.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""
2+
Tests for data originating from tskit format for compatibility
3+
with various outputs.
4+
"""
5+
6+
import bio2zarr.tskit as ts2z
7+
import bio2zarr.vcf as v2z
8+
import msprime
9+
import numpy as np
10+
import numpy.testing as nt
11+
import pytest
12+
import sgkit as sg
13+
import tskit
14+
import xarray.testing as xt
15+
16+
from vcztools.vcf_writer import write_vcf
17+
18+
19+
def add_mutations(ts):
20+
# Add some mutation to the tree sequence. This guarantees that
21+
# we have variation at all sites > 0.
22+
tables = ts.dump_tables()
23+
samples = ts.samples()
24+
states = "ACGT"
25+
for j in range(1, int(ts.sequence_length) - 1):
26+
site = tables.sites.add_row(j, ancestral_state=states[j % 4])
27+
tables.mutations.add_row(
28+
site=site,
29+
derived_state=states[(j + 1) % 4],
30+
node=samples[j % ts.num_samples],
31+
)
32+
return tables.tree_sequence()
33+
34+
35+
@pytest.fixture()
36+
def fx_diploid_msprime_sim(tmp_path):
37+
seed = 1234
38+
ts = msprime.sim_ancestry(5, sequence_length=100, random_seed=seed)
39+
ts = msprime.sim_mutations(ts, rate=0.5, random_seed=seed)
40+
assert ts.num_mutations > 0
41+
zarr_path = tmp_path / "sim.vcz"
42+
ts2z.convert(ts, zarr_path)
43+
return zarr_path
44+
45+
46+
@pytest.fixture()
47+
def fx_haploid_msprime_sim(tmp_path):
48+
seed = 12345
49+
ts = msprime.sim_ancestry(5, ploidy=1, sequence_length=100, random_seed=seed)
50+
ts = msprime.sim_mutations(ts, rate=0.5, random_seed=seed)
51+
assert ts.num_mutations > 0
52+
zarr_path = tmp_path / "sim.vcz"
53+
ts2z.convert(ts, zarr_path)
54+
return zarr_path
55+
56+
57+
def simple_ts_tables():
58+
tables = tskit.TableCollection(sequence_length=100)
59+
for _ in range(4):
60+
ind = -1
61+
ind = tables.individuals.add_row()
62+
tables.nodes.add_row(flags=tskit.NODE_IS_SAMPLE, time=0, individual=ind)
63+
tables.nodes.add_row(flags=0, time=1) # MRCA for 0,1
64+
tables.nodes.add_row(flags=0, time=1) # MRCA for 2,3
65+
tables.edges.add_row(left=0, right=100, parent=4, child=0)
66+
tables.edges.add_row(left=0, right=100, parent=4, child=1)
67+
tables.edges.add_row(left=0, right=100, parent=5, child=2)
68+
tables.edges.add_row(left=0, right=100, parent=5, child=3)
69+
site_id = tables.sites.add_row(position=10, ancestral_state="A")
70+
tables.mutations.add_row(site=site_id, node=4, derived_state="TTTT")
71+
site_id = tables.sites.add_row(position=20, ancestral_state="CCC")
72+
tables.mutations.add_row(site=site_id, node=5, derived_state="G")
73+
site_id = tables.sites.add_row(position=30, ancestral_state="G")
74+
tables.mutations.add_row(site=site_id, node=0, derived_state="AA")
75+
76+
tables.sort()
77+
return tables
78+
79+
80+
@pytest.fixture()
81+
def fx_simple_ts(tmp_path):
82+
ts = simple_ts_tables().tree_sequence()
83+
zarr_path = tmp_path / "sim.vcz"
84+
ts2z.convert(ts, zarr_path)
85+
return zarr_path
86+
87+
88+
# TODO add other fixtures here like stuff with odd mixtures of ploidy,
89+
# and zero variants (need to address
90+
# https://github.com/sgkit-dev/bio2zarr/issues/342 before zero variants
91+
# handled)
92+
93+
94+
class TestVcfRoundTrip:
95+
def assert_bio2zarr_rt(self, tmp_path, tskit_vcz):
96+
vcf_path = tmp_path / "out.vcf"
97+
write_vcf(tskit_vcz, vcf_path)
98+
rt_vcz_path = tmp_path / "rt.vcz"
99+
v2z.convert([vcf_path], rt_vcz_path)
100+
ds1 = sg.load_dataset(tskit_vcz)
101+
ds2 = sg.load_dataset(rt_vcz_path)
102+
drop_fields = [
103+
"variant_id",
104+
"variant_id_mask",
105+
"filter_id",
106+
"filter_description",
107+
"variant_filter",
108+
"variant_quality",
109+
]
110+
xt.assert_equal(ds1, ds2.drop(drop_fields))
111+
num_variants = ds2.dims["variants"]
112+
assert np.all(np.isnan(ds2["variant_quality"].values))
113+
nt.assert_array_equal(
114+
ds2["variant_filter"], np.ones((num_variants, 1), dtype=bool)
115+
)
116+
assert list(ds2["filter_id"].values) == ["PASS"]
117+
118+
def test_diploid_msprime_sim(self, tmp_path, fx_diploid_msprime_sim):
119+
self.assert_bio2zarr_rt(tmp_path, fx_diploid_msprime_sim)
120+
121+
def test_haploid_msprime_sim(self, tmp_path, fx_haploid_msprime_sim):
122+
self.assert_bio2zarr_rt(tmp_path, fx_haploid_msprime_sim)
123+
124+
def test_simple_ts(self, tmp_path, fx_simple_ts):
125+
self.assert_bio2zarr_rt(tmp_path, fx_simple_ts)

vcztools/vcf_writer.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from . import _vcztools, constants, retrieval
1515
from . import filter as filter_mod
16-
from .constants import RESERVED_VARIABLE_NAMES
16+
from .constants import FLOAT32_MISSING, RESERVED_VARIABLE_NAMES
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -156,7 +156,7 @@ def write_vcf(
156156
return
157157

158158
contigs = root["contig_id"][:].astype("S")
159-
filters = root["filter_id"][:].astype("S")
159+
filters = get_filter_ids(root)
160160

161161
for chunk_data in retrieval.variant_chunk_iter(
162162
root,
@@ -187,19 +187,35 @@ def c_chunk_to_vcf(
187187
drop_genotypes,
188188
no_update,
189189
):
190+
format_fields = {}
191+
info_fields = {}
192+
num_samples = len(samples_selection) if samples_selection is not None else None
193+
190194
# TODO check we don't truncate silently by doing this
191195
pos = chunk_data["variant_position"].astype(np.int32)
192196
num_variants = len(pos)
193197
if num_variants == 0:
194198
return ""
199+
# Required fields
195200
chrom = contigs[chunk_data["variant_contig"]]
196-
id = chunk_data["variant_id"].astype("S")
197201
alleles = chunk_data["variant_allele"]
198-
qual = chunk_data["variant_quality"]
199-
filter_ = chunk_data["variant_filter"]
200-
format_fields = {}
201-
info_fields = {}
202-
num_samples = len(samples_selection) if samples_selection is not None else None
202+
203+
# Optional fields which we fill in with "all missing" defaults
204+
if "variant_id" in chunk_data:
205+
id = chunk_data["variant_id"].astype("S")
206+
else:
207+
id = np.array(["."] * num_variants, dtype="S")
208+
if "variant_quality" in chunk_data:
209+
qual = chunk_data["variant_quality"]
210+
else:
211+
qual = np.full(num_variants, FLOAT32_MISSING, dtype=np.float32)
212+
213+
# Filter defaults to "PASS" if not present
214+
if "variant_filter" in chunk_data:
215+
filter_ = chunk_data["variant_filter"]
216+
else:
217+
filter_ = np.ones((num_variants, 1), dtype=bool)
218+
203219
gt = None
204220
gt_phased = None
205221

@@ -213,6 +229,7 @@ def c_chunk_to_vcf(
213229
):
214230
gt_phased = chunk_data["call_genotype_phased"]
215231
else:
232+
# Default to unphased if call_genotype_phased not present
216233
gt_phased = np.zeros_like(gt, dtype=bool)
217234

218235
for name, array in chunk_data.items():
@@ -294,6 +311,18 @@ def c_chunk_to_vcf(
294311
print(line, file=output)
295312

296313

314+
def get_filter_ids(root):
315+
"""
316+
Returns the filter IDs from the specified Zarr store. If the array
317+
does not exist, return a single filter "PASS" by default.
318+
"""
319+
if "filter_id" in root:
320+
filters = root["filter_id"][:].astype("S")
321+
else:
322+
filters = np.array(["PASS"], dtype="S")
323+
return filters
324+
325+
297326
def _generate_header(
298327
ds,
299328
sample_ids,
@@ -304,7 +333,7 @@ def _generate_header(
304333
output = io.StringIO()
305334

306335
contigs = list(ds["contig_id"][:])
307-
filters = list(ds["filter_id"][:])
336+
filters = list(get_filter_ids(ds).astype("U"))
308337
info_fields = []
309338
format_fields = []
310339

0 commit comments

Comments
 (0)