Skip to content

Commit 9ea1fed

Browse files
committed
final updates to BOSS and GRaSP
1 parent 58419ec commit 9ea1fed

File tree

5 files changed

+31
-137
lines changed

5 files changed

+31
-137
lines changed

causallearn/search/PermutationBased/BOSS.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ def boss(
4848
G : learned causal graph, where G.graph[j,i] = 1 and G.graph[i,j] = -1 indicates i --> j, G.graph[i,j] = G.graph[j,i] = -1 indicates i --- j.
4949
"""
5050

51+
X = X.copy()
5152
n, p = X.shape
5253
if n < p:
5354
warnings.warn("The number of features is much larger than the sample size!")
5455

55-
X = np.mat(X)
5656
if score_func == "local_score_CV_general":
5757
# % k-fold negative cross validated likelihood based on regression in RKHS
5858
if parameters is None:
@@ -63,14 +63,12 @@ def boss(
6363
localScoreClass = LocalScoreClass(
6464
data=X, local_score_fun=local_score_cv_general, parameters=parameters
6565
)
66-
6766
elif score_func == "local_score_marginal_general":
6867
# negative marginal likelihood based on regression in RKHS
6968
parameters = {}
7069
localScoreClass = LocalScoreClass(
7170
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
7271
)
73-
7472
elif score_func == "local_score_CV_multi":
7573
# k-fold negative cross validated likelihood based on regression in RKHS
7674
# for data with multi-variate dimensions
@@ -85,7 +83,6 @@ def boss(
8583
localScoreClass = LocalScoreClass(
8684
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
8785
)
88-
8986
elif score_func == "local_score_marginal_multi":
9087
# negative marginal likelihood based on regression in RKHS
9188
# for data with multi-variate dimensions
@@ -96,7 +93,6 @@ def boss(
9693
localScoreClass = LocalScoreClass(
9794
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
9895
)
99-
10096
elif score_func == "local_score_BIC":
10197
# SEM BIC score
10298
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
@@ -105,25 +101,22 @@ def boss(
105101
localScoreClass = LocalScoreClass(
106102
data=X, local_score_fun=local_score_BIC, parameters=parameters
107103
)
108-
109104
elif score_func == "local_score_BIC_from_cov":
110105
# SEM BIC score
111106
if parameters is None:
112107
parameters = {"lambda_value": 2}
113108
localScoreClass = LocalScoreClass(
114109
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
115110
)
116-
117111
elif score_func == "local_score_BDeu":
118112
# BDeu score
119113
localScoreClass = LocalScoreClass(
120114
data=X, local_score_fun=local_score_BDeu, parameters=None
121115
)
122-
123116
else:
124117
raise Exception("Unknown function!")
118+
125119
score = localScoreClass
126-
127120
gsts = [GST(i, score) for i in range(p)]
128121

129122
node_names = [("X%d" % (i + 1)) for i in range(p)] if node_names is None else node_names

causallearn/search/PermutationBased/GRaSP.py

Lines changed: 9 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ def grasp(
106106
G : learned causal graph, where G.graph[j,i] = 1 and G.graph[i,j] = -1 indicates i --> j, G.graph[i,j] = G.graph[j,i] = -1 indicates i --- j.
107107
"""
108108

109+
X = X.copy()
109110
n, p = X.shape
110111
if n < p:
111112
warnings.warn("The number of features is much larger than the sample size!")
112113

113-
X = np.mat(X)
114114
if score_func == "local_score_CV_general":
115115
# k-fold negative cross validated likelihood based on regression in RKHS
116116
if parameters is None:
@@ -121,14 +121,12 @@ def grasp(
121121
localScoreClass = LocalScoreClass(
122122
data=X, local_score_fun=local_score_cv_general, parameters=parameters
123123
)
124-
125124
elif score_func == "local_score_marginal_general":
126125
# negative marginal likelihood based on regression in RKHS
127126
parameters = {}
128127
localScoreClass = LocalScoreClass(
129128
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
130129
)
131-
132130
elif score_func == "local_score_CV_multi":
133131
# k-fold negative cross validated likelihood based on regression in RKHS
134132
# for data with multi-variate dimensions
@@ -143,7 +141,6 @@ def grasp(
143141
localScoreClass = LocalScoreClass(
144142
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
145143
)
146-
147144
elif score_func == "local_score_marginal_multi":
148145
# negative marginal likelihood based on regression in RKHS
149146
# for data with multi-variate dimensions
@@ -154,7 +151,6 @@ def grasp(
154151
localScoreClass = LocalScoreClass(
155152
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
156153
)
157-
158154
elif score_func == "local_score_BIC":
159155
# SEM BIC score
160156
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
@@ -163,25 +159,22 @@ def grasp(
163159
localScoreClass = LocalScoreClass(
164160
data=X, local_score_fun=local_score_BIC, parameters=parameters
165161
)
166-
167162
elif score_func == "local_score_BIC_from_cov":
168163
# SEM BIC score
169164
if parameters is None:
170165
parameters = {"lambda_value": 2}
171166
localScoreClass = LocalScoreClass(
172167
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
173168
)
174-
175169
elif score_func == "local_score_BDeu":
176170
# BDeu score
177171
localScoreClass = LocalScoreClass(
178172
data=X, local_score_fun=local_score_BDeu, parameters=None
179173
)
180-
181174
else:
182175
raise Exception("Unknown function!")
183-
score = localScoreClass
184176

177+
score = localScoreClass
185178
gsts = [GST(i, score) for i in range(p)]
186179

187180
node_names = [("X%d" % (i + 1)) for i in range(p)] if node_names is None else node_names
@@ -201,15 +194,11 @@ def grasp(
201194
y_parents = order.get_parents(y)
202195

203196
candidates = [order.get(j) for j in range(0, i)]
204-
# local_score = gsts[y].trace(candidates)
205-
206-
grow(y, y_parents, candidates, score)
207-
local_score = shrink(y, y_parents, score)
208-
197+
local_score = gsts[y].trace(candidates, y_parents)
209198
order.set_local_score(y, local_score)
210199
order.bump_edges(len(y_parents))
211200

212-
while dfs(depth - 1, set(), [], order, score, gsts):
201+
while dfs(depth - 1, set(), [], order, gsts):
213202
if verbose:
214203
sys.stdout.write("\rGRaSP edge count: %i " % order.get_edges())
215204
sys.stdout.flush()
@@ -230,7 +219,7 @@ def grasp(
230219

231220

232221
# performs a dfs over covered tucks
233-
def dfs(depth: int, flipped: set, history: List[set], order, score, gsts):
222+
def dfs(depth: int, flipped: set, history: List[set], order, gsts):
234223

235224
cache = [{}, {}, {}, 0]
236225

@@ -258,7 +247,7 @@ def dfs(depth: int, flipped: set, history: List[set], order, score, gsts):
258247
cache[3] = order.get_edges()
259248

260249
tuck(i, j, order)
261-
edge_bump, score_bump = update(i, j, order, score, gsts)
250+
edge_bump, score_bump = update(i, j, order, gsts)
262251

263252
# because things that should be zero sometimes are not
264253
if score_bump > 1e-6:
@@ -277,7 +266,7 @@ def dfs(depth: int, flipped: set, history: List[set], order, score, gsts):
277266

278267
if len(flipped) > 0 and flipped not in history:
279268
history.append(flipped)
280-
if depth > 0 and dfs(depth - 1, flipped, history, order, score, gsts):
269+
if depth > 0 and dfs(depth - 1, flipped, history, order, gsts):
281270
return True
282271
del history[-1]
283272

@@ -292,7 +281,7 @@ def dfs(depth: int, flipped: set, history: List[set], order, score, gsts):
292281

293282

294283
# updates the parents and scores after a tuck
295-
def update(i: int, j: int, order, score, gsts):
284+
def update(i: int, j: int, order, gsts):
296285

297286
edge_bump = 0
298287
old_score = 0
@@ -307,19 +296,7 @@ def update(i: int, j: int, order, score, gsts):
307296

308297
z_parents.clear()
309298
candidates = [order.get(l) for l in range(0, k)]
310-
311-
# for w in [w for w in z_parents if w not in candidates]:
312-
# z_parents.remove(w)
313-
# shrink(z, z_parents, score)
314-
315-
# for w in z_parents:
316-
# candidates.remove(w)
317-
318-
# local_score = gsts[z].trace(candidates, z_parents)
319-
320-
grow(z, z_parents, candidates, score)
321-
local_score = shrink(z, z_parents, score)
322-
299+
local_score = gsts[z].trace(candidates, z_parents)
323300
order.set_local_score(z, local_score)
324301

325302
edge_bump += len(z_parents)
@@ -328,68 +305,6 @@ def update(i: int, j: int, order, score, gsts):
328305
return edge_bump, new_score - old_score
329306

330307

331-
# grow of grow-shrink
332-
def grow(y: int, y_parents: List[int], candidates: List[int], score):
333-
334-
best = -score.score(y, y_parents)
335-
# best = -score.score_nocache(y, y_parents)
336-
337-
add = None
338-
checked = []
339-
while add is not None or len(candidates) > 0:
340-
341-
if add is not None:
342-
checked.remove(add)
343-
y_parents.append(add)
344-
candidates = checked
345-
checked = []
346-
add = None
347-
348-
while len(candidates) > 0:
349-
350-
x = candidates.pop()
351-
y_parents.append(x)
352-
current = -score.score(y, y_parents)
353-
# current = -score.score_nocache(y, y_parents)
354-
y_parents.remove(x)
355-
checked.append(x)
356-
357-
if current > best:
358-
best = current
359-
add = x
360-
361-
return best
362-
363-
364-
# shrink of grow-shrink
365-
def shrink(y: int, y_parents: List[int], score):
366-
367-
best = -score.score(y, y_parents)
368-
# best = -score.score_nocache(y, y_parents)
369-
370-
remove = None
371-
checked = 0
372-
while remove is not None or checked < len(y_parents):
373-
374-
if remove is not None:
375-
y_parents.remove(remove)
376-
checked = 0
377-
remove = None
378-
379-
while checked < len(y_parents):
380-
x = y_parents.pop(0)
381-
current = -score.score(y, y_parents)
382-
# current = -score.score_nocache(y, y_parents)
383-
y_parents.append(x)
384-
checked += 1
385-
386-
if current > best:
387-
best = current
388-
remove = x
389-
390-
return best
391-
392-
393308
# tucks the node at position i into position j
394309
def tuck(i: int, j: int, order):
395310

causallearn/search/PermutationBased/gst.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
class GSTNode:
22

33
def __init__(self, tree, add=None, score=None):
4-
# if score is None: score = -tree.score.score_nocache(tree.vertex, [])
5-
if score is None: score = -tree.score.score(tree.vertex, [])
4+
if score is None: score = -tree.score.score_nocache(tree.vertex, [])
5+
# if score is None: score = -tree.score.score(tree.vertex, [])
66
self.tree = tree
77
self.add = add
88
self.grow_score = score
@@ -17,8 +17,8 @@ def grow(self, available, parents):
1717
self.branches = []
1818
for add in available:
1919
parents.append(add)
20-
# score = -self.tree.score.score_nocache(self.tree.vertex, parents)
21-
score = -self.tree.score.score(self.tree.vertex, parents)
20+
score = -self.tree.score.score_nocache(self.tree.vertex, parents)
21+
# score = -self.tree.score.score(self.tree.vertex, parents)
2222
parents.remove(add)
2323
branch = GSTNode(self.tree, add, score)
2424
if score > self.grow_score: self.branches.append(branch)
@@ -30,8 +30,8 @@ def shrink(self, parents):
3030
best = None
3131
for remove in [parent for parent in parents]:
3232
parents.remove(remove)
33-
# score = -self.tree.score.score_nocache(self.tree.vertex, parents)
34-
score = -self.tree.score.score(self.tree.vertex, parents)
33+
score = -self.tree.score.score_nocache(self.tree.vertex, parents)
34+
# score = -self.tree.score.score(self.tree.vertex, parents)
3535
parents.append(remove)
3636
if score > self.shrink_score:
3737
self.shrink_score = score

tests/TestBOSS.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,8 @@
77
from causallearn.graph.GraphNode import GraphNode
88
from causallearn.search.PermutationBased.BOSS import boss
99
from causallearn.utils.DAG2CPDAG import dag2cpdag
10-
from causallearn.utils.GraphUtils import GraphUtils
11-
import matplotlib.image as mpimg
12-
import matplotlib.pyplot as plt
13-
import io
10+
11+
import gc
1412

1513

1614
def simulate_data(p, d, n):
@@ -53,11 +51,13 @@ def simulate_data(p, d, n):
5351
class TestBOSS(unittest.TestCase):
5452
def test_boss(self):
5553
ps = [30, 60]
56-
# ps = [30]
5754
ds = [0.1, 0.15]
5855
n = 1000
5956
reps = 5
6057

58+
gc.set_threshold(20000, 50, 50)
59+
# gc.set_debug(gc.DEBUG_STATS)
60+
6161
for p in ps:
6262
for d in ds:
6363
stats = [[], [], [], []]
@@ -84,15 +84,8 @@ def test_boss(self):
8484

8585
G0 = dag2cpdag(G0)
8686

87-
G = boss(X, parameters={'lambda_value': 2})
88-
89-
# pyd = GraphUtils.to_pydot(G)
90-
# tmp_png = pyd.create_png(f="png")
91-
# fp = io.BytesIO(tmp_png)
92-
# img = mpimg.imread(fp, format='png')
93-
# plt.axis('off')
94-
# plt.imshow(img)
95-
# plt.show()
87+
G = boss(X, parameters={'lambda_value': 4})
88+
gc.collect()
9689

9790
AdjC = AdjacencyConfusion(G0, G)
9891
stats[0].append(AdjC.get_adj_precision())

0 commit comments

Comments
 (0)