Skip to content

Commit a8964d6

Browse files
authoredJul 18, 2024
Merge pull request #188 from py-why/add_boss
add boss + gets and update grasp to use gsts
2 parents ae2f5b2 + 5399515 commit a8964d6

File tree

7 files changed

+481
-177
lines changed

7 files changed

+481
-177
lines changed
 

‎causallearn/score/LocalScoreFunction.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def local_score_BIC(Data: ndarray, i: int, PAi: List[int], parameters=None) -> f
3434
if len(PAi) == 0:
3535
return n * np.log(cov[i, i])
3636

37-
yX = np.mat(cov[np.ix_([i], PAi)])
38-
XX = np.mat(cov[np.ix_(PAi, PAi)])
39-
H = np.log(cov[i, i] - yX * np.linalg.inv(XX) * yX.T)
37+
yX = cov[np.ix_([i], PAi)]
38+
XX = cov[np.ix_(PAi, PAi)]
39+
H = np.log(cov[i, i] - yX @ np.linalg.inv(XX) @ yX.T)
4040

4141
return n * H + np.log(n) * len(PAi) * lambda_value
4242

@@ -68,9 +68,9 @@ def local_score_BIC_from_cov(
6868
if len(PAi) == 0:
6969
return n * np.log(cov[i, i])
7070

71-
yX = np.mat(cov[np.ix_([i], PAi)])
72-
XX = np.mat(cov[np.ix_(PAi, PAi)])
73-
H = np.log(cov[i, i] - yX * np.linalg.inv(XX) * yX.T)
71+
yX = cov[np.ix_([i], PAi)]
72+
XX = cov[np.ix_(PAi, PAi)]
73+
H = np.log(cov[i, i] - yX @ np.linalg.inv(XX) @ yX.T)
7474

7575
return n * H + np.log(n) * len(PAi) * lambda_value
7676

‎causallearn/score/LocalScoreFunctionClass.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,9 @@ def score(self, i: int, PAi: List[int]) -> float:
4646
self.score_cache[i][hash_key] = self.local_score_fun(self.data, i, PAi, self.parameters)
4747

4848
return self.score_cache[i][hash_key]
49+
50+
def score_nocache(self, i: int, PAi: List[int]) -> float:
51+
if self.local_score_fun == local_score_BIC_from_cov:
52+
return self.local_score_fun((self.cov, self.n), i, PAi, self.parameters)
53+
else:
54+
return self.local_score_fun(self.data, i, PAi, self.parameters)
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import random
2+
import sys
3+
import time
4+
import warnings
5+
from typing import Any, Dict, List, Optional
6+
7+
import numpy as np
8+
from causallearn.graph.GeneralGraph import GeneralGraph
9+
from causallearn.graph.GraphNode import GraphNode
10+
from causallearn.score.LocalScoreFunction import (
11+
local_score_BDeu,
12+
local_score_BIC,
13+
local_score_BIC_from_cov,
14+
local_score_cv_general,
15+
local_score_cv_multi,
16+
local_score_marginal_general,
17+
local_score_marginal_multi,
18+
)
19+
from causallearn.search.PermutationBased.gst import GST;
20+
from causallearn.score.LocalScoreFunctionClass import LocalScoreClass
21+
from causallearn.utils.DAG2CPDAG import dag2cpdag
22+
23+
24+
def boss(
25+
X: np.ndarray,
26+
score_func: str = "local_score_BIC_from_cov",
27+
parameters: Optional[Dict[str, Any]] = None,
28+
verbose: Optional[bool] = True,
29+
node_names: Optional[List[str]] = None,
30+
) -> GeneralGraph:
31+
"""
32+
Perform a best order score search (BOSS) algorithm
33+
34+
Parameters
35+
----------
36+
X : data set (numpy ndarray), shape (n_samples, n_features). The input data, where n_samples is the number of samples and n_features is the number of features.
37+
score_func : the string name of score function. (str(one of 'local_score_CV_general', 'local_score_marginal_general',
38+
'local_score_CV_multi', 'local_score_marginal_multi', 'local_score_BIC', 'local_score_BIC_from_cov', 'local_score_BDeu')).
39+
parameters : when using CV likelihood,
40+
parameters['kfold']: k-fold cross validation
41+
parameters['lambda']: regularization parameter
42+
parameters['dlabel']: for variables with multi-dimensions,
43+
indicate which dimensions belong to the i-th variable.
44+
verbose : whether to print the time cost and verbose output of the algorithm.
45+
46+
Returns
47+
-------
48+
G : learned causal graph, where G.graph[j,i] = 1 and G.graph[i,j] = -1 indicates i --> j, G.graph[i,j] = G.graph[j,i] = -1 indicates i --- j.
49+
"""
50+
51+
X = X.copy()
52+
n, p = X.shape
53+
if n < p:
54+
warnings.warn("The number of features is much larger than the sample size!")
55+
56+
if score_func == "local_score_CV_general":
57+
# % k-fold negative cross validated likelihood based on regression in RKHS
58+
if parameters is None:
59+
parameters = {
60+
"kfold": 10, # 10 fold cross validation
61+
"lambda": 0.01,
62+
} # regularization parameter
63+
localScoreClass = LocalScoreClass(
64+
data=X, local_score_fun=local_score_cv_general, parameters=parameters
65+
)
66+
elif score_func == "local_score_marginal_general":
67+
# negative marginal likelihood based on regression in RKHS
68+
parameters = {}
69+
localScoreClass = LocalScoreClass(
70+
data=X, local_score_fun=local_score_marginal_general, parameters=parameters
71+
)
72+
elif score_func == "local_score_CV_multi":
73+
# k-fold negative cross validated likelihood based on regression in RKHS
74+
# for data with multi-variate dimensions
75+
if parameters is None:
76+
parameters = {
77+
"kfold": 10,
78+
"lambda": 0.01,
79+
"dlabel": {},
80+
} # regularization parameter
81+
for i in range(X.shape[1]):
82+
parameters["dlabel"]["{}".format(i)] = i
83+
localScoreClass = LocalScoreClass(
84+
data=X, local_score_fun=local_score_cv_multi, parameters=parameters
85+
)
86+
elif score_func == "local_score_marginal_multi":
87+
# negative marginal likelihood based on regression in RKHS
88+
# for data with multi-variate dimensions
89+
if parameters is None:
90+
parameters = {"dlabel": {}}
91+
for i in range(X.shape[1]):
92+
parameters["dlabel"]["{}".format(i)] = i
93+
localScoreClass = LocalScoreClass(
94+
data=X, local_score_fun=local_score_marginal_multi, parameters=parameters
95+
)
96+
elif score_func == "local_score_BIC":
97+
# SEM BIC score
98+
warnings.warn("Please use 'local_score_BIC_from_cov' instead")
99+
if parameters is None:
100+
parameters = {"lambda_value": 2}
101+
localScoreClass = LocalScoreClass(
102+
data=X, local_score_fun=local_score_BIC, parameters=parameters
103+
)
104+
elif score_func == "local_score_BIC_from_cov":
105+
# SEM BIC score
106+
if parameters is None:
107+
parameters = {"lambda_value": 2}
108+
localScoreClass = LocalScoreClass(
109+
data=X, local_score_fun=local_score_BIC_from_cov, parameters=parameters
110+
)
111+
elif score_func == "local_score_BDeu":
112+
# BDeu score
113+
localScoreClass = LocalScoreClass(
114+
data=X, local_score_fun=local_score_BDeu, parameters=None
115+
)
116+
else:
117+
raise Exception("Unknown function!")
118+
119+
score = localScoreClass
120+
gsts = [GST(i, score) for i in range(p)]
121+
122+
node_names = [("X%d" % (i + 1)) for i in range(p)] if node_names is None else node_names
123+
nodes = []
124+
125+
for name in node_names:
126+
node = GraphNode(name)
127+
nodes.append(node)
128+
129+
G = GeneralGraph(nodes)
130+
131+
runtime = time.perf_counter()
132+
133+
order = [v for v in range(p)]
134+
135+
gsts = [GST(v, score) for v in order]
136+
parents = {v: [] for v in order}
137+
138+
variables = [v for v in order]
139+
while True:
140+
improved = False
141+
random.shuffle(variables)
142+
if verbose:
143+
for i, v in enumerate(order):
144+
parents[v].clear()
145+
gsts[v].trace(order[:i], parents[v])
146+
sys.stdout.write("\rBOSS edge count: %i " % np.sum([len(parents[v]) for v in range(p)]))
147+
sys.stdout.flush()
148+
149+
for v in variables:
150+
improved |= better_mutation(v, order, gsts)
151+
if not improved: break
152+
153+
for i, v in enumerate(order):
154+
parents[v].clear()
155+
gsts[v].trace(order[:i], parents[v])
156+
157+
runtime = time.perf_counter() - runtime
158+
159+
if verbose:
160+
sys.stdout.write("\nBOSS completed in: %.2fs \n" % runtime)
161+
sys.stdout.flush()
162+
163+
for y in range(p):
164+
for x in parents[y]:
165+
G.add_directed_edge(nodes[x], nodes[y])
166+
167+
G = dag2cpdag(G)
168+
169+
return G
170+
171+
172+
def reversed_enumerate(iter, j):
173+
for w in reversed(iter):
174+
yield j, w
175+
j -= 1
176+
177+
178+
def better_mutation(v, order, gsts):
179+
i = order.index(v)
180+
p = len(order)
181+
scores = np.zeros(p + 1)
182+
183+
prefix = []
184+
score = 0
185+
for j, w in enumerate(order):
186+
scores[j] = gsts[v].trace(prefix) + score
187+
if v != w:
188+
score += gsts[w].trace(prefix)
189+
prefix.append(w)
190+
191+
scores[p] = gsts[v].trace(prefix) + score
192+
best = p
193+
194+
prefix.append(v)
195+
score = 0
196+
for j, w in reversed_enumerate(order, p - 1):
197+
if v != w:
198+
prefix.remove(w)
199+
score += gsts[w].trace(prefix)
200+
scores[j] += score
201+
if scores[j] > scores[best]: best = j
202+
203+
if scores[i] + 1e-6 > scores[best]: return False
204+
order.remove(v)
205+
order.insert(best - int(best > i), v)
206+
207+
return True

0 commit comments

Comments
 (0)