Skip to content

Commit f0affa5

Browse files
Better Handling of zero median values in Kernel Width (#160)
* Filter zeros out of median computation * removing IDE files from commit --------- Signed-off-by: Nicholas Parente <parentenickj@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9d732dd commit f0affa5

File tree

4 files changed

+23
-2
lines changed

4 files changed

+23
-2
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ venv.bak/
123123
.spyderproject
124124
.spyproject
125125

126+
# PyCharm
127+
.idea/
128+
*.iml
129+
*.iws
130+
126131
# Rope project settings
127132
.ropeproject
128133

doc/whats_new/v0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ Changelog
5050
- |Feature| Add a suite of general categorical data CI tests, by `Adam Li`_ (:pr:`128`)
5151
- |Feature| Implement CAM, SCORE, DAS, NoGAM algorithms in ``dodiscover.toporder`` submodule (:pr:`129`)
5252
- |Feature| Add Psi-FCI and I-FCI algorithm for handling soft-interventional data, :class:`dodiscover.constraint.PsiFCI` by `Adam Li`_ (:pr:`111`)
53+
- |Fix| Update the kernel_width method to filter out zeros before computing the median pairwise distance, by `Nick Parente`_ (:pr:`160`)
5354

5455
Code and Documentation Contributors
5556
-----------------------------------

dodiscover/toporder/utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ def kernel_width(X: NDArray):
1919
Matrix of the data.
2020
"""
2121
X_diff = np.expand_dims(X, axis=1) - X # Gram matrix of the data
22-
D = np.linalg.norm(X_diff, axis=2)
23-
s = np.median(D.flatten())
22+
D = np.linalg.norm(X_diff, axis=2).flatten()
23+
D_nonzeros = D[D > 0] # Remove zeros
24+
s = np.median(D_nonzeros) if np.any(D_nonzeros) else 1
2425
return s
2526

2627

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import numpy as np
2+
3+
from dodiscover.toporder.utils import kernel_width
4+
5+
6+
def test_kernel_width_when_zero_median_pairwise_distances():
7+
arr = np.zeros((100, 1), dtype=np.int64)
8+
arr[1] = 1
9+
assert kernel_width(arr) == 1
10+
11+
12+
def test_kernel_width_when_all_zero_pairwise_distances():
13+
arr = np.ones((100, 1), dtype=np.int64)
14+
assert kernel_width(arr) == 1

0 commit comments

Comments
 (0)