-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathulf_recon_fns.py
194 lines (157 loc) · 7.59 KB
/
ulf_recon_fns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# CS Recon Functions for ULF MRI Experiments - Updated by D Waddington 2022
import numpy as np
import matplotlib.pyplot as plt
import sigpy as sp
import sigpy.mri as mr
import sigpy.plot as pl
# functions for metrics adapted from .py file provided by Efrat
from metrics import calc_NRMSE as nrmse
from metrics import calc_SSIM as ssim
# the main reconstruction function
def ulfl1recon(ksp,mask,lamda,iter=30, mps=float('nan')):
maskedksp = applyMask(ksp,mask)
if np.any(np.isnan(mps)):
if maskedksp.shape[0] == 1: #single-channel
mps = np.ones(maskedksp.shape,dtype=complex)
else: #multi-channel
mps = mr.app.EspiritCalib(maskedksp, calib_width=20, kernel_width=4, show_pbar=False).run()
l1wav = mr.app.L1WaveletRecon(maskedksp, mps, lamda, show_pbar=False)
l1wav.alg.max_iter = iter
img_l1wav = l1wav.run()
return img_l1wav
# the main reconstruction function but with TV regularization
def ulfTVrecon(ksp,mask,lamda,iter=30, mps=float('nan')):
maskedksp = applyMask(ksp,mask)
if np.any(np.isnan(mps)):
if maskedksp.shape[0] == 1: #single-channel
mps = np.ones(maskedksp.shape,dtype=complex)
else: #multi-channel
mps = mr.app.EspiritCalib(maskedksp, calib_width=20, kernel_width=4, show_pbar=False).run()
TVrcn = mr.app.TotalVariationRecon(maskedksp, mps, lamda, show_pbar=False)
TVrcn.alg.max_iter = iter
img_TVrcn = TVrcn.run()
return img_TVrcn
# retrospective undersampling function that applies a mask to fully sampled 3D kspace
def applyMask(ksp,mask):
if len(ksp.shape) == 3: #making single-channel the same shape as multi-channel
ksp_channels = np.reshape(ksp,(1,ksp.shape[0],ksp.shape[1],ksp.shape[2]))
elif len(ksp.shape) == 4: #multi-channel
ksp_channels = ksp
maskedksp = np.zeros(ksp_channels.shape,dtype = 'complex64')
for i in range(ksp_channels.shape[0]):
for j in range(ksp_channels.shape[2]):
maskedksp[i,:,j,:] = np.multiply(np.squeeze(ksp_channels[i,:,j,:]),mask)
return maskedksp
# Mask generation using the poisson disc function from SigPy but with a calibration region added
def poissonDiscSigpy(imSize,accel,in_seed,calib_size=10):
if accel == 1:
maskpoissonsquare = np.ones((imSize[0],imSize[2]))
else:
maskpoissonsquare = sp.mri.poisson([imSize[0],imSize[2]], accel, tol = 0.1, seed = in_seed);
maskpoissonsquare[int(imSize[0]/2-calib_size/2+1):int(imSize[0]/2+calib_size/2+1),int(imSize[2]/2-calib_size/2+1):int(imSize[2]/2+calib_size/2+1)] = np.ones((calib_size,calib_size),dtype=complex);
return maskpoissonsquare
# finds the optimal regularization parameter lamda for reconstruction
def find_lamda_mask(ksp, GT, mps=float('nan'), calib_size=10, show_plot=True, iter = 30):
if len(ksp.shape) == 3: # if single channel reshape to multi-channel with a leading 1.
ksp = np.reshape(ksp,(1,ksp.shape[0],ksp.shape[1],ksp.shape[2]))
lamda_vals = np.array([1E-4, 2E-4, 5E-4, 1E-3, 2E-3, 5E-3, 1E-2, 2E-2, 5E-2, 1E-1, 2E-1, 5E-1, 1, 2, 5, 10, 20])
nrmse_vals = np.zeros((lamda_vals.size))
ssim_vals = np.zeros((lamda_vals.size))
mask_metrics = np.ma.getmask(np.ma.masked_less(abs(GT),0.0001))
GT[mask_metrics]=0
i = 0
for lamda in lamda_vals:
img_l1wav = ulfl1recon(ksp,np.ones((ksp.shape[1],ksp.shape[3]),dtype='complex64'),lamda,iter,mps)
#img_l1wav = ulfl1recon(ksp,np.squeeze(masks[:,:,j]),lamda,iter)
img_l1wav[mask_metrics]=0
nrmse_vals[i] = nrmse(abs(img_l1wav[:,:,:]),abs(GT[:,:,:]))
ssim_vals[i] = ssim(abs(img_l1wav[:,:,:]),abs(GT[:,:,:]))
i = i + 1
lamda_opt = lamda_vals[np.argmin(nrmse_vals)]
if show_plot == True:
fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
line1 = ax1.plot(lamda_vals,nrmse_vals, 'ob-')
ax1.set(title='Lamda Opt',
ylabel='NRMSE',
xlabel='lamda')
ax1.set_xscale('log')
ax1.legend("NRMSE")
ax2 = fig1.add_subplot(111, sharex=ax1, frameon=False)
line2 =ax2.plot(lamda_vals,ssim_vals, 'or-')
ax2.yaxis.tick_right()
ax2.yaxis.set_label_position("right")
ax2.set(ylabel='SSIM')
ax2.legend("SSIM")
plt.show()
print('Minimum NRMSE for lamda value of ',lamda_opt)
return lamda_opt
# finds the optimal number of iterations for reconstruction
def find_iter_mask(ksp,GT, lamda_opt, mps=float('nan'), show_plot=True):
#iter_vals = np.array([1, 2, 5, 10, 15, 20, 30, 40, 50, 60, 70, 80, 90, 100, 120, 150, 200, 300, 400])
iter_vals = np.array([1, 2, 3, 4, 5, 10, 20, 30, 50, 70, 100, 150, 200])
nrmse_iter_vals = np.zeros((iter_vals.size))
ssim_iter_vals = np.zeros((iter_vals.size))
#mask_metrics = np.ma.getmask(np.ma.masked_less(abs(GT),0.15))
mask_metrics = np.ma.getmask(np.ma.masked_less(abs(GT),0.0001))
GT[mask_metrics]=0
i = 0
for iter in iter_vals:
img_l1wav = ulfl1recon(ksp,np.ones((ksp.shape[1],ksp.shape[3]),dtype='complex64'),lamda_opt,iter,mps)
#img_l1wav = ulfl1recon(ksp,mask,lamda_opt,iter);
img_l1wav[mask_metrics] = 0
nrmse_iter_vals[i] = nrmse(abs(img_l1wav[:,:,:]),abs(GT[:,:,:]))
ssim_iter_vals[i] = ssim(abs(img_l1wav[:,:,:]),abs(GT[:,:,:]))
i = i + 1
iter_opt = iter_vals[np.argmin(nrmse_iter_vals)]
if show_plot == True:
fig2 = plt.figure()
ax1 = fig2.add_subplot(111)
line1 = ax1.plot(iter_vals,nrmse_iter_vals, 'ob-')
ax1.set(title='Iter Opt',
ylabel='NRMSE',
xlabel='Iterations')
#ax1.set_xscale('log')
ax1.legend("NRMSE")
ax2 = fig2.add_subplot(111, sharex=ax1, frameon=False)
line2 =ax2.plot(iter_vals,ssim_iter_vals, 'or-')
ax2.yaxis.tick_right()
ax2.yaxis.set_label_position("right")
ax2.set(ylabel='SSIM')
ax2.legend("SSIM")
plt.show()
print('Minimum NRMSE for iter value of ',iter_opt)
return iter_opt
# to combine multi-channel data via a sensitvity map (mps)
def coil_combine(imgs_mc,mps):
img_cc = np.sum(np.multiply(np.conj(mps),imgs_mc),axis=0)
return img_cc
# code to add white gaussian noise
# author - Mathuranathan Viswanathan (gaussianwaves.com)
# This code is part of the book Digital Modulations using Python
from numpy import sum,isrealobj,sqrt
from numpy.random import standard_normal
def awgn(s,SNRdB,L=1):
"""
AWGN channel
Add AWGN noise to input signal. The function adds AWGN noise vector to signal 's' to generate a resulting signal vector 'r' of specified SNR in dB. It also
returns the noise vector 'n' that is added to the signal 's' and the power spectral density N0 of noise added
Parameters:
s : input/transmitted signal vector
SNRdB : desired signal to noise ratio (expressed in dB) for the received signal
L : oversampling factor (applicable for waveform simulation) default L = 1.
Returns:
r : received signal vector (r=s+n)
"""
gamma = 10**(SNRdB/10) #SNR to linear scale
if s.ndim==1:# if s is single dimensional vector
P=L*sum(abs(s)**2)/len(s) #Actual power in the vector
else: # multi-dimensional signals like MFSK
P=L*sum(sum(abs(s)**2))/len(s) # if s is a matrix [MxN]
N0=P/gamma # Find the noise spectral density
if isrealobj(s):# check if input is real/complex object type
n = sqrt(N0/2)*standard_normal(s.shape) # computed noise
else:
n = sqrt(N0/2)*(standard_normal(s.shape)+1j*standard_normal(s.shape))
r = s + n # received signal
return r