Skip to content

Commit cec3d0e

Browse files
feat(gpu): add poly product with circulant matrix
1 parent 3988c85 commit cec3d0e

File tree

10 files changed

+755
-104
lines changed

10 files changed

+755
-104
lines changed

backends/tfhe-cuda-backend/cuda/include/linear_algebra.h

+7
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ void cuda_mult_lwe_ciphertext_vector_cleartext_vector_64(
4242
void const *lwe_array_in, void const *cleartext_array_in,
4343
const uint32_t input_lwe_dimension,
4444
const uint32_t input_lwe_ciphertext_count);
45+
void cuda_wrapping_polynomial_mul_one_to_many_64(
46+
void *stream, uint32_t gpu_index, void *result, void const *poly_lhs,
47+
void const *poly_rhs, uint32_t polynomial_size, uint32_t n_rhs);
48+
void cuda_glwe_wrapping_polynomial_mul_one_to_many_64(
49+
void *stream, uint32_t gpu_index, void *result, void const *poly_lhs,
50+
void const *poly_rhs, uint32_t polynomial_size, uint32_t glwe_dimension,
51+
uint32_t n_rhs);
4552
void cuda_add_lwe_ciphertext_vector_plaintext_64(
4653
void *stream, uint32_t gpu_index, void *lwe_array_out,
4754
void const *lwe_array_in, const uint64_t plaintext_in,

backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh

+4-104
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "gadget.cuh"
88
#include "helper_multi_gpu.h"
99
#include "keyswitch.cuh"
10+
#include "linearalgebra/multiplication.cuh"
1011
#include "polynomial/functions.cuh"
1112
#include "polynomial/polynomial_math.cuh"
1213
#include "torus.cuh"
@@ -17,8 +18,6 @@
1718

1819
#define CEIL_DIV(M, N) ((M) + (N)-1) / (N)
1920

20-
const int BLOCK_SIZE_GEMM = 64;
21-
const int THREADS_GEMM = 8;
2221
const int BLOCK_SIZE_DECOMP = 8;
2322

2423
template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
@@ -90,106 +89,6 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
9089
buffer_in[state_idx] = state;
9190
}
9291

93-
// Multiply matrices A, B of size (M, K), (K, N) respectively
94-
// with K as the inner dimension.
95-
//
96-
// A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
97-
// BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
98-
// THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
99-
// BLOCK_SIZE_GEMM)-shaped tiles of values from B.
100-
//
101-
// This code is adapted by generalizing the 1d block-tiling
102-
// kernel from https://github.com/siboehm/SGEMM_CUDA
103-
// to any matrix dimension
104-
template <typename Torus, typename TorusVec>
105-
__global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
106-
int stride_B, Torus *C) {
107-
108-
const int BM = BLOCK_SIZE_GEMM;
109-
const int BN = BLOCK_SIZE_GEMM;
110-
const int BK = THREADS_GEMM;
111-
const int TM = THREADS_GEMM;
112-
113-
const uint cRow = blockIdx.y;
114-
const uint cCol = blockIdx.x;
115-
116-
const int threadCol = threadIdx.x % BN;
117-
const int threadRow = threadIdx.x / BN;
118-
119-
// Allocate space for the current block tile in shared memory
120-
__shared__ Torus As[BM * BK];
121-
__shared__ Torus Bs[BK * BN];
122-
123-
// Initialize the pointers to the input blocks from A, B
124-
// Tiles from these blocks are loaded to shared memory
125-
A += cRow * BM * K;
126-
B += cCol * BN;
127-
128-
// Each thread will handle multiple sub-blocks
129-
const uint innerColA = threadIdx.x % BK;
130-
const uint innerRowA = threadIdx.x / BK;
131-
const uint innerColB = threadIdx.x % BN;
132-
const uint innerRowB = threadIdx.x / BN;
133-
134-
// allocate thread-local cache for results in registerfile
135-
Torus threadResults[TM] = {0};
136-
137-
auto row_A = cRow * BM + innerRowA;
138-
auto col_B = cCol * BN + innerColB;
139-
140-
// For each thread, loop over block tiles
141-
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
142-
auto col_A = bkIdx + innerColA;
143-
auto row_B = bkIdx + innerRowB;
144-
145-
if (row_A < M && col_A < K) {
146-
As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA];
147-
} else {
148-
As[innerRowA * BK + innerColA] = 0;
149-
}
150-
151-
if (col_B < N && row_B < K) {
152-
Bs[innerRowB * BN + innerColB] = B[innerRowB * stride_B + innerColB];
153-
} else {
154-
Bs[innerRowB * BN + innerColB] = 0;
155-
}
156-
synchronize_threads_in_block();
157-
158-
// Advance blocktile for the next iteration of this loop
159-
A += BK;
160-
B += BK * stride_B;
161-
162-
// calculate per-thread results
163-
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
164-
// we make the dotproduct loop the outside loop, which facilitates
165-
// reuse of the Bs entry, which we can cache in a tmp var.
166-
Torus tmp = Bs[dotIdx * BN + threadCol];
167-
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
168-
threadResults[resIdx] +=
169-
As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
170-
}
171-
}
172-
synchronize_threads_in_block();
173-
}
174-
175-
// Initialize the pointer to the output block of size (BLOCK_SIZE_GEMM,
176-
// BLOCK_SIZE_GEMM)
177-
C += cRow * BM * N + cCol * BN;
178-
179-
// write out the results
180-
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
181-
int outRow = cRow * BM + threadRow * TM + resIdx;
182-
int outCol = cCol * BN + threadCol;
183-
184-
if (outRow >= M)
185-
continue;
186-
if (outCol >= N)
187-
continue;
188-
189-
C[(threadRow * TM + resIdx) * N + threadCol] += threadResults[resIdx];
190-
}
191-
}
192-
19392
// Finish the keyswitching operation and prepare GLWEs for accumulation.
19493
// 1. Finish the keyswitching computation partially performed with a GEMM:
19594
// - negate the dot product between the GLWE and KSK polynomial
@@ -312,7 +211,7 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
312211
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
313212
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
314213
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
315-
stride_KSK_buffer, d_mem_1);
214+
stride_KSK_buffer, d_mem_1, glwe_accumulator_size);
316215
check_cuda_error(cudaGetLastError());
317216

