Skip to content

Commit

Permalink
Add colorbar to correlation matrix plot (#1401)
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo authored Feb 12, 2025
2 parents 8190929 + 7e71ad3 commit 9300df3
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions xcp_d/interfaces/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def plot_matrix(self, corr_mat, network_labels, ax):
np.fill_diagonal(corr_mat, 0)

# Plot the correlation matrix
ax.imshow(corr_mat, vmin=-1, vmax=1, cmap='seismic')
im = ax.imshow(corr_mat, vmin=-1, vmax=1, cmap='seismic')

# Add lines separating networks
for idx in break_idx[1:-1]:
Expand All @@ -404,9 +404,11 @@ def plot_matrix(self, corr_mat, network_labels, ax):
ax.axes.set_yticklabels(unique_labels)
ax.axes.set_xticklabels(unique_labels, rotation=90)

return ax
return im, ax

def _run_interface(self, runtime):
from matplotlib.gridspec import GridSpec

priority_list = [
'MIDB',
'MyersLabonte',
Expand Down Expand Up @@ -442,22 +444,21 @@ def _run_interface(self, runtime):
}

if len(selected_atlases) == 4:
nrows, ncols, figsize, ax_idx = 2, 2, (20, 20), [(0, 0), (0, 1), (1, 0), (1, 1)]
nrows, ncols, figsize = 2, 2, (20, 20)
else:
nrows, ncols, figsize = 1, len(selected_atlases), (10 * len(selected_atlases), 10)
ax_idx = list(range(ncols))

fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
if isinstance(axes, plt.Axes):
axes = np.array([axes])
fig = plt.figure(figsize=figsize)
gs = GridSpec(nrows, ncols + 1, width_ratios=[*[1] * ncols, 0.05])

# Create axes for each matrix
for i_ax, atlas in enumerate(selected_atlases):
i_row, i_col = divmod(i_ax, ncols)
ax = fig.add_subplot(gs[i_row, i_col])
atlas_idx = self.inputs.atlases.index(atlas)
atlas_file = self.inputs.correlations_tsv[atlas_idx]
dseg_file = self.inputs.atlas_tsvs[atlas_idx]

sel_ax_idx = ax_idx[i_ax]

column_name = COMMUNITY_LOOKUP.get(atlas, 'network_label')
dseg_df = pd.read_table(dseg_file)
corrs_df = pd.read_table(atlas_file, index_col='Node')
Expand All @@ -475,8 +476,7 @@ def _run_interface(self, runtime):
else:
network_labels = ['None'] * dseg_df.shape[0]

ax = axes[sel_ax_idx]
ax = self.plot_matrix(
im, ax = self.plot_matrix(
corr_mat=corrs_df.to_numpy(),
network_labels=network_labels,
ax=ax,
Expand All @@ -486,6 +486,10 @@ def _run_interface(self, runtime):
fontdict={'weight': 'normal', 'size': 20},
)

# Add colorbar in the reserved space
cbar_ax = fig.add_subplot(gs[0, -1])
plt.colorbar(im, cax=cbar_ax)
cbar_ax.set_yticks([-1, 0, 1])
fig.tight_layout()

# Write the results out
Expand Down

0 comments on commit 9300df3

Please sign in to comment.