Skip to content

Commit c12bee7

Browse files
larsonersnwnde
authored andcommitted
BUG: Fix bug with sensor_colors (mne-tools#12068)
1 parent 56851d5 commit c12bee7

File tree

12 files changed

+154
-129
lines changed

12 files changed

+154
-129
lines changed

doc/changes/devel.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Enhancements
3131
- Add the possibility to provide a float between 0 and 1 as ``n_grad``, ``n_mag`` and ``n_eeg`` in `~mne.compute_proj_raw`, `~mne.compute_proj_epochs` and `~mne.compute_proj_evoked` to select the number of vectors based on the cumulative explained variance (:gh:`11919` by `Mathieu Scheltienne`_)
3232
- Added support for Artinis fNIRS data files to :func:`mne.io.read_raw_snirf` (:gh:`11926` by `Robert Luke`_)
3333
- Add helpful error messages when using methods on empty :class:`mne.Epochs`-objects (:gh:`11306` by `Martin Schulz`_)
34+
- Add support for passing a :class:`python:dict` as ``sensor_color`` to specify per-channel-type colors in :func:`mne.viz.plot_alignment` (:gh:`12067` by `Eric Larson`_)
3435
- Add inferring EEGLAB files' montage unit automatically based on estimated head radius using :func:`read_raw_eeglab(..., montage_units="auto") <mne.io.read_raw_eeglab>` (:gh:`11925` by `Jack Zhang`_, :gh:`11951` by `Eric Larson`_)
3536
- Add :class:`~mne.time_frequency.EpochsSpectrumArray` and :class:`~mne.time_frequency.SpectrumArray` to support creating power spectra from :class:`NumPy array <numpy.ndarray>` data (:gh:`11803` by `Alex Rockhill`_)
3637
- Add support for writing forward solutions to HDF5 and convenience function :meth:`mne.Forward.save` (:gh:`12036` by `Eric Larson`_)
@@ -56,6 +57,7 @@ Bugs
5657
- Fix bug with axis clip box boundaries in :func:`mne.viz.plot_evoked_topo` and related functions (:gh:`11999` by `Eric Larson`_)
5758
- Fix bug with ``subject_info`` when loading data from and exporting to EDF file (:gh:`11952` by `Paul Roujansky`_)
5859
- Fix bug with delayed checking of :class:`info["bads"] <mne.Info>` (:gh:`12038` by `Eric Larson`_)
60+
- Fix bug with :func:`mne.viz.plot_alignment` where ``sensor_colors`` were not handled properly on a per-channel-type basis (:gh:`12067` by `Eric Larson`_)
5961
- Fix handling of channel information in annotations when loading data from and exporting to EDF file (:gh:`11960` :gh:`12017` :gh:`12044` by `Paul Roujansky`_)
6062
- Add missing ``overwrite`` and ``verbose`` parameters to :meth:`Transform.save() <mne.transforms.Transform.save>` (:gh:`12004` by `Marijn van Vliet`_)
6163
- Fix parsing of eye-link :class:`~mne.Annotations` when ``apply_offsets=False`` is provided to :func:`~mne.io.read_raw_eyelink` (:gh:`12003` by `Mathieu Scheltienne`_)

doc/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@
200200
"path-like": ":term:`path-like`",
201201
"array-like": ":term:`array_like <numpy:array_like>`",
202202
"Path": ":class:`python:pathlib.Path`",
203-
"bool": ":class:`python:bool`",
203+
"bool": ":ref:`python:typebool`",
204204
# Matplotlib
205205
"colormap": ":ref:`colormap <matplotlib:colormaps>`",
206206
"color": ":doc:`color <matplotlib:api/colors_api>`",

mne/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@
227227
coreg=dict(
228228
mri_fid_opacity=1.0,
229229
dig_fid_opacity=1.0,
230+
# go from unit scaling (e.g., unit-radius sphere) to meters
230231
mri_fid_scale=5e-3,
231232
dig_fid_scale=8e-3,
232233
extra_scale=4e-3,
@@ -235,6 +236,8 @@
235236
eegp_height=0.1,
236237
ecog_scale=5e-3,
237238
seeg_scale=5e-3,
239+
meg_scale=1.0, # sensors are already in SI units
240+
ref_meg_scale=1.0,
238241
dbs_scale=5e-3,
239242
fnirs_scale=5e-3,
240243
source_scale=5e-3,

mne/gui/_coreg.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ def _redraw(self, *, verbose=None):
835835
mri_fids=self._add_mri_fiducials,
836836
hsp=self._add_head_shape_points,
837837
hpi=self._add_hpi_coils,
838-
eeg=self._add_eeg_channels,
838+
eeg=self._add_eeg_fnirs_channels,
839839
head_fids=self._add_head_fiducials,
840840
helmet=self._add_helmet,
841841
)
@@ -1217,7 +1217,7 @@ def _add_head_shape_points(self):
12171217
hsp_actors = None
12181218
self._update_actor("head_shape_points", hsp_actors)
12191219

1220-
def _add_eeg_channels(self):
1220+
def _add_eeg_fnirs_channels(self):
12211221
if self._eeg_channels:
12221222
eeg = ["original"]
12231223
picks = pick_types(self._info, eeg=(len(eeg) > 0), fnirs=True)
@@ -1240,8 +1240,7 @@ def _add_eeg_channels(self):
12401240
check_inside=self._check_inside,
12411241
nearest=self._nearest,
12421242
)
1243-
sens_actors = actors["eeg"]
1244-
sens_actors.extend(actors["fnirs"])
1243+
sens_actors = sum(actors.values(), list())
12451244
else:
12461245
sens_actors = None
12471246
else:

