Skip to content

Commit 8ca4173

Browse files
authored
Merge pull request #296 from nadir199/master
Add benchmarks and some fixes
2 parents 1744028 + ca956d5 commit 8ca4173

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3484
-87
lines changed

benchmarks/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@ add_subdirectory(DNN/blocks/LSTM/cpu)
2121
add_subdirectory(DNN/blocks/LSTM/cpu_lib)
2222
add_subdirectory(DNN/blocks/LSTM/cpu_lib_sparse)
2323
add_subdirectory(DNN/blocks/vggBlock/cpu/dense)
24+
add_subdirectory(DNN/blocks/vggBlock/cpu/sparse)
2425
add_subdirectory(DNN/blocks/fusedresNet/cpu/dense)
2526
add_subdirectory(DNN/blocks/fusedresNet_inference/cpu/sparse)
2627
add_subdirectory(DNN/blocks/fusedresNet_inference/cpu/dense)
2728
add_subdirectory(DNN/blocks/DenseNetBlock/cpu/dense)
2829
add_subdirectory(DNN/blocks/Conv-ReLU-MaxPool/cpu/dense)
2930
add_subdirectory(DNN/blocks/Conv-ReLU-MaxPool/cpu/sparse)
3031
add_subdirectory(DNN/blocks/Resize-Conv-ReLU-MaxPool/cpu/dense)
32+
add_subdirectory(DNN/blocks/Resize-Conv-ReLU-MaxPool/cpu/sparse)
3133
add_subdirectory(DNN/blocks/Conv-Relu-FC-Softmax/cpu/dense)
Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,2 @@
1-
<<<<<<< HEAD
2-
rm -rf conv_relu_maxpool_generator_tiramisu conv_relu_maxpool_tiramisu.o conv_relu_maxpool_tiramisu.o.h wrapper_nn_block_conv_relu_maxpool conv_relu_maxpool_mkl conv_relu_maxpool_mkldnn mkl_result.txt tiramisu_result.txt tf_model.pb
1+
rm -rf conv_relu_maxpool_generator_tiramisu conv_relu_maxpool_tiramisu.o conv_relu_maxpool_tiramisu.o.h wrapper_nn_block_conv_relu_maxpool conv_relu_maxpool_mkl conv_relu_maxpool_mkldnn mkl_result.txt tiramisu_result.txt tf_model.pb tvm_autotuning.log
32
rm -rf .pkl_memoize_py3 param_tuning.h
4-
=======
5-
rm conv_relu_maxpool_generator_tiramisu conv_relu_maxpool_tiramisu.o conv_relu_maxpool_tiramisu.o.h wrapper_nn_block_conv_relu_maxpool conv_relu_maxpool_mkl conv_relu_maxpool_mkldnn mkl_result.txt tiramisu_result.txt tf_model.pb tvm_autotuning.log
6-
rm -rf .pkl_memoize_py3
7-
>>>>>>> upstream/master

benchmarks/DNN/blocks/Conv-ReLU-MaxPool/cpu/sparse/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ The files in this folder are organized as follows:
1818
Intel MKL
1919
spconv_relu_maxpool_generator_mkl.c: code that calls Intel MKL's dense conv-relu-maxpool.
2020

21+
Intel MKL Sparse
22+
spconv_relu_maxpool_generator_mkl_sparse.cpp: code that calls Intel MKL Sparse's sparse conv-relu-maxpool.
23+
2124
To run this benchmark:
2225

2326
At the directory build/benchmarks/DNN/blocks/Conv-ReLU-Maxpool/cpu/sparse execute
@@ -36,6 +39,11 @@ To run this benchmark:
3639
then
3740
./spconv_relu_maxpool_wrapper
3841

