Skip to content

Commit 37ae7e3

Browse files
BabaSanfourpre-commit-ci[bot]wmvanvlietdrammock
authored
Add raw stc (mne-tools#12001)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Marijn van Vliet <w.m.vanvliet@gmail.com> Co-authored-by: Daniel McCloy <dan@mccloy.info>
1 parent 647fdd3 commit 37ae7e3

File tree

4 files changed

+57
-19
lines changed

4 files changed

+57
-19
lines changed

doc/changes/devel.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Enhancements
2929
- Added public :func:`mne.io.write_info` to complement :func:`mne.io.read_info` (:gh:`11918` by `Eric Larson`_)
3030
- Added option ``remove_dc`` to to :meth:`Raw.compute_psd() <mne.io.Raw.compute_psd>`, :meth:`Epochs.compute_psd() <mne.Epochs.compute_psd>`, and :meth:`Evoked.compute_psd() <mne.Evoked.compute_psd>`, to allow skipping DC removal when computing Welch or multitaper spectra (:gh:`11769` by `Nikolai Chapochnikov`_)
3131
- Add the possibility to provide a float between 0 and 1 as ``n_grad``, ``n_mag`` and ``n_eeg`` in `~mne.compute_proj_raw`, `~mne.compute_proj_epochs` and `~mne.compute_proj_evoked` to select the number of vectors based on the cumulative explained variance (:gh:`11919` by `Mathieu Scheltienne`_)
32+
- Add extracting all time courses in a label using :func:`mne.extract_label_time_course` without applying an aggregation function (like ``mean``) (:gh:`12001` by `Hamza Abdelhedi`_)
3233
- Added support for Artinis fNIRS data files to :func:`mne.io.read_raw_snirf` (:gh:`11926` by `Robert Luke`_)
3334
- Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_)
3435
- Add support for passing a :class:`python:dict` as ``sensor_color`` to specify per-channel-type colors in :func:`mne.viz.plot_alignment` (:gh:`12067` by `Eric Larson`_)

mne/source_estimate.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3240,6 +3240,7 @@ def _pca_flip(flip, data):
32403240
"mean_flip": lambda flip, data: np.mean(flip * data, axis=0),
32413241
"max": lambda flip, data: np.max(np.abs(data), axis=0),
32423242
"pca_flip": _pca_flip,
3243+
None: lambda flip, data: data, # Return Identity: Preserves all vertices.
32433244
}
32443245

32453246

@@ -3494,7 +3495,7 @@ def _volume_labels(src, labels, mri_resolution):
34943495

34953496

34963497
def _get_default_label_modes():
3497-
return sorted(_label_funcs.keys()) + ["auto"]
3498+
return sorted(_label_funcs.keys(), key=lambda x: (x is None, x)) + ["auto"]
34983499

34993500

35003501
def _get_allowed_label_modes(stc):
@@ -3572,7 +3573,12 @@ def _gen_extract_label_time_course(
35723573
)
35733574

