Skip to content

Commit f99d91b

Browse files
feat(gpu): add poly product with circulant matrix
1 parent f962716 commit f99d91b

File tree

10 files changed

+755
-104
lines changed

10 files changed

+755
-104
lines changed

backends/tfhe-cuda-backend/cuda/include/linear_algebra.h

+7
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ void cuda_mult_lwe_ciphertext_vector_cleartext_vector_64(
4242
void const *lwe_array_in, void const *cleartext_array_in,
4343
const uint32_t input_lwe_dimension,
4444
const uint32_t input_lwe_ciphertext_count);
45+
void cuda_wrapping_polynomial_mul_one_to_many_64(
46+
void *stream, uint32_t gpu_index, void *result, void const *poly_lhs,
47+
void const *poly_rhs, uint32_t polynomial_size, uint32_t n_rhs);
48+
void cuda_glwe_wrapping_polynomial_mul_one_to_many_64(
49+
void *stream, uint32_t gpu_index, void *result, void const *poly_lhs,
50+
void const *poly_rhs, uint32_t polynomial_size, uint32_t glwe_dimension,
51+
uint32_t n_rhs);
4552
void cuda_add_lwe_ciphertext_vector_plaintext_64(
4653
void *stream, uint32_t gpu_index, void *lwe_array_out,
4754
void const *lwe_array_in, const uint64_t plaintext_in,

backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh

+4-104
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "gadget.cuh"
99
#include "helper_multi_gpu.h"
1010
#include "keyswitch.cuh"
11+
#include "linearalgebra/multiplication.cuh"
1112
#include "polynomial/functions.cuh"
1213
#include "polynomial/polynomial_math.cuh"
1314
#include "torus.cuh"
@@ -18,8 +19,6 @@
1819

1920
#define CEIL_DIV(M, N) ((M) + (N)-1) / (N)
2021

21-
const int BLOCK_SIZE_GEMM = 64;
22-
const int THREADS_GEMM = 8;
2322
const int BLOCK_SIZE_DECOMP = 8;
2423

2524
template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
@@ -91,106 +90,6 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
9190
buffer_in[state_idx] = state;
9291
}
9392

94-
// Multiply matrices A, B of size (M, K), (K, N) respectively
95-
// with K as the inner dimension.
96-
//
97-
// A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
98-
// BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
99-
// THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
100-
// BLOCK_SIZE_GEMM)-shaped tiles of values from B.
101-
//
102-
// This code is adapted by generalizing the 1d block-tiling
103-
// kernel from https://github.com/siboehm/SGEMM_CUDA
104-
// to any matrix dimension
105-
template <typename Torus, typename TorusVec>
106-
__global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
107-
int stride_B, Torus *C) {
108-
109-
const int BM = BLOCK_SIZE_GEMM;
110-
const int BN = BLOCK_SIZE_GEMM;
111-
const int BK = THREADS_GEMM;
112-
const int TM = THREADS_GEMM;
113-
114-
const uint cRow = blockIdx.y;
115-
const uint cCol = blockIdx.x;
116-
117-
const int threadCol = threadIdx.x % BN;
118-
const int threadRow = threadIdx.x / BN;
119-
120-
// Allocate space for the current block tile in shared memory
121-
__shared__ Torus As[BM * BK];
122-
__shared__ Torus Bs[BK * BN];
123-
124-
// Initialize the pointers to the input blocks from A, B
125-
// Tiles from these blocks are loaded to shared memory
126-
A += cRow * BM * K;
127-
B += cCol * BN;
128-
129-
// Each thread will handle multiple sub-blocks
130-
const uint innerColA = threadIdx.x % BK;
131-
const uint innerRowA = threadIdx.x / BK;
132-
const uint innerColB = threadIdx.x % BN;
133-
const uint innerRowB = threadIdx.x / BN;
134-
135-
// allocate thread-local cache for results in registerfile
136-
Torus threadResults[TM] = {0};
137-
138-
auto row_A = cRow * BM + innerRowA;
139-
auto col_B = cCol * BN + innerColB;
140-
141-
// For each thread, loop over block tiles
142-
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
143-
auto col_A = bkIdx + innerColA;
144-
auto row_B = bkIdx + innerRowB;
145-
146-
if (row_A < M && col_A < K) {
147-
As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA];
148-
} else {
149-
As[innerRowA * BK + innerColA] = 0;
150-
}
151-
152-
if (col_B < N && row_B < K) {
153-
Bs[innerRowB * BN + innerColB] = B[innerRowB * stride_B + innerColB];
154-
} else {
155-
Bs[innerRowB * BN + innerColB] = 0;
156-
}
157-
synchronize_threads_in_block();
158-
159-
// Advance blocktile for the next iteration of this loop
160-
A += BK;
161-
B += BK * stride_B;
162-
163-
// calculate per-thread results
164-
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
165-
// we make the dotproduct loop the outside loop, which facilitates
166-
// reuse of the Bs entry, which we can cache in a tmp var.
167-
Torus tmp = Bs[dotIdx * BN + threadCol];
168-
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
169-
threadResults[resIdx] +=
170-
As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
171-
}
172-
}
173-
synchronize_threads_in_block();
174-
}
175-
176-
// Initialize the pointer to the output block of size (BLOCK_SIZE_GEMM,
177-
// BLOCK_SIZE_GEMM)
178-
C += cRow * BM * N + cCol * BN;
179-
180-
// write out the results
181-
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
182-
int outRow = cRow * BM + threadRow * TM + resIdx;
183-
int outCol = cCol * BN + threadCol;
184-
185-
if (outRow >= M)
186-
continue;
187-
if (outCol >= N)
188-
continue;
189-
190-
C[(threadRow * TM + resIdx) * N + threadCol] += threadResults[resIdx];
191-
}
192-
}
193-
19493
// Finish the keyswitching operation and prepare GLWEs for accumulation.
19594
// 1. Finish the keyswitching computation partially performed with a GEMM:
19695
// - negate the dot product between the GLWE and KSK polynomial
@@ -313,7 +212,7 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
313212
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
314213
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
315214
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
316-
stride_KSK_buffer, d_mem_1);
215+
stride_KSK_buffer, d_mem_1, glwe_accumulator_size);
317216
check_cuda_error(cudaGetLastError());
318217