318217
auto ksk_block_size = glwe_accumulator_size;
@@ -326,7 +225,8 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
326225
tgemm<Torus, TorusVec>
327226
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
328227
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
329-
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
228+
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1,
229+
glwe_accumulator_size);
330230
check_cuda_error(cudaGetLastError());
331231
}
332232

backends/tfhe-cuda-backend/cuda/src/linearalgebra/multiplication.cu

+22
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "linearalgebra/multiplication.cuh"
2+
#include "polynomial/polynomial_math.cuh"
23

34
/*
45
* Perform the multiplication of a u32 input LWE ciphertext vector with a u32
@@ -58,3 +59,24 @@ void cuda_mult_lwe_ciphertext_vector_cleartext_vector_64(
5859
static_cast<const uint64_t *>(cleartext_array_in), input_lwe_dimension,
5960
input_lwe_ciphertext_count);
6061
}
62+
63+
void cuda_wrapping_polynomial_mul_one_to_many_64(
64+
void *stream, uint32_t gpu_index, void *result, void const *poly_lhs,
65+
void const *poly_rhs, uint32_t polynomial_size, uint32_t n_rhs) {
66+
67+
host_wrapping_polynomial_mul_one_to_many<uint64_t, ulonglong4>(
68+
static_cast<cudaStream_t>(stream), gpu_index,
69+
static_cast<uint64_t *>(result), static_cast<uint64_t const *>(poly_lhs),
70+
static_cast<uint64_t const *>(poly_rhs), polynomial_size, 0, n_rhs);
71+
}
72+
73+
void cuda_glwe_wrapping_polynomial_mul_one_to_many_64(
74+
void *stream, uint32_t gpu_index, void *result, void const *glwe_lhs,
75+
void const *poly_rhs, uint32_t polynomial_size, uint32_t glwe_dimension,
76+
uint32_t n_rhs) {
77+
host_glwe_wrapping_polynomial_mul_one_to_many<uint64_t, ulonglong4>(
78+
static_cast<cudaStream_t>(stream), gpu_index,
79+
static_cast<uint64_t *>(result), static_cast<uint64_t const *>(glwe_lhs),
80+
static_cast<uint64_t const *>(poly_rhs), polynomial_size, glwe_dimension,
81+
n_rhs);
82+
}

backends/tfhe-cuda-backend/cuda/src/linearalgebra/multiplication.cuh

+104
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,108 @@ host_cleartext_multiplication(cudaStream_t stream, uint32_t gpu_index,
8686
check_cuda_error(cudaGetLastError());
8787
}
8888

89+
const int BLOCK_SIZE_GEMM = 64;
90+
const int THREADS_GEMM = 8;
91+
92+
// Multiply matrices A, B of size (M, K), (K, N) respectively
93+
// with K as the inner dimension.
94+
//
95+
// A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
96+
// BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
97+
// THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
98+
// BLOCK_SIZE_GEMM)-shaped tiles of values from B.
99+
//
100+
// This code is adapted by generalizing the 1d block-tiling
101+
// kernel from https://github.com/siboehm/SGEMM_CUDA
102+
// to any matrix dimension
103+
template <typename Torus, typename TorusVec>
104+
__global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
105+
int stride_B, Torus *C, int stride_C) {
106+
107+
const int BM = BLOCK_SIZE_GEMM;
108+
const int BN = BLOCK_SIZE_GEMM;
109+
const int BK = THREADS_GEMM;
110+
const int TM = THREADS_GEMM;
111+
112+
const uint cRow = blockIdx.y;
113+
const uint cCol = blockIdx.x;
114+
115+
const int threadCol = threadIdx.x % BN;
116+
const int threadRow = threadIdx.x / BN;
117+
118+
// Allocate space for the current block tile in shared memory
119+
__shared__ Torus As[BM * BK];
120+
__shared__ Torus Bs[BK * BN];
121+
122+
// Initialize the pointers to the input blocks from A, B
123+
// Tiles from these blocks are loaded to shared memory
124+
A += cRow * BM * K;
125+
B += cCol * BN;
126+
127+
// Each thread will handle multiple sub-blocks
128+
const uint innerColA = threadIdx.x % BK;
129+
const uint innerRowA = threadIdx.x / BK;
130+
const uint innerColB = threadIdx.x % BN;
131+
const uint innerRowB = threadIdx.x / BN;
132+
133+
// allocate thread-local cache for results in registerfile
134+
Torus threadResults[TM] = {0};
135+
136+
auto row_A = cRow * BM + innerRowA;
137+
auto col_B = cCol * BN + innerColB;
138+
139+
// For each thread, loop over block tiles
140+
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
141+
auto col_A = bkIdx + innerColA;
142+
auto row_B = bkIdx + innerRowB;
143+
144+
if (row_A < M && col_A < K) {
145+
As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA];
146+
} else {
147+
As[innerRowA * BK + innerColA] = 0;
148+
}
149+
150+
if (col_B < N && row_B < K) {
151+
Bs[innerRowB * BN + innerColB] = B[innerRowB * stride_B + innerColB];
152+
} else {
153+
Bs[innerRowB * BN + innerColB] = 0;
154+
}
155+
synchronize_threads_in_block();
156+
157+
// Advance blocktile for the next iteration of this loop
158+
A += BK;
159+
B += BK * stride_B;
160+
161+
// calculate per-thread results
162+
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
163+
// we make the dotproduct loop the outside loop, which facilitates
164+
// reuse of the Bs entry, which we can cache in a tmp var.
165+
Torus tmp = Bs[dotIdx * BN + threadCol];
166+
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
167+
threadResults[resIdx] +=
168+
As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
169+
}
170+
}
171+
synchronize_threads_in_block();
172+
}
173+
174+
// Initialize the pointer to the output block of size (BLOCK_SIZE_GEMM,
175+
// BLOCK_SIZE_GEMM)
176+
C += cRow * BM * stride_C + cCol * BN;
177+
178+
// write out the results
179+
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
180+
int outRow = cRow * BM + threadRow * TM + resIdx;
181+
int outCol = cCol * BN + threadCol;
182+
183+
if (outRow >= M)
184+
continue;
185+
if (outCol >= N)
186+
continue;
187+
188+
C[(threadRow * TM + resIdx) * stride_C + threadCol] +=
189+
threadResults[resIdx];
190+
}
191+
}
192+
89193
#endif // CUDA_MULT_H

0 commit comments

Comments
 (0)