35743575
# do the extraction
3575-
label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype)
3576+
if mode is None:
3577+
# prepopulate an empty list for easy array-like index-based assignment
3578+
label_tc = [None] * max(len(label_vertidx), len(src_flip))
3579+
else:
3580+
# For other modes, initialize the label_tc array
3581+
label_tc = np.zeros((n_labels,) + stc.data.shape[1:], dtype=stc.data.dtype)
35763582
for i, (vertidx, flip) in enumerate(zip(label_vertidx, src_flip)):
35773583
if vertidx is not None:
35783584
if isinstance(vertidx, sparse.csr_matrix):
@@ -3585,15 +3591,13 @@ def _gen_extract_label_time_course(
35853591
this_data = stc.data[vertidx]
35863592
label_tc[i] = func(flip, this_data)
35873593

3588-
# extract label time series for the vol src space (only mean supported)
3589-
offset = nvert[:-n_mean].sum() # effectively :2 or :0
3590-
for i, nv in enumerate(nvert[2:]):
3591-
if nv != 0:
3592-
v2 = offset + nv
3593-
label_tc[n_mode + i] = np.mean(stc.data[offset:v2], axis=0)
3594-
offset = v2
3595-
3596-
# this is a generator!
3594+
if mode is not None:
3595+
offset = nvert[:-n_mean].sum() # effectively :2 or :0
3596+
for i, nv in enumerate(nvert[2:]):
3597+
if nv != 0:
3598+
v2 = offset + nv
3599+
label_tc[n_mode + i] = np.mean(stc.data[offset:v2], axis=0)
3600+
offset = v2
35973601
yield label_tc
35983602

35993603

mne/tests/test_source_estimate.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -678,12 +678,24 @@ def test_extract_label_time_course(kind, vector):
678678

679679
label_tcs = dict(mean=np.arange(n_labels)[:, None] * np.ones((n_labels, n_times)))
680680
label_tcs["max"] = label_tcs["mean"]
681+
label_tcs[None] = label_tcs["mean"]
681682

682683
# compute the mean with sign flip
683684
label_tcs["mean_flip"] = np.zeros_like(label_tcs["mean"])
684685
for i, label in enumerate(labels):
685686
label_tcs["mean_flip"][i] = i * np.mean(label_sign_flip(label, src[:2]))
686687

688+
# compute pca_flip
689+
label_flip = []
690+
for i, label in enumerate(labels):
691+
this_flip = i * label_sign_flip(label, src[:2])
692+
label_flip.append(this_flip)
693+
# compute pca_flip
694+
label_tcs["pca_flip"] = np.zeros_like(label_tcs["mean"])
695+
for i, (label, flip) in enumerate(zip(labels, label_flip)):
696+
sign = np.sign(np.dot(np.full((flip.shape[0]), i), flip))
697+
label_tcs["pca_flip"][i] = sign * label_tcs["mean"][i]
698+
687699
# generate some stc's with known data
688700
stcs = list()
689701
pad = (((0, 0), (2, 0), (0, 0)), "constant")
@@ -734,7 +746,7 @@ def test_extract_label_time_course(kind, vector):
734746
assert_array_equal(arr[1:], vol_means_t)
735747

736748
# test the different modes
737-
modes = ["mean", "mean_flip", "pca_flip", "max", "auto"]
749+
modes = ["mean", "mean_flip", "pca_flip", "max", "auto", None]
738750

739751
for mode in modes:
740752
if vector and mode not in ("mean", "max", "auto"):
@@ -748,18 +760,36 @@ def test_extract_label_time_course(kind, vector):
748760
]
749761
assert len(label_tc) == n_stcs
750762
assert len(label_tc_method) == n_stcs
751-
for tc1, tc2 in zip(label_tc, label_tc_method):
752-
assert tc1.shape == (n_labels + len(vol_means),) + end_shape
753-
assert tc2.shape == (n_labels + len(vol_means),) + end_shape
754-
assert_allclose(tc1, tc2, rtol=1e-8, atol=1e-16)
763+
for j, (tc1, tc2) in enumerate(zip(label_tc, label_tc_method)):
764+
if mode is None:
765+
assert all(arr.shape[1] == tc1[0].shape[1] for arr in tc1)
766+
assert all(arr.shape[1] == tc2[0].shape[1] for arr in tc2)
767+
assert (len(tc1), tc1[0].shape[1]) == (n_labels,) + end_shape
768+
assert (len(tc2), tc2[0].shape[1]) == (n_labels,) + end_shape
769+
for arr1, arr2 in zip(tc1, tc2): # list of arrays
770+
assert_allclose(arr1, arr2, rtol=1e-8, atol=1e-16)
771+
else:
772+
assert tc1.shape == (n_labels + len(vol_means),) + end_shape
773+
assert tc2.shape == (n_labels + len(vol_means),) + end_shape
774+
assert_allclose(tc1, tc2, rtol=1e-8, atol=1e-16)
755775
if mode == "auto":
756776
use_mode = "mean" if vector else "mean_flip"
757777
else:
758778
use_mode = mode
759-
# XXX we don't check pca_flip, probably should someday...
760-
if use_mode in ("mean", "max", "mean_flip"):
779+
if mode == "pca_flip":
780+
for arr1, arr2 in zip(tc1, label_tcs[use_mode]):
781+
assert_array_almost_equal(arr1, arr2)
782+
elif use_mode is None:
783+
for arr1, arr2 in zip(
784+
tc1[:n_labels], label_tcs[use_mode]
785+
): # list of arrays
786+
assert_allclose(
787+
arr1, np.tile(arr2, (arr1.shape[0], 1)), rtol=1e-8, atol=1e-16
788+
)
789+
elif use_mode in ("mean", "max", "mean_flip"):
761790
assert_array_almost_equal(tc1[:n_labels], label_tcs[use_mode])
762-
assert_array_almost_equal(tc1[n_labels:], vol_means_t)
791+
if mode is not None:
792+
assert_array_almost_equal(tc1[n_labels:], vol_means_t)
763793

764794
# test label with very few vertices (check SVD conditionals)
765795
label = Label(vertices=src[0]["vertno"][:2], hemi="lh")

mne/utils/docs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,9 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
12031203
- ``'auto'`` (default)
12041204
Uses ``'mean_flip'`` when a standard source estimate is applied, and
12051205
``'mean'`` when a vector source estimate is supplied.
1206+
- ``None``
1207+
No aggregation is performed, and an array of shape ``(n_vertices, n_times)`` is
1208+
returned.
12061209
12071210
.. versionadded:: 0.21
12081211
Support for ``'auto'``, vector, and volume source estimates.

0 commit comments

Comments
 (0)