319218
auto ksk_block_size = glwe_accumulator_size;
@@ -327,7 +226,8 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
327226
tgemm<Torus, TorusVec>
328227
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
329228
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
330-
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
229+
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1,
230+
glwe_accumulator_size);
331231
check_cuda_error(cudaGetLastError());
332232
}
333233

backends/tfhe-cuda-backend/cuda/src/linearalgebra/multiplication.cu

+22
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "linearalgebra/multiplication.cuh"
2+
#include "polynomial/polynomial_math.cuh"
23

34
/*
45
* Perform the multiplication of a u32 input LWE ciphertext vector with a u32
@@ -58,3 +59,24 @@ void cuda_mult_lwe_ciphertext_vector_cleartext_vector_64(
5859
static_cast<const uint64_t *>(cleartext_array_in), input_lwe_dimension,
5960
input_lwe_ciphertext_count);
6061
}
62+
63+
void cuda_wrapping_polynomial_mul_one_to_many_64(
64+
void *stream, uint32_t gpu_index, void *result, void const *poly_lhs,
65+
void const *poly_rhs, uint32_t polynomial_size, uint32_t n_rhs) {
66+
67+
host_wrapping_polynomial_mul_one_to_many<uint64_t, ulonglong4>(
68+
static_cast<cudaStream_t>(stream), gpu_index,
69+
static_cast<uint64_t *>(result), static_cast<uint64_t const *>(poly_lhs),
70+
static_cast<uint64_t const *>(poly_rhs), polynomial_size, 0, n_rhs);
71+
}
72+
73+
void cuda_glwe_wrapping_polynomial_mul_one_to_many_64(
74+
void *stream, uint32_t gpu_index, void *result, void const *glwe_lhs,
75+
void const *poly_rhs, uint32_t polynomial_size, uint32_t glwe_dimension,
76+
uint32_t n_rhs) {
77+
host_glwe_wrapping_polynomial_mul_one_to_many<uint64_t, ulonglong4>(
78+
static_cast<cudaStream_t>(stream), gpu_index,
79+
static_cast<uint64_t *>(result), static_cast<uint64_t const *>(glwe_lhs),
80+
static_cast<uint64_t const *>(poly_rhs), polynomial_size, glwe_dimension,
81+
n_rhs);
82+
}

backends/tfhe-cuda-backend/cuda/src/linearalgebra/multiplication.cuh

+104
Original file line numberDiff line numberDiff line change
@@ -86,4 +86,108 @@ host_cleartext_multiplication(cudaStream_t stream, uint32_t gpu_index,
8686
check_cuda_error(cudaGetLastError());
8787
}
8888

89+
const int BLOCK_SIZE_GEMM = 64;
90+
const int THREADS_GEMM = 8;
91+
92+
// Multiply matrices A, B of size (M, K), (K, N) respectively
93+
// with K as the inner dimension.
94+
//
95+
// A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
96+
// BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
97+
// THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
98+
// BLOCK_SIZE_GEMM)-shaped tiles of values from B.
99+
//
100+
// This code is adapted by generalizing the 1d block-tiling
101+
// kernel from https://github.com/siboehm/SGEMM_CUDA
102+
// to any matrix dimension
103+
template <typename Torus, typename TorusVec>
104+
__global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
105+
int stride_B, Torus *C, int stride_C) {
106+
107+
const int BM = BLOCK_SIZE_GEMM;
108+
const int BN = BLOCK_SIZE_GEMM;
109+
const int BK = THREADS_GEMM;
110+
const int TM = THREADS_GEMM;
111+
112+
const uint cRow = blockIdx.y;
113+
const uint cCol = blockIdx.x;
114+
115+
const int threadCol = threadIdx.x % BN;
116+
const int threadRow = threadIdx.x / BN;
117+
118+
// Allocate space for the current block tile in shared memory
119+
__shared__ Torus As[BM * BK];
120+
__shared__ Torus Bs[BK * BN];
121+
122+
// Initialize the pointers to the input blocks from A, B
123+
// Tiles from these blocks are loaded to shared memory
124+
A += cRow * BM * K;
125+
B += cCol * BN;
126+
127+
// Each thread will handle multiple sub-blocks
128+
const uint innerColA = threadIdx.x % BK;
129+
const uint innerRowA = threadIdx.x / BK;
130+
const uint innerColB = threadIdx.x % BN;
131+
const uint innerRowB = threadIdx.x / BN;
132+
133+
// allocate thread-local cache for results in registerfile
134+
Torus threadResults[TM] = {0};
135+
136+
auto row_A = cRow * BM + innerRowA;
137+
auto col_B = cCol * BN + innerColB;
138+
139+
// For each thread, loop over block tiles
140+
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
141+
auto col_A = bkIdx + innerColA;
142+
auto row_B = bkIdx + innerRowB;
143+
144+
if (row_A < M && col_A < K) {
145+
As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA];
146+
} else {
147+
As[innerRowA * BK + innerColA] = 0;
148+
}
149+
150+
if (col_B < N && row_B < K) {
151+
Bs[innerRowB * BN + innerColB] = B[innerRowB * stride_B + innerColB];
152+
} else {
153+
Bs[innerRowB * BN + innerColB] = 0;
154+
}
155+
synchronize_threads_in_block();
156+
157+
// Advance blocktile for the next iteration of this loop
158+
A += BK;
159+
B += BK * stride_B;
160+
161+
// calculate per-thread results
162+
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
163+
// we make the dotproduct loop the outside loop, which facilitates
164+
// reuse of the Bs entry, which we can cache in a tmp var.
165+
Torus tmp = Bs[dotIdx * BN + threadCol];
166+
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
167+
threadResults[resIdx] +=
168+
As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
169+
}
170+
}
171+
synchronize_threads_in_block();
172+
}
173+
174+
// Initialize the pointer to the output block of size (BLOCK_SIZE_GEMM,
175+
// BLOCK_SIZE_GEMM)
176+
C += cRow * BM * stride_C + cCol * BN;
177+
178+
// write out the results
179+
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
180+
int outRow = cRow * BM + threadRow * TM + resIdx;
181+
int outCol = cCol * BN + threadCol;
182+
183+
if (outRow >= M)
184+
continue;
185+
if (outCol >= N)
186+
continue;
187+
188+
C[(threadRow * TM + resIdx) * stride_C + threadCol] +=
189+
threadResults[resIdx];
190+
}
191+
}
192+
89193
#endif // CUDA_MULT_H

0 commit comments

Comments
 (0)