Skip to content

Commit e9c3800

Browse files
committed
modified apsi
1 parent afa67ca commit e9c3800

37 files changed

+1925
-2057
lines changed

CMakeLists.txt

+27-6
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ CMAKE_MINIMUM_REQUIRED(VERSION 3.2)
55
PROJECT(Kunlun)
66
# On some machines, the path of OpenSSL is /usr/local/lib64/openssl/libcrypto.a
77
IF(CMAKE_SYSTEM_NAME STREQUAL "Linux")
8-
SET(OPENSSL_LIBRARIES /usr/local/lib64/libcrypto.a /usr/local/lib64/libssl.a)
9-
SET(OPENSSL_INCLUDE_DIR /usr/local/include)
10-
SET(OPENSSL_DIR /usr/local/lib64)
8+
SET(OPENSSL_LIBRARIES /usr/local/openssl/lib64/libcrypto.a /usr/local/openssl/lib64/libssl.a)
9+
SET(OPENSSL_INCLUDE_DIR /usr/local/openssl/include)
10+
SET(OPENSSL_DIR /usr/local/openssl/lib64)
1111
ELSEIF(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
1212
SET(OPENSSL_LIBRARIES /usr/local/lib/libcrypto.a /usr/local/lib/libssl.a)
1313
SET(OPENSSL_INCLUDE_DIR /usr/local/include)
@@ -242,6 +242,27 @@ target_link_libraries(test_label_peqt ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${
242242
add_executable(test_keyword_pir test/mytest/test_keywordpir.cpp)
243243
target_link_libraries(test_keyword_pir ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})
244244

245-
#test_poly
246-
add_executable(test_polynomial test/mytest/test_polynomial.cpp)
247-
target_link_libraries(test_polynomial ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})
245+
##test_poly
246+
#add_executable(test_polynomial test/mytest/test_polynomial.cpp)
247+
#target_link_libraries(test_polynomial ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})
248+
249+
#test_cuckoo_filter
250+
add_executable(mytest_cuckoo_filter test/mytest/test_cuckoo_filter.cpp)
251+
target_link_libraries(mytest_cuckoo_filter ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})
252+
253+
#test_psi
254+
add_executable(test_psi test/mytest/test_psi.cpp)
255+
target_link_libraries(test_psi ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})
256+
257+
add_executable(splitdata test/mytest/splitdata.cpp)
258+
target_link_libraries(splitdata ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})
259+
260+
261+
262+
#labelpsi
263+
add_library(labelpsi mpc/labelpsi/psi.h mpc/labelpsi/psi.cpp mpc/labelpsi/windowing.cpp mpc/labelpsi/hashing.cpp mpc/labelpsi/random.cpp mpc/labelpsi/polynomials.cpp)
264+
265+
target_link_libraries(labelpsi SEAL::seal)
266+
267+
add_executable(label test/mytest/test_labelpsi.cpp)
268+
target_link_libraries(label labelpsi ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})

config/config.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#pragma once
22

3-
#define NUMBER_OF_LOGICAL_CORES 4
4-
#define NUMBER_OF_PHYSICAL_CORES 4
3+
#define NUMBER_OF_LOGICAL_CORES 64
4+
#define NUMBER_OF_PHYSICAL_CORES 32
55
#define IS_64BIT 1
66
#define HAS_SSE2 1
77

