Skip to content

WIP: Add GED transformer #13259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft

WIP: Add GED transformer #13259

wants to merge 8 commits into from

Conversation

Genuster
Copy link
Contributor

@Genuster Genuster commented May 22, 2025

What does this implement/fix?

Adds transformer for generalized eigenvalue decomposition (or approximate joint diagonalization) of covariance matrices.
It generalizes xdawn, csp, ssd, and spoc algorithms.

Additional information

Currently testing that it outputs identical filters and patterns as child classes for all tests.
Next step will be to clean up redundant code from the child classes, and unify linear algebra, e.g.:
(1) choose between mne's pinv or np.linalg.pinv for all classes
(2) remove hacky _mult_order from _smart_ged (it's currently there as, numerically, change of multiplication order breaks np.testing.assert_allclose for ssd)
(3) perhaps ssd should use mne.cov.compute_whitener() instead of its own whitener implementation. It won't be identical, but conceptually seems to do the same thing.
(4) add feature to perform GED in the principal subspace for xdawn

@larsoner
Copy link
Member

Already have a failure but fortunately it's just a tol issue I think:

mne/decoding/tests/test_csp.py:444: in test_spoc
    spoc.fit(X, y)
mne/decoding/csp.py:985: in fit
    np.testing.assert_allclose(old_filters, self.filters_)
E   AssertionError: 
E   Not equal to tolerance rtol=1e-07, atol=0
E   
E   Mismatched elements: 1 / 100 (1%)
E   Max absolute difference among violations: 9.04019248e-09
E   Max relative difference among violations: 1.11806536e-07
E    ACTUAL: array([[  2.037415,   1.424886,   2.718162,  -3.07798 ,  -3.862132,
E             1.412549,  -3.821452,   1.276637,   1.899782,  -2.389858],
E          [ 11.534231, -22.178034, -12.321628, -52.410096,  62.876084,...
E    DESIRED: array([[  2.037415,   1.424886,   2.718162,  -3.07798 ,  -3.862132,
E             1.412549,  -3.821452,   1.276637,   1.899782,  -2.389858],
E          [ 11.534231, -22.178034, -12.321628, -52.410096,  62.876084,...

I would just bump the rtol a bit here to 1e-6, and if you know the magnitudes are in the single/double digits then an atol=1e-7 would also be reasonable (could do both).

@Genuster
Copy link
Contributor Author

Genuster commented May 22, 2025

Thanks!
Interesting how it passed macos-13/mamba/3.12, but didn't pass macos-latest/mamba/3.12

It might be that the small difference between filters_ will propagate and increase in patterns_, so rtol/atol won't be much of help for patterns_. But let's see

@larsoner
Copy link
Member

Different architectures, macos-13 is Intel x86_64 and macos-latest is ARM / M1. And Windows also failed, could be use of MKL there or something. I'm cautiously optimistic it's just floating point errors...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants