Skip to content

WIP: Add GED transformer #13259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
632c819
assert_allclose for base ged for csp, spoc, ssd and xdawn
Genuster May 22, 2025
4f8b5fa
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster May 22, 2025
7c072d1
update _epoch_cov logging following merge
Genuster May 22, 2025
211d23f
add a few preliminary docstrings
Genuster May 22, 2025
a701e42
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster May 22, 2025
0d58c8d
bump rtol/atol for spoc
Genuster May 22, 2025
2a1c5cb
Add big sklearn compliance test
Genuster May 29, 2025
3e9c32c
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster May 29, 2025
6e8b3aa
add __sklearn_tags__ to vulture's whitelist
Genuster Jun 2, 2025
b2e24ea
calm vulture down per attribute
Genuster Jun 2, 2025
fbd585e
put the TransformerMixin back
Genuster Jun 2, 2025
a9d5390
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 2, 2025
d142bd0
fix validation of covariances
Genuster Jun 2, 2025
9329494
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 3, 2025
6796366
add gedtranformer tests with audvis dataset
Genuster Jun 3, 2025
7a291b1
fixes following Eric's comments
Genuster Jun 4, 2025
9b34bd3
Merge remote-tracking branch 'upstream/main' into base-GED
Genuster Jun 4, 2025
7c867ec
document shapes
Genuster Jun 4, 2025
e1e8d6d
another small test for GEDtransformer
Genuster Jun 6, 2025
5edc6fa
change name of restricting map to restricting matrix
Genuster Jun 6, 2025
89fb141
a few more ged tests
Genuster Jun 6, 2025
3986c99
fix multiplication order in original SSD
Genuster Jun 6, 2025
11b038f
add assert_allclose to xdawn and csp transform methods.
Genuster Jun 6, 2025
25e1ae3
more ged tests
Genuster Jun 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 266 additions & 0 deletions mne/decoding/_covs_ged.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
"""Covariance estimation for GED transformers."""

# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import numpy as np

from .._fiff.meas_info import Info, create_info
from .._fiff.pick import _picks_to_idx
from ..cov import Covariance, _compute_rank_raw_array, _regularized_covariance
from ..defaults import _handle_default
from ..filter import filter_data
from ..rank import compute_rank
from ..utils import _verbose_safe_false, logger


def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info):
"""Concatenate epochs before computing the covariance."""
_, n_channels, _ = x_class.shape

x_class = x_class.transpose(1, 0, 2).reshape(n_channels, -1)
cov = _regularized_covariance(
x_class,
reg=reg,
method_params=cov_method_params,
rank=rank,
info=info,
cov_kind=cov_kind,
log_rank=log_rank,
log_ch_type="data",
)
weight = x_class.shape[0]

return cov, weight


def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info):
"""Mean of per-epoch covariances."""
name = reg if isinstance(reg, str) else "empirical"
name += " with shrinkage" if isinstance(reg, float) else ""
logger.info(
f"Estimating {cov_kind + (' ' if cov_kind else '')}"
f"covariance (average over epochs; {name.upper()})"
)
cov = sum(
_regularized_covariance(
this_X,
reg=reg,
method_params=cov_method_params,
rank=rank,
info=info,
cov_kind=cov_kind,
log_rank=log_rank and ii == 0,
log_ch_type="data",
verbose=_verbose_safe_false(),
)
for ii, this_X in enumerate(x_class)
)
cov /= len(x_class)
weight = len(x_class)

return cov, weight


def _csp_estimate(X, y, reg, cov_method_params, cov_est, rank, norm_trace):
_, n_channels, _ = X.shape
classes_ = np.unique(y)
if cov_est == "concat":
cov_estimator = _concat_cov
elif cov_est == "epoch":
cov_estimator = _epoch_cov
# Someday we could allow the user to pass this, then we wouldn't need to convert
# but in the meantime they can use a pipeline with a scaler
_info = create_info(n_channels, 1000.0, "mag")
if isinstance(rank, dict):
_rank = {"mag": sum(rank.values())}
else:
_rank = _compute_rank_raw_array(
X.transpose(1, 0, 2).reshape(X.shape[1], -1),
_info,
rank=rank,
scalings=None,
log_ch_type="data",
)

covs = []
sample_weights = []
for ci, this_class in enumerate(classes_):
cov, weight = cov_estimator(
X[y == this_class],
cov_kind=f"class={this_class}",
log_rank=ci == 0,
reg=reg,
cov_method_params=cov_method_params,
rank=_rank,
info=_info,
)

if norm_trace:
cov /= np.trace(cov)

covs.append(cov)
sample_weights.append(weight)

covs = np.stack(covs)
C_ref = covs.mean(0)

return covs, C_ref, _info, _rank, dict(sample_weights=np.array(sample_weights))


def _xdawn_estimate(
X,
y,
reg,
cov_method_params,
R=None,
events=None,
tmin=0,
sfreq=1,
info=None,
rank="full",
):
from ..preprocessing.xdawn import _least_square_evoked

if not isinstance(X, np.ndarray) or X.ndim != 3:
raise ValueError("X must be 3D ndarray")

classes = np.unique(y)

# XXX Eventually this could be made to deal with rank deficiency properly
# by exposing this "rank" parameter, but this will require refactoring
# the linalg.eigh call to operate in the lower-dimension
# subspace, then project back out.

