1
+ import random
2
+ import sys
3
+ import time
4
+ import warnings
5
+ from typing import Any , Dict , List , Optional
6
+
7
+ import numpy as np
8
+ from causallearn .graph .GeneralGraph import GeneralGraph
9
+ from causallearn .graph .GraphNode import GraphNode
10
+ from causallearn .score .LocalScoreFunction import (
11
+ local_score_BDeu ,
12
+ local_score_BIC ,
13
+ local_score_BIC_from_cov ,
14
+ local_score_cv_general ,
15
+ local_score_cv_multi ,
16
+ local_score_marginal_general ,
17
+ local_score_marginal_multi ,
18
+ )
19
+ from causallearn .search .PermutationBased .gst import GST ;
20
+ from causallearn .score .LocalScoreFunctionClass import LocalScoreClass
21
+ from causallearn .utils .DAG2CPDAG import dag2cpdag
22
+
23
+
24
+ def boss (
25
+ X : np .ndarray ,
26
+ score_func : str = "local_score_BIC_from_cov" ,
27
+ parameters : Optional [Dict [str , Any ]] = None ,
28
+ verbose : Optional [bool ] = True ,
29
+ node_names : Optional [List [str ]] = None ,
30
+ ) -> GeneralGraph :
31
+ """
32
+ Perform a best order score search (BOSS) algorithm
33
+
34
+ Parameters
35
+ ----------
36
+ X : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features.
37
+ score_func : the string name of score function. (str(one of 'local_score_CV_general', 'local_score_marginal_general',
38
+ 'local_score_CV_multi', 'local_score_marginal_multi', 'local_score_BIC', 'local_score_BIC_from_cov', 'local_score_BDeu')).
39
+ parameters : when using CV likelihood,
40
+ parameters['kfold']: k-fold cross validation
41
+ parameters['lambda']: regularization parameter
42
+ parameters['dlabel']: for variables with multi-dimensions,
43
+ indicate which dimensions belong to the i-th variable.
44
+ verbose : whether to print the time cost and verbose output of the algorithm.
45
+
46
+ Returns
47
+ -------
48
+ G : learned causal graph, where G.graph[j,i] = 1 and G.graph[i,j] = -1 indicates i --> j, G.graph[i,j] = G.graph[j,i] = -1 indicates i --- j.
49
+ """
50
+
51
+ X = X .copy ()
52
+ n , p = X .shape
53
+ if n < p :
54
+ warnings .warn ("The number of features is much larger than the sample size!" )
55
+
56
+ if score_func == "local_score_CV_general" :
57
+ # % k-fold negative cross validated likelihood based on regression in RKHS
58
+ if parameters is None :
59
+ parameters = {
60
+ "kfold" : 10 , # 10 fold cross validation
61
+ "lambda" : 0.01 ,
62
+ } # regularization parameter
63
+ localScoreClass = LocalScoreClass (
64
+ data = X , local_score_fun = local_score_cv_general , parameters = parameters
65
+ )
66
+ elif score_func == "local_score_marginal_general" :
67
+ # negative marginal likelihood based on regression in RKHS
68
+ parameters = {}
69
+ localScoreClass = LocalScoreClass (
70
+ data = X , local_score_fun = local_score_marginal_general , parameters = parameters
71
+ )
72
+ elif score_func == "local_score_CV_multi" :
73
+ # k-fold negative cross validated likelihood based on regression in RKHS
74
+ # for data with multi-variate dimensions
75
+ if parameters is None :
76
+ parameters = {
77
+ "kfold" : 10 ,
78
+ "lambda" : 0.01 ,
79
+ "dlabel" : {},
80
+ } # regularization parameter
81
+ for i in range (X .shape [1 ]):
82
+ parameters ["dlabel" ]["{}" .format (i )] = i
83
+ localScoreClass = LocalScoreClass (
84
+ data = X , local_score_fun = local_score_cv_multi , parameters = parameters
85
+ )
86
+ elif score_func == "local_score_marginal_multi" :
87
+ # negative marginal likelihood based on regression in RKHS
88
+ # for data with multi-variate dimensions
89
+ if parameters is None :
90
+ parameters = {"dlabel" : {}}
91
+ for i in range (X .shape [1 ]):
92
+ parameters ["dlabel" ]["{}" .format (i )] = i
93
+ localScoreClass = LocalScoreClass (
94
+ data = X , local_score_fun = local_score_marginal_multi , parameters = parameters
95
+ )
96
+ elif score_func == "local_score_BIC" :
97
+ # SEM BIC score
98
+ warnings .warn ("Please use 'local_score_BIC_from_cov' instead" )
99
+ if parameters is None :
100
+ parameters = {"lambda_value" : 2 }
101
+ localScoreClass = LocalScoreClass (
102
+ data = X , local_score_fun = local_score_BIC , parameters = parameters
103
+ )
104
+ elif score_func == "local_score_BIC_from_cov" :
105
+ # SEM BIC score
106
+ if parameters is None :
107
+ parameters = {"lambda_value" : 2 }
108
+ localScoreClass = LocalScoreClass (
109
+ data = X , local_score_fun = local_score_BIC_from_cov , parameters = parameters
110
+ )
111
+ elif score_func == "local_score_BDeu" :
112
+ # BDeu score
113
+ localScoreClass = LocalScoreClass (
114
+ data = X , local_score_fun = local_score_BDeu , parameters = None
115
+ )
116
+ else :
117
+ raise Exception ("Unknown function!" )
118
+
119
+ score = localScoreClass
120
+ gsts = [GST (i , score ) for i in range (p )]
121
+
122
+ node_names = [("X%d" % (i + 1 )) for i in range (p )] if node_names is None else node_names
123
+ nodes = []
124
+
125
+ for name in node_names :
126
+ node = GraphNode (name )
127
+ nodes .append (node )
128
+
129
+ G = GeneralGraph (nodes )
130
+
131
+ runtime = time .perf_counter ()
132
+
133
+ order = [v for v in range (p )]
134
+
135
+ gsts = [GST (v , score ) for v in order ]
136
+ parents = {v : [] for v in order }
137
+
138
+ variables = [v for v in order ]
139
+ while True :
140
+ improved = False
141
+ random .shuffle (variables )
142
+ if verbose :
143
+ for i , v in enumerate (order ):
144
+ parents [v ].clear ()
145
+ gsts [v ].trace (order [:i ], parents [v ])
146
+ sys .stdout .write ("\r BOSS edge count: %i " % np .sum ([len (parents [v ]) for v in range (p )]))
147
+ sys .stdout .flush ()
148
+
149
+ for v in variables :
150
+ improved |= better_mutation (v , order , gsts )
151
+ if not improved : break
152
+
153
+ for i , v in enumerate (order ):
154
+ parents [v ].clear ()
155
+ gsts [v ].trace (order [:i ], parents [v ])
156
+
157
+ runtime = time .perf_counter () - runtime
158
+
159
+ if verbose :
160
+ sys .stdout .write ("\n BOSS completed in: %.2fs \n " % runtime )
161
+ sys .stdout .flush ()
162
+
163
+ for y in range (p ):
164
+ for x in parents [y ]:
165
+ G .add_directed_edge (nodes [x ], nodes [y ])
166
+
167
+ G = dag2cpdag (G )
168
+
169
+ return G
170
+
171
+
172
+ def reversed_enumerate (iter , j ):
173
+ for w in reversed (iter ):
174
+ yield j , w
175
+ j -= 1
176
+
177
+
178
+ def better_mutation (v , order , gsts ):
179
+ i = order .index (v )
180
+ p = len (order )
181
+ scores = np .zeros (p + 1 )
182
+
183
+ prefix = []
184
+ score = 0
185
+ for j , w in enumerate (order ):
186
+ scores [j ] = gsts [v ].trace (prefix ) + score
187
+ if v != w :
188
+ score += gsts [w ].trace (prefix )
189
+ prefix .append (w )
190
+
191
+ scores [p ] = gsts [v ].trace (prefix ) + score
192
+ best = p
193
+
194
+ prefix .append (v )
195
+ score = 0
196
+ for j , w in reversed_enumerate (order , p - 1 ):
197
+ if v != w :
198
+ prefix .remove (w )
199
+ score += gsts [w ].trace (prefix )
200
+ scores [j ] += score
201
+ if scores [j ] > scores [best ]: best = j
202
+
203
+ if scores [i ] + 1e-6 > scores [best ]: return False
204
+ order .remove (v )
205
+ order .insert (best - int (best > i ), v )
206
+
207
+ return True
0 commit comments