Skip to content

Commit

Permalink
Merge pull request #64 from zhuzhenxi/master
Browse files Browse the repository at this point in the history
Update saliency_detection.py
  • Loading branch information
xulabs authored Jan 9, 2020
2 parents 36acfc2 + f1bead3 commit 11598c7
Showing 1 changed file with 151 additions and 104 deletions.
255 changes: 151 additions & 104 deletions aitom/segmentation/saliency/feature_decomposition/saliency_detection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import scipy.ndimage as SN
import numpy as np
import heapq
import aitom.io.file as io_file
import matplotlib.pyplot as plt
import time
import math
import numpy.linalg
import multiprocessing
import gc
from numba import jit
from scipy.spatial.distance import cdist
from scipy import signal
Expand All @@ -15,19 +18,20 @@
parameters:
a: volume data gaussian_sigma: sigma for de-noise gabor_sigma/gabor_lambda: sigma/lambda for Gabor filter
cluster_center_number: initial number of cluster centers save_flag: set True to save results
pick_num: the number of particles to pick out
return: saliency map, the same shape as a
'''
def saliency_detection(a, gaussian_sigma, gabor_sigma, gabor_lambda, cluster_center_number, save_flag=False):
def saliency_detection(a, gaussian_sigma, gabor_sigma, gabor_lambda, cluster_center_number, pick_num=None, multiprocessing_num=0, save_flag=False):
# Step 1
# Data Pre-processing
b_time = time.time()
a = SN.gaussian_filter(input=a, sigma=gaussian_sigma) # de-noise
print('sigma=', gaussian_sigma)
end_time = time.time()
print('de-noise takes', end_time - b_time, 's', ' sigma=', gaussian_sigma)
if save_flag:
img = (a[:, :, int(a.shape[2] / 2)]).copy()
plt.axis('off')
plt.imshow(img, cmap='gray')
plt.savefig('./original.png') # save fig
plt.imsave('./result/original.png', img, cmap='gray')

# Step 2
# Supervoxel over-segmentation
Expand All @@ -54,12 +58,8 @@ def saliency_detection(a, gaussian_sigma, gabor_sigma, gabor_lambda, cluster_cen

print('the number of cluster centers = %d' % len(ck))
print(ck[: 5])
label = [[[0 for i in range(a.shape[2])] for i in range(a.shape[1])] for i in range(a.shape[0])]
label = np.array(label) # numba supports numpy array
distance = [[[float('inf') for i in range(a.shape[2])] for i in range(a.shape[1])] for i in range(a.shape[0])]
distance = np.array(distance)
# label = np.zeros((a.shape[0], a.shape[1], a.shape[2])) # numba will report error
# distance = np.full((a.shape[0], a.shape[1], a.shape[2]), np.inf)
label = np.full((a.shape[0], a.shape[1], a.shape[2]), 0)
distance = np.full((a.shape[0], a.shape[1], a.shape[2]), np.inf)
start_time = time.time()
print('Supervoxel over-segmentation begins')
ck = np.array(ck)
Expand All @@ -68,7 +68,7 @@ def saliency_detection(a, gaussian_sigma, gabor_sigma, gabor_lambda, cluster_cen
for number in range(10): # 10 iterations suffices for most images
b_time = time.time()
print('\n%d of 10 iterations' % number)
distance, label, ck = fast_SLIC(distance, label, ck, a, interval)
distance, label, ck, redundant_flag = fast_SLIC(distance, label, ck, a, interval, redundant_flag)
# merge cluster centers
ck_dist_min = interval / 2 # merge two cluster centers if the distance between them is less than ck_dist_min
for ck_i in range(len(ck)):
Expand All @@ -86,24 +86,16 @@ def saliency_detection(a, gaussian_sigma, gabor_sigma, gabor_lambda, cluster_cen
print('total number of remove cluster centers = ', sum(redundant_flag == True))
e_time = time.time()
print('\n', e_time - b_time, 's')

cluster_center_number = int(len(ck) - sum(redundant_flag))
# renumber cluster center index
label = renumber(redundant_flag=redundant_flag,label=label)
assert np.max(label) == cluster_center_number - 1
end_time = time.time()
print('Supervoxel over-segmentation done,', end_time - start_time, 's')

# save labels for Feature extraction
labels_remove_num = sum(redundant_flag == True)
labels = {}

for i in range(0, a.shape[0]):
for j in range(0, a.shape[1]):
for k in range(0, a.shape[2]):
if label[i][j][k] in labels:
labels[label[i][j][k]].append([i, j, k])
else:
labels[label[i][j][k]] = [[i, j, k]]
assert labels_remove_num + len(labels) == len(ck)

if save_flag:
np.save('./labels', labels)
np.save('./result/label', label)
img = (a[:, :, int(a.shape[2] / 2)]).copy()
k = int(a.shape[2] / 2)
draw_color = np.min(a)
Expand All @@ -112,9 +104,9 @@ def saliency_detection(a, gaussian_sigma, gabor_sigma, gabor_lambda, cluster_cen
if label[i][j][k] != label[i - 1][j][k] or label[i][j][k] != label[i + 1][j][k] or label[i][j][k] != \
label[i][j - 1][k] or label[i][j][k] != label[i][j + 1][k]:
img[i][j] = draw_color
plt.axis('off')
plt.imshow(img, cmap='gray')
plt.savefig('./SLIC.png') # save fig
plt.imsave('./result/SLIC.png', img, cmap='gray')
del distance
gc.collect()

# Step 3
# Feature Extraction
Expand All @@ -126,58 +118,38 @@ def saliency_detection(a, gaussian_sigma, gabor_sigma, gabor_lambda, cluster_cen
# filters2 = filter_bank_gb3d(sigma=s2, Lambda=Lambda,psi=0,gamma=1)
# filters = filters1 + filters2
filters_num = len(filters)
feature_matrix = np.zeros((len(filters) + 6, len(labels))) # Gabor filter bases features and 6 density features
feature_matrix = np.zeros((len(filters) + 6, cluster_center_number)) # Gabor features and 6 density features
print('%d Gabor based features' % filters_num)

print('Feature extraction begins')
# 3D Gabor filter based features
for i in range(len(filters)):
# convolution
start_time = time.time()
# b=SN.correlate(a,filters[i]) # too slow
b = signal.correlate(a, filters[i], mode='same')
end_time = time.time()
print('feature %d done (%f s)' % (i, end_time - start_time))

# show Gabor filter output
if save_flag:
img = (b[:, :, int(a.shape[2] / 2)]).copy()
plt.axis('off')
plt.imshow(img, cmap='gray')
plt.savefig('./gabor_output(%d).png' % i) # save fig

# generate feature vector
start_time = time.time()
index_col = 0
for key in labels:
vox = labels[key]
sum_vox = 0
for j in range(len(vox)):
sum_vox = sum_vox + b[vox[j][0], vox[j][1], vox[j][2]]
# print('sum.type',type(sum)) <class 'numpy.float64'>
sum_vox = sum_vox / len(vox)
feature_matrix[i][index_col] = sum_vox
index_col += 1
# print(feature_matrix[i, 0:30])
end_time = time.time()
print('feature vector %d done (%f s)' % (i, end_time - start_time))
res_pool = []
if multiprocessing_num > 1:
pool = multiprocessing.Pool(processes=min(multiprocessing_num, multiprocessing.cpu_count()))
else:
pool = None

if pool is not None:
for fm_i in range(len(filters)):
res_pool.append(pool.apply_async(func=gabor_feature_single_job,
kwds={'a': a, 'filters': filters, 'fm_i': fm_i, 'label': label,
'cluster_center_number': cluster_center_number, 'save_flag': False}))
pool.close()
pool.join()
del pool
for pool_i in res_pool:
feature_matrix[pool_i.get()[0], :] = pool_i.get()[1]
else:
for fm_i in range(len(filters)):
_, feature_matrix[fm_i,:]=gabor_feature_single_job(a=a, filters=filters, fm_i=fm_i, label=label, cluster_center_number=cluster_center_number, save_flag=False)
print('3D Gabor filter based features done')

# density features
min_val = np.min(a)
max_val = np.max(a)
width = (max_val - min_val) / 6
index_col = 0
for key in labels:
vox = labels[key]
for j in vox:
bin_num = min(int((a[j[0]][j[1]][j[2]] - min_val) / width), 5) # normalize
feature_matrix[filters_num + bin_num][index_col] += 1
index_col += 1
feature_matrix = density_feature(a=a,feature_matrix=feature_matrix,label=label,filters_num=filters_num)
print('Density features done')

if save_flag:
np.save('./feature_matrix', feature_matrix)
np.save('./result/feature_matrix', feature_matrix)

etime = time.time()
print('Feature extraction done,', etime - stime, 's')
Expand All @@ -191,31 +163,15 @@ def saliency_detection(a, gaussian_sigma, gabor_sigma, gabor_lambda, cluster_cen
print('RPCA done, ', end_time - start_time, 's')
supervoxel_saliency = np.sum(S, axis=0) / S.shape[0]
if save_flag:
np.save('./supervoxel_saliency', supervoxel_saliency)
np.save('./result/supervoxel_saliency', supervoxel_saliency)

# Step 5
# Generate Saliency Map
min_saliency = np.min(supervoxel_saliency)
max_saliency = np.max(supervoxel_saliency)
t = (min_saliency + max_saliency) / 2 # threshold
print('min=', min_saliency, 'max=', max_saliency, 'threshold=', t)
index_col = 0
for key in labels:
vox = labels[key]
if supervoxel_saliency[index_col] < t:
supervoxel_saliency[index_col] = min_saliency

for j in vox:
a[j[0]][j[1]][j[2]] = supervoxel_saliency[index_col]
index_col += 1
# print('sum.type',type(sum)) <class 'numpy.float64'>

a = generate_saliency_map(a=a,label=label,supervoxel_saliency=supervoxel_saliency,pick_num=pick_num)
if save_flag:
img = a[:, :, int(a.shape[2] / 2)].copy()
plt.axis('off')
plt.imshow(img, cmap='gray')
plt.savefig('./saliency_map.png')
io_file.put_mrc_data(a, './saliency_map.mrc')
plt.imsave('./result/saliency_map.png', img, cmap='gray')
# io_file.put_mrc_data(a, './saliency_map.mrc')
print('saliency map saved')

return a
Expand Down Expand Up @@ -264,10 +220,7 @@ def filter_bank_gb3d(sigma, Lambda, psi=0, gamma=1, truncate=4.0):
for theta_x in np.arange(0, np.pi, np.pi / 4):
for theta_y in np.arange(0, np.pi, np.pi / 4):
for theta_z in np.arange(0, np.pi, np.pi / 4):
if np.sum(np.abs([theta_x, theta_y, theta_z]) < 10e-8) < 2:
continue
thetas = [theta_x, theta_y, theta_z]
print(thetas)
kern = gabor_fn(sigma, thetas, Lambda, psi, gamma, size)
kern /= kern.sum()
filters.append(np.transpose(kern))
Expand Down Expand Up @@ -296,6 +249,7 @@ def filter_bank_gb3d(sigma, Lambda, psi=0, gamma=1, truncate=4.0):
kern /= kern.sum()
# kern /= 1.5 * kern.sum()
filters.append(np.transpose(kern))

return filters


Expand Down Expand Up @@ -391,7 +345,7 @@ def converged(M, L, S, initial_error):


@jit(nopython=True)
def fast_SLIC(distance, label, ck, a, interval):
def fast_SLIC(distance, label, ck, a, interval, redundant_flag):
m = 10 # m can be in the range [1,40]
for i in range(len(ck)):
boundary = []
Expand All @@ -410,25 +364,118 @@ def fast_SLIC(distance, label, ck, a, interval):
label[ix][iy][iz] = i

# update cluster center
sum = np.zeros((len(ck),5))
sum_ck = np.zeros((len(ck),5))
for i in range(0, a.shape[0]):
for j in range(0, a.shape[1]):
for k in range(0, a.shape[2]):
sum[label[i][j][k]][4]=sum[label[i][j][k]][4]+1
sum[label[i][j][k]][:4]=sum[label[i][j][k]][:4]+np.array([i, j, k, a[i][j][k]])
sum_ck[label[i][j][k]][4]=sum_ck[label[i][j][k]][4]+1
sum_ck[label[i][j][k]][:4]=sum_ck[label[i][j][k]][:4]+np.array([i, j, k, a[i][j][k]])
for i in range(len(ck)):
if ck[i][3] == np.inf:
if redundant_flag[i]:
continue
assert sum[i][4]>0
sum[i][:4]=sum[i][:4]/sum[i][4]
ck[i]=sum[i][:4]
return distance,label,ck
if sum_ck[i][4] == 0:
redundant_flag[i] = True
continue
assert sum_ck[i][4]>0
sum_ck[i][:4]=sum_ck[i][:4]/sum_ck[i][4]
ck[i]=sum_ck[i][:4]
return distance,label,ck,redundant_flag


@jit(nopython=True)
def renumber(redundant_flag, label):
reduce_index = np.zeros(len(redundant_flag))
cnt = 0
for i in range(len(reduce_index)):
if redundant_flag[i]:
cnt += 1
else:
reduce_index[i] = cnt
for i in range(label.shape[0]):
for j in range(label.shape[1]):
for k in range(label.shape[2]):
label[i][j][k] -= reduce_index[label[i][j][k]]
return label


@jit(nopython=True)
def density_feature(a, feature_matrix, label, filters_num):
min_val = np.min(a)
max_val = np.max(a)
width = (max_val - min_val) / 6
for i in range(a.shape[0]):
for j in range(a.shape[1]):
for k in range(a.shape[2]):
bin_num = min(int((a[i][j][k] - min_val) / width), 5) # normalize
feature_matrix[filters_num + bin_num][label[i][j][k]] += 1
return feature_matrix


@jit(nopython=True)
def generate_saliency_map(a, label, supervoxel_saliency, pick_num):
min_saliency = np.min(supervoxel_saliency)
max_saliency = np.max(supervoxel_saliency)
t = (min_saliency + max_saliency) / 2 # threshold
if pick_num is not None:
unqiue_saliency = np.unique(supervoxel_saliency)
t = heapq.nlargest(pick_num, unqiue_saliency)[-1]
print('min=', min_saliency, 'max=', max_saliency, 'threshold=', t)

for i in range(len(supervoxel_saliency)):
if supervoxel_saliency[i] < t:
supervoxel_saliency[i] = min_saliency
for i in range(a.shape[0]):
for j in range(a.shape[1]):
for k in range(a.shape[2]):
a[i][j][k] = supervoxel_saliency[label[i][j][k]]
return a


def gabor_feature_single_job(a, filters, fm_i, label, cluster_center_number, save_flag):
# convolution
start_time = time.time()
# b=SN.correlate(a,filters[i]) # too slow
b = signal.correlate(a, filters[fm_i], mode='same')
end_time = time.time()
print('feature %d done (%f s)' % (fm_i, end_time - start_time))

# show Gabor filter output
if save_flag:
img = (b[:, :, int(a.shape[2] / 2)]).copy()
plt.imsave('./result/gabor_output(%d).png' % fm_i, img, cmap='gray') # save fig

# generate feature vector
start_time = time.time()
result = generate_feature_vector(b=b, label=label, cluster_center_number=cluster_center_number)
end_time = time.time()
print('feature vector %d done (%f s)' % (fm_i, end_time - start_time))
return fm_i, result


@jit(nopython=True)
def generate_feature_vector(b, label, cluster_center_number):
result = np.array([0]*cluster_center_number)
#sum_f = np.array((cluster_center_number, 2), 0)
sum_f = np.array([[0 for i in range(2)] for j in range(cluster_center_number)])
for i in range(0, b.shape[0]):
for j in range(0, b.shape[1]):
for k in range(0, b.shape[2]):
sum_f[label[i][j][k]][1] = sum_f[label[i][j][k]][1] + 1
sum_f[label[i][j][k]][0] = sum_f[label[i][j][k]][0] + b[i][j][k]
for i in range(cluster_center_number):
assert sum_f[i][1] > 0
result[i] = sum_f[i][0] / sum_f[i][1]
return result


if __name__ == "__main__":
path = './aitom_demo_single_particle_tomogram.mrc' # file path
mrc_header = io_file.read_mrc_header(path)
a = io_file.read_mrc_data(path) # volume data
assert a.shape[0] > 0
a = a.astype(np.float32)
print("file has been read, shape is", a.shape)
saliency_detection(a=a, gaussian_sigma=2.5, gabor_sigma=9.0, gabor_lambda=9.0, cluster_center_number=10000, save_flag=True)
start_time = time.time()
saliency_detection(a=a, gaussian_sigma=2.5, gabor_sigma=14.0, gabor_lambda=13.0, cluster_center_number=10000, multiprocessing_num=0, pick_num=1000, save_flag=True)
end_time = time.time()
print('saliency detection takes', end_time-start_time,'s')

0 comments on commit 11598c7

Please sign in to comment.