-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwright_fisher.py
117 lines (103 loc) · 4.38 KB
/
wright_fisher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import numpy as np
import tskit
class Parent(object):
def __init__(self, index, n0, n1):
self.index = index
self.n0 = n0
self.n1 = n1
class PopState(object):
def __init__(self, N):
self.parents = [Parent(i, 2 * i, 2 * i + 1) for i in range(N)]
self.next_parent = N
self.tables = tskit.TableCollection(1.0)
self.buffered_edges = [[[], []] for i in range(N)]
self.pnodes = [(2 * i, 2 * i + 1) for i in range(N)]
self.generation_offsets = [(0, len(self.buffered_edges))]
self.current_generation = 0
# Measure time going forwards.
# Will reverse later
for i in range(N):
self.tables.nodes.add_row(time=0.0)
self.tables.nodes.add_row(time=0.0)
def wright_fisher(ngens, psurvival, popstate):
if psurvival >= 1.0 or psurvival < 0:
raise ValueError("unhelpful survival probability")
for gen in range(1, ngens + 1):
# regulation
dead = []
parent_list = []
for i, p in enumerate(popstate.parents):
if p.index == -1:
raise RuntimeError("oops, dead!")
if np.random.uniform() > psurvival:
p.index = -1
parents = np.random.choice(len(popstate.parents), 2)
# "Mendel"
p0node = popstate.parents[parents[0]].n0
if np.random.uniform() < 0.5:
p0node = popstate.parents[parents[0]].n1
p1node = popstate.parents[parents[1]].n0
if np.random.uniform() < 0.5:
p1node = popstate.parents[parents[1]].n1
parent_list.append((p0node, p1node))
dead.append(i)
for d, p in zip(dead, parent_list):
n0 = popstate.tables.nodes.add_row(time=gen)
n1 = popstate.tables.nodes.add_row(time=gen)
popstate.tables.edges.add_row(left=0, right=1, parent=p[0], child=n0)
popstate.tables.edges.add_row(left=0, right=1, parent=p[1], child=n1)
popstate.parents[d] = Parent(popstate.next_parent, n0, n1)
popstate.next_parent += 1
return popstate
def wright_fisher_eb(ngens, psurvival, popstate):
if psurvival >= 1.0 or psurvival < 0:
raise ValueError("unhelpful survival probability")
for gen in range(1, ngens + 1):
# regulation
dead = []
parent_list = []
for i, p in enumerate(popstate.parents):
if p.index == -1:
raise RuntimeError("oops, dead!")
if np.random.uniform() > psurvival:
parents = np.random.choice(len(popstate.parents), 2)
# "Mendel"
p0node = popstate.parents[parents[0]].n0
i0 = 0
if np.random.uniform() < 0.5:
p0node = popstate.parents[parents[0]].n1
i0 = 1
p1node = popstate.parents[parents[1]].n0
i1 = 0
if np.random.uniform() < 0.5:
p1node = popstate.parents[parents[1]].n1
i1 = 1
parent_list.append(
(
p0node,
p1node,
i0,
i1,
popstate.parents[parents[0]].index,
popstate.parents[parents[1]].index,
)
)
dead.append(i)
x = len(popstate.buffered_edges)
for d, p in zip(dead, parent_list):
# NOTE: apply "dead" flag here
# so that we aren't giving invalid
# indexes in the regulation step above
popstate.parents[d].index = -1
n0 = popstate.tables.nodes.add_row(time=popstate.current_generation + gen)
n1 = popstate.tables.nodes.add_row(time=popstate.current_generation + gen)
popstate.buffered_edges[p[4]][p[2]].append((0, 1, p[0], n0))
popstate.buffered_edges[p[5]][p[3]].append((0, 1, p[1], n1))
popstate.pnodes.append((n0, n1))
popstate.parents[d] = Parent(popstate.next_parent, n0, n1)
popstate.next_parent += 1
popstate.buffered_edges.append([[], []])
if len(dead) > 0:
popstate.generation_offsets.append((x, len(popstate.buffered_edges)))
popstate.current_generation += gen
return popstate