Skip to content

Minor issues fix #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 9 additions & 149 deletions causallearn/search/ConstraintBased/FCI.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from causallearn.utils.cit import *
from causallearn.utils.FAS import fas
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from itertools import combinations


def traverseSemiDirected(node: Node, edge: Edge) -> Node | None:
if node == edge.get_node1():
Expand Down Expand Up @@ -320,9 +320,8 @@ def rulesR1R2cycle(graph: Graph, bk: BackgroundKnowledge | None, changeFlag: boo

def isNoncollider(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], node_i: Node, node_j: Node,
node_k: Node) -> bool:
node_map = graph.get_node_map()
sep_set = sep_sets.get((node_map[node_i], node_map[node_k]))
return sep_set is not None and sep_set.__contains__(node_map[node_j])
sep_set = sep_sets[(graph.get_node_map()[node_i], graph.get_node_map()[node_k])]
return sep_set is not None and sep_set.__contains__(graph.get_node_map()[node_j])


def ruleR3(graph: Graph, sep_sets: Dict[Tuple[int, int], Set[int]], bk: BackgroundKnowledge | None, changeFlag: bool,
Expand Down Expand Up @@ -543,142 +542,6 @@ def ruleR4B(graph: Graph, maxPathLength: int, data: ndarray, independence_test_m
return change_flag



def rule8(graph: Graph, nodes: List[Node]):
nodes = graph.get_nodes()
changeFlag = False
for node_B in nodes:
adj = graph.get_adjacent_nodes(node_B)
if len(adj) < 2:
continue

cg = ChoiceGenerator(len(adj), 2)
combination = cg.next()

while combination is not None:
node_A = adj[combination[0]]
node_C = adj[combination[1]]
combination = cg.next()

if(graph.get_endpoint(node_A, node_B) == Endpoint.ARROW and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
graph.is_adjacent_to(node_A, node_C) and \
graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE) or \
(graph.get_endpoint(node_A, node_B) == Endpoint.CIRCLE and graph.get_endpoint(node_B, node_A) == Endpoint.TAIL and \
graph.get_endpoint(node_B, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_B) == Endpoint.TAIL and \
graph.is_adjacent_to(node_A, node_C) and \
graph.get_endpoint(node_A, node_C) == Endpoint.ARROW and graph.get_endpoint(node_C, node_A)== Endpoint.CIRCLE):
edge1 = graph.get_edge(node_A, node_C)
graph.remove_edge(edge1)
graph.add_edge(Edge(node_A, node_C,Endpoint.TAIL, Endpoint.ARROW))
changeFlag = True

return changeFlag



def is_possible_parent(graph: Graph, potential_parent_node, child_node):
if graph.node_map[potential_parent_node] == graph.node_map[child_node]:
return False
if not graph.is_adjacent_to(potential_parent_node, child_node):
return False

if graph.get_endpoint(child_node, potential_parent_node) == Endpoint.ARROW or \
graph.get_endpoint(potential_parent_node, child_node) == Endpoint.TAIL:
return False
else:
return True


def find_possible_children(graph: Graph, parent_node, en_nodes=None):
if en_nodes is None:
nodes = graph.get_nodes()
en_nodes = [node for node in nodes if graph.node_map[node] != graph.node_map[parent_node]]

potential_child_nodes = set()
for potential_node in en_nodes:
if is_possible_parent(graph, potential_parent_node=parent_node, child_node=potential_node):
potential_child_nodes.add(potential_node)

return potential_child_nodes

def rule9(graph: Graph, nodes: List[Node]):
changeFlag = False
nodes = graph.get_nodes()
for node_C in nodes:
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
for node_A in intoCArrows:
# we want A o--> C
if not graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE:
continue

# look for a possibly directed uncovered path s.t. B and C are not connected (for the given A o--> C
a_node_idx = graph.node_map[node_A]
c_node_idx = graph.node_map[node_C]
a_adj_nodes = graph.get_adjacent_nodes(node_A)
nodes_set = [node for node in a_adj_nodes if graph.node_map[node] != a_node_idx and graph.node_map[node]!= c_node_idx]
possible_children = find_possible_children(graph, node_A, nodes_set)
for node_B in possible_children:
if graph.is_adjacent_to(node_B, node_C):
continue
if existsSemiDirectedPath(node_from=node_B, node_to=node_C, G=graph):
edge1 = graph.get_edge(node_A, node_C)
graph.remove_edge(edge1)
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
changeFlag = True
break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A
return changeFlag


def rule10(graph: Graph):
changeFlag = False
nodes = graph.get_nodes()
for node_C in nodes:
intoCArrows = graph.get_nodes_into(node_C, Endpoint.ARROW)
if len(intoCArrows) < 2:
continue
# get all A where A o-> C
Anodes = [node_A for node_A in intoCArrows if graph.get_endpoint(node_C, node_A) == Endpoint.CIRCLE]
if len(Anodes) == 0:
continue

for node_A in Anodes:
A_adj_nodes = graph.get_adjacent_nodes(node_A)
en_nodes = [i for i in A_adj_nodes if i is not node_C]
A_possible_children = find_possible_children(graph, parent_node=node_A, en_nodes=en_nodes)
if len(A_possible_children) < 2:
continue

gen = ChoiceGenerator(len(intoCArrows), 2)
choice = gen.next()
while choice is not None:
node_B = intoCArrows[choice[0]]
node_D = intoCArrows[choice[1]]

choice = gen.next()
# we want B->C<-D
if graph.get_endpoint(node_C, node_B) != Endpoint.TAIL:
continue

if graph.get_endpoint(node_C, node_D) != Endpoint.TAIL:
continue

for children in combinations(A_possible_children, 2):
child_one, child_two = children
if not existsSemiDirectedPath(node_from=child_one, node_to=node_B, G=graph) or \
not existsSemiDirectedPath(node_from=child_two, node_to=node_D, G=graph):
continue

if not graph.is_adjacent_to(child_one, child_two):
edge1 = graph.get_edge(node_A, node_C)
graph.remove_edge(edge1)
graph.add_edge(Edge(node_A, node_C, Endpoint.TAIL, Endpoint.ARROW))
changeFlag = True
break #once we found it, break out since we have already oriented Ao->C to A->C, we want to find the next A

return changeFlag


def visibleEdgeHelperVisit(graph: Graph, node_c: Node, node_a: Node, node_b: Node, path: List[Node]) -> bool:
if path.__contains__(node_a):
return False
Expand Down Expand Up @@ -828,8 +691,10 @@ def _contains_all(set_a: Set[Node], set_b: Set[Node]):
break



def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float = 0.05, depth: int = -1,
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None, show_progress: bool = True,
max_path_length: int = -1, verbose: bool = False, background_knowledge: BackgroundKnowledge | None = None,
show_progress: bool = True, node_names = None,
**kwargs) -> Tuple[Graph, List[Edge]]:
"""
Perform Fast Causal Inference (FCI) algorithm for causal discovery
Expand Down Expand Up @@ -884,8 +749,10 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =


nodes = []
if node_names is None:
node_names = [f"X{i + 1}" for i in range(dataset.shape[1])]
for i in range(dataset.shape[1]):
node = GraphNode(f"X{i + 1}")
node = GraphNode(node_names[i])
node.add_attribute("id", i)
nodes.append(node)

Expand Down Expand Up @@ -923,13 +790,6 @@ def fci(dataset: ndarray, independence_test_method: str=fisherz, alpha: float =
if verbose:
print("Epoch")

# rule 8
change_flag = rule8(graph,nodes)
# rule 9
change_flag = rule9(graph, nodes)
# rule 10
change_flag = rule10(graph)

graph.set_pag(True)

edges = get_color_edges(graph)
Expand Down
2 changes: 1 addition & 1 deletion causallearn/utils/PCUtils/BackgroundKnowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def is_forbidden(self, node1: Node, node2: Node) -> bool:

# then check in tier_map
if self.tier_value_map.keys().__contains__(node1) and self.tier_value_map.keys().__contains__(node2):
if self.tier_value_map.get(node1) >= self.tier_value_map.get(node2):
if self.tier_value_map.get(node1) > self.tier_value_map.get(node2): # Allow orientation within the same tier
return True

return False
Expand Down
Loading