8
8
#include " gadget.cuh"
9
9
#include " helper_multi_gpu.h"
10
10
#include " keyswitch.cuh"
11
+ #include " linearalgebra/multiplication.cuh"
11
12
#include " polynomial/functions.cuh"
12
13
#include " polynomial/polynomial_math.cuh"
13
14
#include " torus.cuh"
18
19
19
20
#define CEIL_DIV (M, N ) ((M) + (N)-1 ) / (N)
20
21
21
- const int BLOCK_SIZE_GEMM = 64 ;
22
- const int THREADS_GEMM = 8 ;
23
22
const int BLOCK_SIZE_DECOMP = 8 ;
24
23
25
24
template <typename Torus> uint64_t get_shared_mem_size_tgemm () {
@@ -91,106 +90,6 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
91
90
buffer_in[state_idx] = state;
92
91
}
93
92
94
- // Multiply matrices A, B of size (M, K), (K, N) respectively
95
- // with K as the inner dimension.
96
- //
97
- // A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
98
- // BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
99
- // THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
100
- // BLOCK_SIZE_GEMM)-shaped tiles of values from B.
101
- //
102
- // This code is adapted by generalizing the 1d block-tiling
103
- // kernel from https://github.com/siboehm/SGEMM_CUDA
104
- // to any matrix dimension
105
- template <typename Torus, typename TorusVec>
106
- __global__ void tgemm (int M, int N, int K, const Torus *A, const Torus *B,
107
- int stride_B, Torus *C) {
108
-
109
- const int BM = BLOCK_SIZE_GEMM;
110
- const int BN = BLOCK_SIZE_GEMM;
111
- const int BK = THREADS_GEMM;
112
- const int TM = THREADS_GEMM;
113
-
114
- const uint cRow = blockIdx .y ;
115
- const uint cCol = blockIdx .x ;
116
-
117
- const int threadCol = threadIdx .x % BN;
118
- const int threadRow = threadIdx .x / BN;
119
-
120
- // Allocate space for the current block tile in shared memory
121
- __shared__ Torus As[BM * BK];
122
- __shared__ Torus Bs[BK * BN];
123
-
124
- // Initialize the pointers to the input blocks from A, B
125
- // Tiles from these blocks are loaded to shared memory
126
- A += cRow * BM * K;
127
- B += cCol * BN;
128
-
129
- // Each thread will handle multiple sub-blocks
130
- const uint innerColA = threadIdx .x % BK;
131
- const uint innerRowA = threadIdx .x / BK;
132
- const uint innerColB = threadIdx .x % BN;
133
- const uint innerRowB = threadIdx .x / BN;
134
-
135
- // allocate thread-local cache for results in registerfile
136
- Torus threadResults[TM] = {0 };
137
-
138
- auto row_A = cRow * BM + innerRowA;
139
- auto col_B = cCol * BN + innerColB;
140
-
141
- // For each thread, loop over block tiles
142
- for (uint bkIdx = 0 ; bkIdx < K; bkIdx += BK) {
143
- auto col_A = bkIdx + innerColA;
144
- auto row_B = bkIdx + innerRowB;
145
-
146
- if (row_A < M && col_A < K) {
147
- As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA];
148
- } else {
149
- As[innerRowA * BK + innerColA] = 0 ;
150
- }
151
-
152
- if (col_B < N && row_B < K) {
153
- Bs[innerRowB * BN + innerColB] = B[innerRowB * stride_B + innerColB];
154
- } else {
155
- Bs[innerRowB * BN + innerColB] = 0 ;
156
- }
157
- synchronize_threads_in_block ();
158
-
159
- // Advance blocktile for the next iteration of this loop
160
- A += BK;
161
- B += BK * stride_B;
162
-
163
- // calculate per-thread results
164
- for (uint dotIdx = 0 ; dotIdx < BK; ++dotIdx) {
165
- // we make the dotproduct loop the outside loop, which facilitates
166
- // reuse of the Bs entry, which we can cache in a tmp var.
167
- Torus tmp = Bs[dotIdx * BN + threadCol];
168
- for (uint resIdx = 0 ; resIdx < TM; ++resIdx) {
169
- threadResults[resIdx] +=
170
- As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
171
- }
172
- }
173
- synchronize_threads_in_block ();
174
- }
175
-
176
- // Initialize the pointer to the output block of size (BLOCK_SIZE_GEMM,
177
- // BLOCK_SIZE_GEMM)
178
- C += cRow * BM * N + cCol * BN;
179
-
180
- // write out the results
181
- for (uint resIdx = 0 ; resIdx < TM; ++resIdx) {
182
- int outRow = cRow * BM + threadRow * TM + resIdx;
183
- int outCol = cCol * BN + threadCol;
184
-
185
- if (outRow >= M)
186
- continue ;
187
- if (outCol >= N)
188
- continue ;
189
-
190
- C[(threadRow * TM + resIdx) * N + threadCol] += threadResults[resIdx];
191
- }
192
- }
193
-
194
93
// Finish the keyswitching operation and prepare GLWEs for accumulation.
195
94
// 1. Finish the keyswitching computation partially performed with a GEMM:
196
95
// - negate the dot product between the GLWE and KSK polynomial
@@ -313,7 +212,7 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
313
212
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
314
213
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
315
214
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
316
- stride_KSK_buffer, d_mem_1);
215
+ stride_KSK_buffer, d_mem_1, glwe_accumulator_size );
317
216
check_cuda_error (cudaGetLastError ());
318
217
319
218
auto ksk_block_size = glwe_accumulator_size;
@@ -327,7 +226,8 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
327
226
tgemm<Torus, TorusVec>
328
227
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
329
228
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
330
- fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
229
+ fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1,
230
+ glwe_accumulator_size);
331
231
check_cuda_error (cudaGetLastError ());
332
232
}
333
233
0 commit comments