7
7
#include " gadget.cuh"
8
8
#include " helper_multi_gpu.h"
9
9
#include " keyswitch.cuh"
10
+ #include " linearalgebra/multiplication.cuh"
10
11
#include " polynomial/functions.cuh"
11
12
#include " polynomial/polynomial_math.cuh"
12
13
#include " torus.cuh"
17
18
18
19
#define CEIL_DIV (M, N ) ((M) + (N)-1 ) / (N)
19
20
20
- const int BLOCK_SIZE_GEMM = 64 ;
21
- const int THREADS_GEMM = 8 ;
22
21
const int BLOCK_SIZE_DECOMP = 8 ;
23
22
24
23
template <typename Torus> uint64_t get_shared_mem_size_tgemm () {
@@ -90,106 +89,6 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
90
89
buffer_in[state_idx] = state;
91
90
}
92
91
93
- // Multiply matrices A, B of size (M, K), (K, N) respectively
94
- // with K as the inner dimension.
95
- //
96
- // A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
97
- // BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
98
- // THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
99
- // BLOCK_SIZE_GEMM)-shaped tiles of values from B.
100
- //
101
- // This code is adapted by generalizing the 1d block-tiling
102
- // kernel from https://github.com/siboehm/SGEMM_CUDA
103
- // to any matrix dimension
104
- template <typename Torus, typename TorusVec>
105
- __global__ void tgemm (int M, int N, int K, const Torus *A, const Torus *B,
106
- int stride_B, Torus *C) {
107
-
108
- const int BM = BLOCK_SIZE_GEMM;
109
- const int BN = BLOCK_SIZE_GEMM;
110
- const int BK = THREADS_GEMM;
111
- const int TM = THREADS_GEMM;
112
-
113
- const uint cRow = blockIdx .y ;
114
- const uint cCol = blockIdx .x ;
115
-
116
- const int threadCol = threadIdx .x % BN;
117
- const int threadRow = threadIdx .x / BN;
118
-
119
- // Allocate space for the current block tile in shared memory
120
- __shared__ Torus As[BM * BK];
121
- __shared__ Torus Bs[BK * BN];
122
-
123
- // Initialize the pointers to the input blocks from A, B
124
- // Tiles from these blocks are loaded to shared memory
125
- A += cRow * BM * K;
126
- B += cCol * BN;
127
-
128
- // Each thread will handle multiple sub-blocks
129
- const uint innerColA = threadIdx .x % BK;
130
- const uint innerRowA = threadIdx .x / BK;
131
- const uint innerColB = threadIdx .x % BN;
132
- const uint innerRowB = threadIdx .x / BN;
133
-
134
- // allocate thread-local cache for results in registerfile
135
- Torus threadResults[TM] = {0 };
136
-
137
- auto row_A = cRow * BM + innerRowA;
138
- auto col_B = cCol * BN + innerColB;
139
-
140
- // For each thread, loop over block tiles
141
- for (uint bkIdx = 0 ; bkIdx < K; bkIdx += BK) {
142
- auto col_A = bkIdx + innerColA;
143
- auto row_B = bkIdx + innerRowB;
144
-
145
- if (row_A < M && col_A < K) {
146
- As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA];
147
- } else {
148
- As[innerRowA * BK + innerColA] = 0 ;
149
- }
150
-
151
- if (col_B < N && row_B < K) {
152
- Bs[innerRowB * BN + innerColB] = B[innerRowB * stride_B + innerColB];
153
- } else {
154
- Bs[innerRowB * BN + innerColB] = 0 ;
155
- }
156
- synchronize_threads_in_block ();
157
-
158
- // Advance blocktile for the next iteration of this loop
159
- A += BK;
160
- B += BK * stride_B;
161
-
162
- // calculate per-thread results
163
- for (uint dotIdx = 0 ; dotIdx < BK; ++dotIdx) {
164
- // we make the dotproduct loop the outside loop, which facilitates
165
- // reuse of the Bs entry, which we can cache in a tmp var.
166
- Torus tmp = Bs[dotIdx * BN + threadCol];
167
- for (uint resIdx = 0 ; resIdx < TM; ++resIdx) {
168
- threadResults[resIdx] +=
169
- As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
170
- }
171
- }
172
- synchronize_threads_in_block ();
173
- }
174
-
175
- // Initialize the pointer to the output block of size (BLOCK_SIZE_GEMM,
176
- // BLOCK_SIZE_GEMM)
177
- C += cRow * BM * N + cCol * BN;
178
-
179
- // write out the results
180
- for (uint resIdx = 0 ; resIdx < TM; ++resIdx) {
181
- int outRow = cRow * BM + threadRow * TM + resIdx;
182
- int outCol = cCol * BN + threadCol;
183
-
184
- if (outRow >= M)
185
- continue ;
186
- if (outCol >= N)
187
- continue ;
188
-
189
- C[(threadRow * TM + resIdx) * N + threadCol] += threadResults[resIdx];
190
- }
191
- }
192
-
193
92
// Finish the keyswitching operation and prepare GLWEs for accumulation.
194
93
// 1. Finish the keyswitching computation partially performed with a GEMM:
195
94
// - negate the dot product between the GLWE and KSK polynomial
@@ -312,7 +211,7 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
312
211
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
313
212
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
314
213
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
315
- stride_KSK_buffer, d_mem_1);
214
+ stride_KSK_buffer, d_mem_1, glwe_accumulator_size );
316
215
check_cuda_error (cudaGetLastError ());
317
216
318
217
auto ksk_block_size = glwe_accumulator_size;
@@ -326,7 +225,8 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
326
225
tgemm<Torus, TorusVec>
327
226
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>> (
328
227
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
329
- fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
228
+ fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1,
229
+ glwe_accumulator_size);
330
230
check_cuda_error (cudaGetLastError ());
331
231
}
332
232
0 commit comments