Skip to content

Commit ce7fcbd

Browse files
committed
fix keyword sealPIR
1 parent 15f9cb7 commit ce7fcbd

9 files changed

+618
-43
lines changed

Diff for: CMakeLists.txt

+13
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,16 @@ target_link_libraries(test_cope ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE
215215
add_executable(test_oprf_psi test/mytest/test_psi_from_oprf.cpp)
216216
target_link_libraries(test_oprf_psi ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})
217217

218+
#ecdh_psi
219+
add_executable(test_ecdh_psi test/mytest/test_ecdh_psi.cpp)
220+
target_link_libraries(test_ecdh_psi ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})
221+
222+
#hash function
223+
add_executable(test_hash test/mytest/test_hash.cpp)
224+
target_link_libraries(test_hash ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})
225+
226+
#multi thread
227+
add_executable(test_thread test/mytest/test_multi_thread.cpp)
228+
target_link_libraries(test_thread ${OPENSSL_LIBRARIES} OpenMP::OpenMP_CXX ${CMAKE_DL_LIBS})
229+
230+

Diff for: mpc/pir/sealpir_keyword.hpp

+62-41
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ SEALPIR can be executed for keyword-based queries through network communication
1414
#include "seal/seal.h"
1515
#include "seal/util/polyarithsmallmod.h"
1616
#include "../../netio/stream_channel.hpp"
17+
#include "../../crypto/bigint.hpp"
18+
#include "../../crypto/hash.hpp"
19+
#include "../../include/global.hpp"
20+
1721
#include "pir.hpp"
1822
#include "pir_client.hpp"
1923
#include "pir_server.hpp"
@@ -35,55 +39,63 @@ namespace SEALPIRKEYWORD {
3539
/*
3640
* add keyword query
3741
* use a vector to save mapping of indexs and keywords
38-
3942
* */
4043
/*First clinet need to read file for KeyWord*/
4144
std::ifstream fin;
42-
std::vector<block> index_data;
45+
std::vector<std::string> keyword;
4346
fin.open("A_PIR_ID.txt", std::ios::binary);
4447
if (!fin) {
4548
std::cout << "Failed to open file" << std::endl;
4649
std::exit(-1);
4750
}
4851
std::string line;
4952
int i = 0;
50-
while (std::getline(fin, line)) {
51-
index_data.push_back(Block::MakeBlock(0LL, std::stoull(line)));
53+
while (std::getline(fin, line, '\n')) {
54+
if (!line.empty() && line[line.size() - 1] == '\r')
55+
line.erase(line.size() - 1);
56+
keyword.push_back(line);
5257
}
5358
fin.close();
54-
/*Second client receive file for Server for mapping of indexs and keywords */
59+
/*Second client receive file for Server for mapping of indexs and hash(keywords) */
5560
block a;
5661
/*Recive file size */
5762
io.ReceiveBlock(a);
5863
int server_file_len = Block::BlockToInt64(a);
5964
/*Receive file*/
60-
std::vector<block> index_num;
65+
std::vector<block> index_num(server_file_len);
6166

6267
std::cout << "Start receive file data" << std::endl;
68+
//io.ReceiveBlocks(index_num.data(),server_file_len);
69+
70+
#pragma omp parallel for num_threads(thread_count)
6371
for (int i = 0; i < server_file_len; i++) {
6472
io.ReceiveBlock(a);
65-
index_num.push_back(a);
73+
index_num[i] = a;
6674
}
6775
std::cout << "receive file finish" << std::endl;
68-
//use map to save mapping indexs of keywords in order to achieve fast finding
76+
//use map to save mapping indexs of hash(keywords) in order to achieve fast finding
6977
// construct map
7078
std::map<std::uint64_t, std::uint64_t> m;
7179
// data insert map
72-
for (auto a: index_num) {
73-
std::uint64_t index = Block::BlockToUint64High(a);
74-
std::uint64_t keyword = Block::BlockToInt64(a);
80+
//#pragma omp parallel for num_threads(thread_count)
81+
for (int i = 0; i < index_num.size(); i++) {
82+
std::uint64_t index = Block::BlockToUint64High(index_num[i]);
83+
std::uint64_t keyword = Block::BlockToInt64(index_num[i]);
7584
m[keyword] = index;
7685
}
77-
// find index by keyword and save data to block
78-
for (int i = 0; i < index_data.size(); i++) {
86+
87+
std::vector<std::uint64_t> keyword_index(keyword.size());
88+
//#pragma omp parallel for num_threads(thread_count)
89+
for (int i = 0; i < keyword.size(); i++) {
90+
auto a = Hash::StringToBlock(keyword[i]);
91+
auto s = Block::BlockToInt64(a);
7992
std::map<std::uint64_t, std::uint64_t>::iterator it = m.find(
80-
Block::BlockToInt64(index_data[i])); // 在 std::map 容器中查找 key 对应的 value
93+
s); // 在 std::map 容器中查找 key 对应的 value
8194
if (it != m.end()) {
82-
block temp = Block::MakeBlock(it->second, Block::BlockToInt64(index_data[i]));
83-
index_data[i] = temp;
84-
95+
keyword_index[i] = (it->second);
8596
}
8697
}
98+
std::cout << "keyword_inde size= " << keyword_index.size() << std::endl;
8799
auto start_time = std::chrono::steady_clock::now();
88100
/*The preprocessing of the files has been completed */
89101
cout << "Main: Generating galois keys for client" << endl;
@@ -103,15 +115,19 @@ namespace SEALPIRKEYWORD {
103115
* it means the query is finished and the network communication is disconnected*/
104116
std::string isend = "1";
105117
io.SendString(isend);
106-
for (int i = 0; i < index_data.size(); i++) {
118+
int query_toltal_size = 0;
107119

120+
for (int i = 0; i < keyword_index.size(); i++) {
108121
/*Generate query*/
109-
uint64_t ele_index = Block::BlockToUint64High(index_data[i]);
122+
//uint64_t ele_index = Block::BlockToUint64High(index_data[i]);
123+
uint64_t ele_index = keyword_index[i];
110124
uint64_t index = client.get_fv_index(ele_index); // index of FV plaintext
111125
uint64_t offset = client.get_fv_offset(ele_index); // offset in FV plaintext
126+
112127
/*Serialization query sending by network*/
113128
stringstream client_stream;
114129
int query_size = client.generate_serialized_query(index, client_stream);
130+
query_toltal_size += query_size;
115131
/*First send query size to server*/
116132
io.SendBlock(Block::MakeBlock(0LL, query_size));
117133
std::cout << "query_size" << query_size << std::endl;
@@ -130,7 +146,7 @@ namespace SEALPIRKEYWORD {
130146
PirReply reply = client.deserialize_reply(sreply, context);
131147
vector<uint8_t> elems = client.decode_reply(reply, offset);
132148
ans.push_back(elems);
133-
if (i != index_data.size() - 1)
149+
if (i != keyword_index.size() - 1)
134150
isend = "1";
135151
else isend = "0";
136152
io.SendString(isend);
@@ -141,9 +157,10 @@ namespace SEALPIRKEYWORD {
141157
auto running_time = end_time - start_time;
142158
std::cout << "SealPIR:Client side takes time = "
143159
<< std::chrono::duration<double, std::milli>(running_time).count() << " ms" << std::endl;
144-
160+
std::cout << "SealPIR:Client Query size= ["
161+
<< (double) (query_toltal_size) / (1024 * 1024) << " MB]" << std::endl;
145162
return ans;
146-
}
163+
} //client
147164

148165
void Server(NetIO &io, PirParams params, const EncryptionParameters &enc_params, std::string filename) {
149166
/*, unique_ptr<uint8_t[]> &db1*/
@@ -174,19 +191,27 @@ namespace SEALPIRKEYWORD {
174191

175192
}
176193
fin.close();
177-
/*save index and keyword to block */
178-
std::vector<block> index;
179-
for (int i = 0; i < file_data.size(); i++) {
180-
block a = Block::MakeBlock(std::stoull(file_data[i][0]), std::stoull(file_data[i][1]));
181-
index.push_back(a);
194+
/*
195+
* mapping keyword-> hash(keyword)
196+
* block=(index,hash(keyword)
197+
* */
198+
std::vector<block> hash_index(file_data.size());
199+
#pragma omp parallel for num_threads(thread_count)
200+
for (auto i = 0; i < file_data.size(); i++) {
201+
auto s = Hash::StringToBlock(file_data[i][1]);
202+
block a = Block::MakeBlock(std::stoull(file_data[i][0]), Block::BlockToInt64(s));
203+
hash_index[i] = a;
182204
}
205+
183206
std::cout << "Send file size" << std::endl;
184-
io.SendBlock(Block::MakeBlock(0LL, index.size()));
207+
io.SendBlock(Block::MakeBlock(0LL, hash_index.size()));
185208
std::cout << "The file length has been sent successfully" << std::endl;
186209
std::cout << "start send file data" << std::endl;
187-
for (auto a: index) {
188-
io.SendBlock(a);
210+
#pragma omp parallel for num_threads(thread_count)
211+
for (auto i = 0; i < hash_index.size(); i++) {
212+
io.SendBlock(hash_index[i]);
189213
}
214+
190215
std::cout << "file data send successfully" << std::endl;
191216
auto start_time = std::chrono::steady_clock::now();
192217
block c;
@@ -198,6 +223,7 @@ namespace SEALPIRKEYWORD {
198223
server.set_galois_key(0, *galois_keys1);
199224
/*Before generating the database file, it is necessary to remove the special
200225
* symbol '-' from the second column of the raw data*/
226+
//#pragma omp parallel for num_threads(thread_count)
201227
for (int i = 0; i < file_data.size(); i++) {
202228
std::string temp = file_data[i][2];
203229
temp.erase(std::remove(temp.begin(), temp.end(), '-'), temp.end());
@@ -220,26 +246,22 @@ namespace SEALPIRKEYWORD {
220246

221247
std::string isend(1, '0');
222248
io.ReceiveString(isend);
249+
int reply_toltal_size = 0;
223250
while (isend == "1") {
224-
225251
block a;
226252
io.ReceiveBlock(a);
227253
int query_size = Block::BlockToInt64(a);
228254
std::cout << "rec_query_size " << query_size << std::endl;
229-
230255
std::string query(query_size, '0');
231256
io.ReceiveString(query);
232257
std::stringstream server_stream;
233258
std::stringstream server_stream1(query);
234259
PirQuery query2 = server.deserialize_query(server_stream1);
235-
236260
PirReply reply = server.generate_reply(query2, 0);
237-
238261
int reply_size = server.serialize_reply(reply, server_stream);
239-
262+
reply_toltal_size += reply_size;
240263
io.SendBlock(Block::MakeBlock(0LL, reply_size));
241264
std::string str1 = server_stream.str();
242-
243265
io.SendString(str1);
244266
io.ReceiveString(isend);
245267
//if (isend == "0") return;
@@ -248,11 +270,10 @@ namespace SEALPIRKEYWORD {
248270
auto running_time = end_time - start_time;
249271
std::cout << "SealPIR:Server side takes time = "
250272
<< std::chrono::duration<double, std::milli>(running_time).count() << " ms" << std::endl;
251-
252-
}
253-
254-
255-
}
273+
std::cout << "SealPIR:Server Reply size= ["
274+
<< (double) (reply_toltal_size) / (1024 * 1024) << " MB]" << std::endl;
275+
}// server
276+
}// namespace
256277

257278

258279
#endif

Diff for: mpc/psi/ecdh_psi.h

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
//
2+
// Created by 17579 on 2023/5/6.
3+
//
4+
5+
#ifndef KUNLUN_ECDH_PSI_H
6+
#define KUNLUN_ECDH_PSI_H
7+
8+
#endif //KUNLUN_ECDH_PSI_H

0 commit comments

Comments
 (0)