mne/utils/docs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3969,6 +3969,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):
39693969
automatically generated, corresponding to all non-zero events.
39703970
"""
39713971

3972+
docdict[
3973+
"sensor_colors"
3974+
] = """
3975+
sensor_colors : array-like of color | dict | None
3976+
Colors to use for the sensor glyphs. Can be None (default) to use default colors.
3977+
A dict should provide the colors (values) for each channel type (keys), e.g.::
3978+
3979+
dict(eeg=eeg_colors)
3980+
3981+
Where the value (``eeg_colors`` above) can be broadcast to an array of colors with
3982+
length that matches the number of channels of that type, i.e., is compatible with
3983+
:func:`matplotlib.colors.to_rgba_array`. A few examples of this for the case above
3984+
are the string ``"k"``, a list of ``n_eeg`` color strings, or an NumPy ndarray of
3985+
shape ``(n_eeg, 3)`` or ``(n_eeg, 4)``.
3986+
"""
3987+
39723988
docdict[
39733989
"sensors_topomap"
39743990
] = """

mne/viz/_3d.py

Lines changed: 100 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#
1010
# License: Simplified BSD
1111

12+
from collections import defaultdict
1213
import os
1314
import os.path as op
1415
import warnings
@@ -604,11 +605,10 @@ def plot_alignment(
604605
.. versionadded:: 0.16
605606
.. versionchanged:: 1.0
606607
Defaults to ``'terrain'``.
607-
sensor_colors : array-like | None
608-
Colors to use for the sensor glyphs. Can be list-like of color strings
609-
(length ``n_sensors``) or array-like of RGB(A) values (shape
610-
``(n_sensors, 3)`` or ``(n_sensors, 4)``). ``None`` (the default) uses
611-
the default sensor colors for the :func:`~mne.viz.plot_alignment` GUI.
608+
%(sensor_colors)s
609+
610+
.. versionchanged:: 1.6
611+
Support for passing a ``dict`` was added.
612612
%(verbose)s
613613
614614
Returns
@@ -1437,29 +1437,16 @@ def _plot_sensors(
14371437
sensor_colors=None,
14381438
):
14391439
"""Render sensors in a 3D scene."""
1440+
from matplotlib.colors import to_rgba_array
1441+
14401442
defaults = DEFAULTS["coreg"]
14411443
ch_pos, sources, detectors = _ch_pos_in_coord_frame(
14421444
pick_info(info, picks), to_cf_t=to_cf_t, warn_meg=warn_meg
14431445
)
14441446

1445-
actors = dict(
1446-
meg=list(),
1447-
ref_meg=list(),
1448-
eeg=list(),
1449-
fnirs=list(),
1450-
ecog=list(),
1451-
seeg=list(),
1452-
dbs=list(),
1453-
)
1454-
locs = dict(
1455-
eeg=list(),
1456-
fnirs=list(),
1457-
ecog=list(),
1458-
seeg=list(),
1459-
source=list(),
1460-
detector=list(),
1461-
)
1462-
scalar = 1 if units == "m" else 1e3
1447+
actors = defaultdict(lambda: list())
1448+
locs = defaultdict(lambda: list())
1449+
unit_scalar = 1 if units == "m" else 1e3
14631450
for ch_name, ch_coord in ch_pos.items():
14641451
ch_type = channel_type(info, info.ch_names.index(ch_name))
14651452
# for default picking
@@ -1471,119 +1458,126 @@ def _plot_sensors(
14711458
plot_sensors = (ch_type != "fnirs" or "channels" in fnirs) and (
14721459
ch_type != "eeg" or "original" in eeg
14731460
)
1474-
color = defaults[ch_type + "_color"]
14751461
# plot sensors
14761462
if isinstance(ch_coord, tuple): # is meg, plot coil
1477-
verts, triangles = ch_coord
1478-
actor, _ = renderer.surface(
1479-
surface=dict(rr=verts * scalar, tris=triangles),
1480-
color=color,
1481-
opacity=0.25,
1482-
backface_culling=True,
1483-
)
1484-
actors[ch_type].append(actor)
1485-
else:
1486-
if plot_sensors:
1487-
locs[ch_type].append(ch_coord)
1463+
ch_coord = dict(rr=ch_coord[0] * unit_scalar, tris=ch_coord[1])
1464+
if plot_sensors:
1465+
locs[ch_type].append(ch_coord)
14881466
if ch_name in sources and "sources" in fnirs:
14891467
locs["source"].append(sources[ch_name])
14901468
if ch_name in detectors and "detectors" in fnirs:
14911469
locs["detector"].append(detectors[ch_name])
1470+
# Plot these now
14921471
if ch_name in sources and ch_name in detectors and "pairs" in fnirs:
14931472
actor, _ = renderer.tube( # array of origin and dest points
1494-
origin=sources[ch_name][np.newaxis] * scalar,
1495-
destination=detectors[ch_name][np.newaxis] * scalar,
1496-
radius=0.001 * scalar,
1473+
origin=sources[ch_name][np.newaxis] * unit_scalar,
1474+
destination=detectors[ch_name][np.newaxis] * unit_scalar,
1475+
radius=0.001 * unit_scalar,
14971476
)
14981477
actors[ch_type].append(actor)
1478+
del ch_type
14991479

1500-
# add sensors
1501-
for sensor_type in locs.keys():
1502-
if len(locs[sensor_type]) > 0:
1503-
sens_loc = np.array(locs[sensor_type])
1504-
sens_loc = sens_loc[~np.isnan(sens_loc).any(axis=1)]
1505-
scale = defaults[sensor_type + "_scale"]
1506-
if sensor_colors is None:
1507-
color = defaults[sensor_type + "_color"]
1480+
# now actually plot the sensors
1481+
extra = ""
1482+
types = (dict, None)
1483+
if len(locs) == 0:
1484+
return
1485+
elif len(locs) == 1:
1486+
# Upsample from array-like to dict when there is one channel type
1487+
extra = "(or array-like since only one sensor type is plotted)"
1488+
if sensor_colors is not None and not isinstance(sensor_colors, dict):
1489+
sensor_colors = {
1490+
list(locs)[0]: to_rgba_array(sensor_colors),
1491+
}
1492+
else:
1493+
extra = f"when more than one channel type ({list(locs)}) is plotted"
1494+
_validate_type(sensor_colors, types, "sensor_colors", extra=extra)
1495+
del extra, types
1496+
if sensor_colors is None:
1497+
sensor_colors = dict()
1498+
assert isinstance(sensor_colors, dict)
1499+
for ch_type, sens_loc in locs.items():
1500+
assert len(sens_loc) # should be guaranteed above
1501+
colors = to_rgba_array(sensor_colors.get(ch_type, defaults[ch_type + "_color"]))
1502+
_check_option(
1503+
f"len(sensor_colors[{repr(ch_type)}])",
1504+
colors.shape[0],
1505+
(len(sens_loc), 1),
1506+
)
1507+
scale = defaults[ch_type + "_scale"] * unit_scalar
1508+
if isinstance(sens_loc[0], dict): # meg coil
1509+
if len(colors) == 1:
1510+
colors = [colors[0]] * len(sens_loc)
1511+
for surface, color in zip(sens_loc, colors):
1512+
actor, _ = renderer.surface(
1513+
surface=surface,
1514+
color=color[:3],
1515+
opacity=sensor_opacity * color[3],
1516+
backface_culling=False, # visible from all sides
1517+
)
1518+
actors[ch_type].append(actor)
1519+
else:
1520+
sens_loc = np.array(sens_loc, float)
1521+
mask = ~np.isnan(sens_loc).any(axis=1)
1522+
if len(colors) == 1:
1523+
# Single color mode (one actor)
15081524
actor, _ = _plot_glyphs(
15091525
renderer=renderer,
1510-
loc=sens_loc * scalar,
1511-
color=color,
1512-
scale=scale * scalar,
1513-
opacity=sensor_opacity,
1526+
loc=sens_loc[mask] * unit_scalar,
1527+
color=colors[0, :3],
1528+
scale=scale,
1529+
opacity=sensor_opacity * colors[0, 3],
15141530
orient_glyphs=orient_glyphs,
15151531
scale_by_distance=scale_by_distance,
15161532
project_points=project_points,
15171533
surf=surf,
15181534
check_inside=check_inside,
15191535
nearest=nearest,
15201536
)
1521-
if sensor_type in ("source", "detector"):
1522-
sensor_type = "fnirs"
1523-
actors[sensor_type].append(actor)
1537+
actors[ch_type].append(actor)
15241538
else:
1525-
actor_list = []
1526-
for idx_sen in range(sens_loc.shape[0]):
1527-
sensor_colors = np.asarray(sensor_colors)
1528-
if (
1529-
sensor_colors.ndim not in (1, 2)
1530-
or sensor_colors.shape[0] != sens_loc.shape[0]
1531-
):
1532-
raise ValueError(
1533-
"sensor_colors should either be None or be "
1534-
"array-like with shape (n_sensors,) or "
1535-
"(n_sensors, 3) or (n_sensors, 4). Got shape "
1536-
f"{sensor_colors.shape}."
1537-
)
1538-
color = sensor_colors[idx_sen]
1539-
1539+
# Multi-color mode (multiple actors)
1540+
for loc, color, usable in zip(sens_loc, colors, mask):
1541+
if not usable:
1542+
continue
15401543
actor, _ = _plot_glyphs(
15411544
renderer=renderer,
1542-
loc=(sens_loc * scalar)[idx_sen, :],
1543-
color=color,
1544-
scale=scale * scalar,
1545-
opacity=sensor_opacity,
1545+
loc=loc * unit_scalar,
1546+
color=color[:3],
1547+
scale=scale,
1548+
opacity=sensor_opacity * color[3],
15461549
orient_glyphs=orient_glyphs,
15471550
scale_by_distance=scale_by_distance,
15481551
project_points=project_points,
15491552
surf=surf,
15501553
check_inside=check_inside,
15511554
nearest=nearest,
15521555
)
1553-
actor_list.append(actor)
1554-
if sensor_type in ("source", "detector"):
1555-
sensor_type = "fnirs"
1556-
actors[sensor_type].append(actor_list)
1557-
1558-
# add projected eeg
1559-
eeg_indices = pick_types(info, eeg=True)
1560-
if eeg_indices.size > 0 and "projected" in eeg:
1561-
logger.info("Projecting sensors to the head surface")
1562-
eeg_loc = np.array([ch_pos[info.ch_names[idx]] for idx in eeg_indices])
1563-
eeg_loc = eeg_loc[~np.isnan(eeg_loc).any(axis=1)]
1564-
eegp_loc, eegp_nn = _project_onto_surface(
1565-
eeg_loc, head_surf, project_rrs=True, return_nn=True
1566-
)[2:4]
1567-
del eeg_loc
1568-
eegp_loc *= scalar
1569-
scale = defaults["eegp_scale"] * scalar
1570-
actor, _ = renderer.quiver3d(
1571-
x=eegp_loc[:, 0],
1572-
y=eegp_loc[:, 1],
1573-
z=eegp_loc[:, 2],
1574-
u=eegp_nn[:, 0],
1575-
v=eegp_nn[:, 1],
1576-
w=eegp_nn[:, 2],
1577-
color=defaults["eegp_color"],
1578-
mode="cylinder",
1579-
scale=scale,
1580-
opacity=0.6,
1581-
glyph_height=defaults["eegp_height"],
1582-
glyph_center=(0.0, -defaults["eegp_height"] / 2.0, 0),
1583-
glyph_resolution=20,
1584-
backface_culling=True,
1585-
)
1586-
actors["eeg"].append(actor)
1556+
actors[ch_type].append(actor)
1557+
if ch_type == "eeg" and "projected" in eeg:
1558+
logger.info("Projecting sensors to the head surface")
1559+
eegp_loc, eegp_nn = _project_onto_surface(
1560+
sens_loc[mask], head_surf, project_rrs=True, return_nn=True
1561+
)[2:4]
1562+
eegp_loc *= unit_scalar
1563+
actor, _ = renderer.quiver3d(
1564+
x=eegp_loc[:, 0],
1565+
y=eegp_loc[:, 1],
1566+
z=eegp_loc[:, 2],
1567+
u=eegp_nn[:, 0],
1568+
v=eegp_nn[:, 1],
1569+
w=eegp_nn[:, 2],
1570+
color=defaults["eegp_color"],
1571+
mode="cylinder",
1572+
scale=defaults["eegp_scale"] * unit_scalar,
1573+
opacity=0.6,
1574+
glyph_height=defaults["eegp_height"],
1575+
glyph_center=(0.0, -defaults["eegp_height"] / 2.0, 0),
1576+
glyph_resolution=20,
1577+
backface_culling=True,
1578+
)
1579+
actors["eeg"].append(actor)
1580+
actors = dict(actors) # get rid of defaultdict
15871581

15881582
return actors
15891583

mne/viz/_brain/_brain.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2763,6 +2763,8 @@ def add_sensors(
27632763
seeg=True,
27642764
dbs=True,
27652765
max_dist=0.004,
2766+
*,
2767+
sensor_colors=None,
27662768
verbose=None,
27672769
):
27682770
"""Add mesh objects to represent sensor positions.
@@ -2778,6 +2780,9 @@ def add_sensors(
27782780
%(seeg)s
27792781
%(dbs)s
27802782
%(max_dist_ieeg)s
2783+
%(sensor_colors)s
2784+
2785+
.. versionadded:: 1.6
27812786
%(verbose)s
27822787
27832788
Notes
@@ -2832,6 +2837,7 @@ def add_sensors(
28322837
warn_meg,
28332838
head_surf,
28342839
self._units,
2840+
sensor_colors=sensor_colors,
28352841
)
28362842
for item, actors in sensors_actors.items():
28372843
for actor in actors:

0 commit comments

Comments
 (0)