From 632c81959f6f8d9a600ba3d12348611f96a32940 Mon Sep 17 00:00:00 2001 From: Genuster Date: Thu, 22 May 2025 16:24:09 +0300 Subject: [PATCH 01/18] assert_allclose for base ged for csp, spoc, ssd and xdawn --- mne/decoding/base.py | 86 +++++++++- mne/decoding/covs_ged.py | 318 +++++++++++++++++++++++++++++++++++++ mne/decoding/csp.py | 64 +++++++- mne/decoding/ged.py | 227 ++++++++++++++++++++++++++ mne/decoding/mod_ged.py | 69 ++++++++ mne/decoding/ssd.py | 38 ++++- mne/preprocessing/xdawn.py | 40 ++++- 7 files changed, 833 insertions(+), 9 deletions(-) create mode 100644 mne/decoding/covs_ged.py create mode 100644 mne/decoding/ged.py create mode 100644 mne/decoding/mod_ged.py diff --git a/mne/decoding/base.py b/mne/decoding/base.py index f73cd976fe3..3737f11960a 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -22,7 +22,91 @@ from sklearn.utils import check_array, check_X_y, indexable from ..parallel import parallel_func -from ..utils import _pl, logger, verbose, warn +from ..utils import _pl, logger, pinv, verbose, warn +from .ged import _get_ssd_rank, _handle_restr_map, _smart_ajd, _smart_ged +from .transformer import MNETransformerMixin + + +class GEDTransformer(MNETransformerMixin, BaseEstimator): + """...""" + + def __init__( + self, + n_filters, + cov_callable, + cov_params, + mod_ged_callable, + mod_params, + dec_type="single", + restr_map=None, + R_func=None, + ): + self.n_filters = n_filters + self.cov_callable = cov_callable + self.cov_params = cov_params + self.mod_ged_callable = mod_ged_callable + self.mod_params = mod_params + self.dec_type = dec_type + self.restr_map = restr_map + self.R_func = R_func + + def fit(self, X, y=None): + """...""" + covs, C_ref, info, rank, kwargs = self.cov_callable(X, y, **self.cov_params) + if self.dec_type == "single": + if len(covs) > 2: + sample_weights = kwargs["sample_weights"] + restr_map = _handle_restr_map(C_ref, self.restr_map, info, rank) + evecs = _smart_ajd(covs, restr_map, weights=sample_weights) + evals = None + else: + S = covs[0] + R = covs[1] + if self.restr_map == "ssd": + rank = _get_ssd_rank(S, R, info, rank) + mult_order = "ssd" + else: + mult_order = None + restr_map = _handle_restr_map(C_ref, self.restr_map, info, rank) + evals, evecs = _smart_ged( + S, R, restr_map, R_func=self.R_func, mult_order=mult_order + ) + + evals, evecs = self.mod_ged_callable( + evals, evecs, covs, **self.mod_params, **kwargs + ) + self.evals_ = evals + self.filters_ = evecs.T + if self.restr_map == "ssd": + self.patterns_ = np.linalg.pinv(evecs) + else: + self.patterns_ = pinv(evecs) + + elif self.dec_type == "multi": + self.classes_ = np.unique(y) + R = covs[-1] + restr_map = _handle_restr_map(C_ref, self.restr_map, info, rank) + all_evals, all_evecs, all_patterns = list(), list(), list() + for i in range(len(self.classes_)): + S = covs[i] + evals, evecs = _smart_ged(S, R, restr_map, R_func=self.R_func) + + evals, evecs = self.mod_ged_callable( + evals, evecs, covs, **self.mod_params, **kwargs + ) + all_evals.append(evals) + all_evecs.append(evecs.T) + all_patterns.append(np.linalg.pinv(evecs)) + self.evals_ = np.array(all_evals) + self.filters_ = np.array(all_evecs) + self.patterns_ = np.array(all_patterns) + + return self + + def transform(self, X): + """...""" + X = np.dot(self.filters_, X) + return X class LinearModel(MetaEstimatorMixin, BaseEstimator): diff --git a/mne/decoding/covs_ged.py b/mne/decoding/covs_ged.py new file mode 100644 index 00000000000..3df65d8107f --- /dev/null +++ b/mne/decoding/covs_ged.py @@ -0,0 +1,318 @@ +"""Covariance estimation for GED transformers.""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import numpy as np +import scipy.linalg + +from .._fiff.meas_info import Info, create_info +from .._fiff.pick import _picks_to_idx +from ..cov import Covariance, _compute_rank_raw_array, _regularized_covariance +from ..filter import filter_data +from ..utils import pinv + + +def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info): + """Concatenate epochs before computing the covariance.""" + _, n_channels, _ = x_class.shape + + x_class = x_class.transpose(1, 0, 2).reshape(n_channels, -1) + cov = _regularized_covariance( + x_class, + reg=reg, + method_params=cov_method_params, + rank=rank, + info=info, + cov_kind=cov_kind, + log_rank=log_rank, + log_ch_type="data", + ) + weight = x_class.shape[0] + + return cov, weight + + +def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info): + """Mean of per-epoch covariances.""" + cov = sum( + _regularized_covariance( + this_X, + reg=reg, + method_params=cov_method_params, + rank=rank, + info=info, + cov_kind=cov_kind, + log_rank=log_rank and ii == 0, + log_ch_type="data", + ) + for ii, this_X in enumerate(x_class) + ) + cov /= len(x_class) + weight = len(x_class) + + return cov, weight + + +def _csp_estimate(X, y, reg, cov_method_params, cov_est, rank, norm_trace): + _, n_channels, _ = X.shape + classes_ = np.unique(y) + if cov_est == "concat": + cov_estimator = _concat_cov + elif cov_est == "epoch": + cov_estimator = _epoch_cov + # Someday we could allow the user to pass this, then we wouldn't need to convert + # but in the meantime they can use a pipeline with a scaler + _info = create_info(n_channels, 1000.0, "mag") + if isinstance(rank, dict): + _rank = {"mag": sum(rank.values())} + else: + _rank = _compute_rank_raw_array( + X.transpose(1, 0, 2).reshape(X.shape[1], -1), + _info, + rank=rank, + scalings=None, + log_ch_type="data", + ) + + covs = [] + sample_weights = [] + for ci, this_class in enumerate(classes_): + cov, weight = cov_estimator( + X[y == this_class], + cov_kind=f"class={this_class}", + log_rank=ci == 0, + reg=reg, + cov_method_params=cov_method_params, + rank=_rank, + info=_info, + ) + + if norm_trace: + cov /= np.trace(cov) + + covs.append(cov) + sample_weights.append(weight) + + covs = np.stack(covs) + C_ref = covs.mean(0) + + return covs, C_ref, _info, _rank, dict(sample_weights=np.array(sample_weights)) + + +def _construct_signal_from_epochs(epochs, events, sfreq, tmin): + """Reconstruct pseudo continuous signal from epochs.""" + n_epochs, n_channels, n_times = epochs.shape + tmax = tmin + n_times / float(sfreq) + start = np.min(events[:, 0]) + int(tmin * sfreq) + stop = np.max(events[:, 0]) + int(tmax * sfreq) + 1 + + n_samples = stop - start + n_epochs, n_channels, n_times = epochs.shape + events_pos = events[:, 0] - events[0, 0] + + raw = np.zeros((n_channels, n_samples)) + for idx in range(n_epochs): + onset = events_pos[idx] + offset = onset + n_times + raw[:, onset:offset] = epochs[idx] + + return raw + + +def _least_square_evoked(epochs_data, events, tmin, sfreq): + """Least square estimation of evoked response from epochs data. + + Parameters + ---------- + epochs_data : array, shape (n_channels, n_times) + The epochs data to estimate evoked. + events : array, shape (n_events, 3) + The events typically returned by the read_events function. + If some events don't match the events of interest as specified + by event_id, they will be ignored. + tmin : float + Start time before event. + sfreq : float + Sampling frequency. + + Returns + ------- + evokeds : array, shape (n_class, n_components, n_times) + An concatenated array of evoked data for each event type. + toeplitz : array, shape (n_class * n_components, n_channels) + An concatenated array of toeplitz matrix for each event type. + """ + n_epochs, n_channels, n_times = epochs_data.shape + tmax = tmin + n_times / float(sfreq) + + # Deal with shuffled epochs + events = events.copy() + events[:, 0] -= events[0, 0] + int(tmin * sfreq) + + # Construct raw signal + raw = _construct_signal_from_epochs(epochs_data, events, sfreq, tmin) + + # Compute the independent evoked responses per condition, while correcting + # for event overlaps. + n_min, n_max = int(tmin * sfreq), int(tmax * sfreq) + window = n_max - n_min + n_samples = raw.shape[1] + toeplitz = list() + classes = np.unique(events[:, 2]) + for ii, this_class in enumerate(classes): + # select events by type + sel = events[:, 2] == this_class + + # build toeplitz matrix + trig = np.zeros((n_samples,)) + ix_trig = (events[sel, 0]) + n_min + trig[ix_trig] = 1 + toeplitz.append(scipy.linalg.toeplitz(trig[0:window], trig)) + + # Concatenate toeplitz + toeplitz = np.array(toeplitz) + X = np.concatenate(toeplitz) + + # least square estimation + predictor = np.dot(pinv(np.dot(X, X.T)), X) + evokeds = np.dot(predictor, raw.T) + evokeds = np.transpose(np.vsplit(evokeds, len(classes)), (0, 2, 1)) + return evokeds, toeplitz + + +def _xdawn_estimate( + X, + y, + reg, + cov_method_params, + R=None, + events=None, + tmin=0, + sfreq=1, + info=None, + rank="full", +): + if not isinstance(X, np.ndarray) or X.ndim != 3: + raise ValueError("X must be 3D ndarray") + + classes = np.unique(y) + + # XXX Eventually this could be made to deal with rank deficiency properly + # by exposing this "rank" parameter, but this will require refactoring + # the linalg.eigh call to operate in the lower-dimension + # subspace, then project back out. + + # Retrieve or compute whitening covariance + if R is None: + R = _regularized_covariance( + np.hstack(X), reg, cov_method_params, info, rank=rank + ) + elif isinstance(R, Covariance): + R = R.data + if not isinstance(R, np.ndarray) or ( + not np.array_equal(R.shape, np.tile(X.shape[1], 2)) + ): + raise ValueError( + "R must be None, a covariance instance, " + "or an array of shape (n_chans, n_chans)" + ) + + # Get prototype events + if events is not None: + evokeds, toeplitzs = _least_square_evoked(X, events, tmin, sfreq) + else: + evokeds, toeplitzs = list(), list() + for c in classes: + # Prototyped response for each class + evokeds.append(np.mean(X[y == c, :, :], axis=0)) + toeplitzs.append(1.0) + + covs = [] + for evo, toeplitz in zip(evokeds, toeplitzs): + # Estimate covariance matrix of the prototype response + evo = np.dot(evo, toeplitz) + evo_cov = _regularized_covariance(evo, reg, cov_method_params, info, rank=rank) + covs.append(evo_cov) + + covs.append(R) + covs = np.stack(covs) + C_ref = None + rank = None + info = None + return covs, C_ref, info, rank, dict() + + +def _ssd_estimate( + X, + y, + reg, + cov_method_params, + info, + picks, + filt_params_signal, + filt_params_noise, + rank, +): + if isinstance(info, Info): + sfreq = info["sfreq"] + elif isinstance(info, float): # special case, mostly for testing + sfreq = info + info = create_info(X.shape[-2], sfreq, ch_types="eeg") + picks = _picks_to_idx(info, picks, none="data", exclude="bads") + X_aux = X[..., picks, :] + X_signal = filter_data(X_aux, sfreq, **filt_params_signal) + X_noise = filter_data(X_aux, sfreq, **filt_params_noise) + X_noise -= X_signal + if X.ndim == 3: + X_signal = np.hstack(X_signal) + X_noise = np.hstack(X_noise) + + # prevent rank change when computing cov with rank='full' + S = _regularized_covariance( + X_signal, + reg=reg, + method_params=cov_method_params, + rank="full", + info=info, + ) + R = _regularized_covariance( + X_noise, + reg=reg, + method_params=cov_method_params, + rank="full", + info=info, + ) + covs = [S, R] + C_ref = S + return covs, C_ref, info, rank, dict() + + +def _spoc_estimate(X, y, reg, cov_method_params, rank): + # Normalize target variable + target = y.astype(np.float64) + target -= target.mean() + target /= target.std() + + n_epochs, n_channels = X.shape[:2] + + # Estimate single trial covariance + covs = np.empty((n_epochs, n_channels, n_channels)) + for ii, epoch in enumerate(X): + covs[ii] = _regularized_covariance( + epoch, + reg=reg, + method_params=cov_method_params, + rank=rank, + log_ch_type="data", + log_rank=ii == 0, + ) + + S = np.mean(covs * target[:, np.newaxis, np.newaxis], axis=0) + R = covs.mean(0) + + covs = [S, R] + C_ref = None + info = None + return covs, C_ref, info, rank, dict() diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index ea38fd58ca3..883004467ee 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -6,7 +6,6 @@ import numpy as np from scipy.linalg import eigh -from sklearn.base import BaseEstimator from sklearn.utils.validation import check_is_fitted from .._fiff.meas_info import create_info @@ -20,11 +19,13 @@ fill_doc, pinv, ) -from .transformer import MNETransformerMixin +from .base import GEDTransformer +from .covs_ged import _csp_estimate, _spoc_estimate +from .mod_ged import _csp_mod, _spoc_mod @fill_doc -class CSP(MNETransformerMixin, BaseEstimator): +class CSP(GEDTransformer): """M/EEG signal decomposition using the Common Spatial Patterns (CSP). This class can be used as a supervised decomposition to estimate spatial @@ -124,6 +125,26 @@ def __init__( self.cov_method_params = cov_method_params self.component_order = component_order + cov_params = dict( + reg=reg, + cov_method_params=cov_method_params, + cov_est=cov_est, + rank=rank, + norm_trace=norm_trace, + ) + + mod_params = dict(evecs_order=component_order) + super().__init__( + n_components, + _csp_estimate, + cov_params, + _csp_mod, + mod_params, + dec_type="single", + restr_map="restricting", + R_func=sum, + ) + def _validate_params(self, *, y): _validate_type(self.n_components, int, "n_components") if hasattr(self, "cov_est"): @@ -191,6 +212,16 @@ def fit(self, X, y): self.filters_ = eigen_vectors.T self.patterns_ = pinv(eigen_vectors) + old_filters = self.filters_ + old_patterns = self.patterns_ + super().fit(X, y) + if self.evals_ is None: + assert eigen_values is None + else: + np.testing.assert_allclose(eigen_values[ix], self.evals_) + np.testing.assert_allclose(old_filters, self.filters_) + np.testing.assert_allclose(old_patterns, self.patterns_) + pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) @@ -857,6 +888,25 @@ def __init__( rank=rank, cov_method_params=cov_method_params, ) + + cov_params = dict( + reg=reg, + cov_method_params=cov_method_params, + rank=rank, + ) + + mod_params = dict() + super(CSP, self).__init__( + n_components, + _spoc_estimate, + cov_params, + _spoc_mod, + mod_params, + dec_type="single", + restr_map=None, + R_func=None, + ) + # Covariance estimation have to be done on the single epoch level, # unlike CSP where covariance estimation can also be achieved through # concatenation of all epochs from the same class. @@ -919,6 +969,14 @@ def fit(self, X, y): self.patterns_ = pinv(evecs).T # n_channels x n_channels self.filters_ = evecs # n_channels x n_channels + old_filters = self.filters_ + old_patterns = self.patterns_ + super(CSP, self).fit(X, y) + + np.testing.assert_allclose(evals[ix], self.evals_) + np.testing.assert_allclose(old_filters, self.filters_) + np.testing.assert_allclose(old_patterns, self.patterns_) + pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) diff --git a/mne/decoding/ged.py b/mne/decoding/ged.py new file mode 100644 index 00000000000..5e505f8be9a --- /dev/null +++ b/mne/decoding/ged.py @@ -0,0 +1,227 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import numpy as np +import scipy.linalg + +from ..cov import Covariance, _smart_eigh, compute_whitener +from ..defaults import _handle_default +from ..rank import compute_rank +from ..utils import _verbose_safe_false, logger + + +def _handle_restr_map(C_ref, restr_map, info, rank): + if C_ref is None or restr_map is None: + return None + if restr_map == "whitening": + projs = info["projs"] + C_ref_cov = Covariance(C_ref, info.ch_names, info["bads"], projs, 0) + restr_map = compute_whitener(C_ref_cov, info, rank=rank, pca=True) + elif restr_map == "ssd": + restr_map = _get_ssd_whitener(C_ref, rank) + elif restr_map == "restricting": + restr_map = _get_restricting_map(C_ref, info, rank) + elif isinstance(restr_map, callable): + pass + else: + raise ValueError( + "restr_map should either be callable or one of whitening, ssd, restricting" + ) + return restr_map + + +def _smart_ged(S, R, restr_map, R_func=None, mult_order=None): + """...""" + if restr_map is None: + evals, evecs = scipy.linalg.eigh(S, R) + return evals, evecs + + if mult_order == "ssd": + S_restr = restr_map @ (S @ restr_map.T) + R_restr = restr_map @ (R @ restr_map.T) + else: + S_restr = restr_map @ S @ restr_map.T + R_restr = restr_map @ R @ restr_map.T + if R_func is not None: + R_restr = R_func([S_restr, R_restr]) + evals, evecs_restr = scipy.linalg.eigh(S_restr, R_restr) + evecs = restr_map.T @ evecs_restr + + return evals, evecs + + +def _ajd_pham(X, eps=1e-6, max_iter=15): + """Approximate joint diagonalization based on Pham's algorithm. + + This is a direct implementation of the PHAM's AJD algorithm [1]. + + Parameters + ---------- + X : ndarray, shape (n_epochs, n_channels, n_channels) + A set of covariance matrices to diagonalize. + eps : float, default 1e-6 + The tolerance for stopping criterion. + max_iter : int, default 1000 + The maximum number of iteration to reach convergence. + + Returns + ------- + V : ndarray, shape (n_channels, n_channels) + The diagonalizer. + D : ndarray, shape (n_epochs, n_channels, n_channels) + The set of quasi diagonal matrices. + + References + ---------- + .. [1] Pham, Dinh Tuan. "Joint approximate diagonalization of positive + definite Hermitian matrices." SIAM Journal on Matrix Analysis and + Applications 22, no. 4 (2001): 1136-1152. + + """ + # Adapted from http://github.com/alexandrebarachant/pyRiemann + n_epochs = X.shape[0] + + # Reshape input matrix + A = np.concatenate(X, axis=0).T + + # Init variables + n_times, n_m = A.shape + V = np.eye(n_times) + epsilon = n_times * (n_times - 1) * eps + + for it in range(max_iter): + decr = 0 + for ii in range(1, n_times): + for jj in range(ii): + Ii = np.arange(ii, n_m, n_times) + Ij = np.arange(jj, n_m, n_times) + + c1 = A[ii, Ii] + c2 = A[jj, Ij] + + g12 = np.mean(A[ii, Ij] / c1) + g21 = np.mean(A[ii, Ij] / c2) + + omega21 = np.mean(c1 / c2) + omega12 = np.mean(c2 / c1) + omega = np.sqrt(omega12 * omega21) + + tmp = np.sqrt(omega21 / omega12) + tmp1 = (tmp * g12 + g21) / (omega + 1) + tmp2 = (tmp * g12 - g21) / max(omega - 1, 1e-9) + + h12 = tmp1 + tmp2 + h21 = np.conj((tmp1 - tmp2) / tmp) + + decr += n_epochs * (g12 * np.conj(h12) + g21 * h21) / 2.0 + + tmp = 1 + 1.0j * 0.5 * np.imag(h12 * h21) + tmp = np.real(tmp + np.sqrt(tmp**2 - h12 * h21)) + tau = np.array([[1, -h12 / tmp], [-h21 / tmp, 1]]) + + A[[ii, jj], :] = np.dot(tau, A[[ii, jj], :]) + tmp = np.c_[A[:, Ii], A[:, Ij]] + tmp = np.reshape(tmp, (n_times * n_epochs, 2), order="F") + tmp = np.dot(tmp, tau.T) + + tmp = np.reshape(tmp, (n_times, n_epochs * 2), order="F") + A[:, Ii] = tmp[:, :n_epochs] + A[:, Ij] = tmp[:, n_epochs:] + V[[ii, jj], :] = np.dot(tau, V[[ii, jj], :]) + if decr < epsilon: + break + D = np.reshape(A, (n_times, -1, n_times)).transpose(1, 0, 2) + return V, D + + +def _smart_ajd(covs, restr_map, weights): + covs = np.array([restr_map @ cov @ restr_map.T for cov in covs], float) + evecs_restr, D = _ajd_pham(covs) + evecs = _normalize_eigenvectors(evecs_restr.T, covs, weights) + evecs = restr_map.T @ evecs + return evecs + + +def _get_restricting_map(C, info, rank): + _, ref_evecs, mask = _smart_eigh( + C, + info, + rank, + proj_subspace=True, + do_compute_rank=False, + log_ch_type="data", + ) + restr_map = ref_evecs[mask] + return restr_map + + +def _normalize_eigenvectors(evecs, covs, sample_weights): + # Here we apply an euclidean mean. See pyRiemann for other metrics + mean_cov = np.average(covs, axis=0, weights=sample_weights) + + for ii in range(evecs.shape[1]): + tmp = np.dot(np.dot(evecs[:, ii].T, mean_cov), evecs[:, ii]) + evecs[:, ii] /= np.sqrt(tmp) + return evecs + + +def _get_ssd_rank(S, R, info, rank): + # find ranks of covariance matrices + rank_signal = list( + compute_rank( + Covariance( + S, + info.ch_names, + list(), + list(), + 0, + verbose=_verbose_safe_false(), + ), + rank, + _handle_default("scalings_cov_rank", None), + info, + ).values() + )[0] + rank_noise = list( + compute_rank( + Covariance( + R, + info.ch_names, + list(), + list(), + 0, + verbose=_verbose_safe_false(), + ), + rank, + _handle_default("scalings_cov_rank", None), + info, + ).values() + )[0] + rank = np.min([rank_signal, rank_noise]) # should be identical + return rank + + +def _get_ssd_whitener(S, rank): + """Perform dimensionality reduction on the covariance matrices.""" + n_channels = S.shape[0] + if rank < n_channels: + eigvals, eigvects = scipy.linalg.eigh(S) + # sort in descending order + ix = np.argsort(eigvals)[::-1] + eigvals = eigvals[ix] + eigvects = eigvects[:, ix] + # compute rank subspace projection matrix + rank_proj = np.matmul( + eigvects[:, :rank], np.eye(rank) * (eigvals[:rank] ** -0.5) + ) + logger.info( + "Projecting covariance of %i channels to %i rank subspace", + n_channels, + rank, + ) + else: + rank_proj = np.eye(n_channels) + logger.info("Preserving covariance rank (%i)", rank) + + return rank_proj.T diff --git a/mne/decoding/mod_ged.py b/mne/decoding/mod_ged.py new file mode 100644 index 00000000000..ad3dc031f16 --- /dev/null +++ b/mne/decoding/mod_ged.py @@ -0,0 +1,69 @@ +"""Eigenvalue eigenvector modifiers for GED transformers.""" + +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + +import numpy as np + + +def _compute_mutual_info(covs, sample_weights, evecs): + class_probas = sample_weights / sample_weights.sum() + + mutual_info = [] + for jj in range(evecs.shape[1]): + aa, bb = 0, 0 + for cov, prob in zip(covs, class_probas): + tmp = np.dot(np.dot(evecs[:, jj].T, cov), evecs[:, jj]) + aa += prob * np.log(np.sqrt(tmp)) + bb += prob * (tmp**2 - 1) + mi = -(aa + (3.0 / 16) * (bb**2)) + mutual_info.append(mi) + + return mutual_info + + +def _csp_mod(evals, evecs, covs, evecs_order, sample_weights): + n_classes = sample_weights.shape[0] + if evecs_order == "mutual_info" and n_classes > 2: + mutual_info = _compute_mutual_info(covs, sample_weights, evecs) + ix = np.argsort(mutual_info)[::-1] + elif evecs_order == "mutual_info" and n_classes == 2: + ix = np.argsort(np.abs(evals - 0.5))[::-1] + elif evecs_order == "alternate" and n_classes == 2: + i = np.argsort(evals) + ix = np.empty_like(i) + ix[1::2] = i[: len(i) // 2] + ix[0::2] = i[len(i) // 2 :][::-1] + if evals is not None: + evals = evals[ix] + evecs = evecs[:, ix] + return evals, evecs + + +def _xdawn_mod(evals, evecs, covs=None): + evals, evecs = _sort_descending(evals, evecs) + evecs /= np.linalg.norm(evecs, axis=0) + return evals, evecs + + +def _ssd_mod(evals, evecs, covs=None): + evals, evecs = _sort_descending(evals, evecs) + return evals, evecs + + +def _spoc_mod(evals, evecs, covs=None): + evals = evals.real + evecs = evecs.real + evals, evecs = _sort_descending(evals, evecs, by_abs=True) + return evals, evecs + + +def _sort_descending(evals, evecs, by_abs=False): + if by_abs: + ix = np.argsort(np.abs(evals))[::-1] + else: + ix = np.argsort(evals)[::-1] + evals = evals[ix] + evecs = evecs[:, ix] + return evals, evecs diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 111ded9f274..367be7038d3 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -4,7 +4,6 @@ import numpy as np from scipy.linalg import eigh -from sklearn.base import BaseEstimator from sklearn.utils.validation import check_is_fitted from .._fiff.meas_info import Info, create_info @@ -21,11 +20,13 @@ fill_doc, logger, ) -from .transformer import MNETransformerMixin +from .base import GEDTransformer +from .covs_ged import _ssd_estimate +from .mod_ged import _ssd_mod @fill_doc -class SSD(MNETransformerMixin, BaseEstimator): +class SSD(GEDTransformer): """ Signal decomposition using the Spatio-Spectral Decomposition (SSD). @@ -118,6 +119,28 @@ def __init__( self.cov_method_params = cov_method_params self.rank = rank + cov_params = dict( + reg=reg, + cov_method_params=cov_method_params, + info=info, + picks=picks, + filt_params_signal=filt_params_signal, + filt_params_noise=filt_params_noise, + rank=rank, + ) + + mod_params = dict() + super().__init__( + n_components, + _ssd_estimate, + cov_params, + _ssd_mod, + mod_params, + dec_type="single", + restr_map="ssd", + R_func=None, + ) + def _validate_params(self, X): if isinstance(self.info, float): # special case, mostly for testing self.sfreq_ = self.info @@ -240,6 +263,15 @@ def fit(self, X, y=None): self.filters_ = np.matmul(rank_proj, eigvects_[:, ix]) self.patterns_ = np.linalg.pinv(self.filters_) + old_filters = self.filters_ + old_patterns = self.patterns_ + super().fit(X, y) + self.filters_ = self.filters_.T + + np.testing.assert_allclose(self.eigvals_, self.evals_) + np.testing.assert_allclose(old_filters, self.filters_) + np.testing.assert_allclose(old_patterns, self.patterns_) + # We assume that ordering by spectral ratio is more important # than the initial ordering. This ordering should be also learned when # fitting. diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py index 606b49370df..f794e404f46 100644 --- a/mne/preprocessing/xdawn.py +++ b/mne/preprocessing/xdawn.py @@ -7,7 +7,9 @@ from .._fiff.pick import _pick_data_channels, pick_info from ..cov import Covariance, _regularized_covariance -from ..decoding import BaseEstimator, TransformerMixin +from ..decoding.base import GEDTransformer +from ..decoding.covs_ged import _xdawn_estimate +from ..decoding.mod_ged import _xdawn_mod from ..epochs import BaseEpochs from ..evoked import Evoked, EvokedArray from ..io import BaseRaw @@ -212,7 +214,7 @@ def _fit_xdawn( return filters, patterns, evokeds -class _XdawnTransformer(BaseEstimator, TransformerMixin): +class _XdawnTransformer(GEDTransformer): """Implementation of the Xdawn Algorithm compatible with scikit-learn. Xdawn is a spatial filtering method designed to improve the signal @@ -259,6 +261,20 @@ def __init__(self, n_components=2, reg=None, signal_cov=None, method_params=None self.reg = reg self.method_params = method_params + cov_params = dict(reg=reg, cov_method_params=method_params, R=signal_cov) + + mod_params = dict() + super().__init__( + n_components, + _xdawn_estimate, + cov_params, + _xdawn_mod, + mod_params, + dec_type="multi", + restr_map=None, + R_func=None, + ) + def fit(self, X, y=None): """Fit Xdawn spatial filters. @@ -286,6 +302,26 @@ def fit(self, X, y=None): signal_cov=self.signal_cov, method_params=self.method_params, ) + old_filters = self.filters_ + old_patterns = self.patterns_ + super().fit(X, y) + self.filters_ = np.concatenate( + [ + self.filters_[i, : self.n_components] + for i in range(self.filters_.shape[0]) + ], + axis=0, + ) + self.patterns_ = np.concatenate( + [ + self.patterns_[i, : self.n_components] + for i in range(self.patterns_.shape[0]) + ], + axis=0, + ) + np.testing.assert_allclose(old_filters, self.filters_) + np.testing.assert_allclose(old_patterns, self.patterns_) + return self def transform(self, X): From 7c072d15f7a7ce7440f20a444d884f5c4e9da007 Mon Sep 17 00:00:00 2001 From: Genuster Date: Thu, 22 May 2025 16:35:21 +0300 Subject: [PATCH 02/18] update _epoch_cov logging following merge --- mne/decoding/covs_ged.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mne/decoding/covs_ged.py b/mne/decoding/covs_ged.py index 3df65d8107f..89e87f48820 100644 --- a/mne/decoding/covs_ged.py +++ b/mne/decoding/covs_ged.py @@ -11,7 +11,7 @@ from .._fiff.pick import _picks_to_idx from ..cov import Covariance, _compute_rank_raw_array, _regularized_covariance from ..filter import filter_data -from ..utils import pinv +from ..utils import _verbose_safe_false, logger, pinv def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info): @@ -36,6 +36,12 @@ def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, in def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info): """Mean of per-epoch covariances.""" + name = reg if isinstance(reg, str) else "empirical" + name += " with shrinkage" if isinstance(reg, float) else "" + logger.info( + f"Estimating {cov_kind + (' ' if cov_kind else '')}" + f"covariance (average over epochs; {name.upper()})" + ) cov = sum( _regularized_covariance( this_X, @@ -46,6 +52,7 @@ def _epoch_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, inf cov_kind=cov_kind, log_rank=log_rank and ii == 0, log_ch_type="data", + verbose=_verbose_safe_false(), ) for ii, this_X in enumerate(x_class) ) From 211d23f6cb1b9eaed2cf0b62d598ef7c36e5fdc7 Mon Sep 17 00:00:00 2001 From: Genuster Date: Thu, 22 May 2025 17:53:51 +0300 Subject: [PATCH 03/18] add a few preliminary docstrings --- mne/decoding/base.py | 58 +++++++++++++++++++++++++++++++++++++++++++- mne/decoding/ged.py | 31 ++++++++++++++++++++--- 2 files changed, 85 insertions(+), 4 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 3737f11960a..d00b02b8391 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -28,7 +28,63 @@ class GEDTransformer(MNETransformerMixin, BaseEstimator): - """...""" + """M/EEG signal decomposition using the generalized eigenvalue decomposition (GED). + + Given two channel covariance matrices S and R, the goal is to find spatial filters + that maximise contrast between S and R. + + Parameters + ---------- + n_filters : int + The number of spatial filters to decompose M/EEG signals. + cov_callable : callable + Function used to estimate covariances and reference matrix (C_ref) from the + data. + cov_params : dict + Parameters passed to cov_callable. + mod_ged_callable : callable + Function used to modify (e.g. sort or normalize) generalized + eigenvalues and eigenvectors. + mod_params : dict + Parameters passed to mod_ged_callable. + dec_type : "single" | "multi" + When "single" and cov_callable returns > 2 covariances, + approximate joint diagonalization based on Pham's algorithm + will be used instead of GED. + When 'multi', GED is performed separately for each class, i.e. each covariance + (except the last) returned by cov_callable is decomposed with the last + covariance. In this case, number of covariances should be number of classes + 1. + Defaults to "single". + restr_map : "restricting" | "whitening" | "ssd" | None + Restricting transformation for covariance matrices before performing GED. + If "restricting" only restriction to the principal subspace of the C_ref + will be performed. + If "whitening", covariance matrices will be additionally rescaled according + to the whitening for the C_ref. + If "ssd", perform simplified version of "whitening", + preserved for compatibility. + If None, no restriction will be applied. Defaults to None. + R_func : callable | None + If provided GED will be performed on (S, R_func(S,R)). + + Attributes + ---------- + evals_ : ndarray, shape (n_channels) + If fit, generalized eigenvalues used to decompose S and R, else None. + filters_ : ndarray, shape (n_channels or less, n_channels) + If fit, spatial filters (unmixing matrix) used to decompose the data, + else None. + patterns_ : ndarray, shape (n_channels or less, n_channels) + If fit, spatial patterns (mixing matrix) used to restore M/EEG signals, + else None. + + See Also + -------- + CSP + SPoC + SSD + mne.preprocessing.Xdawn + """ def __init__( self, diff --git a/mne/decoding/ged.py b/mne/decoding/ged.py index 5e505f8be9a..d71db4aa2f8 100644 --- a/mne/decoding/ged.py +++ b/mne/decoding/ged.py @@ -12,6 +12,11 @@ def _handle_restr_map(C_ref, restr_map, info, rank): + """Get restricting map to C_ref rank-dimensional principal subspace. + + Returns matrix of shape (rank, n_chs) used to restrict or + restrict+rescale (whiten) covariances matrices. + """ if C_ref is None or restr_map is None: return None if restr_map == "whitening": @@ -31,8 +36,15 @@ def _handle_restr_map(C_ref, restr_map, info, rank): return restr_map -def _smart_ged(S, R, restr_map, R_func=None, mult_order=None): - """...""" +def _smart_ged(S, R, restr_map=None, R_func=None, mult_order=None): + """Perform smart generalized eigenvalue decomposition (GED) of S and R. + + If restr_map is provided S and R will be restricted to the principal subspace + of a reference matrix with rank r (see _handle_restr_map), then GED is performed + on the restricted S and R and then generalized eigenvectors are transformed back + to the original space. The g-eigenvectors matrix is of shape (n_chs, r). + If callable R_func is provided the GED will be performed on (S, R_func(S,R)) + """ if restr_map is None: evals, evecs = scipy.linalg.eigh(S, R) return evals, evecs @@ -135,7 +147,19 @@ def _ajd_pham(X, eps=1e-6, max_iter=15): return V, D -def _smart_ajd(covs, restr_map, weights): +def _smart_ajd(covs, restr_map=None, weights=None): + """Perform smart approximate joint diagonalization. + + If restr_map is provided all the cov matrices will be restricted to the + principal subspace of a reference matrix with rank r (see _handle_restr_map), + then GED is performed on the restricted S and R and then generalized eigenvectors + are transformed back to the original space. + The matrix of generalized eigenvectors is of shape (n_chs, r). + """ + if restr_map is None: + evecs, D = _ajd_pham(covs) + return evecs + covs = np.array([restr_map @ cov @ restr_map.T for cov in covs], float) evecs_restr, D = _ajd_pham(covs) evecs = _normalize_eigenvectors(evecs_restr.T, covs, weights) @@ -144,6 +168,7 @@ def _smart_ajd(covs, restr_map, weights): def _get_restricting_map(C, info, rank): + """Get map restricting covariance to rank-dimensional principal subspace of C.""" _, ref_evecs, mask = _smart_eigh( C, info, From 0d58c8d567b34454b594c05570e75cedbe6766fe Mon Sep 17 00:00:00 2001 From: Genuster Date: Thu, 22 May 2025 19:43:38 +0300 Subject: [PATCH 04/18] bump rtol/atol for spoc --- mne/decoding/csp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 1d378a66b3e..51ac2ce0c0c 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -982,8 +982,8 @@ def fit(self, X, y): super(CSP, self).fit(X, y) np.testing.assert_allclose(evals[ix], self.evals_) - np.testing.assert_allclose(old_filters, self.filters_) - np.testing.assert_allclose(old_patterns, self.patterns_) + np.testing.assert_allclose(old_filters, self.filters_, rtol=1e-6, atol=1e-7) + np.testing.assert_allclose(old_patterns, self.patterns_, rtol=1e-6, atol=1e-7) pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) From 2a1c5cb27467962e7a505d6cb5012afcc3c64edc Mon Sep 17 00:00:00 2001 From: Genuster Date: Thu, 29 May 2025 22:23:39 +0300 Subject: [PATCH 05/18] Add big sklearn compliance test --- mne/decoding/base.py | 81 +++++++++++++++++---- mne/decoding/covs_ged.py | 22 ++++++ mne/decoding/csp.py | 4 +- mne/decoding/ged.py | 69 ++++++------------ mne/decoding/ssd.py | 2 +- mne/decoding/tests/test_ged.py | 124 +++++++++++++++++++++++++++++++++ mne/preprocessing/xdawn.py | 2 +- 7 files changed, 241 insertions(+), 63 deletions(-) create mode 100644 mne/decoding/tests/test_ged.py diff --git a/mne/decoding/base.py b/mne/decoding/base.py index d00b02b8391..ac4700b53ed 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -8,11 +8,11 @@ import numbers import numpy as np +import scipy.linalg from sklearn import model_selection as models from sklearn.base import ( # noqa: F401 BaseEstimator, MetaEstimatorMixin, - TransformerMixin, clone, is_classifier, ) @@ -20,10 +20,11 @@ from sklearn.metrics import check_scoring from sklearn.model_selection import KFold, StratifiedKFold, check_cv from sklearn.utils import check_array, check_X_y, indexable +from sklearn.utils.validation import check_is_fitted from ..parallel import parallel_func from ..utils import _pl, logger, pinv, verbose, warn -from .ged import _get_ssd_rank, _handle_restr_map, _smart_ajd, _smart_ged +from .ged import _handle_restr_map, _smart_ajd, _smart_ged from .transformer import MNETransformerMixin @@ -55,7 +56,7 @@ class GEDTransformer(MNETransformerMixin, BaseEstimator): (except the last) returned by cov_callable is decomposed with the last covariance. In this case, number of covariances should be number of classes + 1. Defaults to "single". - restr_map : "restricting" | "whitening" | "ssd" | None + restr_type : "restricting" | "whitening" | "ssd" | None Restricting transformation for covariance matrices before performing GED. If "restricting" only restriction to the principal subspace of the C_ref will be performed. @@ -94,7 +95,7 @@ def __init__( mod_ged_callable, mod_params, dec_type="single", - restr_map=None, + restr_type=None, R_func=None, ): self.n_filters = n_filters @@ -103,27 +104,35 @@ def __init__( self.mod_ged_callable = mod_ged_callable self.mod_params = mod_params self.dec_type = dec_type - self.restr_map = restr_map + self.restr_type = restr_type self.R_func = R_func def fit(self, X, y=None): """...""" + X, y = self._check_data( + X, + y=y, + fit=True, + return_y=True, + atleast_3d=False if self.restr_type == "ssd" else True, + ) covs, C_ref, info, rank, kwargs = self.cov_callable(X, y, **self.cov_params) + self._validate_covariances(covs + [C_ref]) if self.dec_type == "single": if len(covs) > 2: + covs = np.array(covs) sample_weights = kwargs["sample_weights"] - restr_map = _handle_restr_map(C_ref, self.restr_map, info, rank) + restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank) evecs = _smart_ajd(covs, restr_map, weights=sample_weights) evals = None else: S = covs[0] R = covs[1] - if self.restr_map == "ssd": - rank = _get_ssd_rank(S, R, info, rank) + if self.restr_type == "ssd": mult_order = "ssd" else: mult_order = None - restr_map = _handle_restr_map(C_ref, self.restr_map, info, rank) + restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank) evals, evecs = _smart_ged( S, R, restr_map, R_func=self.R_func, mult_order=mult_order ) @@ -133,7 +142,7 @@ def fit(self, X, y=None): ) self.evals_ = evals self.filters_ = evecs.T - if self.restr_map == "ssd": + if self.restr_type == "ssd": self.patterns_ = np.linalg.pinv(evecs) else: self.patterns_ = pinv(evecs) @@ -141,11 +150,18 @@ def fit(self, X, y=None): elif self.dec_type == "multi": self.classes_ = np.unique(y) R = covs[-1] - restr_map = _handle_restr_map(C_ref, self.restr_map, info, rank) + if self.restr_type == "ssd": + mult_order = "ssd" + else: + mult_order = None + restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank) all_evals, all_evecs, all_patterns = list(), list(), list() for i in range(len(self.classes_)): S = covs[i] - evals, evecs = _smart_ged(S, R, restr_map, R_func=self.R_func) + + evals, evecs = _smart_ged( + S, R, restr_map, R_func=self.R_func, mult_order=mult_order + ) evals, evecs = self.mod_ged_callable( evals, evecs, covs, **self.mod_params, **kwargs @@ -161,9 +177,48 @@ def fit(self, X, y=None): def transform(self, X): """...""" - X = np.dot(self.filters_, X) + check_is_fitted(self, "filters_") + X = self._check_data(X) + if self.dec_type == "single": + pick_filters = self.filters_[: self.n_filters] + elif self.dec_type == "multi": + pick_filters = np.concatenate( + [ + self.filters_[i, : self.n_filters] + for i in range(self.filters_.shape[0]) + ], + axis=0, + ) + X = np.asarray([pick_filters @ epoch for epoch in X]) return X + def _validate_covariances(self, covs): + for cov in covs: + if cov is None: + continue + is_sym = scipy.linalg.issymmetric(cov, rtol=1e-10, atol=1e-11) + if not is_sym: + raise ValueError( + "One of covariances or C_ref is not symmetric, " + "check your cov_callable" + ) + if not np.all(np.linalg.eigvals(cov) >= 0): + ValueError( + "One of covariances or C_ref has negative eigenvalues, " + "check your cov_callable" + ) + + def __sklearn_tags__(self): + """Tag the transformer.""" + tags = super().__sklearn_tags__() + tags.estimator_type = "transformer" + # Can be a transformer where S and R covs are not based on y classes. + tags.target_tags.required = False + tags.target_tags.one_d_labels = True + tags.input_tags.two_d_array = True + tags.input_tags.three_d_array = True + return tags + class LinearModel(MetaEstimatorMixin, BaseEstimator): """Compute and store patterns from linear models. diff --git a/mne/decoding/covs_ged.py b/mne/decoding/covs_ged.py index 89e87f48820..f40e9fefcf3 100644 --- a/mne/decoding/covs_ged.py +++ b/mne/decoding/covs_ged.py @@ -10,7 +10,9 @@ from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx from ..cov import Covariance, _compute_rank_raw_array, _regularized_covariance +from ..defaults import _handle_default from ..filter import filter_data +from ..rank import compute_rank from ..utils import _verbose_safe_false, logger, pinv @@ -293,6 +295,26 @@ def _ssd_estimate( ) covs = [S, R] C_ref = S + + all_ranks = list() + for cov in covs: + r = list( + compute_rank( + Covariance( + cov, + info.ch_names, + list(), + list(), + 0, + verbose=_verbose_safe_false(), + ), + rank, + _handle_default("scalings_cov_rank", None), + info, + ).values() + )[0] + all_ranks.append(r) + rank = np.min(all_ranks) return covs, C_ref, info, rank, dict() diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 51ac2ce0c0c..80e1d9c8f47 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -142,7 +142,7 @@ def __init__( _csp_mod, mod_params, dec_type="single", - restr_map="restricting", + restr_type="restricting", R_func=sum, ) @@ -911,7 +911,7 @@ def __init__( _spoc_mod, mod_params, dec_type="single", - restr_map=None, + restr_type=None, R_func=None, ) diff --git a/mne/decoding/ged.py b/mne/decoding/ged.py index d71db4aa2f8..cabf23bccbb 100644 --- a/mne/decoding/ged.py +++ b/mne/decoding/ged.py @@ -6,28 +6,26 @@ import scipy.linalg from ..cov import Covariance, _smart_eigh, compute_whitener -from ..defaults import _handle_default -from ..rank import compute_rank -from ..utils import _verbose_safe_false, logger +from ..utils import logger -def _handle_restr_map(C_ref, restr_map, info, rank): +def _handle_restr_map(C_ref, restr_type, info, rank): """Get restricting map to C_ref rank-dimensional principal subspace. Returns matrix of shape (rank, n_chs) used to restrict or restrict+rescale (whiten) covariances matrices. """ - if C_ref is None or restr_map is None: + if C_ref is None or restr_type is None: return None - if restr_map == "whitening": + if restr_type == "whitening": projs = info["projs"] C_ref_cov = Covariance(C_ref, info.ch_names, info["bads"], projs, 0) - restr_map = compute_whitener(C_ref_cov, info, rank=rank, pca=True) - elif restr_map == "ssd": + restr_map = compute_whitener(C_ref_cov, info, rank=rank, pca=True)[0] + elif restr_type == "ssd": restr_map = _get_ssd_whitener(C_ref, rank) - elif restr_map == "restricting": + elif restr_type == "restricting": restr_map = _get_restricting_map(C_ref, info, rank) - elif isinstance(restr_map, callable): + elif isinstance(restr_type, callable): pass else: raise ValueError( @@ -147,6 +145,15 @@ def _ajd_pham(X, eps=1e-6, max_iter=15): return V, D +def _is_all_pos_def(covs): + for cov in covs: + try: + _ = scipy.linalg.cholesky(cov) + except np.linalg.LinAlgError: + return False + return True + + def _smart_ajd(covs, restr_map=None, weights=None): """Perform smart approximate joint diagonalization. @@ -157,6 +164,12 @@ def _smart_ajd(covs, restr_map=None, weights=None): The matrix of generalized eigenvectors is of shape (n_chs, r). """ if restr_map is None: + is_all_pos_def = _is_all_pos_def(covs) + if not is_all_pos_def: + raise ValueError( + "If C_ref is not provided by covariance estimator, " + "all the covs should be positive definite" + ) evecs, D = _ajd_pham(covs) return evecs @@ -191,42 +204,6 @@ def _normalize_eigenvectors(evecs, covs, sample_weights): return evecs -def _get_ssd_rank(S, R, info, rank): - # find ranks of covariance matrices - rank_signal = list( - compute_rank( - Covariance( - S, - info.ch_names, - list(), - list(), - 0, - verbose=_verbose_safe_false(), - ), - rank, - _handle_default("scalings_cov_rank", None), - info, - ).values() - )[0] - rank_noise = list( - compute_rank( - Covariance( - R, - info.ch_names, - list(), - list(), - 0, - verbose=_verbose_safe_false(), - ), - rank, - _handle_default("scalings_cov_rank", None), - info, - ).values() - )[0] - rank = np.min([rank_signal, rank_noise]) # should be identical - return rank - - def _get_ssd_whitener(S, rank): """Perform dimensionality reduction on the covariance matrices.""" n_channels = S.shape[0] diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 367be7038d3..b8e0a060fb0 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -137,7 +137,7 @@ def __init__( _ssd_mod, mod_params, dec_type="single", - restr_map="ssd", + restr_type="ssd", R_func=None, ) diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py new file mode 100644 index 00000000000..cfc1c4abf81 --- /dev/null +++ b/mne/decoding/tests/test_ged.py @@ -0,0 +1,124 @@ +# Authors: The MNE-Python contributors. +# License: BSD-3-Clause +# Copyright the MNE-Python contributors. + + +import functools + +import numpy as np +import pytest + +pytest.importorskip("sklearn") + + +from sklearn.model_selection import ParameterGrid +from sklearn.utils.estimator_checks import parametrize_with_checks + +from mne import compute_rank, create_info +from mne._fiff.proj import make_eeg_average_ref_proj +from mne.cov import Covariance, _regularized_covariance +from mne.decoding.base import GEDTransformer + + +def _mock_info(n_channels): + info = create_info(n_channels, 1000.0, "eeg") + avg_eeg_projector = make_eeg_average_ref_proj(info=info, activate=False) + info["projs"].append(avg_eeg_projector) + return info + + +def _get_min_rank(covs, info): + min_rank = dict( + eeg=min( + list( + compute_rank( + Covariance( + cov, + info.ch_names, + list(), + list(), + 0, + # verbose=_verbose_safe_false(), + ), + rank=None, + # _handle_default("scalings_cov_rank", None), + info=info, + ).values() + )[0] + for cov in covs + ) + ) + return min_rank + + +def _mock_cov_callable(X, y, cov_method_params=None): + if cov_method_params is None: + cov_method_params = dict() + n_epochs, n_channels, n_times = X.shape + + # To pass sklearn check: + if n_channels == 1: + n_channels = 2 + X = np.tile(X, (1, n_channels, 1)) + + # To make covariance estimation sensible + if n_times == 1: + n_times = n_channels + X = np.tile(X, (1, 1, n_channels)) + + classes = np.unique(y) + covs, sample_weights = list(), list() + for ci, this_class in enumerate(classes): + class_data = X[y == this_class] + class_data = class_data.transpose(1, 0, 2).reshape(n_channels, -1) + cov = _regularized_covariance(class_data, **cov_method_params) + covs.append(cov) + sample_weights.append(class_data.shape[0]) + + ref_data = X.transpose(1, 0, 2).reshape(n_channels, -1) + C_ref = _regularized_covariance(ref_data, **cov_method_params) + info = _mock_info(n_channels) + rank = _get_min_rank(covs, info) + kwargs = dict() + + # To pass sklearn check: + if len(covs) == 1: + covs.append(covs[0]) + + elif len(covs) > 2: + kwargs["sample_weights"] = sample_weights + return covs, C_ref, info, rank, kwargs + + +def _mock_mod_ged_callable(evals, evecs, covs, **kwargs): + if evals is not None: + ix = np.argsort(evals)[::-1] + evals = evals[ix] + evecs = evecs[:, ix] + return evals, evecs + + +param_grid = dict( + n_filters=[4], + cov_callable=[_mock_cov_callable], + cov_params=[ + dict(cov_method_params=dict(reg="empirical")), + ], + mod_ged_callable=[_mock_mod_ged_callable], + mod_params=[dict()], + dec_type=["single", "multi"], + restr_type=[ + "restricting", + "whitening", + ], # Not covering "ssd" here because its tests work with 2D data. + R_func=[functools.partial(np.sum, axis=0)], +) + +ged_estimators = [GEDTransformer(**p) for p in ParameterGrid(param_grid)] + + +@pytest.mark.slowtest +@parametrize_with_checks(ged_estimators) +def test_sklearn_compliance(estimator, check): + """Test GEDTransformer compliance with sklearn.""" + check(estimator) diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py index f794e404f46..5ccc397a087 100644 --- a/mne/preprocessing/xdawn.py +++ b/mne/preprocessing/xdawn.py @@ -271,7 +271,7 @@ def __init__(self, n_components=2, reg=None, signal_cov=None, method_params=None _xdawn_mod, mod_params, dec_type="multi", - restr_map=None, + restr_type=None, R_func=None, ) From 6e8b3aa7dd5ba37f68b84729573c3c1e439582da Mon Sep 17 00:00:00 2001 From: Genuster Date: Mon, 2 Jun 2025 19:30:46 +0300 Subject: [PATCH 06/18] add __sklearn_tags__ to vulture's whitelist --- tools/vulture_allowlist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index 9d0e215ee80..08623cf14b8 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -43,6 +43,7 @@ _._more_tags _.multi_class _.preserves_dtype +_.__sklearn_tags__ deep # Backward compat or rarely used From b2e24eae398bf6567176054d0f3a7e0bab7174df Mon Sep 17 00:00:00 2001 From: Genuster Date: Mon, 2 Jun 2025 19:43:05 +0300 Subject: [PATCH 07/18] calm vulture down per attribute --- tools/vulture_allowlist.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index 08623cf14b8..32ee1091131 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -43,7 +43,9 @@ _._more_tags _.multi_class _.preserves_dtype -_.__sklearn_tags__ +_.one_d_labels +_.two_d_array +_.three_d_array deep # Backward compat or rarely used From fbd585e83e5264cbaa1a9fb919c062cafbfea4c2 Mon Sep 17 00:00:00 2001 From: Genuster Date: Mon, 2 Jun 2025 20:01:28 +0300 Subject: [PATCH 08/18] put the TransformerMixin back --- mne/decoding/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index ac4700b53ed..bd84c937311 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -13,6 +13,7 @@ from sklearn.base import ( # noqa: F401 BaseEstimator, MetaEstimatorMixin, + TransformerMixin, clone, is_classifier, ) From d142bd04048cc306ddc2415bdc9b72c2ab297ffc Mon Sep 17 00:00:00 2001 From: Genuster Date: Mon, 2 Jun 2025 22:29:52 +0300 Subject: [PATCH 09/18] fix validation of covariances --- mne/decoding/base.py | 5 +++-- mne/decoding/covs_ged.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index bd84c937311..ef1e3367857 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -118,10 +118,11 @@ def fit(self, X, y=None): atleast_3d=False if self.restr_type == "ssd" else True, ) covs, C_ref, info, rank, kwargs = self.cov_callable(X, y, **self.cov_params) - self._validate_covariances(covs + [C_ref]) + covs = np.stack(covs) + self._validate_covariances(covs) + self._validate_covariances([C_ref]) if self.dec_type == "single": if len(covs) > 2: - covs = np.array(covs) sample_weights = kwargs["sample_weights"] restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank) evecs = _smart_ajd(covs, restr_map, weights=sample_weights) diff --git a/mne/decoding/covs_ged.py b/mne/decoding/covs_ged.py index f40e9fefcf3..627b7ebc900 100644 --- a/mne/decoding/covs_ged.py +++ b/mne/decoding/covs_ged.py @@ -246,7 +246,6 @@ def _xdawn_estimate( covs.append(evo_cov) covs.append(R) - covs = np.stack(covs) C_ref = None rank = None info = None From 679636605a282e500a74da5b09f80ff2cbbc2a39 Mon Sep 17 00:00:00 2001 From: Genuster Date: Wed, 4 Jun 2025 01:37:00 +0300 Subject: [PATCH 10/18] add gedtranformer tests with audvis dataset --- mne/decoding/tests/test_ged.py | 153 ++++++++++++++++++++++++++++++++- 1 file changed, 151 insertions(+), 2 deletions(-) diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py index cfc1c4abf81..10f11b765c0 100644 --- a/mne/decoding/tests/test_ged.py +++ b/mne/decoding/tests/test_ged.py @@ -2,8 +2,8 @@ # License: BSD-3-Clause # Copyright the MNE-Python contributors. - import functools +from pathlib import Path import numpy as np import pytest @@ -12,12 +12,22 @@ from sklearn.model_selection import ParameterGrid +from sklearn.utils._testing import assert_allclose from sklearn.utils.estimator_checks import parametrize_with_checks -from mne import compute_rank, create_info +from mne import Epochs, compute_rank, create_info, pick_types, read_events from mne._fiff.proj import make_eeg_average_ref_proj from mne.cov import Covariance, _regularized_covariance from mne.decoding.base import GEDTransformer +from mne.decoding.ged import _get_restricting_map, _smart_ajd, _smart_ged +from mne.io import read_raw + +data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" +raw_fname = data_dir / "test_raw.fif" +event_name = data_dir / "test-eve.fif" +tmin, tmax = -0.1, 0.2 +# if stop is too small pca may fail in some cases, but we're okay on this file +start, stop = 0, 8 def _mock_info(n_channels): @@ -122,3 +132,142 @@ def _mock_mod_ged_callable(evals, evecs, covs, **kwargs): def test_sklearn_compliance(estimator, check): """Test GEDTransformer compliance with sklearn.""" check(estimator) + + +def _get_X_y(event_id): + raw = read_raw(raw_fname, preload=False) + events = read_events(event_name) + picks = pick_types( + raw.info, meg=True, stim=False, ecg=False, eog=False, exclude="bads" + ) + picks = picks[2:12:3] # subselect channels -> disable proj! + raw.add_proj([], remove_existing=True) + epochs = Epochs( + raw, + events, + event_id, + tmin, + tmax, + picks=picks, + baseline=(None, 0), + preload=True, + proj=False, + ) + X = epochs.get_data(copy=False) + y = epochs.events[:, -1] + return X, y + + +def test_ged_binary_cov(): + """Test GEDTransformer on audvis dataset with two covariances.""" + event_id = dict(aud_l=1, vis_l=3) + X, y = _get_X_y(event_id) + # Test "single" decomposition + covs, C_ref, info, rank, kwargs = _mock_cov_callable(X, y) + S, R = covs[0], covs[1] + restr_map = _get_restricting_map(C_ref, info, rank) + evals, evecs = _smart_ged(S, R, restr_map=restr_map, R_func=None) + actual_evals, actual_evecs = _mock_mod_ged_callable(evals, evecs, [S, R], **kwargs) + actual_filters = actual_evecs.T + + ged = GEDTransformer( + n_filters=4, + cov_callable=_mock_cov_callable, + cov_params=dict(), + mod_ged_callable=_mock_mod_ged_callable, + mod_params=dict(), + dec_type="single", + restr_type="restricting", + R_func=None, + ) + ged.fit(X, y) + desired_evals = ged.evals_ + desired_filters = ged.filters_ + + assert_allclose(actual_evals, desired_evals) + assert_allclose(actual_filters, desired_filters) + + # Test "multi" decomposition (loop), restr_map can be reused + all_evals, all_evecs = list(), list() + for i in range(len(covs)): + S = covs[i] + evals, evecs = _smart_ged(S, R, restr_map) + evals, evecs = _mock_mod_ged_callable(evals, evecs, covs) + all_evals.append(evals) + all_evecs.append(evecs.T) + actual_evals = np.array(all_evals) + actual_filters = np.array(all_evecs) + + ged = GEDTransformer( + n_filters=4, + cov_callable=_mock_cov_callable, + cov_params=dict(), + mod_ged_callable=_mock_mod_ged_callable, + mod_params=dict(), + dec_type="multi", + restr_type="restricting", + R_func=None, + ) + ged.fit(X, y) + desired_evals = ged.evals_ + desired_filters = ged.filters_ + + assert_allclose(actual_evals, desired_evals) + assert_allclose(actual_filters, desired_filters) + + +def test_ged_multicov(): + """Test GEDTransformer on audvis dataset with multiple covariances.""" + event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4) + X, y = _get_X_y(event_id) + # Test "single" decomposition for multicov (AJD) + covs, C_ref, info, rank, kwargs = _mock_cov_callable(X, y) + restr_map = _get_restricting_map(C_ref, info, rank) + evecs = _smart_ajd(covs, restr_map=restr_map) + evals = None + _, actual_evecs = _mock_mod_ged_callable(evals, evecs, covs, **kwargs) + actual_filters = actual_evecs.T + + ged = GEDTransformer( + n_filters=4, + cov_callable=_mock_cov_callable, + cov_params=dict(), + mod_ged_callable=_mock_mod_ged_callable, + mod_params=dict(), + dec_type="single", + restr_type="restricting", + R_func=None, + ) + ged.fit(X, y) + desired_filters = ged.filters_ + + assert_allclose(actual_filters, desired_filters) + + # Test "multi" decomposition for multicov (loop) + R = covs[-1] + all_evals, all_evecs = list(), list() + for i in range(len(covs)): + S = covs[i] + evals, evecs = _smart_ged(S, R, restr_map) + evals, evecs = _mock_mod_ged_callable(evals, evecs, covs) + all_evals.append(evals) + all_evecs.append(evecs.T) + actual_evals = np.array(all_evals) + actual_filters = np.array(all_evecs) + + ged = GEDTransformer( + n_filters=4, + cov_callable=_mock_cov_callable, + cov_params=dict(), + mod_ged_callable=_mock_mod_ged_callable, + mod_params=dict(), + dec_type="multi", + restr_type="restricting", + R_func=None, + ) + ged.fit(X, y) + desired_evals = ged.evals_ + desired_filters = ged.filters_ + + assert_allclose(actual_evals, desired_evals) + assert_allclose(actual_filters, desired_filters) From 7a291b1f02a306d90d01fbdeadc88d8264ea8323 Mon Sep 17 00:00:00 2001 From: Genuster Date: Wed, 4 Jun 2025 21:41:42 +0300 Subject: [PATCH 11/18] fixes following Eric's comments --- mne/decoding/{covs_ged.py => _covs_ged.py} | 86 +--------------------- mne/decoding/{mod_ged.py => _mod_ged.py} | 0 mne/decoding/base.py | 2 +- mne/decoding/csp.py | 8 +- mne/decoding/ged.py | 86 +--------------------- mne/decoding/ssd.py | 8 +- mne/decoding/tests/test_ged.py | 12 +-- mne/preprocessing/xdawn.py | 8 +- 8 files changed, 24 insertions(+), 186 deletions(-) rename mne/decoding/{covs_ged.py => _covs_ged.py} (72%) rename mne/decoding/{mod_ged.py => _mod_ged.py} (100%) diff --git a/mne/decoding/covs_ged.py b/mne/decoding/_covs_ged.py similarity index 72% rename from mne/decoding/covs_ged.py rename to mne/decoding/_covs_ged.py index 627b7ebc900..3914a929770 100644 --- a/mne/decoding/covs_ged.py +++ b/mne/decoding/_covs_ged.py @@ -5,7 +5,6 @@ # Copyright the MNE-Python contributors. import numpy as np -import scipy.linalg from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx @@ -13,7 +12,7 @@ from ..defaults import _handle_default from ..filter import filter_data from ..rank import compute_rank -from ..utils import _verbose_safe_false, logger, pinv +from ..utils import _verbose_safe_false, logger def _concat_cov(x_class, *, cov_kind, log_rank, reg, cov_method_params, rank, info): @@ -110,87 +109,6 @@ def _csp_estimate(X, y, reg, cov_method_params, cov_est, rank, norm_trace): return covs, C_ref, _info, _rank, dict(sample_weights=np.array(sample_weights)) -def _construct_signal_from_epochs(epochs, events, sfreq, tmin): - """Reconstruct pseudo continuous signal from epochs.""" - n_epochs, n_channels, n_times = epochs.shape - tmax = tmin + n_times / float(sfreq) - start = np.min(events[:, 0]) + int(tmin * sfreq) - stop = np.max(events[:, 0]) + int(tmax * sfreq) + 1 - - n_samples = stop - start - n_epochs, n_channels, n_times = epochs.shape - events_pos = events[:, 0] - events[0, 0] - - raw = np.zeros((n_channels, n_samples)) - for idx in range(n_epochs): - onset = events_pos[idx] - offset = onset + n_times - raw[:, onset:offset] = epochs[idx] - - return raw - - -def _least_square_evoked(epochs_data, events, tmin, sfreq): - """Least square estimation of evoked response from epochs data. - - Parameters - ---------- - epochs_data : array, shape (n_channels, n_times) - The epochs data to estimate evoked. - events : array, shape (n_events, 3) - The events typically returned by the read_events function. - If some events don't match the events of interest as specified - by event_id, they will be ignored. - tmin : float - Start time before event. - sfreq : float - Sampling frequency. - - Returns - ------- - evokeds : array, shape (n_class, n_components, n_times) - An concatenated array of evoked data for each event type. - toeplitz : array, shape (n_class * n_components, n_channels) - An concatenated array of toeplitz matrix for each event type. - """ - n_epochs, n_channels, n_times = epochs_data.shape - tmax = tmin + n_times / float(sfreq) - - # Deal with shuffled epochs - events = events.copy() - events[:, 0] -= events[0, 0] + int(tmin * sfreq) - - # Construct raw signal - raw = _construct_signal_from_epochs(epochs_data, events, sfreq, tmin) - - # Compute the independent evoked responses per condition, while correcting - # for event overlaps. - n_min, n_max = int(tmin * sfreq), int(tmax * sfreq) - window = n_max - n_min - n_samples = raw.shape[1] - toeplitz = list() - classes = np.unique(events[:, 2]) - for ii, this_class in enumerate(classes): - # select events by type - sel = events[:, 2] == this_class - - # build toeplitz matrix - trig = np.zeros((n_samples,)) - ix_trig = (events[sel, 0]) + n_min - trig[ix_trig] = 1 - toeplitz.append(scipy.linalg.toeplitz(trig[0:window], trig)) - - # Concatenate toeplitz - toeplitz = np.array(toeplitz) - X = np.concatenate(toeplitz) - - # least square estimation - predictor = np.dot(pinv(np.dot(X, X.T)), X) - evokeds = np.dot(predictor, raw.T) - evokeds = np.transpose(np.vsplit(evokeds, len(classes)), (0, 2, 1)) - return evokeds, toeplitz - - def _xdawn_estimate( X, y, @@ -203,6 +121,8 @@ def _xdawn_estimate( info=None, rank="full", ): + from ..preprocessing.xdawn import _least_square_evoked + if not isinstance(X, np.ndarray) or X.ndim != 3: raise ValueError("X must be 3D ndarray") diff --git a/mne/decoding/mod_ged.py b/mne/decoding/_mod_ged.py similarity index 100% rename from mne/decoding/mod_ged.py rename to mne/decoding/_mod_ged.py diff --git a/mne/decoding/base.py b/mne/decoding/base.py index ef1e3367857..f8b83f71ddc 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -29,7 +29,7 @@ from .transformer import MNETransformerMixin -class GEDTransformer(MNETransformerMixin, BaseEstimator): +class _GEDTransformer(MNETransformerMixin, BaseEstimator): """M/EEG signal decomposition using the generalized eigenvalue decomposition (GED). Given two channel covariance matrices S and R, the goal is to find spatial filters diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 80e1d9c8f47..53678ee827f 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -20,13 +20,13 @@ logger, pinv, ) -from .base import GEDTransformer -from .covs_ged import _csp_estimate, _spoc_estimate -from .mod_ged import _csp_mod, _spoc_mod +from ._covs_ged import _csp_estimate, _spoc_estimate +from ._mod_ged import _csp_mod, _spoc_mod +from .base import _GEDTransformer @fill_doc -class CSP(GEDTransformer): +class CSP(_GEDTransformer): """M/EEG signal decomposition using the Common Spatial Patterns (CSP). This class can be used as a supervised decomposition to estimate spatial diff --git a/mne/decoding/ged.py b/mne/decoding/ged.py index cabf23bccbb..68f11f9b1c4 100644 --- a/mne/decoding/ged.py +++ b/mne/decoding/ged.py @@ -61,90 +61,6 @@ def _smart_ged(S, R, restr_map=None, R_func=None, mult_order=None): return evals, evecs -def _ajd_pham(X, eps=1e-6, max_iter=15): - """Approximate joint diagonalization based on Pham's algorithm. - - This is a direct implementation of the PHAM's AJD algorithm [1]. - - Parameters - ---------- - X : ndarray, shape (n_epochs, n_channels, n_channels) - A set of covariance matrices to diagonalize. - eps : float, default 1e-6 - The tolerance for stopping criterion. - max_iter : int, default 1000 - The maximum number of iteration to reach convergence. - - Returns - ------- - V : ndarray, shape (n_channels, n_channels) - The diagonalizer. - D : ndarray, shape (n_epochs, n_channels, n_channels) - The set of quasi diagonal matrices. - - References - ---------- - .. [1] Pham, Dinh Tuan. "Joint approximate diagonalization of positive - definite Hermitian matrices." SIAM Journal on Matrix Analysis and - Applications 22, no. 4 (2001): 1136-1152. - - """ - # Adapted from http://github.com/alexandrebarachant/pyRiemann - n_epochs = X.shape[0] - - # Reshape input matrix - A = np.concatenate(X, axis=0).T - - # Init variables - n_times, n_m = A.shape - V = np.eye(n_times) - epsilon = n_times * (n_times - 1) * eps - - for it in range(max_iter): - decr = 0 - for ii in range(1, n_times): - for jj in range(ii): - Ii = np.arange(ii, n_m, n_times) - Ij = np.arange(jj, n_m, n_times) - - c1 = A[ii, Ii] - c2 = A[jj, Ij] - - g12 = np.mean(A[ii, Ij] / c1) - g21 = np.mean(A[ii, Ij] / c2) - - omega21 = np.mean(c1 / c2) - omega12 = np.mean(c2 / c1) - omega = np.sqrt(omega12 * omega21) - - tmp = np.sqrt(omega21 / omega12) - tmp1 = (tmp * g12 + g21) / (omega + 1) - tmp2 = (tmp * g12 - g21) / max(omega - 1, 1e-9) - - h12 = tmp1 + tmp2 - h21 = np.conj((tmp1 - tmp2) / tmp) - - decr += n_epochs * (g12 * np.conj(h12) + g21 * h21) / 2.0 - - tmp = 1 + 1.0j * 0.5 * np.imag(h12 * h21) - tmp = np.real(tmp + np.sqrt(tmp**2 - h12 * h21)) - tau = np.array([[1, -h12 / tmp], [-h21 / tmp, 1]]) - - A[[ii, jj], :] = np.dot(tau, A[[ii, jj], :]) - tmp = np.c_[A[:, Ii], A[:, Ij]] - tmp = np.reshape(tmp, (n_times * n_epochs, 2), order="F") - tmp = np.dot(tmp, tau.T) - - tmp = np.reshape(tmp, (n_times, n_epochs * 2), order="F") - A[:, Ii] = tmp[:, :n_epochs] - A[:, Ij] = tmp[:, n_epochs:] - V[[ii, jj], :] = np.dot(tau, V[[ii, jj], :]) - if decr < epsilon: - break - D = np.reshape(A, (n_times, -1, n_times)).transpose(1, 0, 2) - return V, D - - def _is_all_pos_def(covs): for cov in covs: try: @@ -163,6 +79,8 @@ def _smart_ajd(covs, restr_map=None, weights=None): are transformed back to the original space. The matrix of generalized eigenvectors is of shape (n_chs, r). """ + from .csp import _ajd_pham + if restr_map is None: is_all_pos_def = _is_all_pos_def(covs) if not is_all_pos_def: diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index b8e0a060fb0..7c2c8a3d7ac 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -20,13 +20,13 @@ fill_doc, logger, ) -from .base import GEDTransformer -from .covs_ged import _ssd_estimate -from .mod_ged import _ssd_mod +from ._covs_ged import _ssd_estimate +from ._mod_ged import _ssd_mod +from .base import _GEDTransformer @fill_doc -class SSD(GEDTransformer): +class SSD(_GEDTransformer): """ Signal decomposition using the Spatio-Spectral Decomposition (SSD). diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py index 10f11b765c0..b72369404cf 100644 --- a/mne/decoding/tests/test_ged.py +++ b/mne/decoding/tests/test_ged.py @@ -18,7 +18,7 @@ from mne import Epochs, compute_rank, create_info, pick_types, read_events from mne._fiff.proj import make_eeg_average_ref_proj from mne.cov import Covariance, _regularized_covariance -from mne.decoding.base import GEDTransformer +from mne.decoding.base import _GEDTransformer from mne.decoding.ged import _get_restricting_map, _smart_ajd, _smart_ged from mne.io import read_raw @@ -124,7 +124,7 @@ def _mock_mod_ged_callable(evals, evecs, covs, **kwargs): R_func=[functools.partial(np.sum, axis=0)], ) -ged_estimators = [GEDTransformer(**p) for p in ParameterGrid(param_grid)] +ged_estimators = [_GEDTransformer(**p) for p in ParameterGrid(param_grid)] @pytest.mark.slowtest @@ -170,7 +170,7 @@ def test_ged_binary_cov(): actual_evals, actual_evecs = _mock_mod_ged_callable(evals, evecs, [S, R], **kwargs) actual_filters = actual_evecs.T - ged = GEDTransformer( + ged = _GEDTransformer( n_filters=4, cov_callable=_mock_cov_callable, cov_params=dict(), @@ -198,7 +198,7 @@ def test_ged_binary_cov(): actual_evals = np.array(all_evals) actual_filters = np.array(all_evecs) - ged = GEDTransformer( + ged = _GEDTransformer( n_filters=4, cov_callable=_mock_cov_callable, cov_params=dict(), @@ -228,7 +228,7 @@ def test_ged_multicov(): _, actual_evecs = _mock_mod_ged_callable(evals, evecs, covs, **kwargs) actual_filters = actual_evecs.T - ged = GEDTransformer( + ged = _GEDTransformer( n_filters=4, cov_callable=_mock_cov_callable, cov_params=dict(), @@ -255,7 +255,7 @@ def test_ged_multicov(): actual_evals = np.array(all_evals) actual_filters = np.array(all_evecs) - ged = GEDTransformer( + ged = _GEDTransformer( n_filters=4, cov_callable=_mock_cov_callable, cov_params=dict(), diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py index 5ccc397a087..d7775a87705 100644 --- a/mne/preprocessing/xdawn.py +++ b/mne/preprocessing/xdawn.py @@ -7,9 +7,9 @@ from .._fiff.pick import _pick_data_channels, pick_info from ..cov import Covariance, _regularized_covariance -from ..decoding.base import GEDTransformer -from ..decoding.covs_ged import _xdawn_estimate -from ..decoding.mod_ged import _xdawn_mod +from ..decoding._covs_ged import _xdawn_estimate +from ..decoding._mod_ged import _xdawn_mod +from ..decoding.base import _GEDTransformer from ..epochs import BaseEpochs from ..evoked import Evoked, EvokedArray from ..io import BaseRaw @@ -214,7 +214,7 @@ def _fit_xdawn( return filters, patterns, evokeds -class _XdawnTransformer(GEDTransformer): +class _XdawnTransformer(_GEDTransformer): """Implementation of the Xdawn Algorithm compatible with scikit-learn. Xdawn is a spatial filtering method designed to improve the signal From 7c867ecc83a409da5915d7062bab53a061035fac Mon Sep 17 00:00:00 2001 From: Genuster Date: Wed, 4 Jun 2025 23:20:56 +0300 Subject: [PATCH 12/18] document shapes --- mne/decoding/base.py | 9 ++------- mne/decoding/csp.py | 1 + mne/decoding/ssd.py | 2 ++ mne/preprocessing/xdawn.py | 25 +++++++++++++------------ 4 files changed, 18 insertions(+), 19 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index f8b83f71ddc..645a5d09b9d 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -184,12 +184,8 @@ def transform(self, X): if self.dec_type == "single": pick_filters = self.filters_[: self.n_filters] elif self.dec_type == "multi": - pick_filters = np.concatenate( - [ - self.filters_[i, : self.n_filters] - for i in range(self.filters_.shape[0]) - ], - axis=0, + pick_filters = self.filters_[:, : self.n_filters, :].reshape( + -1, self.filters_.shape[2] ) X = np.asarray([pick_filters @ epoch for epoch in X]) return X @@ -213,7 +209,6 @@ def _validate_covariances(self, covs): def __sklearn_tags__(self): """Tag the transformer.""" tags = super().__sklearn_tags__() - tags.estimator_type = "transformer" # Can be a transformer where S and R covs are not based on y classes. tags.target_tags.required = False tags.target_tags.one_d_labels = True diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 53678ee827f..47d40ef1b1e 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -216,6 +216,7 @@ def fit(self, X, y): old_filters = self.filters_ old_patterns = self.patterns_ super().fit(X, y) + # AJD returns evals_ as None. if self.evals_ is None: assert eigen_values is None else: diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 7c2c8a3d7ac..25df128860f 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -266,6 +266,8 @@ def fit(self, X, y=None): old_filters = self.filters_ old_patterns = self.patterns_ super().fit(X, y) + # SSD, as opposed to CSP and Xdawn stores filters as (n_chs, n_components) + # So need to transpose into (n_components, n_chs) self.filters_ = self.filters_.T np.testing.assert_allclose(self.eigvals_, self.evals_) diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py index d7775a87705..5d0e52cd7d8 100644 --- a/mne/preprocessing/xdawn.py +++ b/mne/preprocessing/xdawn.py @@ -305,19 +305,20 @@ def fit(self, X, y=None): old_filters = self.filters_ old_patterns = self.patterns_ super().fit(X, y) - self.filters_ = np.concatenate( - [ - self.filters_[i, : self.n_components] - for i in range(self.filters_.shape[0]) - ], - axis=0, + # Xdawn performs separate GED for each class. + # filters_ returned by _fit_xdawn are subset per + # n_components and then appended and are of shape + # (n_classes*n_components, n_chs). + # GEDTransformer creates new dimension per class without subsetting + # for easier analysis and visualisations. + # So it needs to be performed post-hoc to conform with Xdawn. + # The shape returned by GED here is (n_classes, n_evecs, n_chs) + # Need to transform and subset into (n_classes*n_components, n_chs) + self.filters_ = self.filters_[:, : self.n_components, :].reshape( + -1, self.filters_.shape[2] ) - self.patterns_ = np.concatenate( - [ - self.patterns_[i, : self.n_components] - for i in range(self.patterns_.shape[0]) - ], - axis=0, + self.patterns_ = self.patterns_[:, : self.n_components, :].reshape( + -1, self.patterns_.shape[2] ) np.testing.assert_allclose(old_filters, self.filters_) np.testing.assert_allclose(old_patterns, self.patterns_) From e1e8d6d3a5b1735a386678678ba97377ebc455a1 Mon Sep 17 00:00:00 2001 From: Genuster Date: Fri, 6 Jun 2025 16:13:25 +0300 Subject: [PATCH 13/18] another small test for GEDtransformer --- mne/decoding/base.py | 2 +- mne/decoding/tests/test_ged.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 645a5d09b9d..3daddfef399 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -201,7 +201,7 @@ def _validate_covariances(self, covs): "check your cov_callable" ) if not np.all(np.linalg.eigvals(cov) >= 0): - ValueError( + raise ValueError( "One of covariances or C_ref has negative eigenvalues, " "check your cov_callable" ) diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py index b72369404cf..efb962341c8 100644 --- a/mne/decoding/tests/test_ged.py +++ b/mne/decoding/tests/test_ged.py @@ -271,3 +271,24 @@ def test_ged_multicov(): assert_allclose(actual_evals, desired_evals) assert_allclose(actual_filters, desired_filters) + + +def test_ged_invalid_cov(): + """Test _validate_covariances raises proper errors.""" + ged = _GEDTransformer( + n_filters=1, + cov_callable=_mock_cov_callable, + cov_params=dict(), + mod_ged_callable=_mock_mod_ged_callable, + mod_params=dict(), + dec_type="single", + restr_type=None, + R_func=None, + ) + asymm_cov = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + with pytest.raises(ValueError): + ged._validate_covariances([asymm_cov, None]) + + negsemidef_cov = np.array([[-2, 0, 0], [0, -1, 0], [0, 0, -3]]) + with pytest.raises(ValueError): + ged._validate_covariances([negsemidef_cov, None]) From 5edc6fa01836f55c104783c850c013539317327f Mon Sep 17 00:00:00 2001 From: Genuster Date: Fri, 6 Jun 2025 16:15:32 +0300 Subject: [PATCH 14/18] change name of restricting map to restricting matrix --- mne/decoding/base.py | 14 ++++----- mne/decoding/ged.py | 52 +++++++++++++++++----------------- mne/decoding/tests/test_ged.py | 16 +++++------ 3 files changed, 41 insertions(+), 41 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 3daddfef399..96516383f18 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -25,7 +25,7 @@ from ..parallel import parallel_func from ..utils import _pl, logger, pinv, verbose, warn -from .ged import _handle_restr_map, _smart_ajd, _smart_ged +from .ged import _handle_restr_mat, _smart_ajd, _smart_ged from .transformer import MNETransformerMixin @@ -124,8 +124,8 @@ def fit(self, X, y=None): if self.dec_type == "single": if len(covs) > 2: sample_weights = kwargs["sample_weights"] - restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank) - evecs = _smart_ajd(covs, restr_map, weights=sample_weights) + restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank) + evecs = _smart_ajd(covs, restr_mat, weights=sample_weights) evals = None else: S = covs[0] @@ -134,9 +134,9 @@ def fit(self, X, y=None): mult_order = "ssd" else: mult_order = None - restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank) + restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank) evals, evecs = _smart_ged( - S, R, restr_map, R_func=self.R_func, mult_order=mult_order + S, R, restr_mat, R_func=self.R_func, mult_order=mult_order ) evals, evecs = self.mod_ged_callable( @@ -156,13 +156,13 @@ def fit(self, X, y=None): mult_order = "ssd" else: mult_order = None - restr_map = _handle_restr_map(C_ref, self.restr_type, info, rank) + restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank) all_evals, all_evecs, all_patterns = list(), list(), list() for i in range(len(self.classes_)): S = covs[i] evals, evecs = _smart_ged( - S, R, restr_map, R_func=self.R_func, mult_order=mult_order + S, R, restr_mat, R_func=self.R_func, mult_order=mult_order ) evals, evecs = self.mod_ged_callable( diff --git a/mne/decoding/ged.py b/mne/decoding/ged.py index 68f11f9b1c4..75dbb757eb6 100644 --- a/mne/decoding/ged.py +++ b/mne/decoding/ged.py @@ -9,8 +9,8 @@ from ..utils import logger -def _handle_restr_map(C_ref, restr_type, info, rank): - """Get restricting map to C_ref rank-dimensional principal subspace. +def _handle_restr_mat(C_ref, restr_type, info, rank): + """Get restricting matrix to C_ref rank-dimensional principal subspace. Returns matrix of shape (rank, n_chs) used to restrict or restrict+rescale (whiten) covariances matrices. @@ -20,43 +20,43 @@ def _handle_restr_map(C_ref, restr_type, info, rank): if restr_type == "whitening": projs = info["projs"] C_ref_cov = Covariance(C_ref, info.ch_names, info["bads"], projs, 0) - restr_map = compute_whitener(C_ref_cov, info, rank=rank, pca=True)[0] + restr_mat = compute_whitener(C_ref_cov, info, rank=rank, pca=True)[0] elif restr_type == "ssd": - restr_map = _get_ssd_whitener(C_ref, rank) + restr_mat = _get_ssd_whitener(C_ref, rank) elif restr_type == "restricting": - restr_map = _get_restricting_map(C_ref, info, rank) + restr_mat = _get_restr_mat(C_ref, info, rank) elif isinstance(restr_type, callable): pass else: raise ValueError( - "restr_map should either be callable or one of whitening, ssd, restricting" + "restr_type should either be callable or one of whitening, ssd, restricting" ) - return restr_map + return restr_mat -def _smart_ged(S, R, restr_map=None, R_func=None, mult_order=None): +def _smart_ged(S, R, restr_mat=None, R_func=None, mult_order=None): """Perform smart generalized eigenvalue decomposition (GED) of S and R. - If restr_map is provided S and R will be restricted to the principal subspace - of a reference matrix with rank r (see _handle_restr_map), then GED is performed + If restr_mat is provided S and R will be restricted to the principal subspace + of a reference matrix with rank r (see _handle_restr_mat), then GED is performed on the restricted S and R and then generalized eigenvectors are transformed back to the original space. The g-eigenvectors matrix is of shape (n_chs, r). If callable R_func is provided the GED will be performed on (S, R_func(S,R)) """ - if restr_map is None: + if restr_mat is None: evals, evecs = scipy.linalg.eigh(S, R) return evals, evecs if mult_order == "ssd": - S_restr = restr_map @ (S @ restr_map.T) - R_restr = restr_map @ (R @ restr_map.T) + S_restr = restr_mat @ (S @ restr_mat.T) + R_restr = restr_mat @ (R @ restr_mat.T) else: - S_restr = restr_map @ S @ restr_map.T - R_restr = restr_map @ R @ restr_map.T + S_restr = restr_mat @ S @ restr_mat.T + R_restr = restr_mat @ R @ restr_mat.T if R_func is not None: R_restr = R_func([S_restr, R_restr]) evals, evecs_restr = scipy.linalg.eigh(S_restr, R_restr) - evecs = restr_map.T @ evecs_restr + evecs = restr_mat.T @ evecs_restr return evals, evecs @@ -70,18 +70,18 @@ def _is_all_pos_def(covs): return True -def _smart_ajd(covs, restr_map=None, weights=None): +def _smart_ajd(covs, restr_mat=None, weights=None): """Perform smart approximate joint diagonalization. - If restr_map is provided all the cov matrices will be restricted to the - principal subspace of a reference matrix with rank r (see _handle_restr_map), + If restr_mat is provided all the cov matrices will be restricted to the + principal subspace of a reference matrix with rank r (see _handle_restr_mat), then GED is performed on the restricted S and R and then generalized eigenvectors are transformed back to the original space. The matrix of generalized eigenvectors is of shape (n_chs, r). """ from .csp import _ajd_pham - if restr_map is None: + if restr_mat is None: is_all_pos_def = _is_all_pos_def(covs) if not is_all_pos_def: raise ValueError( @@ -91,15 +91,15 @@ def _smart_ajd(covs, restr_map=None, weights=None): evecs, D = _ajd_pham(covs) return evecs - covs = np.array([restr_map @ cov @ restr_map.T for cov in covs], float) + covs = np.array([restr_mat @ cov @ restr_mat.T for cov in covs], float) evecs_restr, D = _ajd_pham(covs) evecs = _normalize_eigenvectors(evecs_restr.T, covs, weights) - evecs = restr_map.T @ evecs + evecs = restr_mat.T @ evecs return evecs -def _get_restricting_map(C, info, rank): - """Get map restricting covariance to rank-dimensional principal subspace of C.""" +def _get_restr_mat(C, info, rank): + """Get matrix restricting covariance to rank-dimensional principal subspace of C.""" _, ref_evecs, mask = _smart_eigh( C, info, @@ -108,8 +108,8 @@ def _get_restricting_map(C, info, rank): do_compute_rank=False, log_ch_type="data", ) - restr_map = ref_evecs[mask] - return restr_map + restr_mat = ref_evecs[mask] + return restr_mat def _normalize_eigenvectors(evecs, covs, sample_weights): diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py index efb962341c8..21d3a96f871 100644 --- a/mne/decoding/tests/test_ged.py +++ b/mne/decoding/tests/test_ged.py @@ -19,7 +19,7 @@ from mne._fiff.proj import make_eeg_average_ref_proj from mne.cov import Covariance, _regularized_covariance from mne.decoding.base import _GEDTransformer -from mne.decoding.ged import _get_restricting_map, _smart_ajd, _smart_ged +from mne.decoding.ged import _get_restr_mat, _smart_ajd, _smart_ged from mne.io import read_raw data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" @@ -165,8 +165,8 @@ def test_ged_binary_cov(): # Test "single" decomposition covs, C_ref, info, rank, kwargs = _mock_cov_callable(X, y) S, R = covs[0], covs[1] - restr_map = _get_restricting_map(C_ref, info, rank) - evals, evecs = _smart_ged(S, R, restr_map=restr_map, R_func=None) + restr_mat = _get_restr_mat(C_ref, info, rank) + evals, evecs = _smart_ged(S, R, restr_mat=restr_mat, R_func=None) actual_evals, actual_evecs = _mock_mod_ged_callable(evals, evecs, [S, R], **kwargs) actual_filters = actual_evecs.T @@ -187,11 +187,11 @@ def test_ged_binary_cov(): assert_allclose(actual_evals, desired_evals) assert_allclose(actual_filters, desired_filters) - # Test "multi" decomposition (loop), restr_map can be reused + # Test "multi" decomposition (loop), restr_mat can be reused all_evals, all_evecs = list(), list() for i in range(len(covs)): S = covs[i] - evals, evecs = _smart_ged(S, R, restr_map) + evals, evecs = _smart_ged(S, R, restr_mat) evals, evecs = _mock_mod_ged_callable(evals, evecs, covs) all_evals.append(evals) all_evecs.append(evecs.T) @@ -222,8 +222,8 @@ def test_ged_multicov(): X, y = _get_X_y(event_id) # Test "single" decomposition for multicov (AJD) covs, C_ref, info, rank, kwargs = _mock_cov_callable(X, y) - restr_map = _get_restricting_map(C_ref, info, rank) - evecs = _smart_ajd(covs, restr_map=restr_map) + restr_mat = _get_restr_mat(C_ref, info, rank) + evecs = _smart_ajd(covs, restr_mat=restr_mat) evals = None _, actual_evecs = _mock_mod_ged_callable(evals, evecs, covs, **kwargs) actual_filters = actual_evecs.T @@ -248,7 +248,7 @@ def test_ged_multicov(): all_evals, all_evecs = list(), list() for i in range(len(covs)): S = covs[i] - evals, evecs = _smart_ged(S, R, restr_map) + evals, evecs = _smart_ged(S, R, restr_mat) evals, evecs = _mock_mod_ged_callable(evals, evecs, covs) all_evals.append(evals) all_evecs.append(evecs.T) From 89fb1411ce28c607388d26dce7da548c431db614 Mon Sep 17 00:00:00 2001 From: Genuster Date: Fri, 6 Jun 2025 18:50:43 +0300 Subject: [PATCH 15/18] a few more ged tests --- mne/decoding/base.py | 18 ++++++++--------- mne/decoding/ged.py | 33 +++++++++++++++++++++--------- mne/decoding/tests/test_ged.py | 37 +++++++++++++++++++++++++++++----- 3 files changed, 64 insertions(+), 24 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 96516383f18..1d8e240f621 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -8,7 +8,6 @@ import numbers import numpy as np -import scipy.linalg from sklearn import model_selection as models from sklearn.base import ( # noqa: F401 BaseEstimator, @@ -25,7 +24,7 @@ from ..parallel import parallel_func from ..utils import _pl, logger, pinv, verbose, warn -from .ged import _handle_restr_mat, _smart_ajd, _smart_ged +from .ged import _handle_restr_mat, _is_cov_symm_pos_semidef, _smart_ajd, _smart_ged from .transformer import MNETransformerMixin @@ -194,15 +193,14 @@ def _validate_covariances(self, covs): for cov in covs: if cov is None: continue - is_sym = scipy.linalg.issymmetric(cov, rtol=1e-10, atol=1e-11) - if not is_sym: + # XXX: A lot of mne.decoding classes use mne.cov._regularized_covariance. + # Depending on the data it sometimes returns negative semidefinite matrices. + # So adding the validation of positive semidefinitiveness + # will require overhauling covariance estimation first. + is_cov = _is_cov_symm_pos_semidef(cov, check_pos_semidef=False) + if not is_cov: raise ValueError( - "One of covariances or C_ref is not symmetric, " - "check your cov_callable" - ) - if not np.all(np.linalg.eigvals(cov) >= 0): - raise ValueError( - "One of covariances or C_ref has negative eigenvalues, " + "One of covariances is not symmetric (or positive semidefinite), " "check your cov_callable" ) diff --git a/mne/decoding/ged.py b/mne/decoding/ged.py index 75dbb757eb6..4627c89514d 100644 --- a/mne/decoding/ged.py +++ b/mne/decoding/ged.py @@ -25,8 +25,8 @@ def _handle_restr_mat(C_ref, restr_type, info, rank): restr_mat = _get_ssd_whitener(C_ref, rank) elif restr_type == "restricting": restr_mat = _get_restr_mat(C_ref, info, rank) - elif isinstance(restr_type, callable): - pass + elif callable(restr_type): + restr_mat = restr_type else: raise ValueError( "restr_type should either be callable or one of whitening, ssd, restricting" @@ -61,15 +61,30 @@ def _smart_ged(S, R, restr_mat=None, R_func=None, mult_order=None): return evals, evecs -def _is_all_pos_def(covs): - for cov in covs: - try: - _ = scipy.linalg.cholesky(cov) - except np.linalg.LinAlgError: - return False +def _is_cov_symm_pos_semidef( + cov, rtol=1e-10, atol=1e-11, eval_tol=1e-15, check_pos_semidef=True +): + is_symm = scipy.linalg.issymmetric(cov, rtol=rtol, atol=atol) + if not is_symm: + return False + + if check_pos_semidef: + # numerically slightly negative evals are considered 0 + is_pos_semidef = np.all(scipy.linalg.eigvalsh(cov) >= -eval_tol) + return is_pos_semidef + return True +def _is_cov_pos_def(cov, eval_tol=1e-15): + is_symm = _is_cov_symm_pos_semidef(cov, check_pos_semidef=False) + if not is_symm: + return False + # numerically slightly positive evals are considered 0 + is_pos_def = np.all(scipy.linalg.eigvalsh(cov) > eval_tol) + return is_pos_def + + def _smart_ajd(covs, restr_mat=None, weights=None): """Perform smart approximate joint diagonalization. @@ -82,7 +97,7 @@ def _smart_ajd(covs, restr_mat=None, weights=None): from .csp import _ajd_pham if restr_mat is None: - is_all_pos_def = _is_all_pos_def(covs) + is_all_pos_def = all([_is_cov_pos_def(cov) for cov in covs]) if not is_all_pos_def: raise ValueError( "If C_ref is not provided by covariance estimator, " diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py index 21d3a96f871..f8db73ad070 100644 --- a/mne/decoding/tests/test_ged.py +++ b/mne/decoding/tests/test_ged.py @@ -19,7 +19,13 @@ from mne._fiff.proj import make_eeg_average_ref_proj from mne.cov import Covariance, _regularized_covariance from mne.decoding.base import _GEDTransformer -from mne.decoding.ged import _get_restr_mat, _smart_ajd, _smart_ged +from mne.decoding.ged import ( + _get_restr_mat, + _handle_restr_mat, + _is_cov_pos_def, + _smart_ajd, + _smart_ged, +) from mne.io import read_raw data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" @@ -286,9 +292,30 @@ def test_ged_invalid_cov(): R_func=None, ) asymm_cov = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="not symmetric"): ged._validate_covariances([asymm_cov, None]) - negsemidef_cov = np.array([[-2, 0, 0], [0, -1, 0], [0, 0, -3]]) - with pytest.raises(ValueError): - ged._validate_covariances([negsemidef_cov, None]) + +def test__handle_restr_mat_invalid_restr_type(): + """Test _handle_restr_mat raises correct error when wrong restr_type.""" + C_ref = np.eye(3) + with pytest.raises(ValueError, match="restr_type"): + _handle_restr_mat(C_ref, restr_type="blah", info=None, rank=None) + + +def test__is_cov_pos_def(): + """Test _is_cov_pos_def works.""" + sing_pos_semidef = np.array([[1.0, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0]]) + pos_def = np.array([[5.0, 1.0, 1.0], [1.0, 6.0, 2.0], [1.0, 2.0, 7.0]]) + assert not _is_cov_pos_def(sing_pos_semidef) + assert _is_cov_pos_def(pos_def) + + +def test__smart_ajd_when_restr_mat_is_none(): + """Test _smart_ajd raises ValueError when restr_mat is None.""" + sing_pos_semidef = np.array([[1.0, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0]]) + pos_def1 = np.array([[5.0, 1.0, 1.0], [1.0, 6.0, 2.0], [1.0, 2.0, 7.0]]) + pos_def2 = np.array([[10, 1, 2], [1, 12, 3], [2, 3, 15]]) + bad_covs = [sing_pos_semidef, pos_def1, pos_def2] + with pytest.raises(ValueError, match="positive definite"): + _smart_ajd(bad_covs, restr_mat=None, weights=None) From 3986c996a317ae46a538308368d43e2f22e2a9c4 Mon Sep 17 00:00:00 2001 From: Genuster Date: Fri, 6 Jun 2025 18:58:48 +0300 Subject: [PATCH 16/18] fix multiplication order in original SSD --- mne/decoding/base.py | 16 ++-------------- mne/decoding/ged.py | 10 +++------- mne/decoding/ssd.py | 4 ++-- 3 files changed, 7 insertions(+), 23 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index 1d8e240f621..acb0566f655 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -129,14 +129,8 @@ def fit(self, X, y=None): else: S = covs[0] R = covs[1] - if self.restr_type == "ssd": - mult_order = "ssd" - else: - mult_order = None restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank) - evals, evecs = _smart_ged( - S, R, restr_mat, R_func=self.R_func, mult_order=mult_order - ) + evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func) evals, evecs = self.mod_ged_callable( evals, evecs, covs, **self.mod_params, **kwargs @@ -151,18 +145,12 @@ def fit(self, X, y=None): elif self.dec_type == "multi": self.classes_ = np.unique(y) R = covs[-1] - if self.restr_type == "ssd": - mult_order = "ssd" - else: - mult_order = None restr_mat = _handle_restr_mat(C_ref, self.restr_type, info, rank) all_evals, all_evecs, all_patterns = list(), list(), list() for i in range(len(self.classes_)): S = covs[i] - evals, evecs = _smart_ged( - S, R, restr_mat, R_func=self.R_func, mult_order=mult_order - ) + evals, evecs = _smart_ged(S, R, restr_mat, R_func=self.R_func) evals, evecs = self.mod_ged_callable( evals, evecs, covs, **self.mod_params, **kwargs diff --git a/mne/decoding/ged.py b/mne/decoding/ged.py index 4627c89514d..b176b3ad3fd 100644 --- a/mne/decoding/ged.py +++ b/mne/decoding/ged.py @@ -34,7 +34,7 @@ def _handle_restr_mat(C_ref, restr_type, info, rank): return restr_mat -def _smart_ged(S, R, restr_mat=None, R_func=None, mult_order=None): +def _smart_ged(S, R, restr_mat=None, R_func=None): """Perform smart generalized eigenvalue decomposition (GED) of S and R. If restr_mat is provided S and R will be restricted to the principal subspace @@ -47,12 +47,8 @@ def _smart_ged(S, R, restr_mat=None, R_func=None, mult_order=None): evals, evecs = scipy.linalg.eigh(S, R) return evals, evecs - if mult_order == "ssd": - S_restr = restr_mat @ (S @ restr_mat.T) - R_restr = restr_mat @ (R @ restr_mat.T) - else: - S_restr = restr_mat @ S @ restr_mat.T - R_restr = restr_mat @ R @ restr_mat.T + S_restr = restr_mat @ S @ restr_mat.T + R_restr = restr_mat @ R @ restr_mat.T if R_func is not None: R_restr = R_func([S_restr, R_restr]) evals, evecs_restr = scipy.linalg.eigh(S_restr, R_restr) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 25df128860f..117bd45689b 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -460,6 +460,6 @@ def _dimensionality_reduction(cov_signal, cov_noise, info, rank): logger.info("Preserving covariance rank (%i)", rank) # project covariance matrices to rank subspace - cov_signal = np.matmul(rank_proj.T, np.matmul(cov_signal, rank_proj)) - cov_noise = np.matmul(rank_proj.T, np.matmul(cov_noise, rank_proj)) + cov_signal = rank_proj.T @ cov_signal @ rank_proj + cov_noise = rank_proj.T @ cov_noise @ rank_proj return cov_signal, cov_noise, rank_proj From 11b038f4446d139752b53aae03ff87bc55d8c7fb Mon Sep 17 00:00:00 2001 From: Genuster Date: Fri, 6 Jun 2025 23:17:02 +0300 Subject: [PATCH 17/18] add assert_allclose to xdawn and csp transform methods. --- mne/decoding/base.py | 20 +++++++++++++------- mne/decoding/csp.py | 14 ++++++++------ mne/decoding/tests/test_ged.py | 12 ++++++------ mne/preprocessing/xdawn.py | 6 ++++++ 4 files changed, 33 insertions(+), 19 deletions(-) diff --git a/mne/decoding/base.py b/mne/decoding/base.py index acb0566f655..1b4bd665d1e 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -36,7 +36,7 @@ class _GEDTransformer(MNETransformerMixin, BaseEstimator): Parameters ---------- - n_filters : int + n_components : int The number of spatial filters to decompose M/EEG signals. cov_callable : callable Function used to estimate covariances and reference matrix (C_ref) from the @@ -89,7 +89,7 @@ class _GEDTransformer(MNETransformerMixin, BaseEstimator): def __init__( self, - n_filters, + n_components, cov_callable, cov_params, mod_ged_callable, @@ -98,7 +98,7 @@ def __init__( restr_type=None, R_func=None, ): - self.n_filters = n_filters + self.n_components = n_components self.cov_callable = cov_callable self.cov_params = cov_params self.mod_ged_callable = mod_ged_callable @@ -169,12 +169,18 @@ def transform(self, X): check_is_fitted(self, "filters_") X = self._check_data(X) if self.dec_type == "single": - pick_filters = self.filters_[: self.n_filters] + pick_filters = self.filters_[: self.n_components] elif self.dec_type == "multi": - pick_filters = self.filters_[:, : self.n_filters, :].reshape( - -1, self.filters_.shape[2] + # XXX: Hack to assert_allclose in Xdawn's transform. + # Will be removed when overhauling xdawn. + if hasattr(self, "new_filters_"): + filters = self.new_filters_ + else: + filters = self.filters_ + pick_filters = filters[:, : self.n_components, :].reshape( + -1, filters.shape[2] ) - X = np.asarray([pick_filters @ epoch for epoch in X]) + X = pick_filters @ X return X def _validate_covariances(self, covs): diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 47d40ef1b1e..e72adaa0196 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -136,11 +136,11 @@ def __init__( mod_params = dict(evecs_order=component_order) super().__init__( - n_components, - _csp_estimate, - cov_params, - _csp_mod, - mod_params, + n_components=n_components, + cov_callable=_csp_estimate, + cov_params=cov_params, + mod_ged_callable=_csp_mod, + mod_params=mod_params, dec_type="single", restr_type="restricting", R_func=sum, @@ -254,9 +254,11 @@ def transform(self, X): """ check_is_fitted(self, "filters_") X = self._check_data(X) + orig_X = X.copy() pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) - + ged_X = super().transform(orig_X) + np.testing.assert_allclose(X, ged_X) # compute features (mean band power) if self.transform_into == "average_power": X = (X**2).mean(axis=2) diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py index f8db73ad070..433f28003ce 100644 --- a/mne/decoding/tests/test_ged.py +++ b/mne/decoding/tests/test_ged.py @@ -115,7 +115,7 @@ def _mock_mod_ged_callable(evals, evecs, covs, **kwargs): param_grid = dict( - n_filters=[4], + n_components=[4], cov_callable=[_mock_cov_callable], cov_params=[ dict(cov_method_params=dict(reg="empirical")), @@ -177,7 +177,7 @@ def test_ged_binary_cov(): actual_filters = actual_evecs.T ged = _GEDTransformer( - n_filters=4, + n_components=4, cov_callable=_mock_cov_callable, cov_params=dict(), mod_ged_callable=_mock_mod_ged_callable, @@ -205,7 +205,7 @@ def test_ged_binary_cov(): actual_filters = np.array(all_evecs) ged = _GEDTransformer( - n_filters=4, + n_components=4, cov_callable=_mock_cov_callable, cov_params=dict(), mod_ged_callable=_mock_mod_ged_callable, @@ -235,7 +235,7 @@ def test_ged_multicov(): actual_filters = actual_evecs.T ged = _GEDTransformer( - n_filters=4, + n_components=4, cov_callable=_mock_cov_callable, cov_params=dict(), mod_ged_callable=_mock_mod_ged_callable, @@ -262,7 +262,7 @@ def test_ged_multicov(): actual_filters = np.array(all_evecs) ged = _GEDTransformer( - n_filters=4, + n_components=4, cov_callable=_mock_cov_callable, cov_params=dict(), mod_ged_callable=_mock_mod_ged_callable, @@ -282,7 +282,7 @@ def test_ged_multicov(): def test_ged_invalid_cov(): """Test _validate_covariances raises proper errors.""" ged = _GEDTransformer( - n_filters=1, + n_components=1, cov_callable=_mock_cov_callable, cov_params=dict(), mod_ged_callable=_mock_mod_ged_callable, diff --git a/mne/preprocessing/xdawn.py b/mne/preprocessing/xdawn.py index 5d0e52cd7d8..45681c8387c 100644 --- a/mne/preprocessing/xdawn.py +++ b/mne/preprocessing/xdawn.py @@ -305,6 +305,9 @@ def fit(self, X, y=None): old_filters = self.filters_ old_patterns = self.patterns_ super().fit(X, y) + + # Hack for assert_allclose in transform + self.new_filters_ = self.filters_.copy() # Xdawn performs separate GED for each class. # filters_ returned by _fit_xdawn are subset per # n_components and then appended and are of shape @@ -339,6 +342,7 @@ def transform(self, X): The transformed data. """ X, _ = self._check_Xy(X) + orig_X = X.copy() # Check size if self.filters_.shape[1] != X.shape[1]: @@ -350,6 +354,8 @@ def transform(self, X): # Transform X = np.dot(self.filters_, X) X = X.transpose((1, 0, 2)) + ged_X = super().transform(orig_X) + np.testing.assert_allclose(X, ged_X) return X def inverse_transform(self, X): From 25e1ae32f3b6e227fa44718ee3946b64099cd801 Mon Sep 17 00:00:00 2001 From: Genuster Date: Sat, 7 Jun 2025 00:06:11 +0300 Subject: [PATCH 18/18] more ged tests --- mne/decoding/ged.py | 13 +++---- mne/decoding/tests/test_ged.py | 68 +++++++++++++++++++++++++++++----- 2 files changed, 64 insertions(+), 17 deletions(-) diff --git a/mne/decoding/ged.py b/mne/decoding/ged.py index b176b3ad3fd..ad3a90e25c4 100644 --- a/mne/decoding/ged.py +++ b/mne/decoding/ged.py @@ -25,8 +25,6 @@ def _handle_restr_mat(C_ref, restr_type, info, rank): restr_mat = _get_ssd_whitener(C_ref, rank) elif restr_type == "restricting": restr_mat = _get_restr_mat(C_ref, info, rank) - elif callable(restr_type): - restr_mat = restr_type else: raise ValueError( "restr_type should either be callable or one of whitening, ssd, restricting" @@ -102,11 +100,12 @@ def _smart_ajd(covs, restr_mat=None, weights=None): evecs, D = _ajd_pham(covs) return evecs - covs = np.array([restr_mat @ cov @ restr_mat.T for cov in covs], float) - evecs_restr, D = _ajd_pham(covs) - evecs = _normalize_eigenvectors(evecs_restr.T, covs, weights) - evecs = restr_mat.T @ evecs - return evecs + else: + covs = np.array([restr_mat @ cov @ restr_mat.T for cov in covs], float) + evecs_restr, D = _ajd_pham(covs) + evecs = _normalize_eigenvectors(evecs_restr.T, covs, weights) + evecs = restr_mat.T @ evecs + return evecs def _get_restr_mat(C, info, rank): diff --git a/mne/decoding/tests/test_ged.py b/mne/decoding/tests/test_ged.py index 433f28003ce..d4dd7b1ad2f 100644 --- a/mne/decoding/tests/test_ged.py +++ b/mne/decoding/tests/test_ged.py @@ -23,6 +23,7 @@ _get_restr_mat, _handle_restr_mat, _is_cov_pos_def, + _is_cov_symm_pos_semidef, _smart_ajd, _smart_ged, ) @@ -67,7 +68,7 @@ def _get_min_rank(covs, info): return min_rank -def _mock_cov_callable(X, y, cov_method_params=None): +def _mock_cov_callable(X, y, cov_method_params=None, compute_C_ref=True): if cov_method_params is None: cov_method_params = dict() n_epochs, n_channels, n_times = X.shape @@ -92,7 +93,10 @@ def _mock_cov_callable(X, y, cov_method_params=None): sample_weights.append(class_data.shape[0]) ref_data = X.transpose(1, 0, 2).reshape(n_channels, -1) - C_ref = _regularized_covariance(ref_data, **cov_method_params) + if compute_C_ref: + C_ref = _regularized_covariance(ref_data, **cov_method_params) + else: + C_ref = None info = _mock_info(n_channels) rank = _get_min_rank(covs, info) kwargs = dict() @@ -123,10 +127,12 @@ def _mock_mod_ged_callable(evals, evecs, covs, **kwargs): mod_ged_callable=[_mock_mod_ged_callable], mod_params=[dict()], dec_type=["single", "multi"], + # XXX: Not covering "ssd" here because test_ssd.py works with 2D data. + # Need to fix its tests first. restr_type=[ "restricting", "whitening", - ], # Not covering "ssd" here because its tests work with 2D data. + ], R_func=[functools.partial(np.sum, axis=0)], ) @@ -159,7 +165,7 @@ def _get_X_y(event_id): preload=True, proj=False, ) - X = epochs.get_data(copy=False) + X = epochs.get_data(copy=False, units=dict(eeg="uV", grad="fT/cm", mag="fT")) y = epochs.events[:, -1] return X, y @@ -226,7 +232,7 @@ def test_ged_multicov(): """Test GEDTransformer on audvis dataset with multiple covariances.""" event_id = dict(aud_l=1, aud_r=2, vis_l=3, vis_r=4) X, y = _get_X_y(event_id) - # Test "single" decomposition for multicov (AJD) + # Test "single" decomposition for multicov (AJD) with C_ref covs, C_ref, info, rank, kwargs = _mock_cov_callable(X, y) restr_mat = _get_restr_mat(C_ref, info, rank) evecs = _smart_ajd(covs, restr_mat=restr_mat) @@ -278,6 +284,31 @@ def test_ged_multicov(): assert_allclose(actual_evals, desired_evals) assert_allclose(actual_filters, desired_filters) + # Test "single" decomposition for multicov (AJD) without C_ref + covs, C_ref, info, rank, kwargs = _mock_cov_callable( + X, y, cov_method_params=dict(reg="oas"), compute_C_ref=False + ) + covs = np.stack(covs) + evecs = _smart_ajd(covs, restr_mat=None) + evals = None + _, actual_evecs = _mock_mod_ged_callable(evals, evecs, covs, **kwargs) + actual_filters = actual_evecs.T + + ged = _GEDTransformer( + n_components=4, + cov_callable=_mock_cov_callable, + cov_params=dict(cov_method_params=dict(reg="oas"), compute_C_ref=False), + mod_ged_callable=_mock_mod_ged_callable, + mod_params=dict(), + dec_type="single", + restr_type="restricting", + R_func=None, + ) + ged.fit(X, y) + desired_filters = ged.filters_ + + assert_allclose(actual_filters, desired_filters) + def test_ged_invalid_cov(): """Test _validate_covariances raises proper errors.""" @@ -303,19 +334,36 @@ def test__handle_restr_mat_invalid_restr_type(): _handle_restr_mat(C_ref, restr_type="blah", info=None, rank=None) +def test_cov_validators(): + """Test that covariance validators indeed validate.""" + asymm = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + pos_def = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]]) + + assert not _is_cov_symm_pos_semidef(asymm) + assert _is_cov_symm_pos_semidef(sing_pos_semidef) + assert _is_cov_symm_pos_semidef(pos_def) + + assert not _is_cov_pos_def(asymm) + assert not _is_cov_pos_def(sing_pos_semidef) + assert _is_cov_pos_def(pos_def) + + def test__is_cov_pos_def(): """Test _is_cov_pos_def works.""" - sing_pos_semidef = np.array([[1.0, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0]]) - pos_def = np.array([[5.0, 1.0, 1.0], [1.0, 6.0, 2.0], [1.0, 2.0, 7.0]]) + asymm = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + pos_def = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]]) + assert not _is_cov_pos_def(asymm) assert not _is_cov_pos_def(sing_pos_semidef) assert _is_cov_pos_def(pos_def) def test__smart_ajd_when_restr_mat_is_none(): """Test _smart_ajd raises ValueError when restr_mat is None.""" - sing_pos_semidef = np.array([[1.0, 2.0, 3.0], [2.0, 4.0, 6.0], [3.0, 6.0, 9.0]]) - pos_def1 = np.array([[5.0, 1.0, 1.0], [1.0, 6.0, 2.0], [1.0, 2.0, 7.0]]) + sing_pos_semidef = np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + pos_def1 = np.array([[5, 1, 1], [1, 6, 2], [1, 2, 7]]) pos_def2 = np.array([[10, 1, 2], [1, 12, 3], [2, 3, 15]]) - bad_covs = [sing_pos_semidef, pos_def1, pos_def2] + bad_covs = np.stack([sing_pos_semidef, pos_def1, pos_def2]) with pytest.raises(ValueError, match="positive definite"): _smart_ajd(bad_covs, restr_mat=None, weights=None)