42+
To compare the result of tiramisu with MKL Sparse execute :
43+
./compile_and_run_mkl_sparse.sh
44+
then
45+
./spconv_relu_maxpool_wrapper
46+
3947
execution results could be found in the text files :
40-
mkl_result.txt (same for Intel MKL and Intel MKL-DNN)
48+
mkl_result.txt (same for Intel MKL, Intel MKL-DNN and Intel MKL Sparse)
4149
tiramisu_result.txt
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
rm -rf generated_spconv_relu_maxpool.o.h generated_spconv_relu_maxpool.o mkl_result.txt tiramisu_result.txt spconv_relu_maxpool_wrapper spconv_relu_maxpool_generator conv_relu_maxpool_mkl conv_relu_maxpool_mkldnn
1+
rm -rf generated_spconv_relu_maxpool.o.h generated_spconv_relu_maxpool.o mkl_result.txt tiramisu_result.txt spconv_relu_maxpool_wrapper spconv_relu_maxpool_generator conv_relu_maxpool_mkl conv_relu_maxpool_mkldnn conv_relu_maxpool_mkl_sparse
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#set -x
2+
3+
source ../../../../../configure_paths.sh
4+
MKLDNNROOT=/usr/local/
5+
6+
export INCLUDES="-I${MKL_PREFIX}/include/ -I${MKLDNNROOT}/include"
7+
export LIBRARIES="${MKL_FLAGS} -lisl -lz -lpthread -ldl "
8+
export LIBRARIES_DIR="-L${MKL_PREFIX}/lib/${MKL_LIB_PATH_SUFFIX} -L${MKLDNNROOT}/lib"
9+
10+
source ${MKL_PREFIX}/bin/mklvars.sh ${MKL_LIB_PATH_SUFFIX}
11+
12+
g++ -O3 -DMKL_ILP64 -m64 ${INCLUDES} conv_relu_maxpool_generator_mkl_sparse.cpp -o conv_relu_maxpool_mkl_sparse ${LIBRARIES_DIR} -Wl,--no-as-needed -lmkl_intel_ilp64 -lmkl_gnu_thread -lmkl_core -lgomp -lpthread -fopenmp -lm -ldl -lmkldnn
13+
./conv_relu_maxpool_mkl_sparse

benchmarks/DNN/blocks/Conv-ReLU-MaxPool/cpu/sparse/configure.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
#define Y_BL 2
3838
#define Y_NB_BL (N/Y_BL)
3939

40+
// Parameters for MKL Sparse's IM2COL,
41+
#define H_BL 32 // Must be a divisor of N
42+
#define H_NB_BL N/H_BL
43+
#define W_BL 32 // Must be a divisor of N
44+
#define W_NB_BL N/W_BL
45+
4046
// Number of features in the input
4147
#define FIn 3
4248
// Number of features in the output

benchmarks/DNN/blocks/Conv-ReLU-MaxPool/cpu/sparse/conv_relu_maxpool_generator_mkl.c

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static dnnError_t init_conversion(dnnPrimitive_t *cv, float **ptr_out,
4040
}
4141

