@@ -106,11 +106,11 @@ def grasp(
106
106
G : learned causal graph, where G.graph[j,i] = 1 and G.graph[i,j] = -1 indicates i --> j, G.graph[i,j] = G.graph[j,i] = -1 indicates i --- j.
107
107
"""
108
108
109
+ X = X .copy ()
109
110
n , p = X .shape
110
111
if n < p :
111
112
warnings .warn ("The number of features is much larger than the sample size!" )
112
113
113
- X = np .mat (X )
114
114
if score_func == "local_score_CV_general" :
115
115
# k-fold negative cross validated likelihood based on regression in RKHS
116
116
if parameters is None :
@@ -121,14 +121,12 @@ def grasp(
121
121
localScoreClass = LocalScoreClass (
122
122
data = X , local_score_fun = local_score_cv_general , parameters = parameters
123
123
)
124
-
125
124
elif score_func == "local_score_marginal_general" :
126
125
# negative marginal likelihood based on regression in RKHS
127
126
parameters = {}
128
127
localScoreClass = LocalScoreClass (
129
128
data = X , local_score_fun = local_score_marginal_general , parameters = parameters
130
129
)
131
-
132
130
elif score_func == "local_score_CV_multi" :
133
131
# k-fold negative cross validated likelihood based on regression in RKHS
134
132
# for data with multi-variate dimensions
@@ -143,7 +141,6 @@ def grasp(
143
141
localScoreClass = LocalScoreClass (
144
142
data = X , local_score_fun = local_score_cv_multi , parameters = parameters
145
143
)
146
-
147
144
elif score_func == "local_score_marginal_multi" :
148
145
# negative marginal likelihood based on regression in RKHS
149
146
# for data with multi-variate dimensions
@@ -154,7 +151,6 @@ def grasp(
154
151
localScoreClass = LocalScoreClass (
155
152
data = X , local_score_fun = local_score_marginal_multi , parameters = parameters
156
153
)
157
-
158
154
elif score_func == "local_score_BIC" :
159
155
# SEM BIC score
160
156
warnings .warn ("Please use 'local_score_BIC_from_cov' instead" )
@@ -163,25 +159,22 @@ def grasp(
163
159
localScoreClass = LocalScoreClass (
164
160
data = X , local_score_fun = local_score_BIC , parameters = parameters
165
161
)
166
-
167
162
elif score_func == "local_score_BIC_from_cov" :
168
163
# SEM BIC score
169
164
if parameters is None :
170
165
parameters = {"lambda_value" : 2 }
171
166
localScoreClass = LocalScoreClass (
172
167
data = X , local_score_fun = local_score_BIC_from_cov , parameters = parameters
173
168
)
174
-
175
169
elif score_func == "local_score_BDeu" :
176
170
# BDeu score
177
171
localScoreClass = LocalScoreClass (
178
172
data = X , local_score_fun = local_score_BDeu , parameters = None
179
173
)
180
-
181
174
else :
182
175
raise Exception ("Unknown function!" )
183
- score = localScoreClass
184
176
177
+ score = localScoreClass
185
178
gsts = [GST (i , score ) for i in range (p )]
186
179
187
180
node_names = [("X%d" % (i + 1 )) for i in range (p )] if node_names is None else node_names
@@ -201,15 +194,11 @@ def grasp(
201
194
y_parents = order .get_parents (y )
202
195
203
196
candidates = [order .get (j ) for j in range (0 , i )]
204
- # local_score = gsts[y].trace(candidates)
205
-
206
- grow (y , y_parents , candidates , score )
207
- local_score = shrink (y , y_parents , score )
208
-
197
+ local_score = gsts [y ].trace (candidates , y_parents )
209
198
order .set_local_score (y , local_score )
210
199
order .bump_edges (len (y_parents ))
211
200
212
- while dfs (depth - 1 , set (), [], order , score , gsts ):
201
+ while dfs (depth - 1 , set (), [], order , gsts ):
213
202
if verbose :
214
203
sys .stdout .write ("\r GRaSP edge count: %i " % order .get_edges ())
215
204
sys .stdout .flush ()
@@ -230,7 +219,7 @@ def grasp(
230
219
231
220
232
221
# performs a dfs over covered tucks
233
- def dfs (depth : int , flipped : set , history : List [set ], order , score , gsts ):
222
+ def dfs (depth : int , flipped : set , history : List [set ], order , gsts ):
234
223
235
224
cache = [{}, {}, {}, 0 ]
236
225
@@ -258,7 +247,7 @@ def dfs(depth: int, flipped: set, history: List[set], order, score, gsts):
258
247
cache [3 ] = order .get_edges ()
259
248
260
249
tuck (i , j , order )
261
- edge_bump , score_bump = update (i , j , order , score , gsts )
250
+ edge_bump , score_bump = update (i , j , order , gsts )
262
251
263
252
# because things that should be zero sometimes are not
264
253
if score_bump > 1e-6 :
@@ -277,7 +266,7 @@ def dfs(depth: int, flipped: set, history: List[set], order, score, gsts):
277
266
278
267
if len (flipped ) > 0 and flipped not in history :
279
268
history .append (flipped )
280
- if depth > 0 and dfs (depth - 1 , flipped , history , order , score , gsts ):
269
+ if depth > 0 and dfs (depth - 1 , flipped , history , order , gsts ):
281
270
return True
282
271
del history [- 1 ]
283
272
@@ -292,7 +281,7 @@ def dfs(depth: int, flipped: set, history: List[set], order, score, gsts):
292
281
293
282
294
283
# updates the parents and scores after a tuck
295
- def update (i : int , j : int , order , score , gsts ):
284
+ def update (i : int , j : int , order , gsts ):
296
285
297
286
edge_bump = 0
298
287
old_score = 0
@@ -307,19 +296,7 @@ def update(i: int, j: int, order, score, gsts):
307
296
308
297
z_parents .clear ()
309
298
candidates = [order .get (l ) for l in range (0 , k )]
310
-
311
- # for w in [w for w in z_parents if w not in candidates]:
312
- # z_parents.remove(w)
313
- # shrink(z, z_parents, score)
314
-
315
- # for w in z_parents:
316
- # candidates.remove(w)
317
-
318
- # local_score = gsts[z].trace(candidates, z_parents)
319
-
320
- grow (z , z_parents , candidates , score )
321
- local_score = shrink (z , z_parents , score )
322
-
299
+ local_score = gsts [z ].trace (candidates , z_parents )
323
300
order .set_local_score (z , local_score )
324
301
325
302
edge_bump += len (z_parents )
@@ -328,68 +305,6 @@ def update(i: int, j: int, order, score, gsts):
328
305
return edge_bump , new_score - old_score
329
306
330
307
331
- # grow of grow-shrink
332
- def grow (y : int , y_parents : List [int ], candidates : List [int ], score ):
333
-
334
- best = - score .score (y , y_parents )
335
- # best = -score.score_nocache(y, y_parents)
336
-
337
- add = None
338
- checked = []
339
- while add is not None or len (candidates ) > 0 :
340
-
341
- if add is not None :
342
- checked .remove (add )
343
- y_parents .append (add )
344
- candidates = checked
345
- checked = []
346
- add = None
347
-
348
- while len (candidates ) > 0 :
349
-
350
- x = candidates .pop ()
351
- y_parents .append (x )
352
- current = - score .score (y , y_parents )
353
- # current = -score.score_nocache(y, y_parents)
354
- y_parents .remove (x )
355
- checked .append (x )
356
-
357
- if current > best :
358
- best = current
359
- add = x
360
-
361
- return best
362
-
363
-
364
- # shrink of grow-shrink
365
- def shrink (y : int , y_parents : List [int ], score ):
366
-
367
- best = - score .score (y , y_parents )
368
- # best = -score.score_nocache(y, y_parents)
369
-
370
- remove = None
371
- checked = 0
372
- while remove is not None or checked < len (y_parents ):
373
-
374
- if remove is not None :
375
- y_parents .remove (remove )
376
- checked = 0
377
- remove = None
378
-
379
- while checked < len (y_parents ):
380
- x = y_parents .pop (0 )
381
- current = - score .score (y , y_parents )
382
- # current = -score.score_nocache(y, y_parents)
383
- y_parents .append (x )
384
- checked += 1
385
-
386
- if current > best :
387
- best = current
388
- remove = x
389
-
390
- return best
391
-
392
-
393
308
# tucks the node at position i into position j
394
309
def tuck (i : int , j : int , order ):
395
310
0 commit comments