filter/cuckoo_filter.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ class CuckooFilter {
175175
for(auto j=0;j<element[i].size();j++){
176176
result=(result<<8)|element[i][j];
177177
}
178+
std::cout<<result<<std::endl;
178179
//unsigned char buffer[element[i].size()];
179180
//memcpy(buffer, element[i].data(), element[i].size());
180181
//std::cout<<buffer<<std::endl;

mpc/labelpsi/client.cpp

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#include <cassert>
2+
#include <iostream>
3+
4+
#include "boost/asio.hpp"
5+
6+
#include "networking.h"
7+
8+
using namespace std;
9+
using namespace boost::asio;
10+
11+
int main()
12+
{
13+
vector<uint64_t> inputs = {0x02, 0x07, 0x05, 0xfe};
14+
size_t input_bits = 32;
15+
size_t poly_modulus_degree = 8192;
16+
unsigned short port = 9999;
17+
18+
io_context context;
19+
ip::tcp::socket socket(context);
20+
ip::tcp::resolver resolver(context);
21+
connect(socket, resolver.resolve("localhost", "9999", resolver.numeric_service));
22+
Networking net(socket);
23+
24+
cout << "connected, waiting for hello and set size" << endl;
25+
net.read_hello();
26+
size_t sender_size = net.read_uint32();
27+
28+
cout << "picking params" << endl;
29+
PSIParams params(inputs.size(), sender_size, input_bits, poly_modulus_degree);
30+
params.generate_seeds();
31+
net.set_seal_context(params.context);
32+
PSIReceiver receiver(params);
33+
34+
cout << "sending hello, set size, seeds, pk, relin keys" << endl;
35+
net.write_hello();
36+
net.write_uint32(inputs.size());
37+
net.write_uint64s(params.seeds);
38+
net.write_public_key(receiver.public_key());
39+
net.write_relin_keys(receiver.relin_keys());
40+
41+
cout << "encrypting inputs" << endl;
42+
vector<bucket_slot> buckets;
43+
auto encrypted_inputs = receiver.encrypt_inputs(inputs, buckets);
44+
45+
cout << "sending inputs" << endl;
46+
net.write_ciphertexts(encrypted_inputs);
47+
48+
cout << "waiting for encrypted matches" << endl;
49+
vector<Ciphertext> encrypted_matches;
50+
net.read_ciphertexts(encrypted_matches);
51+
52+
cout << "decrypting matches" << endl;
53+
auto matches = receiver.decrypt_labeled_matches(encrypted_matches);
54+
55+
cout << matches.size() << " matches found: ";
56+
for (auto i : matches) {
57+
assert(i.first < buckets.size());
58+
assert(buckets[i.first] != BUCKET_EMPTY);
59+
cout << inputs[buckets[i.first].first] << "-" << i.second << " ";
60+
}
61+
cout << endl;
62+
}

mpc/labelpsi/hashing.cpp

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#include <cassert>
2+
3+
//#include "aes.h"
4+
#include "../../crypto/aes.hpp"
5+
#include "hashing.h"
6+
7+
using namespace std;
8+
9+
uint64_t aes_hash(AES::Key &key, size_t bits, uint64_t value) {
10+
block a=Block::MakeBlock('0LL',value);
11+
assert(bits < 64);
12+
13+
AES::Enc(key, a);
14+
return (Block::BlockToUint64Low(a) ^ value) & ((1ull << bits) - 1);
15+
}
16+
17+
size_t loc_aes_hash(AES::Key &key, size_t m, uint64_t value) {
18+
return aes_hash(key, m, value >> m) ^ (value & ((1ull << m) - 1));
19+
}
20+
21+
bool cuckoo_hash(shared_ptr<UniformRandomGenerator> random,
22+
vector<uint64_t> &inputs,
23+
size_t m,
24+
vector<bucket_slot> &buckets,
25+
vector<uint64_t> &seeds)
26+
{
27+
buckets.resize(1 << m);
28+
for (size_t i = 0; i < buckets.size(); i++) {
29+
buckets[i] = BUCKET_EMPTY;
30+
}
31+
32+
vector<AES::Key> aes(seeds.size());
33+
for (size_t i = 0; i < seeds.size(); i++) {
34+
aes[i]=AES::GenEncKey(Block::MakeBlock('0LL',seeds[i]));
35+
}
36+
37+
for (size_t i = 0; i < inputs.size(); i++) {
38+
bool resolved = false;
39+
bucket_slot current_item = make_pair(
40+
i,
41+
random_integer(random, seeds.size())
42+
);
43+
44+
// TODO: keep track of # of operations and abort if exceeding some limit
45+
while (!resolved) {
46+
size_t loc = loc_aes_hash(
47+
aes[current_item.second],
48+
m,
49+
inputs[current_item.first]
50+
);
51+
52+
buckets[loc].swap(current_item);
53+
54+
if (current_item == BUCKET_EMPTY) {
55+
resolved = true;
56+
} else {
57+
size_t old_hash = current_item.second;
58+
while (current_item.second == old_hash) {
59+
current_item.second = random_integer(random, seeds.size());
60+
}
61+
}
62+
}
63+
}
64+
65+
return true;
66+
}
67+
68+
bool complete_hash(shared_ptr<UniformRandomGenerator> random,
69+
vector<uint64_t> &inputs,
70+
size_t m,
71+
size_t capacity,
72+
vector<bucket_slot> &buckets,
73+
vector<uint64_t> &seeds)
74+
{
75+
buckets.resize(capacity << m);
76+
for (size_t i = 0; i < buckets.size(); i++) {
77+
buckets[i] = BUCKET_EMPTY;
78+
}
79+
80+
vector<AES::Key> aes(seeds.size());
81+
for (size_t i = 0; i < seeds.size(); i++) {
82+
aes[i]=AES::GenEncKey(Block::MakeBlock('0LL',seeds[i]));
83+
}
84+
85+
vector<size_t> capacity_used(1 << m);
86+
87+
// insert all elements into the table in a deterministic order (filling each
88+
// bucket sequentially)
89+
for (size_t i = 0; i < inputs.size(); i++) {
90+
for (size_t j = 0; j < seeds.size(); j++) {
91+
size_t loc = loc_aes_hash(aes[j], m, inputs[i]);
92+
93+
if (capacity_used[loc] == capacity) {
94+
// all slots in the bucket are used, so we cannot add this
95+
// element
96+
return false;
97+
}
98+
99+
buckets[capacity * loc + capacity_used[loc]] = make_pair(i, j);
100+
capacity_used[loc]++;
101+
}
102+
}
103+
104+
// now shuffle each bucket, to avoid leaking information about bucket load
105+
// distribution through partitioning
106+
for (size_t bucket = 0; bucket < (1 << m); bucket++) {
107+
for (size_t slot = 1; slot < capacity; slot++) {
108+
// uniformly pick a random slot before this one (possibly this
109+
// very same one) and swap
110+
size_t prev_slot = random_integer(random, slot + 1);
111+
buckets[capacity * bucket + slot].swap(buckets[capacity * bucket + prev_slot]);
112+
}
113+
}
114+
115+
return true;
116+
}

mpc/labelpsi/hashing.h

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
2+
#pragma once
3+
#include <cstdint>
4+
#include <utility>
5+
#include <vector>
6+
7+
#include "random.h"
8+
9+
typedef pair<size_t, size_t> bucket_slot;
10+
11+
const bucket_slot BUCKET_EMPTY = make_pair(0xFFFFFFFFul, 0xFFFFFFFFul);
12+
13+
/* Given a set of inputs, a number of buckets, and seeds for a hash function,
14+
performs permutation-based cuckoo hashing to put at most one element in each
15+
bucket.
16+
Permutation-based hashing means that, after hashing, it is safe to drop the
17+
last m bits of all the inputs in the table.
18+
The number of buckets is 2^m. Non-empty buckets will contain
19+
(input_index, seed_index), empty ones will be equal to BUCKET_EMPTY.
20+
Seeds should be random 64-bit values.
21+
*/
22+
bool cuckoo_hash(shared_ptr<UniformRandomGenerator> random,
23+
vector<uint64_t> &inputs,
24+
size_t m,
25+
vector<bucket_slot> &buckets,
26+
vector<uint64_t> &seeds);
27+
28+
/* Given a set of inputs, a number of buckets, and seeds for a hash function,
29+
places every input, hashed with *every* function, into the corresponding
30+
bucket, using permutation-based hashing.
31+
The number of buckets is 2^m. Non-empty buckets will contain
32+
(input_index, seed_index), empty ones will be equal to BUCKET_EMPTY.
33+
jth element of bucket number i is stored in buckets[i * capacity + j].
34+
Seeds should be random 64-bit values.
35+
*/
36+
bool complete_hash(shared_ptr<UniformRandomGenerator> random,
37+
vector<uint64_t> &inputs,
38+
size_t m,
39+
size_t capacity,
40+
vector<bucket_slot> &buckets,
41+
vector<uint64_t> &seeds);
42+

mpc/labelpsi/polynomials.cpp

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include <cassert>
2+
#include <set>
3+
4+
#include "polynomials.h"
5+
6+
uint64_t modexp(uint64_t base, uint64_t exponent, uint64_t modulus) {
7+
uint64_t result = 1;
8+
while (exponent > 0) {
9+
if (exponent & 1) {
10+
result = MUL_MOD(result, base, modulus);
11+
}
12+
base = MUL_MOD(base, base, modulus);
13+
exponent = (exponent >> 1);
14+
}
15+
return result;
16+
}
17+
18+
uint64_t modinv(uint64_t x, uint64_t modulus) {
19+
return modexp(x, modulus - 2, modulus);
20+
}
21+
22+
void polynomial_from_roots(vector<uint64_t> &roots, vector<uint64_t> &coeffs, uint64_t modulus) {
23+
coeffs.clear();
24+
coeffs.resize(roots.size() + 1);
25+
coeffs[0] = 1;
26+
27+
for (size_t i = 0; i < roots.size(); i++) {
28+
// multiply coeffs by (x - root)
29+
uint64_t neg_root = modulus - (roots[i] % modulus);
30+
31+
for (size_t j = i + 1; j > 0; j--) {
32+
coeffs[j] = (coeffs[j - 1] + MUL_MOD(neg_root, coeffs[j], modulus)) % modulus;
33+
}
34+
coeffs[0] = MUL_MOD(coeffs[0], neg_root, modulus);
35+
}
36+
}
37+
38+
void polynomial_from_points(vector<uint64_t> &xs,
39+
vector<uint64_t> &ys,
40+
vector<uint64_t> &coeffs,
41+
uint64_t modulus)
42+
{
43+
assert(xs.size() == ys.size());
44+
coeffs.clear();
45+
coeffs.resize(xs.size());
46+
47+
if (xs.size() == 0) {
48+
return;
49+
}
50+
51+
// at iteration i of the loop, basis contains the coefficients of the basis
52+
// polynomial (x - xs[0]) * (x - xs[1]) * ... * (x - xs[i - 1])
53+
vector<uint64_t> basis(xs.size());
54+
basis[0] = 1;
55+
56+
// at iteration i of the loop, ddif[j] contains the divided difference
57+
// [ys[j], ys[j + 1], ..., ys[j + i]]. thus initially, when i = 0,
58+
// ddif[j] = [ys[j]] = ys[j]
59+
vector<uint64_t> ddif = ys;
60+
61+
for (size_t i = 0; i < xs.size(); i++) {
62+
for (size_t j = 0; j < i + 1; j++) {
63+
coeffs[j] = (coeffs[j] + MUL_MOD(ddif[0], basis[j], modulus)) % modulus;
64+
}
65+
66+
if (i < xs.size() - 1) {
67+
// update basis: multiply it by (x - xs[i])
68+
uint64_t neg_x = modulus - (xs[i] % modulus);
69+
70+
for (size_t j = i + 1; j > 0; j--) {
71+
basis[j] = (basis[j - 1] + MUL_MOD(neg_x, basis[j], modulus)) % modulus;
72+
}
73+
basis[0] = MUL_MOD(basis[0], neg_x, modulus);
74+
75+
// update ddif: compute length-(i + 1) divided differences
76+
for (size_t j = 0; j + i + 1 < xs.size() + 1; j++) {
77+
// dd_{j,j+i+1} = (dd_{j+1, j+i+1} - dd_{j, j+i}) / (x_{j+i+1} - x_j)
78+
uint64_t num = (ddif[j + 1] - ddif[j] + modulus) % modulus;
79+
uint64_t den = (xs[j + i + 1] - xs[j] + modulus) % modulus;
80+
ddif[j] = MUL_MOD(num, modinv(den, modulus), modulus);
81+
}
82+
}
83+
}
84+
}

0 commit comments

Comments
 (0)