4242
// Original version by: Kyle Spafford Adapted for COO Format
43-
int initRandomSparseMatrix(float matrix[FOut][FIn][K][K], float density, const int KK, const int fin_size, const int fout_size)
43+
int initRandomSparseMatrix(float matrix[FOut][FIn][K][K], float density, const int KK, const int fin_size, const int fout_size, int seed)
4444
{
4545
const int n = KK * KK * fin_size * fout_size * density; // number of non zero elements
4646
int nnzAssigned = 0;
@@ -50,10 +50,10 @@ int initRandomSparseMatrix(float matrix[FOut][FIn][K][K], float density, const i
5050
int total_num_entries = KK * KK * fin_size * fout_size;
5151
double prob = (double)n / ((double) total_num_entries);
5252

53-
// Randomly decide whether entry i,j gets a value, but ensure n values
53+
// Randomly decide whether an entry gets a value, but ensure n values
5454
// are assigned
5555
int fillRemaining = 0;
56-
srand(1);
56+
srand(seed);
5757
for (int fout = 0; fout < fout_size; fout++)
5858
{
5959
for (int fin = 0; fin < fin_size; fin++)
@@ -113,7 +113,7 @@ int main()
113113
size_t maxpool_kernel_size[] = {2, 2};
114114
size_t maxpool_strides[] = {2, 2};
115115
int maxpool_offset[] = {0, 0};
116-
initRandomSparseMatrix(conv_filter_param, WEIGHTS_DENSITY, K, FIn, FOut);
116+
initRandomSparseMatrix(conv_filter_param, WEIGHTS_DENSITY, K, FIn, FOut, 1);
117117

118118
srand(3);
119119
// Allocate buffers
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
#include <cstdio>
2+
#include <cstdlib>
3+
#include <iostream>
4+
#include <vector>
5+
#include <algorithm>
6+
#include <omp.h>
7+
#include "mkl.h"
8+
9+
#include "mkldnn.hpp"
10+
#include "im2col.hpp"
11+
#include "mkl_spblas.h"
12+
13+
// Original version by: Kyle Spafford Adapted for CSR format
14+
void initRandomWeights(float* filter_values, MKL_INT* filter_idx, MKL_INT* filter_finptr, const int n, const int KK, const int fin_size, const int fout_size, const int seed)
15+
{
16+
int nnzAssigned = 0;
17+
// Figure out the probability that a nonzero should be assigned to a given
18+
// spot in the matrix
19+
int total_num_entries = KK * KK * fin_size * fout_size;
20+
double prob = (double)n / ((double) total_num_entries);
21+
22+
// Seed random number generator
23+
srand(seed);
24+
25+
// Randomly decide whether an entry gets a value, but ensure n values
26+
// are assigned
27+
int fillRemaining = 0;
28+
29+
for (int fout = 0; fout < fout_size; fout++)
30+
{
31+
filter_finptr[fout] = (MKL_INT)nnzAssigned;
32+
for (int fin = 0; fin < fin_size; fin++)
33+
{
34+
for (int ky = 0; ky < KK; ky++)
35+
{
36+
for (int kx = 0; kx < KK; kx++)
37+
{
38+
int numEntriesLeft = total_num_entries - ((fout * KK * KK * fin_size) + (fin * KK * KK) + (ky * KK) + kx);
39+
int needToAssign = n - nnzAssigned;
40+
if (numEntriesLeft <= needToAssign) {
41+
fillRemaining = 1;
42+
}
43+
if ((nnzAssigned < n && ((double) rand() / (RAND_MAX + 1.0)) <= prob) || fillRemaining)
44+
{
45+
filter_idx[nnzAssigned] = (MKL_INT)(fin * KK * KK + ky * KK + kx);
46+
filter_values[nnzAssigned] = ((float)(rand()%256 - 128)) / 127.f;
47+
nnzAssigned++;
48+
}
49+
}
50+
}
51+
}
52+
}
53+
filter_finptr[fout_size] = nnzAssigned;
54+
if (nnzAssigned != n)
55+
exit(500);
56+
}
57+
58+
int generateCSRWeights(float *filter_values, float density, MKL_INT *filter_idx, MKL_INT* filter_finptr, int KK, int fin_size, int fout_size, int seed) {
59+
int nNonzero = KK * KK * fin_size * fout_size * density;
60+
initRandomWeights(filter_values, filter_idx, filter_finptr, nNonzero, KK, fin_size, fout_size, seed);
61+
return nNonzero;
62+
}
63+
64+
using namespace mkldnn;
65+
66+
int main()
67+
{
68+
std::vector<double> duration_vector;
69+
70+
engine cpu_engine(engine::kind::cpu, 0);
71+
stream cpu_stream(cpu_engine);
72+
73+
std::vector<primitive> net;
74+
std::vector<std::unordered_map<int, memory>> net_args;
75+
76+
memory::dims pool_strides = {2, 2};
77+
memory::dims pool_kernel = {2, 2};
78+
memory::dims pool_padding = {0, 0};
79+
80+
int FNNZ = FOut*FIn*K*K*WEIGHTS_DENSITY;
81+
float filter_values[FNNZ];
82+
MKL_INT filter_idx[FNNZ]; //MKL_INT
83+
MKL_INT filter_finptr[FOut+1];
84+
// Generate sparse weights matrix
85+
generateCSRWeights(filter_values, WEIGHTS_DENSITY, filter_idx, filter_finptr, K, FIn, FOut, 1);
86+
87+
// Descriptor of main sparse matrix properties
88+
struct matrix_descr descrFilter;
89+
// // Structure with sparse matrix stored in CSR format
90+
sparse_matrix_t csrFilter;
91+
float alpha = 1.0, beta = 0.0;
92+
93+
// Create handle with matrix stored in CSR format
94+
mkl_sparse_s_create_csr (&csrFilter, SPARSE_INDEX_BASE_ZERO,
95+
FOut, // number of rows
96+
FIn*K*K, // number of cols
97+
filter_finptr,
98+
filter_finptr+1,
99+
filter_idx,
100+
filter_values);
101+
102+
// Analyze sparse matrix; choose proper kernels and workload balancing strategy
103+
mkl_sparse_optimize(csrFilter);
104+
105+
// Create matrix descriptor
106+
descrFilter.type = SPARSE_MATRIX_TYPE_GENERAL;
107+
108+
// Allocate buffers
109+
float* input_buf = (float*)malloc(sizeof(float) * FIn * (N + 2) * (N + 2) * BATCH_SIZE);
110+
float* conv_bias_buf = (float*)malloc(sizeof(float) * FOut);
111+
float* result_buf = (float*)malloc(sizeof(float) * FIn * (N) * (N) * K * K * BATCH_SIZE);
112+
float* conv_output_buf = (float*)malloc(sizeof(float) * FOut * (N) * (N) * BATCH_SIZE);
113+
114+
srand(3);
115+
for(int b = 0; b < BATCH_SIZE; ++b)
116+
for (int fin = 0; fin < FIn; ++fin)
117+
for (int y = 0; y < N + 2; ++y)
118+
for (int x = 0; x < N + 2; ++x)
119+
input_buf[x + y*(N+2) + fin*(N+2)*(N+2) + b*(N+2)*(N+2)*FIn] = ((float)(rand() % 256 - 128)) / 127.f;
120+
121+
for (int i = 0; i < FOut; i++)
122+
conv_bias_buf[i] = ((float)(rand()%256 - 128)) / 127.f;
123+
124+
printf("Buffers Initialized\n");
125+
126+
auto conv_output_md = memory::desc(
127+
{BATCH_SIZE, FOut, N, N},
128+
memory::data_type::f32,
129+
memory::format_tag::nchw
130+
131+
);
132+
auto conv_output_mem = memory(conv_output_md, cpu_engine, conv_output_buf);
133+
134+
auto relu_desc = eltwise_forward::desc(prop_kind::forward_inference,
135+
algorithm::eltwise_relu, conv_output_md,
136+
0);
137+
auto relu_pd = eltwise_forward::primitive_desc(relu_desc, cpu_engine);
138+
net.push_back(eltwise_forward(relu_pd));
139+
net_args.push_back({
140+
{MKLDNN_ARG_SRC, conv_output_mem},
141+
{MKLDNN_ARG_DST, conv_output_mem}
142+
});
143+
144+
auto pool_output_md = memory::desc(
145+
{BATCH_SIZE, FOut, N/2, N/2},
146+
memory::data_type::f32,
147+
memory::format_tag::any
148+
);
149+
150+
auto pool_d = pooling_forward::desc(
151+
prop_kind::forward_inference,
152+
algorithm::pooling_max,
153+
conv_output_md,
154+
pool_output_md,
155+
pool_strides,
156+
pool_kernel,
157+
pool_padding,
158+
pool_padding
159+
);
160+
161+
auto pool_pd = pooling_forward::primitive_desc(
162+
pool_d,
163+
cpu_engine
164+
);
165+
166+
auto pool_dst_mem = memory(pool_pd.dst_desc(), cpu_engine);
167+
168+
net.push_back(pooling_forward(pool_pd));
169+
net_args.push_back({
170+
{MKLDNN_ARG_SRC, conv_output_mem},
171+
{MKLDNN_ARG_DST, pool_dst_mem}
172+
});
173+
174+
omp_set_num_threads(4);
175+
for (int i = 0; i < NB_TESTS; ++i) {
176+
double start = rtclock();
177+
for(int batch = 0; batch<BATCH_SIZE; batch++){
178+
im2col_cpu(&input_buf[batch*(FIn*(N+2)*(N+2))], FIn,
179+
N+2, N+2, K, K,
180+
1, 1,
181+
&result_buf[batch*(FIn*N*N*K*K)]
182+
);
183+
// Filter weights are (FOut) * (FIn * K * K)
184+
// Lowered Input is (FIn * K * K) * (N * N)
185+
// The result of the mult is : (FOut) * (N * N)
186+
// Calculates C = alpha*A*B + C
187+
mkl_sparse_s_mm(SPARSE_OPERATION_NON_TRANSPOSE,
188+
alpha,
189+
csrFilter,
190+
descrFilter,
191+
SPARSE_LAYOUT_ROW_MAJOR,
192+
&result_buf[batch*(FIn*N*N*K*K)],
193+
N*N,
194+
N*N,
195+
beta,
196+
&conv_output_buf[batch*(FOut*N*N)],
197+
N*N
198+
);
199+
#pragma omp parallel for
200+
for(int fout = 0; fout<FOut; fout++){
201+
for(int y=0; y<N; y++)
202+
for(int x=0; x<N; x++)
203+
conv_output_buf[batch*(FOut*N*N) + fout*N*N + y*N + x] += conv_bias_buf[fout];
204+
}
205+
}
206+
// Execute relu/maxpool
207+
for (size_t j = 0; j < net.size(); ++j)
208+
net[j].execute(cpu_stream, net_args[j]);
209+
cpu_stream.wait();
210+
211+
double end = rtclock();
212+
duration_vector.push_back((end - start) * 1000);
213+
}
214+
215+
std::cout << "\t\tSparse Lowered Convolution time : "
216+
<< median(duration_vector) << " ms" << std::endl;
217+
218+
auto output_usr_md = memory::desc(
219+
{BATCH_SIZE, FOut, N/2, N/2},
220+
memory::data_type::f32,
221+
memory::format_tag::nchw
222+
);
223+
224+
auto output_mem = memory(output_usr_md, cpu_engine);
225+
reorder(pool_dst_mem, output_mem)
226+
.execute(cpu_stream, pool_dst_mem, output_mem);
227+
228+
if (WRITE_RESULT_TO_FILE){
229+
float* output_buf = (float*)output_mem.get_data_handle();
230+
// Write results to file
231+
FILE* f = fopen("mkl_result.txt", "w");
232+
if (f == NULL) {
233+
printf("Error creating mkl_sparse_result.txt.\n");
234+
return 0;
235+
}
236+
237+
for(int b=0; b<BATCH_SIZE; b++)
238+
for(int fout=0; fout<FOut; fout++)
239+
for(int y=0; y<N/2; y++)
240+
for(int x=0; x<N/2; x++)
241+
fprintf(f, "%.17g\n", output_buf[x + y*N/2 + fout*N/2*N/2 + b*N/2*N/2*FOut]);
242+
243+
fclose(f);
244+
}
245+
mkl_sparse_destroy(csrFilter);
246+
free(input_buf);
247+
free(result_buf);
248+
free(conv_output_buf);
249+
return 0;
250+
}

benchmarks/DNN/blocks/Conv-ReLU-MaxPool/cpu/sparse/conv_relu_maxpool_generator_mkldnn.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using namespace mkldnn;
77
using namespace std;
88

99
// Original version by: Kyle Spafford Adapted for COO Format
10-
int initRandomSparseMatrix(float* matrix, float density, const int KK, const int fin_size, const int fout_size)
10+
int initRandomSparseMatrix(float* matrix, float density, const int KK, const int fin_size, const int fout_size, int seed)
1111
{
1212
const int n = KK * KK * fin_size * fout_size * density; // number of non zero elements
1313
int nnzAssigned = 0;
@@ -20,7 +20,7 @@ int initRandomSparseMatrix(float* matrix, float density, const int KK, const int
2020
// Randomly decide whether entry i,j gets a value, but ensure n values
2121
// are assigned
2222
int fillRemaining = 0;
23-
srand(1);
23+
srand(seed);
2424
for (int fout = 0; fout < fout_size; fout++)
2525
{
2626
for (int fin = 0; fin < fin_size; fin++)
@@ -77,7 +77,7 @@ void conv_relu_maxpool_block()
7777
std::vector<float> conv_bias_buf(FOut);
7878
std::vector<float> conv_weights_buf(FOut * FIn * K * K);
7979

80-
initRandomSparseMatrix(conv_weights_buf.data(), WEIGHTS_DENSITY, K, FIn, FOut);
80+
initRandomSparseMatrix(conv_weights_buf.data(), WEIGHTS_DENSITY, K, FIn, FOut, 1);
8181

8282
srand(3);
8383
for (int i = 0; i < BATCH_SIZE*FIn*(N + 2)*(N + 2); i++)

0 commit comments

Comments
 (0)