# Retrieve or compute whitening covariance
if R is None:
R = _regularized_covariance(
np.hstack(X), reg, cov_method_params, info, rank=rank
)
elif isinstance(R, Covariance):
R = R.data
if not isinstance(R, np.ndarray) or (
not np.array_equal(R.shape, np.tile(X.shape[1], 2))
):
raise ValueError(
"R must be None, a covariance instance, "
"or an array of shape (n_chans, n_chans)"
)

# Get prototype events
if events is not None:
evokeds, toeplitzs = _least_square_evoked(X, events, tmin, sfreq)
else:
evokeds, toeplitzs = list(), list()
for c in classes:
# Prototyped response for each class
evokeds.append(np.mean(X[y == c, :, :], axis=0))
toeplitzs.append(1.0)

covs = []
for evo, toeplitz in zip(evokeds, toeplitzs):
# Estimate covariance matrix of the prototype response
evo = np.dot(evo, toeplitz)
evo_cov = _regularized_covariance(evo, reg, cov_method_params, info, rank=rank)
covs.append(evo_cov)

covs.append(R)
C_ref = None
rank = None
info = None
return covs, C_ref, info, rank, dict()


def _ssd_estimate(
X,
y,
reg,
cov_method_params,
info,
picks,
filt_params_signal,
filt_params_noise,
rank,
):
if isinstance(info, Info):
sfreq = info["sfreq"]
elif isinstance(info, float): # special case, mostly for testing
sfreq = info
info = create_info(X.shape[-2], sfreq, ch_types="eeg")
picks = _picks_to_idx(info, picks, none="data", exclude="bads")
X_aux = X[..., picks, :]
X_signal = filter_data(X_aux, sfreq, **filt_params_signal)
X_noise = filter_data(X_aux, sfreq, **filt_params_noise)
X_noise -= X_signal
if X.ndim == 3:
X_signal = np.hstack(X_signal)
X_noise = np.hstack(X_noise)

# prevent rank change when computing cov with rank='full'
S = _regularized_covariance(
X_signal,
reg=reg,
method_params=cov_method_params,
rank="full",
info=info,
)
R = _regularized_covariance(
X_noise,
reg=reg,
method_params=cov_method_params,
rank="full",
info=info,
)
covs = [S, R]
C_ref = S

all_ranks = list()
for cov in covs:
r = list(
compute_rank(
Covariance(
cov,
info.ch_names,
list(),
list(),
0,
verbose=_verbose_safe_false(),
),
rank,
_handle_default("scalings_cov_rank", None),
info,
).values()
)[0]
all_ranks.append(r)
rank = np.min(all_ranks)
return covs, C_ref, info, rank, dict()


def _spoc_estimate(X, y, reg, cov_method_params, rank):
# Normalize target variable
target = y.astype(np.float64)
target -= target.mean()
target /= target.std()

n_epochs, n_channels = X.shape[:2]

# Estimate single trial covariance
covs = np.empty((n_epochs, n_channels, n_channels))
for ii, epoch in enumerate(X):
covs[ii] = _regularized_covariance(
epoch,
reg=reg,
method_params=cov_method_params,
rank=rank,
log_ch_type="data",
log_rank=ii == 0,
)

S = np.mean(covs * target[:, np.newaxis, np.newaxis], axis=0)
R = covs.mean(0)

covs = [S, R]
C_ref = None
info = None
return covs, C_ref, info, rank, dict()
69 changes: 69 additions & 0 deletions mne/decoding/_mod_ged.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Eigenvalue eigenvector modifiers for GED transformers."""

# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import numpy as np


def _compute_mutual_info(covs, sample_weights, evecs):
class_probas = sample_weights / sample_weights.sum()

mutual_info = []
for jj in range(evecs.shape[1]):
aa, bb = 0, 0
for cov, prob in zip(covs, class_probas):
tmp = np.dot(np.dot(evecs[:, jj].T, cov), evecs[:, jj])
aa += prob * np.log(np.sqrt(tmp))
bb += prob * (tmp**2 - 1)
mi = -(aa + (3.0 / 16) * (bb**2))
mutual_info.append(mi)

return mutual_info


def _csp_mod(evals, evecs, covs, evecs_order, sample_weights):
n_classes = sample_weights.shape[0]
if evecs_order == "mutual_info" and n_classes > 2:
mutual_info = _compute_mutual_info(covs, sample_weights, evecs)
ix = np.argsort(mutual_info)[::-1]
elif evecs_order == "mutual_info" and n_classes == 2:
ix = np.argsort(np.abs(evals - 0.5))[::-1]
elif evecs_order == "alternate" and n_classes == 2:
i = np.argsort(evals)
ix = np.empty_like(i)
ix[1::2] = i[: len(i) // 2]
ix[0::2] = i[len(i) // 2 :][::-1]
if evals is not None:
evals = evals[ix]
evecs = evecs[:, ix]
return evals, evecs


def _xdawn_mod(evals, evecs, covs=None):
evals, evecs = _sort_descending(evals, evecs)
evecs /= np.linalg.norm(evecs, axis=0)
return evals, evecs


def _ssd_mod(evals, evecs, covs=None):
evals, evecs = _sort_descending(evals, evecs)
return evals, evecs


def _spoc_mod(evals, evecs, covs=None):
evals = evals.real
evecs = evecs.real
evals, evecs = _sort_descending(evals, evecs, by_abs=True)
return evals, evecs


def _sort_descending(evals, evecs, by_abs=False):
if by_abs:
ix = np.argsort(np.abs(evals))[::-1]
else:
ix = np.argsort(evals)[::-1]
evals = evals[ix]
evecs = evecs[:, ix]
return evals, evecs
Loading