Skip to content

Commit d98225b

Browse files
authoredApr 5, 2024
Merge pull request #168 from kenneth-lee-ch/main
Added rules 8, 9, 10 to FCI
2 parents 36b0829 + a1b0919 commit d98225b

File tree

1 file changed

+144
-2
lines changed
  • causallearn/search/ConstraintBased

1 file changed

+144
-2
lines changed
 

‎causallearn/search/ConstraintBased/FCI.py

Lines changed: 144 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from causallearn.utils.cit import *
1616
from causallearn.utils.FAS import fas
1717
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
18-
18+
from itertools import combinations
1919

2020
def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
2121
if node == edge.get_node1():
@@ -542,6 +542,142 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m
542542
return change_flag
543543

544544

545+
546+
def rule8(graph: Graph, nodes: List[Node]):
547+
nodes = graph.get_nodes()
548+
changeFlag = False
549+
for node_B in nodes:
550+
adj = graph.get_adjacent_nodes(node_B)
551+
if len(adj) < 2:
552+
continue
553+
554+
cg = ChoiceGenerator(len(adj), 2)
555+
combination = cg.next()
556+
557+
while combination is not None:
558+
node_A = adj[combination[0]]
559+
node_C = adj[combination[1]]
560+
combination = cg.next()
561+
562+
if(graph.get_endpoint(node_A, node_B) == Endpoint.ARROW and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
563+
graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
564+
graph.is_adjacent_to(node_A, node_C) and \
565+
graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE) or \
566+
(graph.get_endpoint(node_A, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
567+
graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
568+
graph.is_adjacent_to(node_A, node_C) and \
569+
graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE):
570+
edge1 = graph.get_edge(node_A, node_C)
571+
graph.remove_edge(edge1)
572+
graph.add_edge(Edge(node_A, node_C,Endpoint.TAIL, Endpoint.ARROW))
573+
changeFlag = True
574+
575+
return changeFlag
576+
577+
578+
579+
def is_possible_parent(graph: Graph, potential_parent_node, child_node):
580+
if graph.node_map[potential_parent_node] == graph.node_map[child_node]:
581+
return False
582+
if not graph.is_adjacent_to(potential_parent_node, child_node):
583+
return False
584+
585+
if graph.get_endpoint(child_node, potential_parent_node) == Endpoint.ARROW or \
586+
graph.get_endpoint(potential_parent_node, child_node) == Endpoint.TAIL:
587+
return False
588+
else:
589+
return True
590+
591+
592+
def find_possible_children(graph: Graph, parent_node, en_nodes=None):
593+
if en_nodes is None:
594+
nodes = graph.get_nodes()
595+
en_nodes = [node for node in nodes if graph.node_map[node] != graph.node_map[parent_node]]
596+
597+
potential_child_nodes = set()
598+
for potential_node in en_nodes:
599+
if is_possible_parent(graph, potential_parent_node=parent_node, child_node=potential_node):
600+
potential_child_nodes.add(potential_node)
601+
602+
return potential_child_nodes
603+
604+
def rule9(graph: Graph, nodes: List[Node]):
605+
changeFlag = False
606+
nodes = graph.get_nodes()
607+
for node_C in nodes:
608+
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
609+
for node_A in intoCArrows:
610+
# we want A o--> C
611+
if not graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE:
612+
continue
613+
614+
# look for a possibly directed uncovered path s.t. B and C are not connected (for the given A o--> C
615+
a_node_idx = graph.node_map[node_A]
616+
c_node_idx = graph.node_map[node_C]
617+
a_adj_nodes = graph.get_adjacent_nodes(node_A)
618+
nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= c_node_idx]
619+
possible_children = find_possible_children(graph, node_A, nodes_set)
620+
for node_B in possible_children:
621+
if graph.is_adjacent_to(node_B, node_C):
622+
continue
623+
if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph):
624+
edge1 = graph.get_edge(node_A, node_C)
625+
graph.remove_edge(edge1)
626+
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
627+
changeFlag = True
628+
break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A
629+
return changeFlag
630+
631+
632+
def rule10(graph: Graph):
633+
changeFlag = False
634+
nodes = graph.get_nodes()
635+
for node_C in nodes:
636+
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
637+
if len(intoCArrows) < 2:
638+
continue
639+
# get all A where A o-> C
640+
Anodes = [node_A for node_A in intoCArrows if graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE]
641+
if len(Anodes) == 0:
642+
continue
643+
644+
for node_A in Anodes:
645+
A_adj_nodes = graph.get_adjacent_nodes(node_A)
646+
en_nodes = [i for i in A_adj_nodes if i is not node_C]
647+
A_possible_children = find_possible_children(graph, parent_node=node_A, en_nodes=en_nodes)
648+
if len(A_possible_children) < 2:
649+
continue
650+
651+
gen = ChoiceGenerator(len(intoCArrows), 2)
652+
choice = gen.next()
653+
while choice is not None:
654+
node_B = intoCArrows[choice[0]]
655+
node_D = intoCArrows[choice[1]]
656+
657+
choice = gen.next()
658+
# we want B->C<-D
659+
if graph.get_endpoint(node_C, node_B) != Endpoint.TAIL:
660+
continue
661+
662+
if graph.get_endpoint(node_C, node_D) != Endpoint.TAIL:
663+
continue
664+
665+
for children in combinations(A_possible_children, 2):
666+
child_one, child_two = children
667+
if not existsSemiDirectedPath(node_from=child_one, node_to=node_B, G=graph) or \
668+
not existsSemiDirectedPath(node_from=child_two, node_to=node_D, G=graph):
669+
continue
670+
671+
if not graph.is_adjacent_to(child_one, child_two):
672+
edge1 = graph.get_edge(node_A, node_C)
673+
graph.remove_edge(edge1)
674+
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
675+
changeFlag = True
676+
break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A
677+
678+
return changeFlag
679+
680+
545681
def visibleEdgeHelperVisit(graph: Graph, node_c: Node, node_a: Node, node_b: Node, path: List[Node]) -> bool:
546682
if path.__contains__(node_a):
547683
return False
@@ -691,7 +827,6 @@ def _contains_all(set_a: Set[Node], set_b: Set[Node]):
691827
break
692828

693829

694-
695830
def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1,
696831
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True,
697832
**kwargs) -> Tuple[Graph, List[Edge]]:
@@ -787,6 +922,13 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
787922
if verbose:
788923
print("Epoch")
789924

925+
# rule 8
926+
change_flag = rule8(graph,nodes)
927+
# rule 9
928+
change_flag = rule9(graph, nodes)
929+
# rule 10
930+
change_flag = rule10(graph)
931+
790932
graph.set_pag(True)
791933

792934
edges = get_color_edges(graph)

0 commit comments

Comments
 (0)