forked from broadinstitute/pyro-cov
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess_gisaid.py
executable file
·114 lines (92 loc) · 4.16 KB
/
preprocess_gisaid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright Contributors to the Pyro-Cov project.
# SPDX-License-Identifier: Apache-2.0
import argparse
import datetime
import json
import logging
import os
import pickle
import warnings
from collections import Counter, defaultdict
from pyrocov import pangolin
from pyrocov.geo import gisaid_normalize
from pyrocov.mutrans import START_DATE
logger = logging.getLogger(__name__)
logging.basicConfig(format="%(relativeCreated) 9d %(message)s", level=logging.INFO)
DATE_FORMATS = {4: "%Y", 7: "%Y-%m", 10: "%Y-%m-%d"}
def parse_date(string):
fmt = DATE_FORMATS.get(len(string))
if fmt is None:
# Attempt to fix poorly formated dates like 2020-09-1.
parts = string.split("-")
parts = parts[:1] + [f"{int(p):>02d}" for p in parts[1:]]
string = "-".join(parts)
fmt = DATE_FORMATS[len(string)]
return datetime.datetime.strptime(string, fmt)
FIELDS = ["virus_name", "accession_id", "collection_date", "location", "add_location"]
def main(args):
logger.info(f"Filtering {args.gisaid_file_in}")
if not os.path.exists(args.gisaid_file_in):
raise OSError(f"Missing {args.gisaid_file_in}; you may need to request a feed")
os.makedirs("results", exist_ok=True)
columns = defaultdict(list)
stats = defaultdict(Counter)
covv_fields = ["covv_" + key for key in FIELDS]
with open(args.gisaid_file_in) as f:
for i, line in enumerate(f):
# Optimize for faster reading.
line, _ = line.split(', "sequence": ', 1)
line += "}"
# Filter out bad data.
datum = json.loads(line)
if len(datum["covv_collection_date"]) < 7:
continue # Drop rows with no month information.
date = parse_date(datum["covv_collection_date"])
if date < args.start_date:
date = args.start_date # Clip rows before start date.
lineage = datum["covv_lineage"]
if lineage in (None, "None", "", "XA"):
continue # Drop rows with unknown or ambiguous lineage.
try:
lineage = pangolin.compress(lineage)
lineage = pangolin.decompress(lineage)
assert lineage
except (ValueError, AssertionError) as e:
warnings.warn(str(e))
continue
# Fix duplicate locations.
datum["covv_location"] = gisaid_normalize(datum["covv_location"])
# Collate.
columns["lineage"].append(lineage)
for covv_key, key in zip(covv_fields, FIELDS):
columns[key].append(datum[covv_key])
columns["day"].append((date - args.start_date).days)
# Aggregate statistics.
stats["date"][datum["covv_collection_date"]] += 1
stats["location"][datum["covv_location"]] += 1
stats["lineage"][lineage] += 1
if i % args.log_every == 0:
print(".", end="", flush=True)
if i >= args.truncate:
break
num_dropped = i + 1 - len(columns["day"])
logger.info(f"dropped {num_dropped}/{i+1} = {num_dropped/(i+1)/100:0.2g}% rows")
logger.info(f"saving {args.columns_file_out}")
with open(args.columns_file_out, "wb") as f:
pickle.dump(dict(columns), f)
logger.info(f"saving {args.stats_file_out}")
with open(args.stats_file_out, "wb") as f:
pickle.dump(dict(stats), f)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess GISAID data")
parser.add_argument("--gisaid-file-in", default="results/gisaid.json")
parser.add_argument("--columns-file-out", default="results/gisaid.columns.pkl")
parser.add_argument("--stats-file-out", default="results/gisaid.stats.pkl")
parser.add_argument("--subset-file-out", default="results/gisaid.subset.tsv")
parser.add_argument("--subset-dir-out", default="results/fasta")
parser.add_argument("--start-date", default=START_DATE)
parser.add_argument("-l", "--log-every", default=1000, type=int)
parser.add_argument("--truncate", default=int(1e10), type=int)
args = parser.parse_args()
args.start_date = parse_date(args.start_date)
main(args)