From ea29524d104c2ef926b63be1dc1c28e5a9eafc01 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Fri, 22 Nov 2024 09:45:47 -0800 Subject: [PATCH 01/48] dpf v1 --- CMakePresets.json | 4 +- cryptoTools | 2 +- frontend/benchmark.h | 1115 +++++++++++++++-------------- frontend/main.cpp | 6 + libOTe/Tools/Dpf/RegularDpf.h | 503 +++++++++++++ libOTe_Tests/BgciksOT_Tests.h | 19 - libOTe_Tests/CMakeLists.txt | 18 +- libOTe_Tests/RegularDpf_Tests.cpp | 168 +++++ libOTe_Tests/RegularDpf_Tests.h | 6 + libOTe_Tests/UnitTests.cpp | 5 +- 10 files changed, 1301 insertions(+), 545 deletions(-) create mode 100644 libOTe/Tools/Dpf/RegularDpf.h delete mode 100644 libOTe_Tests/BgciksOT_Tests.h create mode 100644 libOTe_Tests/RegularDpf_Tests.cpp create mode 100644 libOTe_Tests/RegularDpf_Tests.h diff --git a/CMakePresets.json b/CMakePresets.json index 111c3e40..4451b953 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -64,9 +64,9 @@ "ENABLE_MR": "ON", "ENABLE_SIMPLESTOT": "ON", "ENABLE_GMP": false, - "ENABLE_RELIC": true, + "ENABLE_RELIC": false, "ENABLE_SODIUM": false, - "ENABLE_BOOST": true, + "ENABLE_BOOST": false, "ENABLE_BITPOLYMUL": true, "FETCH_AUTO": "ON", "ENABLE_CIRCUITS": true, diff --git a/cryptoTools b/cryptoTools index 3f1dd5c9..409c6d5c 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit 3f1dd5c909b8df3bc593a35669e25fce1f675975 +Subproject commit 409c6d5c88bd8851eaa7818e2e11dd0c405c3188 diff --git a/frontend/benchmark.h b/frontend/benchmark.h index e76ad641..03823553 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -15,678 +15,753 @@ #include "libOTe/Tools/CoeffCtx.h" #include "libOTe/Tools/TungstenCode/TungstenCode.h" #include "libOTe/Tools/ExConvCodeOld/ExConvCodeOld.h" +#include "libOTe/Tools/Dpf/RegularDpf.h" namespace osuCrypto { - inline void QCCodeBench(CLP& cmd) - { + inline void QCCodeBench(CLP& cmd) + { #ifdef ENABLE_BITPOLYMUL - u64 trials = cmd.getOr("t", 10); + u64 trials = cmd.getOr("t", 10); - // the message length of the code. - // The noise vector will have size n=2*k. - // the user can use - // -k X - // to state that exactly X rows should be used or - // -kk X - // to state that 2^X rows should be used. - u64 k = cmd.getOr("k", 1ull << cmd.getOr("kk", 10)); + // the message length of the code. + // The noise vector will have size n=2*k. + // the user can use + // -k X + // to state that exactly X rows should be used or + // -kk X + // to state that 2^X rows should be used. + u64 k = cmd.getOr("k", 1ull << cmd.getOr("kk", 10)); - u64 n = k * 2; + u64 n = k * 2; - // verbose flag. + // verbose flag. - oc::Timer timer; - QuasiCyclicCode code; - code.init2(k, n); - std::vector c0(code.size(), ZeroBlock); - for (auto t = 0ull; t < trials; ++t) - { + oc::Timer timer; + QuasiCyclicCode code; + code.init2(k, n); + std::vector c0(code.size(), ZeroBlock); + for (auto t = 0ull; t < trials; ++t) + { - timer.setTimePoint("reset"); - code.dualEncode(c0); - timer.setTimePoint("encode"); - } + timer.setTimePoint("reset"); + code.dualEncode(c0); + timer.setTimePoint("encode"); + } - if (!cmd.isSet("quiet")) - std::cout << timer << std::endl; + if (!cmd.isSet("quiet")) + std::cout << timer << std::endl; #endif - } + } - inline void EACodeBench(CLP& cmd) - { - u64 trials = cmd.getOr("t", 10); + inline void EACodeBench(CLP& cmd) + { + u64 trials = cmd.getOr("t", 10); - // the message length of the code. - // The noise vector will have size n=2*k. - // the user can use - // -k X - // to state that exactly X rows should be used or - // -kk X - // to state that 2^X rows should be used. - u64 k = cmd.getOr("k", 1ull << cmd.getOr("kk", 10)); + // the message length of the code. + // The noise vector will have size n=2*k. + // the user can use + // -k X + // to state that exactly X rows should be used or + // -kk X + // to state that 2^X rows should be used. + u64 k = cmd.getOr("k", 1ull << cmd.getOr("kk", 10)); - u64 n = cmd.getOr("n", k * cmd.getOr("R", 5.0)); + u64 n = cmd.getOr("n", k * cmd.getOr("R", 5.0)); - // the weight of the code - u64 w = cmd.getOr("w", 7); + // the weight of the code + u64 w = cmd.getOr("w", 7); - // size for the accumulator (# random transitions) + // size for the accumulator (# random transitions) - // verbose flag. - bool v = cmd.isSet("v"); + // verbose flag. + bool v = cmd.isSet("v"); - EACode code; - code.config(k, n, w); + EACode code; + code.config(k, n, w); - if (v) - { - std::cout << "n: " << code.mCodeSize << std::endl; - std::cout << "k: " << code.mMessageSize << std::endl; - std::cout << "w: " << code.mExpanderWeight << std::endl; - } + if (v) + { + std::cout << "n: " << code.mCodeSize << std::endl; + std::cout << "k: " << code.mMessageSize << std::endl; + std::cout << "w: " << code.mExpanderWeight << std::endl; + } - std::vector x(code.mCodeSize), y(code.mMessageSize); - Timer timer, verbose; + std::vector x(code.mCodeSize), y(code.mMessageSize); + Timer timer, verbose; - if (v) - code.setTimer(verbose); + if (v) + code.setTimer(verbose); - timer.setTimePoint("_____________________"); - for (u64 i = 0; i < trials; ++i) - { - code.dualEncode(x, y, {}); - timer.setTimePoint("encode"); - } + timer.setTimePoint("_____________________"); + for (u64 i = 0; i < trials; ++i) + { + code.dualEncode(x, y, {}); + timer.setTimePoint("encode"); + } - std::cout << "EA " << std::endl; - std::cout << timer << std::endl; + std::cout << "EA " << std::endl; + std::cout << timer << std::endl; - if (v) - std::cout << verbose << std::endl; - } + if (v) + std::cout << verbose << std::endl; + } - inline void ExConvCodeBench(CLP& cmd) - { - u64 trials = cmd.getOr("t", 10); + inline void ExConvCodeBench(CLP& cmd) + { + u64 trials = cmd.getOr("t", 10); - // the message length of the code. - // The noise vector will have size n=2*k. - // the user can use - // -k X - // to state that exactly X rows should be used or - // -kk X - // to state that 2^X rows should be used. - u64 k = cmd.getOr("k", 1ull << cmd.getOr("kk", 10)); + // the message length of the code. + // The noise vector will have size n=2*k. + // the user can use + // -k X + // to state that exactly X rows should be used or + // -kk X + // to state that 2^X rows should be used. + u64 k = cmd.getOr("k", 1ull << cmd.getOr("kk", 10)); - u64 n = cmd.getOr("n", k * cmd.getOr("R", 2.0)); + u64 n = cmd.getOr("n", k * cmd.getOr("R", 2.0)); - // the weight of the code - u64 w = cmd.getOr("w", 7); + // the weight of the code + u64 w = cmd.getOr("w", 7); - // size for the accumulator (# random transitions) - u64 a = cmd.getOr("a", roundUpTo(log2ceil(n), 8)); + // size for the accumulator (# random transitions) + u64 a = cmd.getOr("a", roundUpTo(log2ceil(n), 8)); - bool gf128 = cmd.isSet("gf128"); + bool gf128 = cmd.isSet("gf128"); - // verbose flag. - bool v = cmd.isSet("v"); - bool sys = cmd.isSet("sys"); + // verbose flag. + bool v = cmd.isSet("v"); + bool sys = cmd.isSet("sys"); - ExConvCode code; - code.config(k, n, w, a, sys); + ExConvCode code; + code.config(k, n, w, a, sys); - if (v) - { - std::cout << "n: " << code.mCodeSize << std::endl; - std::cout << "k: " << code.mMessageSize << std::endl; - //std::cout << "w: " << code.mExpanderWeight << std::endl; - } + if (v) + { + std::cout << "n: " << code.mCodeSize << std::endl; + std::cout << "k: " << code.mMessageSize << std::endl; + //std::cout << "w: " << code.mExpanderWeight << std::endl; + } - std::vector x(code.mCodeSize), y(code.mMessageSize * !sys); - Timer timer, verbose; + std::vector x(code.mCodeSize), y(code.mMessageSize * !sys); + Timer timer, verbose; - if (v) - code.setTimer(verbose); + if (v) + code.setTimer(verbose); - timer.setTimePoint("_____________________"); - for (u64 i = 0; i < trials; ++i) - { - if (gf128) - code.dualEncode(x.begin(), {}); - else - code.dualEncode(x.begin(), {}); + timer.setTimePoint("_____________________"); + for (u64 i = 0; i < trials; ++i) + { + if (gf128) + code.dualEncode(x.begin(), {}); + else + code.dualEncode(x.begin(), {}); - timer.setTimePoint("encode"); - } + timer.setTimePoint("encode"); + } - std::cout << "EC " << std::endl; - std::cout << timer << std::endl; + std::cout << "EC " << std::endl; + std::cout << timer << std::endl; - if (v) - std::cout << verbose << std::endl; - } + if (v) + std::cout << verbose << std::endl; + } - inline void ExConvCodeOldBench(CLP& cmd) - { + inline void ExConvCodeOldBench(CLP& cmd) + { #ifdef LIBOTE_ENABLE_OLD_EXCONV - u64 trials = cmd.getOr("t", 10); + u64 trials = cmd.getOr("t", 10); - // the message length of the code. - // The noise vector will have size n=2*k. - // the user can use - // -k X - // to state that exactly X rows should be used or - // -kk X - // to state that 2^X rows should be used. - u64 k = cmd.getOr("k", 1ull << cmd.getOr("kk", 10)); + // the message length of the code. + // The noise vector will have size n=2*k. + // the user can use + // -k X + // to state that exactly X rows should be used or + // -kk X + // to state that 2^X rows should be used. + u64 k = cmd.getOr("k", 1ull << cmd.getOr("kk", 10)); - u64 n = cmd.getOr("n", k * cmd.getOr("R", 2.0)); + u64 n = cmd.getOr("n", k * cmd.getOr("R", 2.0)); - // the weight of the code - u64 w = cmd.getOr("w", 7); + // the weight of the code + u64 w = cmd.getOr("w", 7); - // size for the accumulator (# random transitions) - u64 a = cmd.getOr("a", roundUpTo(log2ceil(n), 8)); + // size for the accumulator (# random transitions) + u64 a = cmd.getOr("a", roundUpTo(log2ceil(n), 8)); - bool gf128 = cmd.isSet("gf128"); + bool gf128 = cmd.isSet("gf128"); - // verbose flag. - bool v = cmd.isSet("v"); - bool sys = cmd.isSet("sys"); + // verbose flag. + bool v = cmd.isSet("v"); + bool sys = cmd.isSet("sys"); - ExConvCodeOld code; - code.config(k, n, w, a, sys); + ExConvCodeOld code; + code.config(k, n, w, a, sys); - if (v) - { - std::cout << "n: " << code.mCodeSize << std::endl; - std::cout << "k: " << code.mMessageSize << std::endl; - //std::cout << "w: " << code.mExpanderWeight << std::endl; - } + if (v) + { + std::cout << "n: " << code.mCodeSize << std::endl; + std::cout << "k: " << code.mMessageSize << std::endl; + //std::cout << "w: " << code.mExpanderWeight << std::endl; + } - std::vector x(code.mCodeSize), y(code.mMessageSize * !sys); - Timer timer, verbose; + std::vector x(code.mCodeSize), y(code.mMessageSize * !sys); + Timer timer, verbose; - if (v) - code.setTimer(verbose); + if (v) + code.setTimer(verbose); - timer.setTimePoint("_____________________"); - for (u64 i = 0; i < trials; ++i) - { - code.dualEncode(x); + timer.setTimePoint("_____________________"); + for (u64 i = 0; i < trials; ++i) + { + code.dualEncode(x); - timer.setTimePoint("encode"); - } + timer.setTimePoint("encode"); + } - if (cmd.isSet("quiet") == false) - { - std::cout << "EC " << std::endl; - std::cout << timer << std::endl; - } - if (v) - std::cout << verbose << std::endl; + if (cmd.isSet("quiet") == false) + { + std::cout << "EC " << std::endl; + std::cout << timer << std::endl; + } + if (v) + std::cout << verbose << std::endl; #else - std::cout << "LIBOTE_ENABLE_OLD_EXCONV = false" << std::endl; + std::cout << "LIBOTE_ENABLE_OLD_EXCONV = false" << std::endl; #endif - } + } - inline void PprfBench(CLP& cmd) - { + inline void PprfBench(CLP& cmd) + { #ifdef ENABLE_SILENTOT - try - { - using Ctx = CoeffCtxGF2; - RegularPprfReceiver recver; - RegularPprfSender sender; + try + { + using Ctx = CoeffCtxGF2; + RegularPprfReceiver recver; + RegularPprfSender sender; - u64 trials = cmd.getOr("t", 10); + u64 trials = cmd.getOr("t", 10); - u64 w = cmd.getOr("w", 32); - u64 n = cmd.getOr("n", 1ull << cmd.getOr("nn", 14)); + u64 w = cmd.getOr("w", 32); + u64 n = cmd.getOr("n", 1ull << cmd.getOr("nn", 14)); - PRNG prng0(ZeroBlock), prng1(ZeroBlock); - //block delta = prng0.get(); + PRNG prng0(ZeroBlock), prng1(ZeroBlock); + //block delta = prng0.get(); - auto sock = coproto::LocalAsyncSocket::makePair(); + auto sock = coproto::LocalAsyncSocket::makePair(); - Timer rTimer; - auto s = rTimer.setTimePoint("start"); - auto ctx = Ctx{}; - auto vals = Ctx::Vec(w); - auto out0 = Ctx::Vec(n / w * w); - auto out1 = Ctx::Vec(n / w * w); + Timer rTimer; + auto s = rTimer.setTimePoint("start"); + auto ctx = Ctx{}; + auto vals = Ctx::Vec(w); + auto out0 = Ctx::Vec(n / w * w); + auto out1 = Ctx::Vec(n / w * w); - for (u64 t = 0; t < trials; ++t) - { - sender.configure(n / w, w); - recver.configure(n / w, w); + for (u64 t = 0; t < trials; ++t) + { + sender.configure(n / w, w); + recver.configure(n / w, w); - std::vector> baseSend(sender.baseOtCount()); - std::vector baseRecv(sender.baseOtCount()); - BitVector baseChoice(sender.baseOtCount()); - sender.setBase(baseSend); - recver.setBase(baseRecv); - recver.setChoiceBits(baseChoice); + std::vector> baseSend(sender.baseOtCount()); + std::vector baseRecv(sender.baseOtCount()); + BitVector baseChoice(sender.baseOtCount()); + sender.setBase(baseSend); + recver.setBase(baseRecv); + recver.setChoiceBits(baseChoice); - auto p0 = sender.expand(sock[0], vals, prng0.get(), out0, PprfOutputFormat::Interleaved, true, 1, ctx); - auto p1 = recver.expand(sock[1], out1, PprfOutputFormat::Interleaved, true, 1, ctx); + auto p0 = sender.expand(sock[0], vals, prng0.get(), out0, PprfOutputFormat::Interleaved, true, 1, ctx); + auto p1 = recver.expand(sock[1], out1, PprfOutputFormat::Interleaved, true, 1, ctx); - rTimer.setTimePoint("r start"); - coproto::sync_wait(macoro::when_all_ready( - std::move(p0), std::move(p1))); - rTimer.setTimePoint("r done"); + rTimer.setTimePoint("r start"); + coproto::sync_wait(macoro::when_all_ready( + std::move(p0), std::move(p1))); + rTimer.setTimePoint("r done"); - } - auto e = rTimer.setTimePoint("end"); + } + auto e = rTimer.setTimePoint("end"); - auto time = std::chrono::duration_cast(e - s).count(); - auto avgTime = time / double(trials); - auto timePer512 = avgTime / n * 512; - std::cout << "OT n:" << n << ", " << - avgTime << "ms/batch, " << timePer512 << "ms/512ot" << std::endl; + auto time = std::chrono::duration_cast(e - s).count(); + auto avgTime = time / double(trials); + auto timePer512 = avgTime / n * 512; + std::cout << "OT n:" << n << ", " << + avgTime << "ms/batch, " << timePer512 << "ms/512ot" << std::endl; - std::cout << rTimer << std::endl; + std::cout << rTimer << std::endl; - std::cout << sock[0].bytesReceived() / trials << " " << sock[1].bytesReceived() / trials << " bytes per " << std::endl; - } - catch (std::exception& e) - { - std::cout << e.what() << std::endl; - } + std::cout << sock[0].bytesReceived() / trials << " " << sock[1].bytesReceived() / trials << " bytes per " << std::endl; + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + } #else - std::cout << "ENABLE_SILENTOT = false" << std::endl; + std::cout << "ENABLE_SILENTOT = false" << std::endl; #endif - } - - inline void TungstenCodeBench(CLP& cmd) - { - u64 trials = cmd.getOr("t", 10); - - // the message length of the code. - // The noise vector will have size n=2*k. - // the user can use - // -k X - // to state that exactly X rows should be used or - // -kk X - // to state that 2^X rows should be used. - u64 k = cmd.getOr("k", 1ull << cmd.getOr("kk", 10)); - - u64 n = cmd.getOr("n", k * cmd.getOr("R", 2.0)); - - // verbose flag. - bool v = cmd.isSet("v"); - - experimental::TungstenCode code; - code.config(k, n); - code.mNumIter = cmd.getOr("iter", 2); - - if (v) - { - std::cout << "n: " << code.mCodeSize << std::endl; - std::cout << "k: " << code.mMessageSize << std::endl; - } - - AlignedUnVector x(code.mCodeSize); - Timer timer, verbose; - - - timer.setTimePoint("_____________________"); - for (u64 i = 0; i < trials; ++i) - { - code.dualEncode(x.data(), {}); - - timer.setTimePoint("encode"); - } - - if (cmd.isSet("quiet") == false) - { - std::cout << "tungsten " << std::endl; - std::cout << timer << std::endl; - } - if (v) - std::cout << verbose << std::endl; - } - - - inline void transpose(const CLP& cmd) - { + } + + inline void TungstenCodeBench(CLP& cmd) + { + u64 trials = cmd.getOr("t", 10); + + // the message length of the code. + // The noise vector will have size n=2*k. + // the user can use + // -k X + // to state that exactly X rows should be used or + // -kk X + // to state that 2^X rows should be used. + u64 k = cmd.getOr("k", 1ull << cmd.getOr("kk", 10)); + + u64 n = cmd.getOr("n", k * cmd.getOr("R", 2.0)); + + // verbose flag. + bool v = cmd.isSet("v"); + + experimental::TungstenCode code; + code.config(k, n); + code.mNumIter = cmd.getOr("iter", 2); + + if (v) + { + std::cout << "n: " << code.mCodeSize << std::endl; + std::cout << "k: " << code.mMessageSize << std::endl; + } + + AlignedUnVector x(code.mCodeSize); + Timer timer, verbose; + + + timer.setTimePoint("_____________________"); + for (u64 i = 0; i < trials; ++i) + { + code.dualEncode(x.data(), {}); + + timer.setTimePoint("encode"); + } + + if (cmd.isSet("quiet") == false) + { + std::cout << "tungsten " << std::endl; + std::cout << timer << std::endl; + } + if (v) + std::cout << verbose << std::endl; + } + + + inline void transpose(const CLP& cmd) + { #ifdef ENABLE_AVX - u64 trials = cmd.getOr("trials", 1ull << 18); - { + u64 trials = cmd.getOr("trials", 1ull << 18); + { - AlignedArray data; + AlignedArray data; - Timer timer; - auto start0 = timer.setTimePoint("b"); + Timer timer; + auto start0 = timer.setTimePoint("b"); - for (u64 i = 0; i < trials; ++i) - { - avx_transpose128(data.data()); - } + for (u64 i = 0; i < trials; ++i) + { + avx_transpose128(data.data()); + } - auto end0 = timer.setTimePoint("b"); + auto end0 = timer.setTimePoint("b"); - for (u64 i = 0; i < trials; ++i) - { - sse_transpose128(data.data()); - } + for (u64 i = 0; i < trials; ++i) + { + sse_transpose128(data.data()); + } - auto end1 = timer.setTimePoint("b"); + auto end1 = timer.setTimePoint("b"); - std::cout << "avx " << std::chrono::duration_cast(end0 - start0).count() << std::endl; - std::cout << "sse " << std::chrono::duration_cast(end1 - end0).count() << std::endl; - } + std::cout << "avx " << std::chrono::duration_cast(end0 - start0).count() << std::endl; + std::cout << "sse " << std::chrono::duration_cast(end1 - end0).count() << std::endl; + } - { - AlignedArray data; + { + AlignedArray data; - Timer timer; - auto start1 = timer.setTimePoint("b"); + Timer timer; + auto start1 = timer.setTimePoint("b"); - for (u64 i = 0; i < trials * 8; ++i) - { - avx_transpose128(data.data()); - } + for (u64 i = 0; i < trials * 8; ++i) + { + avx_transpose128(data.data()); + } - auto start0 = timer.setTimePoint("b"); + auto start0 = timer.setTimePoint("b"); - for (u64 i = 0; i < trials; ++i) - { - avx_transpose128x1024(data.data()); - } + for (u64 i = 0; i < trials; ++i) + { + avx_transpose128x1024(data.data()); + } - auto end0 = timer.setTimePoint("b"); + auto end0 = timer.setTimePoint("b"); - for (u64 i = 0; i < trials; ++i) - { - sse_transpose128x1024(*(std::array, 128>*)data.data()); - } + for (u64 i = 0; i < trials; ++i) + { + sse_transpose128x1024(*(std::array, 128>*)data.data()); + } - auto end1 = timer.setTimePoint("b"); + auto end1 = timer.setTimePoint("b"); - std::cout << "avx " << std::chrono::duration_cast(start0 - start1).count() << std::endl; - std::cout << "avx " << std::chrono::duration_cast(end0 - start0).count() << std::endl; - std::cout << "sse " << std::chrono::duration_cast(end1 - end0).count() << std::endl; - } + std::cout << "avx " << std::chrono::duration_cast(start0 - start1).count() << std::endl; + std::cout << "avx " << std::chrono::duration_cast(end0 - start0).count() << std::endl; + std::cout << "sse " << std::chrono::duration_cast(end1 - end0).count() << std::endl; + } #endif - } + } - inline void SilentOtBench(const CLP& cmd) - { + inline void SilentOtBench(const CLP& cmd) + { #ifdef ENABLE_SILENTOT - try - { + try + { - SilentOtExtSender sender; - SilentOtExtReceiver recver; + SilentOtExtSender sender; + SilentOtExtReceiver recver; - u64 trials = cmd.getOr("t", 10); + u64 trials = cmd.getOr("t", 10); - u64 n = cmd.getOr("n", 1ull << cmd.getOr("nn", 20)); - MultType multType = (MultType)cmd.getOr("m", (int)MultType::ExConv7x24); - std::cout << multType << std::endl; + u64 n = cmd.getOr("n", 1ull << cmd.getOr("nn", 20)); + MultType multType = (MultType)cmd.getOr("m", (int)MultType::ExConv7x24); + std::cout << multType << std::endl; - recver.mMultType = multType; - sender.mMultType = multType; + recver.mMultType = multType; + sender.mMultType = multType; - PRNG prng0(ZeroBlock), prng1(ZeroBlock); - block delta = prng0.get(); + PRNG prng0(ZeroBlock), prng1(ZeroBlock); + block delta = prng0.get(); - auto sock = coproto::LocalAsyncSocket::makePair(); + auto sock = coproto::LocalAsyncSocket::makePair(); - Timer sTimer; - Timer rTimer; - recver.setTimer(rTimer); - sender.setTimer(rTimer); - sTimer.setTimePoint("start"); - auto s = sTimer.setTimePoint("start"); + Timer sTimer; + Timer rTimer; + recver.setTimer(rTimer); + sender.setTimer(rTimer); + sTimer.setTimePoint("start"); + auto s = sTimer.setTimePoint("start"); - for (u64 t = 0; t < trials; ++t) - { - sender.configure(n); - recver.configure(n); + for (u64 t = 0; t < trials; ++t) + { + sender.configure(n); + recver.configure(n); - auto choice = recver.sampleBaseChoiceBits(prng0); - std::vector> sendBase(sender.silentBaseOtCount()); - std::vector recvBase(recver.silentBaseOtCount()); - sender.setSilentBaseOts(sendBase); - recver.setSilentBaseOts(recvBase); + auto choice = recver.sampleBaseChoiceBits(prng0); + std::vector> sendBase(sender.silentBaseOtCount()); + std::vector recvBase(recver.silentBaseOtCount()); + sender.setSilentBaseOts(sendBase); + recver.setSilentBaseOts(recvBase); - auto p0 = sender.silentSendInplace(delta, n, prng0, sock[0]); - auto p1 = recver.silentReceiveInplace(n, prng1, sock[1], ChoiceBitPacking::True); + auto p0 = sender.silentSendInplace(delta, n, prng0, sock[0]); + auto p1 = recver.silentReceiveInplace(n, prng1, sock[1], ChoiceBitPacking::True); - rTimer.setTimePoint("r start"); - coproto::sync_wait(macoro::when_all_ready( - std::move(p0), std::move(p1))); - rTimer.setTimePoint("r done"); + rTimer.setTimePoint("r start"); + coproto::sync_wait(macoro::when_all_ready( + std::move(p0), std::move(p1))); + rTimer.setTimePoint("r done"); - } - auto e = rTimer.setTimePoint("end"); + } + auto e = rTimer.setTimePoint("end"); - if (cmd.isSet("quiet") == false) - { + if (cmd.isSet("quiet") == false) + { - auto time = std::chrono::duration_cast(e - s).count(); - auto avgTime = time / double(trials); - auto timePer512 = avgTime / n * 512; - std::cout << "OT n:" << n << ", " << - avgTime << "ms/batch, " << timePer512 << "ms/512ot" << std::endl; + auto time = std::chrono::duration_cast(e - s).count(); + auto avgTime = time / double(trials); + auto timePer512 = avgTime / n * 512; + std::cout << "OT n:" << n << ", " << + avgTime << "ms/batch, " << timePer512 << "ms/512ot" << std::endl; - std::cout << sTimer << std::endl; - std::cout << rTimer << std::endl; + std::cout << sTimer << std::endl; + std::cout << rTimer << std::endl; - std::cout << sock[0].bytesReceived() / trials << " " << sock[1].bytesReceived() / trials << " bytes per " << std::endl; - } - } - catch (std::exception& e) - { - std::cout << e.what() << std::endl; - } + std::cout << sock[0].bytesReceived() / trials << " " << sock[1].bytesReceived() / trials << " bytes per " << std::endl; + } + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + } #else - std::cout << "ENABLE_SILENTOT = false" << std::endl; + std::cout << "ENABLE_SILENTOT = false" << std::endl; #endif - } + } - inline void VoleBench2(const CLP& cmd) - { + inline void VoleBench2(const CLP& cmd) + { #ifdef ENABLE_SILENT_VOLE - try - { + try + { - SilentVoleSender sender; - SilentVoleReceiver recver; + SilentVoleSender sender; + SilentVoleReceiver recver; - u64 trials = cmd.getOr("t", 10); + u64 trials = cmd.getOr("t", 10); - u64 n = cmd.getOr("n", 1ull << cmd.getOr("nn", 20)); - MultType multType = (MultType)cmd.getOr("m", (int)MultType::ExConv7x24); - std::cout << multType << std::endl; + u64 n = cmd.getOr("n", 1ull << cmd.getOr("nn", 20)); + MultType multType = (MultType)cmd.getOr("m", (int)MultType::ExConv7x24); + std::cout << multType << std::endl; - recver.mMultType = multType; - sender.mMultType = multType; + recver.mMultType = multType; + sender.mMultType = multType; - std::vector> baseSend(128); - std::vector baseRecv(128); - BitVector baseChoice(128); - PRNG prng(CCBlock); - baseChoice.randomize(prng); - for (u64 i = 0; i < 128; ++i) - { - baseSend[i] = prng.get(); - baseRecv[i] = baseSend[i][baseChoice[i]]; - } + std::vector> baseSend(128); + std::vector baseRecv(128); + BitVector baseChoice(128); + PRNG prng(CCBlock); + baseChoice.randomize(prng); + for (u64 i = 0; i < 128; ++i) + { + baseSend[i] = prng.get(); + baseRecv[i] = baseSend[i][baseChoice[i]]; + } #ifdef ENABLE_SOFTSPOKEN_OT - sender.mOtExtRecver.emplace(); - sender.mOtExtSender.emplace(); - recver.mOtExtRecver.emplace(); - recver.mOtExtSender.emplace(); - sender.mOtExtRecver->setBaseOts(baseSend); - recver.mOtExtRecver->setBaseOts(baseSend); - sender.mOtExtSender->setBaseOts(baseRecv, baseChoice); - recver.mOtExtSender->setBaseOts(baseRecv, baseChoice); + sender.mOtExtRecver.emplace(); + sender.mOtExtSender.emplace(); + recver.mOtExtRecver.emplace(); + recver.mOtExtSender.emplace(); + sender.mOtExtRecver->setBaseOts(baseSend); + recver.mOtExtRecver->setBaseOts(baseSend); + sender.mOtExtSender->setBaseOts(baseRecv, baseChoice); + recver.mOtExtSender->setBaseOts(baseRecv, baseChoice); #endif // ENABLE_SOFTSPOKEN_OT - PRNG prng0(ZeroBlock), prng1(ZeroBlock); - block delta = prng0.get(); + PRNG prng0(ZeroBlock), prng1(ZeroBlock); + block delta = prng0.get(); - auto sock = coproto::LocalAsyncSocket::makePair(); + auto sock = coproto::LocalAsyncSocket::makePair(); - Timer sTimer; - Timer rTimer; - sTimer.setTimePoint("start"); - rTimer.setTimePoint("start"); + Timer sTimer; + Timer rTimer; + sTimer.setTimePoint("start"); + rTimer.setTimePoint("start"); - auto t0 = std::thread([&] { - for (u64 t = 0; t < trials; ++t) - { - auto p0 = sender.silentSendInplace(delta, n, prng0, sock[0]); + auto t0 = std::thread([&] { + for (u64 t = 0; t < trials; ++t) + { + auto p0 = sender.silentSendInplace(delta, n, prng0, sock[0]); - char c = 0; + char c = 0; - coproto::sync_wait(sock[0].send(std::move(c))); - coproto::sync_wait(sock[0].recv(c)); - sTimer.setTimePoint("__"); - coproto::sync_wait(sock[0].send(std::move(c))); - coproto::sync_wait(sock[0].recv(c)); - sTimer.setTimePoint("s start"); - coproto::sync_wait(p0); - sTimer.setTimePoint("s done"); - } - }); + coproto::sync_wait(sock[0].send(std::move(c))); + coproto::sync_wait(sock[0].recv(c)); + sTimer.setTimePoint("__"); + coproto::sync_wait(sock[0].send(std::move(c))); + coproto::sync_wait(sock[0].recv(c)); + sTimer.setTimePoint("s start"); + coproto::sync_wait(p0); + sTimer.setTimePoint("s done"); + } + }); - for (u64 t = 0; t < trials; ++t) - { - auto p1 = recver.silentReceiveInplace(n, prng1, sock[1]); - char c=0; - coproto::sync_wait(sock[1].send(std::move(c))); - coproto::sync_wait(sock[1].recv(c)); + for (u64 t = 0; t < trials; ++t) + { + auto p1 = recver.silentReceiveInplace(n, prng1, sock[1]); + char c = 0; + coproto::sync_wait(sock[1].send(std::move(c))); + coproto::sync_wait(sock[1].recv(c)); - rTimer.setTimePoint("__"); - coproto::sync_wait(sock[1].send(std::move(c))); - coproto::sync_wait(sock[1].recv(c)); + rTimer.setTimePoint("__"); + coproto::sync_wait(sock[1].send(std::move(c))); + coproto::sync_wait(sock[1].recv(c)); - rTimer.setTimePoint("r start"); - coproto::sync_wait(p1); - rTimer.setTimePoint("r done"); + rTimer.setTimePoint("r start"); + coproto::sync_wait(p1); + rTimer.setTimePoint("r done"); - } + } - t0.join(); - std::cout << sTimer << std::endl; - std::cout << rTimer << std::endl; + t0.join(); + std::cout << sTimer << std::endl; + std::cout << rTimer << std::endl; - std::cout << sock[0].bytesReceived() / trials << " " << sock[1].bytesReceived() / trials << " bytes per " << std::endl; - } - catch (std::exception& e) - { - std::cout << e.what() << std::endl; - } + std::cout << sock[0].bytesReceived() / trials << " " << sock[1].bytesReceived() / trials << " bytes per " << std::endl; + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + } #else - std::cout << "ENABLE_Silent_VOLE = false" << std::endl; + std::cout << "ENABLE_Silent_VOLE = false" << std::endl; #endif - } - - - void AESBenchmark(const oc::CLP& cmd) - { - u64 n = roundUpTo(cmd.getOr("n", 1ull << cmd.getOr("nn", 20)), 8); - u64 t =cmd.getOr("t", 10); - using AES_ = AES;// details::AES; - - auto unroll8 = [](AES_& aes, block* __restrict s) - { - block b[8]; - b[0] = AES_::firstFn(s[0], aes.mRoundKey[0]); - b[1] = AES_::firstFn(s[1], aes.mRoundKey[0]); - b[2] = AES_::firstFn(s[2], aes.mRoundKey[0]); - b[3] = AES_::firstFn(s[3], aes.mRoundKey[0]); - b[4] = AES_::firstFn(s[4], aes.mRoundKey[0]); - b[5] = AES_::firstFn(s[5], aes.mRoundKey[0]); - b[6] = AES_::firstFn(s[6], aes.mRoundKey[0]); - b[7] = AES_::firstFn(s[7], aes.mRoundKey[0]); - - for (u64 i = 1; i < 9; ++i) - { - b[0] = AES_::roundFn(b[0], aes.mRoundKey[i]); - b[1] = AES_::roundFn(b[1], aes.mRoundKey[i]); - b[2] = AES_::roundFn(b[2], aes.mRoundKey[i]); - b[3] = AES_::roundFn(b[3], aes.mRoundKey[i]); - b[4] = AES_::roundFn(b[4], aes.mRoundKey[i]); - b[5] = AES_::roundFn(b[5], aes.mRoundKey[i]); - b[6] = AES_::roundFn(b[6], aes.mRoundKey[i]); - b[7] = AES_::roundFn(b[7], aes.mRoundKey[i]); - } - - - b[0] = AES_::penultimateFn(b[0], aes.mRoundKey[9]); - b[1] = AES_::penultimateFn(b[1], aes.mRoundKey[9]); - b[2] = AES_::penultimateFn(b[2], aes.mRoundKey[9]); - b[3] = AES_::penultimateFn(b[3], aes.mRoundKey[9]); - b[4] = AES_::penultimateFn(b[4], aes.mRoundKey[9]); - b[5] = AES_::penultimateFn(b[5], aes.mRoundKey[9]); - b[6] = AES_::penultimateFn(b[6], aes.mRoundKey[9]); - b[7] = AES_::penultimateFn(b[7], aes.mRoundKey[9]); - s[0] = AES_::finalFn(b[0], aes.mRoundKey[10]); - s[1] = AES_::finalFn(b[1], aes.mRoundKey[10]); - s[2] = AES_::finalFn(b[2], aes.mRoundKey[10]); - s[3] = AES_::finalFn(b[3], aes.mRoundKey[10]); - s[4] = AES_::finalFn(b[4], aes.mRoundKey[10]); - s[5] = AES_::finalFn(b[5], aes.mRoundKey[10]); - s[6] = AES_::finalFn(b[6], aes.mRoundKey[10]); - s[7] = AES_::finalFn(b[7], aes.mRoundKey[10]); - - }; - - oc::AlignedUnVector x(n); - AES_ aes(block(42352345, 3245345234676534)); - Timer timer; - timer.setTimePoint("begin"); - for (u64 tt = 0; tt < t; ++tt) - { - for (u64 i = 0; i < n; i += 8) - { - unroll8(aes, x.data() + i); - } - timer.setTimePoint("unroll"); - } - - for (u64 tt = 0; tt < t; ++tt) - { - for (u64 i = 0; i < n; i += 8) - { - aes.ecbEncBlocks<8>(x.data() + i, x.data() + i); - } - timer.setTimePoint("aes <>"); - } - - for (u64 tt = 0; tt < t; ++tt) - { - aes.ecbEncBlocks(x, x); - timer.setTimePoint("aes "); - } - - std::cout << timer << std::endl; - - } + } + + + void AESBenchmark(const oc::CLP& cmd) + { + u64 n = roundUpTo(cmd.getOr("n", 1ull << cmd.getOr("nn", 20)), 8); + u64 t = cmd.getOr("t", 10); + using AES_ = AES;// details::AES; + + auto unroll8 = [](AES_& aes, block* __restrict s) + { + block b[8]; + b[0] = AES_::firstFn(s[0], aes.mRoundKey[0]); + b[1] = AES_::firstFn(s[1], aes.mRoundKey[0]); + b[2] = AES_::firstFn(s[2], aes.mRoundKey[0]); + b[3] = AES_::firstFn(s[3], aes.mRoundKey[0]); + b[4] = AES_::firstFn(s[4], aes.mRoundKey[0]); + b[5] = AES_::firstFn(s[5], aes.mRoundKey[0]); + b[6] = AES_::firstFn(s[6], aes.mRoundKey[0]); + b[7] = AES_::firstFn(s[7], aes.mRoundKey[0]); + + for (u64 i = 1; i < 9; ++i) + { + b[0] = AES_::roundFn(b[0], aes.mRoundKey[i]); + b[1] = AES_::roundFn(b[1], aes.mRoundKey[i]); + b[2] = AES_::roundFn(b[2], aes.mRoundKey[i]); + b[3] = AES_::roundFn(b[3], aes.mRoundKey[i]); + b[4] = AES_::roundFn(b[4], aes.mRoundKey[i]); + b[5] = AES_::roundFn(b[5], aes.mRoundKey[i]); + b[6] = AES_::roundFn(b[6], aes.mRoundKey[i]); + b[7] = AES_::roundFn(b[7], aes.mRoundKey[i]); + } + + + b[0] = AES_::penultimateFn(b[0], aes.mRoundKey[9]); + b[1] = AES_::penultimateFn(b[1], aes.mRoundKey[9]); + b[2] = AES_::penultimateFn(b[2], aes.mRoundKey[9]); + b[3] = AES_::penultimateFn(b[3], aes.mRoundKey[9]); + b[4] = AES_::penultimateFn(b[4], aes.mRoundKey[9]); + b[5] = AES_::penultimateFn(b[5], aes.mRoundKey[9]); + b[6] = AES_::penultimateFn(b[6], aes.mRoundKey[9]); + b[7] = AES_::penultimateFn(b[7], aes.mRoundKey[9]); + s[0] = AES_::finalFn(b[0], aes.mRoundKey[10]); + s[1] = AES_::finalFn(b[1], aes.mRoundKey[10]); + s[2] = AES_::finalFn(b[2], aes.mRoundKey[10]); + s[3] = AES_::finalFn(b[3], aes.mRoundKey[10]); + s[4] = AES_::finalFn(b[4], aes.mRoundKey[10]); + s[5] = AES_::finalFn(b[5], aes.mRoundKey[10]); + s[6] = AES_::finalFn(b[6], aes.mRoundKey[10]); + s[7] = AES_::finalFn(b[7], aes.mRoundKey[10]); + + }; + + oc::AlignedUnVector x(n); + AES_ aes(block(42352345, 3245345234676534)); + Timer timer; + timer.setTimePoint("begin"); + for (u64 tt = 0; tt < t; ++tt) + { + for (u64 i = 0; i < n; i += 8) + { + unroll8(aes, x.data() + i); + } + timer.setTimePoint("unroll"); + } + + for (u64 tt = 0; tt < t; ++tt) + { + for (u64 i = 0; i < n; i += 8) + { + aes.ecbEncBlocks<8>(x.data() + i, x.data() + i); + } + timer.setTimePoint("aes <>"); + } + + for (u64 tt = 0; tt < t; ++tt) + { + aes.ecbEncBlocks(x, x); + timer.setTimePoint("aes "); + } + + std::cout << timer << std::endl; + + } + + void RegularDpfBenchmark(const oc::CLP& cmd) + { + PRNG prng(block(231234, 321312)); + u64 trials = cmd.getOr("t", 100); + u64 domain = 1ull << cmd.getOr("d", 10); + u64 numPoints = cmd.getOr("p", 64); + std::vector points0(numPoints); + std::vector points1(numPoints); + std::vector values0(numPoints); + std::vector values1(numPoints); + for (u64 i = 0; i < numPoints; ++i) + { + points1[i] = prng.get(); + points0[i] = (prng.get() % domain) ^ points1[i]; + values0[i] = prng.get(); + values1[i] = prng.get(); + } + + + auto sock = coproto::LocalAsyncSocket::makePair(); + + Timer timer; + + std::array dpf; + dpf[0].init(0, domain, points0, values0); + dpf[1].init(1, domain, points1, values1); + + auto baseCount = dpf[0].baseOtCount(); + + std::array, 2> baseRecv; + std::array>, 2> baseSend; + std::array baseChoice; + baseRecv[0].resize(baseCount); + baseRecv[1].resize(baseCount); + baseSend[0].resize(baseCount); + baseSend[1].resize(baseCount); + baseChoice[0].resize(baseCount); + baseChoice[1].resize(baseCount); + baseChoice[0].randomize(prng); + baseChoice[1].randomize(prng); + for (u64 i = 0; i < baseCount; ++i) + { + baseSend[0][i] = prng.get(); + baseSend[1][i] = prng.get(); + baseRecv[0][i] = baseSend[1][i][baseChoice[0][i]]; + baseRecv[1][i] = baseSend[0][i][baseChoice[1][i]]; + } + dpf[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); + dpf[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); + + std::array, 2> output; + output[0].resize(numPoints, domain); + output[1].resize(numPoints, domain); + + for (u64 tt = 0; tt < trials; ++tt) + { + + timer.setTimePoint("start"); + macoro::sync_wait(macoro::when_all_ready( + dpf[0].expand(output[0], prng, sock[0]), + dpf[1].expand(output[1], prng, sock[1]) + )); + timer.setTimePoint("finish"); + + dpf[0].init(0, domain, points0, values0); + dpf[1].init(1, domain, points1, values1); + dpf[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); + dpf[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); + } + + if (cmd.isSet("v")) + std::cout << timer << std::endl; + } } \ No newline at end of file diff --git a/frontend/main.cpp b/frontend/main.cpp index 978976ee..ce2ccf44 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -119,6 +119,12 @@ int main(int argc, char** argv) TungstenCodeBench(cmd); else if (cmd.isSet("aes")) AESBenchmark(cmd); + else if (cmd.isSet("dpf")) + RegularDpfBenchmark(cmd); + else + { + std::cout << "unknown benchmark" << std::endl; + } return 0; } diff --git a/libOTe/Tools/Dpf/RegularDpf.h b/libOTe/Tools/Dpf/RegularDpf.h new file mode 100644 index 00000000..ecfd1e09 --- /dev/null +++ b/libOTe/Tools/Dpf/RegularDpf.h @@ -0,0 +1,503 @@ +#pragma once + + +#include "cryptoTools/Common/Defines.h" +#include "coproto/Socket/Socket.h" +#include "cryptoTools/Crypto/PRNG.h" +#include "cryptoTools/Common/BitVector.h" +#include "cryptoTools/Common/Matrix.h" + +namespace osuCrypto +{ + struct RegularDpf + { + enum class OutputFormat + { + // The i'th row holds the i'th leaf for all trees. + // The j'th tree is in the j'th column. + ByLeafIndex, + + // The i'th row holds the i'th tree. + // The j'th leaf is in the j'th column. + ByTreeIndex, + + }; + + OutputFormat mOutputFormat = OutputFormat::ByLeafIndex; + + u64 mPartyIdx = 0; + + u64 mDomain = 0; + + u64 mDepth = 0; + + std::vector mPoints; + + std::vector mValues; + + oc::BitVector mChoiceBits; + + std::vector mRecvOts; + std::vector> mSendOts; + + u64 mOtIdx = 0; + + u8 lsb(const block& b) + { + return b.get(0) & 1; + } + + void init( + u64 partyIdx, + u64 domain, + span points, + span values) + { + if (partyIdx > 1) + throw RTE_LOC; + if (domain < 2) + throw RTE_LOC; + if (points.size() != values.size()) + throw RTE_LOC; + + mPartyIdx = partyIdx; + mDomain = domain; + mDepth = oc::log2ceil(domain); + + mPoints.clear(); + mValues.clear(); + mPoints.insert(mPoints.end(), points.begin(), points.end()); + mValues.insert(mValues.end(), values.begin(), values.end()); + + } + +#define SIMD8(VAR, STATEMENT) \ + { constexpr u64 VAR = 0; STATEMENT; }\ + { constexpr u64 VAR = 1; STATEMENT; }\ + { constexpr u64 VAR = 2; STATEMENT; }\ + { constexpr u64 VAR = 3; STATEMENT; }\ + { constexpr u64 VAR = 4; STATEMENT; }\ + { constexpr u64 VAR = 5; STATEMENT; }\ + { constexpr u64 VAR = 6; STATEMENT; }\ + { constexpr u64 VAR = 7; STATEMENT; }\ + do{}while(0) + + template + macoro::task<> expand( + Output&& output, + PRNG& prng, + coproto::Socket& sock) + { + if constexpr (std::is_same, Matrix>::value) + { + if (output.rows() != mPoints.size()) + throw RTE_LOC; + if (output.cols() != mDomain) + throw RTE_LOC; + } + + u64 numPoints = mPoints.size(); + u64 numPoints8 = numPoints / 8 * 8; + + + // shares of S' + //std::vector> s(mDepth + 2); + auto pow2 = 1ull << log2ceil(mDomain); + std::array, 2> s; + s[mDepth & 1].resize(pow2, numPoints, oc::AllocType::Uninitialized); + s[(mDepth & 1) ^ 1].resize(pow2/2, numPoints, oc::AllocType::Uninitialized); + + //s[0].resize(1, mPoints.size()); + prng.get(s[0].data(), 1); + + + // share of t + std::array, 2> t; + t[0].resize(s[0].rows(), s[0].cols()); + t[1].resize(s[1].rows(), s[1].cols()); + for (u64 i = 0; i < numPoints; ++i) + t[0](0,i) = mPartyIdx; + //std::vector> t(mDepth + 2); + //t[0].resize(1, mPoints.size()); + //for (auto& tt : t[0]) + // tt = mPartyIdx; + + std::array, 2> tau; + tau[0].resize(mPoints.size()); + tau[1].resize(mPoints.size()); + + std::array hashes{ + block(223142132554234532,345324534532452345), + block(476657546875476456,849723947534923433), + }; + std::array, 2> z, zg; + z[0].resize(mPoints.size()); + z[1].resize(mPoints.size()); + zg[0].resize(mPoints.size()); + zg[1].resize(mPoints.size()); + AlignedUnVector sigma(mPoints.size()); + BitVector negAlphaj(mPoints.size()); + AlignedUnVector diff(mPoints.size()); + + + { + //s[1].resize(2, mPoints.size()); + //t[1].resize(2, mPoints.size()); + + setBytes(z[0], 0); + setBytes(z[1], 0); + + auto spi = s[0][0]; + auto sc0 = s[1][0]; + auto sc1 = s[1][1]; + for (u64 k = 0; k < numPoints; ++k) + { + sc0[k] = hashes[0].hashBlock(spi[k]); + sc1[k] = hashes[1].hashBlock(spi[k]); + + z[0][k] ^= sc0[k]; + z[1][k] ^= sc1[k]; + } + } + + for (u64 iter = 1; iter <= mDepth; ++iter) + { + //auto& sp = s[iter - 1]; + auto& tp = t[(iter - 1) & 1]; + auto& sc = s[iter & 1]; + auto& tc = t[iter & 1]; + auto& sg = s[(iter + 1) & 1]; + + auto size = 1ull << iter; + auto size2 = 1ull << (iter + 1); + + if (iter != mDepth) + { + //sg.resize(size2, mPoints.size()); + //t[iter + 1].resize(size2, mPoints.size()); + + setBytes(zg[0], 0); + setBytes(zg[1], 0); + } + + for (u64 k = 0; k < mPoints.size(); ++k) + { + auto alphaj = *oc::BitIterator(&mPoints[k], mDepth - iter); + tau[0][k] = lsb(z[0][k]) ^ alphaj ^ mPartyIdx; + tau[1][k] = lsb(z[1][k]) ^ alphaj; + diff[k] = z[0][k] ^ z[1][k]; + negAlphaj[k] = alphaj ^ mPartyIdx; + } + + co_await multiply(negAlphaj, diff, diff, sock); + // sigma = z[1^alpha[j]] + for (u64 k = 0; k < mPoints.size(); ++k) + sigma[k] = diff[k] ^ z[0][k]; + + // reveal + u64 buffSize = sigma.size() * 16 + divCeil(mPoints.size() * 2, 8); + AlignedUnVector buffer(buffSize); + copyBytesMin(buffer, sigma); + auto bitIter = BitIterator(&buffer[numPoints * 16]); + for (u64 i = 0; i < mPoints.size(); ++i) + { + *bitIter++ = tau[0][i]; + *bitIter++ = tau[1][i]; + } + if (bitIter.mByte >= buffer.data() + buffer.size() && bitIter.mShift) + throw RTE_LOC; + co_await sock.send(std::move(buffer)); + buffer.resize(buffSize); + bitIter = BitIterator(&buffer[numPoints * 16]); + co_await sock.recv(buffer); + for (u64 k = 0; k < mPoints.size(); ++k) + { + block sk = *(block*)&buffer[k * sizeof(block)]; + sigma[k] ^= sk; + tau[0][k] ^= *bitIter++; + tau[1][k] ^= *bitIter++; + } + + if (iter == mDepth) + { + + for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 2, L4 += 4) + { +#if defined(NDEBUG) + auto tpl = tp.data(L); + auto scl0 = sc.data(L2 + 0); + auto scl1 = sc.data(L2 + 1); + auto tcl0 = tc.data(L2 + 0); + auto tcl1 = tc.data(L2 + 1); +#else + auto tpl = tp[L]; + auto scl0 = sc[L2 + 0]; + auto scl1 = sc[L2 + 1]; + auto tcl0 = tc[L2 + 0]; + auto tcl1 = tc[L2 + 1]; +#endif + + for (u64 k = 0; k < numPoints8; k += 8) + { + block T[8]; + SIMD8(q, T[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); + SIMD8(q, tcl0[k + q] = lsb(scl0[k + q]) ^ tpl[k + q] & tau[0][k + q]); + SIMD8(q, tcl1[k + q] = lsb(scl1[k + q]) ^ tpl[k + q] & tau[1][k + q]); + SIMD8(q, scl0[k + q] ^= T[q]); + SIMD8(q, scl1[k + q] ^= T[q]); + } + + for (u64 k = numPoints8; k < mPoints.size(); ++k) + { + auto T = block::allSame(-tpl[k + 0]) & sigma[k + 0]; + tc[L2 + 0][k] = lsb(sc[L2 + 0][k]) ^ tp[L][k] & tau[0][k]; + tc[L2 + 1][k] = lsb(sc[L2 + 1][k]) ^ tp[L][k] & tau[1][k]; + sc[L2 + 0][k] ^= T; + sc[L2 + 1][k] ^= T; + } + } + } + else + { + + for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 2, L4 += 4) + { +#if defined(NDEBUG) + auto tpl = tp.data(L); + auto scl0 = sc.data(L2 + 0); + auto scl1 = sc.data(L2 + 1); + auto tcl0 = tc.data(L2 + 0); + auto tcl1 = tc.data(L2 + 1); + + auto sg00 = sg.data(L4 + 0); + auto sg10 = sg.data(L4 + 1); + auto sg01 = sg.data(L4 + 2); + auto sg11 = sg.data(L4 + 3); +#else + + auto tpl = tp[L]; + auto scl0 = sc[L2 + 0]; + auto scl1 = sc[L2 + 1]; + auto tcl0 = tc[L2 + 0]; + auto tcl1 = tc[L2 + 1]; + + auto sg00 = sg[L4 + 0]; + auto sg10 = sg[L4 + 1]; + auto sg01 = sg[L4 + 2]; + auto sg11 = sg[L4 + 3]; +#endif + + for (u64 k = 0; k < numPoints8; k += 8) + { + block T[8]; + SIMD8(q, T[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); + SIMD8(q, tcl0[k + q] = lsb(scl0[k + q]) ^ tpl[k + q] & tau[0][k + q]); + SIMD8(q, scl0[k + q] ^= T[q]); + + hashes[0].ecbEncBlocks<8>(&scl0[k], &sg10[k]); + SIMD8(q, sg00[k + q] = AES::roundEnc(sg10[k + q], scl0[k + q])); + SIMD8(q, sg10[k + q] = sg10[k + q] + scl0[k + q]); + + SIMD8(q, zg[0][k + q] ^= sg00[k + q]); + SIMD8(q, zg[1][k + q] ^= sg10[k + q]); + + SIMD8(q, tcl1[k + q] = lsb(scl1[k + q]) ^ tpl[k + q] & tau[1][k + q]); + SIMD8(q, scl1[k + q] ^= T[q]); + + hashes[0].ecbEncBlocks<8>(&scl1[k], &sg11[k]); + SIMD8(q, sg01[k + q] = AES::roundEnc(sg11[k + q], scl1[k + q])); + SIMD8(q, sg11[k + q] = sg11[k + q] + scl1[k + q]); + SIMD8(q, zg[0][k + q] ^= sg01[k + q]); + SIMD8(q, zg[1][k + q] ^= sg11[k + q]); + } + + for (u64 k = numPoints8; k < mPoints.size(); ++k) + { + auto T = block::allSame(-tpl[k + 0]) & sigma[k + 0]; + + tcl0[k] = lsb(scl0[k]) ^ tpl[k] & tau[0][k]; + scl0[k] ^= T; + + sg10[k] = hashes[0].ecbEncBlock(scl0[k]); + sg00[k] = AES::roundEnc(sg10[k], scl0[k]); + sg10[k] = sg10[k] + scl0[k]; + + zg[0][k] ^= sg00[k]; + zg[1][k] ^= sg10[k]; + + tcl1[k] = lsb(scl1[k]) ^ tpl[k] & tau[1][k]; + scl1[k] ^= T; + + sg11[k] = hashes[0].ecbEncBlock(scl1[k]); + sg01[k] = AES::roundEnc(sg11[k], scl1[k]); + sg11[k] = sg11[k] + scl1[k]; + + zg[0][k] ^= sg01[k]; + zg[1][k] ^= sg11[k]; + } + } + } + + std::swap(z, zg); + } + + if (mValues.size()) + { + + AlignedUnVector gamma(mPoints.size()); + for (u64 k = 0; k < mPoints.size(); ++k) + { + diff[k] = zg[0][k] ^ zg[1][k] ^ mValues[k]; + } + co_await sock.send(std::move(diff)); + co_await sock.recv(gamma); + for (u64 k = 0; k < mPoints.size(); ++k) + { + gamma[k] = zg[0][k] ^ zg[1][k] ^ mValues[k] ^ gamma[k]; + } + + auto& sd = s[mDepth&1]; + auto& td = t[mDepth&1]; + for (u64 i = 0; i < mDomain; ++i) + { +#if defined(NDEBUG) + auto sdi = sd.data(i); + auto tdi = td.data(i); +#else + auto sdi = sd[i]; + auto tdi = td[i]; +#endif + + for (u64 k = 0; k < numPoints8; k += 8) + { + block T[8]; + + SIMD8(q, T[q] = block::allSame(-tdi[k + q]) & gamma[k + q]); + SIMD8(q, output(k + q, i) = sdi[k + q] ^ T[q]); + } + for (u64 k = numPoints8; k < mPoints.size(); ++k) + { + auto T = block::allSame(-tdi[k]) & gamma[k]; + output(k, i) = sdi[k] ^ T; + } + } + } + } + + + + + // We are given two OTs, one in each direction. Let us denote them as + // + // a0 b0 + // c00 c01 + // + // b1 a1 + // c10 c11 + // + // such that + // + // a0 * b0 = (c00 + c01) + // a1 * b1 = (c10 + c11) + // + // Note that we write these OTs in OLE format, that is for OT (m0,m1),(g,mg) + // we have a0=g, b0=(m0+m1), c00=mg, c01=m0 and similar for the second + // instance. + // + // We first convert these two "OTs/OLEs" into a random beaver triple + // + // [a] * [b] = [c'] + // + // We do this by computing + // + // [a] = (a0, a1) + // [b] = (b1, b0) + // [c'] = (c00+c10+a0b1, c01+c11+a1b0) + // + // As you can see, all 4 cross terms are present. Given this beaver triple + // we can use the standard protocol. We reveal + // + // phi = [x] + [a] + // theta = [y] + [b] + // + // [zy] = [c'] + theta a + phi b + theta phi + // = ab + (y+b) a + (x+a) b + (y+b)(x+a) + // = ab + ab + ya + xb + ab + yx + ya + xb + ab + // = xy + // + macoro::task<> multiply(const oc::BitVector& x, span y, span xy, coproto::Socket& sock) + { + if (x.size() != y.size() || x.size() != xy.size()) + throw RTE_LOC; + BitVector a0; a0.append(mChoiceBits, x.size(), mOtIdx); + AlignedUnVector A0(x.size()), C(x.size()), theta(x.size()), b1(x.size()); + for (u64 j = 0; j < x.size(); ++j) + { + A0[j] = block(-u64(a0[j]), -u64(a0[j])); + auto c00 = mRecvOts[mOtIdx + j]; + + auto c10 = mSendOts[mOtIdx + j][0]; + + b1[j] = mSendOts[mOtIdx + j][0] ^ mSendOts[mOtIdx + j][1]; + // C0' = c00+c10+a0b1 + C[j] = c00 ^ c10 ^ (b1[j] & A0[j]); + + theta[j] = y[j] ^ b1[j]; + } + auto phi = x ^ a0; + while (phi.size() % 8) + phi.pushBack(0); + + AlignedUnVector buffer(theta.size() + phi.sizeBlocks()); + memcpy(buffer.data(), theta.data(), theta.size() * sizeof(block)); + memcpy(buffer.data() + theta.size(), phi.data(), phi.sizeBytes()); + + co_await sock.send(std::move(buffer)); + + buffer.resize(theta.size() + phi.sizeBlocks()); + co_await sock.recv(buffer); + span theta1(buffer.data(), theta.size()); + BitVector phi1((u8*)&buffer[theta.size()], phi.size()); + + phi ^= phi1; + for (u64 j = 0; j < x.size(); ++j) + { + auto Phi = block(-u64(phi[j]), -u64(phi[j])); + theta[j] ^= theta1[j]; + xy[j] = C[j] ^ theta[j] & A0[j] ^ Phi & b1[j]; + + if (mPartyIdx) + xy[j] ^= theta[j] & Phi; + } + + + mOtIdx += x.size(); + + } + + u64 baseOtCount() const { return mDepth * mPoints.size(); } + + void setBaseOts( + span> baseSendOts, + span recvBaseOts, + const oc::BitVector& baseChoices) + { + if (baseSendOts.size() != baseOtCount() || + recvBaseOts.size() != baseOtCount() || + baseChoices.size() != baseOtCount()) + throw RTE_LOC; + + mSendOts.clear(); + mRecvOts.clear(); + mSendOts.insert(mSendOts.end(), baseSendOts.begin(), baseSendOts.end()); + mRecvOts.insert(mRecvOts.end(), recvBaseOts.begin(), recvBaseOts.end()); + mChoiceBits = baseChoices; + mOtIdx = 0; + } + + + }; + +} + +#undef SIMD8 \ No newline at end of file diff --git a/libOTe_Tests/BgciksOT_Tests.h b/libOTe_Tests/BgciksOT_Tests.h deleted file mode 100644 index 19e4bc97..00000000 --- a/libOTe_Tests/BgciksOT_Tests.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once -// © 2016 Peter Rindal. -// © 2022 Visa. -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -#include - -void SilentPprf_Test(const oc::CLP& cmd); -void SilentPprf_trans_Test(const oc::CLP& cmd); -void SilentOT_Test(const oc::CLP& cmd); - -void bitShift_test(const oc::CLP& cmd); -void modp_test(const oc::CLP& cmd); - -void SilentOT_mul_Test(const oc::CLP& cmd); \ No newline at end of file diff --git a/libOTe_Tests/CMakeLists.txt b/libOTe_Tests/CMakeLists.txt index d219908b..474f08d0 100644 --- a/libOTe_Tests/CMakeLists.txt +++ b/libOTe_Tests/CMakeLists.txt @@ -1,7 +1,21 @@ -#project(libOTe_Tests) +set(SRCS + BaseOT_Tests.cpp + bitpolymul_Tests.cpp + Common.cpp + EACode_Tests.cpp + ExCOnvCode_Tests.cpp + NcoOT_Tests.cpp + OT_Tests.cpp + Pprf_Tests.cpp + RegularDpf_Tests.cpp + SilentOT_Tests.cpp + Softspoken_Tests.cpp + TungstenCode_Tests.cpp + UnitTests.cpp + Vole_Tests.cpp +) -file(GLOB SRCS *.cpp) add_library(libOTe_Tests STATIC ${SRCS}) target_link_libraries(libOTe_Tests libOTe) diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp new file mode 100644 index 00000000..41c810a9 --- /dev/null +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -0,0 +1,168 @@ +#include "RegularDpf_Tests.h" +#include "libOTe/Tools/Dpf/RegularDpf.h" +#include "coproto/Socket/LocalAsyncSock.h" + +using namespace oc; + +void RegularDpf_Multiply_Test(const CLP& cmd) +{ + u64 n = 13; + PRNG prng(block(231234, 321312)); + std::array dpf; + dpf[0].mPartyIdx = 0; + dpf[1].mPartyIdx = 1; + dpf[0].mSendOts.push_back(prng.get()); + dpf[1].mSendOts.push_back(prng.get()); + dpf[0].mChoiceBits.pushBack(0); + dpf[1].mChoiceBits.pushBack(1); + dpf[0].mRecvOts.push_back(dpf[1].mSendOts[0][dpf[0].mChoiceBits[0]]); + dpf[1].mRecvOts.push_back(dpf[0].mSendOts[0][dpf[1].mChoiceBits[0]]); + + + { + u64 i = 0; + u64 a0 = dpf[0].mChoiceBits[i]; + block A0 = block(-a0, -a0); + auto c00 = dpf[0].mRecvOts[i]; + + auto b1 = dpf[0].mSendOts[i][0] ^ dpf[0].mSendOts[i][1]; + auto c10 = dpf[0].mSendOts[i][0]; + + // C0' = c00+c10+a0b1 + auto C0 = c00 ^ c10 ^ (b1 & A0); + + u64 a1 = dpf[1].mChoiceBits[i]; + block A1 = block(-a1, -a1); + auto c01 = dpf[1].mRecvOts[i]; + + auto b0 = dpf[1].mSendOts[i][1] ^ dpf[1].mSendOts[i][1]; + auto c11 = dpf[1].mSendOts[i][0]; + + // C0' = c00+c10+a0b1 + auto C1 = c11 ^ c01 ^ (b0 & A1); + + auto a = a0 ^ a1; + auto B = b0 ^ b1; + auto C = C0 ^ C1; + + if (a == 0 && C != oc::ZeroBlock) + throw RTE_LOC; + if (a == 1 && C != B) + throw RTE_LOC; + } + + auto sock = coproto::LocalAsyncSocket::makePair(); + + for (u64 i = 0; i < 100; ++i) + { + //std::cout << "-=========================-" std::endl; + + for (u64 j = 0; j < n; ++j) + { + dpf[0].mSendOts.push_back(prng.get()); + dpf[1].mSendOts.push_back(prng.get()); + dpf[0].mChoiceBits.pushBack(prng.getBit()); + dpf[1].mChoiceBits.pushBack(prng.getBit()); + dpf[0].mRecvOts.push_back(dpf[1].mSendOts.back()[dpf[0].mChoiceBits.back()]); + dpf[1].mRecvOts.push_back(dpf[0].mSendOts.back()[dpf[1].mChoiceBits.back()]); + } + + + BitVector x0(n), x1(n); + x0.randomize(prng); + x1.randomize(prng); + std::vector xy0(n), xy1(n), y0(n), y1(n); + + prng.get(y0.data(), y0.size()); + prng.get(y1.data(), y1.size()); + + macoro::sync_wait(macoro::when_all_ready( + dpf[0].multiply(x0, y0, xy0, sock[0]), + dpf[1].multiply(x1, y1, xy1, sock[1]) + )); + + for (u64 j = 0; j < n; ++j) + { + + u64 x = x0[j] ^ x1[j]; + auto y = y0[j] ^ y1[j]; + auto xy = xy0[j] ^ xy1[j]; + auto exp = block(-x, -x) & y; + if (xy != exp) + { + std::cout << "i " << i << std::endl; + std::cout << "act " << xy << " " << xy0[j] << " + " << xy1[j] << std::endl; + std::cout << "exp " << exp << std::endl; + throw RTE_LOC; + } + } + } +} + +void RegularDpf_Proto_Test(const CLP& cmd) +{ + PRNG prng(block(231234, 321312)); + u64 domain = 8; + u64 numPoints = 11; + std::vector points0(numPoints); + std::vector points1(numPoints); + std::vector values0(numPoints); + std::vector values1(numPoints); + for (u64 i = 0; i < numPoints; ++i) + { + points1[i] = prng.get(); + points0[i] = (prng.get() % domain) ^ points1[i]; + values0[i] = prng.get(); + values1[i] = prng.get(); + } + + std::array dpf; + dpf[0].init(0, domain, points0, values0); + dpf[1].init(1, domain, points1, values1); + + auto baseCount = dpf[0].baseOtCount(); + + std::array, 2> baseRecv; + std::array>, 2> baseSend; + std::array baseChoice; + baseRecv[0].resize(baseCount); + baseRecv[1].resize(baseCount); + baseSend[0].resize(baseCount); + baseSend[1].resize(baseCount); + baseChoice[0].resize(baseCount); + baseChoice[1].resize(baseCount); + baseChoice[0].randomize(prng); + baseChoice[1].randomize(prng); + for (u64 i = 0; i < baseCount; ++i) + { + baseSend[0][i] = prng.get(); + baseSend[1][i] = prng.get(); + baseRecv[0][i] = baseSend[1][i][baseChoice[0][i]]; + baseRecv[1][i] = baseSend[0][i][baseChoice[1][i]]; + } + dpf[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); + dpf[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); + + std::array, 2> output; + output[0].resize(numPoints, domain); + output[1].resize(numPoints, domain); + + auto sock = coproto::LocalAsyncSocket::makePair(); + macoro::sync_wait(macoro::when_all_ready( + dpf[0].expand(output[0], prng, sock[0]), + dpf[1].expand(output[1], prng, sock[1]) + )); + + + for (u64 i = 0; i < domain; ++i) + { + for (u64 k = 0; k < numPoints; ++k) + { + auto p = points0[k] ^ points1[k]; + auto act = output[0][k][i] ^ output[1][k][i]; + auto exp = i == p ? (values0[k] ^ values1[k]) : ZeroBlock; + if (exp != act) + throw RTE_LOC; + } + } +} \ No newline at end of file diff --git a/libOTe_Tests/RegularDpf_Tests.h b/libOTe_Tests/RegularDpf_Tests.h new file mode 100644 index 00000000..8ef7fdcc --- /dev/null +++ b/libOTe_Tests/RegularDpf_Tests.h @@ -0,0 +1,6 @@ + +#pragma once +#include "cryptoTools/Common/CLP.h" + +void RegularDpf_Multiply_Test(const oc::CLP& cmd); +void RegularDpf_Proto_Test(const oc::CLP& cmd); diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 37d86a7c..809194ee 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -16,6 +16,7 @@ #include "libOTe/Tools/LDPC/Mtx.h" #include "libOTe_Tests/Pprf_Tests.h" #include "libOTe_Tests/TungstenCode_Tests.h" +#include "libOTe_Tests/RegularDpf_Tests.h" using namespace osuCrypto; namespace tests_libOTe @@ -57,7 +58,9 @@ namespace tests_libOTe tc.add("Tools_Pprf_ByTreeIndex_test ", Tools_Pprf_ByTreeIndex_test); tc.add("Tools_Pprf_callback_test ", Tools_Pprf_callback_test); - + tc.add("RegularDpf_Multiply_Test ", RegularDpf_Multiply_Test); + tc.add("RegularDpf_Proto_Test ", RegularDpf_Proto_Test); + tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); tc.add("Bot_Simplest_asm_Test ", Bot_Simplest_asm_Test); From 0d619a2c9f0af5e461063aa500c5f064defe447c Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 4 Dec 2024 10:39:33 -0800 Subject: [PATCH 02/48] partia --- frontend/H4.cpp | 918 ++++++++++++++++++++++++++++++ frontend/benchmark.h | 12 +- frontend/main.cpp | 3 + libOTe/Tools/Dpf/DpfMult.h | 163 ++++++ libOTe/Tools/Dpf/RegularDpf.h | 495 ++++++---------- libOTe/Tools/Dpf/SparseDpf.h | 335 +++++++++++ libOTe_Tests/RegularDpf_Tests.cpp | 51 +- libOTe_Tests/RegularDpf_Tests.h | 1 + libOTe_Tests/UnitTests.cpp | 1 + 9 files changed, 1657 insertions(+), 322 deletions(-) create mode 100644 frontend/H4.cpp create mode 100644 libOTe/Tools/Dpf/DpfMult.h create mode 100644 libOTe/Tools/Dpf/SparseDpf.h diff --git a/frontend/H4.cpp b/frontend/H4.cpp new file mode 100644 index 00000000..56c3afc1 --- /dev/null +++ b/frontend/H4.cpp @@ -0,0 +1,918 @@ +///////////////////////////////////////////////////////////////////////////// +//// Example source code for blog post: +//// "C++ Coroutines: Understanding Symmetric-Transfer" +//// +//// Implementation of a naive 'task' coroutine type. +// +//#include +//#include +//#include +//// using namespace std; +// +//#ifndef H4_H +//#define H4_H +// +//#define H4_VERSION "4.0.8" +// +//#ifndef H4_USERLOOP +//#define H4_USERLOOP 1 // improves performance +//#endif +//#define H4_COUNT_LOOPS 0 // DIAGNOSTICS +//#define H4_HOOK_TASKS 0 +// +//#define H4_JITTER_LO 100 // Entropy lower bound +//#define H4_JITTER_HI 350 // Entropy upper bound +//#define H4_Q_CAPACITY 10 // Default Q capacity +//#define H4_Q_ABS_MIN 6 // Absolute minimum Q capacity +// +//#define H4_DEBUG 0 +// +//#define H4_SAFETY_TIME 200 // ms, the time space where h4 could fix rollover issue, +// // too long might let more functions called earler if these falls just between millis() rollover and this period, just after the rollover, +// // too tight might cause missing it (if the h4.loop() didn't take control at the short period) +// +//#if H4_DEBUG +//#define H4_Pirntf(f_, ...) Serial.printf((f_), ##__VA_ARGS__) +//#else +//#define H4_Pirntf(f_, ...) +//#endif +// +// +//enum { +// H4_CHUNKER_ID = 90, +// H4AT_SCAVENGER_ID, +// H4AS_SSE_KA_ID, +// H4AS_WS_KA_ID, +// H4AMC_RCX_ID, +// H4AMC_KA_ID +//}; // must not grow past 99! +// +//// #include +// +//#include +//#include +//#include +//#include +//#include +//#include +//#include +//#include +//#include +//#define __PRETTY_FUNCTION__ __FUNCSIG__ +//#ifdef ARDUINO_ARCH_RP2040 +//#define h4rebootCore rp2040.restart +//#elif defined(ARDUINO) +//#define h4rebootCore ESP.restart +//#else +//void somef() {} +//#define h4rebootCore somef +//#endif +//#define H4_BOARD ARDUINO_BOARD +// +//uint32_t globMillis; +//uint32_t millis() { +// return globMillis; +//} +// +//void debugFunction(std::string f) { std::cout << f << std::endl; } +//void h4reboot(); +// +//void HAL_enableInterrupts(); +//void HAL_disableInterrupts(); +// +//uint64_t millis64(); +// +// +//class task; +//using H4_TASK_PTR = task*; +//using H4_TIMER = H4_TASK_PTR; +// +//class H4Delay; +//struct H4Coroutine {}; +// +//using H4_FN_COUNT = std::function; +//using H4_FN_TASK = std::function; +//using H4_FN_TIF = std::function; +//using H4_FN_VOID = std::function; +//using H4_FN_COROUTINE = std::function; +//using H4_FN_RTPTR = H4_FN_COUNT; +//// +//using H4_INT_MAP = std::unordered_map; +//using H4_TIMER_MAP = std::unordered_map; +//// +// +//#define CSTR(x) x.c_str() +//#define ME H4::context +//#define MY(x) H4::context->x +//#define TAG(x) (u+((x)*100)) +// +//extern H4_TASK_PTR& H4_context; +// +//class H4Countdown { +//public: +// uint32_t count; +// H4Countdown(uint32_t start = 1) { count = start; } +// uint32_t operator()() { return --count; } +//}; +// +//class H4Random : public H4Countdown { +//public: +// H4Random(uint32_t tmin = 0, uint32_t tmax = 0); +//}; +// +//// +//// T A S K +//// +//class task { +// bool harakiri = false; +// +// void _chain(); +// void _destruct(); +// friend class H4Delay; +//public: +// uint64_t id; +// H4_FN_VOID f; +// H4_FN_COROUTINE fcoro; +// uint32_t rmin = 0; +// uint32_t rmax = 0; +// H4_FN_COUNT reaper; +// H4_FN_VOID chain; +// // H4_FN_COROUTINE chaincoro; +// uint32_t uid = 0; +// bool singleton = false; +// H4_FN_VOID lastRites = [] {}; +// size_t len = 0; +// uint64_t at; +// uint32_t nrq = 0; +// void* partial = NULL; +// +// bool operator()(const task* lhs, const task* rhs) const; +// void operator()(); +// +// task() {} // only for comparison operator +// +// task( +// H4_FN_VOID _f, +// uint32_t _m, +// uint32_t _x, +// H4_FN_COUNT _r, +// H4_FN_VOID _c, +// uint32_t _u = 0, +// bool _s = false +// ); +// +// task( +// H4_FN_COROUTINE _f, +// uint32_t _m, +// uint32_t _x, +// H4_FN_COUNT _r, +// H4_FN_VOID _c, +// uint32_t _u = 0, +// bool _s = false +// ); +// +// ~task() {}//H4_Pirntf("T=%u TASK DTOR %p\n",millis(),this); } +// +// static void cancelSingleton(uint32_t id); +// uint32_t cleardown(uint32_t t); +// // The many ways to die... :) +// uint32_t endF(); // finalise: finishEarly +// uint32_t endU(); // unconditional finishNow; +// uint32_t endC(H4_FN_TIF); // conditional +// uint32_t endK(); // kill, chop etc +// // +// void createPartial(void* d, size_t l); +// void getPartial(void* d) { memcpy(d, partial, len); } +// void putPartial(void* d) { memcpy(partial, d, len); } +// void requeue(); +// void schedule(); +// static uint32_t randomRange(uint32_t lo, uint32_t hi); // move to h4 +//}; +// +//class H4Coroutine +//{ +// +// // task* owner; +// uint32_t duration; +// task* owner = nullptr; +// task* resumer = nullptr; +//public: +// class promise_type { +// // uint32_t duration; +// task* owner = nullptr; +// task* resumer = nullptr; +// friend class H4Coroutine; +// public: +// H4Coroutine get_return_object() noexcept; +// std::suspend_never initial_suspend() noexcept; +// void return_void() noexcept; +// void unhandled_exception() noexcept; +// struct final_awaiter; +// final_awaiter final_suspend() noexcept; +// +// void cancel(); +// }; +// std::coroutine_handle _coro; +// +// explicit H4Coroutine(std::coroutine_handle h) : _coro(h) { debugFunction(__PRETTY_FUNCTION__); printf("this=%p\n", this); printf("h=%p\n", h.address()); } +// ~H4Coroutine() { +// debugFunction(__PRETTY_FUNCTION__); +// // printf("this=%p\n", this); +// // printf("_coro=%p\tduration=%u\towner=%p\tresumer=%p\n", _coro,duration,owner,resumer); +// // if (_coro) _coro.destroy(); +// } +//}; +// +//class H4Delay { +// task* owner; +// uint32_t duration; +//public: +// +// explicit H4Delay(uint32_t duration, task* caller = H4_context) : duration(duration), owner(caller) { +// // debugFunction(__PRETTY_FUNCTION__); printf("this=%p\n", this); +// // printf("_coro=%p\tduration=%u\towner=%p\tresumer=%p\n", _coro,duration,owner,resumer); +// } +// +// bool await_ready() noexcept; +// void await_suspend(const std::coroutine_handle h) noexcept; +// void await_resume() noexcept; +//}; +//// +//// H 4 +//// +// +//class H4 : public std::priority_queue, task> { // H4P 35500 - 35700 +// friend class task; +// H4_TIMER_MAP singles; +// std::vector loopChain; +//public: +// std::unordered_map unloadables; +// static H4_TASK_PTR context; +// static std::map> suspendedTasks; +// +// +// void loop(); +// void setup(); +// +// H4(uint32_t baud = 0, size_t qSize = H4_Q_CAPACITY) { +// reserve(qSize); +// if (baud) { +// // Serial.begin(baud); +// H4_Pirntf("\nH4 RUNNING %s\n", H4_VERSION); +// } +// } +// +// H4_TASK_PTR every(uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR everyRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR nTimes(uint32_t n, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR nTimesRandom(uint32_t n, uint32_t msec, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR once(uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR onceRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR queueFunction(H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR randomTimes(uint32_t tmin, uint32_t tmax, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR randomTimesRandom(uint32_t tmin, uint32_t tmax, uint32_t msec, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR repeatWhile(H4_FN_COUNT w, uint32_t msec, H4_FN_VOID fn = []() {}, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR repeatWhileEver(H4_FN_COUNT w, uint32_t msec, H4_FN_VOID fn = []() {}, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// +// H4_TASK_PTR every(uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR everyRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR nTimes(uint32_t n, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR nTimesRandom(uint32_t n, uint32_t msec, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR once(uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR onceRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR queueFunction(H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR randomTimes(uint32_t tmin, uint32_t tmax, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR randomTimesRandom(uint32_t tmin, uint32_t tmax, uint32_t msec, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR repeatWhile(H4_FN_COUNT w, uint32_t msec, H4_FN_COROUTINE fn = [](H4Coroutine) -> H4Delay { return H4Delay(0); }, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// H4_TASK_PTR repeatWhileEver(H4_FN_COUNT w, uint32_t msec, H4_FN_COROUTINE fn = [](H4Coroutine) -> H4Delay { return H4Delay(0); }, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); +// +// H4_TASK_PTR cancel(H4_TASK_PTR t = context) { return endK(t); } // ? rv ? +// void cancel(std::initializer_list l) { for (auto const t : l) cancel(t); } +// void cancelAll(H4_FN_VOID fn = nullptr); +// void cancelSingleton(uint32_t s) { task::cancelSingleton(s); } +// void cancelSingleton(std::initializer_list l) { for (auto const i : l) cancelSingleton(i); } +// uint32_t finishEarly(H4_TASK_PTR t = context) { return endF(t); } +// uint32_t finishNow(H4_TASK_PTR t = context) { return endU(t); } +// bool finishIf(H4_TASK_PTR t, H4_FN_TIF f) { return endC(t, f); } +// // syscall only +// size_t _capacity() { return c.capacity(); } +// std::vector _copyQ(); +// void _hookLoop(H4_FN_VOID f, uint32_t subid); +// bool _unHook(uint32_t token); +// +// // protected: +// uint32_t gpFramed(task* t, std::function f); +// bool has(task* t) { return find(c.begin(), c.end(), t) != c.end(); } +// uint32_t endF(task* t); +// uint32_t endU(task* t); +// bool endC(task* t, H4_FN_TIF f); +// task* endK(task* t); +// void qt(task* t); +// void reserve(size_t n) { c.reserve(n); } +// H4_FN_TASK taskEvent = [](task*, uint32_t) {}; +// // +//#if H4_HOOK_TASKS +// static H4_FN_TASK taskHook; +// +// void _hookTask(H4_FN_TASK f) { taskHook = f; } +// static std::string dumpTask(task* t, uint32_t faze); +// static void addTaskNames(H4_INT_MAP names); +// static std::string getTaskType(uint32_t t); +// static const char* getTaskName(uint32_t t); +//#else +// static void addTaskNames(H4_INT_MAP names) {} +//#endif +// static void dumpQ(); +// // public: +// task* add(H4_FN_VOID _f, uint32_t _m, uint32_t _x, H4_FN_COUNT _r, H4_FN_VOID _c, uint32_t _u = 0, bool _s = false); +// task* add(H4_FN_COROUTINE _f, uint32_t _m, uint32_t _x, H4_FN_COUNT _r, H4_FN_VOID _c, uint32_t _u = 0, bool _s = false); +//}; +// +//template +//class pr { +// size_t size = sizeof(T); +// +// template +// T2 put(T2 v) { +// memcpy(MY(partial), reinterpret_cast(&v), size); +// return get(); +// } +// template +// T2 get() { return (*(reinterpret_cast(MY(partial)))); } +// +//public: +// pr(T v) { +// if (!MY(partial)) { +// MY(partial) = reinterpret_cast(malloc(size)); +// put(v); +// } +// } +// +// pr operator=(const T other) { return put(other); } +// +// operator T() { return get(); } +// +// T operator +(T v) { return get() + v; } +// +// T operator +=(T v) { return put(get() + v); } +// +// T* operator->() const { +// return reinterpret_cast(MY(partial)); +// } +//}; +// +//extern H4 h4; +// +//template +//static void h4Chunker(T& x, std::function fn, uint32_t lo = H4_JITTER_LO, uint32_t hi = H4_JITTER_HI, H4_FN_VOID final = nullptr) { +// H4_TIMER p = h4.repeatWhile( +// H4Countdown(x.size()), +// task::randomRange(lo, hi), // arbitrary +// [=]() { +// typename T::iterator thunk; +// ME->getPartial(&thunk); +// fn(thunk++); +// ME->putPartial((void*)&thunk); +// // yield(); +// }, +// final, +// H4_CHUNKER_ID); +// typename T::iterator chunkIt = x.begin(); +// p->createPartial((void*)&chunkIt, sizeof(typename T::iterator)); +// p->lastRites = [=] { +// free(p->partial); +// p->partial = nullptr; +// }; +//} +// +//#endif // H4_H +// +//#define __attribute__(X) +// +//////////////////////////// H4.cpp ///////////////////////////// +//#ifdef ARDUINO_ARCH_ESP32 +//portMUX_TYPE h4_mutex = portMUX_INITIALIZER_UNLOCKED; +//void HAL_enableInterrupts() { portEXIT_CRITICAL(&h4_mutex); } +//void HAL_disableInterrupts() { portENTER_CRITICAL(&h4_mutex); } +//#else +//void HAL_enableInterrupts() { /* interrupts(); */ } +//void HAL_disableInterrupts() { /* noInterrupts(); */ } +//#endif +//// +//// and ...here we go! +//// +//void __attribute__((weak)) h4setup(); +//void __attribute__((weak)) h4UserLoop(); +// +//H4_TIMER H4::context = nullptr; +//H4_TASK_PTR& H4_context = H4::context; +// +//std::map> H4::suspendedTasks; +// +//void h4reboot() { h4rebootCore(); } +// +//H4Random::H4Random(uint32_t rmin, uint32_t rmax) { count = task::randomRange(rmin, rmax); } +// +//__attribute__((weak)) H4_INT_MAP h4TaskNames = {}; +// +//#if H4_COUNT_LOOPS +//uint32_t h4Nloops = 0; +//#endif +// +//H4Delay H4Delay::promise_type::get_return_object() noexcept { +// debugFunction(__PRETTY_FUNCTION__); +// return H4Delay(std::coroutine_handle::from_promise(*this)); +//} +//std::suspend_never H4Delay::promise_type::initial_suspend() noexcept { debugFunction(__PRETTY_FUNCTION__); return {}; } +//void H4Delay::promise_type::return_void() noexcept { debugFunction(__PRETTY_FUNCTION__); } +//void H4Delay::promise_type::unhandled_exception() noexcept { debugFunction(__PRETTY_FUNCTION__); std::terminate(); } +//struct H4Delay::promise_type::final_awaiter { +// bool await_ready() noexcept { debugFunction(__PRETTY_FUNCTION__); return false; } +// bool await_suspend(std::coroutine_handle h) noexcept { +// debugFunction(__PRETTY_FUNCTION__); +// printf("h=%p\n", h.address()); +// auto owner = h.promise().owner; +// if (owner) owner->_destruct(); +// H4::suspendedTasks.erase(owner); +// // [ ] IF NOT IMMEDIATEREQUEUE: MANAGE REQUEUE AND CHAIN CALLS. +// return false; +// } +// void await_resume() noexcept { debugFunction(__PRETTY_FUNCTION__); } +//}; +//H4Delay::promise_type::final_awaiter H4Delay::promise_type::final_suspend() noexcept { return {}; } +// +//bool H4Delay::await_ready() noexcept { debugFunction(__PRETTY_FUNCTION__); return false; } +// +//void H4Delay::await_suspend(const std::coroutine_handle h) noexcept { +// debugFunction(__PRETTY_FUNCTION__); +// printf("h=%p\n", h.address()); +// // Schedule the resumer. +// _coro = h; +// resumer = h4.once(duration, [this] { +// +// debugFunction(__PRETTY_FUNCTION__); +// _coro.resume(); +// }); +// h.promise().owner = owner; +// h.promise().resumer = resumer; +// H4::suspendedTasks[owner] = _coro; +//} +// +//void H4Delay::await_resume() noexcept { +// debugFunction(__PRETTY_FUNCTION__); +// resumer = nullptr; +//} +// +// +//void H4Delay::promise_type::cancel() { +// debugFunction(__PRETTY_FUNCTION__); +// auto _coro = std::coroutine_handle::from_promise(*this); +// printf("_coro=%p\n", _coro.address()); +// if (_coro) { +// // _coro.promise().owner = nullptr; +// _coro.destroy(); +// } +// if (resumer) { +// h4.cancel(resumer); +// resumer = nullptr; +// } +// H4::suspendedTasks.erase(owner); +// owner = nullptr; +//} +// +// +//void H4::dumpQ() {} +// +//uint64_t millis64() { +// static volatile uint64_t overflow = 0; +// static volatile uint32_t lastSample = 0; +// static const uint64_t kOverflowIncrement = static_cast(0x100000000); +// +// uint64_t overflowSample; +// uint32_t sample; +// +// // Tracking timer wrap assumes that this function gets called with +// // a period that is less than 1/2 the timer range. +// HAL_disableInterrupts(); +// sample = millis(); +// +// if (lastSample > sample) +// { +// overflow = overflow + kOverflowIncrement; +// } +// +// lastSample = sample; +// overflowSample = overflow; +// HAL_enableInterrupts(); +// +// return (overflowSample | static_cast(sample)); +//} +//// +//// task +//// +//task::task( +// H4_FN_VOID _f, +// uint32_t _m, +// uint32_t _x, +// H4_FN_COUNT _r, +// H4_FN_VOID _c, +// uint32_t _u, +// bool _s +//) : +// f{ _f }, +// rmin{ _m }, +// rmax{ _x }, +// reaper{ _r }, +// chain{ _c }, +// uid{ _u }, +// singleton{ _s } +//{ +// static uint64_t count = 0; +// count++; +// id = count; +// if (_s) { +// uint32_t id = _u % 100; +// if (h4.singles.count(id)) h4.singles[id]->endK(); +// h4.singles[id] = this; +// } +// schedule(); +//} +//task::task( +// H4_FN_COROUTINE _f, +// uint32_t _m, +// uint32_t _x, +// H4_FN_COUNT _r, +// H4_FN_VOID _c, +// uint32_t _u, +// bool _s +//) : +// fcoro{ _f }, +// rmin{ _m }, +// rmax{ _x }, +// reaper{ _r }, +// chain{ _c }, +// uid{ _u }, +// singleton{ _s } +//{ +// static uint64_t count = 0; +// count++; +// id = count; +// if (_s) { +// uint32_t id = _u % 100; +// if (h4.singles.count(id)) h4.singles[id]->endK(); +// h4.singles[id] = this; +// } +// schedule(); +//} +// +//bool task::operator() (const task* lhs, const task* rhs) const { return ((lhs->at > rhs->at) || (lhs->at == rhs->at && lhs->id > rhs->id)) ? true : false; } +//H4Coroutine h4dummy; +//void task::operator()() { +// if (harakiri) _destruct(); // for clean exits +// else { +// std::cout << "CALLING " << (f ? "F" : fcoro ? "FCORO" : "UNDEFINED") << std::endl; +// if (f) f(); +// else fcoro(h4dummy); +// // f(); +// bool thisis_suspended = H4::suspendedTasks.count(this); +// // CURRENTLY: THIS ONLY PREVENTS DESTRUCTION AT THIS POINT, IN FUTURE: RELAY REQUEUE & CHAIN .. +// if (reaper) { // it's finite +// if (!(reaper())) { // ...and it just ended +// _chain(); // run chain function if there is one +// if ((rmin == rmax) && rmin) { +// rmin = 86400000; // reque in +24 hrs +// rmax = 0; +// reaper = nullptr; // and every day after +// requeue(); +// } +// else if (!thisis_suspended) _destruct(); +// } +// else requeue(); +// } +// else requeue(); +// } +//} +// +//void task::_chain() { if (chain) h4.add(chain, 0, 0, H4Countdown(1), nullptr, uid); } // prevents tag rescaling during the pass +// +//void task::cancelSingleton(uint32_t s) { if (h4.singles.count(s)) h4.singles[s]->endK(); } +// +//uint32_t task::cleardown(uint32_t pass) { +// if (singleton) { +// uint32_t id = uid % 100; +// h4.singles.erase(id); +// } +// return pass; +//} +// +//void task::_destruct() { +// debugFunction(__PRETTY_FUNCTION__); +//#if H4_HOOK_TASKS +// H4::taskHook(this, 4); +//#endif +// lastRites(); +// if (partial) free(partial); +// delete this; +//} +//// The many ways to die... :) +//uint32_t task::endF() { +// // H4_Pirntf("ENDF %p\n",this); +// reaper = H4Countdown(1); +// at = 0; +// return cleardown(1 + nrq); +//} +// +//uint32_t task::endU() { +// // H4_Pirntf("ENDU %p\n",this); +// _chain(); +// return nrq + endK(); +//} +// +//uint32_t task::endC(H4_FN_TIF f) { +// bool rv = f(this); +// if (rv) return endF(); +// return rv; +//} +// +//uint32_t task::endK() { +// debugFunction(__PRETTY_FUNCTION__); +// // H4_Pirntf("ENDK %p\n",this); +// auto it = std::find_if(H4::suspendedTasks.begin(), H4::suspendedTasks.end(), [this](const std::pair> p) { return p.first == this; }); +// bool thisiscoro = it != H4::suspendedTasks.end(); +// std::cout << "\tthisiscoro=" << thisiscoro << std::endl; +// if (thisiscoro) { +// it->second.promise().cancel(); +// } +// harakiri = true; +// return cleardown(at = 0); +//} +// +//uint32_t task::randomRange(uint32_t rmin, uint32_t rmax) { return rmax > rmin ? (rand() % (rmax - rmin)) + rmin : rmin; } +// +//void task::requeue() { +// nrq++; +// schedule(); +// h4.qt(this); +//} +// +//void task::schedule() { at = millis64() + randomRange(rmin, rmax); } +// +//void task::createPartial(void* d, size_t l) { +// partial = malloc(l); +// memcpy(partial, d, l); +// len = l; +//} +//// +//// H4 +//// +//task* H4::add(H4_FN_VOID _f, uint32_t _m, uint32_t _x, H4_FN_COUNT _r, H4_FN_VOID _c, uint32_t _u, bool _s) { +// task* t = new task(_f, _m, _x, _r, _c, _u, _s); +//#if H4_HOOK_TASKS +// H4::taskHook(t, 1); +//#endif +// qt(t); +// return t; +//} +//task* H4::add(H4_FN_COROUTINE _f, uint32_t _m, uint32_t _x, H4_FN_COUNT _r, H4_FN_VOID _c, uint32_t _u, bool _s) { +// task* t = new task(_f, _m, _x, _r, _c, _u, _s); +//#if H4_HOOK_TASKS +// H4::taskHook(t, 1); +//#endif +// qt(t); +// return t; +//} +// +//uint32_t H4::gpFramed(task* t, H4_FN_RTPTR f) { +// uint32_t rv = 0; +// printf("t=%p, f=%p\n", t, f); +// if (t) { +// HAL_disableInterrupts(); +// if (has(t) || (t == H4::context) || H4::suspendedTasks.count(t)) rv = f(); // fix bug where context = 0! +// HAL_enableInterrupts(); +// } +// return rv; +//} +// +//uint32_t H4::endF(task* t) { return gpFramed(t, [=] { return t->endF(); }); } +// +//uint32_t H4::endU(task* t) { return gpFramed(t, [=] { return t->endU(); }); } +// +//bool H4::endC(task* t, H4_FN_TIF f) { return gpFramed(t, [=] { return t->endC(f); }); } +// +//task* H4::endK(task* t) { +// debugFunction(__PRETTY_FUNCTION__); +// return reinterpret_cast(gpFramed(t, [=] { return t->endK(); })); } +// +//void H4::qt(task* t) { +// HAL_disableInterrupts(); +// push(t); +// HAL_enableInterrupts(); +//#if H4_HOOK_TASKS +// H4::taskHook(t, 2); +//#endif +//} +//// +//extern void h4setup(); +// +//std::vector H4::_copyQ() { +// std::vector t; +// HAL_disableInterrupts(); +// t = c; +// HAL_enableInterrupts(); +// return t; +//} +// +//void H4::_hookLoop(H4_FN_VOID f, uint32_t subid) { +// if (f) { +// unloadables[subid] = loopChain.size(); +// loopChain.push_back(f); +// } +//} +// +//bool H4::_unHook(uint32_t subid) { +// if (unloadables.count(subid)) { +// loopChain.erase(loopChain.begin() + unloadables[subid]); +// unloadables.erase(subid); +// return true; +// } +// return false; +//} +// +//void setup() { +// h4.setup(); +// h4setup(); +//} +// +//void loop() { +// h4.loop(); +//} +// +//void H4::cancelAll(H4_FN_VOID f) { +// HAL_disableInterrupts(); +// while (!empty()) { +// top()->endK(); +// pop(); +// } +// HAL_enableInterrupts(); +// if (f) f(); +//} +// +//H4_TASK_PTR H4::every(uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, nullptr, fnc, TAG(3), s); } +// +//H4_TASK_PTR H4::everyRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, nullptr, fnc, TAG(4), s); } +// +//H4_TASK_PTR H4::nTimes(uint32_t n, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Countdown(n), fnc, TAG(5), s); } +// +//H4_TASK_PTR H4::nTimesRandom(uint32_t n, uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Countdown(n), fnc, TAG(6), s); } +// +//H4_TASK_PTR H4::once(uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Countdown(1), fnc, TAG(7), s); } +// +//H4_TASK_PTR H4::onceRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Countdown(1), fnc, TAG(8), s); } +// +//H4_TASK_PTR H4::queueFunction(H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, 0, 0, H4Countdown(1), fnc, TAG(9), s); } +// +//H4_TASK_PTR H4::randomTimes(uint32_t tmin, uint32_t tmax, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Random(tmin, tmax), fnc, TAG(10), s); } +// +//H4_TASK_PTR H4::randomTimesRandom(uint32_t tmin, uint32_t tmax, uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Random(tmin, tmax), fnc, TAG(11), s); } +// +//H4_TASK_PTR H4::repeatWhile(H4_FN_COUNT fncd, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, fncd, fnc, TAG(12), s); } +// +//H4_TASK_PTR H4::repeatWhileEver(H4_FN_COUNT fncd, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { +// return add(fn, msec, 0, fncd, +// std::bind([this](H4_FN_COUNT fncd, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { +// fnc(); +// repeatWhileEver(fncd, msec, fn, fnc, u, s); +// }, fncd, msec, fn, fnc, u, s), +// TAG(13), s); +//} +// +//H4_TASK_PTR H4::every(uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, nullptr, fnc, TAG(3), s); } +// +//H4_TASK_PTR H4::everyRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, nullptr, fnc, TAG(4), s); } +// +//H4_TASK_PTR H4::nTimes(uint32_t n, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Countdown(n), fnc, TAG(5), s); } +// +//H4_TASK_PTR H4::nTimesRandom(uint32_t n, uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Countdown(n), fnc, TAG(6), s); } +// +//H4_TASK_PTR H4::once(uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Countdown(1), fnc, TAG(7), s); } +// +//H4_TASK_PTR H4::onceRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Countdown(1), fnc, TAG(8), s); } +// +//H4_TASK_PTR H4::queueFunction(H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, 0, 0, H4Countdown(1), fnc, TAG(9), s); } +// +//H4_TASK_PTR H4::randomTimes(uint32_t tmin, uint32_t tmax, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Random(tmin, tmax), fnc, TAG(10), s); } +// +//H4_TASK_PTR H4::randomTimesRandom(uint32_t tmin, uint32_t tmax, uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Random(tmin, tmax), fnc, TAG(11), s); } +// +//H4_TASK_PTR H4::repeatWhile(H4_FN_COUNT fncd, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, fncd, fnc, TAG(12), s); } +// +//H4_TASK_PTR H4::repeatWhileEver(H4_FN_COUNT fncd, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { +// return add(fn, msec, 0, fncd, +// std::bind([this](H4_FN_COUNT fncd, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { +// fnc(); +// repeatWhileEver(fncd, msec, fn, fnc, u, s); +// }, fncd, msec, fn, fnc, u, s), +// TAG(13), s); +//} +// +//void H4::setup() { +//} +// +//void H4::loop() { +// task* t = nullptr; +// uint64_t now = millis64(); +// HAL_disableInterrupts(); +// if (size()) { +// if (((int64_t)(top()->at - now)) < 1) { +// t = top(); +// pop(); +// } +// } +// HAL_enableInterrupts(); +// if (t) { // H4P 35000 35100 +// H4::context = t; +// // H4_Pirntf("T=%u H4context <-- %p\n",millis(),t); +// (*t)(); +// // H4_Pirntf("T=%u H4context --> %p\n",millis(),t); +// // dumpQ(); +// }; +// // +// for (auto const& f : loopChain) f(); +//#if H4_USERLOOP +// h4UserLoop(); +//#endif +//#if H4_COUNT_LOOPS +// h4Nloops++; +//#endif +//} +// +//H4 h4(0); +//int H4main() { +// setup(); +// // Emulating while(1) loop. +// while (millis() < 10000) { +// if (!(millis() % 5)) +// std::cout << " T= " << millis() << "ms" << std::endl; +// // Each millisecond runs thousands of iterations, simulate a few: +// for (auto i = 0; i < 20; i++) +// loop(); +// globMillis++; +// } +// return 0; +//} +// +//H4Delay someF() { +// debugFunction(__PRETTY_FUNCTION__); +// printf("on 500, awaiting 400 ms:\n"); +// // auto currentContext = H4::context; +// // h4.once(100, [currentContext]{ debugFunction(__PRETTY_FUNCTION__); h4.cancel(currentContext); }); +// co_await H4Delay(400); +// printf("400ms awaited!\n"); +//} +//void h4setup() { +// // h4.once(1000, []{ printf("1000ms elapsed\n"); }); +// /* h4.queueFunction([]() ->H4Delay { +// // for (auto i=0 ; i<20; i++) { +// // printf("i=%d\n", i); +// co_await H4Delay(5); +// // } +// }); */ +// /* h4.queueFunction([](H4Coroutine) -> H4Delay { // Replacement to h4Chunker(vs,[](std::vector::iterator it){ printf("Processing [%s]\n", *it.data());}, 100,200); +// std::vector vs {"Hello", "World"}; +// for (auto &v : vs) { +// printf ("Processing [%s]\n", v.data()); +// co_await H4Delay(task::randomRange(100,200)); +// } +// }); +// h4.queueFunction([](H4Coroutine) -> H4Delay { // Replacement to h4.nTimes(20, 5, []{ printf("i=%d\n", ME->nrq);}); +// for (auto i = 0; i < 20; i++) { +// printf("i=%d\n", i); +// co_await H4Delay(5); // Delay asynchronously :) +// } +// printf("Chain Function\n"); +// }); +// +// h4.queueFunction([](H4Coroutine) -> H4Delay { // Replacement to h4.every(100, []{printf("Some processing\n"); }); +// while (true) { +// printf("Some processing\n"); +// co_await H4Delay(100); +// } +// }); */ +// auto context = h4.once(500, someF); +// h4.once(1000, [context] { debugFunction(__PRETTY_FUNCTION__); h4.cancel(context); }); +//} +//void h4UserLoop() { +// +//} +// +///* +// Coroutines: +// - co_await H4Delay({$Time}); +// - co_await H4Delay(0) does queue the continuation to the next loop iteration. +// - The function signature should return H4Delay type instead of void, and can accepts H4Coroutine Parameter. +// - Finishing the timer can be done by h4.cancel($task) or h4.FinishNow/h4.FinishIf/h4.cancel. Where they all destroy the coroutine handle. +// - FinishNow/FinishIf would call the chain function, cancel does not. +// - The chain or requeue of the coroutine gets called immediately, (Not after the coroutine function itself finishes), therefore if some h4.nTimes() function gets called it'd be rescheduled once it call the coroutine function. Also the chain would be scheduled just after calling the last function even if it's a coroutine. and it would co_await. +// +// +// */ \ No newline at end of file diff --git a/frontend/benchmark.h b/frontend/benchmark.h index 03823553..b24b8ee1 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -715,8 +715,8 @@ namespace osuCrypto Timer timer; std::array dpf; - dpf[0].init(0, domain, points0, values0); - dpf[1].init(1, domain, points1, values1); + dpf[0].init(0, domain, numPoints); + dpf[1].init(1, domain, numPoints); auto baseCount = dpf[0].baseOtCount(); @@ -750,13 +750,13 @@ namespace osuCrypto timer.setTimePoint("start"); macoro::sync_wait(macoro::when_all_ready( - dpf[0].expand(output[0], prng, sock[0]), - dpf[1].expand(output[1], prng, sock[1]) + dpf[0].expand(points0, values0, [&](auto k, auto i, auto v, auto t) { output[0](k,i) = v; }, prng, sock[0]), + dpf[1].expand(points1, values1, [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; }, prng, sock[1]) )); timer.setTimePoint("finish"); - dpf[0].init(0, domain, points0, values0); - dpf[1].init(1, domain, points1, values1); + dpf[0].init(0, domain, numPoints); + dpf[1].init(1, domain, numPoints); dpf[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); dpf[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); } diff --git a/frontend/main.cpp b/frontend/main.cpp index ce2ccf44..346f217f 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -29,6 +29,7 @@ #include "libOTe/TwoChooseOne/Iknp/IknpOtExtSender.h" #include "libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.h" +int H4main(); using namespace osuCrypto; #ifdef ENABLE_IKNP @@ -81,6 +82,8 @@ void minimal() int main(int argc, char** argv) { + //H4main(); + //return 0; CLP cmd; cmd.parse(argc, argv); diff --git a/libOTe/Tools/Dpf/DpfMult.h b/libOTe/Tools/Dpf/DpfMult.h new file mode 100644 index 00000000..cabeae7a --- /dev/null +++ b/libOTe/Tools/Dpf/DpfMult.h @@ -0,0 +1,163 @@ +#pragma once + + +#include "cryptoTools/Common/Defines.h" +#include "coproto/Socket/Socket.h" +#include "cryptoTools/Crypto/PRNG.h" +#include "cryptoTools/Common/BitVector.h" +#include "cryptoTools/Common/Matrix.h" + +namespace osuCrypto +{ + struct DpfMult + { + + u64 mPartyIdx = 0; + + u64 mTotalMults = 0; + + oc::BitVector mChoiceBits; + + std::vector mRecvOts; + std::vector> mSendOts; + + u64 mOtIdx = 0; + + u8 lsb(const block& b) + { + return b.get(0) & 1; + } + + void init( + u64 partyIdx, + u64 n) + { + if (partyIdx > 1) + throw RTE_LOC; + + mPartyIdx = partyIdx; + mTotalMults = n; + mOtIdx = 0; + mSendOts.clear(); + mRecvOts.clear(); + mChoiceBits.resize(0); + } + + + // We are given two OTs, one in each direction. Let us denote them as + // + // a0 b0 + // c00 c01 + // + // b1 a1 + // c10 c11 + // + // such that + // + // a0 * b0 = (c00 + c01) + // a1 * b1 = (c10 + c11) + // + // Note that we write these OTs in OLE format, that is for OT (m0,m1),(g,mg) + // we have a0=g, b0=(m0+m1), c00=mg, c01=m0 and similar for the second + // instance. + // + // We first convert these two "OTs/OLEs" into a random beaver triple + // + // [a] * [b] = [c'] + // + // We do this by computing + // + // [a] = (a0, a1) + // [b] = (b1, b0) + // [c'] = (c00+c10+a0b1, c01+c11+a1b0) + // + // As you can see, all 4 cross terms are present. Given this beaver triple + // we can use the standard protocol. We reveal + // + // phi = [x] + [a] + // theta = [y] + [b] + // + // [zy] = [c'] + theta a + phi b + theta phi + // = ab + (y+b) a + (x+a) b + (y+b)(x+a) + // = ab + ab + ya + xb + ab + yx + ya + xb + ab + // = xy + // + macoro::task<> multiply(const oc::BitVector& x, span y, span xy, coproto::Socket& sock) + { + if (x.size() != y.size() || x.size() != xy.size()) + throw RTE_LOC; + if (x.size() + mOtIdx > mTotalMults) + throw RTE_LOC; + + BitVector a0; a0.append(mChoiceBits, x.size(), mOtIdx); + AlignedUnVector A0(x.size()), C(x.size()), theta(x.size()), b1(x.size()); + for (u64 j = 0; j < x.size(); ++j) + { + A0[j] = block(-u64(a0[j]), -u64(a0[j])); + auto c00 = mRecvOts[mOtIdx + j]; + + auto c10 = mSendOts[mOtIdx + j][0]; + + b1[j] = mSendOts[mOtIdx + j][0] ^ mSendOts[mOtIdx + j][1]; + // C0' = c00+c10+a0b1 + C[j] = c00 ^ c10 ^ (b1[j] & A0[j]); + + theta[j] = y[j] ^ b1[j]; + } + auto phi = x ^ a0; + while (phi.size() % 8) + phi.pushBack(0); + + AlignedUnVector buffer(theta.size() + phi.sizeBlocks()); + memcpy(buffer.data(), theta.data(), theta.size() * sizeof(block)); + memcpy(buffer.data() + theta.size(), phi.data(), phi.sizeBytes()); + + co_await sock.send(std::move(buffer)); + + buffer.resize(theta.size() + phi.sizeBlocks()); + co_await sock.recv(buffer); + span theta1(buffer.data(), theta.size()); + BitVector phi1((u8*)&buffer[theta.size()], phi.size()); + + phi ^= phi1; + for (u64 j = 0; j < x.size(); ++j) + { + auto Phi = block(-u64(phi[j]), -u64(phi[j])); + theta[j] ^= theta1[j]; + xy[j] = C[j] ^ theta[j] & A0[j] ^ Phi & b1[j]; + + if (mPartyIdx) + xy[j] ^= theta[j] & Phi; + } + + + mOtIdx += x.size(); + + } + + u64 baseOtCount() const { return mTotalMults; } + + void setBaseOts( + span> baseSendOts, + span recvBaseOts, + const oc::BitVector& baseChoices) + { + if (baseSendOts.size() != baseOtCount() || + recvBaseOts.size() != baseOtCount() || + baseChoices.size() != baseOtCount()) + throw RTE_LOC; + + mSendOts.clear(); + mRecvOts.clear(); + mSendOts.insert(mSendOts.end(), baseSendOts.begin(), baseSendOts.end()); + mRecvOts.insert(mRecvOts.end(), recvBaseOts.begin(), recvBaseOts.end()); + mChoiceBits = baseChoices; + mOtIdx = 0; + } + + + }; + +} + +#undef SIMD8 \ No newline at end of file diff --git a/libOTe/Tools/Dpf/RegularDpf.h b/libOTe/Tools/Dpf/RegularDpf.h index ecfd1e09..b009f361 100644 --- a/libOTe/Tools/Dpf/RegularDpf.h +++ b/libOTe/Tools/Dpf/RegularDpf.h @@ -7,6 +7,8 @@ #include "cryptoTools/Common/BitVector.h" #include "cryptoTools/Common/Matrix.h" +#include "DpfMult.h" + namespace osuCrypto { struct RegularDpf @@ -31,16 +33,9 @@ namespace osuCrypto u64 mDepth = 0; - std::vector mPoints; - - std::vector mValues; - - oc::BitVector mChoiceBits; - - std::vector mRecvOts; - std::vector> mSendOts; + u64 mNumPoints = 0; - u64 mOtIdx = 0; + DpfMult mMultiplier; u8 lsb(const block& b) { @@ -50,25 +45,20 @@ namespace osuCrypto void init( u64 partyIdx, u64 domain, - span points, - span values) + u64 numPoints) { if (partyIdx > 1) throw RTE_LOC; if (domain < 2) throw RTE_LOC; - if (points.size() != values.size()) + if (!numPoints) throw RTE_LOC; + mDepth = oc::log2ceil(domain); mPartyIdx = partyIdx; mDomain = domain; - mDepth = oc::log2ceil(domain); - - mPoints.clear(); - mValues.clear(); - mPoints.insert(mPoints.end(), points.begin(), points.end()); - mValues.insert(mValues.end(), values.begin(), values.end()); - + mNumPoints = numPoints; + mMultiplier.init(partyIdx, numPoints * mDepth); } #define SIMD8(VAR, STATEMENT) \ @@ -82,417 +72,308 @@ namespace osuCrypto { constexpr u64 VAR = 7; STATEMENT; }\ do{}while(0) - template + template< + typename Output + > macoro::task<> expand( + span points, + span values, Output&& output, PRNG& prng, coproto::Socket& sock) { if constexpr (std::is_same, Matrix>::value) { - if (output.rows() != mPoints.size()) + if (output.rows() != mNumPoints) throw RTE_LOC; if (output.cols() != mDomain) throw RTE_LOC; } + if (points.size() != mNumPoints) + throw RTE_LOC; + if (values.size() && values.size() != mNumPoints) + throw RTE_LOC; - u64 numPoints = mPoints.size(); + u64 numPoints = points.size(); u64 numPoints8 = numPoints / 8 * 8; // shares of S' - //std::vector> s(mDepth + 2); auto pow2 = 1ull << log2ceil(mDomain); std::array, 2> s; s[mDepth & 1].resize(pow2, numPoints, oc::AllocType::Uninitialized); - s[(mDepth & 1) ^ 1].resize(pow2/2, numPoints, oc::AllocType::Uninitialized); - - //s[0].resize(1, mPoints.size()); - prng.get(s[0].data(), 1); - + s[(mDepth & 1) ^ 1].resize(pow2 / 2, numPoints, oc::AllocType::Uninitialized); // share of t std::array, 2> t; t[0].resize(s[0].rows(), s[0].cols()); t[1].resize(s[1].rows(), s[1].cols()); for (u64 i = 0; i < numPoints; ++i) - t[0](0,i) = mPartyIdx; - //std::vector> t(mDepth + 2); - //t[0].resize(1, mPoints.size()); - //for (auto& tt : t[0]) - // tt = mPartyIdx; + t[0](0, i) = mPartyIdx; - std::array, 2> tau; - tau[0].resize(mPoints.size()); - tau[1].resize(mPoints.size()); - - std::array hashes{ - block(223142132554234532,345324534532452345), - block(476657546875476456,849723947534923433), - }; - std::array, 2> z, zg; - z[0].resize(mPoints.size()); - z[1].resize(mPoints.size()); - zg[0].resize(mPoints.size()); - zg[1].resize(mPoints.size()); - AlignedUnVector sigma(mPoints.size()); - BitVector negAlphaj(mPoints.size()); - AlignedUnVector diff(mPoints.size()); +#if defined(NDEBUG) + auto getRow = [](auto&& m, u64 i) {return m.data(i); }; +#else + auto getRow = [](auto&& m, u64 i) {return m[i]; }; +#endif + std::array, 2> tau; + tau[0].resize(mNumPoints); + tau[1].resize(mNumPoints); - { - //s[1].resize(2, mPoints.size()); - //t[1].resize(2, mPoints.size()); + std::array, 2> z; + z[0].resize(mNumPoints); + z[1].resize(mNumPoints); + AlignedUnVector sigma(mNumPoints); + BitVector negAlphaj(mNumPoints); + AlignedUnVector diff(mNumPoints); - setBytes(z[0], 0); - setBytes(z[1], 0); - auto spi = s[0][0]; + { + // we skip level 0 and set level 1 to be random auto sc0 = s[1][0]; auto sc1 = s[1][1]; for (u64 k = 0; k < numPoints; ++k) { - sc0[k] = hashes[0].hashBlock(spi[k]); - sc1[k] = hashes[1].hashBlock(spi[k]); + sc0[k] = prng.get(); + sc1[k] = prng.get(); - z[0][k] ^= sc0[k]; - z[1][k] ^= sc1[k]; + z[0][k] = sc0[k]; + z[1][k] = sc1[k]; } } for (u64 iter = 1; iter <= mDepth; ++iter) { - //auto& sp = s[iter - 1]; + // the parent level auto& tp = t[(iter - 1) & 1]; + + // the child level auto& sc = s[iter & 1]; auto& tc = t[iter & 1]; + + // the grandchild level auto& sg = s[(iter + 1) & 1]; auto size = 1ull << iter; - auto size2 = 1ull << (iter + 1); - if (iter != mDepth) + // + for (u64 k = 0; k < mNumPoints; ++k) { - //sg.resize(size2, mPoints.size()); - //t[iter + 1].resize(size2, mPoints.size()); - - setBytes(zg[0], 0); - setBytes(zg[1], 0); - } - - for (u64 k = 0; k < mPoints.size(); ++k) - { - auto alphaj = *oc::BitIterator(&mPoints[k], mDepth - iter); + auto alphaj = *oc::BitIterator(&points[k], mDepth - iter); tau[0][k] = lsb(z[0][k]) ^ alphaj ^ mPartyIdx; tau[1][k] = lsb(z[1][k]) ^ alphaj; diff[k] = z[0][k] ^ z[1][k]; negAlphaj[k] = alphaj ^ mPartyIdx; } - co_await multiply(negAlphaj, diff, diff, sock); + co_await mMultiplier.multiply(negAlphaj, diff, diff, sock); // sigma = z[1^alpha[j]] - for (u64 k = 0; k < mPoints.size(); ++k) + for (u64 k = 0; k < mNumPoints; ++k) sigma[k] = diff[k] ^ z[0][k]; - // reveal - u64 buffSize = sigma.size() * 16 + divCeil(mPoints.size() * 2, 8); - AlignedUnVector buffer(buffSize); - copyBytesMin(buffer, sigma); - auto bitIter = BitIterator(&buffer[numPoints * 16]); - for (u64 i = 0; i < mPoints.size(); ++i) + // reveal sigma and tau + u64 buffSize = sigma.size() * 16 + divCeil(mNumPoints * 2, 8); + AlignedUnVector sendBuff(buffSize), recvBuff(buffSize); + copyBytesMin(sendBuff, sigma); + auto sendBitIter = BitIterator(&sendBuff[numPoints * 16]); + auto recvBitIter = BitIterator(&recvBuff[numPoints * 16]); + for (u64 i = 0; i < mNumPoints; ++i) { - *bitIter++ = tau[0][i]; - *bitIter++ = tau[1][i]; + *sendBitIter++ = tau[0][i]; + *sendBitIter++ = tau[1][i]; } - if (bitIter.mByte >= buffer.data() + buffer.size() && bitIter.mShift) - throw RTE_LOC; - co_await sock.send(std::move(buffer)); - buffer.resize(buffSize); - bitIter = BitIterator(&buffer[numPoints * 16]); - co_await sock.recv(buffer); - for (u64 k = 0; k < mPoints.size(); ++k) + co_await sock.send(std::move(sendBuff)); + co_await sock.recv(recvBuff); + for (u64 k = 0; k < mNumPoints; ++k) { - block sk = *(block*)&buffer[k * sizeof(block)]; + block sk = *(block*)&recvBuff[k * sizeof(block)]; sigma[k] ^= sk; - tau[0][k] ^= *bitIter++; - tau[1][k] ^= *bitIter++; + tau[0][k] ^= *recvBitIter++; + tau[1][k] ^= *recvBitIter++; } - if (iter == mDepth) + + if (iter != mDepth) { + setBytes(z[0], 0); + setBytes(z[1], 0); + for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 2, L4 += 4) { -#if defined(NDEBUG) - auto tpl = tp.data(L); - auto scl0 = sc.data(L2 + 0); - auto scl1 = sc.data(L2 + 1); - auto tcl0 = tc.data(L2 + 0); - auto tcl1 = tc.data(L2 + 1); -#else - auto tpl = tp[L]; - auto scl0 = sc[L2 + 0]; - auto scl1 = sc[L2 + 1]; - auto tcl0 = tc[L2 + 0]; - auto tcl1 = tc[L2 + 1]; -#endif + // parent control bits + auto tpl = getRow(tp, L); + + // child seed + std::array scl{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; + + // child control bit + std::array tcl{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; + + // grandchild seeds + std::array sgl{ getRow(sg, L4 + 0), getRow(sg, L4 + 1), getRow(sg, L4 + 2), getRow(sg, L4 + 3) }; for (u64 k = 0; k < numPoints8; k += 8) { - block T[8]; - SIMD8(q, T[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); - SIMD8(q, tcl0[k + q] = lsb(scl0[k + q]) ^ tpl[k + q] & tau[0][k + q]); - SIMD8(q, tcl1[k + q] = lsb(scl1[k + q]) ^ tpl[k + q] & tau[1][k + q]); - SIMD8(q, scl0[k + q] ^= T[q]); - SIMD8(q, scl1[k + q] ^= T[q]); - } + block temp[8]; + SIMD8(q, temp[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); + SIMD8(q, tcl[0][k + q] = lsb(scl[0][k + q]) ^ tpl[k + q] & tau[0][k + q]); + SIMD8(q, scl[0][k + q] ^= temp[q]); - for (u64 k = numPoints8; k < mPoints.size(); ++k) - { - auto T = block::allSame(-tpl[k + 0]) & sigma[k + 0]; - tc[L2 + 0][k] = lsb(sc[L2 + 0][k]) ^ tp[L][k] & tau[0][k]; - tc[L2 + 1][k] = lsb(sc[L2 + 1][k]) ^ tp[L][k] & tau[1][k]; - sc[L2 + 0][k] ^= T; - sc[L2 + 1][k] ^= T; - } - } - } - else - { - for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 2, L4 += 4) - { -#if defined(NDEBUG) - auto tpl = tp.data(L); - auto scl0 = sc.data(L2 + 0); - auto scl1 = sc.data(L2 + 1); - auto tcl0 = tc.data(L2 + 0); - auto tcl1 = tc.data(L2 + 1); - - auto sg00 = sg.data(L4 + 0); - auto sg10 = sg.data(L4 + 1); - auto sg01 = sg.data(L4 + 2); - auto sg11 = sg.data(L4 + 3); -#else + mAesFixedKey.ecbEncBlocks<8>(&scl[0][k], &sgl[1][k]); + SIMD8(q, sgl[0][k + q] = AES::roundEnc(sgl[1][k + q], scl[0][k + q])); + SIMD8(q, sgl[1][k + q] = sgl[1][k + q] + scl[0][k + q]); - auto tpl = tp[L]; - auto scl0 = sc[L2 + 0]; - auto scl1 = sc[L2 + 1]; - auto tcl0 = tc[L2 + 0]; - auto tcl1 = tc[L2 + 1]; + SIMD8(q, z[0][k + q] ^= sgl[0][k + q]); + SIMD8(q, z[1][k + q] ^= sgl[1][k + q]); - auto sg00 = sg[L4 + 0]; - auto sg10 = sg[L4 + 1]; - auto sg01 = sg[L4 + 2]; - auto sg11 = sg[L4 + 3]; -#endif + SIMD8(q, tcl[1][k + q] = lsb(scl[1][k + q]) ^ tpl[k + q] & tau[1][k + q]); + SIMD8(q, scl[1][k + q] ^= temp[q]); - for (u64 k = 0; k < numPoints8; k += 8) - { - block T[8]; - SIMD8(q, T[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); - SIMD8(q, tcl0[k + q] = lsb(scl0[k + q]) ^ tpl[k + q] & tau[0][k + q]); - SIMD8(q, scl0[k + q] ^= T[q]); - - hashes[0].ecbEncBlocks<8>(&scl0[k], &sg10[k]); - SIMD8(q, sg00[k + q] = AES::roundEnc(sg10[k + q], scl0[k + q])); - SIMD8(q, sg10[k + q] = sg10[k + q] + scl0[k + q]); - - SIMD8(q, zg[0][k + q] ^= sg00[k + q]); - SIMD8(q, zg[1][k + q] ^= sg10[k + q]); - - SIMD8(q, tcl1[k + q] = lsb(scl1[k + q]) ^ tpl[k + q] & tau[1][k + q]); - SIMD8(q, scl1[k + q] ^= T[q]); - - hashes[0].ecbEncBlocks<8>(&scl1[k], &sg11[k]); - SIMD8(q, sg01[k + q] = AES::roundEnc(sg11[k + q], scl1[k + q])); - SIMD8(q, sg11[k + q] = sg11[k + q] + scl1[k + q]); - SIMD8(q, zg[0][k + q] ^= sg01[k + q]); - SIMD8(q, zg[1][k + q] ^= sg11[k + q]); + mAesFixedKey.ecbEncBlocks<8>(&scl[1][k], &sgl[3][k]); + SIMD8(q, sgl[2][k + q] = AES::roundEnc(sgl[3][k + q], scl[1][k + q])); + SIMD8(q, sgl[3][k + q] = sgl[3][k + q] + scl[1][k + q]); + SIMD8(q, z[0][k + q] ^= sgl[2][k + q]); + SIMD8(q, z[1][k + q] ^= sgl[3][k + q]); } - for (u64 k = numPoints8; k < mPoints.size(); ++k) + for (u64 k = numPoints8; k < mNumPoints; ++k) { - auto T = block::allSame(-tpl[k + 0]) & sigma[k + 0]; + auto temp = block::allSame(-tpl[k + 0]) & sigma[k + 0]; + + tcl[0][k] = lsb(scl[0][k]) ^ tpl[k] & tau[0][k]; + scl[0][k] ^= temp; - tcl0[k] = lsb(scl0[k]) ^ tpl[k] & tau[0][k]; - scl0[k] ^= T; + sgl[1][k] = mAesFixedKey.ecbEncBlock(scl[0][k]); + sgl[0][k] = AES::roundEnc(sgl[1][k], scl[0][k]); + sgl[1][k] = sgl[1][k] + scl[0][k]; - sg10[k] = hashes[0].ecbEncBlock(scl0[k]); - sg00[k] = AES::roundEnc(sg10[k], scl0[k]); - sg10[k] = sg10[k] + scl0[k]; + z[0][k] ^= sgl[0][k]; + z[1][k] ^= sgl[1][k]; - zg[0][k] ^= sg00[k]; - zg[1][k] ^= sg10[k]; + tcl[1][k] = lsb(scl[1][k]) ^ tpl[k] & tau[1][k]; + scl[1][k] ^= temp; - tcl1[k] = lsb(scl1[k]) ^ tpl[k] & tau[1][k]; - scl1[k] ^= T; - - sg11[k] = hashes[0].ecbEncBlock(scl1[k]); - sg01[k] = AES::roundEnc(sg11[k], scl1[k]); - sg11[k] = sg11[k] + scl1[k]; + sgl[3][k] = mAesFixedKey.ecbEncBlock(scl[1][k]); + sgl[2][k] = AES::roundEnc(sgl[3][k], scl[1][k]); + sgl[3][k] = sgl[3][k] + scl[1][k]; - zg[0][k] ^= sg01[k]; - zg[1][k] ^= sg11[k]; + z[0][k] ^= sgl[2][k]; + z[1][k] ^= sgl[3][k]; } } } + } + - std::swap(z, zg); + // fixing the last layer + { + auto size = 1ull << mDepth; + + auto& tp = t[(mDepth - 1) & 1]; + auto& sc = s[mDepth & 1]; + auto& tc = t[mDepth & 1]; + for (u64 L = 0, L2 = 0; L2 < size; ++L, L2 += 2) + { + // parent control bits + auto tpl = getRow(tp, L); + + // child seed + std::array scl{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; + + // child control bit + std::array tcl{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; + + for (u64 k = 0; k < numPoints8; k += 8) + { + block temp[8]; + SIMD8(q, temp[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); + SIMD8(q, tcl[0][k + q] = lsb(scl[0][k + q]) ^ tpl[k + q] & tau[0][k + q]); + SIMD8(q, tcl[1][k + q] = lsb(scl[1][k + q]) ^ tpl[k + q] & tau[1][k + q]); + SIMD8(q, scl[0][k + q] ^= temp[q]); + SIMD8(q, scl[1][k + q] ^= temp[q]); + } + + for (u64 k = numPoints8; k < mNumPoints; ++k) + { + auto temp = block::allSame(-tpl[k + 0]) & sigma[k + 0]; + tc[L2 + 0][k] = lsb(scl[0][k]) ^ tpl[k] & tau[0][k]; + tc[L2 + 1][k] = lsb(scl[1][k]) ^ tpl[k] & tau[1][k]; + sc[L2 + 0][k] ^= temp; + sc[L2 + 1][k] ^= temp; + } + } } - if (mValues.size()) + if (values.size()) { - AlignedUnVector gamma(mPoints.size()); - for (u64 k = 0; k < mPoints.size(); ++k) + AlignedUnVector gamma(mNumPoints); + for (u64 k = 0; k < mNumPoints; ++k) { - diff[k] = zg[0][k] ^ zg[1][k] ^ mValues[k]; + diff[k] = z[0][k] ^ z[1][k] ^ values[k]; } co_await sock.send(std::move(diff)); co_await sock.recv(gamma); - for (u64 k = 0; k < mPoints.size(); ++k) + for (u64 k = 0; k < mNumPoints; ++k) { - gamma[k] = zg[0][k] ^ zg[1][k] ^ mValues[k] ^ gamma[k]; + gamma[k] = z[0][k] ^ z[1][k] ^ values[k] ^ gamma[k]; } - auto& sd = s[mDepth&1]; - auto& td = t[mDepth&1]; + auto& sd = s[mDepth & 1]; + auto& td = t[mDepth & 1]; for (u64 i = 0; i < mDomain; ++i) { -#if defined(NDEBUG) - auto sdi = sd.data(i); - auto tdi = td.data(i); -#else - auto sdi = sd[i]; - auto tdi = td[i]; -#endif + auto sdi = getRow(sd, i); + auto tdi = getRow(td, i); for (u64 k = 0; k < numPoints8; k += 8) { block T[8]; SIMD8(q, T[q] = block::allSame(-tdi[k + q]) & gamma[k + q]); - SIMD8(q, output(k + q, i) = sdi[k + q] ^ T[q]); + SIMD8(q, output(k + q, i, sdi[k + q] ^ T[q], tdi[k+q])); } - for (u64 k = numPoints8; k < mPoints.size(); ++k) + for (u64 k = numPoints8; k < mNumPoints; ++k) { auto T = block::allSame(-tdi[k]) & gamma[k]; - output(k, i) = sdi[k] ^ T; + output(k, i, sdi[k] ^ T, tdi[k]); } } } - } - - - - - // We are given two OTs, one in each direction. Let us denote them as - // - // a0 b0 - // c00 c01 - // - // b1 a1 - // c10 c11 - // - // such that - // - // a0 * b0 = (c00 + c01) - // a1 * b1 = (c10 + c11) - // - // Note that we write these OTs in OLE format, that is for OT (m0,m1),(g,mg) - // we have a0=g, b0=(m0+m1), c00=mg, c01=m0 and similar for the second - // instance. - // - // We first convert these two "OTs/OLEs" into a random beaver triple - // - // [a] * [b] = [c'] - // - // We do this by computing - // - // [a] = (a0, a1) - // [b] = (b1, b0) - // [c'] = (c00+c10+a0b1, c01+c11+a1b0) - // - // As you can see, all 4 cross terms are present. Given this beaver triple - // we can use the standard protocol. We reveal - // - // phi = [x] + [a] - // theta = [y] + [b] - // - // [zy] = [c'] + theta a + phi b + theta phi - // = ab + (y+b) a + (x+a) b + (y+b)(x+a) - // = ab + ab + ya + xb + ab + yx + ya + xb + ab - // = xy - // - macoro::task<> multiply(const oc::BitVector& x, span y, span xy, coproto::Socket& sock) - { - if (x.size() != y.size() || x.size() != xy.size()) - throw RTE_LOC; - BitVector a0; a0.append(mChoiceBits, x.size(), mOtIdx); - AlignedUnVector A0(x.size()), C(x.size()), theta(x.size()), b1(x.size()); - for (u64 j = 0; j < x.size(); ++j) - { - A0[j] = block(-u64(a0[j]), -u64(a0[j])); - auto c00 = mRecvOts[mOtIdx + j]; - - auto c10 = mSendOts[mOtIdx + j][0]; - - b1[j] = mSendOts[mOtIdx + j][0] ^ mSendOts[mOtIdx + j][1]; - // C0' = c00+c10+a0b1 - C[j] = c00 ^ c10 ^ (b1[j] & A0[j]); - - theta[j] = y[j] ^ b1[j]; - } - auto phi = x ^ a0; - while (phi.size() % 8) - phi.pushBack(0); - - AlignedUnVector buffer(theta.size() + phi.sizeBlocks()); - memcpy(buffer.data(), theta.data(), theta.size() * sizeof(block)); - memcpy(buffer.data() + theta.size(), phi.data(), phi.sizeBytes()); - - co_await sock.send(std::move(buffer)); - - buffer.resize(theta.size() + phi.sizeBlocks()); - co_await sock.recv(buffer); - span theta1(buffer.data(), theta.size()); - BitVector phi1((u8*)&buffer[theta.size()], phi.size()); - - phi ^= phi1; - for (u64 j = 0; j < x.size(); ++j) + else { - auto Phi = block(-u64(phi[j]), -u64(phi[j])); - theta[j] ^= theta1[j]; - xy[j] = C[j] ^ theta[j] & A0[j] ^ Phi & b1[j]; - - if (mPartyIdx) - xy[j] ^= theta[j] & Phi; + auto& sd = s[mDepth & 1]; + auto& td = t[mDepth & 1]; + for (u64 i = 0; i < mDomain; ++i) + { + auto sdi = getRow(sd, i); + auto tdi = getRow(td, i); + for (u64 k = 0; k < numPoints8; k += 8) + { + SIMD8(q, output(k + q, i, sdi[k + q], tdi[k + q])); + } + for (u64 k = numPoints8; k < mNumPoints; ++k) + { + output(k, i, sdi[k], tdi[k]); + } + } } + } - mOtIdx += x.size(); - + u64 baseOtCount() const { + return mMultiplier.baseOtCount(); } - u64 baseOtCount() const { return mDepth * mPoints.size(); } - void setBaseOts( span> baseSendOts, span recvBaseOts, const oc::BitVector& baseChoices) { - if (baseSendOts.size() != baseOtCount() || - recvBaseOts.size() != baseOtCount() || - baseChoices.size() != baseOtCount()) - throw RTE_LOC; - - mSendOts.clear(); - mRecvOts.clear(); - mSendOts.insert(mSendOts.end(), baseSendOts.begin(), baseSendOts.end()); - mRecvOts.insert(mRecvOts.end(), recvBaseOts.begin(), recvBaseOts.end()); - mChoiceBits = baseChoices; - mOtIdx = 0; + mMultiplier.setBaseOts(baseSendOts, recvBaseOts, baseChoices); } diff --git a/libOTe/Tools/Dpf/SparseDpf.h b/libOTe/Tools/Dpf/SparseDpf.h new file mode 100644 index 00000000..a24159cb --- /dev/null +++ b/libOTe/Tools/Dpf/SparseDpf.h @@ -0,0 +1,335 @@ +#pragma once + + +#include "cryptoTools/Common/Defines.h" +#include "coproto/Socket/Socket.h" +#include "cryptoTools/Crypto/PRNG.h" +#include "cryptoTools/Common/BitVector.h" +#include "cryptoTools/Common/Matrix.h" + +namespace osuCrypto +{ + + struct SparseDpf + { + u64 mPartyIdx = 0; + + u64 mDomain = 0; + + struct Point + { + // the point's true address + u32 mAddress; + + // the number of points before this point. + u32 mFinalRank; + }; + + BitVector reverse(BitVector b) + { + BitVector r(b.size()); + for (u64 i = 0; i < b.size(); ++i) + r[r.size() - 1 - i] = b[i]; + return r; + } + std::string print(u32 p, u32 bitCount) + { + auto low = reverse(BitVector((u8*)&p, bitCount)); + auto hgh = reverse(BitVector((u8*)&p, 32 - bitCount, bitCount)); + std::stringstream ss; + ss << hgh << "." << low; + return ss.str(); + } + + // For the given Expand node, mDepth[0] is how far down the + // expanded left child should be copied. mDepth[1] is how far down the + // expanded right child should be copied. + // + // 0 indicates that it is a final node and should be copied to the + // final output. + struct Expand + { + std::array mDepth; + }; + + enum class Initial + { + None, + Final, + Expand + }; + + bool mDebug = true; + std::vector> mInitals; + std::vector>> mFinals; + std::vector>> mExpands; + + + // A partition represents a path of mDepth degree 1 nodes + // followed a degree 2 node. e.g. + // * <| 0 + // * <| 1 + // * <| 2 + // * * + // This would have depth 2. A partition also + // contains two sets of points that are under + // the left and right subtree. + struct Partition + { + // The sets of points that are contained in the left and right subtree. + std::array, 2> mSets; + + // The number of degree 1 nodes that lead to the degree 2 node. + u32 mDepth; + + //u32 mPrefix; + + u32 mLowBitCount; + }; + + Partition partition(span points, u32 lowBitCount) + { +#ifndef NDEBUG + u32 prefix = points[0] >> (lowBitCount + 1); + for (auto pp : points) + if (pp >> (lowBitCount + 1) != prefix) + throw RTE_LOC; +#endif + + assert(points.size() > 1); + auto iter = std::find_if(points.begin(), points.end(), [lowBitCount](auto v) {return (v >> lowBitCount) & 1; }); + if (iter == points.begin() || iter == points.end()) + { + assert(lowBitCount); + auto p = partition(points, lowBitCount - 1); + ++p.mDepth; + return p; + } + + return Partition{ + .mSets{ span(points.begin(), iter), span(iter, points.end()) }, + .mDepth{0}, + .mLowBitCount = lowBitCount }; + } + + void getLevels(Partition& par, u64 treeIdx, span points) + { + std::cout << "-> {"; + for (auto j = 0; j < 2; ++j) + { + if (j) + std::cout << ',' << std::endl; + else + std::cout << std::endl; + + for (auto p : par.mSets[j]) + { + std::cout << print(p, par.mLowBitCount) << std::endl; + } + } + std::cout << "}" << std::endl; + + Expand expand; + for (u64 i = 0; i < 2; ++i) + { + if (par.mSets[i].size() == 1) + { + auto idx = std::distance(points.data(), &par.mSets[i][0]); + mFinals[treeIdx][par.mLowBitCount].push_back(idx); + expand.mDepth[i] = 0; + //expand = (Expand)((u8)i | (u8)expand); + } + else + { + + // * <| par + // * <| + // * <| + // * * <| p2 + // * <| + // * <| + // * * + + assert(par.mSets[i].size()); + auto bIdx2 = par.mLowBitCount - 1 - par.mDepth; + auto p2 = partition(par.mSets[i], bIdx2); + getLevels(p2, treeIdx, points); + expand.mDepth[i] = p2.mDepth + 1; + //expand.mOutputs[i] = Address{ .mLevel{bIdx2}, .mIndex{mSizes[treeIdx][bIdx2]++} }; + } + } + mExpands[treeIdx][par.mLowBitCount].push_back(expand); + } + + + void init( + u64 partyIdx, + u64 domain, + MatrixView sparsePoints) + { + mPartyIdx = partyIdx; + mDomain = domain; + + auto depth = log2ceil(domain); + auto preDepth = log2ceil(sparsePoints.cols()); + auto preDomain = 1ull << preDepth; + auto bitCount = depth - preDepth; + auto shift = bitCount - 1; + u32 mask = (1ull << bitCount) - 1; + + mExpands.resize(sparsePoints.rows()); + mFinals.resize(sparsePoints.rows()); + mInitals.resize(sparsePoints.rows()); + + for (u64 r = 0; r < sparsePoints.rows(); ++r) + { + assert(std::is_sorted(sparsePoints[r].begin(), sparsePoints[r].end())); + + mInitals[r].resize(preDomain); + mExpands[r].resize(bitCount); + mFinals[r].resize(bitCount + 1); + std::vector> points(preDomain); + auto iter = sparsePoints[r].begin(); + while (iter != sparsePoints[r].end()) + { + auto p = *iter; + auto idx = p >> bitCount; + auto end = std::find_if(iter, sparsePoints[r].end(), [idx, bitCount](auto v) {return (v >> bitCount) != idx; }); + points[idx] = span(iter, end); + iter = end; + } + for (u64 c = 0; c < sparsePoints.cols(); ++c) + { + std::cout << "(" << sparsePoints(r, c) << ", " << c << ") "; + } + std::cout << std::endl; + + for (u32 c = 0; c < points.size(); ++c) + { + std::cout << " group " << c << std::endl; + //std::sort(points[c].begin(), points[c].end(), [](auto& a, auto& b) {return a.mAddress < b.mAddress; }); + for (auto p : points[c]) + { + std::cout << print(p, bitCount) << std::endl; + } + + if (points[c].size() == 1) + { + mInitals[r][c] = Initial::Final; + auto idx = std::distance(sparsePoints.data(), &points[c][0]); + mFinals[r].back().push_back(points[c][0]); + } + else if (points[c].size() > 1) + { + mInitals[r][c] = Initial::Expand; + auto par = partition(points[c], bitCount - 1); + getLevels(par, r, sparsePoints); + } + else + { + mInitals[r][c] = Initial::None; + } + } + + std::vector set(sparsePoints.cols()); + std::vector>> states(bitCount + 1); + states.back().resize(preDomain); + for (auto point : sparsePoints[r]) + states.back()[point >> bitCount].push_back(point); + auto copyIter = mFinals[r].back().begin(); + for (u64 i = 0; i < mInitals[r].size(); ++i) + { + switch (mInitals[r][i]) + { + case Initial::None: + break; + case Initial::Final: + { + + if (states.back()[i].size() != 1) + throw RTE_LOC; + auto idx = *copyIter++; + if (sparsePoints[r][idx] != states.back()[i][0]) + throw RTE_LOC; + set[idx] = 1; + break; + } + case Initial::Expand: + states[bitCount - 1].push_back(states.back()[i]); + break; + } + } + + + if (mDebug) + { + + for (u64 i = bitCount; i != 0; --i) + { + auto& ex = mExpands[r][i - 1]; + auto& state = states[i - 1]; + auto copyIter = mFinals[r][i - 1].begin(); + + for (u64 j = 0; j < ex.size(); ++j) + { + std::cout << "expand " << i << " " << j << std::endl; + auto mid = std::partition(state[j].begin(), state[j].end(), [&](auto& a) {return !(a >> (i - 1) & 1); }); + for (auto p : state[j]) + { + + std::cout << print(p, i); + if (p == *mid) + std::cout << " <- mid"; + std::cout << std::endl; + } + std::cout << std::endl; + std::array, 2> sets{ span(state[j].begin(), mid), span(mid, state[j].end()) }; + for (u64 k = 0; k < 2; ++k) + { + if (ex[j].mDepth[k] == 0) + { + if (state[j].size() != 2) + throw RTE_LOC; + auto c0 = *copyIter++; + //auto c1 = *copyIter++; + + auto p0 = sparsePoints[r][c0]; + if (p0 != state[j][k]) + throw RTE_LOC; + + set[c0] = 1; + } + else + { + if (sets[k].size() <= 1) + throw RTE_LOC; + auto next = i - 1 - ex[j].mDepth[k]; + std::cout << "pushing set " << k << " to lvl " << next << std::endl; + states[next].push_back(std::vector(sets[k].begin(), sets[k].end())); + } + } + } + + } + if (std::find(set.begin(), set.end(), 0) != set.end()) + throw RTE_LOC;; + } + } + + } + + + template + macoro::task<> expand( + Output&& output, + PRNG& prng, + coproto::Socket& sock) + { + } + + + }; + +} + +#undef SIMD8 \ No newline at end of file diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 41c810a9..c7e8afc1 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -1,14 +1,16 @@ #include "RegularDpf_Tests.h" #include "libOTe/Tools/Dpf/RegularDpf.h" #include "coproto/Socket/LocalAsyncSock.h" - +#include "libOTe/Tools/Dpf/SparseDpf.h" +#include +#include using namespace oc; void RegularDpf_Multiply_Test(const CLP& cmd) { u64 n = 13; PRNG prng(block(231234, 321312)); - std::array dpf; + std::array dpf; dpf[0].mPartyIdx = 0; dpf[1].mPartyIdx = 1; dpf[0].mSendOts.push_back(prng.get()); @@ -102,7 +104,7 @@ void RegularDpf_Multiply_Test(const CLP& cmd) void RegularDpf_Proto_Test(const CLP& cmd) { PRNG prng(block(231234, 321312)); - u64 domain = 8; + u64 domain = 131; u64 numPoints = 11; std::vector points0(numPoints); std::vector points1(numPoints); @@ -117,8 +119,8 @@ void RegularDpf_Proto_Test(const CLP& cmd) } std::array dpf; - dpf[0].init(0, domain, points0, values0); - dpf[1].init(1, domain, points1, values1); + dpf[0].init(0, domain, numPoints); + dpf[1].init(1, domain, numPoints); auto baseCount = dpf[0].baseOtCount(); @@ -144,13 +146,16 @@ void RegularDpf_Proto_Test(const CLP& cmd) dpf[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); std::array, 2> output; + std::array, 2> tags; output[0].resize(numPoints, domain); output[1].resize(numPoints, domain); + tags[0].resize(numPoints, domain); + tags[1].resize(numPoints, domain); auto sock = coproto::LocalAsyncSocket::makePair(); macoro::sync_wait(macoro::when_all_ready( - dpf[0].expand(output[0], prng, sock[0]), - dpf[1].expand(output[1], prng, sock[1]) + dpf[0].expand(points0, values0, [&](auto k, auto i, auto v, auto t) { output[0](k, i) = v; tags[0](k, i) = t; }, prng, sock[0]), + dpf[1].expand(points1, values1, [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; tags[1](k, i) = t; }, prng, sock[1]) )); @@ -160,9 +165,37 @@ void RegularDpf_Proto_Test(const CLP& cmd) { auto p = points0[k] ^ points1[k]; auto act = output[0][k][i] ^ output[1][k][i]; - auto exp = i == p ? (values0[k] ^ values1[k]) : ZeroBlock; + auto t = i == p; + auto tAct = tags[0][k][i] ^ tags[1][k][i]; + auto exp = t ? (values0[k] ^ values1[k]) : ZeroBlock; if (exp != act) throw RTE_LOC; + if (t != tAct) + throw RTE_LOC; } } -} \ No newline at end of file +} + +void SparseDpf_Proto_Test(const oc::CLP& cmd) +{ + PRNG prng(block(32324, 2342)); + u64 numPoints = 1; + u64 domain = 256; + oc::SparseDpf dpf; + Matrix sparsePoints(numPoints, domain / 10); + std::vector set(domain); + std::iota(set.begin(), set.end(), 0); + for(u64 i = 0; i < sparsePoints.size(); ++i) + { + auto j = prng.get() % set.size(); + std::swap(set[j], set.back()); + sparsePoints(i) = set.back(); + set.pop_back(); + } + + for (u64 i = 0; i < sparsePoints.rows(); ++i) + { + std::sort(sparsePoints[i].begin(), sparsePoints[i].end()); + } + dpf.init(0, domain, sparsePoints); +} diff --git a/libOTe_Tests/RegularDpf_Tests.h b/libOTe_Tests/RegularDpf_Tests.h index 8ef7fdcc..b2304d9d 100644 --- a/libOTe_Tests/RegularDpf_Tests.h +++ b/libOTe_Tests/RegularDpf_Tests.h @@ -4,3 +4,4 @@ void RegularDpf_Multiply_Test(const oc::CLP& cmd); void RegularDpf_Proto_Test(const oc::CLP& cmd); +void SparseDpf_Proto_Test(const oc::CLP& cmd); diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 809194ee..da56bd71 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -60,6 +60,7 @@ namespace tests_libOTe tc.add("RegularDpf_Multiply_Test ", RegularDpf_Multiply_Test); tc.add("RegularDpf_Proto_Test ", RegularDpf_Proto_Test); + tc.add("SparseDpf_Proto_Test ", SparseDpf_Proto_Test); tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); tc.add("Bot_Simplest_asm_Test ", Bot_Simplest_asm_Test); From 6d1ade0292e76bdec7d5ff46203e29c59a382753 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Thu, 23 Jan 2025 10:14:16 -0800 Subject: [PATCH 03/48] sparse dpf --- cryptoTools | 2 +- libOTe/Tools/Dpf/DpfMult.h | 4 + libOTe/Tools/Dpf/RegularDpf.h | 13 +- libOTe/Tools/Dpf/SparseDpf.h | 757 ++++++++++++++++++++---------- libOTe_Tests/RegularDpf_Tests.cpp | 141 ++++-- 5 files changed, 646 insertions(+), 271 deletions(-) diff --git a/cryptoTools b/cryptoTools index 409c6d5c..d92336ec 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit 409c6d5c88bd8851eaa7818e2e11dd0c405c3188 +Subproject commit d92336ecde55fcd0918def7efda948e41b510965 diff --git a/libOTe/Tools/Dpf/DpfMult.h b/libOTe/Tools/Dpf/DpfMult.h index cabeae7a..abcd462b 100644 --- a/libOTe/Tools/Dpf/DpfMult.h +++ b/libOTe/Tools/Dpf/DpfMult.h @@ -23,6 +23,8 @@ namespace osuCrypto u64 mOtIdx = 0; + bool hasBaseOts() const { return mChoiceBits.size(); } + u8 lsb(const block& b) { return b.get(0) & 1; @@ -88,6 +90,8 @@ namespace osuCrypto throw RTE_LOC; if (x.size() + mOtIdx > mTotalMults) throw RTE_LOC; + if (hasBaseOts() == false) + throw RTE_LOC; BitVector a0; a0.append(mChoiceBits, x.size(), mOtIdx); AlignedUnVector A0(x.size()), C(x.size()), theta(x.size()), b1(x.size()); diff --git a/libOTe/Tools/Dpf/RegularDpf.h b/libOTe/Tools/Dpf/RegularDpf.h index b009f361..bb3da230 100644 --- a/libOTe/Tools/Dpf/RegularDpf.h +++ b/libOTe/Tools/Dpf/RegularDpf.h @@ -143,16 +143,23 @@ namespace osuCrypto } } + // at each iteration we first correct the parent level. + // The parent level has two syblings which are random. + // We need to correct the inactive child so that both parties + // hold the same seed (a sharing of zero). + // + // we then expand the parent to level to get the children level. + // We compute left and right sums for the children. for (u64 iter = 1; iter <= mDepth; ++iter) { - // the parent level + // the grand parent level auto& tp = t[(iter - 1) & 1]; - // the child level + // the parent level auto& sc = s[iter & 1]; auto& tc = t[iter & 1]; - // the grandchild level + // the child level auto& sg = s[(iter + 1) & 1]; auto size = 1ull << iter; diff --git a/libOTe/Tools/Dpf/SparseDpf.h b/libOTe/Tools/Dpf/SparseDpf.h index a24159cb..2804fa2a 100644 --- a/libOTe/Tools/Dpf/SparseDpf.h +++ b/libOTe/Tools/Dpf/SparseDpf.h @@ -14,322 +14,601 @@ namespace osuCrypto { u64 mPartyIdx = 0; + u64 mNumPoints = 0; + u64 mDomain = 0; - struct Point - { - // the point's true address - u32 mAddress; + u64 mDenseDepth = 0; - // the number of points before this point. - u32 mFinalRank; - }; - BitVector reverse(BitVector b) + RegularDpf mRegDpf; + + DpfMult mMultiplier; + + void init( + u64 partyIdx, + u64 numPoints, + u64 domain, + u64 denseDepth + ) { - BitVector r(b.size()); - for (u64 i = 0; i < b.size(); ++i) - r[r.size() - 1 - i] = b[i]; - return r; + mNumPoints = numPoints; + mPartyIdx = partyIdx; + mDomain = domain; + mDenseDepth = std::min(denseDepth, log2ceil(mDomain)); + auto depth = log2ceil(mDomain) - mDenseDepth; + mMultiplier.init(mPartyIdx, depth * mNumPoints); + if (mDenseDepth) + mRegDpf.init(mPartyIdx, 1ull << mDenseDepth, numPoints); } - std::string print(u32 p, u32 bitCount) + + u8 lsb(const block& b) { - auto low = reverse(BitVector((u8*)&p, bitCount)); - auto hgh = reverse(BitVector((u8*)&p, 32 - bitCount, bitCount)); - std::stringstream ss; - ss << hgh << "." << low; - return ss.str(); + return b.get(0) & 1; } - // For the given Expand node, mDepth[0] is how far down the - // expanded left child should be copied. mDepth[1] is how far down the - // expanded right child should be copied. - // - // 0 indicates that it is a final node and should be copied to the - // final output. - struct Expand + + u64 baseOtCount() const { - std::array mDepth; - }; + return log2ceil(mDomain) * mNumPoints; + } + - enum class Initial + void setBaseOts( + span> baseSendOts, + span recvBaseOts, + const oc::BitVector& baseChoices) { - None, - Final, - Expand - }; + auto count = baseOtCount(); + if (baseSendOts.size() != count) + throw RTE_LOC; + if (recvBaseOts.size() != count) + throw RTE_LOC; + if (baseChoices.size() != count) + throw RTE_LOC; + + auto denseCount = mRegDpf.baseOtCount(); + auto + sDense = baseSendOts.subspan(0, denseCount), + sRest = baseSendOts.subspan(denseCount); + + auto + rDense = recvBaseOts.subspan(0, denseCount), + rRest = recvBaseOts.subspan(denseCount); + + BitVector cDense, cRest; + cDense.append(baseChoices, denseCount); + cRest.append(baseChoices, count - denseCount, denseCount); + + if (denseCount) + mRegDpf.setBaseOts(sDense, rDense, cDense); + + mMultiplier.setBaseOts(sRest, rRest, cRest); + } + - bool mDebug = true; - std::vector> mInitals; - std::vector>> mFinals; - std::vector>> mExpands; - - - // A partition represents a path of mDepth degree 1 nodes - // followed a degree 2 node. e.g. - // * <| 0 - // * <| 1 - // * <| 2 - // * * - // This would have depth 2. A partition also - // contains two sets of points that are under - // the left and right subtree. struct Partition { - // The sets of points that are contained in the left and right subtree. - std::array, 2> mSets; + // the index of the bit that partitions the left and right + // sets. + span mRange; + span::iterator mMid; + + Partition() = default; + Partition(span range, span::iterator mid) + : mRange(range), mMid(mid) + { + } - // The number of degree 1 nodes that lead to the degree 2 node. - u32 mDepth; - //u32 mPrefix; - u32 mLowBitCount; + std::array, 2> children() + { + return + { span(mRange.begin(), mMid), span(mMid, mRange.end()) }; + } + + std::string print(u64 bitIdx) + { + std::stringstream ss; + ss << "bit " << bitIdx << " val " << (1 << bitIdx) << " {"; + --bitIdx; + for (auto iter = mRange.begin(); iter != mRange.end(); ++iter) + { + if (iter == mMid) + ss << ","; + + auto upper = *iter >> bitIdx; + auto lower = *iter & ((1 << bitIdx) - 1); + + ss << " " << upper << "." << lower; + } + ss << "}"; + return ss.str(); + } }; - Partition partition(span points, u32 lowBitCount) + + // the upper bits of points[i] are all the same and points is sorted. + // "upper bits" are defined as bits indexed by {upperBitsBegin,...,31} + // This function will look at the bits at index upperBitsBegin and paritions + // the points into two sets. + std::pair partition(span points, u32 upperBitsBegin) { + if (points.size() == 1) + return { 0, Partition{points, points.end()} }; #ifndef NDEBUG - u32 prefix = points[0] >> (lowBitCount + 1); - for (auto pp : points) - if (pp >> (lowBitCount + 1) != prefix) - throw RTE_LOC; + if (std::adjacent_find(points.begin(), points.end(), std::greater{}) != points.end()) + throw RTE_LOC; + if (std::any_of(points.begin(), points.end(), + [upperBitsBegin, prefix = points[0] >> (upperBitsBegin)](auto v) {return v >> (upperBitsBegin) != prefix; })) + throw RTE_LOC; + if (points.size() <= 1) + throw RTE_LOC; #endif - assert(points.size() > 1); - auto iter = std::find_if(points.begin(), points.end(), [lowBitCount](auto v) {return (v >> lowBitCount) & 1; }); - if (iter == points.begin() || iter == points.end()) - { - assert(lowBitCount); - auto p = partition(points, lowBitCount - 1); - ++p.mDepth; - return p; - } - - return Partition{ - .mSets{ span(points.begin(), iter), span(iter, points.end()) }, - .mDepth{0}, - .mLowBitCount = lowBitCount }; + Partition par; + do { + assert(upperBitsBegin != 0); + --upperBitsBegin; + par.mMid = std::upper_bound( + points.begin(), points.end(), 0, + [upperBitsBegin](auto, auto b) {return (b >> upperBitsBegin) & 1; }); + } while (par.mMid == points.begin() || par.mMid == points.end()); + + par.mRange = points; + return { upperBitsBegin + 1, par }; } - void getLevels(Partition& par, u64 treeIdx, span points) + struct Tree { - std::cout << "-> {"; - for (auto j = 0; j < 2; ++j) - { - if (j) - std::cout << ',' << std::endl; - else - std::cout << std::endl; + std::vector> mPartitions; + std::vector> mSeeds; + std::vector> mTags; + std::vector> mChild; + std::vector> mParent; - for (auto p : par.mSets[j]) - { - std::cout << print(p, par.mLowBitCount) << std::endl; - } + + std::vector> mZ; + std::vector mC; + std::vector> mTau; + std::vector mSigma; + + void resize(u64 depth) + { + mPartitions.resize(depth); + mSeeds.resize(depth); + mTags.resize(depth); + mChild.resize(depth); + mParent.resize(depth); + mZ.resize(depth); + mC.resize(depth); + mTau.resize(depth); + mSigma.resize(depth); } - std::cout << "}" << std::endl; - Expand expand; - for (u64 i = 0; i < 2; ++i) + struct Level { - if (par.mSets[i].size() == 1) + Tree* mTree; + u64 mIdx; + + void push_back(u8 child, u8 parentLevel, Partition& b, block seed, u8 tag) { - auto idx = std::distance(points.data(), &par.mSets[i][0]); - mFinals[treeIdx][par.mLowBitCount].push_back(idx); - expand.mDepth[i] = 0; - //expand = (Expand)((u8)i | (u8)expand); + mTree->mPartitions[mIdx].push_back(b); + mTree->mSeeds[mIdx].push_back(seed); + mTree->mTags[mIdx].push_back(tag); + mTree->mChild[mIdx].push_back(child); + mTree->mParent[mIdx].push_back(parentLevel); } - else - { - // * <| par - // * <| - // * <| - // * * <| p2 - // * <| - // * <| - // * * - - assert(par.mSets[i].size()); - auto bIdx2 = par.mLowBitCount - 1 - par.mDepth; - auto p2 = partition(par.mSets[i], bIdx2); - getLevels(p2, treeIdx, points); - expand.mDepth[i] = p2.mDepth + 1; - //expand.mOutputs[i] = Address{ .mLevel{bIdx2}, .mIndex{mSizes[treeIdx][bIdx2]++} }; + u64 size() const + { + return mTree->mPartitions[mIdx].size(); } + }; + Level operator[](u64 i) + { + return { this, i }; } - mExpands[treeIdx][par.mLowBitCount].push_back(expand); - } - void init( - u64 partyIdx, - u64 domain, - MatrixView sparsePoints) + + }; + + template + macoro::task<> expand( + span points, + span values, + Output&& output, + PRNG& prng, + MatrixView sparsePoints, + coproto::Socket& sock) { - mPartyIdx = partyIdx; - mDomain = domain; + if (points.size() != sparsePoints.rows()) + throw RTE_LOC; + u64 depth = log2ceil(mDomain) - mDenseDepth; - auto depth = log2ceil(domain); - auto preDepth = log2ceil(sparsePoints.cols()); - auto preDomain = 1ull << preDepth; - auto bitCount = depth - preDepth; - auto shift = bitCount - 1; - u32 mask = (1ull << bitCount) - 1; + std::vector trees(mNumPoints); + for (u64 i = 0; i < mNumPoints; ++i) + { + trees[i].resize(depth + 1); + } + using T = block; + std::unique_ptr mem; + std::vector> leafValues(mNumPoints); + std::vector> leafTags(mNumPoints); + u64 totalSize = 0; + for (u64 i = 0; i < mNumPoints; ++i) + totalSize += sparsePoints[i].size(); + + mem.reset(new u8[totalSize * (sizeof(T) + 1)]()); + auto iter = mem.get(); + for (u64 i = 0; i < mNumPoints; ++i) + { + leafValues[i] = span((T*)iter, sparsePoints[i].size()); + iter += leafValues[i].size_bytes(); + } + for (u64 i = 0; i < mNumPoints; ++i) + { + leafTags[i] = span(iter, sparsePoints[i].size()); + iter += leafTags[i].size_bytes(); + } - mExpands.resize(sparsePoints.rows()); - mFinals.resize(sparsePoints.rows()); - mInitals.resize(sparsePoints.rows()); + //for (u64 i = 0; i < mNumPoints; ++i) + //{ + // for (auto p : sparsePoints[i]) + // std::cout << p << " "; + // std::cout << std::endl; + //} - for (u64 r = 0; r < sparsePoints.rows(); ++r) + + if (mDenseDepth) { - assert(std::is_sorted(sparsePoints[r].begin(), sparsePoints[r].end())); - - mInitals[r].resize(preDomain); - mExpands[r].resize(bitCount); - mFinals[r].resize(bitCount + 1); - std::vector> points(preDomain); - auto iter = sparsePoints[r].begin(); - while (iter != sparsePoints[r].end()) - { - auto p = *iter; - auto idx = p >> bitCount; - auto end = std::find_if(iter, sparsePoints[r].end(), [idx, bitCount](auto v) {return (v >> bitCount) != idx; }); - points[idx] = span(iter, end); - iter = end; - } - for (u64 c = 0; c < sparsePoints.cols(); ++c) - { - std::cout << "(" << sparsePoints(r, c) << ", " << c << ") "; - } - std::cout << std::endl; - for (u32 c = 0; c < points.size(); ++c) + if (mDenseDepth > log2ceil(mDomain)) + throw RTE_LOC; + + std::vector densePoints(points.size()); + for (u64 i = 0; i < points.size(); ++i) + densePoints[i] = points[i] >> depth; + Matrix seeds(points.size(), 1ull << mDenseDepth); + Matrix tags(points.size(), 1ull << mDenseDepth); + co_await mRegDpf.expand(densePoints, {}, [&](auto treeIdx, auto leafIdx, auto seed, auto tag) { + seeds(treeIdx, leafIdx) = seed; + tags(treeIdx, leafIdx) = tag; + }, prng, sock); + + for (u64 r = 0; r < sparsePoints.rows(); ++r) { - std::cout << " group " << c << std::endl; - //std::sort(points[c].begin(), points[c].end(), [](auto& a, auto& b) {return a.mAddress < b.mAddress; }); - for (auto p : points[c]) + //auto seed = seeds[i] + auto& tree = trees[r]; + auto iter = sparsePoints[r].begin(); + auto end = sparsePoints[r].end(); + while (iter != end) { - std::cout << print(p, bitCount) << std::endl; - } + auto p = *iter; + auto bin = p >> depth; + auto seed = seeds(r, bin); + auto tag = tags(r, bin); + + auto e = std::find_if(iter, end, [bin, depth](auto v) {return (v >> depth) != bin; }); + auto points = span(iter, e); + if (points.size() == 1) + { + auto idx = std::distance(sparsePoints.data(r), &points[0]); + leafValues[r][idx] = seed; + leafTags[r][idx] = tag; + //std::cout << "p " << mPartyIdx << " leaf " << idx << " seed " << seed << " " << int(tag) << std::endl; + } + else if (points.size()) + { + auto [delta, root] = partition(points, depth); - if (points[c].size() == 1) - { - mInitals[r][c] = Initial::Final; - auto idx = std::distance(sparsePoints.data(), &points[c][0]); - mFinals[r].back().push_back(points[c][0]); + block cSeeds[2]; + cSeeds[0] = mAesFixedKey.hashBlock(seed ^ ZeroBlock); + cSeeds[1] = mAesFixedKey.hashBlock(seed ^ OneBlock); + + auto children = root.children(); + for (u64 j = 0; j < 2; ++j) + { + auto [delta2, b2] = partition(children[j], delta); + //if (!mPartyIdx) + // std::cout << b2.print(delta2) << std::endl; + //std::cout << "p " << mPartyIdx << " d " << delta << " j " << j <<" bin " << bin << " seed " << cSeeds[j] << " " << int(tag) << std::endl; + tree[delta2].push_back(j, delta, b2, cSeeds[j], tag); + tree.mZ[delta][j] ^= cSeeds[j]; + tree.mC[delta] = 1; + } + } + iter = e; } - else if (points[c].size() > 1) + } + } + else + { + + for (u64 r = 0; r < mNumPoints; ++r) + { + auto points = sparsePoints[r]; + auto& tree = trees[r]; + // range must be sorted and unique + //if (std::adjacent_find(points.begin(), points.end(), std::greater{}) != points.end()) + // throw RTE_LOC; + + auto [delta, b] = partition(points, depth); + //if (!mPartyIdx) + // std::cout << b.print(delta) << std::endl; + + auto children = b.children(); + for (u64 j = 0; j < 2; ++j) { - mInitals[r][c] = Initial::Expand; - auto par = partition(points[c], bitCount - 1); - getLevels(par, r, sparsePoints); + auto [delta2, b2] = partition(children[j], delta); + //if (!mPartyIdx) + // std::cout << b2.print(delta2) << std::endl; + block seed = prng.get(); + //std::cout << "p " << mPartyIdx << " d " << delta << " j " << j << " seed " << seed << std::endl; + tree[delta2].push_back(j, delta, b2, seed, mPartyIdx); + tree.mZ[delta][j] = seed; + tree.mC[delta] = 1; } + } + } + + + for (u64 d = depth; d; --d) + { + //std::cout << "-----" << d << "-----" << std::endl; + + std::vector z0(mNumPoints); + std::vector z1(mNumPoints); + + BitVector negAlpha(mNumPoints); + std::vector> taus(mNumPoints); + std::vector sigmas(mNumPoints); + bool used = false; + for (u64 r = 0; r < mNumPoints; ++r) + { + auto& tree = trees[r]; + if (tree.mC[d] == 0) + tree.mZ[d] = prng.get(); else + used = true; + + auto alphaD = (points[r] >> (d - 1)) & 1; + taus[r][0] = lsb(tree.mZ[d][0]) ^ alphaD ^ mPartyIdx; + taus[r][1] = lsb(tree.mZ[d][1]) ^ alphaD; + + //std::cout << "p " << mPartyIdx << " d " << d << " z " << tree.mZ[d][0] << " " <" << d << "<" << std::endl; + trees[r].mSigma[d] = sigmas[r]; + trees[r].mTau[d] = taus[r]; } } - std::vector set(sparsePoints.cols()); - std::vector>> states(bitCount + 1); - states.back().resize(preDomain); - for (auto point : sparsePoints[r]) - states.back()[point >> bitCount].push_back(point); - auto copyIter = mFinals[r].back().begin(); - for (u64 i = 0; i < mInitals[r].size(); ++i) + auto dNext = d - 1; + if (dNext == 0) + break; + + //std::cout << "vvvvv" << dNext << "vvvvv" << std::endl; + + for (u64 r = 0; r < mNumPoints; ++r) { - switch (mInitals[r][i]) - { - case Initial::None: - break; - case Initial::Final: - { + auto& tree = trees[r]; + auto size = tree.mSeeds[dNext].size(); - if (states.back()[i].size() != 1) - throw RTE_LOC; - auto idx = *copyIter++; - if (sparsePoints[r][idx] != states.back()[i][0]) - throw RTE_LOC; - set[idx] = 1; - break; - } - case Initial::Expand: - states[bitCount - 1].push_back(states.back()[i]); - break; + for (u64 i = 0; i < size; ++i) + { + auto& seed = tree.mSeeds[dNext][i]; + auto par = tree.mPartitions[dNext][i]; + auto tag = tree.mTags[dNext][i]; + auto child = tree.mChild[dNext][i]; + auto parent = tree.mParent[dNext][i]; + auto pTau = tree.mTau[parent][child]; + auto pSigma = tree.mSigma[parent]; + + auto cTag = lsb(seed) ^ tag * pTau; + //auto old = seed; + seed = seed ^ (pSigma & block::allSame(-tag)); + + //std::cout <<"p " << mPartyIdx << " d " << dNext << " i " << i << " " + // < " << seed << " via >"<< int(parent)<< "< " << pSigma << " t " << int(tag) << std::endl; + + std::array cSeed; + cSeed[0] = mAesFixedKey.hashBlock(seed ^ ZeroBlock); + cSeed[1] = mAesFixedKey.hashBlock(seed ^ OneBlock); + + tree.mZ[dNext][0] = tree.mZ[dNext][0] ^ cSeed[0]; + tree.mZ[dNext][1] = tree.mZ[dNext][1] ^ cSeed[1]; + tree.mC[dNext] = 1; + + auto children = par.children(); + for (u64 j = 0; j < 2; ++j) + { + auto [cd, cPar] = partition(children[j], dNext); + + //if(!mPartyIdx) + // std::cout << cPar.print(cd) << std::endl; + tree.mPartitions[cd].push_back(cPar); + tree.mChild[cd].push_back(j); + tree.mParent[cd].push_back(dNext); + tree.mSeeds[cd].push_back(cSeed[j]); + tree.mTags[cd].push_back(cTag); + } } } + } + std::vector gamma(values.begin(), values.end()); + for (u64 r = 0; r < mNumPoints; ++r) + { + auto& tree = trees[r]; + auto size = tree.mSeeds[0].size(); - if (mDebug) + for (u64 i = 0; i < size; ++i) { + auto& seed = tree.mSeeds[0][i]; + auto& tag = tree.mTags[0][i]; + auto j = tree.mChild[0][i]; + auto parent = tree.mParent[0][i]; + auto par = tree.mPartitions[0][i]; + auto pTau = tree.mTau[parent][j]; + auto pSigma = tree.mSigma[parent]; + + auto b = std::distance(sparsePoints.data(r), par.mRange.data()); + leafTags[r][b] = lsb(seed) ^ tag * pTau; + leafValues[r][b] = seed ^ (pSigma & block::allSame(-tag)); + gamma[r] = gamma[r] ^ leafValues[r][b]; + //std::cout << "p " << mPartyIdx << " d " << 0 << " i " << i << " " + // << seed << " -> " << leafValues[r][b] << " via >" << int(parent) << "< " << pSigma << " t " << int(tag) << std::endl; - for (u64 i = bitCount; i != 0; --i) - { - auto& ex = mExpands[r][i - 1]; - auto& state = states[i - 1]; - auto copyIter = mFinals[r][i - 1].begin(); + } + } - for (u64 j = 0; j < ex.size(); ++j) - { - std::cout << "expand " << i << " " << j << std::endl; - auto mid = std::partition(state[j].begin(), state[j].end(), [&](auto& a) {return !(a >> (i - 1) & 1); }); - for (auto p : state[j]) - { + co_await reveal(gamma, sock); + //std::cout << "-----------final-------------" << std::endl; + for (u64 r = 0; r < mNumPoints; ++r) + { + auto& tree = trees[r]; + auto size = sparsePoints[r].size(); + for (u64 i = 0; i < size; ++i) + { + assert(leafValues[r][i] != oc::ZeroBlock); + auto val = leafValues[r][i] ^ (gamma[r] & block::allSame(-leafTags[r][i])); - std::cout << print(p, i); - if (p == *mid) - std::cout << " <- mid"; - std::cout << std::endl; - } - std::cout << std::endl; - std::array, 2> sets{ span(state[j].begin(), mid), span(mid, state[j].end()) }; - for (u64 k = 0; k < 2; ++k) - { - if (ex[j].mDepth[k] == 0) - { - if (state[j].size() != 2) - throw RTE_LOC; - auto c0 = *copyIter++; - //auto c1 = *copyIter++; - - auto p0 = sparsePoints[r][c0]; - if (p0 != state[j][k]) - throw RTE_LOC; - - set[c0] = 1; - } - else - { - if (sets[k].size() <= 1) - throw RTE_LOC; - auto next = i - 1 - ex[j].mDepth[k]; - std::cout << "pushing set " << k << " to lvl " << next << std::endl; - states[next].push_back(std::vector(sets[k].begin(), sets[k].end())); - } - } - } + //std::cout << "p " << mPartyIdx << " d " << 0 << " i " << i << " " + // << leafValues[r][i] << " -> " << val << " via " << gamma[r] << " t " << int(leafTags[r][i]) << std::endl; - } - if (std::find(set.begin(), set.end(), 0) != set.end()) - throw RTE_LOC;; + output(r, i, val, leafTags[r][i]); } } + co_return; } - template - macoro::task<> expand( - Output&& output, - PRNG& prng, - coproto::Socket& sock) + + //std::vector gamma(values.begin(), values.end()); + //for (u64 r = 0; r < mNumPoints; ++r) + //{ + // auto& tree = trees[r]; + // auto size = sparsePoints[r].size(); + // if (tree.mSeeds[0].size() != size) + // throw RTE_LOC; + // for (u64 i = 0; i < size; ++i) + // { + // auto& seed = tree.mSeeds[0][i]; + // auto& tag = tree.mTags[0][i]; + // auto j = tree.mChild[0][i]; + // auto parent = tree.mParent[0][i]; + // auto par = tree.mPartitions[0][i]; + // auto pTau = tree.mTau[parent][j]; + // auto pSigma = tree.mSigma[parent]; + + // //auto old = seed; + + // auto b = std::distance(sparsePoints.data(r), par.mRange.data()); + // leafTags[r][b] = lsb(seed) ^ tag * pTau; + // leafValues[r][b] = seed ^ (pSigma & block::allSame(-tag)); + + // //std::cout << "p " << mPartyIdx << " d " << 0 << " i " << i << " " + // // << old << " -> " << seed << " via >" << int(parent) << "< " << pSigma << " t " << int(tag) << std::endl; + + + // gamma[r] = gamma[r] ^ seed; + // } + //} + + //co_await reveal(gamma, sock); + ////std::cout << "-----------final-------------" << std::endl; + //for (u64 r = 0; r < mNumPoints; ++r) + //{ + // //auto& tree = trees[r]; + // auto size = sparsePoints[r].size(); + // for (u64 i = 0; i < size; ++i) + // { + // //auto seed = tree.mSeeds[0][i]; + // //auto tag = tree.mTags[0][i]; + + + // //auto old = seed; + // auto val = leafValues[r][i] ^ (gamma[r] & block::allSame(-leafTags[r][i])); + + // //std::cout << "p " << mPartyIdx << " d " << 0 << " i " << i << " " + // // << old << " -> " << seed << " via " << gamma[r] << " t " << int(tag) << std::endl; + + // output(r, i, val, leafTags[r][i]); + // } + //} + + + + macoro::task<> reveal(span sigma, span> tau, coproto::Socket& sock) { + if (sigma.size() != tau.size()) + throw RTE_LOC; + std::vector sBuff(sigma.begin(), sigma.end()); + std::vector> tBuff(tau.begin(), tau.end()); + co_await macoro::when_all_ready( + sock.send(std::move(sBuff)), + sock.send(std::move(tBuff)) + ); + sBuff.resize(sigma.size()); + tBuff.resize(tau.size()); + co_await macoro::when_all_ready( + sock.recv(sBuff), + sock.recv(tBuff) + ); + for (u64 i = 0; i < sigma.size(); ++i) + { + sigma[i] = sigma[i] ^ sBuff[i]; + tau[i][0] = tau[i][0] ^ tBuff[i][0]; + tau[i][1] = tau[i][1] ^ tBuff[i][1]; + } } + macoro::task<> reveal(span sigma, coproto::Socket& sock) + { + std::vector sBuff(sigma.begin(), sigma.end()); + co_await sock.send(std::move(sBuff)); + sBuff.resize(sigma.size()); + co_await sock.recv(sBuff); + for (u64 i = 0; i < sigma.size(); ++i) + { + sigma[i] = sigma[i] ^ sBuff[i]; + } + } + }; } -#undef SIMD8 \ No newline at end of file +#undef SIMD8 + diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index c7e8afc1..bccb9aef 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -11,15 +11,25 @@ void RegularDpf_Multiply_Test(const CLP& cmd) u64 n = 13; PRNG prng(block(231234, 321312)); std::array dpf; - dpf[0].mPartyIdx = 0; - dpf[1].mPartyIdx = 1; - dpf[0].mSendOts.push_back(prng.get()); - dpf[1].mSendOts.push_back(prng.get()); - dpf[0].mChoiceBits.pushBack(0); - dpf[1].mChoiceBits.pushBack(1); - dpf[0].mRecvOts.push_back(dpf[1].mSendOts[0][dpf[0].mChoiceBits[0]]); - dpf[1].mRecvOts.push_back(dpf[0].mSendOts[0][dpf[1].mChoiceBits[0]]); + dpf[0].init(0, n); + dpf[1].init(1, n); + std::array>, 2> sendOts; + std::array, 2> recvOts; + std::array choices; + for (u64 i = 0; i < 2; ++i) + { + sendOts[i].resize(n); + recvOts[i].resize(n); + choices[i].resize(n); + choices[i].randomize(prng); + prng.get(sendOts[i].data(), sendOts[i].size()); + for (u64 j = 0; j < n; ++j) + recvOts[i][j] = sendOts[i][j][choices[i][j]]; + } + + dpf[0].setBaseOts(sendOts[0], recvOts[1], choices[1]); + dpf[1].setBaseOts(sendOts[1], recvOts[0], choices[0]); { u64 i = 0; @@ -59,16 +69,15 @@ void RegularDpf_Multiply_Test(const CLP& cmd) { //std::cout << "-=========================-" std::endl; - for (u64 j = 0; j < n; ++j) + for (u64 i = 0; i < 2; ++i) { - dpf[0].mSendOts.push_back(prng.get()); - dpf[1].mSendOts.push_back(prng.get()); - dpf[0].mChoiceBits.pushBack(prng.getBit()); - dpf[1].mChoiceBits.pushBack(prng.getBit()); - dpf[0].mRecvOts.push_back(dpf[1].mSendOts.back()[dpf[0].mChoiceBits.back()]); - dpf[1].mRecvOts.push_back(dpf[0].mSendOts.back()[dpf[1].mChoiceBits.back()]); + choices[i].randomize(prng); + prng.get(sendOts[i].data(), sendOts[i].size()); + for (u64 j = 0; j < n; ++j) + recvOts[i][j] = sendOts[i][j][choices[i][j]]; } - + dpf[0].setBaseOts(sendOts[0], recvOts[1], choices[1]); + dpf[1].setBaseOts(sendOts[1], recvOts[0], choices[0]); BitVector x0(n), x1(n); x0.randomize(prng); @@ -165,7 +174,7 @@ void RegularDpf_Proto_Test(const CLP& cmd) { auto p = points0[k] ^ points1[k]; auto act = output[0][k][i] ^ output[1][k][i]; - auto t = i == p; + auto t = i == p ? 1 : 0; auto tAct = tags[0][k][i] ^ tags[1][k][i]; auto exp = t ? (values0[k] ^ values1[k]) : ZeroBlock; if (exp != act) @@ -180,22 +189,98 @@ void SparseDpf_Proto_Test(const oc::CLP& cmd) { PRNG prng(block(32324, 2342)); u64 numPoints = 1; - u64 domain = 256; - oc::SparseDpf dpf; - Matrix sparsePoints(numPoints, domain / 10); - std::vector set(domain); - std::iota(set.begin(), set.end(), 0); - for(u64 i = 0; i < sparsePoints.size(); ++i) + u64 domain = 1773; + u64 dense = 4; + u64 fraction = 16; + + auto index{ std::vector(numPoints) }; + auto value{ std::vector(numPoints) }; + std::array points{ std::vector(numPoints),std::vector(numPoints) }; + std::array values{ std::vector(numPoints), std::vector(numPoints) }; + oc::SparseDpf dpf[2]; + Matrix sparsePoints(numPoints, domain / fraction); + + for (u64 j = 0; j < numPoints; ++j) { - auto j = prng.get() % set.size(); - std::swap(set[j], set.back()); - sparsePoints(i) = set.back(); - set.pop_back(); + std::vector set(domain); + std::iota(set.begin(), set.end(), 0); + for (u64 i = 0; i < sparsePoints.cols(); ++i) + { + auto k = prng.get() % set.size(); + std::swap(set[k], set.back()); + sparsePoints(j, i) = set.back(); + set.pop_back(); + } } for (u64 i = 0; i < sparsePoints.rows(); ++i) { std::sort(sparsePoints[i].begin(), sparsePoints[i].end()); + index[i] = prng.get() % sparsePoints.cols(); + value[i] = prng.get(); + auto alpha = sparsePoints(i, index[i]); + //std::cout << "alpha " << alpha << " " << oc::BitVector((u8*)&alpha, log2ceil(domain)) << std::endl; + points[0][i] = prng.get(); + points[1][i] = points[0][i] ^ sparsePoints(i, index[i]); + values[0][i] = prng.get(); + values[1][i] = values[0][i] ^ value[i]; + } + + + dpf[0].init(0, numPoints, domain, dense); + dpf[1].init(1, numPoints, domain, dense); + auto sock = coproto::LocalAsyncSocket::makePair(); + + auto baseCount = dpf[0].baseOtCount(); + + std::array>, 2> sendOts; + std::array, 2> recvOts; + std::array choices; + for (u64 i = 0; i < 2; ++i) + { + sendOts[i].resize(baseCount); + recvOts[i].resize(baseCount); + choices[i].resize(baseCount); + choices[i].randomize(prng); + prng.get(sendOts[i].data(), sendOts[i].size()); + for (u64 j = 0; j < baseCount; ++j) + recvOts[i][j] = sendOts[i][j][choices[i][j]]; + } + + dpf[0].setBaseOts(sendOts[0], recvOts[1], choices[1]); + dpf[1].setBaseOts(sendOts[1], recvOts[0], choices[0]); + + + std::array, 2> out; + std::array, 2> tags; + out[0].resize(numPoints, sparsePoints.cols()); + out[1].resize(numPoints, sparsePoints.cols()); + tags[0].resize(numPoints, sparsePoints.cols()); + tags[1].resize(numPoints, sparsePoints.cols()); + auto r = macoro::sync_wait( + macoro::when_all_ready( + dpf[0].expand(points[0], values[0], [&](auto k, auto i, auto v, auto t) { out[0](k, i) = v; tags[0](k, i) = t; }, prng, sparsePoints, sock[0]), + dpf[1].expand(points[1], values[1], [&](auto k, auto i, auto v, auto t) { out[1](k, i) = v; tags[1](k, i) = t; }, prng, sparsePoints, sock[1]) + )); + + + std::get<0>(r).result(); + std::get<1>(r).result(); + + for (u64 i = 0; i < numPoints; ++i) + { + for (u64 j = 0; j < sparsePoints.cols(); ++j) + { + auto active = index[i] == j ? 1 : 0; + + auto tag = tags[0](i, j) ^ tags[1](i, j); + if (tag != active) + throw RTE_LOC; + + auto act = out[0](i, j) ^ out[1](i, j); + auto exp = active ? value[i] : ZeroBlock; + if (act != exp) + throw RTE_LOC; + } } - dpf.init(0, domain, sparsePoints); } From 23ef8f13d84452a8913822e5f3d0787760d5680b Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 27 Jan 2025 01:02:05 -0800 Subject: [PATCH 04/48] ported foliage and some optimizations --- CMakePresets.json | 7 +- libOTe/Tools/Foliage/F4Ops.h | 214 ++++ libOTe/Tools/Foliage/FoliageMain.cpp | 304 ++++++ libOTe/Tools/Foliage/FoliageUtils.h | 263 +++++ libOTe/Tools/Foliage/fft/FoliageFFT_bench.cpp | 138 +++ libOTe/Tools/Foliage/fft/FoliageFFT_bench.h | 13 + libOTe/Tools/Foliage/fft/FoliageFft.cpp | 311 ++++++ libOTe/Tools/Foliage/fft/FoliageFft.h | 37 + libOTe/Tools/Foliage/spfss_test.cpp | 158 +++ libOTe/Tools/Foliage/tri-dpf/.gitignore | 5 + libOTe/Tools/Foliage/tri-dpf/FoliageDpf.cpp | 317 ++++++ libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h | 35 + .../Tools/Foliage/tri-dpf/FoliageDpf_test.cpp | 166 +++ .../Tools/Foliage/tri-dpf/FoliageDpf_test.h | 9 + libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h | 83 ++ libOTe/Tools/Foliage/tri-dpf/LICENSE | 9 + libOTe/Tools/Foliage/tri-dpf/README.md | 116 +++ libOTe/Tools/Foliage/tri-dpf/TriDpfUtils.h | 68 ++ libOTe/Tools/Foliage/uint128.h | 801 +++++++++++++++ libOTe_Tests/CMakeLists.txt | 5 +- libOTe_Tests/Foliage_Tests.cpp | 954 ++++++++++++++++++ libOTe_Tests/Foliage_Tests.h | 11 + libOTe_Tests/UnitTests.cpp | 6 + 23 files changed, 4024 insertions(+), 6 deletions(-) create mode 100644 libOTe/Tools/Foliage/F4Ops.h create mode 100644 libOTe/Tools/Foliage/FoliageMain.cpp create mode 100644 libOTe/Tools/Foliage/FoliageUtils.h create mode 100644 libOTe/Tools/Foliage/fft/FoliageFFT_bench.cpp create mode 100644 libOTe/Tools/Foliage/fft/FoliageFFT_bench.h create mode 100644 libOTe/Tools/Foliage/fft/FoliageFft.cpp create mode 100644 libOTe/Tools/Foliage/fft/FoliageFft.h create mode 100644 libOTe/Tools/Foliage/spfss_test.cpp create mode 100644 libOTe/Tools/Foliage/tri-dpf/.gitignore create mode 100644 libOTe/Tools/Foliage/tri-dpf/FoliageDpf.cpp create mode 100644 libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h create mode 100644 libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.cpp create mode 100644 libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.h create mode 100644 libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h create mode 100644 libOTe/Tools/Foliage/tri-dpf/LICENSE create mode 100644 libOTe/Tools/Foliage/tri-dpf/README.md create mode 100644 libOTe/Tools/Foliage/tri-dpf/TriDpfUtils.h create mode 100644 libOTe/Tools/Foliage/uint128.h create mode 100644 libOTe_Tests/Foliage_Tests.cpp create mode 100644 libOTe_Tests/Foliage_Tests.h diff --git a/CMakePresets.json b/CMakePresets.json index 4451b953..fc21c0f1 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -8,7 +8,7 @@ "generator": "Ninja", "binaryDir": "${sourceDir}/out/build/${presetName}", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_BUILD_TYPE": "RelWithDebInfo", "FETCH_AUTO": true, "ENABLE_ALL_OT": true, "ENABLE_SSE": true, @@ -36,9 +36,8 @@ "microsoft.com/VisualStudioRemoteSettings/CMake/1.0": { "sourceDir": "$env{HOME}/.vs/$ms{projectDirName}", "copySourcesOptions": { - "exclusionList": [ ".vs", "out/build", "out/install", "out/boost*", "out/relic/build", ".git" ] - }, - "rsyncCommandArgs": [ "-t", "--delete", "--include=${sourceDir}/out/macoro/*", "--verbose" ] + "exclusionList": [ ".vs", "out", ".git" ] + } } } }, diff --git a/libOTe/Tools/Foliage/F4Ops.h b/libOTe/Tools/Foliage/F4Ops.h new file mode 100644 index 00000000..a533a689 --- /dev/null +++ b/libOTe/Tools/Foliage/F4Ops.h @@ -0,0 +1,214 @@ +#pragma once + +#include "libOTe/Tools/Foliage/FoliageUtils.h" + +namespace osuCrypto +{ + //typedef __int128 int128_t; + //typedef unsigned __int128 uint128_t; + + // Samples a non-zero element of F4 + inline uint8_t rand_f4x(PRNG& prng) + { + uint8_t t; + unsigned char rand_byte; + + // loop until we have two bits where at least one is non-zero + while (1) + { + rand_byte = prng.get(); + t = 0; + t |= rand_byte & 1; + t = t << 1; + t |= (rand_byte >> 1) & 1; + + if (t != 0 && t != 4) + return t; + } + } + + // Multiplies two elements of F4 (optionally: 4 elements packed into uint8_t) + // and returns the result. + inline uint8_t mult_f4(uint8_t a, uint8_t b) + { + u8 tmp = ((a & 0b10) & (b & 0b10)); + uint8_t res = tmp ^ (((a & 0b10) & ((b & 0b01) << 1)) ^ (((a & 0b01) << 1) & (b & 0b10))); + res |= ((a & 0b01) & (b & 0b01)) ^ (tmp >> 1); + return res; + } + + // Multiplies two packed matrices of F4 elements column-by-column. + // Note that here the "columns" are packed into an element of uint8_t + // resulting in a matrix with 4 columns. + inline void multiply_fft_8( + span a_poly, + span b_poly, + span res_poly, + size_t poly_size) + { + const uint8_t pattern = 0xaa; + uint8_t mask_h = pattern; // 0b101010101010101001010 + uint8_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + uint8_t tmp; + uint8_t a_h, a_l, b_h, b_l; + + for (size_t i = 0; i < poly_size; i++) + { + // multiplication over F4 + a_h = (a_poly[i] & mask_h); + a_l = (a_poly[i] & mask_l); + b_h = (b_poly[i] & mask_h); + b_l = (b_poly[i] & mask_l); + + tmp = (a_h & b_h); + res_poly[i] = tmp ^ (a_h & (b_l << 1)); + res_poly[i] ^= ((a_l << 1) & b_h); + res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); + } + } + + // Multiplies two packed matrices of F4 elements column-by-column. + // Note that here the "columns" are packed into an element of uint16_t + // resulting in a matrix with 8 columns. + inline void multiply_fft_16( + span a_poly, + span b_poly, + span res_poly, + size_t poly_size) + { + const uint16_t pattern = 0xaaaa; + uint16_t mask_h = pattern; // 0b101010101010101001010 + uint16_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + uint16_t tmp; + uint16_t a_h, a_l, b_h, b_l; + + for (size_t i = 0; i < poly_size; i++) + { + // multiplication over F4 + a_h = (a_poly[i] & mask_h); + a_l = (a_poly[i] & mask_l); + b_h = (b_poly[i] & mask_h); + b_l = (b_poly[i] & mask_l); + + tmp = (a_h & b_h); + res_poly[i] = tmp ^ (a_h & (b_l << 1)); + res_poly[i] ^= ((a_l << 1) & b_h); + res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); + } + } + + // Multiplies two packed matrices of F4 elements column-by-column. + // Note that here the "columns" are packed into an element of uint32_t + // resulting in a matrix with 16 columns. + inline void multiply_fft_32( + span a_poly, + span b_poly, + span res_poly, + size_t poly_size) + { + const uint32_t pattern = 0xaaaaaaaa; + uint32_t mask_h = pattern; // 0b101010101010101001010 + uint32_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + uint32_t tmp; + uint32_t a_h, a_l, b_h, b_l; + + for (size_t i = 0; i < poly_size; i++) + { + // multiplication over F4 + a_h = (a_poly[i] & mask_h); + a_l = (a_poly[i] & mask_l); + b_h = (b_poly[i] & mask_h); + b_l = (b_poly[i] & mask_l); + + tmp = (a_h & b_h); + res_poly[i] = tmp ^ (a_h & (b_l << 1)); + res_poly[i] ^= ((a_l << 1) & b_h); + res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); + } + } + + // Multiplies two packed matrices of F4 elements column-by-column. + // Note that here the "columns" are packed into an element of uint64_t + // resulting in a matrix with 32 columns. + inline void multiply_fft_64( + span a_poly, + span b_poly, + span res_poly, + size_t poly_size) + { + const uint64_t pattern = 0xaaaaaaaaaaaaaaaa; + uint64_t mask_h = pattern; // 0b101010101010101001010 + uint64_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + uint64_t tmp; + uint64_t a_h, a_l, b_h, b_l; + + for (size_t i = 0; i < poly_size; i++) + { + // multiplication over F4 + a_h = (a_poly[i] & mask_h); + a_l = (a_poly[i] & mask_l); + b_h = (b_poly[i] & mask_h); + b_l = (b_poly[i] & mask_l); + + tmp = (a_h & b_h); + res_poly[i] = tmp ^ (a_h & (b_l << 1)); + res_poly[i] ^= ((a_l << 1) & b_h); + res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); + } + } + + + + // samples the a polynomials and axa polynomials + inline void sample_a_and_a2(span fft_a, span fft_a2, size_t poly_size, size_t c, PRNG& prng) + { + if (c > 16) + throw RTE_LOC; + + prng.get(fft_a.data(), poly_size); + + // make a_0 the identity polynomial (in FFT space) i.e., all 1s + for (size_t i = 0; i < poly_size; i++) + { + fft_a[i] = (fft_a[i] & ~3ull) | 1; + } + + // FOR DEBUGGING: set fft_a to the identity + // for (size_t i = 0; i < poly_size; i++) + // { + // fft_a[i] = (0xaaaa >> 1); + // } + + uint32_t prod; + for (size_t j = 0; j < c; j++) + { + for (size_t k = 0; k < c; k++) + { + for (size_t i = 0; i < poly_size; i++) + { + auto a = (fft_a[i] >> (2 * j)) & 0b11; + auto b = (fft_a[i] >> (2 * k)) & 0b11; + auto a1 = a & 1; + auto a2 = a & 2; + auto b1 = b & 1; + auto b2 = b & 2; + + { + u8 tmp = (a2 & b2); + prod = tmp ^ ((a2 & (b1 << 1)) ^ ((a1 << 1) & b2)); + prod |= (a1 & b1) ^ (tmp >> 1); + //return res; + } + //prod = mult_f4(, ); + size_t slot = j * c + k; + fft_a2[i] |= prod << (2 * slot); + } + } + } + } + +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/FoliageMain.cpp b/libOTe/Tools/Foliage/FoliageMain.cpp new file mode 100644 index 00000000..7914afd5 --- /dev/null +++ b/libOTe/Tools/Foliage/FoliageMain.cpp @@ -0,0 +1,304 @@ +#include +#include +#include + +#include "libOTe/Tools/Foliage/F4Ops.h" +#include "libOTe/Tools/Foliage/fft/FoliageFft.h" + +#include "libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h" +#include "libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +// Benchmarks are less documented compared to test.c; see test.c to +// better understand what is done here for timing purposes. + +#define DPF_MSG_SIZE 8 +namespace osuCrypto +{ + + + double bench_pcg(size_t n, size_t c, size_t t) + { + if (c > 4) + { + printf("ERROR: currently only implemented for c <= 4"); + exit(0); + } + + const size_t poly_size = ipow(3, n); + PRNG prng(block(342)); + + //************************************************************************ + // Step 0: Sample the global (1, a1 ... a_c-1) polynomials + //************************************************************************ + AlignedUnVector fft_a(poly_size); + AlignedUnVector fft_a2(poly_size); + sample_a_and_a2(fft_a, fft_a2, poly_size, c, prng); + + //************************************************************************ + // Step 1: Sample DPF keys for the cross product. + // For benchmarking purposes, we sample random DPF functions for a + // sufficiently large domain size to express a block of coefficients. + //************************************************************************ + size_t dpf_domain_bits = ceil(log_base(poly_size / (t * DPF_MSG_SIZE * 64), 3)); + printf("dpf_domain_bits = %zu \n", dpf_domain_bits); + + size_t seed_size_bits = (128 * (dpf_domain_bits * 3 + 1) + DPF_MSG_SIZE * 128) * c * c * t * t; + printf("PCG seed size: %.2f MB\n", seed_size_bits / 8000000.0); + + size_t dpf_block_size = DPF_MSG_SIZE * ipow(3, dpf_domain_bits); + size_t block_size = ceil(poly_size / t); + + printf("block_size = %zu \n", block_size); + + std::vectordpf_keys_A(c * c * t * t); + std::vectordpf_keys_B(c * c * t * t); + + // Sample PRF keys for the DPFs + PRFKeys prf_keys; + prf_keys.gen(prng); + + // Sample DPF keys for each of the t errors in the t blocks + for (size_t i = 0; i < c; i++) + { + for (size_t j = 0; j < c; j++) + { + for (size_t k = 0; k < t; k++) + { + for (size_t l = 0; l < t; l++) + { + size_t index = i * c * t * t + j * t * t + k * t + l; + + // Pick a random index for benchmarking purposes + size_t alpha = random_index(block_size, prng); + + // Pick a random output message for benchmarking purposes + uint128_t beta[DPF_MSG_SIZE]; + prng.get(beta, DPF_MSG_SIZE); + + // Message (beta) is of size 8 blocks of 128 bits + DPFGen(prf_keys, dpf_domain_bits, alpha, beta, DPF_MSG_SIZE, dpf_keys_A[index], dpf_keys_B[index], prng); + } + } + } + } + + //************************************************ + printf("Benchmarking PCG evaluation \n"); + //************************************************ + + // Allocate memory for the DPF outputs (this is reused for each evaluation) + AlignedUnVector shares(dpf_block_size); + AlignedUnVector cache(dpf_block_size); + + // Allocate memory for the concatenated DPF outputs + const size_t packed_block_size = ceil(block_size / 64.0); + const size_t packed_poly_size = t * packed_block_size; + AlignedUnVector packed_polys(c * c * packed_poly_size); + + // Allocate memory for the output FFT + AlignedUnVector fft_u(poly_size); + + // Allocate memory for the final inner product + AlignedUnVector z_poly(poly_size); + AlignedUnVector res_poly_mat(poly_size); + + //************************************************************************ + // Step 3: Evaluate all the DPFs to recover shares of the c*c polynomials. + //************************************************************************ + + clock_t time; + time = clock(); + + size_t key_index; + uint128_t* poly_block; + size_t i, j, k, l, w; + for (i = 0; i < c; i++) + { + for (j = 0; j < c; j++) + { + const size_t poly_index = i * c + j; + uint128_t* packed_poly = &packed_polys[poly_index * packed_poly_size]; + + for (k = 0; k < t; k++) + { + poly_block = &packed_poly[k * packed_block_size]; + + for (l = 0; l < t; l++) + { + key_index = i * c * t * t + j * t * t + k * t + l; + + DPFFullDomainEval(dpf_keys_A[key_index], cache, shares); + + for (w = 0; w < packed_block_size; w++) + poly_block[w] ^= shares[w]; + } + } + } + } + + //************************************************************************ + // Step 3: Compute the transpose of the polynomials to pack them into + // the parallel FFT format. + // + // TODO: this is the bottleneck of the computation and can be improved + // using SIMD operations for performing matrix transposes (see TODO in test.c). + //************************************************************************ + for (size_t i = 0; i < c * c; i++) + { + size_t poly_index = i * packed_poly_size; + const uint128_t* poly = &packed_polys[poly_index]; + +#ifdef ENABLE_SSE + _mm_prefetch((char*)poly, _MM_HINT_T2); +#endif // ENABLE_SSE + + size_t block_idx, packed_coeff_idx, coeff_idx; + //uint8_t packed_bit_idx; + uint128_t packed_coeff; + + block_idx = 0; + packed_coeff_idx = 0; + coeff_idx = 0; + + for (size_t k = 0; k < poly_size - 64; k += 64) + { + packed_coeff = poly[block_idx * packed_block_size + packed_coeff_idx]; + +#ifdef ENABLE_SSE + _mm_prefetch((char*)&fft_u[k], _MM_HINT_T2); +#endif // ENABLE_SSE + //__builtin_prefetch(&fft_u[k], 0, 0); + //__builtin_prefetch(&fft_u[k], 1, 0); + + for (size_t l = 0; l < 64; l++) + { + packed_coeff = packed_coeff >> 2; + fft_u[k + l] |= static_cast(packed_coeff) & 0b11; + fft_u[k + l] = fft_u[k + l] << 2; + } + + packed_coeff_idx++; + coeff_idx += 64; + + if (coeff_idx > block_size) + { + coeff_idx = 0; + block_idx++; + packed_coeff_idx = 0; + +#ifdef ENABLE_SSE + _mm_prefetch((char*)&poly[block_idx * packed_block_size], _MM_HINT_T2); + //__builtin_prefetch(&poly[block_idx * packed_block_size], 0, 2); +#endif // ENABLE_SSE + } + } + + packed_coeff = poly[block_idx * packed_block_size + packed_coeff_idx]; + for (size_t k = poly_size - 64 + 1; k < poly_size; k++) + { + packed_coeff = packed_coeff >> 2; + fft_u[k] |= static_cast(packed_coeff) & 0b11 ; + fft_u[k] = fft_u[k] << 2; + } + } + + fft_recursive_uint32(fft_u, n, poly_size / 3); + multiply_fft_32(fft_a2, fft_u, res_poly_mat, poly_size); + + // Perform column-wise XORs to get the result + for (size_t i = 0; i < poly_size; i++) + { + // XOR the (packed) columns into the accumulator + for (size_t j = 0; j < c * c; j++) + { + z_poly[i] ^= res_poly_mat[i] & 0b11; + res_poly_mat[i] = res_poly_mat[i] >> 2; + } + } + + time = clock() - time; + double time_taken = ((double)time) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("Eval time (total) %f ms\n", time_taken); + printf("DONE\n\n"); + + //DestroyPRFKey(prf_keys); + //free(fft_a); + //free(fft_a2); + //free(dpf_keys_A); + //free(dpf_keys_B); + //free(shares); + //free(cache); + //free(fft_u); + //free(packed_polys); + //free(res_poly_mat); + //free(z_poly); + + return time_taken; + } + + void printUsage() + { + printf("Usage: ./pcg [OPTIONS]\n"); + printf("Options:\n"); + printf(" --test\tTests correctness of the PCG.\n"); + printf(" --bench\tBenchmarks the PCG on conservative and aggressive parameters.\n"); + } + + void runBenchmarks(size_t n, size_t c, size_t t, int num_trials) + { + double time = 0; + + for (int i = 0; i < num_trials; i++) + { + time += bench_pcg(n, c, t); + printf("Done with trial %i of %i\n", i + 1, num_trials); + } + printf("******************************************\n"); + printf("Avg time (N=3^%zu, c=%zu, t=%zu): %0.4f ms\n", n, c, t, time / num_trials); + printf("******************************************\n\n"); + } + + int main_foliage(int argc, char** argv) + { + int num_trials = 5; + + for (int i = 1; i < argc; i++) + { + if (strcmp(argv[i], "--bench") == 0) + { + printf("******************************************\n"); + printf("Benchmarking PCG with conservative parameters (c=4, t=27)\n"); + runBenchmarks(14, 4, 27, num_trials); + runBenchmarks(16, 4, 27, num_trials); + runBenchmarks(18, 4, 27, num_trials); + + printf("******************************************\n"); + printf("Benchmarking PCG with aggressive parameters (c=3, t=27)\n"); + runBenchmarks(14, 3, 27, num_trials); + runBenchmarks(16, 3, 27, num_trials); + runBenchmarks(18, 3, 27, num_trials); + } + //else if (strcmp(argv[i], "--test") == 0) + //{ + // printf("******************************************\n"); + // printf("Testing PCG\n"); + // foliage_pcg_test(); + // printf("******************************************\n"); + // printf("PASS\n"); + // printf("******************************************\n\n"); + //} + else + { + printUsage(); + } + } + + if (argc == 1) + printUsage(); + + return 0; + } +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/FoliageUtils.h b/libOTe/Tools/Foliage/FoliageUtils.h new file mode 100644 index 00000000..c54a4f66 --- /dev/null +++ b/libOTe/Tools/Foliage/FoliageUtils.h @@ -0,0 +1,263 @@ +#pragma once +#include "cryptoTools/Crypto/AES.h" +#include "cryptoTools/Crypto/PRNG.h" +#include +#include "uint128.h" +namespace osuCrypto +{ + using uint128_t = absl::uint128_t; + //using int128_t = block; + //using uint128_t = block; + //using uint128_t = __uint128_t; + //struct uint128_t + //{ + // std::array mVals; + + // uint128_t() = default; + // uint128_t(const uint128_t&) = default; + // uint128_t& operator=(const uint128_t&) = default; + + // uint128_t(const u64& v) : mVals({ v,0 }) {}; + + // bool operator==(const uint128_t& o) const { return mVals[0] == o.mVals[0] && mVals[1] == o.mVals[1]; } + // bool operator!=(const uint128_t& o) const { return !(*this == o); } + + // bool operator==(const u64& o) const { return *this == uint128_t{ o }; } + // bool operator!=(const u64& o) const { return *this != uint128_t{ o }; } + // bool operator==(const int& o) const { return *this == uint128_t{ u64(o) }; } + // bool operator!=(const int& o) const { return *this != uint128_t{ u64(o) }; } + + + // uint128_t operator^(const uint128_t&o) const { + // uint128_t r = *this; + // r ^= o; + // return r; + // } + // uint128_t& operator^=(const uint128_t& o) + // { + // mVals[0] ^= o.mVals[0]; + // mVals[1] ^= o.mVals[1]; + // return *this; + // } + + // uint128_t operator&(const uint128_t&o) const { + // uint128_t r = *this; + // r &= o; + // return r; + // } + // uint128_t& operator&=(const uint128_t&o) + // { + // mVals[0] &= o.mVals[0]; + // mVals[1] &= o.mVals[1]; + // return *this; + // } + + + // uint128_t operator+(const uint128_t&o) const + // { + // uint128_t r = *this; + // r += o; + // return r; + // } + // uint128_t& operator+=(const uint128_t&o) + // { + // u64 v; + // char cout = _addcarry_u64(0, mVals[0], o.mVals[0], &mVals[0]); + // _addcarry_u64(cout, mVals[1], o.mVals[1], &mVals[1]); + // return *this; + // } + + + // uint128_t operator-(const uint128_t&o) const + // { + // uint128_t r = *this; + // r -= o; + // return r; + // } + // uint128_t& operator-=(const uint128_t&o) + // { + // auto borrow = _subborrow_u64(0, mVals[0], o.mVals[0], &mVals[0]); + // _subborrow_u64(borrow, mVals[1], o.mVals[1], &mVals[1]); + // return *this; + // } + + + // uint128_t operator>>(u64 s) const + // { + // auto r = *this; + // r >>= s; + // return r; + // } + // uint128_t& operator>>=(u64 s) + // { + // assert(s <= 128); + // if (s < 64) + // { + // mVals[0] = (mVals[0] >> s) | (mVals[1] << (64-s)); + // mVals[1] >>= s; + // } + // else + // { + // s = s - 64; + // mVals[0] = mVals[1] >> s; + // mVals[1] = 0; + // } + // return *this; + // } + + // uint128_t operator<<(u64 s) const + // { + // auto r = *this; + // r <<= s; + // return r; + // } + // uint128_t& operator<<=(u64 s) + // { + // assert(s <= 128); + // if (s < 64) + // { + // mVals[1] = (mVals[1] << s) | (mVals[0] >> (64 - s)); + // mVals[0] <<= s; + // } + // else + // { + // s = s - 64; + // mVals[1] = mVals[0] << s; + // mVals[0] = 0; + // } + // return *this; + // } + + // uint128_t operator>>(int s) const { return *this >> u64(s); } + // uint128_t& operator>>=(int s) { return *this >>= u64(s); } + + // uint128_t operator<<(int s) const { return *this << u64(s); } + // uint128_t& operator<<=(int s) { return *this >>= u64(s); } + + // operator u64 () const + // { + // return mVals[0]; + // } + + + //}; + + + + inline void printBytes(void* p, int num) + { + unsigned char* c = (unsigned char*)p; + for (int i = 0; i < num; i++) + { + printf("%02x", c[i]); + } + printf("\n"); + } + + // Samples a uniformly random value between 0 and max via rejection sampling. + inline uint64_t random_index(uint64_t max, PRNG& prng) + { + if (max == 0) + return 0; + + return prng.get() % (max + 1); + //while (1) + //{ + + // // Use rejection sampling to ensure uniformity + // if (rand_value <= (UINT64_MAX - (UINT64_MAX % (max + 1)))) + // return rand_value % (max + 1); + //} + } + + // Samples a random trit (0,1,2) via rejection sampling + inline uint8_t rand_trit(PRNG& prng) + { + uint8_t t; + + while (1) + { + //RAND_bytes(&rand_byte, 1); + t = prng.get(); + if (t <= 170) // Rejecting values greater than 170 + return t % 3; + } + } + + // Reverses the order of elements in an array of uint8_t values + inline void reverse_uint8_array(span trits, size_t size) + { + size_t i = 0; + size_t j = size - 1; + + while (i < j) + { + // Swap elements at positions i and j + uint8_t temp = trits[i]; + trits[i] = trits[j]; + trits[j] = temp; + + // Move towards the center of the array + i++; + j--; + } + } + + // Converts an array of trits (not packed) into their integer representation. + inline size_t trits_to_int(span trits, size_t size) + { + reverse_uint8_array(trits, size); + size_t result = 0; + for (size_t i = 0; i < size; i++) + result = result * 3 + (size_t)trits[i]; + + return result; + } + + // Converts an integer into ternary representation (each trit = 0,1,2) + inline void int_to_trits(size_t n, span trits, size_t size) + { + for (size_t i = 0; i < size; i++) + trits[i] = 0; + + size_t index = 0; + while (n > 0 && index < size) + { + trits[index] = (uint8_t)(n % 3); + n = n / 3; + index++; + } + } + + // Computes the log of `a` base `base` + inline double log_base(double a, double base) + { + return std::log2(a) / std::log2(base); + } + + // Compute base^exp without the floating-point precision + // errors of the built-in pow function. + inline size_t ipow(size_t base, size_t exp) + { + if (exp == 1) + return base; + + if (exp == 0) + return 1; + + size_t result = 1; + while (1) + { + if (exp & 1) + result *= base; + exp >>= 1; + if (!exp) + break; + base *= base; + } + + return result; + } + + +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/fft/FoliageFFT_bench.cpp b/libOTe/Tools/Foliage/fft/FoliageFFT_bench.cpp new file mode 100644 index 00000000..d10d7bc7 --- /dev/null +++ b/libOTe/Tools/Foliage/fft/FoliageFFT_bench.cpp @@ -0,0 +1,138 @@ +//#include +//#include +//#include +//#include + +#include +#include + +#include "libOTe/Tools/Foliage/fft/FoliageFft.h" +#include "cryptoTools/Common/Aligned.h" +#include "cryptoTools/Crypto/PRNG.h" + +#include "libOTe/Tools/Foliage/FoliageUtils.h" + +#define NUMVARS 16 + +namespace osuCrypto +{ + + + double Foliage_FFT64_bench() + { + size_t num_vars = NUMVARS; + size_t num_coeffs = ipow(3, num_vars); + AlignedUnVector coeffs (num_coeffs); + PRNG prng(block(342)); + prng.get(coeffs.data(), num_coeffs); + + //************************************************ + printf("Benchmarking FFT evaluation with uint64_t packing \n"); + //************************************************ + + clock_t t; + t = clock(); + fft_recursive_uint64(coeffs, num_vars, num_coeffs / 3); + t = clock() - t; + double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("FFT (uint64) eval time (total) %f ms\n", time_taken); + + return time_taken; + } + + double Foliage_FFT32_bench() + { + size_t num_vars = NUMVARS; + size_t num_coeffs = ipow(3, num_vars); + + AlignedUnVector < uint32_t> coeffs(num_coeffs); + PRNG prng(block(342)); + prng.get(coeffs.data(), num_coeffs); + + //************************************************ + printf("Benchmarking FFT evaluation with uint32_t packing \n"); + //************************************************ + + clock_t t; + t = clock(); + fft_recursive_uint32(coeffs, num_vars, num_coeffs / 3); + t = clock() - t; + double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("FFT (uint32) eval time (total) %f ms\n", time_taken); + + + return time_taken; + } + + double Foliage_FFT8_bench() + { + size_t num_vars = NUMVARS; + size_t num_coeffs = ipow(3, num_vars); + AlignedUnVector coeffs (num_coeffs); + PRNG prng(block(342)); + prng.get(coeffs.data(), num_coeffs); + + //************************************************ + printf("Benchmarking FFT evaluation without packing \n"); + //************************************************ + + clock_t t; + t = clock(); + fft_recursive_uint8(coeffs, num_vars, num_coeffs / 3); + t = clock() - t; + double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("FFT (uint8) eval time (total) %f ms\n", time_taken); + + //free(coeffs); + + return time_taken; + } + + int mainFFT(int argc, char** argv) + { + double time = 0; + int testTrials = 5; + + printf("******************************************\n"); + printf("Testing FFT (uint8 packing)\n"); + for (int i = 0; i < testTrials; i++) + { + time += Foliage_FFT8_bench(); + printf("Done with trial %i of %i\n", i + 1, testTrials); + } + printf("******************************************\n"); + printf("DONE\n"); + printf("Avg time: %0.2f\n", time / testTrials); + printf("******************************************\n\n"); + + printf("******************************************\n"); + printf("Testing FFT (uint32 packing) \n"); + time = 0; + for (int i = 0; i < testTrials; i++) + { + time += Foliage_FFT32_bench(); + printf("Done with trial %i of %i\n", i + 1, testTrials); + } + printf("******************************************\n"); + printf("DONE\n"); + printf("Avg time: %0.2f\n", time / testTrials); + printf("******************************************\n\n"); + + printf("******************************************\n"); + printf("Testing FFT (uint64 packing) \n"); + time = 0; + for (int i = 0; i < testTrials; i++) + { + time += Foliage_FFT64_bench(); + printf("Done with trial %i of %i\n", i + 1, testTrials); + } + printf("******************************************\n"); + printf("DONE\n"); + printf("Avg time: %0.2f\n", time / testTrials); + printf("******************************************\n\n"); + return 0; + } +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/fft/FoliageFFT_bench.h b/libOTe/Tools/Foliage/fft/FoliageFFT_bench.h new file mode 100644 index 00000000..922b1a57 --- /dev/null +++ b/libOTe/Tools/Foliage/fft/FoliageFFT_bench.h @@ -0,0 +1,13 @@ +#pragma once + + +namespace osuCrypto +{ + + double Foliage_FFT8_bench(); + double Foliage_FFT32_bench(); + double Foliage_FFT64_bench(); + + + +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/fft/FoliageFft.cpp b/libOTe/Tools/Foliage/fft/FoliageFft.cpp new file mode 100644 index 00000000..54d8f615 --- /dev/null +++ b/libOTe/Tools/Foliage/fft/FoliageFft.cpp @@ -0,0 +1,311 @@ +#include +#include +#include "libOTe/Tools/Foliage/fft/FoliageFft.h" + +namespace osuCrypto { + + void fft_recursive_uint64( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint64( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint64( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint64( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint64_t tL, tM; + uint64_t mult, xor_h, xor_l; + + uint64_t* coeffsL = &coeffs[0]; + uint64_t* coeffsM = &coeffs[num_coeffs]; + uint64_t* coeffsR = &coeffs[2 * num_coeffs]; + + const uint64_t pattern = 0xaaaaaaaaaaaaaaaa; + const uint64_t mask_h = pattern; // 0b101010101010101001010 + const uint64_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + + void fft_recursive_uint32( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint32( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint32( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint32( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint32_t tL, tM; + uint32_t mult, xor_h, xor_l; + + uint32_t* coeffsL = &coeffs[0]; + uint32_t* coeffsM = &coeffs[num_coeffs]; + uint32_t* coeffsR = &coeffs[2 * num_coeffs]; + + const uint32_t pattern = 0xaaaaaaaa; + const uint32_t mask_h = pattern; // 0b101010101010101001010 + const uint32_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + + void fft_recursive_uint16( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint16( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint16( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint16( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint16_t tL, tM; + uint16_t mult, xor_h, xor_l; + + uint16_t* coeffsL = &coeffs[0]; + uint16_t* coeffsM = &coeffs[num_coeffs]; + uint16_t* coeffsR = &coeffs[2 * num_coeffs]; + + const uint16_t pattern = 0xaaaa; + const uint16_t mask_h = pattern; // 0b101010101010101001010 + const uint16_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + + void fft_recursive_uint8( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint8( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint8( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint8( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint8_t tL, tM; + uint8_t mult, xor_h, xor_l; + + uint8_t* coeffsL = &coeffs[0]; + uint8_t* coeffsM = &coeffs[num_coeffs]; + uint8_t* coeffsR = &coeffs[2 * num_coeffs]; + + const uint8_t pattern = 0xaa; + const uint8_t mask_h = pattern; // 0b101010101010101001010 + const uint8_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/fft/FoliageFft.h b/libOTe/Tools/Foliage/fft/FoliageFft.h new file mode 100644 index 00000000..ffbff8e9 --- /dev/null +++ b/libOTe/Tools/Foliage/fft/FoliageFft.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include "cryptoTools/Common/Defines.h" + +//#include "libOTe/Tools/Foliage/utils.h" +namespace osuCrypto { + + //typedef __int128 int128_t; + //typedef unsigned __int128 uint128_t; + + // FFT for (up to) 32 polynomials over F4 + void fft_recursive_uint64( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + // FFT for (up to) 16 polynomials over F4 + void fft_recursive_uint32( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + // FFT for (up to) 8 polynomials over F4 + void fft_recursive_uint16( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + // FFT for (up to) 4 polynomials over F4 + void fft_recursive_uint8( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + +} diff --git a/libOTe/Tools/Foliage/spfss_test.cpp b/libOTe/Tools/Foliage/spfss_test.cpp new file mode 100644 index 00000000..d5111f15 --- /dev/null +++ b/libOTe/Tools/Foliage/spfss_test.cpp @@ -0,0 +1,158 @@ +#include +#include +#include + +#include "libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h" +#include "FoliageUtils.h" + +#define SUMT 730 // sum of T DPFs + +#define FULLEVALDOMAIN 10 +#define MESSAGESIZE 8 +#define MAXRANDINDEX ipow(3, FULLEVALDOMAIN) +namespace osuCrypto +{ + + //size_t randIndex() + //{ + // srand(time(NULL)); + // return ((size_t)rand()) % ((size_t)MAXRANDINDEX); + //} + + //uint128_t randMsg() + //{ + // uint128_t msg; + // RAND_bytes((uint8_t*)&msg, sizeof(uint128_t)); + // return msg; + //} + + double benchmark_spfss() + { + size_t num_leaves = ipow(3, FULLEVALDOMAIN); + size_t size = FULLEVALDOMAIN; // evaluation will result in 3^size points + PRNG prng(block(3423423)); + + size_t secret_index = prng.get() % MAXRANDINDEX; + uint128_t secret_msg = prng.get(); + size_t msg_len = MESSAGESIZE; + + PRFKeys prf_keys; + prf_keys.gen(prng); + + std::vector kA(SUMT); + std::vector kB(SUMT); + + clock_t t; + t = clock(); + + for (size_t i = 0; i < SUMT; i++) + DPFGen(prf_keys, size, secret_index, span(&secret_msg,1), msg_len, kA[i], kB[i], prng); + + t = clock() - t; + double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("Time %f ms\n", time_taken); + + return time_taken; + } + + double benchmarkAES() + { + size_t num_leaves = ipow(3, FULLEVALDOMAIN); + size_t size = FULLEVALDOMAIN; + PRNG prng(block(3423423)); + + PRFKeys prf_keys; + prf_keys.gen(prng); + + AlignedUnVector data_in (num_leaves * MESSAGESIZE); + AlignedUnVector data_out(num_leaves * MESSAGESIZE); + AlignedUnVector data_tmp(num_leaves * MESSAGESIZE); + AlignedUnVector tmp; + + // fill with unique data + for (size_t i = 0; i < num_leaves * MESSAGESIZE; i++) + data_tmp[i] = (uint128_t)i; + + // make the input data pseudorandom for correct timing + PRFBatchEval(prf_keys.prf_key0, data_tmp, data_in, num_leaves * MESSAGESIZE); + + //************************************************ + // Benchmark AES encryption time required in DPF loop + //************************************************ + + clock_t t; + t = clock(); + + for (size_t n = 0; n < SUMT; n++) + { + size_t num_nodes = 1; + for (size_t i = 0; i < size; i++) + { + PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes); + PRFBatchEval(prf_keys.prf_key1, data_in, data_out.subspan(num_nodes), num_nodes); + PRFBatchEval(prf_keys.prf_key2, data_in, data_out.subspan(num_nodes * 2), num_nodes); + + tmp = data_out; + data_out = data_in; + data_in = tmp; + + num_nodes *= 3; + } + // compute AES part of output extension + PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes * MESSAGESIZE); + } + + t = clock() - t; + double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("Time %f ms\n", time_taken); + + return time_taken; + } + + int mainSpfss(int argc, char** argv) + { + + double time = 0; + int testTrials = 10; + + //printf("******************************************\n"); + //printf("Testing DPF.FullEval\n"); + //for (int i = 0; i < testTrials; i++) + //{ + // time += foliage_spfss_test(); + // printf("Done with trial %i of %i\n", i + 1, testTrials); + //} + //printf("******************************************\n"); + //printf("PASS\n"); + //printf("DPF.FullEval: (avg time) %0.2f ms\n", time / testTrials); + //printf("******************************************\n\n"); + + time = 0; + printf("******************************************\n"); + printf("Benchmarking DPF.Gen\n"); + for (int i = 0; i < testTrials; i++) + { + time += benchmark_spfss(); + printf("Done with trial %i of %i\n", i + 1, testTrials); + } + printf("******************************************\n"); + printf("Avg time: %0.4f ms\n", time / testTrials); + printf("******************************************\n\n"); + + time = 0; + printf("******************************************\n"); + printf("Benchmarking AES\n"); + for (int i = 0; i < testTrials; i++) + { + time += benchmarkAES(); + printf("Done with trial %i of %i\n", i + 1, testTrials); + } + printf("******************************************\n"); + printf("Avg time: %0.2f ms\n", time / testTrials); + printf("******************************************\n\n"); + + return 0; + } +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/tri-dpf/.gitignore b/libOTe/Tools/Foliage/tri-dpf/.gitignore new file mode 100644 index 00000000..71035e10 --- /dev/null +++ b/libOTe/Tools/Foliage/tri-dpf/.gitignore @@ -0,0 +1,5 @@ +*.json +*.o +*.a +.DS_Store +bin diff --git a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf.cpp b/libOTe/Tools/Foliage/tri-dpf/FoliageDpf.cpp new file mode 100644 index 00000000..81fc77ef --- /dev/null +++ b/libOTe/Tools/Foliage/tri-dpf/FoliageDpf.cpp @@ -0,0 +1,317 @@ + +#include "FoliageDpf.h" + +#include "libOTe/Tools/Foliage/tri-dpf/TriDpfUtils.h" + + +//#include + +#define LOG_BATCH_SIZE 6 // operate in smallish batches to maximize cache hits +namespace osuCrypto +{ + // Naming conventions: + // - A,B refer to shares given to parties A and B + // - 0,1,2 refer to the branch index in the ternary tree + + void DPFGen( + PRFKeys& prf_keys, + size_t domain_size, + size_t index, + span msg_blocks, + size_t msg_block_len, + DPFKey& k0, + DPFKey& k1, + PRNG& prng) + { + + // starting seeds given to each party + uint128_t seedA = prng.get(); + uint128_t seedB = prng.get(); + + // correction word provided to both parties + // (one correction word per level) + std::vector sCW0(domain_size); + std::vector sCW1(domain_size); + std::vector sCW2(domain_size); + + // variables for the intermediate values + uint128_t parent, parentA, parentB, sA0, sA1, sA2, sB0, sB1, sB2; + + // current parent value (xor of the two seeds) + parent = seedA ^ seedB; + + // control bit of the parent on the special path must always be set to 1 + // so as to apply the corresponding correction word + if (get_lsb(parent) == uint128_t{ 0 }) + seedA = flip_lsb(seedA); + + parentA = seedA; + parentB = seedB; + + uint8_t prev_control_bit_A, prev_control_bit_B; + + for (size_t i = 0; i < domain_size; i++) + { + prev_control_bit_A = static_cast(get_lsb(parentA)); + prev_control_bit_B = static_cast(get_lsb(parentB)); + + // expand the starting seeds of each party + PRFEval(prf_keys.prf_key0, parentA, sA0); + PRFEval(prf_keys.prf_key1, parentA, sA1); + PRFEval(prf_keys.prf_key2, parentA, sA2); + PRFEval(prf_keys.prf_key0, parentB, sB0); + PRFEval(prf_keys.prf_key1, parentB, sB1); + PRFEval(prf_keys.prf_key2, parentB, sB2); + + // on-path correction word is set to random + // so as to be indistinguishable from the real correction words + uint128_t r = prng.get(); + + // get the current trit (ternary bit) of the special index + uint8_t trit = get_trit(index, domain_size, i); + + switch (trit) + { + case 0: + parent = sA0 ^ sB0 ^ r; + if (get_lsb(parent) == 0) + r = flip_lsb(r); + + sCW0[i] = r; + sCW1[i] = sA1 ^ sB1; + sCW2[i] = sA2 ^ sB2; + + if (get_lsb(parentA) == 1) + { + parentA = sA0 ^ r; + parentB = sB0; + } + else + { + parentA = sA0; + parentB = sB0 ^ r; + } + + break; + + case 1: + parent = sA1 ^ sB1 ^ r; + if (get_lsb(parent) == 0) + r = flip_lsb(r); + + sCW0[i] = sA0 ^ sB0; + sCW1[i] = r; + sCW2[i] = sA2 ^ sB2; + + if (get_lsb(parentA) == 1) + { + parentA = sA1 ^ r; + parentB = sB1; + } + else + { + parentA = sA1; + parentB = sB1 ^ r; + } + + break; + + case 2: + parent = sA2 ^ sB2 ^ r; + if (get_lsb(parent) == 0) + r = flip_lsb(r); + + sCW0[i] = sA0 ^ sB0; + sCW1[i] = sA1 ^ sB1; + sCW2[i] = r; + + if (get_lsb(parentA) == 1) + { + parentA = sA2 ^ r; + parentB = sB2; + } + else + { + parentA = sA2; + parentB = sB2 ^ r; + } + + break; + + default: + printf("error: not a ternary digit!\n"); + exit(0); + } + } + + // set the last correction word to correct the output to msg + uint128_t leaf_seedA, leaf_seedB; + uint8_t last_trit = get_trit(index, domain_size, domain_size - 1); + if (last_trit == 0) + { + leaf_seedA = sA0 ^ uint128_t(prev_control_bit_A * sCW0[domain_size - 1]); + leaf_seedB = sB0 ^ uint128_t(prev_control_bit_B * sCW0[domain_size - 1]); + } + else if (last_trit == 1) + { + leaf_seedA = sA1 ^ uint128_t(prev_control_bit_A * sCW1[domain_size - 1]); + leaf_seedB = sB1 ^ uint128_t(prev_control_bit_B * sCW1[domain_size - 1]); + } + + else if (last_trit == 2) + { + leaf_seedA = sA2 ^ uint128_t(prev_control_bit_A * sCW2[domain_size - 1]); + leaf_seedB = sB2 ^ uint128_t(prev_control_bit_B * sCW2[domain_size - 1]); + } + + AlignedUnVector outputA(msg_block_len); + AlignedUnVector outputB(msg_block_len); + AlignedUnVector cache(msg_block_len); + AlignedUnVector outputCW(msg_block_len); + + outputA[0] = leaf_seedA; + outputB[0] = leaf_seedB; + + ExtendOutput(prf_keys, outputA, cache, 1, msg_block_len); + ExtendOutput(prf_keys, outputB, cache, 1, msg_block_len); + + for (size_t i = 0; i < msg_block_len; i++) + outputCW[i] = outputA[i] ^ outputB[i] ^ msg_blocks[i]; + + // memcpy all the generated values into two keys + // 16 = sizeof(uint128_t) + size_t key_size = sizeof(uint128_t); // initial seed size; + key_size += 3 * domain_size * sizeof(uint128_t); // correction words + key_size += sizeof(uint128_t) * msg_block_len; // output correction word + + k0.prf_keys = &prf_keys; + k0.k.resize(key_size); + k0.size = domain_size; + k0.msg_len = msg_block_len; + memcpy(&k0.k[0], &seedA, 16); + memcpy(&k0.k[16], &sCW0[0], domain_size * 16); + memcpy(&k0.k[16 * domain_size + 16], &sCW1[0], domain_size * 16); + memcpy(&k0.k[16 * 2 * domain_size + 16], &sCW2[0], domain_size * 16); + memcpy(&k0.k[16 * 3 * domain_size + 16], &outputCW[0], msg_block_len * 16); + + k1.prf_keys = &prf_keys; + k1.k.resize(key_size); + k1.size = domain_size; + k1.msg_len = msg_block_len; + memcpy(&k1.k[0], &seedB, 16); + memcpy(&k1.k[16], &sCW0[0], domain_size * 16); + memcpy(&k1.k[16 * domain_size + 16], &sCW1[0], domain_size * 16); + memcpy(&k1.k[16 * 2 * domain_size + 16], &sCW2[0], domain_size * 16); + memcpy(&k1.k[16 * 3 * domain_size + 16], &outputCW[0], msg_block_len * 16); + + //free(outputA); + //free(outputB); + //free(cache); + //free(outputCW); + } + + // evaluates the full DPF domain; much faster than + // batching the evaluation points since each level of the DPF tree + // is only expanded once. + void DPFFullDomainEval( + DPFKey& key, + span cache, + span output) + { + size_t size = key.size; + span k = key.k; + PRFKeys& prf_keys = *key.prf_keys; + + if (size % 2 == 1) + { + auto tmp = cache; + cache = output; + output = tmp; + } + + // full_eval_size = pow(3, size); + const size_t num_leaves = ipow(3, size); + + memcpy(&output[0], &k[0], 16); // output[0] is the start seed + const uint128_t* sCW0 = (uint128_t*)&k[16]; + const uint128_t* sCW1 = (uint128_t*)&k[16 * size + 16]; + const uint128_t* sCW2 = (uint128_t*)&k[16 * 2 * size + 16]; + + // inner loop variables related to node expansion + // and correction word application + span tmp; + size_t idx0, idx1, idx2; + uint8_t cb = 0; + + // batching variables related to chunking of inner loop processing + // for the purpose of maximizing cache hits + size_t max_batch_size = ipow(3, LOG_BATCH_SIZE); + size_t batch, num_batches, batch_size, offset; + + size_t num_nodes = 1; + for (uint8_t i = 0; i < size; i++) + { + if (i < LOG_BATCH_SIZE) + { + batch_size = num_nodes; + num_batches = 1; + } + else + { + batch_size = max_batch_size; + num_batches = num_nodes / max_batch_size; + } + + offset = 0; + for (batch = 0; batch < num_batches; batch++) + { + PRFBatchEval(prf_keys.prf_key0, output.subspan(offset), cache.subspan(offset), batch_size); + PRFBatchEval(prf_keys.prf_key1, output.subspan(offset), cache.subspan(num_nodes + offset), batch_size); + PRFBatchEval(prf_keys.prf_key2, output.subspan(offset), cache.subspan((num_nodes * 2) + offset), batch_size); + + idx0 = offset; + idx1 = num_nodes + offset; + idx2 = (num_nodes * 2) + offset; + + while (idx0 < offset + batch_size) + { + cb = static_cast(output[idx0]) & 1; // gets the LSB of the parent + cache[idx0] ^= (cb * sCW0[i]); + cache[idx1] ^= (cb * sCW1[i]); + cache[idx2] ^= (cb * sCW2[i]); + + idx0++; + idx1++; + idx2++; + } + + offset += batch_size; + } + + tmp = output; + output = cache; + cache = tmp; + + num_nodes *= 3; + } + + const size_t output_length = key.msg_len * num_leaves; + const size_t msg_len = key.msg_len; + uint128_t* outputCW = (uint128_t*)&k[16 * 3 * size + 16]; + ExtendOutput(prf_keys, output, cache, num_leaves, output_length); + + size_t j = 0; + for (size_t i = 0; i < num_leaves; i++) + { + // TODO: a bit hacky, assumes that cache[i*msg_len] = old_output[i] + // which is the case internally in ExtendOutput. It would be good + // to remove this assumption however using memcpy is costly... + + if (cache[i * msg_len] & uint128_t{ 1 }) // parent control bit + { + for (j = 0; j < msg_len; j++) + output[i * msg_len + j] ^= outputCW[j]; + } + } + } +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h b/libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h new file mode 100644 index 00000000..1ca23568 --- /dev/null +++ b/libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h @@ -0,0 +1,35 @@ +#pragma once + +#include +#include + +#include "libOTe/Tools/Foliage/FoliageUtils.h" +#include "libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h" + + +namespace osuCrypto +{ + struct DPFKey + { + PRFKeys* prf_keys; + AlignedUnVector k; + size_t msg_len; + size_t size; + }; + + void DPFGen( + PRFKeys& prf_keys, + size_t domain_size, + size_t index, + span msg_blocks, + size_t msg_block_len, + DPFKey& k0, + DPFKey& k1, + PRNG& prng); + + void DPFFullDomainEval( + DPFKey& k, + span cache, + span output); + +} diff --git a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.cpp b/libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.cpp new file mode 100644 index 00000000..30ded078 --- /dev/null +++ b/libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.cpp @@ -0,0 +1,166 @@ +//#include +//#include +//#include +//#include +#include +#include +#include + +#include "libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h" +//#include "libOTe/Tools/Foliage/tri-dpf/FoliageHalfDpf.h" +#include + +#define FULLEVALDOMAIN 14 +#define MESSAGESIZE 2 +#define MAXRANDINDEX ipow(3, FULLEVALDOMAIN) +namespace osuCrypto +{ + size_t randIndex(PRNG& prng) + { + return prng.get() % (size_t)MAXRANDINDEX; + } + //using int128_t = uint128_t; + uint128_t randMsg(PRNG& prng) + { + return prng.get(); + //uint128_t msg; + //RAND_bytes((uint8_t*)&msg, sizeof(uint128_t)); + //return msg; + } + + double benchmark_dpfGen() + { + size_t num_leaves = ipow(3, FULLEVALDOMAIN); + size_t size = FULLEVALDOMAIN; // evaluation will result in 3^size points + PRNG prng(block(3423423)); + size_t secret_index = randIndex(prng); + uint128_t secret_msg = randMsg(prng); + size_t msg_len = 1; + + PRFKeys prf_keys; + prf_keys.gen(prng); + + DPFKey kA; + DPFKey kB; + + clock_t t; + t = clock(); + DPFGen(prf_keys, size, secret_index, span(&secret_msg,1), msg_len, kA, kB, prng); + t = clock() - t; + double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("Time %f ms\n", time_taken); + + return time_taken; + } + + double benchmark_dpfAES() + { + size_t num_leaves = ipow(3, FULLEVALDOMAIN); + size_t size = FULLEVALDOMAIN; + + PRNG prng(block(3423423)); + PRFKeys prf_keys; + prf_keys.gen(prng); + + AlignedUnVector data_in(num_leaves * MESSAGESIZE); + AlignedUnVector data_out(num_leaves * MESSAGESIZE); + AlignedUnVector data_tmp(num_leaves * MESSAGESIZE); + AlignedUnVector tmp; + + // fill with unique data + for (size_t i = 0; i < num_leaves * MESSAGESIZE; i++) + data_tmp[i] = (uint128_t)i; + + // make the input data pseudorandom for correct timing + PRFBatchEval(prf_keys.prf_key0, data_tmp, data_in, num_leaves * MESSAGESIZE); + + //************************************************ + // Benchmark AES encryption time required in DPF loop + //************************************************ + + clock_t t; + t = clock(); + size_t num_nodes = 1; + for (size_t i = 0; i < size; i++) + { + PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes); + PRFBatchEval(prf_keys.prf_key1, data_in, data_out.subspan(num_nodes), num_nodes); + PRFBatchEval(prf_keys.prf_key2, data_in, data_out.subspan(num_nodes * 2), num_nodes); + + tmp = data_out; + data_out = data_in; + data_in = tmp; + + num_nodes *= 3; + } + + // compute AES part of output extension + PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes * MESSAGESIZE); + + t = clock() - t; + double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("Time %f ms\n", time_taken); + + return time_taken; + } + + int main_test_tridpf(int argc, char** argv) + { + + double time = 0; + int testTrials = 3; + + //printf("******************************************\n"); + //printf("Testing DPF.FullEval\n"); + //for (int i = 0; i < testTrials; i++) + //{ + // time += foliage_dpf_test(); + // printf("Done with trial %i of %i\n", i + 1, testTrials); + //} + //printf("******************************************\n"); + //printf("PASS\n"); + //printf("DPF.FullEval: (avg time) %0.2f ms\n", time / testTrials); + //printf("******************************************\n\n"); + + //time = 0; + //printf("******************************************\n"); + //printf("Testing HalfDPF.FullEval\n"); + //for (int i = 0; i < testTrials; i++) + //{ + // time += foliage_Halfdpf_test(); + // printf("Done with trial %i of %i\n", i + 1, testTrials); + //} + //printf("******************************************\n"); + //printf("PASS\n"); + //printf("HalfDPF.FullEval: (avg time) %0.2f ms\n", time / testTrials); + //printf("******************************************\n\n"); + + time = 0; + printf("******************************************\n"); + printf("Benchmarking DPF.Gen\n"); + for (int i = 0; i < testTrials; i++) + { + time += benchmark_dpfGen(); + printf("Done with trial %i of %i\n", i + 1, testTrials); + } + printf("******************************************\n"); + printf("Avg time: %0.4f ms\n", time / testTrials); + printf("******************************************\n\n"); + + time = 0; + printf("******************************************\n"); + printf("Benchmarking AES\n"); + for (int i = 0; i < testTrials; i++) + { + time += benchmark_dpfAES(); + printf("Done with trial %i of %i\n", i + 1, testTrials); + } + printf("******************************************\n"); + printf("Avg time: %0.2f ms\n", time / testTrials); + printf("******************************************\n\n"); + + return 0; + } +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.h b/libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.h new file mode 100644 index 00000000..9776388a --- /dev/null +++ b/libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.h @@ -0,0 +1,9 @@ +#pragma once + + +namespace osuCrypto +{ + void foliage_dpf_test(); + void foliage_Halfdpf_test(); + +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h b/libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h new file mode 100644 index 00000000..0d69e0a4 --- /dev/null +++ b/libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h @@ -0,0 +1,83 @@ +#pragma once + + +#include +#include "cryptoTools/Crypto/AES.h" +//#include "utils.h" +#include "libOTe/Tools/Foliage/FoliageUtils.h" + +namespace osuCrypto +{ + + + using EVP_CIPHER_CTX = oc::AES; + struct PRFKeys + { + PRFKeys() = default; + + + void gen(PRNG& prng) + { + prf_key0.setKey(prng.get()); + prf_key1.setKey(prng.get()); + prf_key2.setKey(prng.get()); + prf_key_ext.setKey(prng.get()); + } + + + + EVP_CIPHER_CTX prf_key0; + EVP_CIPHER_CTX prf_key1; + EVP_CIPHER_CTX prf_key2; + EVP_CIPHER_CTX prf_key_ext; + }; + + //void PRFKeyGen(struct PRFKeys* prf_keys); + //void DestroyPRFKey(struct PRFKeys* prf_keys); + + // XOR with input to prevent inversion using Davies–Meyer construction + inline void PRFEval(EVP_CIPHER_CTX& ctx, uint128_t& input, uint128_t& outputs) + { + block in, out; + copyBytes(in, input); + out = ctx.hashBlock(in); + copyBytes(outputs, out); + } + + // PRF used to expand the DPF tree. Just a call to AES-ECB. + // Note: we use ECB-mode (instead of CTR) as we want to manage each block separately. + // XOR with input to prevent inversion using Davies–Meyer construction + inline void PRFBatchEval(EVP_CIPHER_CTX& ctx, span input, span outputs, u64 num_blocks) + { + if (num_blocks > input.size()) + throw RTE_LOC; + if (num_blocks > outputs.size()) + throw RTE_LOC; + ctx.hashBlocks((block*)input.data(), num_blocks, (block*)outputs.data()); + } + + // extends the output by the provided factor using the PRG + inline void ExtendOutput( + PRFKeys& prf_keys, + span output, + span cache, + const size_t output_size, + const size_t new_output_size) + { + + if (new_output_size % output_size != 0) + throw std::runtime_error("ERROR: new_output_size needs to be a multiple of output_size. " LOCATION); + if (new_output_size < output_size) + throw std::runtime_error("ERROR: new_output_size < output_size" LOCATION); + + size_t factor = new_output_size / output_size; + + for (size_t i = 0; i < output_size; i++) + { + for (size_t j = 0; j < factor; j++) + cache[i * factor + j] = output[i] ^ uint128_t{ j }; + } + + PRFBatchEval(prf_keys.prf_key_ext, cache, output, new_output_size); + } +} diff --git a/libOTe/Tools/Foliage/tri-dpf/LICENSE b/libOTe/Tools/Foliage/tri-dpf/LICENSE new file mode 100644 index 00000000..2aa6fcd1 --- /dev/null +++ b/libOTe/Tools/Foliage/tri-dpf/LICENSE @@ -0,0 +1,9 @@ +MIT License + +Copyright © 2024 Sacha Servan-Schreiber + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Softwareâ€), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS ISâ€, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/libOTe/Tools/Foliage/tri-dpf/README.md b/libOTe/Tools/Foliage/tri-dpf/README.md new file mode 100644 index 00000000..f628b4f1 --- /dev/null +++ b/libOTe/Tools/Foliage/tri-dpf/README.md @@ -0,0 +1,116 @@ +# Ternary-tree DPF Implementation + +A simple C implementation of Distributed Point Functions (DPFs) with several performance optimizations. + +Optimizations include: + +- Ternary instead of a binary tree (increases communication slightly but improves evaluation performance by having a flatter tree). +- Using batched AES for fast PRF evaluation with AES-NI. +- The half-tree optimization of [Guo et al.](https://eprint.iacr.org/2022/1431.pdf), however, this only improves performance by 2\%-4\% in the ternary-tree case. + +## Dependencies + +- OpenSSL +- GNU Make +- Cmake +- Clang + +## Getting everything to run (tested on Ubuntu, CentOS, and MacOS) + +| Install dependencies (Ubuntu): | Install dependencies (CentOS): | +| -------------------------------------- | ------------------------------------------- | +| `sudo apt-get install build-essential` | `sudo yum groupinstall 'Development Tools'` | +| `sudo apt-get install cmake` | `sudo yum install cmake` | +| `sudo apt install libssl-dev` | `sudo yum install openssl-devel` | +| `sudo apt install clang` | `sudo yum install clang` | + +## Running tests and benchmarks + +``` +make +./bin/test +``` + +## Possible extensions (TODOs): + +- Arbitrary output size and full domain evaluation optimization of [Boyle et al.](https://eprint.iacr.org/2018/707). +- Serialization for DPF keys. + +## Minimal example + +```c +size_t domain_size = 10; +size_t num_leaves = ipow(3, domain_size); // domain of size 3^10 + +size_t secret_index = 5; +uint128_t secret_msg = 1; + +// common PRF keys +struct PRFKeys *prf_keys = malloc(sizeof(struct PRFKeys)); +PRFKeyGen(prf_keys); + +// DPF keys for each party +struct DPFKey *kA = malloc(sizeof(struct DPFKey)); +struct DPFKey *kB = malloc(sizeof(struct DPFKey)); + +DPFGen(prf_keys, domain_size, secret_index, &secret_msg, 1, kA, kB); + +uint128_t *shares0 = malloc(sizeof(uint128_t) * num_leaves); +uint128_t *shares1 = malloc(sizeof(uint128_t) * num_leaves); + +// cache is used to speed up evaluations when running many +// DPF evaluations sequentially +uint128_t *cache = malloc(sizeof(uint128_t) * num_leaves); + +// evaluate the DPF using the key of party A +DPFFullDomainEval(kA, cache, shares0); + +// evaluate the DPF using the key of party B +DPFFullDomainEval(kB, cache, shares1); + +DestroyPRFKey(prf_keys); +free(kA); +free(kB); +free(shares0); +free(shares1); +free(cache); +``` + +#### Performance on M1 Macbook Pro + +Domain of size $3^{14} \approx 2^{22}$ and message size of 256 bits. + +``` +****************************************** +Testing DPF.FullEval +****************************************** +PASS +Avg time for DPF.FullEval: 68.29 ms +****************************************** + +****************************************** +Testing HalfDPF.FullEval +****************************************** +PASS +Avg time for HalfDPF.FullEval: 65.38 ms +****************************************** +``` + +## Citation + +``` +@misc{foleage, + author = {Maxime Bombar and Dung Bui and Geoffroy Couteau and Alain Couvreur and Clément Ducros and Sacha Servan-Schreiber}, + title = {FOLEAGE: $\mathbb{F}_4$OLE-Based Multi-Party Computation for Boolean Circuits}, + howpublished = {Cryptology ePrint Archive, Paper 2024/429}, + year = {2024}, + note = {\url{https://eprint.iacr.org/2024/429}}, + url = {https://eprint.iacr.org/2024/429} +} + +``` + +## âš ï¸ Important Warning + +This implementation is intended for _research purposes only_. The code has NOT been vetted by security experts. +As such, no portion of the code should be used in any real-world or production setting! diff --git a/libOTe/Tools/Foliage/tri-dpf/TriDpfUtils.h b/libOTe/Tools/Foliage/tri-dpf/TriDpfUtils.h new file mode 100644 index 00000000..7910b05a --- /dev/null +++ b/libOTe/Tools/Foliage/tri-dpf/TriDpfUtils.h @@ -0,0 +1,68 @@ +#pragma once + + +#include +#include +#include "libOTe/Tools/Foliage/FoliageUtils.h" +#include "cryptoTools/Common/BitIterator.h" + +namespace osuCrypto +{ + + static inline uint128_t flip_lsb(uint128_t input) + { + return input ^ uint128_t{ 1 }; + } + + static inline uint128_t get_lsb(uint128_t input) + { + return input & uint128_t{ 1 }; + } + + static inline int get_trit(uint64_t x, int size, int t) + { + std::vector ternary(size); + for (int i = 0; i < size; i++) + { + ternary[i] = x % 3; + x /= 3; + } + + return ternary[t]; + } + + static inline int get_bit(uint128_t x, int size, int b) + { + return *oc::BitIterator((u8*)&x, (size - b)); + //return ((x) >> (size - b)) & 1; + } + + //static void printBytes(void* p, int num) + //{ + // unsigned char* c = (unsigned char*)p; + // for (int i = 0; i < num; i++) + // { + // printf("%02x", c[i]); + // } + // printf("\n"); + //} + + //// Compute base^exp without the floating-point precision + //// errors of the built-in pow function. + //static inline int ipow(int base, int exp) + //{ + // int result = 1; + // while (1) + // { + // if (exp & 1) + // result *= base; + // exp >>= 1; + // if (!exp) + // break; + // base *= base; + // } + + // return result; + //} + +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/uint128.h b/libOTe/Tools/Foliage/uint128.h new file mode 100644 index 00000000..b38d9012 --- /dev/null +++ b/libOTe/Tools/Foliage/uint128.h @@ -0,0 +1,801 @@ +// +// Copyright 2017 The Abseil Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// ----------------------------------------------------------------------------- +// File: int128_t.h +// ----------------------------------------------------------------------------- +// +// This header file defines 128-bit integer types, `uint128_t` and `int128_t`. + +#ifndef ABSL_INT128_H_ +#define ABSL_INT128_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define ABSL_IS_LITTLE_ENDIAN +#if defined(_MSC_VER) +// In very old versions of MSVC and when the /Zc:wchar_t flag is off, wchar_t is +// a typedef for unsigned short. Otherwise wchar_t is mapped to the __wchar_t +// builtin type. We need to make sure not to define operator wchar_t() +// alongside operator unsigned short() in these instances. +#define ABSL_INTERNAL_WCHAR_T __wchar_t +#if defined(_M_X64) +#include +#pragma intrinsic(_umul128) +#endif // defined(_M_X64) +#else // defined(_MSC_VER) +#define ABSL_INTERNAL_WCHAR_T wchar_t +#endif // defined(_MSC_VER) + +#ifdef _WIN32 +#ifdef abslint128_t_EXPORTS +#define ABSL_DLL __declspec(dllexport) +#else +#define ABSL_DLL __declspec(dllimport) +#endif +#else // _WIN32 +#define ABSL_DLL +#endif // _WIN32 + +// ABSL_ATTRIBUTE_ALWAYS_INLINE +// ABSL_ATTRIBUTE_NOINLINE +// +// Forces functions to either inline or not inline. Introduced in gcc 3.1. +#if defined(__GNUC__) || defined(__clang__) +#define ABSL_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) +#elif defined(_MSC_VER) && !__INTEL_COMPILER && _MSC_VER >= 1310 // since Visual Studio .NET 2003 +#define ABSL_ATTRIBUTE_ALWAYS_INLINE inline __forceinline +#else +#define ABSL_ATTRIBUTE_ALWAYS_INLINE inline +#endif + +// ABSL_INTERNAL_ASSUME(cond) +// Informs the compiler than a condition is always true and that it can assume +// it to be true for optimization purposes. The call has undefined behavior if +// the condition is false. +// In !NDEBUG mode, the condition is checked with an assert(). +// NOTE: The expression must not have side effects, as it will only be evaluated +// in some compilation modes and not others. +// +// Example: +// +// int x = ...; +// ABSL_INTERNAL_ASSUME(x >= 0); +// // The compiler can optimize the division to a simple right shift using the +// // assumption specified above. +// int y = x / 16; +// +#if !defined(NDEBUG) +#define ABSL_INTERNAL_ASSUME(cond) assert(cond) +#elif ABSL_HAVE_BUILTIN(__builtin_assume) +#define ABSL_INTERNAL_ASSUME(cond) __builtin_assume(cond) +#elif defined(__GNUC__) || ABSL_HAVE_BUILTIN(__builtin_unreachable) +#define ABSL_INTERNAL_ASSUME(cond) \ + do { \ + if (!(cond)) __builtin_unreachable(); \ + } while (0) +#elif defined(_MSC_VER) +#define ABSL_INTERNAL_ASSUME(cond) __assume(cond) +#else +#define ABSL_INTERNAL_ASSUME(cond) \ + do { \ + static_cast(false && (cond)); \ + } while (0) +#endif + +namespace absl { + + + // uint128_t + // + // An unsigned 128-bit integer type. The API is meant to mimic an intrinsic type + // as closely as is practical, including exhibiting undefined behavior in + // analogous cases (e.g. division by zero). This type is intended to be a + // drop-in replacement once C++ supports an intrinsic `uint128_t_t` type; when + // that occurs, existing well-behaved uses of `uint128_t` will continue to work + // using that new type. + // + // Note: code written with this type will continue to compile once `uint128_t_t` + // is introduced, provided the replacement helper functions + // `Uint128(Low|High)64()` and `MakeUint128()` are made. + // + // A `uint128_t` supports the following: + // + // * Implicit construction from integral types + // * Explicit conversion to integral types + // + // Additionally, if your compiler supports `__int128_t`, `uint128_t` is + // interoperable with that type. (Abseil checks for this compatibility through + // the `ABSL_HAVE_INTRINSIC_INT128` macro.) + // + // However, a `uint128_t` differs from intrinsic integral types in the following + // ways: + // + // * Errors on implicit conversions that do not preserve value (such as + // loss of precision when converting to float values). + // * Requires explicit construction from and conversion to floating point + // types. + // * Conversion to integral types requires an explicit static_cast() to + // mimic use of the `-Wnarrowing` compiler flag. + // * The alignment requirement of `uint128_t` may differ from that of an + // intrinsic 128-bit integer type depending on platform and build + // configuration. + // + // Example: + // + // float y = absl::Uint128Max(); // Error. uint128_t cannot be implicitly + // // converted to float. + // + // absl::uint128_t v; + // uint64_t i = v; // Error + // uint64_t i = static_cast(v); // OK + // + class +#if defined(ABSL_HAVE_INTRINSIC_INT128) + alignas(unsigned __int128_t) +#endif // ABSL_HAVE_INTRINSIC_INT128 + uint128_t { + public: + uint128_t() = default; + + // Constructors from arithmetic types + constexpr uint128_t(int v); // NOLINT(runtime/explicit) + constexpr uint128_t(unsigned int v); // NOLINT(runtime/explicit) + constexpr uint128_t(long v); // NOLINT(runtime/int) + constexpr uint128_t(unsigned long v); // NOLINT(runtime/int) + constexpr uint128_t(long long v); // NOLINT(runtime/int) + constexpr uint128_t(unsigned long long v); // NOLINT(runtime/int) +#ifdef ABSL_HAVE_INTRINSIC_INT128 + constexpr uint128_t(__int128_t v); // NOLINT(runtime/explicit) + constexpr uint128_t(unsigned __int128_t v); // NOLINT(runtime/explicit) +#endif // ABSL_HAVE_INTRINSIC_INT128 + explicit uint128_t(float v); + explicit uint128_t(double v); + explicit uint128_t(long double v); + + // Assignment operators from arithmetic types + uint128_t& operator=(int v); + uint128_t& operator=(unsigned int v); + uint128_t& operator=(long v); // NOLINT(runtime/int) + uint128_t& operator=(unsigned long v); // NOLINT(runtime/int) + uint128_t& operator=(long long v); // NOLINT(runtime/int) + uint128_t& operator=(unsigned long long v); // NOLINT(runtime/int) +#ifdef ABSL_HAVE_INTRINSIC_INT128 + uint128_t& operator=(__int128_t v); + uint128_t& operator=(unsigned __int128_t v); +#endif // ABSL_HAVE_INTRINSIC_INT128 + + // Conversion operators to other arithmetic types + constexpr explicit operator bool() const; + constexpr explicit operator char() const; + constexpr explicit operator signed char() const; + constexpr explicit operator unsigned char() const; + constexpr explicit operator char16_t() const; + constexpr explicit operator char32_t() const; + constexpr explicit operator ABSL_INTERNAL_WCHAR_T() const; + constexpr explicit operator short() const; // NOLINT(runtime/int) + // NOLINTNEXTLINE(runtime/int) + constexpr explicit operator unsigned short() const; + constexpr explicit operator int() const; + constexpr explicit operator unsigned int() const; + constexpr explicit operator long() const; // NOLINT(runtime/int) + // NOLINTNEXTLINE(runtime/int) + constexpr explicit operator unsigned long() const; + // NOLINTNEXTLINE(runtime/int) + constexpr explicit operator long long() const; + // NOLINTNEXTLINE(runtime/int) + constexpr explicit operator unsigned long long() const; +#ifdef ABSL_HAVE_INTRINSIC_INT128 + constexpr explicit operator __int128_t() const; + constexpr explicit operator unsigned __int128_t() const; +#endif // ABSL_HAVE_INTRINSIC_INT128 + explicit operator float() const; + explicit operator double() const; + explicit operator long double() const; + + // Trivial copy constructor, assignment operator and destructor. + + // Arithmetic operators. + uint128_t& operator+=(uint128_t other); + uint128_t& operator-=(uint128_t other); + uint128_t& operator*=(uint128_t other); + // Long division/modulo for uint128_t. + uint128_t& operator/=(uint128_t other); + uint128_t& operator%=(uint128_t other); + uint128_t operator++(int); + uint128_t operator--(int); + uint128_t& operator<<=(int); + uint128_t& operator>>=(int); + uint128_t& operator&=(uint128_t other); + uint128_t& operator|=(uint128_t other); + uint128_t& operator^=(uint128_t other); + uint128_t& operator++(); + uint128_t& operator--(); + + // Uint128Low64() + // + // Returns the lower 64-bit value of a `uint128_t` value. + friend constexpr uint64_t Uint128Low64(uint128_t v); + + // Uint128High64() + // + // Returns the higher 64-bit value of a `uint128_t` value. + friend constexpr uint64_t Uint128High64(uint128_t v); + + // MakeUInt128() + // + // Constructs a `uint128_t` numeric value from two 64-bit unsigned integers. + // Note that this factory function is the only way to construct a `uint128_t` + // from integer values greater than 2^64. + // + // Example: + // + // absl::uint128_t big = absl::MakeUint128(1, 0); + friend constexpr uint128_t MakeUint128(uint64_t high, uint64_t low); + + // Uint128Max() + // + // Returns the highest value for a 128-bit unsigned integer. + friend constexpr uint128_t Uint128Max(); + + // Support for absl::Hash. + template + friend H AbslHashValue(H h, uint128_t v) { + return H::combine(std::move(h), Uint128High64(v), Uint128Low64(v)); + } + + // Combined division/modulo for a 128-bit unsigned integer. + static void DivMod(uint128_t dividend, uint128_t divisor, uint128_t* quotient_ret, + uint128_t* remainder_ret); + + static std::string ToFormattedString(uint128_t v, std::ios_base::fmtflags flags = std::ios_base::fmtflags()); + + static std::string ToString(uint128_t v); + + private: + constexpr uint128_t(uint64_t high, uint64_t low); + + // TODO(strel) Update implementation to use __int128_t once all users of + // uint128_t are fixed to not depend on alignof(uint128_t) == 8. Also add + // alignas(16) to class definition to keep alignment consistent across + // platforms. +#if defined(ABSL_IS_LITTLE_ENDIAN) + uint64_t lo_; + uint64_t hi_; +#elif defined(ABSL_IS_BIG_ENDIAN) + uint64_t hi_; + uint64_t lo_; +#else // byte order +#error "Unsupported byte order: must be little-endian or big-endian." +#endif // byte order + }; + + // allow uint128_t to be logged + std::ostream& operator<<(std::ostream& os, uint128_t v); + + // TODO(strel) add operator>>(std::istream&, uint128_t) + + constexpr uint128_t Uint128Max() { + return uint128_t((std::numeric_limits::max)(), + (std::numeric_limits::max)()); + } + +} // namespace absl + +// Specialized numeric_limits for uint128_t. +namespace std { + template <> + class numeric_limits { + public: + static constexpr bool is_specialized = true; + static constexpr bool is_signed = false; + static constexpr bool is_integer = true; + static constexpr bool is_exact = true; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = false; + static constexpr bool has_signaling_NaN = false; + static constexpr float_denorm_style has_denorm = denorm_absent; + static constexpr bool has_denorm_loss = false; + static constexpr float_round_style round_style = round_toward_zero; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = true; + static constexpr int digits = 128; + static constexpr int digits10 = 38; + static constexpr int max_digits10 = 0; + static constexpr int radix = 2; + static constexpr int min_exponent = 0; + static constexpr int min_exponent10 = 0; + static constexpr int max_exponent = 0; + static constexpr int max_exponent10 = 0; +#ifdef ABSL_HAVE_INTRINSIC_INT128 + static constexpr bool traps = numeric_limits::traps; +#else // ABSL_HAVE_INTRINSIC_INT128 + static constexpr bool traps = numeric_limits::traps; +#endif // ABSL_HAVE_INTRINSIC_INT128 + static constexpr bool tinyness_before = false; + + static constexpr absl::uint128_t(min)() { return 0; } + static constexpr absl::uint128_t lowest() { return 0; } + static constexpr absl::uint128_t(max)() { return absl::Uint128Max(); } + static constexpr absl::uint128_t epsilon() { return 0; } + static constexpr absl::uint128_t round_error() { return 0; } + static constexpr absl::uint128_t infinity() { return 0; } + static constexpr absl::uint128_t quiet_NaN() { return 0; } + static constexpr absl::uint128_t signaling_NaN() { return 0; } + static constexpr absl::uint128_t denorm_min() { return 0; } + }; +} // namespace std + + +// -------------------------------------------------------------------------- +// Implementation details follow +// -------------------------------------------------------------------------- +namespace absl { + + constexpr uint128_t MakeUint128(uint64_t high, uint64_t low) { + return uint128_t(high, low); + } + + // Assignment from integer types. + + inline uint128_t& uint128_t::operator=(int v) { return *this = uint128_t(v); } + + inline uint128_t& uint128_t::operator=(unsigned int v) { + return *this = uint128_t(v); + } + + inline uint128_t& uint128_t::operator=(long v) { // NOLINT(runtime/int) + return *this = uint128_t(v); + } + + // NOLINTNEXTLINE(runtime/int) + inline uint128_t& uint128_t::operator=(unsigned long v) { + return *this = uint128_t(v); + } + + // NOLINTNEXTLINE(runtime/int) + inline uint128_t& uint128_t::operator=(long long v) { + return *this = uint128_t(v); + } + + // NOLINTNEXTLINE(runtime/int) + inline uint128_t& uint128_t::operator=(unsigned long long v) { + return *this = uint128_t(v); + } + +#ifdef ABSL_HAVE_INTRINSIC_INT128 + inline uint128_t& uint128_t::operator=(__int128_t v) { + return *this = uint128_t(v); + } + + inline uint128_t& uint128_t::operator=(unsigned __int128_t v) { + return *this = uint128_t(v); + } +#endif // ABSL_HAVE_INTRINSIC_INT128 + + + // Arithmetic operators. + + uint128_t operator<<(uint128_t lhs, int amount); + uint128_t operator>>(uint128_t lhs, int amount); + uint128_t operator+(uint128_t lhs, uint128_t rhs); + uint128_t operator-(uint128_t lhs, uint128_t rhs); + uint128_t operator*(uint128_t lhs, uint128_t rhs); + uint128_t operator/(uint128_t lhs, uint128_t rhs); + uint128_t operator%(uint128_t lhs, uint128_t rhs); + + inline uint128_t& uint128_t::operator<<=(int amount) { + *this = *this << amount; + return *this; + } + + inline uint128_t& uint128_t::operator>>=(int amount) { + *this = *this >> amount; + return *this; + } + + inline uint128_t& uint128_t::operator+=(uint128_t other) { + *this = *this + other; + return *this; + } + + inline uint128_t& uint128_t::operator-=(uint128_t other) { + *this = *this - other; + return *this; + } + + inline uint128_t& uint128_t::operator*=(uint128_t other) { + *this = *this * other; + return *this; + } + + inline uint128_t& uint128_t::operator/=(uint128_t other) { + *this = *this / other; + return *this; + } + + inline uint128_t& uint128_t::operator%=(uint128_t other) { + *this = *this % other; + return *this; + } + + constexpr uint64_t Uint128Low64(uint128_t v) { return v.lo_; } + + constexpr uint64_t Uint128High64(uint128_t v) { return v.hi_; } + + // Constructors from integer types. + +#if defined(ABSL_IS_LITTLE_ENDIAN) + + constexpr uint128_t::uint128_t(uint64_t high, uint64_t low) + : lo_{ low }, hi_{ high } { + } + + constexpr uint128_t::uint128_t(int v) + : lo_{ static_cast(v) }, + hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { + } + constexpr uint128_t::uint128_t(long v) // NOLINT(runtime/int) + : lo_{ static_cast(v) }, + hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { + } + constexpr uint128_t::uint128_t(long long v) // NOLINT(runtime/int) + : lo_{ static_cast(v) }, + hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { + } + + constexpr uint128_t::uint128_t(unsigned int v) : lo_{ v }, hi_{ 0 } {} + // NOLINTNEXTLINE(runtime/int) + constexpr uint128_t::uint128_t(unsigned long v) : lo_{ v }, hi_{ 0 } {} + // NOLINTNEXTLINE(runtime/int) + constexpr uint128_t::uint128_t(unsigned long long v) : lo_{ v }, hi_{ 0 } {} + +#ifdef ABSL_HAVE_INTRINSIC_INT128 + constexpr uint128_t::uint128_t(__int128_t v) + : lo_{ static_cast(v & ~uint64_t{0}) }, + hi_{ static_cast(static_cast(v) >> 64) } { + } + constexpr uint128_t::uint128_t(unsigned __int128_t v) + : lo_{ static_cast(v & ~uint64_t{0}) }, + hi_{ static_cast(v >> 64) } { + } +#endif // ABSL_HAVE_INTRINSIC_INT128 + +#elif defined(ABSL_IS_BIG_ENDIAN) + + constexpr uint128_t::uint128_t(uint64_t high, uint64_t low) + : hi_{ high }, lo_{ low } { + } + + constexpr uint128_t::uint128_t(int v) + : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, + lo_{ static_cast(v) } { + } + constexpr uint128_t::uint128_t(long v) // NOLINT(runtime/int) + : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, + lo_{ static_cast(v) } { + } + constexpr uint128_t::uint128_t(long long v) // NOLINT(runtime/int) + : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, + lo_{ static_cast(v) } { + } + + constexpr uint128_t::uint128_t(unsigned int v) : hi_{ 0 }, lo_{ v } {} + // NOLINTNEXTLINE(runtime/int) + constexpr uint128_t::uint128_t(unsigned long v) : hi_{ 0 }, lo_{ v } {} + // NOLINTNEXTLINE(runtime/int) + constexpr uint128_t::uint128_t(unsigned long long v) : hi_{ 0 }, lo_{ v } {} + +#ifdef ABSL_HAVE_INTRINSIC_INT128 + constexpr uint128_t::uint128_t(__int128_t v) + : hi_{ static_cast(static_cast(v) >> 64) }, + lo_{ static_cast(v & ~uint64_t{0}) } { + } + constexpr uint128_t::uint128_t(unsigned __int128_t v) + : hi_{ static_cast(v >> 64) }, + lo_{ static_cast(v & ~uint64_t{0}) } { + } +#endif // ABSL_HAVE_INTRINSIC_INT128 + + constexpr uint128_t::uint128_t(int128_t v) + : hi_{ static_cast(Int128High64(v)) }, lo_{ Int128Low64(v) } { + } + +#else // byte order +#error "Unsupported byte order: must be little-endian or big-endian." +#endif // byte order + +// Conversion operators to integer types. + + constexpr uint128_t::operator bool() const { return lo_ || hi_; } + + constexpr uint128_t::operator char() const { return static_cast(lo_); } + + constexpr uint128_t::operator signed char() const { + return static_cast(lo_); + } + + constexpr uint128_t::operator unsigned char() const { + return static_cast(lo_); + } + + constexpr uint128_t::operator char16_t() const { + return static_cast(lo_); + } + + constexpr uint128_t::operator char32_t() const { + return static_cast(lo_); + } + + constexpr uint128_t::operator ABSL_INTERNAL_WCHAR_T() const { + return static_cast(lo_); + } + + // NOLINTNEXTLINE(runtime/int) + constexpr uint128_t::operator short() const { return static_cast(lo_); } + + constexpr uint128_t::operator unsigned short() const { // NOLINT(runtime/int) + return static_cast(lo_); // NOLINT(runtime/int) + } + + constexpr uint128_t::operator int() const { return static_cast(lo_); } + + constexpr uint128_t::operator unsigned int() const { + return static_cast(lo_); + } + + // NOLINTNEXTLINE(runtime/int) + constexpr uint128_t::operator long() const { return static_cast(lo_); } + + constexpr uint128_t::operator unsigned long() const { // NOLINT(runtime/int) + return static_cast(lo_); // NOLINT(runtime/int) + } + + constexpr uint128_t::operator long long() const { // NOLINT(runtime/int) + return static_cast(lo_); // NOLINT(runtime/int) + } + + constexpr uint128_t::operator unsigned long long() const { // NOLINT(runtime/int) + return static_cast(lo_); // NOLINT(runtime/int) + } + +#ifdef ABSL_HAVE_INTRINSIC_INT128 + constexpr uint128_t::operator __int128_t() const { + return (static_cast<__int128_t>(hi_) << 64) + lo_; + } + + constexpr uint128_t::operator unsigned __int128_t() const { + return (static_cast(hi_) << 64) + lo_; + } +#endif // ABSL_HAVE_INTRINSIC_INT128 + + // Conversion operators to floating point types. + + inline uint128_t::operator float() const { + return static_cast(lo_) + std::ldexp(static_cast(hi_), 64); + } + + inline uint128_t::operator double() const { + return static_cast(lo_) + std::ldexp(static_cast(hi_), 64); + } + + inline uint128_t::operator long double() const { + return static_cast(lo_) + + std::ldexp(static_cast(hi_), 64); + } + + // Comparison operators. + + inline bool operator==(uint128_t lhs, uint128_t rhs) { + return (Uint128Low64(lhs) == Uint128Low64(rhs) && + Uint128High64(lhs) == Uint128High64(rhs)); + } + + inline bool operator!=(uint128_t lhs, uint128_t rhs) { + return !(lhs == rhs); + } + + inline bool operator<(uint128_t lhs, uint128_t rhs) { +#ifdef ABSL_HAVE_INTRINSIC_INT128 + return static_cast(lhs) < + static_cast(rhs); +#else + return (Uint128High64(lhs) == Uint128High64(rhs)) + ? (Uint128Low64(lhs) < Uint128Low64(rhs)) + : (Uint128High64(lhs) < Uint128High64(rhs)); +#endif + } + + inline bool operator>(uint128_t lhs, uint128_t rhs) { return rhs < lhs; } + + inline bool operator<=(uint128_t lhs, uint128_t rhs) { return !(rhs < lhs); } + + inline bool operator>=(uint128_t lhs, uint128_t rhs) { return !(lhs < rhs); } + + // Unary operators. + + inline uint128_t operator-(uint128_t val) { + uint64_t hi = ~Uint128High64(val); + uint64_t lo = ~Uint128Low64(val) + 1; + if (lo == 0) ++hi; // carry + return MakeUint128(hi, lo); + } + + inline bool operator!(uint128_t val) { + return !Uint128High64(val) && !Uint128Low64(val); + } + + // Logical operators. + + inline uint128_t operator~(uint128_t val) { + return MakeUint128(~Uint128High64(val), ~Uint128Low64(val)); + } + + inline uint128_t operator|(uint128_t lhs, uint128_t rhs) { + return MakeUint128(Uint128High64(lhs) | Uint128High64(rhs), + Uint128Low64(lhs) | Uint128Low64(rhs)); + } + + inline uint128_t operator&(uint128_t lhs, uint128_t rhs) { + return MakeUint128(Uint128High64(lhs) & Uint128High64(rhs), + Uint128Low64(lhs) & Uint128Low64(rhs)); + } + + inline uint128_t operator^(uint128_t lhs, uint128_t rhs) { + return MakeUint128(Uint128High64(lhs) ^ Uint128High64(rhs), + Uint128Low64(lhs) ^ Uint128Low64(rhs)); + } + + inline uint128_t& uint128_t::operator|=(uint128_t other) { + hi_ |= other.hi_; + lo_ |= other.lo_; + return *this; + } + + inline uint128_t& uint128_t::operator&=(uint128_t other) { + hi_ &= other.hi_; + lo_ &= other.lo_; + return *this; + } + + inline uint128_t& uint128_t::operator^=(uint128_t other) { + hi_ ^= other.hi_; + lo_ ^= other.lo_; + return *this; + } + + // Arithmetic operators. + + inline uint128_t operator<<(uint128_t lhs, int amount) { +#ifdef ABSL_HAVE_INTRINSIC_INT128 + return static_cast(lhs) << amount; +#else + // uint64_t shifts of >= 64 are undefined, so we will need some + // special-casing. + if (amount < 64) { + if (amount != 0) { + return MakeUint128( + (Uint128High64(lhs) << amount) | (Uint128Low64(lhs) >> (64 - amount)), + Uint128Low64(lhs) << amount); + } + return lhs; + } + return MakeUint128(Uint128Low64(lhs) << (amount - 64), 0); +#endif + } + + inline uint128_t operator>>(uint128_t lhs, int amount) { +#ifdef ABSL_HAVE_INTRINSIC_INT128 + return static_cast(lhs) >> amount; +#else + // uint64_t shifts of >= 64 are undefined, so we will need some + // special-casing. + if (amount < 64) { + if (amount != 0) { + return MakeUint128(Uint128High64(lhs) >> amount, + (Uint128Low64(lhs) >> amount) | + (Uint128High64(lhs) << (64 - amount))); + } + return lhs; + } + return MakeUint128(0, Uint128High64(lhs) >> (amount - 64)); +#endif + } + + inline uint128_t operator+(uint128_t lhs, uint128_t rhs) { + uint128_t result = MakeUint128(Uint128High64(lhs) + Uint128High64(rhs), + Uint128Low64(lhs) + Uint128Low64(rhs)); + if (Uint128Low64(result) < Uint128Low64(lhs)) { // check for carry + return MakeUint128(Uint128High64(result) + 1, Uint128Low64(result)); + } + return result; + } + + inline uint128_t operator-(uint128_t lhs, uint128_t rhs) { + uint128_t result = MakeUint128(Uint128High64(lhs) - Uint128High64(rhs), + Uint128Low64(lhs) - Uint128Low64(rhs)); + if (Uint128Low64(lhs) < Uint128Low64(rhs)) { // check for carry + return MakeUint128(Uint128High64(result) - 1, Uint128Low64(result)); + } + return result; + } + + inline uint128_t operator*(uint128_t lhs, uint128_t rhs) { +#if defined(ABSL_HAVE_INTRINSIC_INT128) + // TODO(strel) Remove once alignment issues are resolved and unsigned __int128_t + // can be used for uint128_t storage. + return static_cast(lhs) * + static_cast(rhs); +#elif defined(_MSC_VER) && defined(_M_X64) + uint64_t carry; + uint64_t low = _umul128(Uint128Low64(lhs), Uint128Low64(rhs), &carry); + return MakeUint128(Uint128Low64(lhs) * Uint128High64(rhs) + + Uint128High64(lhs) * Uint128Low64(rhs) + carry, + low); +#else // ABSL_HAVE_INTRINSIC128 + uint64_t a32 = Uint128Low64(lhs) >> 32; + uint64_t a00 = Uint128Low64(lhs) & 0xffffffff; + uint64_t b32 = Uint128Low64(rhs) >> 32; + uint64_t b00 = Uint128Low64(rhs) & 0xffffffff; + uint128_t result = + MakeUint128(Uint128High64(lhs) * Uint128Low64(rhs) + + Uint128Low64(lhs) * Uint128High64(rhs) + a32 * b32, + a00 * b00); + result += uint128_t(a32 * b00) << 32; + result += uint128_t(a00 * b32) << 32; + return result; +#endif // ABSL_HAVE_INTRINSIC128 + } + + // Increment/decrement operators. + + inline uint128_t uint128_t::operator++(int) { + uint128_t tmp(*this); + *this += 1; + return tmp; + } + + inline uint128_t uint128_t::operator--(int) { + uint128_t tmp(*this); + *this -= 1; + return tmp; + } + + inline uint128_t& uint128_t::operator++() { + *this += 1; + return *this; + } + + inline uint128_t& uint128_t::operator--() { + *this -= 1; + return *this; + } + + + +} // namespace absl + +#undef ABSL_INTERNAL_WCHAR_T + +#endif // ABSL_INT128_H_ \ No newline at end of file diff --git a/libOTe_Tests/CMakeLists.txt b/libOTe_Tests/CMakeLists.txt index 474f08d0..e92b0e08 100644 --- a/libOTe_Tests/CMakeLists.txt +++ b/libOTe_Tests/CMakeLists.txt @@ -4,16 +4,17 @@ set(SRCS bitpolymul_Tests.cpp Common.cpp EACode_Tests.cpp - ExCOnvCode_Tests.cpp + ExConvCode_Tests.cpp NcoOT_Tests.cpp OT_Tests.cpp Pprf_Tests.cpp RegularDpf_Tests.cpp SilentOT_Tests.cpp - Softspoken_Tests.cpp + SoftSpoken_Tests.cpp TungstenCode_Tests.cpp UnitTests.cpp Vole_Tests.cpp + Foliage_Tests.cpp ) add_library(libOTe_Tests STATIC ${SRCS}) diff --git a/libOTe_Tests/Foliage_Tests.cpp b/libOTe_Tests/Foliage_Tests.cpp new file mode 100644 index 00000000..670d86ce --- /dev/null +++ b/libOTe_Tests/Foliage_Tests.cpp @@ -0,0 +1,954 @@ + +#include "Foliage_Tests.h" +#include "libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h" +#include "libOTe/Tools/Foliage/fft/FoliageFft.h" +//#include "libOTe/Tools/Foliage/tri-dpf/FoliageHalfDpf.h" +#include "libOTe/Tools/Foliage/F4Ops.h" +#include "cryptoTools/Common/Matrix.h" +namespace osuCrypto +{ + //u8 extractF4(const uint128_t& val, u8 idx) + //{ + // auto byteIdx = idx / 4; + // auto bitIdx = idx % 4; + // auto byte = ((u8*)&val)[byteIdx]; + // return (byte >> (bitIdx * 2)) & 0b11; + //} + int popcount(uint128_t x) + { + std::array xArr; + memcpy(xArr.data(), &x, 16); + return popcount(xArr[0]) + popcount(xArr[1]); + } + + std::array extractF4(const uint128_t& val) + { + std::array ret; + const char* ptr = (const char*)&val; + for (u8 i = 0; i < 16; ++i) + { + ret[i * 4 + 0] = (ptr[i] >> 0) & 3; + ret[i * 4 + 1] = (ptr[i] >> 2) & 3; + ret[i * 4 + 2] = (ptr[i] >> 4) & 3; + ret[i * 4 + 3] = (ptr[i] >> 6) & 3;; + } + return ret; + } + + void testOutputCorrectness( + span shares0, + span shares1, + size_t num_outputs, + size_t secret_index, + span secret_msg, + size_t msg_len) + { + for (size_t i = 0; i < msg_len; i++) + { + uint128_t shareA = shares0[secret_index * msg_len + i]; + uint128_t shareB = shares1[secret_index * msg_len + i]; + uint128_t res = shareA ^ shareB; + + if (res != secret_msg[i]) + { + printf("FAIL (wrong message)\n"); + exit(0); + } + } + + for (size_t i = 0; i < num_outputs; i++) + { + if (i == secret_index) + continue; + + for (size_t j = 0; j < msg_len; j++) + { + uint128_t shareA = shares0[i * msg_len + j]; + uint128_t shareB = shares1[i * msg_len + j]; + uint128_t res = shareA ^ shareB; + + if (res != 0) + { + printf("FAIL (non-zero) %zu\n", i); + printBytes(&shareA, 16); + printBytes(&shareB, 16); + + exit(0); + } + } + } + } + + void printOutputShares( + uint128_t* shares0, + uint128_t* shares1, + size_t num_outputs, + size_t msg_len) + { + for (size_t i = 0; i < num_outputs; i++) + { + for (size_t j = 0; j < msg_len; j++) + { + uint128_t shareA = shares0[i * msg_len + j]; + uint128_t shareB = shares1[i * msg_len + j]; + //uint128_t res = shareA ^ shareB; + + printf("(%zu, %zu) %zu\n", i, j, msg_len); + printBytes(&shareA, 16); + printBytes(&shareB, 16); + } + } + } + + + + void testOutputCorrectness_spf( + span shares0, + span shares1, + size_t num_outputs, + size_t secret_index, + span secret_msg, + size_t msg_len) + { + for (size_t i = 0; i < msg_len; i++) + { + uint128_t shareA = shares0[secret_index * msg_len + i]; + uint128_t shareB = shares1[secret_index * msg_len + i]; + uint128_t res = shareA ^ shareB; + + if (res != secret_msg[i]) + { + printf("FAIL (wrong message)\n"); + throw RTE_LOC; + } + } + + for (size_t i = 0; i < num_outputs; i++) + { + if (i == secret_index) + continue; + + for (size_t j = 0; j < msg_len; j++) + { + uint128_t shareA = shares0[i * msg_len + j]; + uint128_t shareB = shares1[i * msg_len + j]; + uint128_t res = shareA ^ shareB; + + if (res != 0) + { + printf("FAIL (non-zero) %zu\n", i); + printBytes(&shareA, 16); + printBytes(&shareB, 16); + throw RTE_LOC; + //exit(0); + } + } + } + } + + void printOutputShares_spf( + uint128_t* shares0, + uint128_t* shares1, + size_t num_outputs, + size_t msg_len) + { + for (size_t i = 0; i < num_outputs; i++) + { + for (size_t j = 0; j < msg_len; j++) + { + uint128_t shareA = shares0[i * msg_len + j]; + uint128_t shareB = shares1[i * msg_len + j]; + //uint128_t res = shareA ^ shareB; + + printf("(%zu, %zu) %zu\n", i, j, msg_len); + printBytes(&shareA, 16); + printBytes(&shareB, 16); + } + } + } + + void foliage_spfss_test() + { + + size_t SUMT = 730;// sum of T DPFs + size_t FULLEVALDOMAIN = 10; + size_t MESSAGESIZE = 8; + size_t MAXRANDINDEX = ipow(3, FULLEVALDOMAIN); + + const size_t size = FULLEVALDOMAIN; // evaluation will result in 3^size points + const size_t msg_len = MESSAGESIZE; + PRNG prng(block(3423423)); + + size_t num_leaves = ipow(3, size); + + size_t secret_index = prng.get() % MAXRANDINDEX; + + // sample a random message of size msg_len + std::vector secret_msg(msg_len); + for (size_t i = 0; i < msg_len; i++) + secret_msg[i] = prng.get(); + + PRFKeys prf_keys; + prf_keys.gen(prng); + + std::vector kA(SUMT); + std::vector kB(SUMT); + + for (size_t i = 0; i < SUMT; i++) + DPFGen(prf_keys, size, secret_index, secret_msg, msg_len, kA[i], kB[i], prng); + + std::vector shares0(num_leaves * msg_len); + std::vector shares1(num_leaves * msg_len); + std::vector cache(num_leaves * msg_len); + + //************************************************ + // Test full domain evaluation + //************************************************ + + for (size_t i = 0; i < SUMT; i++) + DPFFullDomainEval(kA[i], cache, shares0); + + clock_t t; + t = clock(); + + for (size_t i = 0; i < SUMT; i++) + DPFFullDomainEval(kB[i], cache, shares1); // we can reuse the same shares and cache + + t = clock() - t; + double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("Time %f ms\n", time_taken); + + // printOutputShares(shares0, shares1, num_leaves, msg_len); + + testOutputCorrectness_spf( + shares0, + shares1, + num_leaves, + secret_index, + secret_msg, + msg_len); + + //DestroyPRFKey(prf_keys); + //free(kA); + //free(kB); + //free(shares0); + //free(shares1); + //free(cache); + + } + + + void foliage_dpf_test() + { + const size_t size = 14; // evaluation will result in 3^size points + const size_t msg_len = 2; + PRNG prng(block(342134)); + + size_t num_leaves = ipow(3, size); + + size_t secret_index = prng.get() % ipow(3, size); + + // sample a random message of size msg_len + std::vector secret_msg(msg_len); + for (size_t i = 0; i < msg_len; i++) + secret_msg[i] = prng.get(); + + PRFKeys prf_keys; + prf_keys.gen(prng); + + DPFKey kA; + DPFKey kB; + + DPFGen(prf_keys, size, secret_index, secret_msg, msg_len, kA, kB, prng); + + std::vector shares0(num_leaves * msg_len); + std::vector shares1(num_leaves * msg_len); + std::vector cache(num_leaves * msg_len); + + //************************************************ + // Test full domain evaluation + //************************************************ + + DPFFullDomainEval(kA, cache, shares0); + + clock_t t; + t = clock(); + DPFFullDomainEval(kB, cache, shares1); + t = clock() - t; + double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("Time %f ms\n", time_taken); + + // printOutputShares(shares0, shares1, num_leaves, msg_len); + + testOutputCorrectness( + shares0, + shares1, + num_leaves, + secret_index, + secret_msg, + msg_len); + + } + + + + // This test case implements Figure 1 from https://eprint.iacr.org/2024/429.pdf. + // It uses /libs/fft and libs/tri-dpf extensively. + // Several simplifying design choices are made: + // 1. We assume that c*c <= 16 so that we can use a parallel FFT packing of F4 + // elements using a uint32_t type. + // 2. We assume that t is a power of 3 so that the block size of each error + // vector divides the size of the polynomial. This makes the code significantly + // more readable and easier to understand. + + // TODO[feature]: The current implementation assumes that C*C <= 16 in order + // to parallelize the FFTs and other components. Making the code work with + // arbitrary values of C is left for future work. + + // TODO[feature]: modularize the different components of the test case and + // design more unit tests. + + u64 log3Ceil(u64 x) + { + if (x == 0) return 0; + u64 i = 0; + u64 v = 1; + while (v < x) + { + v *= 3; + i++; + } + assert(i == ceil(log_base(x, 3))); + + return i; + } + + // This test evaluates the full PCG.Expand for both parties and + // checks correctness of the resulting OLE correlation. + void foliage_pcg_test(const CLP& cmd) + { + bool check = !cmd.isSet("noCheck"); + auto N = 14; // 3^N number of OLEs generated in total + + // The C and T parameters are computed using the SageMath script that can be + // found in https://github.com/mbombar/estimator_folding + + auto C = 4;// compression factor + auto T = 27;// noise weight + + + clock_t time; + time = clock(); + PRNG prng(block(54233453245)); + + const size_t n = N; + const size_t c = C; + const size_t t = T; + + // 3^n + const size_t poly_size = ipow(3, n); + + //************************************************************************ + // Step 0: Sample the global (1, a1 ... a_c-1) polynomials + //************************************************************************ + std::vector fft_a(poly_size); + std::vector fft_a2(poly_size); + sample_a_and_a2(fft_a, fft_a2, poly_size, c, prng); + + //************************************************************************ + // Here, we figure out a good block size for the error vectors such that + // t*block_size = 3^n and block_size/L*128 is close to a power of 3. + // We pack L=256 coefficients of F4 into each DPF output (note that larger + // packing values are also okay, but they will do increase key size). + //************************************************************************ + size_t dpf_domain_bits = log3Ceil(divCeil(poly_size, t * 256.0)); + if (dpf_domain_bits == 0) + dpf_domain_bits = 1; + + printf("DPF domain bits %zu \n", dpf_domain_bits); + + // 4*128 ==> 256 coefficients in F4 + size_t dpf_block_size = 4 * ipow(3, dpf_domain_bits); + + printf("dpf_block_size = %zu\n", dpf_block_size); + + // Note: We assume that t is a power of 3 and so it divides poly_size + assert(poly_size % t == 0); + + // the size of a single regular block. We have t blocks in each polynomial + // poly_size = 2^n / t = 3^{n-3} + size_t block_size = poly_size / t; + + printf("block_size = %zu \n", block_size); + + printf("[ ]Done with Step 0 (sampling the public values)\n"); + + //************************************************************************ + // Step 1: Sample error polynomials eA and eB (c polynomials in total) + // each polynomial is t-sparse and has degree (t * block_size) = poly_size. + //************************************************************************ + std::vector err_polys_A(c * poly_size); + std::vector err_polys_B(c * poly_size); + + // coefficients associated with each error vector + std::vector err_poly_coeffs_A(c * t); + std::vector err_poly_coeffs_B(c * t); + + // positions of the T errors in each error vector + std::vector err_poly_positions_A(c * t); + std::vector err_poly_positions_B(c * t); + + for (size_t i = 0; i < c; i++) + { + for (size_t j = 0; j < t; j++) + { + size_t offset = i * t + j; + + // random *non-zero* coefficients in F4 + uint8_t a = rand_f4x(prng); + uint8_t b = rand_f4x(prng); + err_poly_coeffs_A[offset] = a; + err_poly_coeffs_B[offset] = b; + + // random index within the block + size_t pos_A = random_index(block_size - 1, prng); + size_t pos_B = random_index(block_size - 1, prng); + + if (pos_A >= block_size || pos_B >= block_size) + { + printf("FAIL: position > block_size: %zu, %zu\n", pos_A, pos_B); + throw RTE_LOC; + //exit(0); + } + + err_poly_positions_A[offset] = pos_A; + err_poly_positions_B[offset] = pos_B; + + // set the coefficient at the error position to the error value + err_polys_A[i * poly_size + j * block_size + pos_A] = a; + err_polys_B[i * poly_size + j * block_size + pos_B] = b; + } + } + + // Compute FFT of eA and eB in packed form. + // Note that because c = 4, we can pack 4 FFTs into a uint8_t + std::vector fft_eA(poly_size); + std::vector fft_eB(poly_size); + uint8_t coeff_A, coeff_B; + + // This loop essentially computes a transpose to pack the coefficients + // of each polynomial into one "row" of the parallel FFT matrix + for (size_t j = 0; j < c; j++) + { + for (size_t i = 0; i < poly_size; i++) + { + // extract the i-th coefficient of the j-th error polynomial + coeff_A = err_polys_A[j * poly_size + i]; + coeff_B = err_polys_B[j * poly_size + i]; + + // pack the extracted coefficient into the j-th FFT slot + fft_eA[i] |= (coeff_A << (2 * j)); + fft_eB[i] |= (coeff_B << (2 * j)); + } + } + + // Evaluate the FFTs on the error polynomials eA and eB + fft_recursive_uint8(fft_eA, n, poly_size / 3); + fft_recursive_uint8(fft_eB, n, poly_size / 3); + + printf("[. ]Done with Step 1 (sampling error vectors)\n"); + + //************************************************************************ + // Step 2: compute the inner product xA = and xB = + //************************************************************************ + + // Initialize polynomials to zero (accumulators for inner product) + std::vector x_poly_A(poly_size); + std::vector x_poly_B(poly_size); + + // Compute the coordinate-wise multiplication over the packed FFT result + std::vector res_poly_A(poly_size); + std::vector res_poly_B(poly_size); + multiply_fft_8(fft_a, fft_eA, res_poly_A, poly_size); // a*eA + multiply_fft_8(fft_a, fft_eB, res_poly_B, poly_size); // a*eB + + // XOR the result into the accumulator. + // Specifically, we XOR all the columns of the FFT result to get a + // vector of size poly_size. + for (size_t j = 0; j < c; j++) + { + for (size_t i = 0; i < poly_size; i++) + { + x_poly_A[i] ^= (res_poly_A[i] >> (2 * j)) & 0b11; + x_poly_B[i] ^= (res_poly_B[i] >> (2 * j)) & 0b11; + } + } + + printf("[.. ]Done with Step 2 (computing the local vectors)\n"); + + //************************************************************************ + // Step 3: Compute cross product (eA x eB) using the position vectors + //************************************************************************ + std::vector err_poly_cross_coeffs(c * c * t * t); + std::vector err_poly_cross_positions(c * c * t * t); + std::vector err_polys_cross(c * c * poly_size); + std::vector trit_decomp_A(n); + std::vector trit_decomp_B(n); + std::vector trit_decomp(n); + + for (size_t iA = 0; iA < c; iA++) + { + for (size_t iB = 0; iB < c; iB++) + { + size_t poly_index = iA * c * t * t + iB * t * t; + std::vector next_idx(t); + + for (size_t jA = 0; jA < t; jA++) + { + for (size_t jB = 0; jB < t; jB++) + { + // jA-th coefficient value of the iA-th polynomial + uint8_t vA = err_poly_coeffs_A[iA * t + jA]; + + // jB-th coefficient value of the iB-th polynomial + uint8_t vB = err_poly_coeffs_B[iB * t + jB]; + + // Resulting cross-product coefficient + uint8_t v = mult_f4(vA, vB); + + // Compute the position (in the full polynomial) + size_t posA = jA * block_size + err_poly_positions_A[iA * t + jA]; + size_t posB = jB * block_size + err_poly_positions_B[iB * t + jB]; + + if (err_polys_A[iA * poly_size + posA] == 0) + { + printf("FAIL: Incorrect position recovered\n"); + throw RTE_LOC; + //exit(0); + } + + if (err_polys_B[iB * poly_size + posB] == 0) + { + printf("FAIL: Incorrect position recovered\n"); + throw RTE_LOC; + } + + // Decompose the position into the ternary basis + int_to_trits(posA, trit_decomp_A, n); + int_to_trits(posB, trit_decomp_B, n); + + // printf("[DEBUG]: posA=%zu, posB=%zu\n", posA, posB); + + // Sum ternary decomposition coordinate-wise to + // get the new position (in ternary). + for (size_t k = 0; k < n; k++) + { + // printf("[DEBUG]: trits_A[%zu]=%i, trits_B[%zu]=%i\n", + // k, trit_decomp_A[k], k, trit_decomp_B[k]); + trit_decomp[k] = (trit_decomp_A[k] + trit_decomp_B[k]) % 3; + } + + // Get back the resulting cross-product position as an integer + size_t pos = trits_to_int(trit_decomp, n); + size_t block_idx = floor(pos / block_size); // block index in polynomial + //size_t in_block_idx = pos % block_size; // index within the block + + err_polys_cross[(iA * c + iB) * poly_size + pos] ^= v; + + size_t idx = next_idx[block_idx]; + next_idx[block_idx]++; + + // printf("[DEBUG]: pos=%zu, block_idx=%zu, idx=%zu\n", pos, block_idx, idx); + err_poly_cross_coeffs[poly_index + block_idx * t + idx] = v; + err_poly_cross_positions[poly_index + block_idx * t + idx] = pos % block_size; + } + } + + for (size_t k = 0; k < t; k++) + { + if (next_idx[k] > t) + { + std::cout << "FAIL: next_idx > t at the end: " << next_idx[k] << std::endl; + throw RTE_LOC; + } + } + + //free(next_idx); + } + } + + // cleanup temporary values + //free(trit_decomp); + //free(trit_decomp_A); + //free(trit_decomp_B); + + printf("[... ]Done with Step 3 (computing the cross product)\n"); + + //************************************************************************ + // Step 4: Sample the DPF keys for the cross product (eA x eB) + //************************************************************************ + + std::vector dpf_keys_A(c * c * t * t); + std::vector dpf_keys_B(c * c * t * t); + + // Sample PRF keys for the DPFs + PRFKeys prf_keys; + prf_keys.gen(prng); + + // Sample DPF keys for each of the t errors in the t blocks + for (size_t i = 0; i < c; i++) + { + for (size_t j = 0; j < c; j++) + { + for (size_t k = 0; k < t; k++) + { + for (size_t l = 0; l < t; l++) + { + size_t index = i * c * t * t + j * t * t + k * t + l; + + // Parse the index into the right format + size_t alpha = err_poly_cross_positions[index]; + + // Output message index in the DPF output space + // which consists of 256 F4 elements + size_t alpha_0 = floor(alpha / 256.0); + + // Coeff index in the block of 256 coefficients + size_t alpha_1 = alpha % 256; + + // Coeff index in the uint128_t output (64 elements of F4) + size_t packed_idx = floor(alpha_1 / 64.0); + + // Bit index in the uint128_t ouput + size_t bit_idx = alpha_1 % 64; + + // Set the DPF message to the coefficient + uint128_t coeff = uint128_t(err_poly_cross_coeffs[index]); + + // Position coefficient into the block + std::array beta; // init to zero + setBytes(beta, 0); + beta[packed_idx] = coeff << (2 * (63 - bit_idx)); + + // Message (beta) is of size 4 blocks of 128 bits + DPFGen(prf_keys, dpf_domain_bits, alpha_0, beta, 4, dpf_keys_A[index], dpf_keys_B[index], prng); + } + } + } + } + + printf("[.... ]Done with Step 4 (sampling DPF keys)\n"); + + //************************************************************************ + // Step 5: Evaluate the DPFs to compute shares of (eA x eB) + //************************************************************************ + + // Allocate memory for the DPF outputs (this is reused for each evaluation) + std::vector shares_A(dpf_block_size); + std::vector shares_B(dpf_block_size); + std::vector cache(dpf_block_size); + + // Allocate memory for the concatenated DPF outputs + size_t packed_block_size = divCeil(block_size, 64); + size_t packed_poly_size = t * packed_block_size; + + // printf("[DEBUG]: packed_block_size = %zu\n", packed_block_size); + // printf("[DEBUG]: packed_poly_size = %zu\n", packed_poly_size); + // + // each row is a block. every t rows is a polynomial. + Matrix packed_polys_A_(c * c * t, packed_block_size); + Matrix packed_polys_B_(c * c * t, packed_block_size); + //std::vector packed_polys_A(c * c * packed_poly_size); + //std::vector packed_polys_B(c * c * packed_poly_size); + + // Allocate memory for the output FFT + std::vectorfft_uA(poly_size); + std::vectorfft_uB(poly_size); + //std::vectorfft_uA2(poly_size); + //std::vectorfft_uB2(poly_size); + + // Allocate memory for the final inner product + std::vector z_poly_A(poly_size); + std::vector z_poly_B(poly_size); + std::vector res_poly_mat_A(poly_size); + std::vector res_poly_mat_B(poly_size); + + auto dpf_keys_A_iter = dpf_keys_A.begin(); + auto dpf_keys_B_iter = dpf_keys_B.begin(); + + for (size_t i = 0; i < c; i++) + { + for (size_t j = 0; j < c; j++) + { + const size_t poly_index = i * c + j; + + oc::MatrixView packed_polyA_(packed_polys_A_.data(poly_index * t), t, packed_block_size); + oc::MatrixView packed_polyB_(packed_polys_B_.data(poly_index * t), t, packed_block_size); + //uint128_t* packed_polyA = &packed_polys_A[poly_index * packed_poly_size]; + //uint128_t* packed_polyB = &packed_polys_B[poly_index * packed_poly_size]; + + for (size_t k = 0; k < t; k++) + { + span poly_blockA = packed_polyA_[k]; + span poly_blockB = packed_polyB_[k]; + + for (size_t l = 0; l < t; l++) + { + + DPFKey& dpf_keyA = *dpf_keys_A_iter++; + DPFKey& dpf_keyB = *dpf_keys_B_iter++; + + DPFFullDomainEval(dpf_keyA, cache, shares_A); + DPFFullDomainEval(dpf_keyB, cache, shares_B); + + // Sum all the DPFs for the current block together + // note that there is some extra "garbage" in the last + // block of uint128_t since 64 does not divide block_size. + // We deal with this slack later when packing the outputs + // into the parallel FFT matrix. + for (size_t w = 0; w < packed_block_size; w++) + { + poly_blockA[w] ^= shares_A[w]; + poly_blockB[w] ^= shares_B[w]; + } + } + } + } + } + + + if (check) + { + + // Here, we test to make sure all polynomials have at most t^2 errors + // and fail the test otherwise. + for (size_t i = 0; i < c; i++) + { + for (size_t j = 0; j < c; j++) + { + size_t err_count = 0; + size_t poly_index = i * c + j; + + oc::MatrixView packed_polyA_(packed_polys_A_.data(poly_index * t), t, packed_block_size); + oc::MatrixView packed_polyB_(packed_polys_B_.data(poly_index * t), t, packed_block_size); + //uint128_t* poly_A = &packed_polys_A[poly_index * packed_poly_size]; + //uint128_t* poly_B = &packed_polys_B[poly_index * packed_poly_size]; + + for (size_t p = 0; p < packed_poly_size; p++) + { + uint128_t res = packed_polyA_(p) ^ packed_polyB_(p); + if (res) + { + auto e = extractF4(res); + for (size_t l = 0; l < 64; l++) + { + //if (((res >> (2 * (63 - l))) & uint128_t(0b11)) != uint128_t(0)) + err_count += (e[l] | (e[l] >> 1)) & 1; + //if (e[l]) + // err_count++; + } + } + } + + // printf("[DEBUG]: Number of non-zero coefficients in poly (%zu,%zu) is %zu\n", i, j, err_count); + + if (err_count > t * t) + { + printf("FAIL: Number of non-zero coefficients is %zu > t*t\n", err_count); + throw RTE_LOC; + } + else if (err_count == 0) + { + printf("FAIL: Number of non-zero coefficients in poly (%zu,%zu) is %zu\n", i, j, err_count); + throw RTE_LOC; + } + } + } + } + printf("[..... ]Done with Step 5 (evaluating all DPFs)\n"); + + //************************************************************************ + // Step 6: Compute an FFT over the shares of (eA x eB) + //************************************************************************ + + // Pack the coefficients into FFT blocks + // + // TODO[optimization]: use AVX and fast matrix transposition algorithms. + // The transpose is the bottleneck of the current implementation and + // therefore improving this step can result in significant performance gains. + + if (check) + { + + for (size_t j = 0; j < c; j++) + { + for (size_t k = 0; k < c; k++) + { + std::vector test_poly_A(poly_size); + std::vector test_poly_B(poly_size); + + size_t poly_index = j * c + k; + + oc::MatrixView poly_A(packed_polys_A_.data(poly_index * t), t, packed_block_size); + oc::MatrixView poly_B(packed_polys_B_.data(poly_index * t), t, packed_block_size); + + //uint128_t* poly_A = &packed_polys_A[poly_index * packed_poly_size]; + //uint128_t* poly_B = &packed_polys_B[poly_index * packed_poly_size]; + + u64 i = 0; + for (u64 block_idx = 0; block_idx < t; ++block_idx) + { + for (u64 packed_idx = 0; packed_idx < packed_block_size; ++packed_idx) + { + auto coeffA = extractF4(poly_A(block_idx, packed_idx)); + auto coeffB = extractF4(poly_B(block_idx, packed_idx)); + + //auto idx = j * c + k; + //if (idx >= 16) + // throw RTE_LOC; + auto e = std::min(block_size - packed_idx * 64, 64); + for (u64 element_idx = 0; element_idx < e; ++element_idx) + { + test_poly_A[i] = coeffA[63 - element_idx]; + test_poly_B[i] = coeffB[63 - element_idx]; + ++i; + } + } + } + + for (size_t i = 0; i < poly_size; i++) + { + uint8_t exp_coeff = err_polys_cross[j * c * poly_size + k * poly_size + i]; + uint8_t got_coeff = test_poly_A[i] ^ test_poly_B[i]; + + if (got_coeff != exp_coeff) + { + printf("FAIL: incorrect cross coefficient at index %zu (%i =/= %i)\n", i, got_coeff, exp_coeff); + throw RTE_LOC; + } + } + + } + } + } + + // TODO[optimization]: for arbitrary values of C, we only need to perform + // C*(C+1)/2 FFTs which can lead to a more efficient implementation. + // Because we assume C=4, we have C*C = 16 which fits perfectly into a + // uint32 packing. + + for (size_t j = 0; j < c; j++) + { + for (size_t k = 0; k < c; k++) + { + size_t poly_index = (j * c + k);// *packed_poly_size; + + oc::MatrixView polyA(packed_polys_A_.data(poly_index * t), t, packed_block_size); + oc::MatrixView polyB(packed_polys_B_.data(poly_index * t), t, packed_block_size); + + u64 i = 0; + for (u64 block_idx = 0; block_idx < t; ++block_idx) + { + for (u64 packed_idx = 0; packed_idx < packed_block_size; ++packed_idx) + { + auto coeffA = extractF4(polyA(block_idx, packed_idx)); + auto coeffB = extractF4(polyB(block_idx, packed_idx)); + + //auto idx = j * c + k; + //if (idx >= 16) + // throw RTE_LOC; + auto e = std::min(block_size - packed_idx * 64, 64); + + for (u64 element_idx = 0; element_idx < e; ++element_idx) + { + fft_uA[i] |= u32{ coeffA[63 - element_idx] } << (2 * poly_index); + fft_uB[i] |= u32{ coeffB[63 - element_idx] } << (2 * poly_index); + ++i; + } + } + } + } + } + + fft_recursive_uint32(fft_uA, n, poly_size / 3); + fft_recursive_uint32(fft_uB, n, poly_size / 3); + + printf("[...... ]Done with Step 6 (computing FFTs)\n"); + + //************************************************************************ + // Step 7: Compute shares of z = + //************************************************************************ + multiply_fft_32(fft_a2, fft_uA, res_poly_mat_A, poly_size); + multiply_fft_32(fft_a2, fft_uB, res_poly_mat_B, poly_size); + + //size_t num_ffts = c * c; + + // XOR the (packed) columns into the accumulator. + // Specifically, we perform column-wise XORs to get the result. + uint128_t lsbMask, msbMask; + setBytes(lsbMask, 0b01010101); + setBytes(msbMask, 0b10101010); + for (size_t i = 0; i < poly_size; i++) + { + //auto resA = extractF4(res_poly_mat_A[i]); + //auto resB = extractF4(res_poly_mat_B[i]); + + z_poly_A[i] = + popcount(res_poly_mat_A[i] & lsbMask) & 1 | + (popcount(res_poly_mat_A[i] & msbMask) & 1) << 1; + + z_poly_B[i] = + popcount(res_poly_mat_B[i] & lsbMask) & 1 | + (popcount(res_poly_mat_B[i] & msbMask) & 1) << 1; + + //u8 aSum = 0; + + //for (size_t j = 0; j < c * c; j++) + //{ + // aSum ^= resA[j]; + //} + + //if ((aSum & 1) != aLsb) + // throw RTE_LOC; + //if (((aSum>>1) & 1) != aMsb) + // throw RTE_LOC; + + //for (size_t j = 0; j < c * c; j++) + //{ + // z_poly_A[i] ^= resA[j]; + // z_poly_B[i] ^= resB[j]; + //} + } + + // Now we check that we got the correct OLE correlations and fail + // the test otherwise. + for (size_t i = 0; i < poly_size; i++) + { + uint8_t res = z_poly_A[i] ^ z_poly_B[i]; + uint8_t exp = mult_f4(x_poly_A[i], x_poly_B[i]); + + // printf("[DEBUG]: Got: (%i,%i), Expected: (%i, %i)\n", + // (res >> 1) & 1, res & 1, (exp >> 1) & 1, exp & 1); + + if (res != exp) + { + printf("FAIL: Incorrect correlation output at index %zu\n", i); + printf("Got: (%i,%i), Expected: (%i, %i)\n", + (res >> 1) & 1, res & 1, (exp >> 1) & 1, exp & 1); + throw RTE_LOC; + + } + } + + time = clock() - time; + double time_taken = ((double)time) / (CLOCKS_PER_SEC / 1000.0); // ms + + printf("[.......]Done with Step 7 (recovering shares)\n\n"); + + printf("Time elapsed %f ms\n", time_taken); + + } + +} \ No newline at end of file diff --git a/libOTe_Tests/Foliage_Tests.h b/libOTe_Tests/Foliage_Tests.h new file mode 100644 index 00000000..5e84058b --- /dev/null +++ b/libOTe_Tests/Foliage_Tests.h @@ -0,0 +1,11 @@ +#pragma once +#include "cryptoTools/Common/CLP.h" +namespace osuCrypto +{ + + void foliage_spfss_test(); + void foliage_dpf_test(); + void foliage_pcg_test(const CLP& cmd); + + +} \ No newline at end of file diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index da56bd71..1376a6de 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -17,6 +17,7 @@ #include "libOTe_Tests/Pprf_Tests.h" #include "libOTe_Tests/TungstenCode_Tests.h" #include "libOTe_Tests/RegularDpf_Tests.h" +#include "libOTe_Tests/Foliage_Tests.h" using namespace osuCrypto; namespace tests_libOTe @@ -62,6 +63,11 @@ namespace tests_libOTe tc.add("RegularDpf_Proto_Test ", RegularDpf_Proto_Test); tc.add("SparseDpf_Proto_Test ", SparseDpf_Proto_Test); + tc.add("foliage_dpf_test ", foliage_dpf_test); + tc.add("foliage_spfss_test ", foliage_spfss_test); + tc.add("foliage_pcg_test ", foliage_pcg_test); + + tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); tc.add("Bot_Simplest_asm_Test ", Bot_Simplest_asm_Test); From c834400dc2fc46be9555e4b8a3ce67b510345e10 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 27 Jan 2025 17:10:56 -0800 Subject: [PATCH 05/48] FoliageF4Ole --- libOTe/Tools/Foliage/F4Ops.h | 34 ++- libOTe/Tools/Foliage/FoliagePcg.cpp | 414 ++++++++++++++++++++++++++++ libOTe/Tools/Foliage/FoliagePcg.h | 64 +++++ libOTe/Tools/Foliage/FoliageUtils.h | 76 ++++- libOTe/Tools/Foliage/uint128.h | 17 +- libOTe_Tests/Foliage_Tests.cpp | 170 ++++++++---- libOTe_Tests/Foliage_Tests.h | 1 + libOTe_Tests/UnitTests.cpp | 1 + 8 files changed, 694 insertions(+), 83 deletions(-) create mode 100644 libOTe/Tools/Foliage/FoliagePcg.cpp create mode 100644 libOTe/Tools/Foliage/FoliagePcg.h diff --git a/libOTe/Tools/Foliage/F4Ops.h b/libOTe/Tools/Foliage/F4Ops.h index a533a689..fd4a73d3 100644 --- a/libOTe/Tools/Foliage/F4Ops.h +++ b/libOTe/Tools/Foliage/F4Ops.h @@ -10,21 +10,12 @@ namespace osuCrypto // Samples a non-zero element of F4 inline uint8_t rand_f4x(PRNG& prng) { - uint8_t t; - unsigned char rand_byte; - - // loop until we have two bits where at least one is non-zero - while (1) + uint8_t t = 0; + while (t == 0) { - rand_byte = prng.get(); - t = 0; - t |= rand_byte & 1; - t = t << 1; - t |= (rand_byte >> 1) & 1; - - if (t != 0 && t != 4) - return t; + t = prng.get() & 3; } + return t; } // Multiplies two elements of F4 (optionally: 4 elements packed into uint8_t) @@ -37,6 +28,17 @@ namespace osuCrypto return res; } + inline void f4Mult( + block aLsb, block aMsb, + block bLsb, block bMsb, + block& cLsb, block& cMsb) + { + auto tmp = aMsb & bMsb;// msb only + cMsb = tmp ^ (aMsb & bLsb) ^ (aLsb & bMsb);// msb only + cLsb = (aLsb & bLsb) ^ tmp; + } + + // Multiplies two packed matrices of F4 elements column-by-column. // Note that here the "columns" are packed into an element of uint8_t // resulting in a matrix with 4 columns. @@ -47,8 +49,8 @@ namespace osuCrypto size_t poly_size) { const uint8_t pattern = 0xaa; - uint8_t mask_h = pattern; // 0b101010101010101001010 - uint8_t mask_l = mask_h >> 1; // 0b010101010101010100101 + uint8_t mask_h = pattern; // 0b10101010 + uint8_t mask_l = mask_h >> 1; // 0b01010101 uint8_t tmp; uint8_t a_h, a_l, b_h, b_l; @@ -177,6 +179,8 @@ namespace osuCrypto fft_a[i] = (fft_a[i] & ~3ull) | 1; } + //std::cout << "sampleA " << int(fft_a[0]) << int(fft_a[1]) << int(fft_a[2]) << int(fft_a[3]) << std::endl; + // FOR DEBUGGING: set fft_a to the identity // for (size_t i = 0; i < poly_size; i++) // { diff --git a/libOTe/Tools/Foliage/FoliagePcg.cpp b/libOTe/Tools/Foliage/FoliagePcg.cpp new file mode 100644 index 00000000..33adc175 --- /dev/null +++ b/libOTe/Tools/Foliage/FoliagePcg.cpp @@ -0,0 +1,414 @@ +#include "FoliagePcg.h" +#include "libOTe/Tools/Foliage/FoliageUtils.h" +#include "libOTe/Tools/Foliage/F4Ops.h" +#include "libOTe/Tools/Foliage/fft/FoliageFft.h" +#include "cryptoTools/Common/BitIterator.h" +#include "libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h" +#include "libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h" +namespace osuCrypto +{ + + + void FoliageF4Ole::init(u64 partyIdx, u64 n, PRNG& prng) + { + mPartyIdx = partyIdx; + mLog3N = log3Ceil(n); + mN = ipow(3, mLog3N); + if (mT % 3 != 0) + throw RTE_LOC; + + mDpfDomainDepth = std::max(1, log3Ceil(divCeil(mN, mT * 256))); + mDpfBlockSize = 4 * ipow(3, mDpfDomainDepth); + + mBlockSize = mN / mT; + if (mBlockSize < 8) + throw RTE_LOC; + + sampleA(block(431234234, 213434234123)); + + + //std::cout << "a " << hash(mFftA.data(), mFftA.size()) << std::endl; + //std::cout << "a2 " << hash(mFftASquared.data(), mFftASquared.size()) << std::endl; + + + } + + + void FoliageF4Ole::sampleA(block seed) + { + + if (mC > 4) + throw RTE_LOC; + + PRNG prng(seed); + mFftA.resize(mN); + mFftASquared.resize(0); + mFftASquared.resize(mN); + prng.get(mFftA.data(), mFftA.size()); + + // make a_0 the identity polynomial (in FFT space) i.e., all 1s + for (size_t i = 0; i < mN; i++) + { + mFftA[i] = (mFftA[i] & ~3) | 1; + } + + + // FOR DEBUGGING: set fft_a to the identity + // for (size_t i = 0; i < mN; i++) + // { + // mFftA[i] = (0xaaaa >> 1); + // } + uint32_t prod; + for (size_t i = 0; i < mN; i++) + { + mFftASquared[i] = 0; + for (size_t j = 0; j < mC; j++) + { + for (size_t k = 0; k < mC; k++) + { + auto a = (mFftA[i] >> (2 * j)) & 0b11; + auto b = (mFftA[i] >> (2 * k)) & 0b11; + auto a1 = a & 1; + auto a2 = a & 2; + auto b1 = b & 1; + auto b2 = b & 2; + + { + u8 tmp = (a2 & b2); + prod = tmp ^ ((a2 & (b1 << 1)) ^ ((a1 << 1) & b2)); + prod |= (a1 & b1) ^ (tmp >> 1); + //return res; + } + //prod = mult_f4(, ); + size_t slot = j * mC + k; + mFftASquared[i] |= prod << (2 * slot); + } + } + } + + //{ + // std::vector fft_a(mN); + // std::vector fft_a2(mN); + // PRNG APrng(block(431234234, 213434234123)); + // sample_a_and_a2(fft_a, fft_a2, mN, mC, APrng); + + // for (u64 i = 0; i < mN; ++i) + // { + // if (fft_a[i] != mFftA[i]) + // throw RTE_LOC; + // if (fft_a2[i] != mFftASquared[i]) + // throw RTE_LOC; + // } + + //} + } + + + macoro::task<> FoliageF4Ole::expand( + span ALsb, + span AMsb, + span CLsb, + span CMsb, + PRNG& prng, + coproto::Socket& sock) + { + if (divCeil(mN, 128) < ALsb.size()) + throw RTE_LOC; + if (ALsb.size() != AMsb.size() || ALsb.size() != CLsb.size() || ALsb.size() != CMsb.size()) + throw RTE_LOC; + + mSparseCoefficients.resize(mC, mT); + mSparsePositions.resize(mC, mT); + for (u64 i = 0; i < mC * mT; ++i) + { + while (mSparseCoefficients(i) == 0) + mSparseCoefficients(i) = prng.get() & 3; + mSparsePositions(i) = prng.get() % mBlockSize; + } + + + //std::cout << "pos " << hash(mSparsePositions.data(), mSparsePositions.size()) << std::endl; + //std::cout << "coeff " << hash(mSparseCoefficients.data(), mSparseCoefficients.size()) << std::endl; + + + if (mC != 4) + throw RTE_LOC; + + // we pack 4 FFTs into a single u8. + std::vector fftSparsePoly(mN); + for (u64 i = 0; i < mT; ++i) + { + for (u64 j = 0; j < mC; ++j) + { + auto pos = i * mBlockSize + mSparsePositions(j, i); + fftSparsePoly[pos] |= mSparseCoefficients(j, i) << (2 * j); + } + } + + //std::cout << "sparse " << hash(fftSparsePoly.data(), fftSparsePoly.size()) << std::endl; + + // switch from polynomial to FFT form + fft_recursive_uint8(fftSparsePoly, mLog3N, mN / 3); + + // multiply by the packed A polynomial + multiply_fft_8(mFftA, fftSparsePoly, fftSparsePoly, mN); + + //std::cout << "mult " << hash(fftSparsePoly.data(), fftSparsePoly.size()) << std::endl; + + + // compress the resume and set the output. + auto outSize = std::min(mN, ALsb.size() * 128); + std::vector A(mN); + for (u64 i = 0; i < outSize; ++i) + { + auto a = + ((fftSparsePoly[i] >> 0) ^ + (fftSparsePoly[i] >> 2) ^ + (fftSparsePoly[i] >> 4) ^ + (fftSparsePoly[i] >> 6)) & 3; + + *BitIterator(ALsb.data(), i) = a & 1; + *BitIterator(AMsb.data(), i) = (a >> 1) & 1; + + A[i] = a; + } + //std::cout << "compress " << hash(fftSparsePoly.data(), fftSparsePoly.size()) << std::endl; + + + std::vector prodPolyCoefficient(mC * mC * mT * mT); + std::vector prodPolyPosition(mC * mC * mT * mT); + //auto prodPolyCoefficientIter = prodPolyCoefficient.begin(); + //auto prodPolyPositionIter = prodPolyPosition.begin(); + std::vector tritA(mLog3N), tritB(mLog3N), trits(mLog3N); + + Matrix otherSparseCoefficients(mC, mT); + Matrix otherSparsePositions(mC, mT); + co_await sock.send(coproto::copy(mSparseCoefficients)); + co_await sock.send(coproto::copy(mSparsePositions)); + co_await sock.recv(otherSparseCoefficients); + co_await sock.recv(otherSparsePositions); + u64 polyOffset = 0; + u8 vA, vB; + for (u64 iA = 0; iA < mC; ++iA) + { + for (u64 iB = 0; iB < mC; ++iB) + { + std::vector nextIdx(mT); + + for (u64 jA = 0; jA < mT; ++jA) + { + for (u64 jB = 0; jB < mT; ++jB) + { + if (mPartyIdx == 0) + { + vA = mSparseCoefficients(iA, jA); + vB = otherSparseCoefficients(iB, jB); + auto posA = jA * mBlockSize + mSparsePositions(iA, jA); + auto posB = jB * mBlockSize + otherSparsePositions(iB, jB); + int_to_trits(posA, tritA); + int_to_trits(posB, tritB); + } + else + { + vA = otherSparseCoefficients(iA, jA); + vB = mSparseCoefficients(iB, jB); + auto posA = jA * mBlockSize + otherSparsePositions(iA, jA); + auto posB = jB * mBlockSize + mSparsePositions(iB, jB); + int_to_trits(posA, tritA); + int_to_trits(posB, tritB); + } + + for (size_t k = 0; k < mLog3N; k++) + { + trits[k] = (tritA[k] + tritB[k]) % 3; + } + + u64 pos = trits_to_int(trits); + auto blockIdx = pos / mBlockSize; + + size_t idx = polyOffset + blockIdx * mT + nextIdx[blockIdx]++; + prodPolyCoefficient[idx] = mult_f4(vA, vB); + prodPolyPosition[idx] = pos % mBlockSize; + } + } + + if (nextIdx != std::vector(mT, mT)) + throw RTE_LOC; + + polyOffset += mT * mT; + } + } + + + std::vector Dpfs(mC * mC * mT * mT); + + // Sample PRF keys for the DPFs + PRFKeys prf_keys; + PRNG prfSeedPrng(block(3412342134, 56453452362346)); + prf_keys.gen(prfSeedPrng); + + // Sample DPF keys for each of the t errors in the t blocks + u64 index = 0; + PRNG genPrng; + + //oc::RandomOracle dpfHash(16); + + for (u64 i = 0; i < mC; i++) + { + for (u64 j = 0; j < mC; j++) + { + for (u64 k = 0; k < mT; k++) + { + for (u64 l = 0; l < mT; l++, ++index) + { + //size_t index = i * c * t * t + j * t * t + k * t + l; + + // Parse the index into the right format + size_t alpha = prodPolyPosition[index]; + + // Output message index in the DPF output space + // which consists of 256 F4 elements + size_t alpha_0 = alpha / 256; + + // Coeff index in the block of 256 coefficients + size_t alpha_1 = alpha % 256; + + // Coeff index in the uint128_t output (64 elements of F4) + size_t packed_idx = alpha_1 / 64; + + // Bit index in the uint128_t ouput + size_t bit_idx = alpha_1 % 64; + + // Set the DPF message to the coefficient + uint128_t coeff = uint128_t(prodPolyCoefficient[index]); + + // Position coefficient into the block + std::array beta; // init to zero + setBytes(beta, 0); + beta[packed_idx] = coeff << (2 * (63 - bit_idx)); + + // Message (beta) is of size 4 blocks of 128 bits + genPrng.SetSeed(block(index, 542345234)); + DPFKey _; + if (mPartyIdx) + { + DPFGen(prf_keys, mDpfDomainDepth, alpha_0, beta, 4, _, Dpfs[index], genPrng); + } + else + { + DPFGen(prf_keys, mDpfDomainDepth, alpha_0, beta, 4, Dpfs[index], _, genPrng); + } + + //dpfHash.Update(Dpfs[index].k.data(), Dpfs[index].k.size()); + //dpfHash.Update(Dpfs[index].msg_len); + //dpfHash.Update(Dpfs[index].size); + + } + } + } + } + + //block dpfHashVal; + //dpfHash.Final(dpfHashVal); + //std::cout << "dpf " << dpfHashVal << std::endl; + + std::vector shares(mDpfBlockSize); + std::vector cache(mDpfBlockSize); + + size_t packedBlockSize = divCeil(mBlockSize, 64); + Matrix blocks(mC * mC * mT, packedBlockSize); + + std::vector fft(mN), fftRes(mN); + + auto dpfIter = Dpfs.begin(); + //auto dpf_keys_B_iter = dpf_keys_B.begin(); + + for (size_t i = 0; i < mC; i++) + { + for (size_t j = 0; j < mC; j++) + { + const size_t poly_index = i * mC + j; + + oc::MatrixView packed_polyA_(blocks.data(poly_index * mT), mT, blocks.cols()); + + for (size_t k = 0; k < mT; k++) + { + span poly_blockA = packed_polyA_[k]; + + for (size_t l = 0; l < mT; l++) + { + + DPFKey& dpf = *dpfIter++; + + DPFFullDomainEval(dpf, cache, shares); + + // Sum all the DPFs for the current block together + // note that there is some extra "garbage" in the last + // block of uint128_t since 64 does not divide block_size. + // We deal with this slack later when packing the outputs + // into the parallel FFT matrix. + for (size_t w = 0; w < packedBlockSize; w++) + { + poly_blockA[w] ^= shares[w]; + } + } + } + } + } + + //std::cout << "block " << hash(blocks.data(), blocks.size()) << std::endl; + + + for (size_t j = 0; j < mC; j++) + { + for (size_t k = 0; k < mC; k++) + { + size_t poly_index = (j * mC + k); + + oc::MatrixView poly(blocks.data(poly_index * mT), mT, packedBlockSize); + + u64 i = 0; + for (u64 block_idx = 0; block_idx < mT; ++block_idx) + { + for (u64 packed_idx = 0; packed_idx < packedBlockSize; ++packed_idx) + { + auto coeff = extractF4(poly(block_idx, packed_idx)); + auto e = std::min(mBlockSize - packed_idx * 64, 64); + + for (u64 element_idx = 0; element_idx < e; ++element_idx) + { + fft[i] |= u32{ coeff[63 - element_idx] } << (2 * poly_index); + ++i; + } + } + } + } + } + //std::cout << "CIn " << hash(fft.data(), fft.size()) << std::endl; + + + fft_recursive_uint32(fft, mLog3N, mN / 3); + //std::cout << "Cfft " << hash(fft.data(), fft.size()) << std::endl; + multiply_fft_32(mFftASquared, fft, fftRes, mN); + + //std::cout << "C " << hash(fftRes.data(), fftRes.size()) << std::endl; + + + // XOR the (packed) columns into the accumulator. + // Specifically, we perform column-wise XORs to get the result. + uint128_t lsbMask, msbMask; + setBytes(lsbMask, 0b01010101); + setBytes(msbMask, 0b10101010); + for (size_t i = 0; i < outSize; i++) + { + //auto resA = extractF4(res_poly_mat_A[i]); + //auto resB = extractF4(res_poly_mat_B[i]); + + *BitIterator(CLsb.data(), i) = popcount(fftRes[i] & lsbMask) & 1; + *BitIterator(CMsb.data(), i) = popcount(fftRes[i] & msbMask) & 1; + } + + } + + +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/FoliagePcg.h b/libOTe/Tools/Foliage/FoliagePcg.h new file mode 100644 index 00000000..1ce762c5 --- /dev/null +++ b/libOTe/Tools/Foliage/FoliagePcg.h @@ -0,0 +1,64 @@ +#pragma once +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/Matrix.h" +#include "cryptoTools/Common/Aligned.h" +#include "coproto/Socket/Socket.h" +#include "cryptoTools/Crypto/PRNG.h" + +namespace osuCrypto +{ + + class FoliageF4Ole + { + public: + u64 mPartyIdx = 0; + + // log3 polynomial size + u64 mLog3N = 0; + + // the number of noisy positions per polynomial + u64 mT = 27; + + // the number of polynomials + u64 mC = 4; + + // the size of a polynomial, 3^mLog3N + u64 mN = 0; + + // The A poly in FFT format. We pack mC FFTs into a single u8. The + // first is hard coded to the identity polynomial. + AlignedUnVector mFftA; + + // The A^2 poly in FFT format. We pack mC^2 FFTs into a single u32. + AlignedVector mFftASquared; + + // depth of 3-ary DPF with 256 F4 values per leaf. + u64 mDpfDomainDepth = 0; + + u64 mDpfBlockSize = 0; + + // the number of F4 values per block. Each block will have 1 non-zero. + // A polynomial will have mT blocks. i.e. mN = mT * mBlockSize. + u64 mBlockSize = 0; + + // the coefficient of the sparse polynomial. + // the i'th row containts the coeffs for the i'th poly. + Matrix mSparseCoefficients; + + // the locations of the non-zeros in the j'th block of the sparse polynomial. + // the i'th row containts the coeffs for the i'th poly. + Matrix mSparsePositions; + + void init(u64 partyIdx, u64 n, PRNG& prng); + + macoro::task<> expand( + span ALsb, + span AMsb, + span CLsb, + span CMsb, PRNG& prng, coproto::Socket& sock); + + + + void sampleA(block seed); + }; +} diff --git a/libOTe/Tools/Foliage/FoliageUtils.h b/libOTe/Tools/Foliage/FoliageUtils.h index c54a4f66..52b18e8f 100644 --- a/libOTe/Tools/Foliage/FoliageUtils.h +++ b/libOTe/Tools/Foliage/FoliageUtils.h @@ -1,8 +1,12 @@ #pragma once #include "cryptoTools/Crypto/AES.h" #include "cryptoTools/Crypto/PRNG.h" +#include "cryptoTools/Crypto/RandomOracle.h" #include #include "uint128.h" +#include +#include + namespace osuCrypto { using uint128_t = absl::uint128_t; @@ -154,6 +158,27 @@ namespace osuCrypto printf("\n"); } + template + inline block hash(T* ptr, u64 size) + { + oc::RandomOracle ro(16); + ro.Update(ptr, size); + block f; + ro.Final(f); + return f; + } + + + inline std::string hex32(span ptr) + { + std::stringstream ss; + for (u64 i = 0; i < ptr.size(); ++i) + ss << std::setw(8)< trits, size_t size) + inline void reverse_uint8_array(span trits) { size_t i = 0; - size_t j = size - 1; + size_t j = trits.size() - 1; while (i < j) { @@ -204,24 +229,24 @@ namespace osuCrypto } // Converts an array of trits (not packed) into their integer representation. - inline size_t trits_to_int(span trits, size_t size) + inline size_t trits_to_int(span trits) { - reverse_uint8_array(trits, size); + reverse_uint8_array(trits); size_t result = 0; - for (size_t i = 0; i < size; i++) + for (size_t i = 0; i < trits.size(); i++) result = result * 3 + (size_t)trits[i]; return result; } // Converts an integer into ternary representation (each trit = 0,1,2) - inline void int_to_trits(size_t n, span trits, size_t size) + inline void int_to_trits(size_t n, span trits) { - for (size_t i = 0; i < size; i++) + for (size_t i = 0; i < trits.size(); i++) trits[i] = 0; size_t index = 0; - while (n > 0 && index < size) + while (n > 0 && index < trits.size()) { trits[index] = (uint8_t)(n % 3); n = n / 3; @@ -235,6 +260,21 @@ namespace osuCrypto return std::log2(a) / std::log2(base); } + inline u64 log3Ceil(u64 x) + { + if (x == 0) return 0; + u64 i = 0; + u64 v = 1; + while (v < x) + { + v *= 3; + i++; + } + //assert(i == ceil(log_base(x, 3))); + + return i; + } + // Compute base^exp without the floating-point precision // errors of the built-in pow function. inline size_t ipow(size_t base, size_t exp) @@ -259,5 +299,25 @@ namespace osuCrypto return result; } + inline int popcount(uint128_t x) + { + std::array xArr; + memcpy(xArr.data(), &x, 16); + return popcount(xArr[0]) + popcount(xArr[1]); + } + + inline std::array extractF4(const uint128_t& val) + { + std::array ret; + const char* ptr = (const char*)&val; + for (u8 i = 0; i < 16; ++i) + { + ret[i * 4 + 0] = (ptr[i] >> 0) & 3; + ret[i * 4 + 1] = (ptr[i] >> 2) & 3; + ret[i * 4 + 2] = (ptr[i] >> 4) & 3; + ret[i * 4 + 3] = (ptr[i] >> 6) & 3;; + } + return ret; + } } \ No newline at end of file diff --git a/libOTe/Tools/Foliage/uint128.h b/libOTe/Tools/Foliage/uint128.h index b38d9012..05fa058b 100644 --- a/libOTe/Tools/Foliage/uint128.h +++ b/libOTe/Tools/Foliage/uint128.h @@ -85,22 +85,11 @@ // // assumption specified above. // int y = x / 16; // -#if !defined(NDEBUG) -#define ABSL_INTERNAL_ASSUME(cond) assert(cond) -#elif ABSL_HAVE_BUILTIN(__builtin_assume) -#define ABSL_INTERNAL_ASSUME(cond) __builtin_assume(cond) -#elif defined(__GNUC__) || ABSL_HAVE_BUILTIN(__builtin_unreachable) -#define ABSL_INTERNAL_ASSUME(cond) \ - do { \ - if (!(cond)) __builtin_unreachable(); \ - } while (0) -#elif defined(_MSC_VER) + +#if defined(_MSC_VER) #define ABSL_INTERNAL_ASSUME(cond) __assume(cond) #else -#define ABSL_INTERNAL_ASSUME(cond) \ - do { \ - static_cast(false && (cond)); \ - } while (0) +#define ABSL_INTERNAL_ASSUME(cond) #endif namespace absl { diff --git a/libOTe_Tests/Foliage_Tests.cpp b/libOTe_Tests/Foliage_Tests.cpp index 670d86ce..922f7016 100644 --- a/libOTe_Tests/Foliage_Tests.cpp +++ b/libOTe_Tests/Foliage_Tests.cpp @@ -5,6 +5,8 @@ //#include "libOTe/Tools/Foliage/tri-dpf/FoliageHalfDpf.h" #include "libOTe/Tools/Foliage/F4Ops.h" #include "cryptoTools/Common/Matrix.h" +#include "libOTe/Tools/Foliage/FoliagePcg.h" +#include "coproto/Socket/LocalAsyncSock.h" namespace osuCrypto { //u8 extractF4(const uint128_t& val, u8 idx) @@ -14,26 +16,6 @@ namespace osuCrypto // auto byte = ((u8*)&val)[byteIdx]; // return (byte >> (bitIdx * 2)) & 0b11; //} - int popcount(uint128_t x) - { - std::array xArr; - memcpy(xArr.data(), &x, 16); - return popcount(xArr[0]) + popcount(xArr[1]); - } - - std::array extractF4(const uint128_t& val) - { - std::array ret; - const char* ptr = (const char*)&val; - for (u8 i = 0; i < 16; ++i) - { - ret[i * 4 + 0] = (ptr[i] >> 0) & 3; - ret[i * 4 + 1] = (ptr[i] >> 2) & 3; - ret[i * 4 + 2] = (ptr[i] >> 4) & 3; - ret[i * 4 + 3] = (ptr[i] >> 6) & 3;; - } - return ret; - } void testOutputCorrectness( span shares0, @@ -310,27 +292,13 @@ namespace osuCrypto // TODO[feature]: modularize the different components of the test case and // design more unit tests. - u64 log3Ceil(u64 x) - { - if (x == 0) return 0; - u64 i = 0; - u64 v = 1; - while (v < x) - { - v *= 3; - i++; - } - assert(i == ceil(log_base(x, 3))); - - return i; - } // This test evaluates the full PCG.Expand for both parties and // checks correctness of the resulting OLE correlation. void foliage_pcg_test(const CLP& cmd) { bool check = !cmd.isSet("noCheck"); - auto N = 14; // 3^N number of OLEs generated in total + auto N = 12; // 3^N number of OLEs generated in total // The C and T parameters are computed using the SageMath script that can be // found in https://github.com/mbombar/estimator_folding @@ -341,7 +309,8 @@ namespace osuCrypto clock_t time; time = clock(); - PRNG prng(block(54233453245)); + PRNG prng0(block(2424523452345, 111124521521455324)); + PRNG prng1(block(6474567454546, 567546754674345444)); const size_t n = N; const size_t c = C; @@ -355,7 +324,12 @@ namespace osuCrypto //************************************************************************ std::vector fft_a(poly_size); std::vector fft_a2(poly_size); - sample_a_and_a2(fft_a, fft_a2, poly_size, c, prng); + PRNG APrng(block(431234234, 213434234123)); + sample_a_and_a2(fft_a, fft_a2, poly_size, c, APrng); + + //std::cout << "a " << hash(fft_a.data(), fft_a.size()) << std::endl; + //std::cout << "a2 " << hash(fft_a2.data(), fft_a2.size()) << std::endl; + //************************************************************************ // Here, we figure out a good block size for the error vectors such that @@ -407,14 +381,14 @@ namespace osuCrypto size_t offset = i * t + j; // random *non-zero* coefficients in F4 - uint8_t a = rand_f4x(prng); - uint8_t b = rand_f4x(prng); + uint8_t a = rand_f4x(prng0); + uint8_t b = rand_f4x(prng1); err_poly_coeffs_A[offset] = a; err_poly_coeffs_B[offset] = b; // random index within the block - size_t pos_A = random_index(block_size - 1, prng); - size_t pos_B = random_index(block_size - 1, prng); + size_t pos_A = random_index(block_size - 1, prng0); + size_t pos_B = random_index(block_size - 1, prng1); if (pos_A >= block_size || pos_B >= block_size) { @@ -432,6 +406,13 @@ namespace osuCrypto } } + + //std::cout << "posA " << hash(err_poly_positions_A.data(), err_poly_positions_A.size()) << std::endl; + //std::cout << "posB " << hash(err_poly_positions_B.data(), err_poly_positions_B.size()) << std::endl; + //std::cout << "coeffA " << hash(err_poly_coeffs_A.data(), err_poly_coeffs_A.size()) << std::endl; + //std::cout << "coeffB " << hash(err_poly_coeffs_B.data(), err_poly_coeffs_B.size()) << std::endl; + + // Compute FFT of eA and eB in packed form. // Note that because c = 4, we can pack 4 FFTs into a uint8_t std::vector fft_eA(poly_size); @@ -454,6 +435,10 @@ namespace osuCrypto } } + //std::cout << "sparseA " << hash(fft_eA.data(), fft_eA.size()) << std::endl; + //std::cout << "sparseB " << hash(fft_eB.data(), fft_eB.size()) << std::endl; + + // Evaluate the FFTs on the error polynomials eA and eB fft_recursive_uint8(fft_eA, n, poly_size / 3); fft_recursive_uint8(fft_eB, n, poly_size / 3); @@ -474,6 +459,12 @@ namespace osuCrypto multiply_fft_8(fft_a, fft_eA, res_poly_A, poly_size); // a*eA multiply_fft_8(fft_a, fft_eB, res_poly_B, poly_size); // a*eB + + //std::cout << "multA " << hash(res_poly_A.data(), res_poly_A.size()) << std::endl; + //std::cout << "multB " << hash(res_poly_B.data(), res_poly_B.size()) << std::endl; + + + // XOR the result into the accumulator. // Specifically, we XOR all the columns of the FFT result to get a // vector of size poly_size. @@ -486,6 +477,10 @@ namespace osuCrypto } } + //std::cout << "compressA " << hash(x_poly_A.data(), x_poly_A.size()) << std::endl; + //std::cout << "compressB " << hash(x_poly_B.data(), x_poly_B.size()) << std::endl; + + printf("[.. ]Done with Step 2 (computing the local vectors)\n"); //************************************************************************ @@ -536,8 +531,8 @@ namespace osuCrypto } // Decompose the position into the ternary basis - int_to_trits(posA, trit_decomp_A, n); - int_to_trits(posB, trit_decomp_B, n); + int_to_trits(posA, trit_decomp_A); + int_to_trits(posB, trit_decomp_B); // printf("[DEBUG]: posA=%zu, posB=%zu\n", posA, posB); @@ -551,7 +546,7 @@ namespace osuCrypto } // Get back the resulting cross-product position as an integer - size_t pos = trits_to_int(trit_decomp, n); + size_t pos = trits_to_int(trit_decomp); size_t block_idx = floor(pos / block_size); // block index in polynomial //size_t in_block_idx = pos % block_size; // index within the block @@ -579,6 +574,7 @@ namespace osuCrypto } } + // cleanup temporary values //free(trit_decomp); //free(trit_decomp_A); @@ -595,7 +591,11 @@ namespace osuCrypto // Sample PRF keys for the DPFs PRFKeys prf_keys; - prf_keys.gen(prng); + PRNG prfSeedPrng(block(3412342134, 56453452362346)); + prf_keys.gen(prfSeedPrng); + PRNG genPrng; + oc::RandomOracle dpfHash0(16); + oc::RandomOracle dpfHash1(16); // Sample DPF keys for each of the t errors in the t blocks for (size_t i = 0; i < c; i++) @@ -633,12 +633,27 @@ namespace osuCrypto beta[packed_idx] = coeff << (2 * (63 - bit_idx)); // Message (beta) is of size 4 blocks of 128 bits - DPFGen(prf_keys, dpf_domain_bits, alpha_0, beta, 4, dpf_keys_A[index], dpf_keys_B[index], prng); + genPrng.SetSeed(block(index, 542345234)); + DPFGen(prf_keys, dpf_domain_bits, alpha_0, beta, 4, dpf_keys_A[index], dpf_keys_B[index], genPrng); + + + dpfHash0.Update(dpf_keys_A[index].k.data(), dpf_keys_A[index].k.size()); + dpfHash0.Update(dpf_keys_A[index].msg_len); + dpfHash0.Update(dpf_keys_A[index].size); + dpfHash1.Update(dpf_keys_B[index].k.data(), dpf_keys_B[index].k.size()); + dpfHash1.Update(dpf_keys_B[index].msg_len); + dpfHash1.Update(dpf_keys_B[index].size); } } } } + block dpfHashVal0, dpfHashVal1; + dpfHash0.Final(dpfHashVal0); + dpfHash1.Final(dpfHashVal1); + //std::cout << "dpfA " << dpfHashVal0 << std::endl; + //std::cout << "dpfB " << dpfHashVal1 << std::endl; + printf("[.... ]Done with Step 4 (sampling DPF keys)\n"); //************************************************************************ @@ -718,6 +733,9 @@ namespace osuCrypto } } + //std::cout << "blockA " << hash(packed_polys_A_.data(), packed_polys_A_.size()) << std::endl; + //std::cout << "blockB " << hash(packed_polys_B_.data(), packed_polys_B_.size()) << std::endl; + if (check) { @@ -872,9 +890,16 @@ namespace osuCrypto } } + //std::cout << "Cin0 " << hash(fft_uA.data(), fft_uA.size()) << std::endl; + //std::cout << "Cin1 " << hash(fft_uB.data(), fft_uB.size()) << std::endl; + fft_recursive_uint32(fft_uA, n, poly_size / 3); fft_recursive_uint32(fft_uB, n, poly_size / 3); + //std::cout << "Cfft0 " << hash(fft_uA.data(), fft_uA.size()) << std::endl; + //std::cout << "Cfft1 " << hash(fft_uB.data(), fft_uB.size()) << std::endl; + + printf("[...... ]Done with Step 6 (computing FFTs)\n"); //************************************************************************ @@ -882,6 +907,8 @@ namespace osuCrypto //************************************************************************ multiply_fft_32(fft_a2, fft_uA, res_poly_mat_A, poly_size); multiply_fft_32(fft_a2, fft_uB, res_poly_mat_B, poly_size); + //std::cout << "C0 " << hash(res_poly_mat_A.data(), res_poly_mat_A.size()) << std::endl; + //std::cout << "C1 " << hash(res_poly_mat_B.data(), res_poly_mat_B.size()) << std::endl; //size_t num_ffts = c * c; @@ -951,4 +978,55 @@ namespace osuCrypto } + + // This test evaluates the full PCG.Expand for both parties and + // checks correctness of the resulting OLE correlation. + void foliage_F4ole_test(const CLP& cmd) + { + std::array oles; + + auto logn = 12; + u64 n = ipow(3, logn); + auto blocks = divCeil(n, 128); + //PRNG prng(block(342342)); + PRNG prng0(block(2424523452345, 111124521521455324)); + PRNG prng1(block(6474567454546, 567546754674345444)); + + oles[0].init(0, n, prng0); + oles[1].init(1, n, prng1); + auto sock = coproto::LocalAsyncSocket::makePair(); + std::vector + ALsb(blocks), + AMsb(blocks), + BLsb(blocks), + BMsb(blocks), + C0Lsb(blocks), + C0Msb(blocks), + C1Lsb(blocks), + C1Msb(blocks); + + auto r = macoro::sync_wait(macoro::when_all_ready( + oles[0].expand(ALsb, AMsb, C0Lsb, C0Msb, prng0, sock[0]), + oles[1].expand(BLsb, BMsb, C1Lsb, C1Msb, prng1, sock[1]))); + std::get<0>(r).result(); + std::get<1>(r).result(); + + // Now we check that we got the correct OLE correlations and fail + // the test otherwise. + for (size_t i = 0; i < blocks; i++) + { + auto aLsb = C0Lsb[i] ^ C1Lsb[i]; + auto aMsb = C0Msb[i] ^ C1Msb[i]; + block mLsb, mMsb; + f4Mult( + ALsb[i], AMsb[i], + BLsb[i], BMsb[i], + mLsb, mMsb); + + if (aLsb != mLsb) + throw RTE_LOC; + if (aMsb != mMsb) + throw RTE_LOC; + } + } } \ No newline at end of file diff --git a/libOTe_Tests/Foliage_Tests.h b/libOTe_Tests/Foliage_Tests.h index 5e84058b..46b69053 100644 --- a/libOTe_Tests/Foliage_Tests.h +++ b/libOTe_Tests/Foliage_Tests.h @@ -6,6 +6,7 @@ namespace osuCrypto void foliage_spfss_test(); void foliage_dpf_test(); void foliage_pcg_test(const CLP& cmd); + void foliage_F4ole_test(const CLP& cmd); } \ No newline at end of file diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 1376a6de..53ffe1aa 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -66,6 +66,7 @@ namespace tests_libOTe tc.add("foliage_dpf_test ", foliage_dpf_test); tc.add("foliage_spfss_test ", foliage_spfss_test); tc.add("foliage_pcg_test ", foliage_pcg_test); + tc.add("foliage_F4ole_test ", foliage_F4ole_test); tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); From b81ff603efcf06c81eb7dfb153bde52686b3ad3b Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 5 Feb 2025 02:58:14 -0800 Subject: [PATCH 06/48] foleage fft attempted opt and started tri dpf --- libOTe/Tools/Dpf/RegularDpf.h | 16 +- libOTe/Tools/Dpf/TriDpf.h | 405 +++++++++ libOTe/Tools/{Foliage => Foleage}/F4Ops.h | 2 +- .../FoleageMain.cpp} | 8 +- .../FoliagePcg.cpp => Foleage/FoleagePcg.cpp} | 89 +- .../FoliagePcg.h => Foleage/FoleagePcg.h} | 5 +- .../FoliageUtils.h => Foleage/FoleageUtils.h} | 4 +- libOTe/Tools/Foleage/PerfectShuffle.h | 437 +++++++++ .../fft/FoleageFFT_bench.cpp} | 16 +- libOTe/Tools/Foleage/fft/FoleageFFT_bench.h | 13 + libOTe/Tools/Foleage/fft/FoleageFft.cpp | 856 ++++++++++++++++++ libOTe/Tools/Foleage/fft/FoleageFft.h | 388 ++++++++ .../Tools/{Foliage => Foleage}/spfss_test.cpp | 4 +- .../{Foliage => Foleage}/tri-dpf/.gitignore | 0 .../tri-dpf/FoleageDpf.cpp} | 4 +- .../tri-dpf/FoleageDpf.h} | 4 +- .../tri-dpf/FoleageDpf_test.cpp} | 4 +- .../tri-dpf/FoleageDpf_test.h} | 0 .../tri-dpf/FoleagePrf.h} | 2 +- .../{Foliage => Foleage}/tri-dpf/LICENSE | 0 .../{Foliage => Foleage}/tri-dpf/README.md | 0 .../tri-dpf/TriDpfUtils.h | 2 +- libOTe/Tools/{Foliage => Foleage}/uint128.h | 0 libOTe/Tools/Foliage/fft/FoliageFFT_bench.h | 13 - libOTe/Tools/Foliage/fft/FoliageFft.cpp | 311 ------- libOTe/Tools/Foliage/fft/FoliageFft.h | 37 - libOTe_Tests/CMakeLists.txt | 2 +- .../{Foliage_Tests.cpp => Foleage_Tests.cpp} | 441 ++++++++- libOTe_Tests/Foleage_Tests.h | 14 + libOTe_Tests/Foliage_Tests.h | 12 - libOTe_Tests/RegularDpf_Tests.cpp | 80 ++ libOTe_Tests/UnitTests.cpp | 13 +- 32 files changed, 2722 insertions(+), 460 deletions(-) create mode 100644 libOTe/Tools/Dpf/TriDpf.h rename libOTe/Tools/{Foliage => Foleage}/F4Ops.h (99%) rename libOTe/Tools/{Foliage/FoliageMain.cpp => Foleage/FoleageMain.cpp} (97%) rename libOTe/Tools/{Foliage/FoliagePcg.cpp => Foleage/FoleagePcg.cpp} (83%) rename libOTe/Tools/{Foliage/FoliagePcg.h => Foleage/FoleagePcg.h} (93%) rename libOTe/Tools/{Foliage/FoliageUtils.h => Foleage/FoleageUtils.h} (98%) create mode 100644 libOTe/Tools/Foleage/PerfectShuffle.h rename libOTe/Tools/{Foliage/fft/FoliageFFT_bench.cpp => Foleage/fft/FoleageFFT_bench.cpp} (92%) create mode 100644 libOTe/Tools/Foleage/fft/FoleageFFT_bench.h create mode 100644 libOTe/Tools/Foleage/fft/FoleageFft.cpp create mode 100644 libOTe/Tools/Foleage/fft/FoleageFft.h rename libOTe/Tools/{Foliage => Foleage}/spfss_test.cpp (98%) rename libOTe/Tools/{Foliage => Foleage}/tri-dpf/.gitignore (100%) rename libOTe/Tools/{Foliage/tri-dpf/FoliageDpf.cpp => Foleage/tri-dpf/FoleageDpf.cpp} (99%) rename libOTe/Tools/{Foliage/tri-dpf/FoliageDpf.h => Foleage/tri-dpf/FoleageDpf.h} (84%) rename libOTe/Tools/{Foliage/tri-dpf/FoliageDpf_test.cpp => Foleage/tri-dpf/FoleageDpf_test.cpp} (97%) rename libOTe/Tools/{Foliage/tri-dpf/FoliageDpf_test.h => Foleage/tri-dpf/FoleageDpf_test.h} (100%) rename libOTe/Tools/{Foliage/tri-dpf/FoliagePrf.h => Foleage/tri-dpf/FoleagePrf.h} (97%) rename libOTe/Tools/{Foliage => Foleage}/tri-dpf/LICENSE (100%) rename libOTe/Tools/{Foliage => Foleage}/tri-dpf/README.md (100%) rename libOTe/Tools/{Foliage => Foleage}/tri-dpf/TriDpfUtils.h (96%) rename libOTe/Tools/{Foliage => Foleage}/uint128.h (100%) delete mode 100644 libOTe/Tools/Foliage/fft/FoliageFFT_bench.h delete mode 100644 libOTe/Tools/Foliage/fft/FoliageFft.cpp delete mode 100644 libOTe/Tools/Foliage/fft/FoliageFft.h rename libOTe_Tests/{Foliage_Tests.cpp => Foleage_Tests.cpp} (76%) create mode 100644 libOTe_Tests/Foleage_Tests.h delete mode 100644 libOTe_Tests/Foliage_Tests.h diff --git a/libOTe/Tools/Dpf/RegularDpf.h b/libOTe/Tools/Dpf/RegularDpf.h index bb3da230..b121edd3 100644 --- a/libOTe/Tools/Dpf/RegularDpf.h +++ b/libOTe/Tools/Dpf/RegularDpf.h @@ -13,20 +13,6 @@ namespace osuCrypto { struct RegularDpf { - enum class OutputFormat - { - // The i'th row holds the i'th leaf for all trees. - // The j'th tree is in the j'th column. - ByLeafIndex, - - // The i'th row holds the i'th tree. - // The j'th leaf is in the j'th column. - ByTreeIndex, - - }; - - OutputFormat mOutputFormat = OutputFormat::ByLeafIndex; - u64 mPartyIdx = 0; u64 mDomain = 0; @@ -54,7 +40,7 @@ namespace osuCrypto if (!numPoints) throw RTE_LOC; - mDepth = oc::log2ceil(domain); + mDepth = oc::log3ceil(domain); mPartyIdx = partyIdx; mDomain = domain; mNumPoints = numPoints; diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h new file mode 100644 index 00000000..af4e23a8 --- /dev/null +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -0,0 +1,405 @@ +#pragma once + + +#include "cryptoTools/Common/Defines.h" +#include "coproto/Socket/Socket.h" +#include "cryptoTools/Crypto/PRNG.h" +#include "cryptoTools/Common/BitVector.h" +#include "cryptoTools/Common/Matrix.h" + +#include "DpfMult.h" +#include "libOTe/Tools/Foleage/FoleageUtils.h" + +namespace osuCrypto +{ + struct TriDpf + { + enum class OutputFormat + { + // The i'th row holds the i'th leaf for all trees. + // The j'th tree is in the j'th column. + ByLeafIndex, + + // The i'th row holds the i'th tree. + // The j'th leaf is in the j'th column. + ByTreeIndex, + + }; + + OutputFormat mOutputFormat = OutputFormat::ByLeafIndex; + + u64 mPartyIdx = 0; + + u64 mDomain = 0; + + u64 mDepth = 0; + + u64 mNumPoints = 0; + + //DpfMult mMultiplier; + + u8 lsb(const block& b) + { + return b.get(0) & 1; + } + + void init( + u64 partyIdx, + u64 domain, + u64 numPoints) + { + if (partyIdx > 1) + throw RTE_LOC; + if (domain < 2) + throw RTE_LOC; + if (!numPoints) + throw RTE_LOC; + + mDepth = log3ceil(domain); + mPartyIdx = partyIdx; + mDomain = domain; + mNumPoints = numPoints; + //mMultiplier.init(partyIdx, numPoints * mDepth); + } + +#define SIMD8(VAR, STATEMENT) \ + { constexpr u64 VAR = 0; STATEMENT; }\ + { constexpr u64 VAR = 1; STATEMENT; }\ + { constexpr u64 VAR = 2; STATEMENT; }\ + { constexpr u64 VAR = 3; STATEMENT; }\ + { constexpr u64 VAR = 4; STATEMENT; }\ + { constexpr u64 VAR = 5; STATEMENT; }\ + { constexpr u64 VAR = 6; STATEMENT; }\ + { constexpr u64 VAR = 7; STATEMENT; }\ + do{}while(0) + + template< + typename Output + > + macoro::task<> expand( + span points, + span values, + Output&& output, + PRNG& prng, + coproto::Socket& sock) + { + if constexpr (std::is_same, Matrix>::value) + { + if (output.rows() != mNumPoints) + throw RTE_LOC; + if (output.cols() != mDomain) + throw RTE_LOC; + } + if (points.size() != mNumPoints) + throw RTE_LOC; + if (values.size() && values.size() != mNumPoints) + throw RTE_LOC; + + for (u64 i = 0; i < mNumPoints; ++i) + { + u64 v = points[i]; + for (u64 j = 0; j < mDepth; ++j) + { + if ((v & 3) == 3) + throw std::runtime_error("TriDpf: invalid point sharing. Expects the input points to be shared over Z_3^D where each Z_3 elements takes up 2 bits of a the value. " LOCATION); + v >>= 2; + } + if(v) + throw std::runtime_error("TriDpf: invalid point sharing. point is larger than 3^D " LOCATION); + } + + u64 numPoints = points.size(); + u64 numPoints8 = numPoints / 8 * 8; + + + // shares of S' + auto pow2 = 1ull << log2ceil(mDomain); + std::array, 2> s; + s[mDepth & 1].resize(pow2, numPoints, oc::AllocType::Uninitialized); + s[(mDepth & 1) ^ 1].resize(pow2 / 2, numPoints, oc::AllocType::Uninitialized); + + // share of t + std::array, 2> t; + t[0].resize(s[0].rows(), s[0].cols()); + t[1].resize(s[1].rows(), s[1].cols()); + for (u64 i = 0; i < numPoints; ++i) + t[0](0, i) = mPartyIdx; + + +#if defined(NDEBUG) + auto getRow = [](auto&& m, u64 i) {return m.data(i); }; +#else + auto getRow = [](auto&& m, u64 i) {return m[i]; }; +#endif + std::array, 2> tau; + tau[0].resize(mNumPoints); + tau[1].resize(mNumPoints); + + std::array, 2> z; + z[0].resize(mNumPoints); + z[1].resize(mNumPoints); + AlignedUnVector sigma(mNumPoints); + BitVector negAlphaj(mNumPoints); + AlignedUnVector diff(mNumPoints); + + + { + // we skip level 0 and set level 1 to be random + auto sc0 = s[1][0]; + auto sc1 = s[1][1]; + for (u64 k = 0; k < numPoints; ++k) + { + sc0[k] = prng.get(); + sc1[k] = prng.get(); + + z[0][k] = sc0[k]; + z[1][k] = sc1[k]; + } + } + + // at each iteration we first correct the parent level. + // The parent level has two syblings which are random. + // We need to correct the inactive child so that both parties + // hold the same seed (a sharing of zero). + // + // we then expand the parent to level to get the children level. + // We compute left and right sums for the children. + for (u64 iter = 1; iter <= mDepth; ++iter) + { + // the grand parent level + auto& tp = t[(iter - 1) & 1]; + + // the parent level + auto& sc = s[iter & 1]; + auto& tc = t[iter & 1]; + + // the child level + auto& sg = s[(iter + 1) & 1]; + + auto size = 1ull << iter; + + // + for (u64 k = 0; k < mNumPoints; ++k) + { + auto alphaj = *oc::BitIterator(&points[k], mDepth - iter); + tau[0][k] = lsb(z[0][k]) ^ alphaj ^ mPartyIdx; + tau[1][k] = lsb(z[1][k]) ^ alphaj; + diff[k] = z[0][k] ^ z[1][k]; + negAlphaj[k] = alphaj ^ mPartyIdx; + } + + co_await mMultiplier.multiply(negAlphaj, diff, diff, sock); + // sigma = z[1^alpha[j]] + for (u64 k = 0; k < mNumPoints; ++k) + sigma[k] = diff[k] ^ z[0][k]; + + // reveal sigma and tau + u64 buffSize = sigma.size() * 16 + divCeil(mNumPoints * 2, 8); + AlignedUnVector sendBuff(buffSize), recvBuff(buffSize); + copyBytesMin(sendBuff, sigma); + auto sendBitIter = BitIterator(&sendBuff[numPoints * 16]); + auto recvBitIter = BitIterator(&recvBuff[numPoints * 16]); + for (u64 i = 0; i < mNumPoints; ++i) + { + *sendBitIter++ = tau[0][i]; + *sendBitIter++ = tau[1][i]; + } + co_await sock.send(std::move(sendBuff)); + co_await sock.recv(recvBuff); + for (u64 k = 0; k < mNumPoints; ++k) + { + block sk = *(block*)&recvBuff[k * sizeof(block)]; + sigma[k] ^= sk; + tau[0][k] ^= *recvBitIter++; + tau[1][k] ^= *recvBitIter++; + } + + + if (iter != mDepth) + { + + setBytes(z[0], 0); + setBytes(z[1], 0); + + for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 2, L4 += 4) + { + // parent control bits + auto tpl = getRow(tp, L); + + // child seed + std::array scl{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; + + // child control bit + std::array tcl{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; + + // grandchild seeds + std::array sgl{ getRow(sg, L4 + 0), getRow(sg, L4 + 1), getRow(sg, L4 + 2), getRow(sg, L4 + 3) }; + + for (u64 k = 0; k < numPoints8; k += 8) + { + block temp[8]; + SIMD8(q, temp[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); + SIMD8(q, tcl[0][k + q] = lsb(scl[0][k + q]) ^ tpl[k + q] & tau[0][k + q]); + SIMD8(q, scl[0][k + q] ^= temp[q]); + + + mAesFixedKey.ecbEncBlocks<8>(&scl[0][k], &sgl[1][k]); + SIMD8(q, sgl[0][k + q] = AES::roundEnc(sgl[1][k + q], scl[0][k + q])); + SIMD8(q, sgl[1][k + q] = sgl[1][k + q] + scl[0][k + q]); + + SIMD8(q, z[0][k + q] ^= sgl[0][k + q]); + SIMD8(q, z[1][k + q] ^= sgl[1][k + q]); + + SIMD8(q, tcl[1][k + q] = lsb(scl[1][k + q]) ^ tpl[k + q] & tau[1][k + q]); + SIMD8(q, scl[1][k + q] ^= temp[q]); + + mAesFixedKey.ecbEncBlocks<8>(&scl[1][k], &sgl[3][k]); + SIMD8(q, sgl[2][k + q] = AES::roundEnc(sgl[3][k + q], scl[1][k + q])); + SIMD8(q, sgl[3][k + q] = sgl[3][k + q] + scl[1][k + q]); + SIMD8(q, z[0][k + q] ^= sgl[2][k + q]); + SIMD8(q, z[1][k + q] ^= sgl[3][k + q]); + } + + for (u64 k = numPoints8; k < mNumPoints; ++k) + { + auto temp = block::allSame(-tpl[k + 0]) & sigma[k + 0]; + + tcl[0][k] = lsb(scl[0][k]) ^ tpl[k] & tau[0][k]; + scl[0][k] ^= temp; + + sgl[1][k] = mAesFixedKey.ecbEncBlock(scl[0][k]); + sgl[0][k] = AES::roundEnc(sgl[1][k], scl[0][k]); + sgl[1][k] = sgl[1][k] + scl[0][k]; + + z[0][k] ^= sgl[0][k]; + z[1][k] ^= sgl[1][k]; + + tcl[1][k] = lsb(scl[1][k]) ^ tpl[k] & tau[1][k]; + scl[1][k] ^= temp; + + sgl[3][k] = mAesFixedKey.ecbEncBlock(scl[1][k]); + sgl[2][k] = AES::roundEnc(sgl[3][k], scl[1][k]); + sgl[3][k] = sgl[3][k] + scl[1][k]; + + z[0][k] ^= sgl[2][k]; + z[1][k] ^= sgl[3][k]; + } + } + } + } + + + // fixing the last layer + { + auto size = 1ull << mDepth; + + auto& tp = t[(mDepth - 1) & 1]; + auto& sc = s[mDepth & 1]; + auto& tc = t[mDepth & 1]; + for (u64 L = 0, L2 = 0; L2 < size; ++L, L2 += 2) + { + // parent control bits + auto tpl = getRow(tp, L); + + // child seed + std::array scl{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; + + // child control bit + std::array tcl{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; + + for (u64 k = 0; k < numPoints8; k += 8) + { + block temp[8]; + SIMD8(q, temp[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); + SIMD8(q, tcl[0][k + q] = lsb(scl[0][k + q]) ^ tpl[k + q] & tau[0][k + q]); + SIMD8(q, tcl[1][k + q] = lsb(scl[1][k + q]) ^ tpl[k + q] & tau[1][k + q]); + SIMD8(q, scl[0][k + q] ^= temp[q]); + SIMD8(q, scl[1][k + q] ^= temp[q]); + } + + for (u64 k = numPoints8; k < mNumPoints; ++k) + { + auto temp = block::allSame(-tpl[k + 0]) & sigma[k + 0]; + tc[L2 + 0][k] = lsb(scl[0][k]) ^ tpl[k] & tau[0][k]; + tc[L2 + 1][k] = lsb(scl[1][k]) ^ tpl[k] & tau[1][k]; + sc[L2 + 0][k] ^= temp; + sc[L2 + 1][k] ^= temp; + } + } + } + + if (values.size()) + { + + AlignedUnVector gamma(mNumPoints); + for (u64 k = 0; k < mNumPoints; ++k) + { + diff[k] = z[0][k] ^ z[1][k] ^ values[k]; + } + co_await sock.send(std::move(diff)); + co_await sock.recv(gamma); + for (u64 k = 0; k < mNumPoints; ++k) + { + gamma[k] = z[0][k] ^ z[1][k] ^ values[k] ^ gamma[k]; + } + + auto& sd = s[mDepth & 1]; + auto& td = t[mDepth & 1]; + for (u64 i = 0; i < mDomain; ++i) + { + auto sdi = getRow(sd, i); + auto tdi = getRow(td, i); + + for (u64 k = 0; k < numPoints8; k += 8) + { + block T[8]; + + SIMD8(q, T[q] = block::allSame(-tdi[k + q]) & gamma[k + q]); + SIMD8(q, output(k + q, i, sdi[k + q] ^ T[q], tdi[k + q])); + } + for (u64 k = numPoints8; k < mNumPoints; ++k) + { + auto T = block::allSame(-tdi[k]) & gamma[k]; + output(k, i, sdi[k] ^ T, tdi[k]); + } + } + } + else + { + auto& sd = s[mDepth & 1]; + auto& td = t[mDepth & 1]; + for (u64 i = 0; i < mDomain; ++i) + { + auto sdi = getRow(sd, i); + auto tdi = getRow(td, i); + for (u64 k = 0; k < numPoints8; k += 8) + { + SIMD8(q, output(k + q, i, sdi[k + q], tdi[k + q])); + } + for (u64 k = numPoints8; k < mNumPoints; ++k) + { + output(k, i, sdi[k], tdi[k]); + } + } + } + } + + + u64 baseOtCount() const { + return mMultiplier.baseOtCount(); + } + + void setBaseOts( + span> baseSendOts, + span recvBaseOts, + const oc::BitVector& baseChoices) + { + mMultiplier.setBaseOts(baseSendOts, recvBaseOts, baseChoices); + } + + + }; + +} + +#undef SIMD8 \ No newline at end of file diff --git a/libOTe/Tools/Foliage/F4Ops.h b/libOTe/Tools/Foleage/F4Ops.h similarity index 99% rename from libOTe/Tools/Foliage/F4Ops.h rename to libOTe/Tools/Foleage/F4Ops.h index fd4a73d3..ba2925e1 100644 --- a/libOTe/Tools/Foliage/F4Ops.h +++ b/libOTe/Tools/Foleage/F4Ops.h @@ -1,6 +1,6 @@ #pragma once -#include "libOTe/Tools/Foliage/FoliageUtils.h" +#include "libOTe/Tools/Foleage/FoleageUtils.h" namespace osuCrypto { diff --git a/libOTe/Tools/Foliage/FoliageMain.cpp b/libOTe/Tools/Foleage/FoleageMain.cpp similarity index 97% rename from libOTe/Tools/Foliage/FoliageMain.cpp rename to libOTe/Tools/Foleage/FoleageMain.cpp index 7914afd5..37517dae 100644 --- a/libOTe/Tools/Foliage/FoliageMain.cpp +++ b/libOTe/Tools/Foleage/FoleageMain.cpp @@ -2,11 +2,11 @@ #include #include -#include "libOTe/Tools/Foliage/F4Ops.h" -#include "libOTe/Tools/Foliage/fft/FoliageFft.h" +#include "libOTe/Tools/Foleage/F4Ops.h" +#include "libOTe/Tools/Foleage/fft/FoleageFft.h" -#include "libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h" -#include "libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h" +#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" +#include "libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) diff --git a/libOTe/Tools/Foliage/FoliagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp similarity index 83% rename from libOTe/Tools/Foliage/FoliagePcg.cpp rename to libOTe/Tools/Foleage/FoleagePcg.cpp index 33adc175..5e0a9e4d 100644 --- a/libOTe/Tools/Foliage/FoliagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -1,20 +1,21 @@ -#include "FoliagePcg.h" -#include "libOTe/Tools/Foliage/FoliageUtils.h" -#include "libOTe/Tools/Foliage/F4Ops.h" -#include "libOTe/Tools/Foliage/fft/FoliageFft.h" +#include "FoleagePcg.h" +#include "libOTe/Tools/Foleage/FoleageUtils.h" +#include "libOTe/Tools/Foleage/F4Ops.h" +#include "libOTe/Tools/Foleage/fft/FoleageFft.h" #include "cryptoTools/Common/BitIterator.h" -#include "libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h" -#include "libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h" +#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" +#include "libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h" namespace osuCrypto { - void FoliageF4Ole::init(u64 partyIdx, u64 n, PRNG& prng) + void FoleageF4Ole::init(u64 partyIdx, u64 n, PRNG& prng) { mPartyIdx = partyIdx; mLog3N = log3Ceil(n); mN = ipow(3, mLog3N); - if (mT % 3 != 0) + + if (mT != ipow(3, mLog3T)) throw RTE_LOC; mDpfDomainDepth = std::max(1, log3Ceil(divCeil(mN, mT * 256))); @@ -34,7 +35,7 @@ namespace osuCrypto } - void FoliageF4Ole::sampleA(block seed) + void FoleageF4Ole::sampleA(block seed) { if (mC > 4) @@ -104,7 +105,7 @@ namespace osuCrypto } - macoro::task<> FoliageF4Ole::expand( + macoro::task<> FoleageF4Ole::expand( span ALsb, span AMsb, span CLsb, @@ -112,6 +113,8 @@ namespace osuCrypto PRNG& prng, coproto::Socket& sock) { + setTimePoint("expand start"); + if (divCeil(mN, 128) < ALsb.size()) throw RTE_LOC; if (ALsb.size() != AMsb.size() || ALsb.size() != CLsb.size() || ALsb.size() != CMsb.size()) @@ -136,24 +139,33 @@ namespace osuCrypto // we pack 4 FFTs into a single u8. std::vector fftSparsePoly(mN); + //std::vector fftSparsePolyLsb(mN), fftSparsePolyMsb(mN); for (u64 i = 0; i < mT; ++i) { for (u64 j = 0; j < mC; ++j) { auto pos = i * mBlockSize + mSparsePositions(j, i); fftSparsePoly[pos] |= mSparseCoefficients(j, i) << (2 * j); + + //fftSparsePolyLsb[pos] |= (mSparseCoefficients(j, i) & 1) << j; + //fftSparsePolyMsb[pos] |= ((mSparseCoefficients(j, i) >> 1) & 1) << j; } } + setTimePoint("sparsePolySample"); + //std::cout << "sparse " << hash(fftSparsePoly.data(), fftSparsePoly.size()) << std::endl; // switch from polynomial to FFT form fft_recursive_uint8(fftSparsePoly, mLog3N, mN / 3); + //foleageFFT2<1>(fftSparsePolyLsb, fftSparsePolyMsb); + // multiply by the packed A polynomial multiply_fft_8(mFftA, fftSparsePoly, fftSparsePoly, mN); //std::cout << "mult " << hash(fftSparsePoly.data(), fftSparsePoly.size()) << std::endl; + setTimePoint("sparsePolyMul"); // compress the resume and set the output. @@ -172,14 +184,15 @@ namespace osuCrypto A[i] = a; } + setTimePoint("copyOutX"); //std::cout << "compress " << hash(fftSparsePoly.data(), fftSparsePoly.size()) << std::endl; std::vector prodPolyCoefficient(mC * mC * mT * mT); std::vector prodPolyPosition(mC * mC * mT * mT); - //auto prodPolyCoefficientIter = prodPolyCoefficient.begin(); - //auto prodPolyPositionIter = prodPolyPosition.begin(); - std::vector tritA(mLog3N), tritB(mLog3N), trits(mLog3N); + + std::vector tritABlk(mLog3T), tritBBlk(mLog3T), tritsBlk(mLog3T); + std::vector tritAPos(mLog3N - mLog3T), tritBPos(mLog3N - mLog3T), tritsPos(mLog3N - mLog3T); Matrix otherSparseCoefficients(mC, mT); Matrix otherSparsePositions(mC, mT); @@ -187,6 +200,8 @@ namespace osuCrypto co_await sock.send(coproto::copy(mSparsePositions)); co_await sock.recv(otherSparseCoefficients); co_await sock.recv(otherSparsePositions); + setTimePoint("sendRecv"); + u64 polyOffset = 0; u8 vA, vB; for (u64 iA = 0; iA < mC; ++iA) @@ -199,36 +214,46 @@ namespace osuCrypto { for (u64 jB = 0; jB < mT; ++jB) { + int_to_trits(jA, tritABlk); + int_to_trits(jB, tritBBlk); + + for (size_t k = 0; k < mLog3T; k++) + { + tritsBlk[k] = (tritABlk[k] + tritBBlk[k]) % 3; + } + u64 blockIdx = trits_to_int(tritsBlk); + + u64 posA_; + u64 posB_; + if (mPartyIdx == 0) { vA = mSparseCoefficients(iA, jA); vB = otherSparseCoefficients(iB, jB); - auto posA = jA * mBlockSize + mSparsePositions(iA, jA); - auto posB = jB * mBlockSize + otherSparsePositions(iB, jB); - int_to_trits(posA, tritA); - int_to_trits(posB, tritB); + posA_ = mSparsePositions(iA, jA); + posB_ = otherSparsePositions(iB, jB); + } else { vA = otherSparseCoefficients(iA, jA); vB = mSparseCoefficients(iB, jB); - auto posA = jA * mBlockSize + otherSparsePositions(iA, jA); - auto posB = jB * mBlockSize + mSparsePositions(iB, jB); - int_to_trits(posA, tritA); - int_to_trits(posB, tritB); + posA_ = otherSparsePositions(iA, jA); + posB_ = mSparsePositions(iB, jB); } + int_to_trits(posA_, tritAPos); + int_to_trits(posB_, tritBPos); - for (size_t k = 0; k < mLog3N; k++) + for(u64 k = 0; k < tritBPos.size(); ++k) { - trits[k] = (tritA[k] + tritB[k]) % 3; + tritsPos[k] = (tritAPos[k] + tritBPos[k]) % 3; } - u64 pos = trits_to_int(trits); - auto blockIdx = pos / mBlockSize; - + auto subblock_pos = trits_to_int(tritsPos); + size_t idx = polyOffset + blockIdx * mT + nextIdx[blockIdx]++; prodPolyCoefficient[idx] = mult_f4(vA, vB); - prodPolyPosition[idx] = pos % mBlockSize; + prodPolyPosition[idx] = subblock_pos; } } @@ -239,6 +264,7 @@ namespace osuCrypto } } + setTimePoint("sparseProductCompute"); std::vector Dpfs(mC * mC * mT * mT); @@ -307,6 +333,7 @@ namespace osuCrypto } } } + setTimePoint("dpfKeyGen"); //block dpfHashVal; //dpfHash.Final(dpfHashVal); @@ -355,6 +382,7 @@ namespace osuCrypto } } } + setTimePoint("dpfKeyEval"); //std::cout << "block " << hash(blocks.data(), blocks.size()) << std::endl; @@ -384,6 +412,9 @@ namespace osuCrypto } } } + + setTimePoint("transpose"); + //std::cout << "CIn " << hash(fft.data(), fft.size()) << std::endl; @@ -393,6 +424,7 @@ namespace osuCrypto //std::cout << "C " << hash(fftRes.data(), fftRes.size()) << std::endl; + setTimePoint("fft"); // XOR the (packed) columns into the accumulator. // Specifically, we perform column-wise XORs to get the result. @@ -408,6 +440,9 @@ namespace osuCrypto *BitIterator(CMsb.data(), i) = popcount(fftRes[i] & msbMask) & 1; } + + setTimePoint("addCopyY"); + } diff --git a/libOTe/Tools/Foliage/FoliagePcg.h b/libOTe/Tools/Foleage/FoleagePcg.h similarity index 93% rename from libOTe/Tools/Foliage/FoliagePcg.h rename to libOTe/Tools/Foleage/FoleagePcg.h index 1ce762c5..67edb6c2 100644 --- a/libOTe/Tools/Foliage/FoliagePcg.h +++ b/libOTe/Tools/Foleage/FoleagePcg.h @@ -4,11 +4,12 @@ #include "cryptoTools/Common/Aligned.h" #include "coproto/Socket/Socket.h" #include "cryptoTools/Crypto/PRNG.h" +#include "cryptoTools/Common/Timer.h" namespace osuCrypto { - class FoliageF4Ole + class FoleageF4Ole : public TimerAdapter { public: u64 mPartyIdx = 0; @@ -19,6 +20,8 @@ namespace osuCrypto // the number of noisy positions per polynomial u64 mT = 27; + u64 mLog3T = 3; + // the number of polynomials u64 mC = 4; diff --git a/libOTe/Tools/Foliage/FoliageUtils.h b/libOTe/Tools/Foleage/FoleageUtils.h similarity index 98% rename from libOTe/Tools/Foliage/FoliageUtils.h rename to libOTe/Tools/Foleage/FoleageUtils.h index 52b18e8f..a1f62a97 100644 --- a/libOTe/Tools/Foliage/FoliageUtils.h +++ b/libOTe/Tools/Foleage/FoleageUtils.h @@ -260,7 +260,7 @@ namespace osuCrypto return std::log2(a) / std::log2(base); } - inline u64 log3Ceil(u64 x) + inline u64 log3ceil(u64 x) { if (x == 0) return 0; u64 i = 0; @@ -277,7 +277,7 @@ namespace osuCrypto // Compute base^exp without the floating-point precision // errors of the built-in pow function. - inline size_t ipow(size_t base, size_t exp) + inline constexpr size_t ipow(size_t base, size_t exp) { if (exp == 1) return base; diff --git a/libOTe/Tools/Foleage/PerfectShuffle.h b/libOTe/Tools/Foleage/PerfectShuffle.h new file mode 100644 index 00000000..c5367fdb --- /dev/null +++ b/libOTe/Tools/Foleage/PerfectShuffle.h @@ -0,0 +1,437 @@ +#pragma once +#include "cryptoTools/Common/Defines.h" +#include +#include + +namespace osuCrypto +{ + + + + // given a shuffle on blocks of 2*Shift, shuffle + // them together to have block size Shift. + template + inline u32 cPerfectShuffle_round(u32 x) + { + static_assert(Shift, "Shift must be 1,2,4,8. That is, we assume the x is split into chunks of size 2*Shift and we will shuffle these into chunks of size Shift"); + u32 t; + t = (x ^ (x >> Shift)) & v; + x = x ^ t ^ (t << Shift); + return x; + } + + + + // Hackers Delight perfect shuffle, Sec 7.2. Interlace bits. + // https://doc.lagout.org/security/Hackers%20Delight.pdf + // + // input : abcd efgh ijkl mnop ABCD EFGH IJKL MNOP, + // output: aAbB cCdD eEfF gGhH iIjJ kKlL mMnN oOpP + inline u32 cPerfectShuffle(u16 x0, u16 x1) + { + u32 x = x0 | (u32{ x1 } << 16); + x = cPerfectShuffle_round<8>(x); + x = cPerfectShuffle_round<4>(x); + x = cPerfectShuffle_round<2>(x); + x = cPerfectShuffle_round<1>(x); + return x; + } + + // Hackers Delight perfect shuffle, Sec 7.2. Uninterlace bits. + // https://doc.lagout.org/security/Hackers%20Delight.pdf + // + // input : aAbB cCdD eEfF gGhH iIjJ kKlL mMnN oOpP + // output: abcd efgh ijkl mnop ABCD EFGH IJKL MNOP, + inline std::array cPerfectUnshuffle(u32 x) + { + x = cPerfectShuffle_round<1>(x); + x = cPerfectShuffle_round<2>(x); + x = cPerfectShuffle_round<4>(x); + x = cPerfectShuffle_round<8>(x); + + std::array r; + r[0] = x; + r[1] = x >> 16; + return r; + } + + // perfect shuffle the bits of `input0` and `input1` into `output`. + // bits from `input0` and `input1` alternate. + inline void cPerfectShuffle(span input0, span input1, span output) + { + if (input0.size() != input1.size()) + throw RTE_LOC; + if (input0.size() != (output.size() + 1) / 2) + throw RTE_LOC; + + u64 n32 = output.size() / sizeof(u32); + + auto in0 = (u16*)input0.data(); + auto in1 = (u16*)input1.data(); + auto out = (u32*)output.data(); + for (u64 i = 0; i < n32; ++i) + { + out[i] = cPerfectShuffle(in0[i], in1[i]); + } + + auto n8 = n32 * sizeof(u32); + if (output.size() != n8) + { + u16 x0 = 0, x1 = 0; + copyBytesMin(x0, input0.subspan(n8 / 2)); + copyBytesMin(x1, input1.subspan(n8 / 2)); + auto t = cPerfectShuffle(x0, x1); + copyBytesMin(output.subspan(n8), t); + } + } + + // perfect unshuffle the bits of `input` into `output0` and `output1`. + // even indexed bits of `input` go to `output0`. + inline void cPerfectUnshuffle(span input, span output0, span output1) + { + if (output0.size() != output1.size()) + throw RTE_LOC; + if (output0.size() != (input.size() + 1) / 2) + throw RTE_LOC; + u64 n32 = input.size() / sizeof(u32); + auto out0 = (u16*)output0.data(); + auto out1 = (u16*)output1.data(); + auto in = (u32*)input.data(); + for (u64 i = 0; i < n32; ++i) + { + auto t = cPerfectUnshuffle(in[i]); + assert((u8*)&(out0[i]) < output0.data() + output0.size()); + assert((u8*)&(out1[i]) < output1.data() + output1.size()); + + out0[i] = ((u16*)&t)[0]; + out1[i] = ((u16*)&t)[1]; + } + + auto n8 = n32 * sizeof(u32); + if (input.size() != n8) + { + // auto rem = output0.size() - n8 / 2; + u32 t = 0; + copyBytesMin(t, input.subspan(n8)); + auto r = cPerfectUnshuffle(t); + copyBytesMin(output0.subspan(n8 / 2), r[0]); + copyBytesMin(output1.subspan(n8 / 2), r[1]); + } + } + +#ifdef ENABLE_SSE + + // given a shuffle on blocks of 2*Shift, shuffle + // them together to have block size Shift. + template + inline void ssePerfectShuffle_round(oc::block& x) + { + static_assert(Shift, "Shift must be 1,2,4,8. That is, we assume the x is split into chunks of size 2*Shift and we will shuffle these into chunks of size Shift"); + oc::block t; + + //t = (x ^ (x >> shift)) & 0x0000FF00; + t = _mm_srli_epi32(x, Shift); + t = _mm_xor_si128(t, x); + t = _mm_and_si128(t, _mm_set_epi32(v, v, v, v)); + + // x = x ^ t ^ (t << shift); + x = _mm_xor_si128(t, x); + t = _mm_slli_epi32(t, Shift); + x = _mm_xor_si128(t, x); + } + + // given a shuffle on blocks of 2*Shift, shuffle + // them together to have block size Shift. + template + inline void ssePerfectShuffle_round(oc::block* x) + { + static_assert(Shift, "Shift must be 1,2,4,8. That is, we assume the x is split into chunks of size 2*Shift and we will shuffle these into chunks of size Shift"); + oc::block t[8]; + auto V = _mm_set_epi32(v, v, v, v); + + //t = (x ^ (x >> shift)) & 0x0000FF00; + t[0] = _mm_srli_epi32(x[0], Shift); + t[1] = _mm_srli_epi32(x[1], Shift); + t[2] = _mm_srli_epi32(x[2], Shift); + t[3] = _mm_srli_epi32(x[3], Shift); + t[4] = _mm_srli_epi32(x[4], Shift); + t[5] = _mm_srli_epi32(x[5], Shift); + t[6] = _mm_srli_epi32(x[6], Shift); + t[7] = _mm_srli_epi32(x[7], Shift); + + t[0] = _mm_xor_si128(t[0], x[0]); + t[1] = _mm_xor_si128(t[1], x[1]); + t[2] = _mm_xor_si128(t[2], x[2]); + t[3] = _mm_xor_si128(t[3], x[3]); + t[4] = _mm_xor_si128(t[4], x[4]); + t[5] = _mm_xor_si128(t[5], x[5]); + t[6] = _mm_xor_si128(t[6], x[6]); + t[7] = _mm_xor_si128(t[7], x[7]); + + t[0] = _mm_and_si128(t[0], V); + t[1] = _mm_and_si128(t[1], V); + t[2] = _mm_and_si128(t[2], V); + t[3] = _mm_and_si128(t[3], V); + t[4] = _mm_and_si128(t[4], V); + t[5] = _mm_and_si128(t[5], V); + t[6] = _mm_and_si128(t[6], V); + t[7] = _mm_and_si128(t[7], V); + + // x = x ^ t ^ (t << shift); + x[0] = _mm_xor_si128(t[0], x[0]); + x[1] = _mm_xor_si128(t[1], x[1]); + x[2] = _mm_xor_si128(t[2], x[2]); + x[3] = _mm_xor_si128(t[3], x[3]); + x[4] = _mm_xor_si128(t[4], x[4]); + x[5] = _mm_xor_si128(t[5], x[5]); + x[6] = _mm_xor_si128(t[6], x[6]); + x[7] = _mm_xor_si128(t[7], x[7]); + t[0] = _mm_slli_epi32(t[0], Shift); + t[1] = _mm_slli_epi32(t[1], Shift); + t[2] = _mm_slli_epi32(t[2], Shift); + t[3] = _mm_slli_epi32(t[3], Shift); + t[4] = _mm_slli_epi32(t[4], Shift); + t[5] = _mm_slli_epi32(t[5], Shift); + t[6] = _mm_slli_epi32(t[6], Shift); + t[7] = _mm_slli_epi32(t[7], Shift); + x[0] = _mm_xor_si128(t[0], x[0]); + x[1] = _mm_xor_si128(t[1], x[1]); + x[2] = _mm_xor_si128(t[2], x[2]); + x[3] = _mm_xor_si128(t[3], x[3]); + x[4] = _mm_xor_si128(t[4], x[4]); + x[5] = _mm_xor_si128(t[5], x[5]); + x[6] = _mm_xor_si128(t[6], x[6]); + x[7] = _mm_xor_si128(t[7], x[7]); + } + + inline oc::block ssePerfectShuffle(u64 x0, u64 x1) + { + // perfect shuffle the bytes. + const oc::block b = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0); + oc::block y = _mm_set_epi64x(x1, x0); + y = _mm_shuffle_epi8(y, b); + + // perfect shuffle the bits. + ssePerfectShuffle_round<4>(y); + ssePerfectShuffle_round<2>(y); + ssePerfectShuffle_round<1>(y); + return y; + } + + inline std::array ssePerfectUnshuffle(oc::block y) + { + // perfect shuffle the bits. + ssePerfectShuffle_round<1>(y); + ssePerfectShuffle_round<2>(y); + ssePerfectShuffle_round<4>(y); + + // perfect shuffle the bytes. + const oc::block b = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); + y = _mm_shuffle_epi8(y, b); + + return std::bit_cast>(y); + } + + // perfect shuffle 4 blocks on x0,x1 into 8 blocks of y. + inline void ssePerfectShuffle(const oc::block* x0, const oc::block* x1, oc::block* y) + { + // perfect shuffle the bytes. + const oc::block b = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0); + y[0] = _mm_set_epi64x(((u64*)x1)[0], ((u64*)x0)[0]); + y[1] = _mm_set_epi64x(((u64*)x1)[1], ((u64*)x0)[1]); + y[2] = _mm_set_epi64x(((u64*)x1)[2], ((u64*)x0)[2]); + y[3] = _mm_set_epi64x(((u64*)x1)[3], ((u64*)x0)[3]); + y[4] = _mm_set_epi64x(((u64*)x1)[4], ((u64*)x0)[4]); + y[5] = _mm_set_epi64x(((u64*)x1)[5], ((u64*)x0)[5]); + y[6] = _mm_set_epi64x(((u64*)x1)[6], ((u64*)x0)[6]); + y[7] = _mm_set_epi64x(((u64*)x1)[7], ((u64*)x0)[7]); + y[0] = _mm_shuffle_epi8(y[0], b); + y[1] = _mm_shuffle_epi8(y[1], b); + y[2] = _mm_shuffle_epi8(y[2], b); + y[3] = _mm_shuffle_epi8(y[3], b); + y[4] = _mm_shuffle_epi8(y[4], b); + y[5] = _mm_shuffle_epi8(y[5], b); + y[6] = _mm_shuffle_epi8(y[6], b); + y[7] = _mm_shuffle_epi8(y[7], b); + + // perfect shuffle the bits. + ssePerfectShuffle_round<4>(y); + ssePerfectShuffle_round<2>(y); + ssePerfectShuffle_round<1>(y); + } + + // perfect unshuffle 8 blocks of y into 4 blocks on x0,x1 into. + inline void ssePerfectUnshuffle(const oc::block* yy, oc::block* x0, oc::block* x1) + { + std::array y; + std::copy((u8*)yy, (u8*)(yy + y.size()), (u8*)y.data()); + // m emcpy(y.data(), yy, sizeof(y)); + + // perfect shuffle the bits. + ssePerfectShuffle_round<1>(y.data()); + ssePerfectShuffle_round<2>(y.data()); + ssePerfectShuffle_round<4>(y.data()); + + // perfect shuffle the bytes. + const oc::block b = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); + y[0] = _mm_shuffle_epi8(y[0], b); + y[1] = _mm_shuffle_epi8(y[1], b); + y[2] = _mm_shuffle_epi8(y[2], b); + y[3] = _mm_shuffle_epi8(y[3], b); + y[4] = _mm_shuffle_epi8(y[4], b); + y[5] = _mm_shuffle_epi8(y[5], b); + y[6] = _mm_shuffle_epi8(y[6], b); + y[7] = _mm_shuffle_epi8(y[7], b); + + + u64* yyy = (u64*)y.data(); + u64* xx1 = (u64*)x1; + u64* xx0 = (u64*)x0; + xx0[0] = yyy[0]; + xx1[0] = yyy[1]; + xx0[1] = yyy[2]; + xx1[1] = yyy[3]; + xx0[2] = yyy[4]; + xx1[2] = yyy[5]; + xx0[3] = yyy[6]; + xx1[3] = yyy[7]; + + xx0[4] = yyy[8]; + xx1[4] = yyy[9]; + xx0[5] = yyy[10]; + xx1[5] = yyy[11]; + xx0[6] = yyy[12]; + xx1[6] = yyy[13]; + xx0[7] = yyy[14]; + xx1[7] = yyy[15]; + } + + inline void ssePerfectShuffle(span input0, span input1, span output) + { + assert(input0.size() == input1.size()); + assert(input0.size() == (output.size() + 1) / 2); + u64 n1024 = output.size() / sizeof(std::array); + + auto in0 = (oc::block*)input0.data(); + auto in1 = (oc::block*)input1.data(); + auto out = (oc::block*)output.data(); + for (u64 i = 0; i < n1024; ++i) + { + ssePerfectShuffle(in0, in1, out); + in0 += 4; + in1 += 4; + out += 8; + } + + auto n64 = n1024 * 16; + auto n8 = n64 * sizeof(u64); + auto rem = input0.size() - n8 / 2; + while (rem) + { + auto min = std::min(rem, sizeof(u64)); + u64 x0 = 0, x1 = 0; + std::copy(input0.data() + n8 / 2, input0.data() + n8 / 2 + min, (u8*)&x0); + std::copy(input1.data() + n8 / 2, input1.data() + n8 / 2 + min, (u8*)&x1); + //m emcpy(&x0, &input0[n8 / 2], min); + //m emcpy(&x1, &input1[n8 / 2], min); + rem -= min; + + auto t = ssePerfectShuffle(x0, x1); + + auto min2 = std::min(output.size() - n8, sizeof(oc::block)); + std::copy((u8*)&t, (u8*)&t + min2, output.data() + n8); + //m emcpy(&output[n8], &t, min2); + n8 += min2; + } + } + + + inline void ssePerfectUnshuffle(span input, span output0, span output1) + { + assert(output0.size() == output1.size()); + assert(output0.size() == (input.size() + 1) / 2); + + u64 n1024 = input.size() / sizeof(std::array); + + auto out0 = (oc::block*)output0.data(); + auto out1 = (oc::block*)output1.data(); + auto in = (oc::block*)input.data(); + for (u64 i = 0; i < n1024; ++i) + { + assert((u8*)(in + 8) <= input.data() + input.size()); + assert((u8*)(out0 + 4) <= output0.data() + output0.size()); + assert((u8*)(out1 + 4) <= output1.data() + output1.size()); + ssePerfectUnshuffle(in, out0, out1); + + in += 8; + out0 += 4; + out1 += 4; + } + + + auto n64 = n1024 * 16; + auto n8 = n64 * sizeof(u64); + //auto n8 = n32 * sizeof(u32); + while (input.size() != n8) + { + auto rem = input.size() - n8; + auto min = std::min(rem, sizeof(oc::block)); + oc::block t = oc::ZeroBlock; + // m emcpy(&t, &input[n8], min); + std::copy(&input[n8], &input[n8] + min, (u8*)&t); + + auto r = ssePerfectUnshuffle(t); + + auto min2 = std::min(output0.size() - n8 / 2, sizeof(u64)); + // m emcpy(&output0[n8 / 2], &r[0], min2); + std::copy((u8*)&r[0], (u8*)&r[0] + min2, output0.data() + n8 / 2); + //m emcpy(&output1[n8 / 2], &r[1], min2); + std::copy((u8*)&r[1], (u8*)&r[1] + min2, output1.data() + n8 / 2); + + n8 += min; + } + } +#endif + + inline void perfectShuffle(span input0, span input1, span output) + { +#ifdef ENABLE_SSE + ssePerfectShuffle(input0, input1, output); +#else + cPerfectShuffle(input0, input1, output); +#endif + } + + inline void perfectUnshuffle(span input, span output0, span output1) + { +#ifdef ENABLE_SSE + ssePerfectUnshuffle(input, output0, output1); +#else + cPerfectUnshuffle(input, output0, output1); +#endif + } +} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/fft/FoliageFFT_bench.cpp b/libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp similarity index 92% rename from libOTe/Tools/Foliage/fft/FoliageFFT_bench.cpp rename to libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp index d10d7bc7..957dd7a2 100644 --- a/libOTe/Tools/Foliage/fft/FoliageFFT_bench.cpp +++ b/libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp @@ -6,11 +6,11 @@ #include #include -#include "libOTe/Tools/Foliage/fft/FoliageFft.h" +#include "libOTe/Tools/Foleage/fft/FoleageFft.h" #include "cryptoTools/Common/Aligned.h" #include "cryptoTools/Crypto/PRNG.h" -#include "libOTe/Tools/Foliage/FoliageUtils.h" +#include "libOTe/Tools/Foleage/FoleageUtils.h" #define NUMVARS 16 @@ -18,7 +18,7 @@ namespace osuCrypto { - double Foliage_FFT64_bench() + double Foleage_FFT64_bench() { size_t num_vars = NUMVARS; size_t num_coeffs = ipow(3, num_vars); @@ -41,7 +41,7 @@ namespace osuCrypto return time_taken; } - double Foliage_FFT32_bench() + double Foleage_FFT32_bench() { size_t num_vars = NUMVARS; size_t num_coeffs = ipow(3, num_vars); @@ -66,7 +66,7 @@ namespace osuCrypto return time_taken; } - double Foliage_FFT8_bench() + double Foleage_FFT8_bench() { size_t num_vars = NUMVARS; size_t num_coeffs = ipow(3, num_vars); @@ -100,7 +100,7 @@ namespace osuCrypto printf("Testing FFT (uint8 packing)\n"); for (int i = 0; i < testTrials; i++) { - time += Foliage_FFT8_bench(); + time += Foleage_FFT8_bench(); printf("Done with trial %i of %i\n", i + 1, testTrials); } printf("******************************************\n"); @@ -113,7 +113,7 @@ namespace osuCrypto time = 0; for (int i = 0; i < testTrials; i++) { - time += Foliage_FFT32_bench(); + time += Foleage_FFT32_bench(); printf("Done with trial %i of %i\n", i + 1, testTrials); } printf("******************************************\n"); @@ -126,7 +126,7 @@ namespace osuCrypto time = 0; for (int i = 0; i < testTrials; i++) { - time += Foliage_FFT64_bench(); + time += Foleage_FFT64_bench(); printf("Done with trial %i of %i\n", i + 1, testTrials); } printf("******************************************\n"); diff --git a/libOTe/Tools/Foleage/fft/FoleageFFT_bench.h b/libOTe/Tools/Foleage/fft/FoleageFFT_bench.h new file mode 100644 index 00000000..bc05951f --- /dev/null +++ b/libOTe/Tools/Foleage/fft/FoleageFFT_bench.h @@ -0,0 +1,13 @@ +#pragma once + + +namespace osuCrypto +{ + + double Foleage_FFT8_bench(); + double Foleage_FFT32_bench(); + double Foleage_FFT64_bench(); + + + +} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.cpp b/libOTe/Tools/Foleage/fft/FoleageFft.cpp new file mode 100644 index 00000000..62cb9042 --- /dev/null +++ b/libOTe/Tools/Foleage/fft/FoleageFft.cpp @@ -0,0 +1,856 @@ +#include +#include +#include "libOTe/Tools/Foleage/fft/FoleageFft.h" +#include "libOTe/Tools/Foleage/PerfectShuffle.h" +namespace osuCrypto { + + void fft_recursive_uint64( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint64( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint64( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint64( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint64_t tL, tM; + uint64_t mult, xor_h, xor_l; + + uint64_t* coeffsL = coeffs.data() + 0; + uint64_t* coeffsM = coeffs.data() + num_coeffs; + uint64_t* coeffsR = coeffs.data() + 2 * num_coeffs; + + const uint64_t pattern = 0xaaaaaaaaaaaaaaaa; + const uint64_t mask_h = pattern; // 0b101010101010101001010 + const uint64_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + + void fft_recursive_uint32( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint32( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint32( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint32( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint32_t tL, tM; + uint32_t mult, xor_h, xor_l; + + uint32_t* coeffsL = coeffs.data() + 0; + uint32_t* coeffsM = coeffs.data() + num_coeffs; + uint32_t* coeffsR = coeffs.data() + 2 * num_coeffs; + + const uint32_t pattern = 0xaaaaaaaa; + const uint32_t mask_h = pattern; // 0b101010101010101001010 + const uint32_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + + void fft_recursive_uint16( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint16( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint16( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint16( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint16_t tL, tM; + uint16_t mult, xor_h, xor_l; + + uint16_t* coeffsL = coeffs.data() + 0; + uint16_t* coeffsM = coeffs.data() + num_coeffs; + uint16_t* coeffsR = coeffs.data() + 2 * num_coeffs; + + const uint16_t pattern = 0xaaaa; + const uint16_t mask_h = pattern; // 0b101010101010101001010 + const uint16_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + + void fft_recursive_uint8( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint8( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint8( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint8( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint8_t tL, tM; + uint8_t mult, xor_h, xor_l; + + uint8_t* coeffsL = coeffs.data() + 0; + uint8_t* coeffsM = coeffs.data() + num_coeffs; + uint8_t* coeffsR = coeffs.data() + 2 * num_coeffs; + + const uint8_t pattern = 0xaa; + const uint8_t mask_h = pattern; // 0b101010101010101001010 + const uint8_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + + + void foleageFFT2( + uint8_t* lsb, + uint8_t* msb, + const size_t num_vars, + const size_t num_coeffs) + { + if (num_vars > 1) + { + // apply FFT on all left coefficients + foleageFFT2( + lsb, msb, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + foleageFFT2( + lsb + num_coeffs, + msb + num_coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + foleageFFT2( + lsb + 2 * num_coeffs, + msb + 2 * num_coeffs, + num_vars - 1, + num_coeffs / 3); + } + + uint8_t* __restrict ptrL0 = lsb + 0; + uint8_t* __restrict ptrL1 = msb + 0; + uint8_t* __restrict ptrM0 = lsb + num_coeffs; + uint8_t* __restrict ptrM1 = msb + num_coeffs; + uint8_t* __restrict ptrR0 = lsb + 2 * num_coeffs; + uint8_t* __restrict ptrR1 = msb + 2 * num_coeffs; + + for (size_t j = 0; j < num_coeffs; j++) + { + + auto coeffsL0 = *ptrL0; + auto coeffsL1 = *ptrL1; + auto coeffsM0 = *ptrM0; + auto coeffsM1 = *ptrM1; + auto coeffsR0 = *ptrR0; + auto coeffsR1 = *ptrR1; + + auto xor_h = coeffsM1 ^ coeffsR1; + auto xor_l = coeffsM0 ^ coeffsR0; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + auto mult0 = xor_h ^ xor_l; + auto mult1 = xor_l; + + // tL coefficient obtained by evaluating on X_i=1 + auto tL0 = coeffsL0 ^ coeffsM0 ^ coeffsR0; + auto tL1 = coeffsL1 ^ coeffsM1 ^ coeffsR1; + + // tM coefficient obtained by evaluating on X_i=\alpha + auto tM0 = coeffsL0 ^ coeffsR0 ^ mult0; + auto tM1 = coeffsL1 ^ coeffsR1 ^ mult1; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + *ptrR0 = coeffsL0 ^ coeffsM0 ^ mult0; + *ptrR1 = coeffsL1 ^ coeffsM1 ^ mult1; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + *ptrL0 = tL0; + *ptrL1 = tL1; + *ptrM0 = tM0; + *ptrM1 = tM1; + + ++ptrL0; + ++ptrL1; + ++ptrM0; + ++ptrM1; + ++ptrR0; + ++ptrR1; + } + } + + template + void foleageFFTLevel( + u8* lsb, + u8* msb, + BlockSize blockSize, + u64 regions + ) + { + //static_assert(depth); + //u64 blockSize = ipow(3, depth - 1); + + for (u64 r = 0; r < regions; ++r) + { + + uint8_t* __restrict ptrL0 = lsb + r * 3 * blockSize + 0; + uint8_t* __restrict ptrL1 = msb + r * 3 * blockSize + 0; + uint8_t* __restrict ptrM0 = lsb + r * 3 * blockSize + blockSize; + uint8_t* __restrict ptrM1 = msb + r * 3 * blockSize + blockSize; + uint8_t* __restrict ptrR0 = lsb + r * 3 * blockSize + 2 * blockSize; + uint8_t* __restrict ptrR1 = msb + r * 3 * blockSize + 2 * blockSize; + + constexpr u64 width = 1; + auto main = blockSize / (width * 16); + for (u64 k = 0; k < main; ++k) + { + + block coeffsL0[width]; + block coeffsL1[width]; + block coeffsM0[width]; + block coeffsM1[width]; + block coeffsR0[width]; + block coeffsR1[width]; + + //{ constexpr u64 VAR = 1; STATEMENT; }\ + //{ constexpr u64 VAR = 2; STATEMENT; }\ + //{ constexpr u64 VAR = 3; STATEMENT; }\ + //{ constexpr u64 VAR = 4; STATEMENT; }\ + //{ constexpr u64 VAR = 5; STATEMENT; }\ + //{ constexpr u64 VAR = 6; STATEMENT; }\ + //{ constexpr u64 VAR = 7; STATEMENT; }\ + +#define SIMD8(VAR, STATEMENT) \ + { constexpr u64 VAR = 0; STATEMENT; }\ + do{}while(0) + + + SIMD8(q, coeffsL0[q] = _mm_loadu_epi8(ptrL0 + q * 16)); + SIMD8(q, coeffsL1[q] = _mm_loadu_epi8(ptrL1 + q * 16)); + SIMD8(q, coeffsM0[q] = _mm_loadu_epi8(ptrM0 + q * 16)); + SIMD8(q, coeffsM1[q] = _mm_loadu_epi8(ptrM1 + q * 16)); + SIMD8(q, coeffsR0[q] = _mm_loadu_epi8(ptrR0 + q * 16)); + SIMD8(q, coeffsR1[q] = _mm_loadu_epi8(ptrR1 + q * 16)); + + + + block xor_h[width], xor_l[width]; + SIMD8(j, xor_h[j] = coeffsM1[j] ^ coeffsR1[j]); + SIMD8(j, xor_l[j] = coeffsM0[j] ^ coeffsR0[j]); + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + block mult0[width];// , mult1[width]; + SIMD8(j, mult0[j] = xor_h[j] ^ xor_l[j]); + //SIMD8(j, mult1[j] = xor_l[j]); + + // tL coefficient obtained by evaluating on X_i=1 + block tL0[width], tL1[width]; + SIMD8(j, tL0[j] = coeffsL0[j] ^ coeffsM0[j] ^ coeffsR0[j]); + SIMD8(j, tL1[j] = coeffsL1[j] ^ coeffsM1[j] ^ coeffsR1[j]); + + // tM coefficient obtained by evaluating on X_i=\alpha + block tM0[width], tM1[width]; + SIMD8(j, tM0[j] = coeffsL0[j] ^ coeffsR0[j] ^ mult0[j]); + SIMD8(j, tM1[j] = coeffsL1[j] ^ coeffsR1[j] ^ xor_l[j]); + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + SIMD8(j, coeffsR0[j] = coeffsL0[j] ^ coeffsM0[j] ^ mult0[j]); + SIMD8(j, coeffsR1[j] = coeffsL1[j] ^ coeffsM1[j] ^ xor_l[j]); + + SIMD8(j, _mm_storeu_epi8(ptrR0 + j * 16, coeffsR0[j])); + SIMD8(j, _mm_storeu_epi8(ptrR1 + j * 16, coeffsR1[j])); + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + SIMD8(j, _mm_storeu_epi8(ptrL0 + j * 16, tL0[j])); + SIMD8(j, _mm_storeu_epi8(ptrL1 + j * 16, tL1[j])); + + SIMD8(j, _mm_storeu_epi8(ptrM0 + j * 16, tM0[j])); + SIMD8(j, _mm_storeu_epi8(ptrM1 + j * 16, tM1[j])); + + + ptrL0 += width * 16; + ptrL1 += width * 16; + ptrM0 += width * 16; + ptrM1 += width * 16; + ptrR0 += width * 16; + ptrR1 += width * 16; + } + +#undef SIMD8 + + for (size_t j = main * width * 16; j < blockSize; j++) + { + + auto coeffsL0 = *ptrL0; + auto coeffsL1 = *ptrL1; + auto coeffsM0 = *ptrM0; + auto coeffsM1 = *ptrM1; + auto coeffsR0 = *ptrR0; + auto coeffsR1 = *ptrR1; + + auto xor_h = coeffsM1 ^ coeffsR1; + auto xor_l = coeffsM0 ^ coeffsR0; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + auto mult0 = xor_h ^ xor_l; + auto mult1 = xor_l; + + // tL coefficient obtained by evaluating on X_i=1 + auto tL0 = coeffsL0 ^ coeffsM0 ^ coeffsR0; + auto tL1 = coeffsL1 ^ coeffsM1 ^ coeffsR1; + + // tM coefficient obtained by evaluating on X_i=\alpha + auto tM0 = coeffsL0 ^ coeffsR0 ^ mult0; + auto tM1 = coeffsL1 ^ coeffsR1 ^ mult1; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + *ptrR0 = coeffsL0 ^ coeffsM0 ^ mult0; + *ptrR1 = coeffsL1 ^ coeffsM1 ^ mult1; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + *ptrL0 = tL0; + *ptrL1 = tL1; + *ptrM0 = tM0; + *ptrM1 = tM1; + + ++ptrL0; + ++ptrL1; + ++ptrM0; + ++ptrM1; + ++ptrR0; + ++ptrR1; + } + } + } + + template + void foleageFFTL1L2( + u8* lsb, + u8* msb, + u64 regions + ) + { + //static_assert(depth); + //u64 blockSize = ipow(3, depth - 1); + u64 r = 0; + if constexpr (0 && stride == 2) + { + constexpr auto stepSize = 24; + auto main = regions / stepSize; + block tempLsb[9]; + block tempMsb[9]; + + for (u64 i = 0; i < main; ++i, r += stepSize) + { + auto lsb0 = lsb + r * stride; + auto lsb1 = lsb + r * stride + 16; + auto lsb2 = lsb + r * stride + 32; + + auto msb0 = msb + r * stride; + auto msb1 = msb + r * stride + 16; + auto msb2 = msb + r * stride + 32; + + // 0 1 2 3 4 5 6 7 + // 8 9 10 11 12 13 14 15 + // 16 17 18 19 20 21 22 23 + foleageTransposeLeaf(lsb0, (__m128i*) & tempLsb[0]); + foleageTransposeLeaf(lsb1, (__m128i*) & tempLsb[1]); + foleageTransposeLeaf(lsb2, (__m128i*) & tempLsb[2]); + + foleageTransposeLeaf(msb0, (__m128i*) & tempMsb[0]); + foleageTransposeLeaf(msb1, (__m128i*) & tempMsb[1]); + foleageTransposeLeaf(msb2, (__m128i*) & tempMsb[2]); + + + foleageFFTOne<1>( + &tempLsb[0], &tempMsb[0], + &tempLsb[1], &tempMsb[1], + &tempLsb[2], &tempMsb[2] + ); + + foleageFFTOne<1>( + &tempLsb[3], &tempMsb[3], + &tempLsb[4], &tempMsb[4], + &tempLsb[5], &tempMsb[5] + ); + + foleageFFTOne<1>( + &tempLsb[6], &tempMsb[6], + &tempLsb[7], &tempMsb[7], + &tempLsb[8], &tempMsb[8] + ); + + foleageTranspose((u8*)&tempLsb[0], (__m128i*)lsb0); + + foleageTranspose((u8*)&tempMsb[0], (__m128i*)msb0); + + foleageFFTOne<3>( + (block*)lsb0, (block*)msb0, + (block*)lsb1, (block*)msb1, + (block*)lsb2, (block*)msb2 + ); + } + } + + for (; r < regions; ++r) + { + constexpr u8 blockSize = 3 * stride; + uint8_t* __restrict ptrL0 = lsb + r * 3 * blockSize + 0; + uint8_t* __restrict ptrL1 = msb + r * 3 * blockSize + 0; + uint8_t* __restrict ptrM0 = lsb + r * 3 * blockSize + blockSize; + uint8_t* __restrict ptrM1 = msb + r * 3 * blockSize + blockSize; + uint8_t* __restrict ptrR0 = lsb + r * 3 * blockSize + 2 * blockSize; + uint8_t* __restrict ptrR1 = msb + r * 3 * blockSize + 2 * blockSize; + + + for (u64 j = 0; j < 9; j += 3) + { + + foleageFFTOne( + ptrL0 + (j + 0) * stride, ptrL1 + (j + 0) * stride, + ptrL0 + (j + 1) * stride, ptrL1 + (j + 1) * stride, + ptrL0 + (j + 2) * stride, ptrL1 + (j + 2) * stride + ); + } + + //foleageFFTOne( + // ptrL0 + 0 * stride, ptrL1 + 0 * stride, + // ptrL0 + 1 * stride, ptrL1 + 1 * stride, + // ptrL0 + 2 * stride, ptrL1 + 2 * stride + //); + //foleageFFTOne( + // ptrM0 + 0 * stride, ptrM1 + 0 * stride, + // ptrM0 + 1 * stride, ptrM1 + 1 * stride, + // ptrM0 + 2 * stride, ptrM1 + 2 * stride + //); + + //foleageFFTOne( + // ptrR0 + 0 * stride, ptrR1 + 0 * stride, + // ptrR0 + 1 * stride, ptrR1 + 1 * stride, + // ptrR0 + 2 * stride, ptrR1 + 2 * stride + //); + + + foleageFFTOne( + ptrL0, ptrL1, + ptrM0, ptrM1, + ptrR0, ptrR1 + ); + //foleageFFTOne( + // ptrL0 + 1 * stride, ptrL1 + 1 * stride, + // ptrM0 + 1 * stride, ptrM1 + 1 * stride, + // ptrR0 + 1 * stride, ptrR1 + 1 * stride + //); + //foleageFFTOne( + // ptrL0 + 2 * stride, ptrL1 + 2 * stride, + // ptrM0 + 2 * stride, ptrM1 + 2 * stride, + // ptrR0 + 2 * stride, ptrR1 + 2 * stride + //); + } + } + + template + void foleageFFTL1L2L3( + u8* lsb, + u8* msb, + u64 regions + ) + { + + for (u64 r = 0; r < regions; ++r) + { + constexpr u8 L4blockSize = 27 * stride; + + // L3 has 3 blocks of size L3blockSize + constexpr u8 L3blockSize = 9 * stride; + constexpr u8 L2blockSize = 3 * stride; + constexpr u8 L1blockSize = 1 * stride; + + uint8_t* baseLsb = lsb + r * L4blockSize; + uint8_t* baseMsb = msb + r * L4blockSize; + + + for (u64 k = 0; k < 3; ++k) + { + // left 1/3 + uint8_t* __restrict ptrL0 = baseLsb + k * L3blockSize + 0 * L2blockSize; + uint8_t* __restrict ptrL1 = baseMsb + k * L3blockSize + 0 * L2blockSize; + // middle 1/3 + uint8_t* __restrict ptrM0 = baseLsb + k * L3blockSize + 1 * L2blockSize; + uint8_t* __restrict ptrM1 = baseMsb + k * L3blockSize + 1 * L2blockSize; + // right 1/3 + uint8_t* __restrict ptrR0 = baseLsb + k * L3blockSize + 2 * L2blockSize; + uint8_t* __restrict ptrR1 = baseMsb + k * L3blockSize + 2 * L2blockSize; + + + for (u64 j = 0; j < 9; j += 3) + { + foleageFFTOne( + ptrL0 + (j + 0) * stride, ptrL1 + (j + 0) * stride, + ptrL0 + (j + 1) * stride, ptrL1 + (j + 1) * stride, + ptrL0 + (j + 2) * stride, ptrL1 + (j + 2) * stride + ); + } + + foleageFFTOne( + ptrL0, ptrL1, + ptrM0, ptrM1, + ptrR0, ptrR1 + ); + } + + + foleageFFTOne( + baseLsb + 0 * L3blockSize, baseMsb + 0 * L3blockSize, + baseLsb + 1 * L3blockSize, baseMsb + 1 * L3blockSize, + baseLsb + 2 * L3blockSize, baseMsb + 2 * L3blockSize + ); + + } + } + + void foleageFFT( + uint8_t* lsb, + uint8_t* msb, + const size_t num_vars, + const size_t num_coeffs) + { + //assert(lsb.size() == msb.size()); + //assert(lsb.size() % stride == 0); + //assert(blockSize == 1 || blockSize % 3 == 0); + //assert(blockSize < lsb.size() / stride); + + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + //u64 stepSize = blockSize * stride; + + + if (num_vars > 1) + { + // apply FFT on all left coefficients + foleageFFT( + lsb, msb, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + foleageFFT( + lsb + num_coeffs, + msb + num_coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + foleageFFT( + lsb + 2 * num_coeffs, + msb + 2 * num_coeffs, + num_vars - 1, + num_coeffs / 3); + } + + + foleageFFTLevel(lsb, msb, num_coeffs, 1); + } + + template + void foleageFFT2( + span lsb, + span msb) + { + auto n = lsb.size() / stride; + + auto log3N = log3Ceil(n); + if (n != ipow(3, log3N)) + throw RTE_LOC; + if (lsb.size() != n * stride) + throw RTE_LOC; + if (lsb.size() != msb.size()) + throw RTE_LOC; + for (u64 i = 1; i <= log3N; ++i) + { + auto regionSize = ipow(3, i); + auto regions = n / regionSize; + + switch (i) + { + case 1: + if(log3N == 1) + foleageFFTLevel(lsb.data(), msb.data(), std::integral_constant{}, regions); + break; + case 2: + // foleageFFTLevel(lsb.data(), msb.data(), std::integral_constant{}, regions); + //if (log3N == 2) + foleageFFTL1L2(lsb.data(), msb.data(), regions); + break; + case 3: + foleageFFTLevel(lsb.data(), msb.data(), std::integral_constant{}, regions); + //foleageFFTL1L2L3(lsb.data(), msb.data(), regions); + break; + case 4: + foleageFFTLevel(lsb.data(), msb.data(), std::integral_constant{}, regions); + break; + default: + u64 blockSize = regionSize / 3 * stride; + foleageFFTLevel(lsb.data(), msb.data(), blockSize, regions); + break; + } + } + + } + + template + void foleageFFT2<2>( + span lsb, + span msb); + template + void foleageFFT2<1>( + span lsb, + span msb); +} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.h b/libOTe/Tools/Foleage/fft/FoleageFft.h new file mode 100644 index 00000000..3f005ce2 --- /dev/null +++ b/libOTe/Tools/Foleage/fft/FoleageFft.h @@ -0,0 +1,388 @@ +#pragma once + +#include +#include +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/MatrixView.h" +#include "libOTe/Tools/Foleage/FoleageUtils.h" + +//#include "libOTe/Tools/Foleage/utils.h" +namespace osuCrypto { + + //typedef __int128 int128_t; + //typedef unsigned __int128 uint128_t; + + // FFT for (up to) 32 polynomials over F4 + void fft_recursive_uint64( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + // FFT for (up to) 16 polynomials over F4 + void fft_recursive_uint32( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + // FFT for (up to) 8 polynomials over F4 + void fft_recursive_uint16( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + // FFT for (up to) 4 polynomials over F4 + void fft_recursive_uint8( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + + + void foleageFFT( + uint8_t* lsb, + uint8_t* msb, + const size_t num_vars, + const size_t num_coeffs); + + inline void printShuffle1(const u16* ptr) + { + + for (u64 j = 0; j < 8; ++j) + { + auto v = ptr[j]; + std::cout << std::setw(2) << std::setfill(' ') << v << " "; + } + } + inline void printShuffle3(const u16* ptr) + { + for (u64 i = 0; i < 3; ++i) + { + printShuffle1(ptr + i * 8); + std::cout << std::endl; + } + } + + inline void printShuffle9(const u16* ptr) + { + for (u64 i = 0; i < 3; ++i) + { + printShuffle1(ptr + i * 24); + printShuffle1(ptr + i * 24 + 8); + printShuffle1(ptr + i * 24 + 16); + std::cout << std::endl; + } + } + // shuffles 3 blocks or 48 bytes + template + void foleageTransposeLeaf(u8* src, __m128i* dst) + { + + if constexpr (stride == 2) + { + // input: + // 0 1 2 3 4 5 6 7 + // 8 9 10 11 12 13 14 15 + // 16 17 18 19 20 21 22 23 + // + // output: + // 0 3 6 9 12 15 18 21 + // 1 4 7 10 13 16 19 22 + // 2 5 8 11 14 17 20 23 + + if (1) + { + // 0 6 12 18 + auto a0 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(18, 12, 6, 0), 2); + // 3 9 15 21 + auto a1 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(20, 14, 8, 2), 2); + // 0 3 6 9 12 15 18 21 + dst[0] = _mm_blendv_epi8(a0, a1, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + + // 1 7 13 19 + auto b0 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(19, 13, 7, 1), 2); + // 4 10 16 22 + auto b1 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(21, 15, 9, 3), 2); + // 1 4 7 10 13 16 19 22 + dst[1] = _mm_blendv_epi8(b0, b1, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + + // 2 8 14 20 + auto c0 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(20, 14, 8, 2), 2); + // 5 11 17 23 + auto c1 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(22, 16, 10, 4), 2); + // 2 5 8 11 14 17 20 23 + dst[2] = _mm_blendv_epi8(c0, c1, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + + } + else + { + + + // 0 1 2 3 4 5 6 7 + auto v0 = _mm_loadu_si128((__m128i*)src); + + // 8 9 10 11 12 13 14 15 + auto v1 = _mm_loadu_si128((__m128i*)(src + 16)); + + // 16 17 18 19 20 21 22 23 + auto v2 = _mm_loadu_si128((__m128i*)(src + 32)); + + // 0 3 6 1 4 7 2 5 + // 0 0c 0d 0e 0f, 1a 1b 1c 1d 1e 1f, 2a 2b 2c 2d + v0 = _mm_shuffle_epi8(v0, _mm_set_epi8(11, 10, 5, 4, 15, 14, 9, 8, 3, 2, 13, 12, 7, 6, 1, 0)); + + // 8 11 14 9 12 15 10 13 + // 2e ef 2g 2h 2i ej, 0g 0h 0i 0j 0k 0l, 1g 1h 1i 1j + v1 = _mm_shuffle_epi8(v1, _mm_set_epi8(11, 10, 5, 4, 15, 14, 9, 8, 3, 2, 13, 12, 7, 6, 1, 0)); + + // 16 19 22 17 20 23 18 21 + // 1k 1l 1m 1n 1o 1p, 2k 2l 2m 2n 2o 2p, 0m 0n 0o 0p + v2 = _mm_shuffle_epi8(v2, _mm_set_epi8(11, 10, 5, 4, 15, 14, 9, 8, 3, 2, 13, 12, 7, 6, 1, 0)); + + // 0 3 6 9 12 15 18 21 + // 0 0c 0d 0e 0f, 0g 0h 0i 0j 0k 0l, 1g 1h 1i 1j + auto u0 = _mm_blendv_epi8(v0, v1, _mm_set_epi16(-1, -1, -1, -1, -1, 0, 0, 0)); + + // 0 3 6 17 20 23 18 21 + // 0 0c 0d 0e 0f, 0g 0h 0i 0j 0k 0l, 0m 0n 0o 0p + u0 = _mm_blendv_epi8(u0, v2, _mm_set_epi16(-1, -1, 0, 0, 0, 0, 0, 0)); + + // 16 19 22 1 4 7 2 5 + // 1k 1l 1m 1n 1o 1p, 1a 1b 1c 1d 1e 1f, 2a 2b 2c 2d + auto u1 = _mm_blendv_epi8(v2, v0, _mm_set_epi16(-1, -1, -1, -1, -1, 0, 0, 0)); + + // 16 19 22 1 4 7 10 13 + // 1k 1l 1m 1n 1o 1p, 1a 1b 1c 1d 1e 1f, 1g 1h 1i 1j + u1 = _mm_blendv_epi8(u1, v1, _mm_set_epi16(-1, -1, 0, 0, 0, 0, 0, 0)); + + // 1 4 7 10 13 16 19 22 + // 1a 1b 1c 1d 1e 1f 1g 1h 1i 1j 1k 1l 1m 1n 1o 1p + u1 = _mm_shuffle_epi8(u1, _mm_set_epi8(5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6)); + + + // 8 11 14 17 20 23 18 21 + // 2e ef 2g 2h 2i ej, 2k 2l 2m 2n 2o 2p, 0m 0n 0o 0p + auto u2 = _mm_blendv_epi8(v1, v2, _mm_set_epi16(-1, -1, -1, -1, -1, 0, 0, 0)); + + // 8 11 14 17 20 23 2 5 + // 2e ef 2g 2h 2i ej, 2k 2l 2m 2n 2o 2p, 2a 2b 2c 2d + u2 = _mm_blendv_epi8(u2, v0, _mm_set_epi16(-1, -1, 0, 0, 0, 0, 0, 0)); + + // 2 5 8 11 14 17 20 23 + // 2a 2b 2c 2d 2e ef 2g 2h 2i ej 2k 2l 2m 2n 2o 2p, + u2 = _mm_shuffle_epi8(u2, _mm_set_epi8(11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12)); + + + _mm_store_si128(dst, u0); + _mm_store_si128(dst + 1, u1); + _mm_store_si128(dst + 2, u2); + } + } + else + { + throw RTE_LOC; + } + } + + // src points at the input data. Logically, there are 3 rows and 24 columns. + // each element is of stride bytes. The output is 3 rows and 8 columns. Each + // element is of stride * 3 bytes. The i'th element in the output are the + // three elements in the i'th column of the input. + // + // the input has 8 columns of row 0, then 8 columns row 1, 8 columns row 2, then repeates. + template + void foleageTranspose(u8* __restrict src, __m128i* __restrict dst) + { + if constexpr (stride == 2) + { + // input data: + // 0 1 2 3 4 5 6 7 + // 8 9 10 11 12 13 14 15 + // 16 17 18 19 20 21 22 23 + // + // 24 25 26 27 28 29 30 31 + // 32 33 34 35 36 37 38 39 + // 40 41 42 43 44 45 46 47 + // + // 48 49 50 51 52 53 54 55 + // 56 57 58 59 60 61 62 63 + // 64 65 66 67 68 69 70 71 + // + + // the input comes in 16 byte chunks. chunks {0,3,6},{1,4,7},{2,5,8} each belong to the same FFT position {0,1,2}. If we lay out the data + // logically we get: + // | | | | | | | + // 0 1 2 3 4 5 6 7 24 25 26 27 28 29 30 31 48 49 50 51 52 53 54 55 + // 8 9 10 11 12 13 14 15 32 33 34 35 36 37 38 39 56 57 58 59 60 61 62 63 + // 16 17 18 19 20 21 22 23 40 41 42 43 44 45 46 47 64 65 66 67 68 69 70 71 + // | | | | | | | + // + // at the previous FFT level, each column corresponds to a FFT instance, e.g. sub blocks {0,8,16}, {1,9,17}, ... + // + // We now want to merge these sub blocks into a single block. This corresponds + // to doing a 3x3 sub block transpose. + // + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 + // | | | | | | | + // 0 8 16 3 11 19 6 14 22 25 33 41 28 36 44 31 39 47 50 58 66 53 61 69 + // 1 9 17 4 12 20 7 15 23 26 34 42 29 37 45 48 56 64 51 59 67 54 62 70 + // 2 10 18 5 13 21 24 32 40 27 35 43 30 38 46 49 57 65 52 60 68 55 63 71 + // | | | | | | | + // + // We are going to transpose using the i32gather instruction. We want the output to be stored with + // each row being contiguous, e.g. "0 8 ... 69" should all be next to eachother. + // Each position takes up stride=2 bytes. But the i32gather instruction works on 4 byte chunks. + // So we will split the 8 gathered values across two instructions. One to gather the even + // positions and one the odd. We will then blend these two together to get the final output. eg: + + //0 8 16 3 11 19 6 14 | 22 25 33 41 28 36 44 31 |39 47 50 58 66 53 61 69 + //0 16 11 6 | 22 33 28 44 |39 50 66 61 + // 8 3 19 14 | 25 41 36 31 | 47 58 53 69 + // + // For row 0 we want to select + // * blend(gather(0,16,11,6),gatherHigh(8,3,19,14)) + // * blend(gather(22,33,28,44),gatherHigh(35,41,36,31)) + // * blend(gather(39,50,66,61),gatherHigh(47,58,53,69)) + // + // where gatherHigh(a,b,c,d) = gather(a-1,b-1,c-1,d-1), + // and blend(...) takes every other 16 bits. + // + // The other rows follow the same logic. + // + // the final set of indices are (each 4 are in reverse order to match _mm_set_epi32): + // + // 6,11,16, 0 | 44,28,33,22 | 61,66,50,39 + // 13,18, 2, 7 | 30,35,40,34 | 68,52,57,46 + // + //* 7,12,17,1 | 45,29,34,23 | 62,67,51,56 + //* 14,19, 3,8 | 47,36,41,25 | 69,53,58,63 + // + //* 24,13,18,2 | 46,30,35,40 | 63,68,52,57 + //* 31,20, 4,9 | 48,37,42,26 | 70,54,59,64 + + // 0 0 + auto a00 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(6, 11, 16, 0), 2); + auto a01 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(13, 18, 2, 7), 2); + dst[0] = _mm_blendv_epi8(a00, a01, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + auto a10 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(44, 28, 33, 22), 2); + auto a11 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(30, 35, 40, 24), 2); + dst[1] = _mm_blendv_epi8(a10, a11, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + auto a20 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(61, 66, 50, 39), 2); + auto a21 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(68, 52, 57, 46), 2); + dst[2] = _mm_blendv_epi8(a20, a21, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + + + auto b00 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(7, 12, 17, 1), 2); + auto b01 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(14, 19, 3, 8), 2); + dst[3] = _mm_blendv_epi8(b00, b01, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + auto b10 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(45, 29, 34, 23), 2); + auto b11 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(47, 36, 41, 25), 2); + dst[4] = _mm_blendv_epi8(b10, b11, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + auto b20 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(62, 67, 51, 56), 2); + auto b21 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(69, 53, 58, 63), 2); + dst[5] = _mm_blendv_epi8(b20, b21, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + + + auto c00 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(24, 13, 18, 2), 2); + auto c01 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(31, 20, 4, 9), 2); + dst[6] = _mm_blendv_epi8(c00, c01, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + auto c10 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(46, 30, 35, 40), 2); + auto c11 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(48, 37, 42, 26), 2); + dst[7] = _mm_blendv_epi8(c10, c11, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + auto c20 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(63, 68, 52, 57), 2); + auto c21 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(70, 54, 59, 64), 2); + dst[8] = _mm_blendv_epi8(c20, c21, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); + + } + else + { + throw RTE_LOC; + } + } + + template + void foliageUnTranspose(u8* src, __m128i* dst) + { + constexpr std::array inv{ + 0, 1, 2, 24, 25, 26, 48, 49, 50, + 3, 4, 5, 27, 28, 29, 51, 52, 53, + 6, 7, 8, 30, 31, 32, 54, 55, 56, + 9, 10, 11, 33, 34, 35, 57, 58, 59, + 12, 13, 14, 36, 37, 38, 60, 61, 62, + 15, 16, 17, 39, 40, 41, 63, 64, 65, + 18, 19, 20, 42, 43, 44, 66, 67, 68, + 21, 22, 23, 45, 46, 47, 69, 70, 71 + }; + + auto dstPtr = (u8*)dst; + for (u64 i = 0; i < inv.size(); ++i) + { + memcpy(dstPtr, src + inv[i] * stride, stride); + dstPtr += stride; + } + + + // 0 1 2 24 25 26 48 49 50 3 4 5 27 28 29 51 52 53 6 7 8 30 31 32 + // 54 55 56 9 10 11 33 34 35 57 58 59 12 13 14 36 37 38 60 61 62 15 16 17 + // 39 40 41 63 64 65 18 19 20 42 43 44 66 67 68 21 22 23 45 46 47 69 70 71 + } + + + + template + OC_FORCEINLINE void foleageFFTOne( + T* __restrict coeffsL0, + T* __restrict coeffsL1, + T* __restrict coeffsM0, + T* __restrict coeffsM1, + T* __restrict coeffsR0, + T* __restrict coeffsR1) + { + +#pragma unroll(stride) + for (u64 i = 0; i < stride; ++i) + { + + auto xor_h = coeffsM1[i] ^ coeffsR1[i]; + auto xor_l = coeffsM0[i] ^ coeffsR0[i]; + + auto mult0 = xor_h ^ xor_l; + auto mult1 = xor_l; + + // tL coefficient obtained by evaluating on X_i=1 + auto tL0 = coeffsL0[i] ^ xor_l; + auto tL1 = coeffsL1[i] ^ xor_h; + auto tM0 = coeffsL0[i] ^ coeffsR0[i] ^ mult0; + auto tM1 = coeffsL1[i] ^ coeffsR1[i] ^ mult1; + coeffsR0[i] = coeffsL0[i] ^ coeffsM0[i] ^ mult0; + coeffsR1[i] = coeffsL1[i] ^ coeffsM1[i] ^ mult1; + coeffsL0[i] = tL0; + coeffsL1[i] = tL1; + coeffsM0[i] = tM0; + coeffsM1[i] = tM1; + } + } + + + + inline void foleageFFT( + MatrixView lsb, + MatrixView msb) + { + if (lsb.rows() != msb.rows()) + throw RTE_LOC; + if (lsb.cols() != msb.cols()) + throw RTE_LOC; + auto numCoeffs = lsb.rows(); + if (numCoeffs % 3) + throw RTE_LOC; + auto numVars = log3Ceil(numCoeffs); + foleageFFT(lsb.data(), msb.data(), numVars, lsb.size() / 3); + } + + + template + void foleageFFT2( + span lsb, + span msb); + +} diff --git a/libOTe/Tools/Foliage/spfss_test.cpp b/libOTe/Tools/Foleage/spfss_test.cpp similarity index 98% rename from libOTe/Tools/Foliage/spfss_test.cpp rename to libOTe/Tools/Foleage/spfss_test.cpp index d5111f15..9e0b301d 100644 --- a/libOTe/Tools/Foliage/spfss_test.cpp +++ b/libOTe/Tools/Foleage/spfss_test.cpp @@ -2,8 +2,8 @@ #include #include -#include "libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h" -#include "FoliageUtils.h" +#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" +#include "FoleageUtils.h" #define SUMT 730 // sum of T DPFs diff --git a/libOTe/Tools/Foliage/tri-dpf/.gitignore b/libOTe/Tools/Foleage/tri-dpf/.gitignore similarity index 100% rename from libOTe/Tools/Foliage/tri-dpf/.gitignore rename to libOTe/Tools/Foleage/tri-dpf/.gitignore diff --git a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf.cpp b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp similarity index 99% rename from libOTe/Tools/Foliage/tri-dpf/FoliageDpf.cpp rename to libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp index 81fc77ef..359f7243 100644 --- a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf.cpp +++ b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp @@ -1,7 +1,7 @@ -#include "FoliageDpf.h" +#include "FoleageDpf.h" -#include "libOTe/Tools/Foliage/tri-dpf/TriDpfUtils.h" +#include "libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h" //#include diff --git a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h similarity index 84% rename from libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h rename to libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h index 1ca23568..0e3f96eb 100644 --- a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h +++ b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h @@ -3,8 +3,8 @@ #include #include -#include "libOTe/Tools/Foliage/FoliageUtils.h" -#include "libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h" +#include "libOTe/Tools/Foleage/FoleageUtils.h" +#include "libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h" namespace osuCrypto diff --git a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.cpp b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp similarity index 97% rename from libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.cpp rename to libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp index 30ded078..ce57ed9d 100644 --- a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.cpp +++ b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp @@ -6,8 +6,8 @@ #include #include -#include "libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h" -//#include "libOTe/Tools/Foliage/tri-dpf/FoliageHalfDpf.h" +#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" +//#include "libOTe/Tools/Foleage/tri-dpf/FoleageHalfDpf.h" #include #define FULLEVALDOMAIN 14 diff --git a/libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.h b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.h similarity index 100% rename from libOTe/Tools/Foliage/tri-dpf/FoliageDpf_test.h rename to libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.h diff --git a/libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h b/libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h similarity index 97% rename from libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h rename to libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h index 0d69e0a4..363ac56e 100644 --- a/libOTe/Tools/Foliage/tri-dpf/FoliagePrf.h +++ b/libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h @@ -4,7 +4,7 @@ #include #include "cryptoTools/Crypto/AES.h" //#include "utils.h" -#include "libOTe/Tools/Foliage/FoliageUtils.h" +#include "libOTe/Tools/Foleage/FoleageUtils.h" namespace osuCrypto { diff --git a/libOTe/Tools/Foliage/tri-dpf/LICENSE b/libOTe/Tools/Foleage/tri-dpf/LICENSE similarity index 100% rename from libOTe/Tools/Foliage/tri-dpf/LICENSE rename to libOTe/Tools/Foleage/tri-dpf/LICENSE diff --git a/libOTe/Tools/Foliage/tri-dpf/README.md b/libOTe/Tools/Foleage/tri-dpf/README.md similarity index 100% rename from libOTe/Tools/Foliage/tri-dpf/README.md rename to libOTe/Tools/Foleage/tri-dpf/README.md diff --git a/libOTe/Tools/Foliage/tri-dpf/TriDpfUtils.h b/libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h similarity index 96% rename from libOTe/Tools/Foliage/tri-dpf/TriDpfUtils.h rename to libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h index 7910b05a..d6648a1d 100644 --- a/libOTe/Tools/Foliage/tri-dpf/TriDpfUtils.h +++ b/libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h @@ -3,7 +3,7 @@ #include #include -#include "libOTe/Tools/Foliage/FoliageUtils.h" +#include "libOTe/Tools/Foleage/FoleageUtils.h" #include "cryptoTools/Common/BitIterator.h" namespace osuCrypto diff --git a/libOTe/Tools/Foliage/uint128.h b/libOTe/Tools/Foleage/uint128.h similarity index 100% rename from libOTe/Tools/Foliage/uint128.h rename to libOTe/Tools/Foleage/uint128.h diff --git a/libOTe/Tools/Foliage/fft/FoliageFFT_bench.h b/libOTe/Tools/Foliage/fft/FoliageFFT_bench.h deleted file mode 100644 index 922b1a57..00000000 --- a/libOTe/Tools/Foliage/fft/FoliageFFT_bench.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - - -namespace osuCrypto -{ - - double Foliage_FFT8_bench(); - double Foliage_FFT32_bench(); - double Foliage_FFT64_bench(); - - - -} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/fft/FoliageFft.cpp b/libOTe/Tools/Foliage/fft/FoliageFft.cpp deleted file mode 100644 index 54d8f615..00000000 --- a/libOTe/Tools/Foliage/fft/FoliageFft.cpp +++ /dev/null @@ -1,311 +0,0 @@ -#include -#include -#include "libOTe/Tools/Foliage/fft/FoliageFft.h" - -namespace osuCrypto { - - void fft_recursive_uint64( - span coeffs, - const size_t num_vars, - const size_t num_coeffs) - { - // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) - - if (num_vars > 1) - { - // apply FFT on all left coefficients - fft_recursive_uint64( - coeffs, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all middle coefficients - fft_recursive_uint64( - coeffs.subspan(num_coeffs), - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all right coefficients - fft_recursive_uint64( - coeffs.subspan(2 * num_coeffs), - num_vars - 1, - num_coeffs / 3); - } - - // temp variables to store intermediate values - uint64_t tL, tM; - uint64_t mult, xor_h, xor_l; - - uint64_t* coeffsL = &coeffs[0]; - uint64_t* coeffsM = &coeffs[num_coeffs]; - uint64_t* coeffsR = &coeffs[2 * num_coeffs]; - - const uint64_t pattern = 0xaaaaaaaaaaaaaaaa; - const uint64_t mask_h = pattern; // 0b101010101010101001010 - const uint64_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - for (size_t j = 0; j < num_coeffs; j++) - { - xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; - xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); - - // tL coefficient obtained by evaluating on X_i=1 - tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; - - // tM coefficient obtained by evaluating on X_i=\alpha - tM = coeffsL[j] ^ coeffsR[j] ^ mult; - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; - - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - coeffsL[j] = tL; - coeffsM[j] = tM; - } - } - - void fft_recursive_uint32( - span coeffs, - const size_t num_vars, - const size_t num_coeffs) - { - // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) - - if (num_vars > 1) - { - // apply FFT on all left coefficients - fft_recursive_uint32( - coeffs, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all middle coefficients - fft_recursive_uint32( - coeffs.subspan(num_coeffs), - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all right coefficients - fft_recursive_uint32( - coeffs.subspan(2 * num_coeffs), - num_vars - 1, - num_coeffs / 3); - } - - // temp variables to store intermediate values - uint32_t tL, tM; - uint32_t mult, xor_h, xor_l; - - uint32_t* coeffsL = &coeffs[0]; - uint32_t* coeffsM = &coeffs[num_coeffs]; - uint32_t* coeffsR = &coeffs[2 * num_coeffs]; - - const uint32_t pattern = 0xaaaaaaaa; - const uint32_t mask_h = pattern; // 0b101010101010101001010 - const uint32_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - for (size_t j = 0; j < num_coeffs; j++) - { - xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; - xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); - - // tL coefficient obtained by evaluating on X_i=1 - tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; - - // tM coefficient obtained by evaluating on X_i=\alpha - tM = coeffsL[j] ^ coeffsR[j] ^ mult; - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; - - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - coeffsL[j] = tL; - coeffsM[j] = tM; - } - } - - void fft_recursive_uint16( - span coeffs, - const size_t num_vars, - const size_t num_coeffs) - { - // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) - - if (num_vars > 1) - { - // apply FFT on all left coefficients - fft_recursive_uint16( - coeffs, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all middle coefficients - fft_recursive_uint16( - coeffs.subspan(num_coeffs), - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all right coefficients - fft_recursive_uint16( - coeffs.subspan(2 * num_coeffs), - num_vars - 1, - num_coeffs / 3); - } - - // temp variables to store intermediate values - uint16_t tL, tM; - uint16_t mult, xor_h, xor_l; - - uint16_t* coeffsL = &coeffs[0]; - uint16_t* coeffsM = &coeffs[num_coeffs]; - uint16_t* coeffsR = &coeffs[2 * num_coeffs]; - - const uint16_t pattern = 0xaaaa; - const uint16_t mask_h = pattern; // 0b101010101010101001010 - const uint16_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - for (size_t j = 0; j < num_coeffs; j++) - { - xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; - xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); - - // tL coefficient obtained by evaluating on X_i=1 - tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; - - // tM coefficient obtained by evaluating on X_i=\alpha - tM = coeffsL[j] ^ coeffsR[j] ^ mult; - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; - - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - coeffsL[j] = tL; - coeffsM[j] = tM; - } - } - - void fft_recursive_uint8( - span coeffs, - const size_t num_vars, - const size_t num_coeffs) - { - // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) - - if (num_vars > 1) - { - // apply FFT on all left coefficients - fft_recursive_uint8( - coeffs, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all middle coefficients - fft_recursive_uint8( - coeffs.subspan(num_coeffs), - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all right coefficients - fft_recursive_uint8( - coeffs.subspan(2 * num_coeffs), - num_vars - 1, - num_coeffs / 3); - } - - // temp variables to store intermediate values - uint8_t tL, tM; - uint8_t mult, xor_h, xor_l; - - uint8_t* coeffsL = &coeffs[0]; - uint8_t* coeffsM = &coeffs[num_coeffs]; - uint8_t* coeffsR = &coeffs[2 * num_coeffs]; - - const uint8_t pattern = 0xaa; - const uint8_t mask_h = pattern; // 0b101010101010101001010 - const uint8_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - for (size_t j = 0; j < num_coeffs; j++) - { - xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; - xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); - - // tL coefficient obtained by evaluating on X_i=1 - tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; - - // tM coefficient obtained by evaluating on X_i=\alpha - tM = coeffsL[j] ^ coeffsR[j] ^ mult; - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; - - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - coeffsL[j] = tL; - coeffsM[j] = tM; - } - } - -} \ No newline at end of file diff --git a/libOTe/Tools/Foliage/fft/FoliageFft.h b/libOTe/Tools/Foliage/fft/FoliageFft.h deleted file mode 100644 index ffbff8e9..00000000 --- a/libOTe/Tools/Foliage/fft/FoliageFft.h +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once - -#include -#include -#include "cryptoTools/Common/Defines.h" - -//#include "libOTe/Tools/Foliage/utils.h" -namespace osuCrypto { - - //typedef __int128 int128_t; - //typedef unsigned __int128 uint128_t; - - // FFT for (up to) 32 polynomials over F4 - void fft_recursive_uint64( - span coeffs, - const size_t num_vars, - const size_t num_coeffs); - - // FFT for (up to) 16 polynomials over F4 - void fft_recursive_uint32( - span coeffs, - const size_t num_vars, - const size_t num_coeffs); - - // FFT for (up to) 8 polynomials over F4 - void fft_recursive_uint16( - span coeffs, - const size_t num_vars, - const size_t num_coeffs); - - // FFT for (up to) 4 polynomials over F4 - void fft_recursive_uint8( - span coeffs, - const size_t num_vars, - const size_t num_coeffs); - -} diff --git a/libOTe_Tests/CMakeLists.txt b/libOTe_Tests/CMakeLists.txt index e92b0e08..8ee65aa7 100644 --- a/libOTe_Tests/CMakeLists.txt +++ b/libOTe_Tests/CMakeLists.txt @@ -14,7 +14,7 @@ set(SRCS TungstenCode_Tests.cpp UnitTests.cpp Vole_Tests.cpp - Foliage_Tests.cpp + Foleage_Tests.cpp ) add_library(libOTe_Tests STATIC ${SRCS}) diff --git a/libOTe_Tests/Foliage_Tests.cpp b/libOTe_Tests/Foleage_Tests.cpp similarity index 76% rename from libOTe_Tests/Foliage_Tests.cpp rename to libOTe_Tests/Foleage_Tests.cpp index 922f7016..5066ec67 100644 --- a/libOTe_Tests/Foliage_Tests.cpp +++ b/libOTe_Tests/Foleage_Tests.cpp @@ -1,12 +1,14 @@ -#include "Foliage_Tests.h" -#include "libOTe/Tools/Foliage/tri-dpf/FoliageDpf.h" -#include "libOTe/Tools/Foliage/fft/FoliageFft.h" -//#include "libOTe/Tools/Foliage/tri-dpf/FoliageHalfDpf.h" -#include "libOTe/Tools/Foliage/F4Ops.h" +#include "Foleage_Tests.h" +#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" +#include "libOTe/Tools/Foleage/fft/FoleageFft.h" +//#include "libOTe/Tools/Foleage/tri-dpf/FoleageHalfDpf.h" +#include "libOTe/Tools/Foleage/F4Ops.h" #include "cryptoTools/Common/Matrix.h" -#include "libOTe/Tools/Foliage/FoliagePcg.h" +#include "libOTe/Tools/Foleage/FoleagePcg.h" #include "coproto/Socket/LocalAsyncSock.h" +#include "libOTe/Tools/Foleage/PerfectShuffle.h" +#include "cryptoTools/Common/Timer.h" namespace osuCrypto { //u8 extractF4(const uint128_t& val, u8 idx) @@ -149,7 +151,411 @@ namespace osuCrypto } } - void foliage_spfss_test() + void foleage_transpose_test(const oc::CLP& cmd) + { + { + + std::vector v(3 * 8); + std::vector v2(3 * 8); + + for (u64 i = 0; i < v.size(); ++i) + { + v[i] = i; + } + + + // input: + // 0 1 2 3 4 5 6 7 + // 8 9 10 11 12 13 14 15 + // 16 17 18 19 20 21 22 23 + // + // output: + // 0 3 6 9 12 15 18 21 + // 1 4 7 10 13 16 19 22 + // 2 5 8 11 14 17 20 23 + //printShuffle3(v.data()); + foleageTransposeLeaf<2>((u8*)v.data(), (__m128i*)v2.data()); + //printShuffle3(v2.data()); + + for (u64 i = 0; i < v2.size(); ++i) + { + auto e = i * 3 % 24 + (i / 8); + if (v2[i] != e) + throw RTE_LOC; + + } + + } + + { + int randomize = 1;// 241234123; // set to 1 to make debuggable + + std::vector v(9 * 8); + std::vector v2(9 * 8); + + + for (u64 i = 0; i < v.size(); ++i) + { + v[i] = i * randomize; + } + + + //std::cout << "\n"; + //printShuffle3(v.data()); + //std::cout << "\n"; + //printShuffle3(v.data() + 3 * 8); + //std::cout << "\n"; + //printShuffle3(v.data() + 6 * 8); + //std::cout << "--------------\n"; + + ////dst[i * 3 + j] = a0; + foleageTranspose<2>((u8*)v.data(), (__m128i*)v2.data()); + + + //printShuffle9(v2.data()); + //std::cout << "\n"; + + + // 0 1 2 3 4 5 6 7 + // 8 9 10 11 12 13 14 15 + // 16 17 18 19 20 21 22 23 + // + // 24 25 26 27 28 29 30 31 + // 32 33 34 35 36 37 38 39 + // 40 41 42 43 44 45 46 47 + // + // 48 49 50 51 52 53 54 55 + // 56 57 58 59 60 61 62 63 + // 64 65 66 67 68 69 70 71 + + + // 0 8 16 3 11 19 6 14 22 25 33 41 28 36 44 31 39 47 50 58 66 53 61 69 + // 1 9 17 4 12 20 7 15 23 26 34 42 29 37 45 48 56 64 51 59 67 54 62 70 + // 2 10 18 5 13 21 24 32 40 27 45 43 30 38 46 49 57 65 52 60 68 55 63 71 + //std::cout << std::endl; + + std::vector> exp(3); + for (u64 i = 0, k = 0; i < 3; ++i) + { + for (u64 j = 0; j < 24; ++j, ++k) + { + auto row = j / 8; + exp[row].push_back(k * randomize); + } + } + //std::cout << "before\n"; + //for (u64 i = 0; i < 3; ++i) + //{ + // for (u64 j = 0; j < 24; ++j) + // { + // //std::cout << v2[i * 24 + j] << " "; + // std::cout << std::setw(2) << std::setfill(' ') << exp[i][j] << " "; + // } + // std::cout << std::endl; + //} + + + for (u64 i = 0; i < 3; ++i) + { + for (u64 j = 0; j < 8; ++j) + { + auto b = j * 3; + for (u64 k = 0; k < 3; ++k) + { + for (u64 l = 0; l < k; ++l) + { + std::swap(exp[k][b + l], exp[l][b + k]); + } + } + } + } + //std::cout << "after\n"; + //for (u64 i = 0; i < 3; ++i) + //{ + // for (u64 j = 0; j < 24; ++j) + // { + // //std::cout << v2[i * 24 + j] << " "; + // std::cout << std::setw(2) << std::setfill(' ') << exp[i][j] << " "; + // } + // std::cout << std::endl; + //} + + for (u64 i = 0; i < 3; ++i) + { + for (u64 j = 0; j < 24; ++j) + { + if (exp[i][j] != v2[i * 24 + j]) + throw RTE_LOC; + } + } + + //printShuffle9(v.data()); + //foleageTranspose<2>((u8*)v2.data(), (__m128i*)v.data()); + + //for (u64 i = 0; i < v.size(); ++i) + //{ + // if(v[i] != i * randomize) + // throw RTE_LOC; + //} + } + + { + int randomize = 241234123; // set to 1 to make debuggable + + std::vector v(9 * 8); + std::vector v2(9 * 8); + + + for (u64 i = 0; i < v.size(); ++i) + { + v[i] = i * randomize; + } + //std::cout << "in\n" << std::endl; + //printShuffle9(v.data()); + + + foleageTransposeLeaf<2>((u8*)&v[0], (__m128i*)& v2[0]); + foleageTransposeLeaf<2>((u8*)&v[3 * 8], (__m128i*)& v2[3 * 8]); + foleageTransposeLeaf<2>((u8*)&v[6 * 8], (__m128i*)& v2[6 * 8]); + + //std::cout << "l1\n" << std::endl; + //printShuffle9(v2.data()); + + foleageTranspose<2>((u8*)v2.data(), (__m128i*)v.data()); + + + //std::cout << "l2\n" << std::endl; + //printShuffle9(v.data()); + + //std::vector inverse(v.size()); + //for (u64 i = 0; i < v.size(); ++i) + //{ + // inverse[v[i]] = i; + //} + + + foliageUnTranspose<2>((u8*)v.data(), (__m128i*)v2.data()); + + //std::cout << "inv\n" << std::endl; + //for (u64 i = 0; i < inverse.size(); ++i) + //{ + // std::cout << std::setw(2) << std::setfill(' ') << inverse[i] << ", "; + //} + ////printShuffle9(inverse.data()); + //std::cout << "f\n"; + //printShuffle9(v2.data()); + + for (u64 i = 0; i < v.size(); ++i) + { + if (v2[i] != u16(i * randomize)) + throw RTE_LOC; + } + } + + if(0) + { + + u64 trials = 1000000; + int randomize = 241234123; // set to 1 to make debuggable + + u64 ss = 9; + std::vector lsb(ss * trials), msb(ss * trials); + std::vector lsb2(ss * trials), msb2(ss * trials); + + PRNG prng(block(342134213421, 2341234123421)); + prng.get(lsb.data(), lsb.size()); + prng.get(msb.data(), msb.size()); + + + //for (u64 i = 0; i < 3 * 24; ++i) + //{ + // ((u16*)lsb.data())[i] = i * randomize; + // ((u16*)msb.data())[i] = i * randomize ^ 2134123423; + //} + std::cout << "in\n" << std::endl; + //printShuffle9(v.data()); + Timer t; + t.setTimePoint("b"); + + auto l = (u16*)lsb.data(); + auto m = (u16*)msb.data(); + for (u64 i = 0; i < trials * 8; ++i) + { + for (u64 j = 0; j < 3; ++j) + { + foleageFFTOne<1>( + &l[i * ss + j * 3 + 0], &m[i * ss + j * 3 + 0], + &l[i * ss + j * 3 + 1], &m[i * ss + j * 3 + 1], + &l[i * ss + j * 3 + 2], &m[i * ss + j * 3 + 2] + ); + } + + + for (u64 j = 0; j < 3; ++j) + { + foleageFFTOne<2>( + &l[i * ss + 0 * 3 + j], &m[i * ss + 0 * 3 + j], + &l[i * ss + 1 * 3 + j], &m[i * ss + 1 * 3 + j], + &l[i * ss + 2 * 3 + j], &m[i * ss + 2 * 3 + j] + ); + } + } + + t.setTimePoint("o"); + for (u64 i = 0; i < trials; ++i) + { + auto bLsb = lsb.data() + i * ss; + auto bMsb = msb.data() + i * ss; + auto bLsb2 = lsb2.data() + i * ss; + auto bMsb2 = msb2.data() + i * ss; + for (u64 j = 0; j < 3; ++j) + { + + foleageTransposeLeaf<2>((u8*)& bLsb[j * 3], (__m128i*)& bLsb[j * 3]); + foleageTransposeLeaf<2>((u8*)& bMsb[j * 3], (__m128i*)& bMsb[j * 3]); + foleageFFTOne<1>( + &bLsb2[j * 3 + 0], &bLsb2[j * 3 + 0], + &bLsb2[j * 3 + 1], &bLsb2[j * 3 + 1], + &bLsb2[j * 3 + 2], &bLsb2[j * 3 + 2] + ); + } + + foleageTranspose<2>((u8*)&bLsb2[0], (__m128i*)bLsb); + + foleageTranspose<2>((u8*)&bMsb2[0], (__m128i*)bMsb); + + foleageFFTOne<3,block>( + &bLsb[0], &bMsb[0], + &bLsb[3], &bMsb[3], + &bLsb[6], &bMsb[6] + ); + + } + t.setTimePoint("e"); + + std::cout << t << std::endl; + + } + } + + void foleage_fft_test(const oc::CLP& cmd) + { + PRNG prng(block(342134213421, 2341234123421)); + u64 nn = 14; + u64 n = ipow(3, nn); + Timer timer; + u64 trials = cmd.getOr("trials", 1); + + if (0) + { + + std::vector a(n); + std::vector lsb(n); + std::vector msb(n); + + prng.get(a.data(), a.size()); + for (u64 i = 0; i < n; ++i) + { + lsb[i] = + (a[i] >> 0) & 1 | + (a[i] >> 1) & 2 | + (a[i] >> 2) & 4 | + (a[i] >> 3) & 8; + auto m = a[i] >> 1; + msb[i] = + (m >> 0) & 1 | + (m >> 1) & 2 | + (m >> 2) & 4 | + (m >> 3) & 8; + } + + timer.setTimePoint("begin"); + fft_recursive_uint8(a, nn, n / 3); + timer.setTimePoint("fft_recursive_uint8"); + foleageFFT(lsb.data(), msb.data(), nn, n / 3); + timer.setTimePoint("foleageFFT 8 bit"); + + for (u64 i = 0; i < n; ++i) + { + auto a0 = + (a[i] >> 0) & 1 | + (a[i] >> 1) & 2 | + (a[i] >> 2) & 4 | + (a[i] >> 3) & 8; + auto m = a[i] >> 1; + auto a1 = + (m >> 0) & 1 | + (m >> 1) & 2 | + (m >> 2) & 4 | + (m >> 3) & 8; + + if (a0 != lsb[i] || a1 != msb[i]) + throw RTE_LOC; + } + + } + { + + std::vector a(n), a2(n); + oc::Matrix lsb(n, 2); + oc::Matrix msb(n, 2); + + prng.get(a.data(), a.size()); + + auto av = span((u8*)a.data(), n * 4); + auto av2 = span((u8*)a2.data(), n * 4); + + perfectUnshuffle(av, lsb, msb); + + auto lsb2 = lsb; + auto msb2 = msb; + + timer.setTimePoint("beign"); + for (u64 i = 0; i < trials; ++i) + fft_recursive_uint32(a, nn, n / 3); + timer.setTimePoint("fft_recursive_uint32"); + + + if (0) + { + + for (u64 i = 0; i < trials; ++i) + foleageFFT(lsb.data(), msb.data(), nn, 2 * n / 3); + timer.setTimePoint("foleageFFT 32bit"); + + perfectShuffle(lsb, msb, av2); + for (u64 i = 0; i < n; ++i) + { + if (a[i] != a2[i]) + throw RTE_LOC; + } + timer.setTimePoint("foleageFFT 32bit check"); + } + + if (1) + { + + for (u64 i = 0; i < trials; ++i) + foleageFFT2<2>(lsb2, msb2); + + timer.setTimePoint("foleageFFT2 32bit"); + + perfectShuffle(lsb2, msb2, av2); + for (u64 i = 0; i < n; ++i) + { + if (a[i] != a2[i]) + throw RTE_LOC; + } + timer.setTimePoint("foleageFFT2 32bit check"); + } + + std::cout << timer << std::endl; + + } + + } + + void foleage_spfss_test() { size_t SUMT = 730;// sum of T DPFs @@ -221,7 +627,7 @@ namespace osuCrypto } - void foliage_dpf_test() + void foleage_dpf_test() { const size_t size = 14; // evaluation will result in 3^size points const size_t msg_len = 2; @@ -295,7 +701,7 @@ namespace osuCrypto // This test evaluates the full PCG.Expand for both parties and // checks correctness of the resulting OLE correlation. - void foliage_pcg_test(const CLP& cmd) + void foleage_pcg_test(const CLP& cmd) { bool check = !cmd.isSet("noCheck"); auto N = 12; // 3^N number of OLEs generated in total @@ -981,17 +1387,20 @@ namespace osuCrypto // This test evaluates the full PCG.Expand for both parties and // checks correctness of the resulting OLE correlation. - void foliage_F4ole_test(const CLP& cmd) + void foleage_F4ole_test(const CLP& cmd) { - std::array oles; + std::array oles; auto logn = 12; u64 n = ipow(3, logn); auto blocks = divCeil(n, 128); + bool verbose = cmd.isSet("v"); + //PRNG prng(block(342342)); PRNG prng0(block(2424523452345, 111124521521455324)); PRNG prng1(block(6474567454546, 567546754674345444)); - + Timer timer; + oles[0].init(0, n, prng0); oles[1].init(1, n, prng1); auto sock = coproto::LocalAsyncSocket::makePair(); @@ -1005,6 +1414,9 @@ namespace osuCrypto C1Lsb(blocks), C1Msb(blocks); + if(verbose) + oles[0].setTimer(timer); + auto r = macoro::sync_wait(macoro::when_all_ready( oles[0].expand(ALsb, AMsb, C0Lsb, C0Msb, prng0, sock[0]), oles[1].expand(BLsb, BMsb, C1Lsb, C1Msb, prng1, sock[1]))); @@ -1019,7 +1431,7 @@ namespace osuCrypto auto aMsb = C0Msb[i] ^ C1Msb[i]; block mLsb, mMsb; f4Mult( - ALsb[i], AMsb[i], + ALsb[i], AMsb[i], BLsb[i], BMsb[i], mLsb, mMsb); @@ -1028,5 +1440,8 @@ namespace osuCrypto if (aMsb != mMsb) throw RTE_LOC; } + + if (verbose) + std::cout << "Time taken: \n" << timer << std::endl; } } \ No newline at end of file diff --git a/libOTe_Tests/Foleage_Tests.h b/libOTe_Tests/Foleage_Tests.h new file mode 100644 index 00000000..33b8cfa5 --- /dev/null +++ b/libOTe_Tests/Foleage_Tests.h @@ -0,0 +1,14 @@ +#pragma once +#include "cryptoTools/Common/CLP.h" +namespace osuCrypto +{ + void foleage_transpose_test(const oc::CLP& cmd); + void foleage_fft_test(const oc::CLP& cmd); + + void foleage_spfss_test(); + void foleage_dpf_test(); + void foleage_pcg_test(const CLP& cmd); + void foleage_F4ole_test(const CLP& cmd); + + +} \ No newline at end of file diff --git a/libOTe_Tests/Foliage_Tests.h b/libOTe_Tests/Foliage_Tests.h deleted file mode 100644 index 46b69053..00000000 --- a/libOTe_Tests/Foliage_Tests.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once -#include "cryptoTools/Common/CLP.h" -namespace osuCrypto -{ - - void foliage_spfss_test(); - void foliage_dpf_test(); - void foliage_pcg_test(const CLP& cmd); - void foliage_F4ole_test(const CLP& cmd); - - -} \ No newline at end of file diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index bccb9aef..d31b2060 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -4,6 +4,8 @@ #include "libOTe/Tools/Dpf/SparseDpf.h" #include #include +#include "libOTe/Tools/Dpf/TriDpf.h" + using namespace oc; void RegularDpf_Multiply_Test(const CLP& cmd) @@ -284,3 +286,81 @@ void SparseDpf_Proto_Test(const oc::CLP& cmd) } } } + +void TritDpf_Proto_Test(const oc::CLP& cmd) +{ + + PRNG prng(block(231234, 321312)); + u64 depth = 3; + u64 domain = ipow(3,depth); + u64 numPoints = 11; + std::vector points0(numPoints); + std::vector points1(numPoints); + std::vector values0(numPoints); + std::vector values1(numPoints); + for (u64 i = 0; i < numPoints; ++i) + { + points1[i] = prng.get(); + points0[i] = (prng.get() % domain) ^ points1[i]; + values0[i] = prng.get(); + values1[i] = prng.get(); + } + + std::array dpf; + dpf[0].init(0, domain, numPoints); + dpf[1].init(1, domain, numPoints); + + auto baseCount = dpf[0].baseOtCount(); + + std::array, 2> baseRecv; + std::array>, 2> baseSend; + std::array baseChoice; + baseRecv[0].resize(baseCount); + baseRecv[1].resize(baseCount); + baseSend[0].resize(baseCount); + baseSend[1].resize(baseCount); + baseChoice[0].resize(baseCount); + baseChoice[1].resize(baseCount); + baseChoice[0].randomize(prng); + baseChoice[1].randomize(prng); + for (u64 i = 0; i < baseCount; ++i) + { + baseSend[0][i] = prng.get(); + baseSend[1][i] = prng.get(); + baseRecv[0][i] = baseSend[1][i][baseChoice[0][i]]; + baseRecv[1][i] = baseSend[0][i][baseChoice[1][i]]; + } + dpf[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); + dpf[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); + + std::array, 2> output; + std::array, 2> tags; + output[0].resize(numPoints, domain); + output[1].resize(numPoints, domain); + tags[0].resize(numPoints, domain); + tags[1].resize(numPoints, domain); + + auto sock = coproto::LocalAsyncSocket::makePair(); + macoro::sync_wait(macoro::when_all_ready( + dpf[0].expand(points0, values0, [&](auto k, auto i, auto v, auto t) { output[0](k, i) = v; tags[0](k, i) = t; }, prng, sock[0]), + dpf[1].expand(points1, values1, [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; tags[1](k, i) = t; }, prng, sock[1]) + )); + + + for (u64 i = 0; i < domain; ++i) + { + for (u64 k = 0; k < numPoints; ++k) + { + auto p = points0[k] ^ points1[k]; + auto act = output[0][k][i] ^ output[1][k][i]; + auto t = i == p ? 1 : 0; + auto tAct = tags[0][k][i] ^ tags[1][k][i]; + auto exp = t ? (values0[k] ^ values1[k]) : ZeroBlock; + if (exp != act) + throw RTE_LOC; + if (t != tAct) + throw RTE_LOC; + } + } + +} diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 53ffe1aa..3bb9115c 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -17,7 +17,7 @@ #include "libOTe_Tests/Pprf_Tests.h" #include "libOTe_Tests/TungstenCode_Tests.h" #include "libOTe_Tests/RegularDpf_Tests.h" -#include "libOTe_Tests/Foliage_Tests.h" +#include "libOTe_Tests/Foleage_Tests.h" using namespace osuCrypto; namespace tests_libOTe @@ -63,10 +63,13 @@ namespace tests_libOTe tc.add("RegularDpf_Proto_Test ", RegularDpf_Proto_Test); tc.add("SparseDpf_Proto_Test ", SparseDpf_Proto_Test); - tc.add("foliage_dpf_test ", foliage_dpf_test); - tc.add("foliage_spfss_test ", foliage_spfss_test); - tc.add("foliage_pcg_test ", foliage_pcg_test); - tc.add("foliage_F4ole_test ", foliage_F4ole_test); + + tc.add("foleage_transpose_test ", foleage_transpose_test); + tc.add("foleage_fft_test ", foleage_fft_test); + tc.add("foleage_dpf_test ", foleage_dpf_test); + tc.add("foleage_spfss_test ", foleage_spfss_test); + tc.add("foleage_pcg_test ", foleage_pcg_test); + tc.add("foleage_F4ole_test ", foleage_F4ole_test); tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); From d1d92e50889d69f32c44838133a8dd99650620e7 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sat, 8 Feb 2025 17:47:06 -0800 Subject: [PATCH 07/48] partial --- CMakePresets.json | 4 +- cryptoTools | 2 +- frontend/ExampleTwoChooseOne.cpp | 2 +- libOTe/Tools/Dpf/RegularDpf.h | 2 +- libOTe/Tools/Dpf/TriDpf.h | 405 +++++++++++++++--------- libOTe/Tools/Foleage/FoleagePcg.cpp | 4 +- libOTe/Tools/Foleage/fft/FoleageFft.cpp | 2 +- libOTe/Tools/Foleage/fft/FoleageFft.h | 2 +- libOTe_Tests/Foleage_Tests.cpp | 2 +- libOTe_Tests/RegularDpf_Tests.cpp | 10 +- libOTe_Tests/RegularDpf_Tests.h | 1 + libOTe_Tests/UnitTests.cpp | 1 + 12 files changed, 279 insertions(+), 158 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index fc21c0f1..23a7fc71 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -64,8 +64,8 @@ "ENABLE_SIMPLESTOT": "ON", "ENABLE_GMP": false, "ENABLE_RELIC": false, - "ENABLE_SODIUM": false, - "ENABLE_BOOST": false, + "ENABLE_SODIUM": true, + "ENABLE_BOOST": true, "ENABLE_BITPOLYMUL": true, "FETCH_AUTO": "ON", "ENABLE_CIRCUITS": true, diff --git a/cryptoTools b/cryptoTools index d92336ec..2bf5fe84 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit d92336ecde55fcd0918def7efda948e41b510965 +Subproject commit 2bf5fe84e19cadd9aeea5c191a08ac59e65b54e7 diff --git a/frontend/ExampleTwoChooseOne.cpp b/frontend/ExampleTwoChooseOne.cpp index 1a1fb152..d1c7e5b2 100644 --- a/frontend/ExampleTwoChooseOne.cpp +++ b/frontend/ExampleTwoChooseOne.cpp @@ -70,7 +70,7 @@ namespace osuCrypto if (totalOTs == 0) totalOTs = 1 << 20; - bool randomOT = true; + bool randomOT = false; // get up the networking auto chl = cp::asioConnect(ip, role == Role::Sender); diff --git a/libOTe/Tools/Dpf/RegularDpf.h b/libOTe/Tools/Dpf/RegularDpf.h index b121edd3..62af644f 100644 --- a/libOTe/Tools/Dpf/RegularDpf.h +++ b/libOTe/Tools/Dpf/RegularDpf.h @@ -40,7 +40,7 @@ namespace osuCrypto if (!numPoints) throw RTE_LOC; - mDepth = oc::log3ceil(domain); + mDepth = log2ceil(domain); mPartyIdx = partyIdx; mDomain = domain; mNumPoints = numPoints; diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index af4e23a8..fd2b90dc 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -12,6 +12,82 @@ namespace osuCrypto { + // a value representing (Z_3)^32. + // The value is stored in 2 bits per Z_3 element. + struct Trit32 + { + u64 mVal = 0; + + Trit32() = default; + Trit32(const Trit32&) = default; + + Trit32(u64 v) + { + fromInt(v); + } + + Trit32& operator=(const Trit32&) = default; + + Trit32 operator+(const Trit32& t) const + { + u64 msbMask, lsbMask; + setBytes(msbMask, 0b10101010); + setBytes(lsbMask, 0b01010101); + + auto x0 = mVal; + auto x1 = mVal >> 1; + auto y0 = t.mVal; + auto y1 = t.mVal >> 1; + + + auto x1x0 = x1 ^ x0; + auto z1 = (y0 ^ x0) & ~(x1x0 ^ y1); + auto z0 = (x1 ^ y1) & ~(x1x0 ^ y0); + + Trit32 r; + r.mVal = ((z1 << 1) & msbMask) | (z0 & lsbMask); + + for (u64 i = 0; i < 32; ++i) + { + auto a = (mVal >> (i * 2)) & 3; + auto b = (mVal >> (i * 2)) & 3; + auto c = (a + b) % 3; + if (c != ((r.mVal >> (i * 2)) & 3)) + throw RTE_LOC; + } + return r; + } + + u64 toInt() const + { + u64 r = 0; + for (u64 i = 31; i < 32; --i) + { + r *= 3; + r |= (mVal >> (i * 2)) & 3; + } + + return r; + } + + void fromInt(u64 v) + { + mVal = 0; + for (u64 i = 0; i < 32; ++i) + { + mVal |= (v % 3) << (i * 2); + v /= 3; + } + } + + + // returns the i'th Z_3 element. + u8 operator[](u64 i) + { + return (mVal >> (i*2)) & 3; + } + }; + struct TriDpf { enum class OutputFormat @@ -36,13 +112,6 @@ namespace osuCrypto u64 mNumPoints = 0; - //DpfMult mMultiplier; - - u8 lsb(const block& b) - { - return b.get(0) & 1; - } - void init( u64 partyIdx, u64 domain, @@ -62,6 +131,15 @@ namespace osuCrypto //mMultiplier.init(partyIdx, numPoints * mDepth); } + // returns something similar to b % 3. + u8 trit(block b) + { + auto v = b.get(0); + return + static_cast(v > 6148914691236517205ull) + + static_cast(v > 12297829382473034410ull); + } + #define SIMD8(VAR, STATEMENT) \ { constexpr u64 VAR = 0; STATEMENT; }\ { constexpr u64 VAR = 1; STATEMENT; }\ @@ -77,7 +155,7 @@ namespace osuCrypto typename Output > macoro::task<> expand( - span points, + span points, span values, Output&& output, PRNG& prng, @@ -97,14 +175,14 @@ namespace osuCrypto for (u64 i = 0; i < mNumPoints; ++i) { - u64 v = points[i]; + u64 v = points[i].mVal; for (u64 j = 0; j < mDepth; ++j) { if ((v & 3) == 3) throw std::runtime_error("TriDpf: invalid point sharing. Expects the input points to be shared over Z_3^D where each Z_3 elements takes up 2 bits of a the value. " LOCATION); v >>= 2; } - if(v) + if (v) throw std::runtime_error("TriDpf: invalid point sharing. point is larger than 3^D " LOCATION); } @@ -131,32 +209,46 @@ namespace osuCrypto #else auto getRow = [](auto&& m, u64 i) {return m[i]; }; #endif - std::array, 2> tau; + std::array, 3> tau; tau[0].resize(mNumPoints); tau[1].resize(mNumPoints); - - std::array, 2> z; + tau[2].resize(mNumPoints); + std::array, 3> z; z[0].resize(mNumPoints); z[1].resize(mNumPoints); - AlignedUnVector sigma(mNumPoints); - BitVector negAlphaj(mNumPoints); - AlignedUnVector diff(mNumPoints); + z[2].resize(mNumPoints); + std::array, 3> v; + v[0].resize(mNumPoints); + v[1].resize(mNumPoints); + v[2].resize(mNumPoints); + std::array, 3> sigma; + sigma[0].resize(mNumPoints); + sigma[1].resize(mNumPoints); + sigma[2].resize(mNumPoints); { // we skip level 0 and set level 1 to be random auto sc0 = s[1][0]; auto sc1 = s[1][1]; + auto sc2 = s[1][2]; for (u64 k = 0; k < numPoints; ++k) { sc0[k] = prng.get(); sc1[k] = prng.get(); + sc2[k] = prng.get(); z[0][k] = sc0[k]; z[1][k] = sc1[k]; + z[2][k] = sc2[k]; } } + std::array aes{ + AES(block(324532455457855483,3575765667434524523)), + AES(block(456475435444364534,9923458239234989843)), + AES(block(324532450985209453,5387987243989842789)) }; + // at each iteration we first correct the parent level. // The parent level has two syblings which are random. // We need to correct the inactive child so that both parties @@ -166,123 +258,143 @@ namespace osuCrypto // We compute left and right sums for the children. for (u64 iter = 1; iter <= mDepth; ++iter) { - // the grand parent level - auto& tp = t[(iter - 1) & 1]; - // the parent level - auto& sc = s[iter & 1]; - auto& tc = t[iter & 1]; + auto& parentSeedBase = t[(iter - 1) & 1]; - // the child level - auto& sg = s[(iter + 1) & 1]; + // current level + auto& seedBase = s[iter & 1]; + auto& tagBase = t[iter & 1]; - auto size = 1ull << iter; + // the child level + auto& childSeedBase = s[(iter + 1) & 1]; + //auto& childTagBase = t[(iter + 1) & 1]; - // - for (u64 k = 0; k < mNumPoints; ++k) - { - auto alphaj = *oc::BitIterator(&points[k], mDepth - iter); - tau[0][k] = lsb(z[0][k]) ^ alphaj ^ mPartyIdx; - tau[1][k] = lsb(z[1][k]) ^ alphaj; - diff[k] = z[0][k] ^ z[1][k]; - negAlphaj[k] = alphaj ^ mPartyIdx; - } + auto size = ipow(3, iter); - co_await mMultiplier.multiply(negAlphaj, diff, diff, sock); - // sigma = z[1^alpha[j]] - for (u64 k = 0; k < mNumPoints; ++k) - sigma[k] = diff[k] ^ z[0][k]; - - // reveal sigma and tau - u64 buffSize = sigma.size() * 16 + divCeil(mNumPoints * 2, 8); - AlignedUnVector sendBuff(buffSize), recvBuff(buffSize); - copyBytesMin(sendBuff, sigma); - auto sendBitIter = BitIterator(&sendBuff[numPoints * 16]); - auto recvBitIter = BitIterator(&recvBuff[numPoints * 16]); - for (u64 i = 0; i < mNumPoints; ++i) - { - *sendBitIter++ = tau[0][i]; - *sendBitIter++ = tau[1][i]; - } - co_await sock.send(std::move(sendBuff)); - co_await sock.recv(recvBuff); - for (u64 k = 0; k < mNumPoints; ++k) { - block sk = *(block*)&recvBuff[k * sizeof(block)]; - sigma[k] ^= sk; - tau[0][k] ^= *recvBitIter++; - tau[1][k] ^= *recvBitIter++; - } + std::vector alphaj(numPoints); + std::vector zz(numPoints * 3); auto zzIter = zz.begin(); + std::vector vv(numPoints * 3); auto vvIter = vv.begin(); + for (u64 k = 0; k < numPoints; ++k) + { + alphaj[k] = points[k][mDepth - iter]; + } + for (u64 k = 0; k < 3; ++k) + { + copyBytes(span(zzIter, zzIter + numPoints), z[k]); zzIter += numPoints; + copyBytes(span(vvIter, vvIter + numPoints), v[k]); vvIter += numPoints; + } - if (iter != mDepth) - { + co_await sock.send(coproto::copy(alphaj)); + co_await sock.send(coproto::copy(zz)); + co_await sock.send(coproto::copy(vv)); - setBytes(z[0], 0); - setBytes(z[1], 0); + auto recvAlphaj = co_await sock.recv>(); + co_await sock.recv(zz); + co_await sock.recv(vv); - for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 2, L4 += 4) + zzIter = zz.begin(); + vvIter = vv.begin(); + for (u64 k = 0; k < 3; ++k) { - // parent control bits - auto tpl = getRow(tp, L); - - // child seed - std::array scl{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; - - // child control bit - std::array tcl{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; + for (u64 i = 0; i < numPoints; ++i) + { + sigma[k][i] = z[k][i] ^ *zzIter++; + tau[k][i] = v[k][i] ^ *vvIter++ ^ 1; + assert(v[k][i] < 2); + } + } - // grandchild seeds - std::array sgl{ getRow(sg, L4 + 0), getRow(sg, L4 + 1), getRow(sg, L4 + 2), getRow(sg, L4 + 3) }; + for (u64 i = 0; i < numPoints; ++i) + { + assert(recvAlphaj[i] < 3); + alphaj[i] = (alphaj[i] + recvAlphaj[i]) % 3; - for (u64 k = 0; k < numPoints8; k += 8) - { - block temp[8]; - SIMD8(q, temp[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); - SIMD8(q, tcl[0][k + q] = lsb(scl[0][k + q]) ^ tpl[k + q] & tau[0][k + q]); - SIMD8(q, scl[0][k + q] ^= temp[q]); + sigma[alphaj[i]][i] ^= oc::mAesFixedKey.ecbEncBlock(block(iter, i)); + tau[alphaj[i]][i] ^= 1; + } + } + if (iter != mDepth) + { + for (u64 i = 0; i < 3; ++i) + { + setBytes(z[i], 0); + setBytes(v[i], 0); + } - mAesFixedKey.ecbEncBlocks<8>(&scl[0][k], &sgl[1][k]); - SIMD8(q, sgl[0][k + q] = AES::roundEnc(sgl[1][k + q], scl[0][k + q])); - SIMD8(q, sgl[1][k + q] = sgl[1][k + q] + scl[0][k + q]); + // we iterate over each parent control bit. + // The parent has 3 "current" children. + // We will expand these three children into 9 gradchildren + for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 3, L4 += 9) + { + // parent control bits, one for each tree. + auto parentTag = getRow(parentSeedBase, L); - SIMD8(q, z[0][k + q] ^= sgl[0][k + q]); - SIMD8(q, z[1][k + q] ^= sgl[1][k + q]); + // child seed, three for each tree. + std::array seed{ getRow(seedBase, L2 + 0), getRow(seedBase, L2 + 1) , getRow(seedBase, L2 + 2) }; - SIMD8(q, tcl[1][k + q] = lsb(scl[1][k + q]) ^ tpl[k + q] & tau[1][k + q]); - SIMD8(q, scl[1][k + q] ^= temp[q]); + // child control bit, tree for each tree. + std::array tag{ getRow(tagBase, L2 + 0), getRow(tagBase, L2 + 1), getRow(tagBase, L2 + 2) }; - mAesFixedKey.ecbEncBlocks<8>(&scl[1][k], &sgl[3][k]); - SIMD8(q, sgl[2][k + q] = AES::roundEnc(sgl[3][k + q], scl[1][k + q])); - SIMD8(q, sgl[3][k + q] = sgl[3][k + q] + scl[1][k + q]); - SIMD8(q, z[0][k + q] ^= sgl[2][k + q]); - SIMD8(q, z[1][k + q] ^= sgl[3][k + q]); - } + // grandchild seeds, nine for each tree. + std::array childSeed; + for (u64 i = 0; i < 9; ++i) + childSeed[i] = getRow(childSeedBase, L4 + i); - for (u64 k = numPoints8; k < mNumPoints; ++k) - { - auto temp = block::allSame(-tpl[k + 0]) & sigma[k + 0]; + //for (u64 k = 0; k < numPoints8; k += 8) + //{ + // block temp[8]; + // SIMD8(q, temp[q] = block::allSame(-parentTag[k + q]) & sigma[k + q]); + // SIMD8(q, tag[0][k + q] = lsb(seed[0][k + q]) ^ parentTag[k + q] & tau[0][k + q]); + // SIMD8(q, seed[0][k + q] ^= temp[q]); - tcl[0][k] = lsb(scl[0][k]) ^ tpl[k] & tau[0][k]; - scl[0][k] ^= temp; - sgl[1][k] = mAesFixedKey.ecbEncBlock(scl[0][k]); - sgl[0][k] = AES::roundEnc(sgl[1][k], scl[0][k]); - sgl[1][k] = sgl[1][k] + scl[0][k]; + // mAesFixedKey.ecbEncBlocks<8>(&seed[0][k], &childSeed[1][k]); + // SIMD8(q, childSeed[0][k + q] = AES::roundEnc(childSeed[1][k + q], seed[0][k + q])); + // SIMD8(q, childSeed[1][k + q] = childSeed[1][k + q] + seed[0][k + q]); - z[0][k] ^= sgl[0][k]; - z[1][k] ^= sgl[1][k]; + // SIMD8(q, z[0][k + q] ^= childSeed[0][k + q]); + // SIMD8(q, z[1][k + q] ^= childSeed[1][k + q]); - tcl[1][k] = lsb(scl[1][k]) ^ tpl[k] & tau[1][k]; - scl[1][k] ^= temp; + // SIMD8(q, tag[1][k + q] = lsb(seed[1][k + q]) ^ parentTag[k + q] & tau[1][k + q]); + // SIMD8(q, seed[1][k + q] ^= temp[q]); - sgl[3][k] = mAesFixedKey.ecbEncBlock(scl[1][k]); - sgl[2][k] = AES::roundEnc(sgl[3][k], scl[1][k]); - sgl[3][k] = sgl[3][k] + scl[1][k]; + // mAesFixedKey.ecbEncBlocks<8>(&seed[1][k], &childSeed[3][k]); + // SIMD8(q, childSeed[2][k + q] = AES::roundEnc(childSeed[3][k + q], seed[1][k + q])); + // SIMD8(q, childSeed[3][k + q] = childSeed[3][k + q] + seed[1][k + q]); + // SIMD8(q, z[0][k + q] ^= childSeed[2][k + q]); + // SIMD8(q, z[1][k + q] ^= childSeed[3][k + q]); + //} + //auto& sigmaL = sigma[L % 3]; - z[0][k] ^= sgl[2][k]; - z[1][k] ^= sgl[3][k]; + for (u64 k = 0; k < mNumPoints; ++k) + { + for (u64 j = 0; j < 3; ++j) + { + // (s,t) = (s,t) ^ q * sigma_j + tag[j][k] = trit(seed[j][k]) ^ parentTag[k] & tau[j][k]; + seed[j][k] ^= block::allSame(-parentTag[k + 0]) & sigma[j][k + 0]; + + // + for (u64 i = 0; i < 3; ++i) + { + auto s = aes[i].hashBlock(seed[j][k]); + childSeed[j * 3 + i][k] = s; + z[i][k] ^= s; + } + } + + //tag[1][k] = lsb(seed[1][k]) ^ parentTag[k] & tau[1][k]; + //seed[1][k] ^= temp; + + //childSeed[3][k] = mAesFixedKey.ecbEncBlock(seed[1][k]); + //childSeed[2][k] = AES::roundEnc(childSeed[3][k], seed[1][k]); + //childSeed[3][k] = childSeed[3][k] + seed[1][k]; + + //z[0][k] ^= childSeed[2][k]; + //z[1][k] ^= childSeed[3][k]; } } } @@ -291,39 +403,45 @@ namespace osuCrypto // fixing the last layer { - auto size = 1ull << mDepth; + auto size = ipow(3, mDepth); - auto& tp = t[(mDepth - 1) & 1]; - auto& sc = s[mDepth & 1]; - auto& tc = t[mDepth & 1]; - for (u64 L = 0, L2 = 0; L2 < size; ++L, L2 += 2) + auto& parentTag = t[(mDepth - 1) & 1]; + auto& curSeed = s[mDepth & 1]; + auto& curTag = t[mDepth & 1]; + for (u64 L = 0, L2 = 0; L2 < size; ++L, L2 += 3) { // parent control bits - auto tpl = getRow(tp, L); + auto tpl = getRow(parentTag, L); // child seed - std::array scl{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; + std::array scl{ getRow(curSeed, L2 + 0), getRow(curSeed, L2 + 1) }; // child control bit - std::array tcl{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; - - for (u64 k = 0; k < numPoints8; k += 8) + std::array tcl{ getRow(curTag, L2 + 0), getRow(curTag, L2 + 1) }; + + //for (u64 k = 0; k < numPoints8; k += 8) + //{ + // block temp[8]; + // SIMD8(q, temp[q] = block::allSame(-parentTag[k + q]) & sigma[k + q]); + // SIMD8(q, tag[0][k + q] = lsb(seed[0][k + q]) ^ parentTag[k + q] & tau[0][k + q]); + // SIMD8(q, tag[1][k + q] = lsb(seed[1][k + q]) ^ parentTag[k + q] & tau[1][k + q]); + // SIMD8(q, seed[0][k + q] ^= temp[q]); + // SIMD8(q, seed[1][k + q] ^= temp[q]); + //} + + for (u64 k = 0; k < mNumPoints; ++k) { - block temp[8]; - SIMD8(q, temp[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); - SIMD8(q, tcl[0][k + q] = lsb(scl[0][k + q]) ^ tpl[k + q] & tau[0][k + q]); - SIMD8(q, tcl[1][k + q] = lsb(scl[1][k + q]) ^ tpl[k + q] & tau[1][k + q]); - SIMD8(q, scl[0][k + q] ^= temp[q]); - SIMD8(q, scl[1][k + q] ^= temp[q]); - } + for (u64 j = 0; j < 3; ++j) + { + curTag[L2 + j][k] = trit(scl[j][k]) ^ tpl[k] & tau[j][k]; + curSeed[L2 + j][k] ^= block::allSame(-tpl[k]) & sigma[j][k];; + } - for (u64 k = numPoints8; k < mNumPoints; ++k) - { - auto temp = block::allSame(-tpl[k + 0]) & sigma[k + 0]; - tc[L2 + 0][k] = lsb(scl[0][k]) ^ tpl[k] & tau[0][k]; - tc[L2 + 1][k] = lsb(scl[1][k]) ^ tpl[k] & tau[1][k]; - sc[L2 + 0][k] ^= temp; - sc[L2 + 1][k] ^= temp; + + //curTag[L2 + 0][k] = lsb(scl[0][k]) ^ tpl[k] & tau[0][k]; + //curTag[L2 + 1][k] = lsb(scl[1][k]) ^ tpl[k] & tau[1][k]; + //curSeed[L2 + 0][k] ^= block::allSame(-tpl[k + 0]) & sigma[k % 3][k + 0];; + //curSeed[L2 + 1][k] ^= temp; } } } @@ -331,7 +449,7 @@ namespace osuCrypto if (values.size()) { - AlignedUnVector gamma(mNumPoints); + AlignedUnVector gamma(mNumPoints), diff(mNumPoints); for (u64 k = 0; k < mNumPoints; ++k) { diff[k] = z[0][k] ^ z[1][k] ^ values[k]; @@ -350,14 +468,13 @@ namespace osuCrypto auto sdi = getRow(sd, i); auto tdi = getRow(td, i); - for (u64 k = 0; k < numPoints8; k += 8) - { - block T[8]; - - SIMD8(q, T[q] = block::allSame(-tdi[k + q]) & gamma[k + q]); - SIMD8(q, output(k + q, i, sdi[k + q] ^ T[q], tdi[k + q])); - } - for (u64 k = numPoints8; k < mNumPoints; ++k) + //for (u64 k = 0; k < numPoints8; k += 8) + //{ + // block T[8]; + // SIMD8(q, T[q] = block::allSame(-tdi[k + q]) & gamma[k + q]); + // SIMD8(q, output(k + q, i, sdi[k + q] ^ T[q], tdi[k + q])); + //} + for (u64 k = 0; k < mNumPoints; ++k) { auto T = block::allSame(-tdi[k]) & gamma[k]; output(k, i, sdi[k] ^ T, tdi[k]); @@ -386,7 +503,8 @@ namespace osuCrypto u64 baseOtCount() const { - return mMultiplier.baseOtCount(); + throw RTE_LOC; + //return mMultiplier.baseOtCount(); } void setBaseOts( @@ -394,7 +512,8 @@ namespace osuCrypto span recvBaseOts, const oc::BitVector& baseChoices) { - mMultiplier.setBaseOts(baseSendOts, recvBaseOts, baseChoices); + throw RTE_LOC; + //mMultiplier.setBaseOts(baseSendOts, recvBaseOts, baseChoices); } diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index 5e0a9e4d..e829ca9d 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -12,13 +12,13 @@ namespace osuCrypto void FoleageF4Ole::init(u64 partyIdx, u64 n, PRNG& prng) { mPartyIdx = partyIdx; - mLog3N = log3Ceil(n); + mLog3N = log2ceil(n); mN = ipow(3, mLog3N); if (mT != ipow(3, mLog3T)) throw RTE_LOC; - mDpfDomainDepth = std::max(1, log3Ceil(divCeil(mN, mT * 256))); + mDpfDomainDepth = std::max(1, log2ceil(divCeil(mN, mT * 256))); mDpfBlockSize = 4 * ipow(3, mDpfDomainDepth); mBlockSize = mN / mT; diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.cpp b/libOTe/Tools/Foleage/fft/FoleageFft.cpp index 62cb9042..146bb99a 100644 --- a/libOTe/Tools/Foleage/fft/FoleageFft.cpp +++ b/libOTe/Tools/Foleage/fft/FoleageFft.cpp @@ -806,7 +806,7 @@ namespace osuCrypto { { auto n = lsb.size() / stride; - auto log3N = log3Ceil(n); + auto log3N = log2ceil(n); if (n != ipow(3, log3N)) throw RTE_LOC; if (lsb.size() != n * stride) diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.h b/libOTe/Tools/Foleage/fft/FoleageFft.h index 3f005ce2..c73b7d01 100644 --- a/libOTe/Tools/Foleage/fft/FoleageFft.h +++ b/libOTe/Tools/Foleage/fft/FoleageFft.h @@ -375,7 +375,7 @@ namespace osuCrypto { auto numCoeffs = lsb.rows(); if (numCoeffs % 3) throw RTE_LOC; - auto numVars = log3Ceil(numCoeffs); + auto numVars = log3ceil(numCoeffs); foleageFFT(lsb.data(), msb.data(), numVars, lsb.size() / 3); } diff --git a/libOTe_Tests/Foleage_Tests.cpp b/libOTe_Tests/Foleage_Tests.cpp index 5066ec67..12fbaf65 100644 --- a/libOTe_Tests/Foleage_Tests.cpp +++ b/libOTe_Tests/Foleage_Tests.cpp @@ -743,7 +743,7 @@ namespace osuCrypto // We pack L=256 coefficients of F4 into each DPF output (note that larger // packing values are also okay, but they will do increase key size). //************************************************************************ - size_t dpf_domain_bits = log3Ceil(divCeil(poly_size, t * 256.0)); + size_t dpf_domain_bits = log3ceil(divCeil(poly_size, t * 256.0)); if (dpf_domain_bits == 0) dpf_domain_bits = 1; diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index d31b2060..c34017b6 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -294,14 +294,14 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) u64 depth = 3; u64 domain = ipow(3,depth); u64 numPoints = 11; - std::vector points0(numPoints); - std::vector points1(numPoints); + std::vector points0(numPoints); + std::vector points1(numPoints); std::vector values0(numPoints); std::vector values1(numPoints); for (u64 i = 0; i < numPoints; ++i) { - points1[i] = prng.get(); - points0[i] = (prng.get() % domain) ^ points1[i]; + points1[i] = Trit32(prng.get() % domain); + points0[i] = Trit32(prng.get() % domain) + points1[i]; values0[i] = prng.get(); values1[i] = prng.get(); } @@ -351,7 +351,7 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) { for (u64 k = 0; k < numPoints; ++k) { - auto p = points0[k] ^ points1[k]; + auto p = (points0[k] + points1[k]).toInt(); auto act = output[0][k][i] ^ output[1][k][i]; auto t = i == p ? 1 : 0; auto tAct = tags[0][k][i] ^ tags[1][k][i]; diff --git a/libOTe_Tests/RegularDpf_Tests.h b/libOTe_Tests/RegularDpf_Tests.h index b2304d9d..ccc85417 100644 --- a/libOTe_Tests/RegularDpf_Tests.h +++ b/libOTe_Tests/RegularDpf_Tests.h @@ -5,3 +5,4 @@ void RegularDpf_Multiply_Test(const oc::CLP& cmd); void RegularDpf_Proto_Test(const oc::CLP& cmd); void SparseDpf_Proto_Test(const oc::CLP& cmd); +void TritDpf_Proto_Test(const oc::CLP& cmd); \ No newline at end of file diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 3bb9115c..cf23a7e2 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -62,6 +62,7 @@ namespace tests_libOTe tc.add("RegularDpf_Multiply_Test ", RegularDpf_Multiply_Test); tc.add("RegularDpf_Proto_Test ", RegularDpf_Proto_Test); tc.add("SparseDpf_Proto_Test ", SparseDpf_Proto_Test); + tc.add("TritDpf_Proto_Test ", TritDpf_Proto_Test); tc.add("foleage_transpose_test ", foleage_transpose_test); From 56b46da8ec7a7c9751b815e85ccd0bf345a8d0fb Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 10 Feb 2025 17:32:49 -0800 Subject: [PATCH 08/48] tridpf partially working --- libOTe/Tools/Dpf/TriDpf.h | 334 +++++++++++++++++------------- libOTe_Tests/RegularDpf_Tests.cpp | 21 +- 2 files changed, 202 insertions(+), 153 deletions(-) diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index fd2b90dc..6d534c80 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -30,34 +30,58 @@ namespace osuCrypto Trit32 operator+(const Trit32& t) const { - u64 msbMask, lsbMask; - setBytes(msbMask, 0b10101010); - setBytes(lsbMask, 0b01010101); + //u64 msbMask, lsbMask; + //setBytes(msbMask, 0b10101010); + //setBytes(lsbMask, 0b01010101); - auto x0 = mVal; - auto x1 = mVal >> 1; - auto y0 = t.mVal; - auto y1 = t.mVal >> 1; + //auto x0 = mVal; + //auto x1 = mVal >> 1; + //auto y0 = t.mVal; + //auto y1 = t.mVal >> 1; - auto x1x0 = x1 ^ x0; - auto z1 = (y0 ^ x0) & ~(x1x0 ^ y1); - auto z0 = (x1 ^ y1) & ~(x1x0 ^ y0); + //auto x1x0 = x1 ^ x0; + //auto z1 = (y0 ^ x0) & ~(x1x0 ^ y1); + //auto z0 = (x1 ^ y1) & ~(x1x0 ^ y0); - Trit32 r; - r.mVal = ((z1 << 1) & msbMask) | (z0 & lsbMask); + //r.mVal = ((z1 << 1) & msbMask) | (z0 & lsbMask); + Trit32 r; for (u64 i = 0; i < 32; ++i) { - auto a = (mVal >> (i * 2)) & 3; - auto b = (mVal >> (i * 2)) & 3; + auto a = t[i]; + auto b = (*this)[i]; auto c = (a + b) % 3; - if (c != ((r.mVal >> (i * 2)) & 3)) - throw RTE_LOC; + + r.mVal |= u64(c) << (i * 2); + //if (c != ((r.mVal >> (i * 2)) & 3)) + //throw RTE_LOC; + } + return r; + } + + + Trit32 operator-(const Trit32& t) const + { + Trit32 r; + for (u64 i = 0; i < 32; ++i) + { + auto a = t[i]; + auto b = (*this)[i]; + auto c = (b + 3 - a) % 3; + + r.mVal |= u64(c) << (i * 2); } return r; } + + bool operator==(const Trit32& t) const + { + return mVal == t.mVal; + } + + u64 toInt() const { u64 r = 0; @@ -82,12 +106,33 @@ namespace osuCrypto // returns the i'th Z_3 element. - u8 operator[](u64 i) + u8 operator[](u64 i) const { - return (mVal >> (i*2)) & 3; + return (mVal >> (i * 2)) & 3; } }; + std::ostream& operator<<(std::ostream& o, const Trit32& t) + { + u64 m = 0; + u64 v = t.mVal; + while (v) + { + ++m; + v >>= 2; + } + if (!m) + o << "0"; + else + { + for (u64 i = m - 1; i < m; --i) + { + o << int(t[i]); + } + } + return o; + } + struct TriDpf { enum class OutputFormat @@ -131,13 +176,18 @@ namespace osuCrypto //mMultiplier.init(partyIdx, numPoints * mDepth); } - // returns something similar to b % 3. - u8 trit(block b) + //// returns something similar to b % 3. + //u8 trit(block b) + //{ + // auto v = b.get(0); + // return + // static_cast(v > 6148914691236517205ull) + + // static_cast(v > 12297829382473034410ull); + //} + + u8 lsb(const block& b) { - auto v = b.get(0); - return - static_cast(v > 6148914691236517205ull) + - static_cast(v > 12297829382473034410ull); + return b.get(0) & 1; } #define SIMD8(VAR, STATEMENT) \ @@ -191,17 +241,19 @@ namespace osuCrypto // shares of S' - auto pow2 = 1ull << log2ceil(mDomain); - std::array, 2> s; - s[mDepth & 1].resize(pow2, numPoints, oc::AllocType::Uninitialized); - s[(mDepth & 1) ^ 1].resize(pow2 / 2, numPoints, oc::AllocType::Uninitialized); + auto pow3 = ipow(3, mDepth); + std::array, 3> s; + auto last = mDepth % 3; + s[last].resize(pow3, numPoints, oc::AllocType::Uninitialized); + s[(last + 2) % 3].resize(pow3 / 3, numPoints, oc::AllocType::Uninitialized); + s[(last + 1) % 3].resize(pow3 / 9, numPoints, oc::AllocType::Uninitialized); // share of t - std::array, 2> t; - t[0].resize(s[0].rows(), s[0].cols()); - t[1].resize(s[1].rows(), s[1].cols()); - for (u64 i = 0; i < numPoints; ++i) - t[0](0, i) = mPartyIdx; + //std::array, 2> t; + //t[0].resize(s[0].rows(), s[0].cols()); + //t[1].resize(s[1].rows(), s[1].cols()); + //for (u64 i = 0; i < numPoints; ++i) + // t[0](0, i) = mPartyIdx; #if defined(NDEBUG) @@ -209,18 +261,18 @@ namespace osuCrypto #else auto getRow = [](auto&& m, u64 i) {return m[i]; }; #endif - std::array, 3> tau; - tau[0].resize(mNumPoints); - tau[1].resize(mNumPoints); - tau[2].resize(mNumPoints); + //std::array, 3> tau; + //tau[0].resize(mNumPoints); + //tau[1].resize(mNumPoints); + //tau[2].resize(mNumPoints); std::array, 3> z; z[0].resize(mNumPoints); z[1].resize(mNumPoints); z[2].resize(mNumPoints); - std::array, 3> v; - v[0].resize(mNumPoints); - v[1].resize(mNumPoints); - v[2].resize(mNumPoints); + //std::array, 3> v; + //v[0].resize(mNumPoints); + //v[1].resize(mNumPoints); + //v[2].resize(mNumPoints); std::array, 3> sigma; sigma[0].resize(mNumPoints); sigma[1].resize(mNumPoints); @@ -229,11 +281,13 @@ namespace osuCrypto { // we skip level 0 and set level 1 to be random + auto t = s[0][0]; auto sc0 = s[1][0]; auto sc1 = s[1][1]; auto sc2 = s[1][2]; for (u64 k = 0; k < numPoints; ++k) { + t[k] = block::allSame(-mPartyIdx); sc0[k] = prng.get(); sc1[k] = prng.get(); sc2[k] = prng.get(); @@ -241,6 +295,8 @@ namespace osuCrypto z[0][k] = sc0[k]; z[1][k] = sc1[k]; z[2][k] = sc2[k]; + + //std::cout << "seed " << sc0[k] << " " << sc1[k] << " " << sc2[k] << std::endl; } } @@ -258,23 +314,9 @@ namespace osuCrypto // We compute left and right sums for the children. for (u64 iter = 1; iter <= mDepth; ++iter) { - // the parent level - auto& parentSeedBase = t[(iter - 1) & 1]; - - // current level - auto& seedBase = s[iter & 1]; - auto& tagBase = t[iter & 1]; - - // the child level - auto& childSeedBase = s[(iter + 1) & 1]; - //auto& childTagBase = t[(iter + 1) & 1]; - - auto size = ipow(3, iter); - { std::vector alphaj(numPoints); std::vector zz(numPoints * 3); auto zzIter = zz.begin(); - std::vector vv(numPoints * 3); auto vvIter = vv.begin(); for (u64 k = 0; k < numPoints; ++k) { alphaj[k] = points[k][mDepth - iter]; @@ -283,45 +325,73 @@ namespace osuCrypto for (u64 k = 0; k < 3; ++k) { copyBytes(span(zzIter, zzIter + numPoints), z[k]); zzIter += numPoints; - copyBytes(span(vvIter, vvIter + numPoints), v[k]); vvIter += numPoints; } co_await sock.send(coproto::copy(alphaj)); co_await sock.send(coproto::copy(zz)); - co_await sock.send(coproto::copy(vv)); auto recvAlphaj = co_await sock.recv>(); co_await sock.recv(zz); - co_await sock.recv(vv); zzIter = zz.begin(); - vvIter = vv.begin(); for (u64 k = 0; k < 3; ++k) { for (u64 i = 0; i < numPoints; ++i) { - sigma[k][i] = z[k][i] ^ *zzIter++; - tau[k][i] = v[k][i] ^ *vvIter++ ^ 1; - assert(v[k][i] < 2); + //if (v[k][i] > 2) + //{ + // throw RTE_LOC; + //} + //std::cout << "sigma = " << z[k][i] << " + " << *zzIter; + sigma[k][i] = z[k][i] ^ *zzIter++;//^ OneBlock; + + //std::cout << " = " << sigma[k][i] << std::endl; + //tau[k][i] = lsb(sigma[k][i]) ^ 1; } } for (u64 i = 0; i < numPoints; ++i) { assert(recvAlphaj[i] < 3); - alphaj[i] = (alphaj[i] + recvAlphaj[i]) % 3; + auto a = (alphaj[i] + recvAlphaj[i]) % 3; + //std::cout << "alpha[" << (mDepth - iter) << "] = " << int(a) << " = " << int(alphaj[i]) << " + " << int(recvAlphaj[i]) << std::endl; - sigma[alphaj[i]][i] ^= oc::mAesFixedKey.ecbEncBlock(block(iter, i)); - tau[alphaj[i]][i] ^= 1; + auto r = (oc::mAesFixedKey.ecbEncBlock(block(iter, i)) & ~OneBlock); + sigma[a][i] ^= r ^ OneBlock; } } + + //for (u64 k = 0; k < 3; ++k) + //{ + // for (u64 i = 0; i < numPoints; ++i) + // { + // tau[k][i] = lsb(sigma[k][i]); + // } + //} + + //std::cout << "sigma[" << iter << "] " << sigma[0][0] << " " << sigma[1][0] << " " << sigma[2][0] << std::endl; + //std::cout << "tau[" << iter << "] " << int(tau[0][0]) << " " << int(tau[1][0]) << " " << int(tau[2][0]) << std::endl; + + // the parent level + auto& parentBase = s[(iter - 1) % 3]; + + // current level + auto& seedBase = s[iter % 3]; + //auto& tagBase = t[iter % 3]; + + // the child level + auto& childBase = s[(iter + 1) % 3]; + //auto& childTagBase = t[(iter + 1) & 1]; + + auto size = ipow(3, iter); + if (iter != mDepth) { for (u64 i = 0; i < 3; ++i) { setBytes(z[i], 0); - setBytes(v[i], 0); + //setBytes(v[i], 0); } // we iterate over each parent control bit. @@ -330,52 +400,32 @@ namespace osuCrypto for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 3, L4 += 9) { // parent control bits, one for each tree. - auto parentTag = getRow(parentSeedBase, L); + auto parentTag = getRow(parentBase, L); + //auto parentSeed = getRow(seedBase, L); // child seed, three for each tree. std::array seed{ getRow(seedBase, L2 + 0), getRow(seedBase, L2 + 1) , getRow(seedBase, L2 + 2) }; // child control bit, tree for each tree. - std::array tag{ getRow(tagBase, L2 + 0), getRow(tagBase, L2 + 1), getRow(tagBase, L2 + 2) }; + //std::array tag{ getRow(tagBase, L2 + 0), getRow(tagBase, L2 + 1), getRow(tagBase, L2 + 2) }; // grandchild seeds, nine for each tree. - std::array childSeed; + std::array childSeed; for (u64 i = 0; i < 9; ++i) - childSeed[i] = getRow(childSeedBase, L4 + i); - - //for (u64 k = 0; k < numPoints8; k += 8) - //{ - // block temp[8]; - // SIMD8(q, temp[q] = block::allSame(-parentTag[k + q]) & sigma[k + q]); - // SIMD8(q, tag[0][k + q] = lsb(seed[0][k + q]) ^ parentTag[k + q] & tau[0][k + q]); - // SIMD8(q, seed[0][k + q] ^= temp[q]); - - - // mAesFixedKey.ecbEncBlocks<8>(&seed[0][k], &childSeed[1][k]); - // SIMD8(q, childSeed[0][k + q] = AES::roundEnc(childSeed[1][k + q], seed[0][k + q])); - // SIMD8(q, childSeed[1][k + q] = childSeed[1][k + q] + seed[0][k + q]); - - // SIMD8(q, z[0][k + q] ^= childSeed[0][k + q]); - // SIMD8(q, z[1][k + q] ^= childSeed[1][k + q]); + childSeed[i] = getRow(childBase, L4 + i); - // SIMD8(q, tag[1][k + q] = lsb(seed[1][k + q]) ^ parentTag[k + q] & tau[1][k + q]); - // SIMD8(q, seed[1][k + q] ^= temp[q]); - // mAesFixedKey.ecbEncBlocks<8>(&seed[1][k], &childSeed[3][k]); - // SIMD8(q, childSeed[2][k + q] = AES::roundEnc(childSeed[3][k + q], seed[1][k + q])); - // SIMD8(q, childSeed[3][k + q] = childSeed[3][k + q] + seed[1][k + q]); - // SIMD8(q, z[0][k + q] ^= childSeed[2][k + q]); - // SIMD8(q, z[1][k + q] ^= childSeed[3][k + q]); - //} - //auto& sigmaL = sigma[L % 3]; - - for (u64 k = 0; k < mNumPoints; ++k) + for (u64 j = 0; j < 3; ++j) { - for (u64 j = 0; j < 3; ++j) + for (u64 k = 0; k < mNumPoints; ++k) { + // (s,t) = (s,t) ^ q * sigma_j - tag[j][k] = trit(seed[j][k]) ^ parentTag[k] & tau[j][k]; - seed[j][k] ^= block::allSame(-parentTag[k + 0]) & sigma[j][k + 0]; + //tag[j][k] = lsb(seed[j][k]) ^ parentTag[k] & tau[j][k]; + //seed[j][k] ^= block::allSame(-i64(parentTag[k + 0])) & sigma[j][k + 0]; + seed[j][k] ^= parentTag[k] & sigma[j][k]; + //tag[j][k] = lsb(seed[j][k]); + //std::cout << mPartyIdx << " " << Trit32(L2 + j) << " " << seed[j][k] << " " << int(lsb(seed[j][k])) <<" " << parentTag[k] << std::endl; // for (u64 i = 0; i < 3; ++i) @@ -384,89 +434,75 @@ namespace osuCrypto childSeed[j * 3 + i][k] = s; z[i][k] ^= s; } - } - - //tag[1][k] = lsb(seed[1][k]) ^ parentTag[k] & tau[1][k]; - //seed[1][k] ^= temp; - - //childSeed[3][k] = mAesFixedKey.ecbEncBlock(seed[1][k]); - //childSeed[2][k] = AES::roundEnc(childSeed[3][k], seed[1][k]); - //childSeed[3][k] = childSeed[3][k] + seed[1][k]; - //z[0][k] ^= childSeed[2][k]; - //z[1][k] ^= childSeed[3][k]; + // replace the seed with the tag. + seed[j][k] = block::allSame(-lsb(seed[j][k])); + } } } } } - - + AlignedUnVector sums(mNumPoints); + Matrix t(ipow(3, mDepth), mNumPoints); // fixing the last layer { auto size = ipow(3, mDepth); - auto& parentTag = t[(mDepth - 1) & 1]; - auto& curSeed = s[mDepth & 1]; - auto& curTag = t[mDepth & 1]; + auto& parentTag = s[(mDepth - 1) % 3]; + auto& curSeed = s[mDepth % 3]; + //auto& curTag = t[mDepth & 1]; + for (u64 L = 0, L2 = 0; L2 < size; ++L, L2 += 3) { // parent control bits auto tpl = getRow(parentTag, L); // child seed - std::array scl{ getRow(curSeed, L2 + 0), getRow(curSeed, L2 + 1) }; + std::array scl{ getRow(curSeed, L2 + 0), getRow(curSeed, L2 + 1), getRow(curSeed, L2 + 2) }; // child control bit - std::array tcl{ getRow(curTag, L2 + 0), getRow(curTag, L2 + 1) }; + //std::array tcl{ getRow(curTag, L2 + 0), getRow(curTag, L2 + 1) , getRow(curTag, L2 + 2) }; - //for (u64 k = 0; k < numPoints8; k += 8) - //{ - // block temp[8]; - // SIMD8(q, temp[q] = block::allSame(-parentTag[k + q]) & sigma[k + q]); - // SIMD8(q, tag[0][k + q] = lsb(seed[0][k + q]) ^ parentTag[k + q] & tau[0][k + q]); - // SIMD8(q, tag[1][k + q] = lsb(seed[1][k + q]) ^ parentTag[k + q] & tau[1][k + q]); - // SIMD8(q, seed[0][k + q] ^= temp[q]); - // SIMD8(q, seed[1][k + q] ^= temp[q]); - //} - - for (u64 k = 0; k < mNumPoints; ++k) + for (u64 j = 0; j < 3; ++j) { - for (u64 j = 0; j < 3; ++j) + for (u64 k = 0; k < mNumPoints; ++k) { - curTag[L2 + j][k] = trit(scl[j][k]) ^ tpl[k] & tau[j][k]; - curSeed[L2 + j][k] ^= block::allSame(-tpl[k]) & sigma[j][k];; + //curTag[L2 + j][k] = lsb(scl[j][k]) ^ tpl[k] & tau[j][k]; + auto s = curSeed[L2 + j][k] ^ tpl[k] & sigma[j][k]; + t[L2 + j][k] = lsb(s); + curSeed[L2 + j][k] = /*convert_G*/ AES::roundFn(s, s); + sums[k] = sums[k] ^ curSeed[L2 + j][k]; + //std::cout << mPartyIdx << " " << Trit32(L2 + j) << " " << curSeed[L2 + j][k] << " " << int(curTag[L2 + j][k]) << std::endl; } - - - //curTag[L2 + 0][k] = lsb(scl[0][k]) ^ tpl[k] & tau[0][k]; - //curTag[L2 + 1][k] = lsb(scl[1][k]) ^ tpl[k] & tau[1][k]; - //curSeed[L2 + 0][k] ^= block::allSame(-tpl[k + 0]) & sigma[k % 3][k + 0];; - //curSeed[L2 + 1][k] ^= temp; } } } + //std::cout << "----------" << std::endl; if (values.size()) { AlignedUnVector gamma(mNumPoints), diff(mNumPoints); + setBytes(diff, 0); + auto& curSeed = s[mDepth % 3]; + for (u64 k = 0; k < mNumPoints; ++k) { - diff[k] = z[0][k] ^ z[1][k] ^ values[k]; + diff[k] = sums[k] ^ values[k]; } co_await sock.send(std::move(diff)); co_await sock.recv(gamma); for (u64 k = 0; k < mNumPoints; ++k) { - gamma[k] = z[0][k] ^ z[1][k] ^ values[k] ^ gamma[k]; + gamma[k] = sums[k] ^ values[k] ^ gamma[k]; } - auto& sd = s[mDepth & 1]; - auto& td = t[mDepth & 1]; + auto& sd = s[mDepth % 3]; + //auto& td = t[mDepth & 1]; for (u64 i = 0; i < mDomain; ++i) { auto sdi = getRow(sd, i); - auto tdi = getRow(td, i); + auto tdi = getRow(t, i); //for (u64 k = 0; k < numPoints8; k += 8) //{ @@ -476,15 +512,18 @@ namespace osuCrypto //} for (u64 k = 0; k < mNumPoints; ++k) { - auto T = block::allSame(-tdi[k]) & gamma[k]; - output(k, i, sdi[k] ^ T, tdi[k]); + auto T = block::allSame(-tdi[k]) & gamma[k]; + auto V = sdi[k] ^ T; + //std::cout << mPartyIdx << " " << Trit32(i) << " " << sdi[k] << " " << int(tdi[k]) << std::endl; + + output(k, i, V, tdi[k]); } } } else { - auto& sd = s[mDepth & 1]; - auto& td = t[mDepth & 1]; + auto& sd = s[mDepth % 3]; + auto& td = t;// [mDepth & 1] ; for (u64 i = 0; i < mDomain; ++i) { auto sdi = getRow(sd, i); @@ -503,7 +542,8 @@ namespace osuCrypto u64 baseOtCount() const { - throw RTE_LOC; + return mDepth * mNumPoints; + //throw RTE_LOC; //return mMultiplier.baseOtCount(); } @@ -512,7 +552,7 @@ namespace osuCrypto span recvBaseOts, const oc::BitVector& baseChoices) { - throw RTE_LOC; + //throw RTE_LOC; //mMultiplier.setBaseOts(baseSendOts, recvBaseOts, baseChoices); } diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index c34017b6..3e25ec2d 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -291,17 +291,21 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) { PRNG prng(block(231234, 321312)); - u64 depth = 3; - u64 domain = ipow(3,depth); - u64 numPoints = 11; + u64 depth = 4; + u64 domain = ipow(3,depth) - 3; + u64 numPoints = 17; std::vector points0(numPoints); std::vector points1(numPoints); + std::vector points(numPoints); std::vector values0(numPoints); std::vector values1(numPoints); for (u64 i = 0; i < numPoints; ++i) { + points[i] = Trit32(prng.get() % domain); points1[i] = Trit32(prng.get() % domain); - points0[i] = Trit32(prng.get() % domain) + points1[i]; + points0[i] = points[i] - points1[i]; + + //std::cout << points[i] << " = " << points0[i] <<" + "<< points1[i] << std::endl; values0[i] = prng.get(); values1[i] = prng.get(); } @@ -349,15 +353,20 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) for (u64 i = 0; i < domain; ++i) { + Trit32 I(i); for (u64 k = 0; k < numPoints; ++k) { - auto p = (points0[k] + points1[k]).toInt(); auto act = output[0][k][i] ^ output[1][k][i]; - auto t = i == p ? 1 : 0; + auto t = I == points[k] ? 1 : 0; auto tAct = tags[0][k][i] ^ tags[1][k][i]; auto exp = t ? (values0[k] ^ values1[k]) : ZeroBlock; if (exp != act) + { + std::cout << "i " << i << "="<< Trit32(i)<<" " << t << std::endl; + std::cout << "exp " << exp << std::endl; + std::cout << "act " << act << std::endl; throw RTE_LOC; + } if (t != tAct) throw RTE_LOC; } From c7bd266b7d7d1ca997db7ecb69d1d21fdc48cc0e Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 12 Feb 2025 08:37:48 -0800 Subject: [PATCH 09/48] triDpf Working --- libOTe/Tools/Dpf/TriDpf.h | 392 ++++++++++++++++++---------- libOTe/Tools/Foleage/FoleagePcg.cpp | 4 +- libOTe_Tests/RegularDpf_Tests.cpp | 4 +- 3 files changed, 258 insertions(+), 142 deletions(-) diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index 6d534c80..3dfbe988 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -135,20 +135,6 @@ namespace osuCrypto struct TriDpf { - enum class OutputFormat - { - // The i'th row holds the i'th leaf for all trees. - // The j'th tree is in the j'th column. - ByLeafIndex, - - // The i'th row holds the i'th tree. - // The j'th leaf is in the j'th column. - ByTreeIndex, - - }; - - OutputFormat mOutputFormat = OutputFormat::ByLeafIndex; - u64 mPartyIdx = 0; u64 mDomain = 0; @@ -157,6 +143,12 @@ namespace osuCrypto u64 mNumPoints = 0; + u64 mOtIdx = 0; + + std::vector> mBaseSendOts; + std::vector mBaseRecvOts; + std::vector mBaseChoice; + void init( u64 partyIdx, u64 domain, @@ -173,6 +165,7 @@ namespace osuCrypto mPartyIdx = partyIdx; mDomain = domain; mNumPoints = numPoints; + mOtIdx = 0; //mMultiplier.init(partyIdx, numPoints * mDepth); } @@ -236,47 +229,23 @@ namespace osuCrypto throw std::runtime_error("TriDpf: invalid point sharing. point is larger than 3^D " LOCATION); } - u64 numPoints = points.size(); - u64 numPoints8 = numPoints / 8 * 8; - + u64 numPoints8 = mNumPoints / 8 * 8; // shares of S' auto pow3 = ipow(3, mDepth); std::array, 3> s; auto last = mDepth % 3; - s[last].resize(pow3, numPoints, oc::AllocType::Uninitialized); - s[(last + 2) % 3].resize(pow3 / 3, numPoints, oc::AllocType::Uninitialized); - s[(last + 1) % 3].resize(pow3 / 9, numPoints, oc::AllocType::Uninitialized); - - // share of t - //std::array, 2> t; - //t[0].resize(s[0].rows(), s[0].cols()); - //t[1].resize(s[1].rows(), s[1].cols()); - //for (u64 i = 0; i < numPoints; ++i) - // t[0](0, i) = mPartyIdx; - + s[last].resize(pow3, mNumPoints, oc::AllocType::Uninitialized); + s[(last + 2) % 3].resize(pow3 / 3, mNumPoints, oc::AllocType::Uninitialized); + s[(last + 1) % 3].resize(pow3 / 9, mNumPoints, oc::AllocType::Uninitialized); #if defined(NDEBUG) auto getRow = [](auto&& m, u64 i) {return m.data(i); }; #else auto getRow = [](auto&& m, u64 i) {return m[i]; }; #endif - //std::array, 3> tau; - //tau[0].resize(mNumPoints); - //tau[1].resize(mNumPoints); - //tau[2].resize(mNumPoints); - std::array, 3> z; - z[0].resize(mNumPoints); - z[1].resize(mNumPoints); - z[2].resize(mNumPoints); - //std::array, 3> v; - //v[0].resize(mNumPoints); - //v[1].resize(mNumPoints); - //v[2].resize(mNumPoints); - std::array, 3> sigma; - sigma[0].resize(mNumPoints); - sigma[1].resize(mNumPoints); - sigma[2].resize(mNumPoints); + Matrix z(3, mNumPoints); + Matrix sigma(3, mNumPoints); { @@ -285,7 +254,7 @@ namespace osuCrypto auto sc0 = s[1][0]; auto sc1 = s[1][1]; auto sc2 = s[1][2]; - for (u64 k = 0; k < numPoints; ++k) + for (u64 k = 0; k < mNumPoints; ++k) { t[k] = block::allSame(-mPartyIdx); sc0[k] = prng.get(); @@ -306,7 +275,7 @@ namespace osuCrypto AES(block(324532450985209453,5387987243989842789)) }; // at each iteration we first correct the parent level. - // The parent level has two syblings which are random. + // The parent level has two siblings which are random. // We need to correct the inactive child so that both parties // hold the same seed (a sharing of zero). // @@ -314,61 +283,8 @@ namespace osuCrypto // We compute left and right sums for the children. for (u64 iter = 1; iter <= mDepth; ++iter) { - { - std::vector alphaj(numPoints); - std::vector zz(numPoints * 3); auto zzIter = zz.begin(); - for (u64 k = 0; k < numPoints; ++k) - { - alphaj[k] = points[k][mDepth - iter]; - } - - for (u64 k = 0; k < 3; ++k) - { - copyBytes(span(zzIter, zzIter + numPoints), z[k]); zzIter += numPoints; - } - - co_await sock.send(coproto::copy(alphaj)); - co_await sock.send(coproto::copy(zz)); - - auto recvAlphaj = co_await sock.recv>(); - co_await sock.recv(zz); - - zzIter = zz.begin(); - for (u64 k = 0; k < 3; ++k) - { - for (u64 i = 0; i < numPoints; ++i) - { - //if (v[k][i] > 2) - //{ - // throw RTE_LOC; - //} - //std::cout << "sigma = " << z[k][i] << " + " << *zzIter; - sigma[k][i] = z[k][i] ^ *zzIter++;//^ OneBlock; - - //std::cout << " = " << sigma[k][i] << std::endl; - //tau[k][i] = lsb(sigma[k][i]) ^ 1; - } - } - - for (u64 i = 0; i < numPoints; ++i) - { - assert(recvAlphaj[i] < 3); - auto a = (alphaj[i] + recvAlphaj[i]) % 3; - //std::cout << "alpha[" << (mDepth - iter) << "] = " << int(a) << " = " << int(alphaj[i]) << " + " << int(recvAlphaj[i]) << std::endl; - - auto r = (oc::mAesFixedKey.ecbEncBlock(block(iter, i)) & ~OneBlock); - sigma[a][i] ^= r ^ OneBlock; - } - } - - //for (u64 k = 0; k < 3; ++k) - //{ - // for (u64 i = 0; i < numPoints; ++i) - // { - // tau[k][i] = lsb(sigma[k][i]); - // } - //} + co_await correctionWord(points, z, sigma, iter, sock); //std::cout << "sigma[" << iter << "] " << sigma[0][0] << " " << sigma[1][0] << " " << sigma[2][0] << std::endl; //std::cout << "tau[" << iter << "] " << int(tau[0][0]) << " " << int(tau[1][0]) << " " << int(tau[2][0]) << std::endl; @@ -378,11 +294,9 @@ namespace osuCrypto // current level auto& seedBase = s[iter % 3]; - //auto& tagBase = t[iter % 3]; // the child level auto& childBase = s[(iter + 1) % 3]; - //auto& childTagBase = t[(iter + 1) & 1]; auto size = ipow(3, iter); @@ -391,7 +305,6 @@ namespace osuCrypto for (u64 i = 0; i < 3; ++i) { setBytes(z[i], 0); - //setBytes(v[i], 0); } // we iterate over each parent control bit. @@ -401,15 +314,11 @@ namespace osuCrypto { // parent control bits, one for each tree. auto parentTag = getRow(parentBase, L); - //auto parentSeed = getRow(seedBase, L); - // child seed, three for each tree. + // current seed, three for each tree. std::array seed{ getRow(seedBase, L2 + 0), getRow(seedBase, L2 + 1) , getRow(seedBase, L2 + 2) }; - // child control bit, tree for each tree. - //std::array tag{ getRow(tagBase, L2 + 0), getRow(tagBase, L2 + 1), getRow(tagBase, L2 + 2) }; - - // grandchild seeds, nine for each tree. + // child seeds, nine for each tree. std::array childSeed; for (u64 i = 0; i < 9; ++i) childSeed[i] = getRow(childBase, L4 + i); @@ -419,24 +328,17 @@ namespace osuCrypto { for (u64 k = 0; k < mNumPoints; ++k) { + auto seedjk = seed[j][k] ^ parentTag[k] & sigma[j][k]; - // (s,t) = (s,t) ^ q * sigma_j - //tag[j][k] = lsb(seed[j][k]) ^ parentTag[k] & tau[j][k]; - //seed[j][k] ^= block::allSame(-i64(parentTag[k + 0])) & sigma[j][k + 0]; - seed[j][k] ^= parentTag[k] & sigma[j][k]; - //tag[j][k] = lsb(seed[j][k]); - //std::cout << mPartyIdx << " " << Trit32(L2 + j) << " " << seed[j][k] << " " << int(lsb(seed[j][k])) <<" " << parentTag[k] << std::endl; - - // for (u64 i = 0; i < 3; ++i) { - auto s = aes[i].hashBlock(seed[j][k]); + auto s = aes[i].hashBlock(seedjk); childSeed[j * 3 + i][k] = s; z[i][k] ^= s; } // replace the seed with the tag. - seed[j][k] = block::allSame(-lsb(seed[j][k])); + seed[j][k] = block::allSame(-lsb(seedjk)); } } } @@ -448,27 +350,22 @@ namespace osuCrypto { auto size = ipow(3, mDepth); - auto& parentTag = s[(mDepth - 1) % 3]; + auto& parentTags = s[(mDepth - 1) % 3]; auto& curSeed = s[mDepth % 3]; - //auto& curTag = t[mDepth & 1]; for (u64 L = 0, L2 = 0; L2 < size; ++L, L2 += 3) { // parent control bits - auto tpl = getRow(parentTag, L); + auto parentTag = getRow(parentTags, L); // child seed std::array scl{ getRow(curSeed, L2 + 0), getRow(curSeed, L2 + 1), getRow(curSeed, L2 + 2) }; - // child control bit - //std::array tcl{ getRow(curTag, L2 + 0), getRow(curTag, L2 + 1) , getRow(curTag, L2 + 2) }; - for (u64 j = 0; j < 3; ++j) { for (u64 k = 0; k < mNumPoints; ++k) { - //curTag[L2 + j][k] = lsb(scl[j][k]) ^ tpl[k] & tau[j][k]; - auto s = curSeed[L2 + j][k] ^ tpl[k] & sigma[j][k]; + auto s = curSeed[L2 + j][k] ^ parentTag[k] & sigma[j][k]; t[L2 + j][k] = lsb(s); curSeed[L2 + j][k] = /*convert_G*/ AES::roundFn(s, s); sums[k] = sums[k] ^ curSeed[L2 + j][k]; @@ -481,7 +378,6 @@ namespace osuCrypto if (values.size()) { - AlignedUnVector gamma(mNumPoints), diff(mNumPoints); setBytes(diff, 0); auto& curSeed = s[mDepth % 3]; @@ -504,7 +400,7 @@ namespace osuCrypto auto sdi = getRow(sd, i); auto tdi = getRow(t, i); - //for (u64 k = 0; k < numPoints8; k += 8) + //for (u64 k = 0; k < mNumPoints8; k += 8) //{ // block T[8]; // SIMD8(q, T[q] = block::allSame(-tdi[k + q]) & gamma[k + q]); @@ -514,8 +410,6 @@ namespace osuCrypto { auto T = block::allSame(-tdi[k]) & gamma[k]; auto V = sdi[k] ^ T; - //std::cout << mPartyIdx << " " << Trit32(i) << " " << sdi[k] << " " << int(tdi[k]) << std::endl; - output(k, i, V, tdi[k]); } } @@ -541,10 +435,219 @@ namespace osuCrypto } - u64 baseOtCount() const { - return mDepth * mNumPoints; - //throw RTE_LOC; - //return mMultiplier.baseOtCount(); + // we are going to create 3 ot message + // + // m0, m1, m2 + // + // such that m_{-a0} = r || 1 for some random r. + // + // the receiver will use choice a1. + macoro::task<> correctionWord(span points, MatrixView z, MatrixView sigma, u64 iter, coproto::Socket& sock) + { + //{ + // char x = 0; + // co_await sock.send(char{ x }); + // co_await sock.recv(x); + //} + //std::cout << "=======" << iter << "======== " << std::endl; + + Matrix sigmaShares(3, mNumPoints); + AlignedUnVector> mask(mNumPoints); + + std::array socks; + socks[0] = sock; + socks[1] = sock.fork(); + if (mPartyIdx) + std::swap(socks[0], socks[1]); + + + auto H = [](const block& a, const block& b) -> block { + RandomOracle ro(sizeof(block)); + ro.Update(a); + ro.Update(b); + block r; + ro.Final(r); + return r; + }; + + auto sender = [&]() -> macoro::task<> { + PRNG prng(block(234134, 21452345 * mPartyIdx)); + + BitVector correction(mNumPoints * 2); + AlignedUnVector> buffer(mNumPoints * 3); + co_await socks[0].recv(correction); + //auto sendIter = mBaseSendOts.begin() + mOtIdx; + for (u64 i = 0; i < mNumPoints; ++i) + { + auto keys0 = mBaseSendOts[mOtIdx + i * 2 + 0]; + auto keys1 = mBaseSendOts[mOtIdx + i * 2 + 1]; + std::array k;// , m; + //std::cout << "p" << mPartyIdx << std::endl;// "\n " << k[0] << "\n " << k[1] << "\n " << k[2] << std::endl; + for (u64 j = 0; j < 3; ++j) + { + auto j0 = j & 1; + auto j1 = j >> 1; + + auto b0 = j0 ^ correction[i * 2 + 0]; + auto b1 = j1 ^ correction[i * 2 + 1]; + auto k0 = keys0[b0]; + auto k1 = keys1[b1]; + + k[j] = H(k0, k1); + //std::cout << "k" << j << " " << k[j] << " = H( " + // << std::hex << k0.get(0) << " " << b0 << " " + // << std::hex << k1.get(0) << " " << b1 << " ) " << std::endl; + } + + block r = prng.get(); + *BitIterator(&r) = mPartyIdx; + + //std::array mask;// = prng.get(); + setBytes(mask[i], 0); + auto a = points[i][mDepth - iter]; + //std::cout << "a0 " << int(a) << std::endl; + + for (u64 j = 0; j < 3; ++j) + { + buffer[i * 3 + j] = PRNG(k[j], 3).get(); + buffer[i * 3 + j][0] ^= mask[i][0]; + buffer[i * 3 + j][1] ^= mask[i][1]; + buffer[i * 3 + j][2] ^= mask[i][2]; + buffer[i * 3 + j][(j + a) % 3] ^= r; + + //std::cout << "buffer " << j << std::endl + // << " " << buffer[i * 3 + j][0] << "\n" + // << " " << buffer[i * 3 + j][1] << "\n" + // << " " << buffer[i * 3 + j][2] << "\n"; + } + } + + co_await socks[0].send(std::move(buffer)); + + co_await socks[0].recv(sigmaShares); + + }; + + auto recver = [&]() -> macoro::task<> { + BitVector correction(mNumPoints * 2); + for (u64 i = 0; i < mNumPoints; ++i) + { + auto a = points[i][mDepth - iter]; + correction[i * 2 + 0] = ((a >> 0) & 1) ^ mBaseChoice[mOtIdx + i * 2 + 0]; + correction[i * 2 + 1] = ((a >> 1) & 1) ^ mBaseChoice[mOtIdx + i * 2 + 1]; + } + co_await socks[1].send(std::move(correction)); + AlignedUnVector> buffer(mNumPoints * 3); + + co_await socks[1].recv(buffer); + + for (u64 i = 0; i < mNumPoints; ++i) + { + auto a = points[i][mDepth - iter]; + + auto k = H( + mBaseRecvOts[mOtIdx + i * 2 + 0], + mBaseRecvOts[mOtIdx + i * 2 + 1]); + //std::cout << "p" << mPartyIdx << " ka " << k << " = H( " + // << std::hex << mBaseRecvOts[i * 2 + 0].get(0) << " " << int(mBaseChoice[i * 2 + 0]) << " " + // << std::hex << mBaseRecvOts[i * 2 + 1].get(0) << " " << int(mBaseChoice[i * 2 + 0]) << " )" << " a1 " << int(a) << std::endl; + + //std::cout << "buffer " << std::endl + // << " " << buffer[i * 3 + a][0] << "\n" + // << " " << buffer[i * 3 + a][1] << "\n" + // << " " << buffer[i * 3 + a][2] << "\n"; + std::array ka = PRNG(k, 3).get(); + sigma[0][i] = ka[0] ^ buffer[i * 3 + a][0] ^ z[0][i]; + sigma[1][i] = ka[1] ^ buffer[i * 3 + a][1] ^ z[1][i]; + sigma[2][i] = ka[2] ^ buffer[i * 3 + a][2] ^ z[2][i]; + + //std::cout << "sigma " << std::endl + // << " " << sigma[0][i] << " = " << std::hex << ka[0].get(0) << " + " << std::hex << buffer[i * 3 + a][0].get(0) << " + " << std::hex << z[0][i].get(0) << "\n" + // << " " << sigma[1][i] << " = " << std::hex << ka[1].get(0) << " + " << std::hex << buffer[i * 3 + a][1].get(0) << " + " << std::hex << z[1][i].get(0) << "\n" + // << " " << sigma[2][i] << " = " << std::hex << ka[2].get(0) << " + " << std::hex << buffer[i * 3 + a][2].get(0) << " + " << std::hex << z[2][i].get(0) << "\n\n"; + } + + co_await socks[1].send(Matrix(sigma)); + }; + co_await macoro::when_all_ready( + sender(), + recver() + ); + + for (u64 i = 0; i < mNumPoints; ++i) + { + for (u64 j = 0; j < 3; ++j) + { + //std::cout << "sigma = " << (sigma[j][i] ^ sigmaShares[j][i]) << " = " << sigma[j][i] << " ^ " << sigmaShares[j][i] << std::endl; + sigma[j][i] ^= sigmaShares[j][i] ^ mask[i][j]; + } + } + + mOtIdx += mNumPoints * 2; + + if (0) + { + std::vector alphaj(mNumPoints); + std::vector zz(mNumPoints * 3); auto zzIter = zz.begin(); + for (u64 k = 0; k < mNumPoints; ++k) + { + alphaj[k] = points[k][mDepth - iter]; + } + + for (u64 k = 0; k < 3; ++k) + { + copyBytes(span(zzIter, zzIter + mNumPoints), z[k]); zzIter += mNumPoints; + } + + co_await sock.send(coproto::copy(alphaj)); + co_await sock.send(coproto::copy(zz)); + + auto recvAlphaj = co_await sock.recv>(); + co_await sock.recv(zz); + + zzIter = zz.begin(); + Matrix sigma2(3, mNumPoints); + for (u64 k = 0; k < 3; ++k) + { + std::cout << "sigma2 \n"; + for (u64 i = 0; i < mNumPoints; ++i) + { + std::cout << " " << (z[k][i] ^ *zzIter) << " = " << std::hex << z[k][i].get(0) << " ^ " << std::hex << zzIter->get(0) << std::endl; + sigma2[k][i] = z[k][i] ^ *zzIter++; + //sigma2[k][i] = ZeroBlock; + } + } + + for (u64 i = 0; i < mNumPoints; ++i) + { + assert(recvAlphaj[i] < 3); + auto a = (alphaj[i] + recvAlphaj[i]) % 3; + //auto r = (oc::mAesFixedKey.ecbEncBlock(block(iter, i)) | OneBlock); + //sigma[a][i] ^= r; + + std::cout << sigma[0][i] << (a == 0 ? '<' : ' ') << std::endl; + std::cout << sigma[1][i] << (a == 1 ? '<' : ' ') << std::endl; + std::cout << sigma[2][i] << (a == 2 ? '<' : ' ') << std::endl; + + auto a1 = (a + 1) % 3; + auto a2 = (a + 2) % 3; + + //if ((sigma[a][i].get(0) & 1) == 0) + // throw RTE_LOC; + if (sigma[a1][i] != sigma2[a1][i]) + { + std::cout << "sigma[" << a1 << "][" << i << "] " << sigma[a1][i] << " != exp " << sigma2[a1][i] << std::endl; + throw RTE_LOC; + } + if (sigma[a2][i] != sigma2[a2][i]) + throw RTE_LOC; + } + } + } + + u64 baseOtCount() const + { + return mDepth * mNumPoints * 2; } void setBaseOts( @@ -552,11 +655,24 @@ namespace osuCrypto span recvBaseOts, const oc::BitVector& baseChoices) { - //throw RTE_LOC; - //mMultiplier.setBaseOts(baseSendOts, recvBaseOts, baseChoices); - } + if (baseSendOts.size() != baseOtCount()) + throw RTE_LOC; + if (recvBaseOts.size() != baseOtCount()) + throw RTE_LOC; + if (baseChoices.size() != baseOtCount()) + throw RTE_LOC; + mBaseSendOts.resize(baseOtCount()); + mBaseRecvOts.resize(mBaseSendOts.size()); + mBaseChoice.resize(mBaseSendOts.size()); + for (u64 i = 0; i < mBaseSendOts.size(); ++i) + { + mBaseSendOts[i] = baseSendOts[i]; + mBaseRecvOts[i] = recvBaseOts[i]; + mBaseChoice[i] = baseChoices[i]; + } + } }; } diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index e829ca9d..8dfb4050 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -12,13 +12,13 @@ namespace osuCrypto void FoleageF4Ole::init(u64 partyIdx, u64 n, PRNG& prng) { mPartyIdx = partyIdx; - mLog3N = log2ceil(n); + mLog3N = log3ceil(n); mN = ipow(3, mLog3N); if (mT != ipow(3, mLog3T)) throw RTE_LOC; - mDpfDomainDepth = std::max(1, log2ceil(divCeil(mN, mT * 256))); + mDpfDomainDepth = std::max(1, log3ceil(divCeil(mN, mT * 256))); mDpfBlockSize = 4 * ipow(3, mDpfDomainDepth); mBlockSize = mN / mT; diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 3e25ec2d..ce215eda 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -291,9 +291,9 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) { PRNG prng(block(231234, 321312)); - u64 depth = 4; + u64 depth = cmd.getOr("depth", 3); u64 domain = ipow(3,depth) - 3; - u64 numPoints = 17; + u64 numPoints = cmd.getOr("numPoints", 17); std::vector points0(numPoints); std::vector points1(numPoints); std::vector points(numPoints); From ba0b0b31bd6e6a76d5f944c138680610de769eab Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 12 Feb 2025 13:03:37 -0800 Subject: [PATCH 10/48] optimized triDpf --- libOTe/Tools/Dpf/TriDpf.h | 119 +++++++++++++++------------- libOTe/Tools/Foleage/FoleagePcg.cpp | 7 ++ libOTe/Tools/Foleage/FoleagePcg.h | 1 + 3 files changed, 74 insertions(+), 53 deletions(-) diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index 3dfbe988..faa5e716 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -166,18 +166,8 @@ namespace osuCrypto mDomain = domain; mNumPoints = numPoints; mOtIdx = 0; - //mMultiplier.init(partyIdx, numPoints * mDepth); } - //// returns something similar to b % 3. - //u8 trit(block b) - //{ - // auto v = b.get(0); - // return - // static_cast(v > 6148914691236517205ull) + - // static_cast(v > 12297829382473034410ull); - //} - u8 lsb(const block& b) { return b.get(0) & 1; @@ -367,7 +357,7 @@ namespace osuCrypto { auto s = curSeed[L2 + j][k] ^ parentTag[k] & sigma[j][k]; t[L2 + j][k] = lsb(s); - curSeed[L2 + j][k] = /*convert_G*/ AES::roundFn(s, s); + curSeed[L2 + j][k] = /*convert_G*/ AES::roundFn(s, s);//AES::roundFn is used to get rid of the correlation in the LSB. sums[k] = sums[k] ^ curSeed[L2 + j][k]; //std::cout << mPartyIdx << " " << Trit32(L2 + j) << " " << curSeed[L2 + j][k] << " " << int(curTag[L2 + j][k]) << std::endl; } @@ -453,6 +443,7 @@ namespace osuCrypto Matrix sigmaShares(3, mNumPoints); AlignedUnVector> mask(mNumPoints); + AlignedUnVector> recvBuffer(mNumPoints * 2); std::array socks; socks[0] = sock; @@ -474,7 +465,7 @@ namespace osuCrypto PRNG prng(block(234134, 21452345 * mPartyIdx)); BitVector correction(mNumPoints * 2); - AlignedUnVector> buffer(mNumPoints * 3); + AlignedUnVector> sendBuffer(mNumPoints * 2); co_await socks[0].recv(correction); //auto sendIter = mBaseSendOts.begin() + mOtIdx; for (u64 i = 0; i < mNumPoints; ++i) @@ -503,17 +494,31 @@ namespace osuCrypto *BitIterator(&r) = mPartyIdx; //std::array mask;// = prng.get(); - setBytes(mask[i], 0); + //mask[i] = prng.get(); auto a = points[i][mDepth - iter]; //std::cout << "a0 " << int(a) << std::endl; - for (u64 j = 0; j < 3; ++j) { - buffer[i * 3 + j] = PRNG(k[j], 3).get(); - buffer[i * 3 + j][0] ^= mask[i][0]; - buffer[i * 3 + j][1] ^= mask[i][1]; - buffer[i * 3 + j][2] ^= mask[i][2]; - buffer[i * 3 + j][(j + a) % 3] ^= r; + + // sendBuffer[i * 3 + 0] = kj ^ mask ^ unitVec(r, a); + // 0 = kj ^ mask ^ unitVec(r, a); + // mask = kj ^ unitVec(r, a); + + mask[i] = PRNG(k[0], 3).get(); + //setBytes(mask[i], 0); + mask[i][a] ^= r; + } + + for (u64 j = 0; j < 2; ++j) + { + std::array kj = PRNG(k[j+1], 3).get(); + //setBytes(kj, 0); + + //sendBuffer[i * 3 + j] = PRNG(k[j], 3).get(); + sendBuffer[i * 2 + j][0] = kj[0] ^ mask[i][0]; + sendBuffer[i * 2 + j][1] = kj[1] ^ mask[i][1]; + sendBuffer[i * 2 + j][2] = kj[2] ^ mask[i][2]; + sendBuffer[i * 2 + j][(j+1 + a) % 3] ^= r; //std::cout << "buffer " << j << std::endl // << " " << buffer[i * 3 + j][0] << "\n" @@ -522,9 +527,8 @@ namespace osuCrypto } } - co_await socks[0].send(std::move(buffer)); + co_await socks[0].send(std::move(sendBuffer)); - co_await socks[0].recv(sigmaShares); }; @@ -537,55 +541,64 @@ namespace osuCrypto correction[i * 2 + 1] = ((a >> 1) & 1) ^ mBaseChoice[mOtIdx + i * 2 + 1]; } co_await socks[1].send(std::move(correction)); - AlignedUnVector> buffer(mNumPoints * 3); - - co_await socks[1].recv(buffer); - - for (u64 i = 0; i < mNumPoints; ++i) - { - auto a = points[i][mDepth - iter]; - - auto k = H( - mBaseRecvOts[mOtIdx + i * 2 + 0], - mBaseRecvOts[mOtIdx + i * 2 + 1]); - //std::cout << "p" << mPartyIdx << " ka " << k << " = H( " - // << std::hex << mBaseRecvOts[i * 2 + 0].get(0) << " " << int(mBaseChoice[i * 2 + 0]) << " " - // << std::hex << mBaseRecvOts[i * 2 + 1].get(0) << " " << int(mBaseChoice[i * 2 + 0]) << " )" << " a1 " << int(a) << std::endl; - - //std::cout << "buffer " << std::endl - // << " " << buffer[i * 3 + a][0] << "\n" - // << " " << buffer[i * 3 + a][1] << "\n" - // << " " << buffer[i * 3 + a][2] << "\n"; - std::array ka = PRNG(k, 3).get(); - sigma[0][i] = ka[0] ^ buffer[i * 3 + a][0] ^ z[0][i]; - sigma[1][i] = ka[1] ^ buffer[i * 3 + a][1] ^ z[1][i]; - sigma[2][i] = ka[2] ^ buffer[i * 3 + a][2] ^ z[2][i]; - - //std::cout << "sigma " << std::endl - // << " " << sigma[0][i] << " = " << std::hex << ka[0].get(0) << " + " << std::hex << buffer[i * 3 + a][0].get(0) << " + " << std::hex << z[0][i].get(0) << "\n" - // << " " << sigma[1][i] << " = " << std::hex << ka[1].get(0) << " + " << std::hex << buffer[i * 3 + a][1].get(0) << " + " << std::hex << z[1][i].get(0) << "\n" - // << " " << sigma[2][i] << " = " << std::hex << ka[2].get(0) << " + " << std::hex << buffer[i * 3 + a][2].get(0) << " + " << std::hex << z[2][i].get(0) << "\n\n"; - } - co_await socks[1].send(Matrix(sigma)); + co_await socks[1].recv(recvBuffer); }; + co_await macoro::when_all_ready( sender(), recver() ); + for (u64 i = 0; i < mNumPoints; ++i) + { + auto a = points[i][mDepth - iter]; + + auto k = H( + mBaseRecvOts[mOtIdx + i * 2 + 0], + mBaseRecvOts[mOtIdx + i * 2 + 1]); + //std::cout << "p" << mPartyIdx << " ka " << k << " = H( " + // << std::hex << mBaseRecvOts[i * 2 + 0].get(0) << " " << int(mBaseChoice[i * 2 + 0]) << " " + // << std::hex << mBaseRecvOts[i * 2 + 1].get(0) << " " << int(mBaseChoice[i * 2 + 0]) << " )" << " a1 " << int(a) << std::endl; + + //std::cout << "buffer " << std::endl + // << " " << buffer[i * 3 + a][0] << "\n" + // << " " << buffer[i * 3 + a][1] << "\n" + // << " " << buffer[i * 3 + a][2] << "\n"; + std::array ka = PRNG(k, 3).get(); + //setBytes(ka, 0); + + sigma[0][i] = ka[0] ^ mask[i][0] ^ z[0][i]; + sigma[1][i] = ka[1] ^ mask[i][1] ^ z[1][i]; + sigma[2][i] = ka[2] ^ mask[i][2] ^ z[2][i]; + if (a) + { + sigma[0][i] ^= recvBuffer[i * 2 + a - 1][0]; + sigma[1][i] ^= recvBuffer[i * 2 + a - 1][1]; + sigma[2][i] ^= recvBuffer[i * 2 + a - 1][2]; + } + + //std::cout << "sigma " << std::endl + // << " " << sigma[0][i] << " = " << std::hex << ka[0].get(0) << " + " << std::hex << buffer[i * 3 + a][0].get(0) << " + " << std::hex << z[0][i].get(0) << "\n" + // << " " << sigma[1][i] << " = " << std::hex << ka[1].get(0) << " + " << std::hex << buffer[i * 3 + a][1].get(0) << " + " << std::hex << z[1][i].get(0) << "\n" + // << " " << sigma[2][i] << " = " << std::hex << ka[2].get(0) << " + " << std::hex << buffer[i * 3 + a][2].get(0) << " + " << std::hex << z[2][i].get(0) << "\n\n"; + } + co_await sock.send(Matrix(sigma)); + + co_await sock.recv(sigmaShares); + for (u64 i = 0; i < mNumPoints; ++i) { for (u64 j = 0; j < 3; ++j) { //std::cout << "sigma = " << (sigma[j][i] ^ sigmaShares[j][i]) << " = " << sigma[j][i] << " ^ " << sigmaShares[j][i] << std::endl; - sigma[j][i] ^= sigmaShares[j][i] ^ mask[i][j]; + sigma[j][i] ^= sigmaShares[j][i];//^ mask[i][j]; } } mOtIdx += mNumPoints * 2; - if (0) + if (1) { std::vector alphaj(mNumPoints); std::vector zz(mNumPoints * 3); auto zzIter = zz.begin(); diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index 8dfb4050..4b0feae6 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -384,6 +384,9 @@ namespace osuCrypto } setTimePoint("dpfKeyEval"); + + co_await dpfEval(prng, sock); + //std::cout << "block " << hash(blocks.data(), blocks.size()) << std::endl; @@ -445,5 +448,9 @@ namespace osuCrypto } + macoro::task<> FoleageF4Ole::dpfEval(PRNG& prng, coproto::Socket& sock) + { + co_return; + } } \ No newline at end of file diff --git a/libOTe/Tools/Foleage/FoleagePcg.h b/libOTe/Tools/Foleage/FoleagePcg.h index 67edb6c2..d25500c0 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.h +++ b/libOTe/Tools/Foleage/FoleagePcg.h @@ -61,6 +61,7 @@ namespace osuCrypto span CMsb, PRNG& prng, coproto::Socket& sock); + macoro::task<> dpfEval(PRNG& prng, coproto::Socket& sock); void sampleA(block seed); }; From 6b38a80163cd376bb5f04beab550e855b6c272f1 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Thu, 13 Feb 2025 18:04:25 -0800 Subject: [PATCH 11/48] dpf refactor and bugfix --- libOTe/Tools/Dpf/RegularDpf.h | 291 ++++++++++++++++++---------- libOTe/Tools/Dpf/SparseDpf.h | 2 +- libOTe/Tools/Dpf/TriDpf.h | 4 +- libOTe/Tools/Foleage/FoleagePcg.cpp | 7 +- libOTe/Tools/Foleage/FoleagePcg.h | 4 +- libOTe_Tests/RegularDpf_Tests.cpp | 11 +- 6 files changed, 204 insertions(+), 115 deletions(-) diff --git a/libOTe/Tools/Dpf/RegularDpf.h b/libOTe/Tools/Dpf/RegularDpf.h index 62af644f..3d26bce7 100644 --- a/libOTe/Tools/Dpf/RegularDpf.h +++ b/libOTe/Tools/Dpf/RegularDpf.h @@ -28,6 +28,14 @@ namespace osuCrypto return b.get(0) & 1; } + // extracts the lsb of b and returns a block saturated with that bit. + block tagBit(const block& b) + { + auto bit = b & block(0, 1); + auto mask = _mm_sub_epi64(_mm_set1_epi64x(0), bit); + return _mm_unpacklo_epi64(mask, mask); + } + void init( u64 partyIdx, u64 domain, @@ -86,16 +94,17 @@ namespace osuCrypto // shares of S' auto pow2 = 1ull << log2ceil(mDomain); - std::array, 2> s; - s[mDepth & 1].resize(pow2, numPoints, oc::AllocType::Uninitialized); - s[(mDepth & 1) ^ 1].resize(pow2 / 2, numPoints, oc::AllocType::Uninitialized); + std::array, 3> s; + s[mDepth % 3].resize(pow2, numPoints, oc::AllocType::Uninitialized); + s[(mDepth + 2) % 3].resize(pow2 / 2, numPoints, oc::AllocType::Uninitialized); + s[(mDepth + 1) % 3].resize(pow2 / 4, numPoints, oc::AllocType::Uninitialized); // share of t - std::array, 2> t; - t[0].resize(s[0].rows(), s[0].cols()); - t[1].resize(s[1].rows(), s[1].cols()); - for (u64 i = 0; i < numPoints; ++i) - t[0](0, i) = mPartyIdx; + //std::array, 2> t; + //t[0].resize(s[0].rows(), s[0].cols()); + //t[1].resize(s[1].rows(), s[1].cols()); + //for (u64 i = 0; i < numPoints; ++i) + // t[0](0, i) = mPartyIdx; #if defined(NDEBUG) @@ -103,27 +112,34 @@ namespace osuCrypto #else auto getRow = [](auto&& m, u64 i) {return m[i]; }; #endif - std::array, 2> tau; - tau[0].resize(mNumPoints); - tau[1].resize(mNumPoints); + //std::array, 2> tau; + //tau[0].resize(mNumPoints); + //tau[1].resize(mNumPoints); std::array, 2> z; z[0].resize(mNumPoints); z[1].resize(mNumPoints); - AlignedUnVector sigma(mNumPoints); + std::array, 2> sigma; + sigma[0].resize(mNumPoints); + sigma[1].resize(mNumPoints); + AlignedUnVector sigmaMult(mNumPoints); BitVector negAlphaj(mNumPoints); AlignedUnVector diff(mNumPoints); - + std::array temp; { // we skip level 0 and set level 1 to be random auto sc0 = s[1][0]; auto sc1 = s[1][1]; + + auto tag = s[0][0]; for (u64 k = 0; k < numPoints; ++k) { sc0[k] = prng.get(); sc1[k] = prng.get(); + tag[k] = block::allSame(-mPartyIdx); + z[0][k] = sc0[k]; z[1][k] = sc1[k]; } @@ -139,163 +155,231 @@ namespace osuCrypto for (u64 iter = 1; iter <= mDepth; ++iter) { // the grand parent level - auto& tp = t[(iter - 1) & 1]; + auto& tp = s[(iter - 1) % 3]; // the parent level - auto& sc = s[iter & 1]; - auto& tc = t[iter & 1]; + auto& sc = s[iter % 3]; + //auto& tc = t[iter & 1]; // the child level - auto& sg = s[(iter + 1) & 1]; + auto& sg = s[(iter + 1) % 3]; auto size = 1ull << iter; // for (u64 k = 0; k < mNumPoints; ++k) { - auto alphaj = *oc::BitIterator(&points[k], mDepth - iter); - tau[0][k] = lsb(z[0][k]) ^ alphaj ^ mPartyIdx; - tau[1][k] = lsb(z[1][k]) ^ alphaj; + u8 alphaj = *oc::BitIterator(&points[k], mDepth - iter); diff[k] = z[0][k] ^ z[1][k]; + *BitIterator(&diff[k]) = 0; + negAlphaj[k] = alphaj ^ mPartyIdx; } co_await mMultiplier.multiply(negAlphaj, diff, diff, sock); // sigma = z[1^alpha[j]] + std::vector buff(mNumPoints + divCeil(mNumPoints, 128)); + auto z1LsbIter = BitIterator(&buff[mNumPoints]); for (u64 k = 0; k < mNumPoints; ++k) - sigma[k] = diff[k] ^ z[0][k]; - - // reveal sigma and tau - u64 buffSize = sigma.size() * 16 + divCeil(mNumPoints * 2, 8); - AlignedUnVector sendBuff(buffSize), recvBuff(buffSize); - copyBytesMin(sendBuff, sigma); - auto sendBitIter = BitIterator(&sendBuff[numPoints * 16]); - auto recvBitIter = BitIterator(&recvBuff[numPoints * 16]); - for (u64 i = 0; i < mNumPoints; ++i) { - *sendBitIter++ = tau[0][i]; - *sendBitIter++ = tau[1][i]; + u8 alphaj = *oc::BitIterator(&points[k], mDepth - iter); + + // sigmaMult[k] = na * msbs(z0+z1) + z0 + na + // = msbs(z_na) + lsb(z0) + na + sigmaMult[k] = diff[k] ^ z[0][k] ^ block(0, mPartyIdx ^ alphaj); + + buff[k] = sigmaMult[k]; + + // lsb(z1) + a + *z1LsbIter++ = lsb(z[1][k]) ^ alphaj; } - co_await sock.send(std::move(sendBuff)); - co_await sock.recv(recvBuff); + //sigma[0] = msbs(z[alpha^1]) || + //sigma[1] = z[alpha^1] ^ unitVec(alpha, lsb(z[0]) ^ lsb(z[1]) ^ 1)[1] + + // reveal sigma and tau + co_await sock.send(coproto::copy(buff)); + co_await sock.recv(buff); + z1LsbIter = BitIterator(&buff[mNumPoints]); for (u64 k = 0; k < mNumPoints; ++k) { - block sk = *(block*)&recvBuff[k * sizeof(block)]; - sigma[k] ^= sk; - tau[0][k] ^= *recvBitIter++; - tau[1][k] ^= *recvBitIter++; + //std::cout << "sigma[0][k] = " << (sigmaMult[k] ^ diff[k]) << " = " << sigmaMult[k] << " ^ " << diff[k] << std::endl; + u8 alphaj = *oc::BitIterator(&points[k], mDepth - iter); + + sigma[0][k] = sigmaMult[k] ^ buff[k]; + sigma[1][k] = sigma[0][k]; + *BitIterator(&sigma[1][k]) = *z1LsbIter++ ^ lsb(z[1][k]) ^ alphaj; + } + if (1) + { + co_await sock.send(coproto::copy(negAlphaj)); + co_await sock.send(coproto::copy(z[0])); + co_await sock.send(coproto::copy(z[1])); + BitVector negAlphaj2(mNumPoints); + + std::array, 2> z2; + z2[0].resize(mNumPoints); + z2[1].resize(mNumPoints); + + co_await sock.recv(negAlphaj2); + co_await sock.recv(z2[0]); + co_await sock.recv(z2[1]); + + auto negA = negAlphaj ^ negAlphaj2; + for (u64 i = 0; i < mNumPoints; ++i) + { + auto na = negA[i]; + auto a = na ^ 1; + block exp[2], zz[2]; + zz[0] = z[0][i] ^ z2[0][i]; + zz[1] = z[1][i] ^ z2[1][i]; + + exp[0] = (zz[na] & ~OneBlock) ^ block(0, lsb(zz[0]) ^ na); + exp[1] = (zz[na] & ~OneBlock) ^ block(0, lsb(zz[1]) ^ a); + //std::cout << "a " << int(a) << std::endl; + //std::cout + // << "z[0] " << zz[0] << " " << int(lsb(zz[0])) + // << "\nz[1] " << zz[1] << " " << int(lsb(zz[1])) << std::endl; + + + //exp[negA[i]] ^= block(0, 1); + //std::cout << "e[0] " << exp[0] << "\ne[1] " << exp[1] << std::endl; + //std::cout << "s[0] " << sigma[0][i] << "\ns[1] " << sigma[1][i] << std::endl; + + if (sigma[0][i] != exp[0]) + { + std::cout << "exp " << exp[0] << " act " << sigma[0][i] << std::endl; + std::cout << "a " << (1 ^ negA[i]) << std::endl; + throw RTE_LOC; + } + if (sigma[1][i] != exp[1]) + { + std::cout << "exp " << exp[1] << " act " << sigma[1][i] << std::endl; + std::cout << "a " << (1 ^ negA[i]) << std::endl; + throw RTE_LOC; + } + } + + } if (iter != mDepth) { + //std::cout << std::endl; setBytes(z[0], 0); setBytes(z[1], 0); + // we iterate over the parent tags. Each has two children. We expend + // these two children into 4 grandchildren. for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 2, L4 += 4) { // parent control bits - auto tpl = getRow(tp, L); + auto parentTag = getRow(tp, L); // child seed - std::array scl{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; - - // child control bit - std::array tcl{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; + std::array currentSeed{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; // grandchild seeds - std::array sgl{ getRow(sg, L4 + 0), getRow(sg, L4 + 1), getRow(sg, L4 + 2), getRow(sg, L4 + 3) }; + std::array childSeed{ getRow(sg, L4 + 0), getRow(sg, L4 + 1), getRow(sg, L4 + 2), getRow(sg, L4 + 3) }; for (u64 k = 0; k < numPoints8; k += 8) { - block temp[8]; - SIMD8(q, temp[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); - SIMD8(q, tcl[0][k + q] = lsb(scl[0][k + q]) ^ tpl[k + q] & tau[0][k + q]); - SIMD8(q, scl[0][k + q] ^= temp[q]); - + // for each child + for (u64 j = 0; j < 2; ++j) + { + // update seed with correction + SIMD8(q, currentSeed[j][k + q] ^= parentTag[k + q] & sigma[j][k + q]); - mAesFixedKey.ecbEncBlocks<8>(&scl[0][k], &sgl[1][k]); - SIMD8(q, sgl[0][k + q] = AES::roundEnc(sgl[1][k + q], scl[0][k + q])); - SIMD8(q, sgl[1][k + q] = sgl[1][k + q] + scl[0][k + q]); + // (s0', s1') = H(s) + mAesFixedKey.ecbEncBlocks<8>(¤tSeed[j][k], &temp[0]); + SIMD8(q, childSeed[j * 2 + 0][k + q] = AES::roundEnc(temp[q], childSeed[j * 2 + 1][k + q])); + SIMD8(q, childSeed[j * 2 + 1][k + q] = childSeed[j * 2 + 1][k + q] + temp[q]); - SIMD8(q, z[0][k + q] ^= sgl[0][k + q]); - SIMD8(q, z[1][k + q] ^= sgl[1][k + q]); + // z = z ^ s' + SIMD8(q, z[0][k + q] ^= childSeed[j * 2 + 0][k + q]); + SIMD8(q, z[1][k + q] ^= childSeed[j * 2 + 1][k + q]); - SIMD8(q, tcl[1][k + q] = lsb(scl[1][k + q]) ^ tpl[k + q] & tau[1][k + q]); - SIMD8(q, scl[1][k + q] ^= temp[q]); + // extract the tag from the seed + SIMD8(q, currentSeed[j][k + q] = tagBit(currentSeed[j][k + q])); + } - mAesFixedKey.ecbEncBlocks<8>(&scl[1][k], &sgl[3][k]); - SIMD8(q, sgl[2][k + q] = AES::roundEnc(sgl[3][k + q], scl[1][k + q])); - SIMD8(q, sgl[3][k + q] = sgl[3][k + q] + scl[1][k + q]); - SIMD8(q, z[0][k + q] ^= sgl[2][k + q]); - SIMD8(q, z[1][k + q] ^= sgl[3][k + q]); } for (u64 k = numPoints8; k < mNumPoints; ++k) { - auto temp = block::allSame(-tpl[k + 0]) & sigma[k + 0]; + for (u64 j = 0; j < 2; ++j) + { + //std::cout << "s[" << iter << "][" << L2 + j << "] " << currentSeed[j][k] << " -> "; - tcl[0][k] = lsb(scl[0][k]) ^ tpl[k] & tau[0][k]; - scl[0][k] ^= temp; + currentSeed[j][k] ^= parentTag[k] & sigma[j][k]; - sgl[1][k] = mAesFixedKey.ecbEncBlock(scl[0][k]); - sgl[0][k] = AES::roundEnc(sgl[1][k], scl[0][k]); - sgl[1][k] = sgl[1][k] + scl[0][k]; + //std::cout << currentSeed[j][k]<<" " << int(lsb(currentSeed[j][k])) << " via " << (parentTag[k] & sigma[j][k]) << std::endl; - z[0][k] ^= sgl[0][k]; - z[1][k] ^= sgl[1][k]; + temp[0] = mAesFixedKey.ecbEncBlock(currentSeed[j][k]); + childSeed[j * 2 + 0][k] = AES::roundEnc(temp[0], currentSeed[j][k]); + childSeed[j * 2 + 1][k] = temp[0] + currentSeed[j][k]; - tcl[1][k] = lsb(scl[1][k]) ^ tpl[k] & tau[1][k]; - scl[1][k] ^= temp; + z[0][k] ^= childSeed[j * 2 + 0][k]; + z[1][k] ^= childSeed[j * 2 + 1][k]; - sgl[3][k] = mAesFixedKey.ecbEncBlock(scl[1][k]); - sgl[2][k] = AES::roundEnc(sgl[3][k], scl[1][k]); - sgl[3][k] = sgl[3][k] + scl[1][k]; + //std::cout << "z1 += " << childSeed[j * 2 + 1][k] << std::endl; - z[0][k] ^= sgl[2][k]; - z[1][k] ^= sgl[3][k]; + currentSeed[j][k] = tagBit(currentSeed[j][k]); + } } } } } + auto size = roundUpTo(mDomain, 2); + Matrix tags(size, mNumPoints); + setBytes(diff, 0); // fixing the last layer { - auto size = 1ull << mDepth; - auto& tp = t[(mDepth - 1) & 1]; - auto& sc = s[mDepth & 1]; - auto& tc = t[mDepth & 1]; + auto& tp = s[(mDepth - 1) % 3]; + auto& sc = s[mDepth % 3]; + auto& tc = tags; + for (u64 L = 0, L2 = 0; L2 < size; ++L, L2 += 2) { // parent control bits - auto tpl = getRow(tp, L); + auto parentTag = getRow(tp, L); // child seed - std::array scl{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; + std::array currentSeed{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; // child control bit - std::array tcl{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; + std::array tag{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; for (u64 k = 0; k < numPoints8; k += 8) { - block temp[8]; - SIMD8(q, temp[q] = block::allSame(-tpl[k + q]) & sigma[k + q]); - SIMD8(q, tcl[0][k + q] = lsb(scl[0][k + q]) ^ tpl[k + q] & tau[0][k + q]); - SIMD8(q, tcl[1][k + q] = lsb(scl[1][k + q]) ^ tpl[k + q] & tau[1][k + q]); - SIMD8(q, scl[0][k + q] ^= temp[q]); - SIMD8(q, scl[1][k + q] ^= temp[q]); + for (u64 j = 0; j < 2; ++j) + { + + SIMD8(q, temp[q] = currentSeed[j][k + q] ^ parentTag[k + q] & sigma[j][k + q]); + SIMD8(q, tag[j][k + q] = tagBit(temp[q])); + + SIMD8(q, currentSeed[j][k + q] = AES::roundFn(temp[q], temp[q])); + SIMD8(q, diff[k+q] ^= currentSeed[j][k+q]); + + } } for (u64 k = numPoints8; k < mNumPoints; ++k) { - auto temp = block::allSame(-tpl[k + 0]) & sigma[k + 0]; - tc[L2 + 0][k] = lsb(scl[0][k]) ^ tpl[k] & tau[0][k]; - tc[L2 + 1][k] = lsb(scl[1][k]) ^ tpl[k] & tau[1][k]; - sc[L2 + 0][k] ^= temp; - sc[L2 + 1][k] ^= temp; + for (u64 j = 0; j < 2; ++j) + { + //std::cout << "s[" << mDepth << "][" << L2 + j << "] " << currentSeed[j][k] << " -> "; + temp[0] = currentSeed[j][k] ^ parentTag[k] & sigma[j][k]; + tag[j][k] = tagBit(temp[0]); + currentSeed[j][k] = AES::roundFn(temp[0], temp[0]); + diff[k] ^= currentSeed[j][k]; + + //std::cout << currentSeed[j][k] << " " << int(lsb(currentSeed[j][k])) << " via " << (parentTag[k] & sigma[j][k]) << std::endl; + } } } } @@ -306,17 +390,17 @@ namespace osuCrypto AlignedUnVector gamma(mNumPoints); for (u64 k = 0; k < mNumPoints; ++k) { - diff[k] = z[0][k] ^ z[1][k] ^ values[k]; + diff[k] ^= values[k]; } - co_await sock.send(std::move(diff)); + co_await sock.send(coproto::copy(diff)); co_await sock.recv(gamma); for (u64 k = 0; k < mNumPoints; ++k) { - gamma[k] = z[0][k] ^ z[1][k] ^ values[k] ^ gamma[k]; + gamma[k] ^= diff[k]; } - auto& sd = s[mDepth & 1]; - auto& td = t[mDepth & 1]; + auto& sd = s[mDepth % 3]; + auto& td = tags; for (u64 i = 0; i < mDomain; ++i) { auto sdi = getRow(sd, i); @@ -325,13 +409,12 @@ namespace osuCrypto for (u64 k = 0; k < numPoints8; k += 8) { block T[8]; - - SIMD8(q, T[q] = block::allSame(-tdi[k + q]) & gamma[k + q]); - SIMD8(q, output(k + q, i, sdi[k + q] ^ T[q], tdi[k+q])); + SIMD8(q, T[q] = tdi[k + q] & gamma[k + q]); + SIMD8(q, output(k + q, i, sdi[k + q] ^ T[q], tdi[k + q])); } for (u64 k = numPoints8; k < mNumPoints; ++k) { - auto T = block::allSame(-tdi[k]) & gamma[k]; + auto T = tdi[k] & gamma[k]; output(k, i, sdi[k] ^ T, tdi[k]); } } @@ -339,7 +422,7 @@ namespace osuCrypto else { auto& sd = s[mDepth & 1]; - auto& td = t[mDepth & 1]; + auto& td = tags; for (u64 i = 0; i < mDomain; ++i) { auto sdi = getRow(sd, i); @@ -357,7 +440,7 @@ namespace osuCrypto } - u64 baseOtCount() const { + u64 baseOtCount() const { return mMultiplier.baseOtCount(); } diff --git a/libOTe/Tools/Dpf/SparseDpf.h b/libOTe/Tools/Dpf/SparseDpf.h index 2804fa2a..7487f2d9 100644 --- a/libOTe/Tools/Dpf/SparseDpf.h +++ b/libOTe/Tools/Dpf/SparseDpf.h @@ -275,7 +275,7 @@ namespace osuCrypto Matrix tags(points.size(), 1ull << mDenseDepth); co_await mRegDpf.expand(densePoints, {}, [&](auto treeIdx, auto leafIdx, auto seed, auto tag) { seeds(treeIdx, leafIdx) = seed; - tags(treeIdx, leafIdx) = tag; + tags(treeIdx, leafIdx) = tag.get(0)&1; }, prng, sock); for (u64 r = 0; r < sparsePoints.rows(); ++r) diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index faa5e716..0e1a8081 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -112,7 +112,7 @@ namespace osuCrypto } }; - std::ostream& operator<<(std::ostream& o, const Trit32& t) + inline std::ostream& operator<<(std::ostream& o, const Trit32& t) { u64 m = 0; u64 v = t.mVal; @@ -598,7 +598,7 @@ namespace osuCrypto mOtIdx += mNumPoints * 2; - if (1) + if (0) { std::vector alphaj(mNumPoints); std::vector zz(mNumPoints * 3); auto zzIter = zz.begin(); diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index 4b0feae6..25481d44 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -5,6 +5,7 @@ #include "cryptoTools/Common/BitIterator.h" #include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" #include "libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h" +#include "libOTe/Tools/Dpf/TriDpf.h" namespace osuCrypto { @@ -190,6 +191,7 @@ namespace osuCrypto std::vector prodPolyCoefficient(mC * mC * mT * mT); std::vector prodPolyPosition(mC * mC * mT * mT); + std::vector prodPolyPositionTrit(mC * mC * mT * mT); std::vector tritABlk(mLog3T), tritBBlk(mLog3T), tritsBlk(mLog3T); std::vector tritAPos(mLog3N - mLog3T), tritBPos(mLog3N - mLog3T), tritsPos(mLog3N - mLog3T); @@ -254,6 +256,7 @@ namespace osuCrypto size_t idx = polyOffset + blockIdx * mT + nextIdx[blockIdx]++; prodPolyCoefficient[idx] = mult_f4(vA, vB); prodPolyPosition[idx] = subblock_pos; + prodPolyPositionTrit[idx] = subblock_pos; } } @@ -385,7 +388,7 @@ namespace osuCrypto setTimePoint("dpfKeyEval"); - co_await dpfEval(prng, sock); + co_await dpfEval(prodPolyPositionTrit, prodPolyCoefficient,prng, sock); //std::cout << "block " << hash(blocks.data(), blocks.size()) << std::endl; @@ -448,7 +451,7 @@ namespace osuCrypto } - macoro::task<> FoleageF4Ole::dpfEval(PRNG& prng, coproto::Socket& sock) + macoro::task<> FoleageF4Ole::dpfEval(span points, span coeffs, PRNG& prng, coproto::Socket& sock) { co_return; } diff --git a/libOTe/Tools/Foleage/FoleagePcg.h b/libOTe/Tools/Foleage/FoleagePcg.h index d25500c0..51b49bc5 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.h +++ b/libOTe/Tools/Foleage/FoleagePcg.h @@ -5,7 +5,7 @@ #include "coproto/Socket/Socket.h" #include "cryptoTools/Crypto/PRNG.h" #include "cryptoTools/Common/Timer.h" - +#include "libOTe/Tools/Dpf/TriDpf.h" namespace osuCrypto { @@ -61,7 +61,7 @@ namespace osuCrypto span CMsb, PRNG& prng, coproto::Socket& sock); - macoro::task<> dpfEval(PRNG& prng, coproto::Socket& sock); + macoro::task<> dpfEval(span points, span coeffs, PRNG& prng, coproto::Socket& sock); void sampleA(block seed); }; diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index ce215eda..077df56d 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -115,8 +115,8 @@ void RegularDpf_Multiply_Test(const CLP& cmd) void RegularDpf_Proto_Test(const CLP& cmd) { PRNG prng(block(231234, 321312)); - u64 domain = 131; - u64 numPoints = 11; + u64 domain = cmd.getOr("domain", 211); + u64 numPoints = cmd.getOr("numPoints", 11); std::vector points0(numPoints); std::vector points1(numPoints); std::vector values0(numPoints); @@ -165,8 +165,8 @@ void RegularDpf_Proto_Test(const CLP& cmd) auto sock = coproto::LocalAsyncSocket::makePair(); macoro::sync_wait(macoro::when_all_ready( - dpf[0].expand(points0, values0, [&](auto k, auto i, auto v, auto t) { output[0](k, i) = v; tags[0](k, i) = t; }, prng, sock[0]), - dpf[1].expand(points1, values1, [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; tags[1](k, i) = t; }, prng, sock[1]) + dpf[0].expand(points0, values0, [&](auto k, auto i, auto v, auto t) { output[0](k, i) = v; tags[0](k, i) = t.get(0)&1; }, prng, sock[0]), + dpf[1].expand(points1, values1, [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }, prng, sock[1]) )); @@ -180,7 +180,10 @@ void RegularDpf_Proto_Test(const CLP& cmd) auto tAct = tags[0][k][i] ^ tags[1][k][i]; auto exp = t ? (values0[k] ^ values1[k]) : ZeroBlock; if (exp != act) + { + throw RTE_LOC; + } if (t != tAct) throw RTE_LOC; } From 9690295ef3245e18dccbeb56d4e092ac0e85343c Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Fri, 14 Feb 2025 01:00:40 -0800 Subject: [PATCH 12/48] dpf refactor and keygen --- frontend/benchmark.h | 4 +- libOTe/Tools/Dpf/RegularDpf.h | 856 ++++++++++++++++++++---------- libOTe/Tools/Dpf/SparseDpf.h | 4 +- libOTe_Tests/RegularDpf_Tests.cpp | 91 +++- libOTe_Tests/RegularDpf_Tests.h | 1 + libOTe_Tests/UnitTests.cpp | 1 + 6 files changed, 670 insertions(+), 287 deletions(-) diff --git a/frontend/benchmark.h b/frontend/benchmark.h index b24b8ee1..431f613d 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -750,8 +750,8 @@ namespace osuCrypto timer.setTimePoint("start"); macoro::sync_wait(macoro::when_all_ready( - dpf[0].expand(points0, values0, [&](auto k, auto i, auto v, auto t) { output[0](k,i) = v; }, prng, sock[0]), - dpf[1].expand(points1, values1, [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; }, prng, sock[1]) + dpf[0].expand(points0, values0, prng.get(), [&](auto k, auto i, auto v, auto t) { output[0](k, i) = v; }, sock[0]), + dpf[1].expand(points1, values1, prng.get(), [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; }, sock[1]) )); timer.setTimePoint("finish"); diff --git a/libOTe/Tools/Dpf/RegularDpf.h b/libOTe/Tools/Dpf/RegularDpf.h index 3d26bce7..42212a57 100644 --- a/libOTe/Tools/Dpf/RegularDpf.h +++ b/libOTe/Tools/Dpf/RegularDpf.h @@ -11,6 +11,47 @@ namespace osuCrypto { + struct RegularDpfKey + { + void resize(u64 domain, u64 numTrees) + { + auto depth = log2ceil(domain); + if (depth == 0) + throw RTE_LOC; + + mCorrectionWords.resize(depth, numTrees); + mCorrectionBits.resize(depth, numTrees); + } + block mSeed; + Matrix mCorrectionWords; + Matrix mCorrectionBits; + std::vector mLeafVals; + + bool operator==(const RegularDpfKey& o) const + { + return + mSeed == o.mSeed && + mCorrectionWords == o.mCorrectionWords && + mCorrectionBits == o.mCorrectionBits && + mLeafVals == o.mLeafVals; + } + }; + + inline std::ostream& operator<<(std::ostream& o, const RegularDpfKey& k) + { + o << k.mSeed << std::endl; + for (u64 i = 0; i < k.mCorrectionWords.size(); ++i) + { + o << k.mCorrectionWords(i) << " " << int(k.mCorrectionBits(i)) << " "; + } + o << std::endl; + for (u64 i = 0; i < k.mLeafVals.size(); ++i) + o << k.mLeafVals[i] << " "; + o << std::endl; + + return o; + } + struct RegularDpf { u64 mPartyIdx = 0; @@ -23,150 +64,299 @@ namespace osuCrypto DpfMult mMultiplier; - u8 lsb(const block& b) + // used to initialize the interactive protocols. + void init( + u64 partyIdx, + u64 domain, + u64 numPoints); + + // returns the number of OTs required for the protocol. + // each party must have this many OTs as the sender and + // as the receiver. + u64 baseOtCount() const; + + // set the base OTs. + void setBaseOts( + span> baseSendOts, + span recvBaseOts, + const oc::BitVector& baseChoices); + + // perform interactive full domain eval. + // - points should be a secret sharing of the locations. + // - values should be a secret sarhing of the values. + // - seed should be a random seed. + // - output should be a lambda of the form [](treeIdx, leadIdx, value, tag){...} + // this will be called for each leaf value produced. tag is a zero/one secret sharing + // indicating if this is the active leaf. + // - sock is the network socket to the other party. + template + macoro::task<> expand( + span points, + span values, + block seed, + Output&& output, + coproto::Socket& sock); + + + // perform interactive key generation. + // - points should be a secret sharing of the locations. + // - values should be a secret sarhing of the values. + // - seed should be a random seed. + // - outputKey is where the result is written to. + // - sock is the network socket to the other party. + macoro::task<> keyGen( + span points, + span values, + block seed, + RegularDpfKey& outputKey, + coproto::Socket& sock); + + + // A static function that can generate a pair of keys. + // - domain is the number of leaf values. + // - points is the plaintext list of locations. + // - values is the plaintext list of values. + // - prng is the source of randomness. + // - keys is a list of two keys where the result is written. + static void keyGen( + u64 domain, + span points, + span values, + PRNG& prng, + span keys); + + // A static function that performs non-interative + // full domain evaluation. + // - partyIdx is this partie's index, 0 or 1. + // - domain is the number of leaf values. + // - key is the share of the FSS key. + // - output should be a lambda of the form [](treeIdx, leadIdx, value, tag){...} + // this will be called for each leaf value produced. tag is a zero/one secret sharing + // indicating if this is the active leaf. + template + static void expand( + u64 partyIdx, + u64 domain, + RegularDpfKey& key, + Output&& output); + + + // the internal implementation. This function can be called with + // different parameters. + // + // For distributed keygen, points, values should be shared and seed is some + // random see. inputKey == nullptr, output = anything, and outputKey + // should point to valid object. + // + // For interactive expand (without an existing key), the parameters are the same + // except that outputKey should be null and output should be a lambda of the form + // [](treeIdx, leadIdx, value, tag){...} + // + // For non-interactive expand (with an existing key), points, values, seed are + // all ignored. inputKey should point to an existing dpf key. sock is ignored. + // output should be a lambda as above. + // + template + macoro::task<> implExpand( + span points, + span values, + block seed, + RegularDpfKey* inputKey, + Output&& output, + coproto::Socket& sock, + RegularDpfKey* outputKey); + + + static u8 lsb(const block& b) { return b.get(0) & 1; } // extracts the lsb of b and returns a block saturated with that bit. - block tagBit(const block& b) + static block tagBit(const block& b) { auto bit = b & block(0, 1); auto mask = _mm_sub_epi64(_mm_set1_epi64x(0), bit); return _mm_unpacklo_epi64(mask, mask); } + }; - void init( - u64 partyIdx, - u64 domain, - u64 numPoints) - { - if (partyIdx > 1) - throw RTE_LOC; - if (domain < 2) - throw RTE_LOC; - if (!numPoints) - throw RTE_LOC; - mDepth = log2ceil(domain); - mPartyIdx = partyIdx; - mDomain = domain; - mNumPoints = numPoints; - mMultiplier.init(partyIdx, numPoints * mDepth); - } -#define SIMD8(VAR, STATEMENT) \ - { constexpr u64 VAR = 0; STATEMENT; }\ - { constexpr u64 VAR = 1; STATEMENT; }\ - { constexpr u64 VAR = 2; STATEMENT; }\ - { constexpr u64 VAR = 3; STATEMENT; }\ - { constexpr u64 VAR = 4; STATEMENT; }\ - { constexpr u64 VAR = 5; STATEMENT; }\ - { constexpr u64 VAR = 6; STATEMENT; }\ - { constexpr u64 VAR = 7; STATEMENT; }\ - do{}while(0) + inline void RegularDpf::init( + u64 partyIdx, + u64 domain, + u64 numPoints) + { + if (partyIdx > 1) + throw RTE_LOC; + if (domain < 2) + throw RTE_LOC; + if (!numPoints) + throw RTE_LOC; + + mDepth = log2ceil(domain); + mPartyIdx = partyIdx; + mDomain = domain; + mNumPoints = numPoints; + mMultiplier.init(partyIdx, numPoints * mDepth); + } + + + template + macoro::task<> RegularDpf::expand( + span points, + span values, + block seed, + Output&& output, + coproto::Socket& sock) + { + return implExpand(points, values, seed, nullptr, output, sock, nullptr); + } - template< - typename Output - > - macoro::task<> expand( - span points, - span values, - Output&& output, - PRNG& prng, - coproto::Socket& sock) + + + // distributed keygen, points, values should be shared and seed is some + // random see. inputKey == nullptr, output = anything, and outputKey + // should point to valid object. Base OTs must be set. + inline macoro::task<> RegularDpf::keyGen( + span points, + span values, + block seed, + RegularDpfKey& outputKey, + coproto::Socket& sock) + { + return implExpand(points, values, seed, nullptr, [](auto, auto, auto, auto) {}, sock, &outputKey); + } + + // the internal implementation. This function can be called with + // different parameters. + // + // For distributed keygen, points, values should be shared and seed is some + // random see. inputKey == nullptr, output = anything, and outputKey + // should point to valid object. + // + // For interactive expand (without an existing key), the parameters are the same + // except that outputKey should be null and output should be a lambda of the form + // [](treeIdx, leadIdx, value, tag){...} + // + // For non-interactive expand (with an existing key), points, values, seed are + // all ignored. inputKey should point to an existing dpf key. sock is ignored. + // output should be a lambda as above. + // + template + macoro::task<> RegularDpf::implExpand( + span points, + span values, + block seed, + RegularDpfKey* inputKey, + Output&& output, + coproto::Socket& sock, + RegularDpfKey* outputKey) + { + if (inputKey == nullptr) { - if constexpr (std::is_same, Matrix>::value) - { - if (output.rows() != mNumPoints) - throw RTE_LOC; - if (output.cols() != mDomain) - throw RTE_LOC; - } if (points.size() != mNumPoints) throw RTE_LOC; if (values.size() && values.size() != mNumPoints) throw RTE_LOC; + } + else + { + if (outputKey) + throw RTE_LOC; + } - u64 numPoints = points.size(); - u64 numPoints8 = numPoints / 8 * 8; - - - // shares of S' - auto pow2 = 1ull << log2ceil(mDomain); - std::array, 3> s; - s[mDepth % 3].resize(pow2, numPoints, oc::AllocType::Uninitialized); - s[(mDepth + 2) % 3].resize(pow2 / 2, numPoints, oc::AllocType::Uninitialized); - s[(mDepth + 1) % 3].resize(pow2 / 4, numPoints, oc::AllocType::Uninitialized); + u64 numPoints = mNumPoints; + u64 numPoints8 = numPoints / 8 * 8; - // share of t - //std::array, 2> t; - //t[0].resize(s[0].rows(), s[0].cols()); - //t[1].resize(s[1].rows(), s[1].cols()); - //for (u64 i = 0; i < numPoints; ++i) - // t[0](0, i) = mPartyIdx; + // shares of S' + auto pow2 = 1ull << log2ceil(mDomain); + std::array, 3> s; + s[mDepth % 3].resize(pow2, numPoints, oc::AllocType::Uninitialized); + s[(mDepth + 2) % 3].resize(pow2 / 2, numPoints, oc::AllocType::Uninitialized); + s[(mDepth + 1) % 3].resize(pow2 / 4, numPoints, oc::AllocType::Uninitialized); #if defined(NDEBUG) - auto getRow = [](auto&& m, u64 i) {return m.data(i); }; + auto getRow = [](auto&& m, u64 i) {return m.data(i); }; #else - auto getRow = [](auto&& m, u64 i) {return m[i]; }; + auto getRow = [](auto&& m, u64 i) {return m[i]; }; #endif - //std::array, 2> tau; - //tau[0].resize(mNumPoints); - //tau[1].resize(mNumPoints); - - std::array, 2> z; - z[0].resize(mNumPoints); - z[1].resize(mNumPoints); - std::array, 2> sigma; - sigma[0].resize(mNumPoints); - sigma[1].resize(mNumPoints); - AlignedUnVector sigmaMult(mNumPoints); - BitVector negAlphaj(mNumPoints); - AlignedUnVector diff(mNumPoints); - std::array temp; - { - // we skip level 0 and set level 1 to be random - auto sc0 = s[1][0]; - auto sc1 = s[1][1]; + if (outputKey) + { + outputKey->resize(mDomain, numPoints); + } - auto tag = s[0][0]; - for (u64 k = 0; k < numPoints; ++k) - { - sc0[k] = prng.get(); - sc1[k] = prng.get(); + std::array, 2> z; + z[0].resize(mNumPoints); + z[1].resize(mNumPoints); + std::array, 2> sigma; + sigma[0].resize(mNumPoints); + sigma[1].resize(mNumPoints); + AlignedUnVector sigmaMult(mNumPoints); + BitVector negAlphaj(mNumPoints); + AlignedUnVector diff(mNumPoints); + std::array temp; - tag[k] = block::allSame(-mPartyIdx); + { + if (inputKey) + seed = inputKey->mSeed; + + // we skip level 0 and set level 1 to be random + if (outputKey) + outputKey->mSeed = seed; + + auto sc0 = s[1][0]; + auto sc1 = s[1][1]; + auto tag = s[0][0]; + PRNG basePeng(seed); + for (u64 k = 0; k < numPoints; ++k) + { + sc0[k] = basePeng.get(); + sc1[k] = basePeng.get(); - z[0][k] = sc0[k]; - z[1][k] = sc1[k]; - } + tag[k] = block::allSame(-mPartyIdx); + + z[0][k] = sc0[k]; + z[1][k] = sc1[k]; } + } - // at each iteration we first correct the parent level. - // The parent level has two syblings which are random. - // We need to correct the inactive child so that both parties - // hold the same seed (a sharing of zero). - // - // we then expand the parent to level to get the children level. - // We compute left and right sums for the children. - for (u64 iter = 1; iter <= mDepth; ++iter) - { - // the grand parent level - auto& tp = s[(iter - 1) % 3]; + // at each iteration we first correct the parent level. + // The parent level has two syblings which are random. + // We need to correct the inactive child so that both parties + // hold the same seed (a sharing of zero). + // + // we then expand the parent to level to get the children level. + // We compute left and right sums for the children. + for (u64 iter = 1; iter <= mDepth; ++iter) + { + // the grand parent level + auto& tp = s[(iter - 1) % 3]; - // the parent level - auto& sc = s[iter % 3]; - //auto& tc = t[iter & 1]; + // the parent level + auto& sc = s[iter % 3]; + //auto& tc = t[iter & 1]; - // the child level - auto& sg = s[(iter + 1) % 3]; + // the child level + auto& sg = s[(iter + 1) % 3]; - auto size = 1ull << iter; + auto size = 1ull << iter; - // + if (inputKey) + { + for (u64 k = 0; k < mNumPoints; ++k) + { + sigma[0][k] = inputKey->mCorrectionWords(iter - 1, k); + sigma[1][k] = sigma[0][k]; + *BitIterator(&sigma[1][k]) = inputKey->mCorrectionBits(iter - 1, k); + } + + } + else + { for (u64 k = 0; k < mNumPoints; ++k) { u8 alphaj = *oc::BitIterator(&points[k], mDepth - iter); @@ -183,18 +373,10 @@ namespace osuCrypto for (u64 k = 0; k < mNumPoints; ++k) { u8 alphaj = *oc::BitIterator(&points[k], mDepth - iter); - - // sigmaMult[k] = na * msbs(z0+z1) + z0 + na - // = msbs(z_na) + lsb(z0) + na sigmaMult[k] = diff[k] ^ z[0][k] ^ block(0, mPartyIdx ^ alphaj); - buff[k] = sigmaMult[k]; - - // lsb(z1) + a *z1LsbIter++ = lsb(z[1][k]) ^ alphaj; } - //sigma[0] = msbs(z[alpha^1]) || - //sigma[1] = z[alpha^1] ^ unitVec(alpha, lsb(z[0]) ^ lsb(z[1]) ^ 1)[1] // reveal sigma and tau co_await sock.send(coproto::copy(buff)); @@ -202,148 +384,80 @@ namespace osuCrypto z1LsbIter = BitIterator(&buff[mNumPoints]); for (u64 k = 0; k < mNumPoints; ++k) { - //std::cout << "sigma[0][k] = " << (sigmaMult[k] ^ diff[k]) << " = " << sigmaMult[k] << " ^ " << diff[k] << std::endl; u8 alphaj = *oc::BitIterator(&points[k], mDepth - iter); - - sigma[0][k] = sigmaMult[k] ^ buff[k]; + auto sigma1Bit = *z1LsbIter++ ^ lsb(z[1][k]) ^ alphaj; + sigma[0][k] = buff[k] ^ sigmaMult[k]; sigma[1][k] = sigma[0][k]; - *BitIterator(&sigma[1][k]) = *z1LsbIter++ ^ lsb(z[1][k]) ^ alphaj; - - } - - if (1) - { - co_await sock.send(coproto::copy(negAlphaj)); - co_await sock.send(coproto::copy(z[0])); - co_await sock.send(coproto::copy(z[1])); - BitVector negAlphaj2(mNumPoints); - - std::array, 2> z2; - z2[0].resize(mNumPoints); - z2[1].resize(mNumPoints); - - co_await sock.recv(negAlphaj2); - co_await sock.recv(z2[0]); - co_await sock.recv(z2[1]); - - auto negA = negAlphaj ^ negAlphaj2; - for (u64 i = 0; i < mNumPoints; ++i) + *BitIterator(&sigma[1][k]) = sigma1Bit; + if (outputKey) { - auto na = negA[i]; - auto a = na ^ 1; - block exp[2], zz[2]; - zz[0] = z[0][i] ^ z2[0][i]; - zz[1] = z[1][i] ^ z2[1][i]; - - exp[0] = (zz[na] & ~OneBlock) ^ block(0, lsb(zz[0]) ^ na); - exp[1] = (zz[na] & ~OneBlock) ^ block(0, lsb(zz[1]) ^ a); - //std::cout << "a " << int(a) << std::endl; - //std::cout - // << "z[0] " << zz[0] << " " << int(lsb(zz[0])) - // << "\nz[1] " << zz[1] << " " << int(lsb(zz[1])) << std::endl; - - - //exp[negA[i]] ^= block(0, 1); - //std::cout << "e[0] " << exp[0] << "\ne[1] " << exp[1] << std::endl; - //std::cout << "s[0] " << sigma[0][i] << "\ns[1] " << sigma[1][i] << std::endl; - - if (sigma[0][i] != exp[0]) - { - std::cout << "exp " << exp[0] << " act " << sigma[0][i] << std::endl; - std::cout << "a " << (1 ^ negA[i]) << std::endl; - throw RTE_LOC; - } - if (sigma[1][i] != exp[1]) - { - std::cout << "exp " << exp[1] << " act " << sigma[1][i] << std::endl; - std::cout << "a " << (1 ^ negA[i]) << std::endl; - throw RTE_LOC; - } + outputKey->mCorrectionWords(iter - 1, k) = sigma[0][k]; + outputKey->mCorrectionBits(iter - 1, k) = sigma1Bit; } - } + } - if (iter != mDepth) - { - //std::cout << std::endl; - - setBytes(z[0], 0); - setBytes(z[1], 0); - - // we iterate over the parent tags. Each has two children. We expend - // these two children into 4 grandchildren. - for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 2, L4 += 4) - { - // parent control bits - auto parentTag = getRow(tp, L); - - // child seed - std::array currentSeed{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; - - // grandchild seeds - std::array childSeed{ getRow(sg, L4 + 0), getRow(sg, L4 + 1), getRow(sg, L4 + 2), getRow(sg, L4 + 3) }; - - for (u64 k = 0; k < numPoints8; k += 8) - { - // for each child - for (u64 j = 0; j < 2; ++j) - { - // update seed with correction - SIMD8(q, currentSeed[j][k + q] ^= parentTag[k + q] & sigma[j][k + q]); - - // (s0', s1') = H(s) - mAesFixedKey.ecbEncBlocks<8>(¤tSeed[j][k], &temp[0]); - SIMD8(q, childSeed[j * 2 + 0][k + q] = AES::roundEnc(temp[q], childSeed[j * 2 + 1][k + q])); - SIMD8(q, childSeed[j * 2 + 1][k + q] = childSeed[j * 2 + 1][k + q] + temp[q]); - - // z = z ^ s' - SIMD8(q, z[0][k + q] ^= childSeed[j * 2 + 0][k + q]); - SIMD8(q, z[1][k + q] ^= childSeed[j * 2 + 1][k + q]); - - // extract the tag from the seed - SIMD8(q, currentSeed[j][k + q] = tagBit(currentSeed[j][k + q])); - } - - } - - for (u64 k = numPoints8; k < mNumPoints; ++k) - { - for (u64 j = 0; j < 2; ++j) - { - //std::cout << "s[" << iter << "][" << L2 + j << "] " << currentSeed[j][k] << " -> "; - - currentSeed[j][k] ^= parentTag[k] & sigma[j][k]; + if (0) + { + co_await sock.send(coproto::copy(negAlphaj)); + co_await sock.send(coproto::copy(z[0])); + co_await sock.send(coproto::copy(z[1])); + BitVector negAlphaj2(mNumPoints); - //std::cout << currentSeed[j][k]<<" " << int(lsb(currentSeed[j][k])) << " via " << (parentTag[k] & sigma[j][k]) << std::endl; + std::array, 2> z2; + z2[0].resize(mNumPoints); + z2[1].resize(mNumPoints); - temp[0] = mAesFixedKey.ecbEncBlock(currentSeed[j][k]); - childSeed[j * 2 + 0][k] = AES::roundEnc(temp[0], currentSeed[j][k]); - childSeed[j * 2 + 1][k] = temp[0] + currentSeed[j][k]; + co_await sock.recv(negAlphaj2); + co_await sock.recv(z2[0]); + co_await sock.recv(z2[1]); - z[0][k] ^= childSeed[j * 2 + 0][k]; - z[1][k] ^= childSeed[j * 2 + 1][k]; + auto negA = negAlphaj ^ negAlphaj2; + for (u64 i = 0; i < mNumPoints; ++i) + { + auto na = negA[i]; + auto a = na ^ 1; + block exp[2], zz[2]; + zz[0] = z[0][i] ^ z2[0][i]; + zz[1] = z[1][i] ^ z2[1][i]; - //std::cout << "z1 += " << childSeed[j * 2 + 1][k] << std::endl; + exp[0] = (zz[na] & ~OneBlock) ^ block(0, lsb(zz[0]) ^ na); + exp[1] = (zz[na] & ~OneBlock) ^ block(0, lsb(zz[1]) ^ a); - currentSeed[j][k] = tagBit(currentSeed[j][k]); - } - } + if (sigma[0][i] != exp[0]) + { + std::cout << "exp " << exp[0] << " act " << sigma[0][i] << std::endl; + std::cout << "a " << (1 ^ negA[i]) << std::endl; + throw RTE_LOC; + } + if (sigma[1][i] != exp[1]) + { + std::cout << "exp " << exp[1] << " act " << sigma[1][i] << std::endl; + std::cout << "a " << (1 ^ negA[i]) << std::endl; + throw RTE_LOC; } } } - auto size = roundUpTo(mDomain, 2); - Matrix tags(size, mNumPoints); - setBytes(diff, 0); +#define SIMD8(VAR, STATEMENT) \ + { constexpr u64 VAR = 0; STATEMENT; }\ + { constexpr u64 VAR = 1; STATEMENT; }\ + { constexpr u64 VAR = 2; STATEMENT; }\ + { constexpr u64 VAR = 3; STATEMENT; }\ + { constexpr u64 VAR = 4; STATEMENT; }\ + { constexpr u64 VAR = 5; STATEMENT; }\ + { constexpr u64 VAR = 6; STATEMENT; }\ + { constexpr u64 VAR = 7; STATEMENT; }\ + do{}while(0) - // fixing the last layer + if (iter != mDepth) { + setBytes(z[0], 0); + setBytes(z[1], 0); - auto& tp = s[(mDepth - 1) % 3]; - auto& sc = s[mDepth % 3]; - auto& tc = tags; - - for (u64 L = 0, L2 = 0; L2 < size; ++L, L2 += 2) + // we iterate over the parent tags. Each has two children. We expend + // these two children into 4 grandchildren. + for (u64 L = 0, L2 = 0, L4 = 0; L2 < size; ++L, L2 += 2, L4 += 4) { // parent control bits auto parentTag = getRow(tp, L); @@ -351,43 +465,109 @@ namespace osuCrypto // child seed std::array currentSeed{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; - // child control bit - std::array tag{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; + // grandchild seeds + std::array childSeed{ getRow(sg, L4 + 0), getRow(sg, L4 + 1), getRow(sg, L4 + 2), getRow(sg, L4 + 3) }; for (u64 k = 0; k < numPoints8; k += 8) { + // for each child for (u64 j = 0; j < 2; ++j) { + // update seed with correction + SIMD8(q, currentSeed[j][k + q] ^= parentTag[k + q] & sigma[j][k + q]); - SIMD8(q, temp[q] = currentSeed[j][k + q] ^ parentTag[k + q] & sigma[j][k + q]); - SIMD8(q, tag[j][k + q] = tagBit(temp[q])); + // (s0', s1') = H(s) + mAesFixedKey.ecbEncBlocks<8>(¤tSeed[j][k], &temp[0]); + SIMD8(q, childSeed[j * 2 + 0][k + q] = AES::roundEnc(temp[q], currentSeed[j][k + q])); + SIMD8(q, childSeed[j * 2 + 1][k + q] = temp[q] + currentSeed[j][k + q]); - SIMD8(q, currentSeed[j][k + q] = AES::roundFn(temp[q], temp[q])); - SIMD8(q, diff[k+q] ^= currentSeed[j][k+q]); + // z = z ^ s' + SIMD8(q, z[0][k + q] ^= childSeed[j * 2 + 0][k + q]); + SIMD8(q, z[1][k + q] ^= childSeed[j * 2 + 1][k + q]); + // extract the tag from the seed + SIMD8(q, currentSeed[j][k + q] = tagBit(currentSeed[j][k + q])); } + } for (u64 k = numPoints8; k < mNumPoints; ++k) { for (u64 j = 0; j < 2; ++j) { - //std::cout << "s[" << mDepth << "][" << L2 + j << "] " << currentSeed[j][k] << " -> "; - temp[0] = currentSeed[j][k] ^ parentTag[k] & sigma[j][k]; - tag[j][k] = tagBit(temp[0]); - currentSeed[j][k] = AES::roundFn(temp[0], temp[0]); - diff[k] ^= currentSeed[j][k]; + currentSeed[j][k] ^= parentTag[k] & sigma[j][k]; + + temp[0] = mAesFixedKey.ecbEncBlock(currentSeed[j][k]); + childSeed[j * 2 + 0][k] = AES::roundEnc(temp[0], currentSeed[j][k]); + childSeed[j * 2 + 1][k] = temp[0] + currentSeed[j][k]; + + z[0][k] ^= childSeed[j * 2 + 0][k]; + z[1][k] ^= childSeed[j * 2 + 1][k]; - //std::cout << currentSeed[j][k] << " " << int(lsb(currentSeed[j][k])) << " via " << (parentTag[k] & sigma[j][k]) << std::endl; + currentSeed[j][k] = tagBit(currentSeed[j][k]); } } } } + } - if (values.size()) + if (!values.size() && outputKey) + co_return; + + auto size = roundUpTo(mDomain, 2); + Matrix tags(size, mNumPoints); + setBytes(diff, 0); + + // fixing the last layer + { + auto& tp = s[(mDepth - 1) % 3]; + auto& sc = s[mDepth % 3]; + auto& tc = tags; + + for (u64 L = 0, L2 = 0; L2 < size; ++L, L2 += 2) { + // parent control bits + auto parentTag = getRow(tp, L); + + // child seed + std::array currentSeed{ getRow(sc, L2 + 0), getRow(sc, L2 + 1) }; + + // child control bit + std::array tag{ getRow(tc, L2 + 0), getRow(tc, L2 + 1) }; - AlignedUnVector gamma(mNumPoints); + for (u64 k = 0; k < numPoints8; k += 8) + { + for (u64 j = 0; j < 2; ++j) + { + SIMD8(q, temp[q] = currentSeed[j][k + q] ^ parentTag[k + q] & sigma[j][k + q]); + SIMD8(q, tag[j][k + q] = tagBit(temp[q])); + SIMD8(q, currentSeed[j][k + q] = AES::roundFn(temp[q], temp[q])); + SIMD8(q, diff[k + q] ^= currentSeed[j][k + q]); + } + } + + for (u64 k = numPoints8; k < mNumPoints; ++k) + { + for (u64 j = 0; j < 2; ++j) + { + temp[0] = currentSeed[j][k] ^ parentTag[k] & sigma[j][k]; + tag[j][k] = tagBit(temp[0]); + currentSeed[j][k] = AES::roundFn(temp[0], temp[0]); + diff[k] ^= currentSeed[j][k]; + } + } + } + } + + if (values.size()) + { + AlignedUnVector gamma(mNumPoints); + if (inputKey) + { + std::copy(inputKey->mLeafVals.begin(), inputKey->mLeafVals.end(), gamma.begin()); + } + else + { for (u64 k = 0; k < mNumPoints; ++k) { diff[k] ^= values[k]; @@ -398,9 +578,16 @@ namespace osuCrypto { gamma[k] ^= diff[k]; } + } + if (outputKey) + { + outputKey->mLeafVals.insert(outputKey->mLeafVals.end(), gamma.begin(), gamma.end()); + } + else + { auto& sd = s[mDepth % 3]; - auto& td = tags; + auto& td = tags; for (u64 i = 0; i < mDomain; ++i) { auto sdi = getRow(sd, i); @@ -419,41 +606,152 @@ namespace osuCrypto } } } - else + } + else + { + auto& sd = s[mDepth & 1]; + auto& td = tags; + for (u64 i = 0; i < mDomain; ++i) { - auto& sd = s[mDepth & 1]; - auto& td = tags; - for (u64 i = 0; i < mDomain; ++i) + auto sdi = getRow(sd, i); + auto tdi = getRow(td, i); + for (u64 k = 0; k < numPoints8; k += 8) { - auto sdi = getRow(sd, i); - auto tdi = getRow(td, i); - for (u64 k = 0; k < numPoints8; k += 8) - { - SIMD8(q, output(k + q, i, sdi[k + q], tdi[k + q])); - } - for (u64 k = numPoints8; k < mNumPoints; ++k) - { - output(k, i, sdi[k], tdi[k]); - } + SIMD8(q, output(k + q, i, sdi[k + q], tdi[k + q])); + } + for (u64 k = numPoints8; k < mNumPoints; ++k) + { + output(k, i, sdi[k], tdi[k]); } } } + } - u64 baseOtCount() const { - return mMultiplier.baseOtCount(); - } + inline u64 RegularDpf::baseOtCount() const { + return mMultiplier.baseOtCount(); + } - void setBaseOts( - span> baseSendOts, - span recvBaseOts, - const oc::BitVector& baseChoices) + inline void RegularDpf::setBaseOts( + span> baseSendOts, + span recvBaseOts, + const oc::BitVector& baseChoices) + { + mMultiplier.setBaseOts(baseSendOts, recvBaseOts, baseChoices); + } + + + inline void RegularDpf::keyGen( + u64 domain, + span points, + span values, + PRNG& prng, + span keys) + { + if (keys.size() != 2) + throw RTE_LOC; + if (values.size() != points.size() && values.size() != 0) + throw RTE_LOC; + + auto depth = log2ceil(domain); + keys[0].resize(domain, values.size()); + keys[1].resize(domain, values.size()); + + auto seed0 = prng.get(); + auto seed1 = prng.get(); + std::array prngs{ seed0, seed1 }; + keys[0].mSeed = prngs[0].getSeed(); + keys[1].mSeed = prngs[1].getSeed(); + for (u64 i = 0; i < values.size(); ++i) { - mMultiplier.setBaseOts(baseSendOts, recvBaseOts, baseChoices); - } + std::array parentTags; + std::array, 2> seeds; + for (u64 p = 0; p < 2; ++p) + { + prngs[p].get(seeds[p].data(), seeds[p].size()); + parentTags[p] = block::allSame(-p); + } - }; + for (u64 iter = 1; iter <= depth; ++iter) + { + auto a = *BitIterator(&points[i], depth - iter); + auto na = a ^ 1; + + auto diff = seeds[0][na] ^ seeds[1][na]; + u8 tau[2]; + tau[0] = lsb(seeds[0][0] ^ seeds[1][0]) ^ na; + tau[1] = lsb(seeds[0][1] ^ seeds[1][1]) ^ a; + + // we want diff || lsbs[0] ^ na || lsbs[1] ^ a + *BitIterator(&diff) = tau[0]; + + block sigma[2]; + sigma[0] = diff; + sigma[1] = diff; + *BitIterator(&sigma[1]) = tau[1]; + + for (u64 p = 0; p < 2; ++p) + { + keys[p].mCorrectionWords(iter - 1, i) = diff; + keys[p].mCorrectionBits(iter - 1, i) = tau[1]; + + seeds[p][0] ^= sigma[0] & parentTags[p]; + seeds[p][1] ^= sigma[1] & parentTags[p]; + parentTags[p] = tagBit(seeds[p][a]); + } + + if (seeds[0][na] != seeds[1][na]) + throw RTE_LOC; + if (lsb(seeds[0][a] ^ seeds[1][a]) != 1) + throw RTE_LOC; + if ((parentTags[0] ^ parentTags[1]) != AllOneBlock) + throw RTE_LOC; + + for (u64 p = 0; p < 2; ++p) + { + if (iter != depth) + { + auto seed = seeds[p][a]; + auto temp = mAesFixedKey.ecbEncBlock(seed); + seeds[p][0] = AES::roundEnc(temp, seed); + seeds[p][1] = temp + seed; + } + } + } + + if (values.size()) + { + auto a = *BitIterator(&points[i], 0); + auto na = a ^ 1; + + if (seeds[0][na] != seeds[1][na]) + throw RTE_LOC; + + for (u64 p = 0; p < 2; ++p) + { + seeds[p][a] = AES::roundFn(seeds[p][a], seeds[p][a]); + } + + auto diff = seeds[0][a] ^ seeds[1][a]; + auto gamma = diff ^ values[i]; + keys[0].mLeafVals.push_back(gamma); + keys[1].mLeafVals.push_back(gamma); + } + } + } + + template + void RegularDpf::expand( + u64 partyIdx, + u64 domain, + RegularDpfKey& key, + Output&& output) + { + RegularDpf d; + d.init(partyIdx, domain, key.mCorrectionBits.cols()); + return macoro::sync_wait(d.implExpand({}, {}, {}, key, output, {}, nullptr)); + } } diff --git a/libOTe/Tools/Dpf/SparseDpf.h b/libOTe/Tools/Dpf/SparseDpf.h index 7487f2d9..311b2099 100644 --- a/libOTe/Tools/Dpf/SparseDpf.h +++ b/libOTe/Tools/Dpf/SparseDpf.h @@ -273,10 +273,10 @@ namespace osuCrypto densePoints[i] = points[i] >> depth; Matrix seeds(points.size(), 1ull << mDenseDepth); Matrix tags(points.size(), 1ull << mDenseDepth); - co_await mRegDpf.expand(densePoints, {}, [&](auto treeIdx, auto leafIdx, auto seed, auto tag) { + co_await mRegDpf.expand(densePoints, {}, prng.get(), [&](auto treeIdx, auto leafIdx, auto seed, auto tag) { seeds(treeIdx, leafIdx) = seed; tags(treeIdx, leafIdx) = tag.get(0)&1; - }, prng, sock); + }, sock); for (u64 r = 0; r < sparsePoints.rows(); ++r) { diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 077df56d..9931664e 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -165,8 +165,8 @@ void RegularDpf_Proto_Test(const CLP& cmd) auto sock = coproto::LocalAsyncSocket::makePair(); macoro::sync_wait(macoro::when_all_ready( - dpf[0].expand(points0, values0, [&](auto k, auto i, auto v, auto t) { output[0](k, i) = v; tags[0](k, i) = t.get(0)&1; }, prng, sock[0]), - dpf[1].expand(points1, values1, [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }, prng, sock[1]) + dpf[0].expand(points0, values0, prng.get(), [&](auto k, auto i, auto v, auto t) { output[0](k, i) = v; tags[0](k, i) = t.get(0) & 1; }, sock[0]), + dpf[1].expand(points1, values1, prng.get(), [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }, sock[1]) )); @@ -190,6 +190,89 @@ void RegularDpf_Proto_Test(const CLP& cmd) } } +void RegularDpf_keyGen_Test(const oc::CLP& cmd) +{ + + PRNG prng(block(231234, 321312)); + u64 domain = cmd.getOr("domain", 211); + u64 numPoints = cmd.getOr("numPoints", 11); + std::vector points(numPoints); + std::vector points0(numPoints); + std::vector points1(numPoints); + std::vector values(numPoints); + std::vector values0(numPoints); + std::vector values1(numPoints); + for (u64 i = 0; i < numPoints; ++i) + { + points[i] = prng.get() % domain; + points1[i] = prng.get(); + points0[i] = points[i] ^ points1[i]; + values0[i] = prng.get(); + values1[i] = prng.get(); + values[i] = values0[i] ^ values1[i]; + } + + std::array dpf; + dpf[0].init(0, domain, numPoints); + dpf[1].init(1, domain, numPoints); + + auto baseCount = dpf[0].baseOtCount(); + + std::array, 2> baseRecv; + std::array>, 2> baseSend; + std::array baseChoice; + baseRecv[0].resize(baseCount); + baseRecv[1].resize(baseCount); + baseSend[0].resize(baseCount); + baseSend[1].resize(baseCount); + baseChoice[0].resize(baseCount); + baseChoice[1].resize(baseCount); + baseChoice[0].randomize(prng); + baseChoice[1].randomize(prng); + for (u64 i = 0; i < baseCount; ++i) + { + baseSend[0][i] = prng.get(); + baseSend[1][i] = prng.get(); + baseRecv[0][i] = baseSend[1][i][baseChoice[0][i]]; + baseRecv[1][i] = baseSend[0][i][baseChoice[1][i]]; + } + dpf[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); + dpf[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); + + std::array, 2> output; + std::array, 2> tags; + output[0].resize(numPoints, domain); + output[1].resize(numPoints, domain); + tags[0].resize(numPoints, domain); + tags[1].resize(numPoints, domain); + + std::array key, key2; + + auto sock = coproto::LocalAsyncSocket::makePair(); + + prng.SetSeed(block(214234, 2341234)); + block seed0 = prng.get(); + block seed1 = prng.get(); + + macoro::sync_wait(macoro::when_all_ready( + dpf[0].keyGen(points0, values0, seed0, key[0], sock[0]), + dpf[1].keyGen(points1, values1, seed1, key[1], sock[1]) + )); + + prng.SetSeed(block(214234, 2341234)); + RegularDpf::keyGen(domain, span(points), span(values), prng, span(key2)); + + + if (key[0] != key2[0]) + { + std::cout << key[0] << std::endl; + std::cout << key2[0] << std::endl; + throw RTE_LOC; + } + if (key[1] != key2[1]) + throw RTE_LOC; +} + void SparseDpf_Proto_Test(const oc::CLP& cmd) { PRNG prng(block(32324, 2342)); @@ -295,7 +378,7 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) PRNG prng(block(231234, 321312)); u64 depth = cmd.getOr("depth", 3); - u64 domain = ipow(3,depth) - 3; + u64 domain = ipow(3, depth) - 3; u64 numPoints = cmd.getOr("numPoints", 17); std::vector points0(numPoints); std::vector points1(numPoints); @@ -365,7 +448,7 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) auto exp = t ? (values0[k] ^ values1[k]) : ZeroBlock; if (exp != act) { - std::cout << "i " << i << "="<< Trit32(i)<<" " << t << std::endl; + std::cout << "i " << i << "=" << Trit32(i) << " " << t << std::endl; std::cout << "exp " << exp << std::endl; std::cout << "act " << act << std::endl; throw RTE_LOC; diff --git a/libOTe_Tests/RegularDpf_Tests.h b/libOTe_Tests/RegularDpf_Tests.h index ccc85417..7edad019 100644 --- a/libOTe_Tests/RegularDpf_Tests.h +++ b/libOTe_Tests/RegularDpf_Tests.h @@ -4,5 +4,6 @@ void RegularDpf_Multiply_Test(const oc::CLP& cmd); void RegularDpf_Proto_Test(const oc::CLP& cmd); +void RegularDpf_keyGen_Test(const oc::CLP& cmd); void SparseDpf_Proto_Test(const oc::CLP& cmd); void TritDpf_Proto_Test(const oc::CLP& cmd); \ No newline at end of file diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index cf23a7e2..44332570 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -61,6 +61,7 @@ namespace tests_libOTe tc.add("RegularDpf_Multiply_Test ", RegularDpf_Multiply_Test); tc.add("RegularDpf_Proto_Test ", RegularDpf_Proto_Test); + tc.add("RegularDpf_keyGen_Test ", RegularDpf_keyGen_Test); tc.add("SparseDpf_Proto_Test ", SparseDpf_Proto_Test); tc.add("TritDpf_Proto_Test ", TritDpf_Proto_Test); From 11729362d9a95f19e606a627d75957960054667e Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Fri, 14 Feb 2025 09:26:22 -0800 Subject: [PATCH 13/48] dpf noninteractive eval fix --- libOTe/Tools/Dpf/RegularDpf.h | 5 +++-- libOTe_Tests/RegularDpf_Tests.cpp | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/libOTe/Tools/Dpf/RegularDpf.h b/libOTe/Tools/Dpf/RegularDpf.h index 42212a57..08080ff7 100644 --- a/libOTe/Tools/Dpf/RegularDpf.h +++ b/libOTe/Tools/Dpf/RegularDpf.h @@ -559,7 +559,7 @@ namespace osuCrypto } } - if (values.size()) + if (values.size() || inputKey && inputKey->mLeafVals.size()) { AlignedUnVector gamma(mNumPoints); if (inputKey) @@ -750,7 +750,8 @@ namespace osuCrypto { RegularDpf d; d.init(partyIdx, domain, key.mCorrectionBits.cols()); - return macoro::sync_wait(d.implExpand({}, {}, {}, key, output, {}, nullptr)); + coproto::Socket sock; + return macoro::sync_wait(d.implExpand({}, {}, {}, &key, output, sock, nullptr)); } } diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 9931664e..d31a90d8 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -271,6 +271,27 @@ void RegularDpf_keyGen_Test(const oc::CLP& cmd) } if (key[1] != key2[1]) throw RTE_LOC; + RegularDpf::expand(0, domain, key2[0], [&](auto k, auto i, auto v, auto t) { output[0](k, i) = v; tags[0](k, i) = t.get(0) & 1; }); + RegularDpf::expand(1, domain, key2[1], [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }); + + for (u64 i = 0; i < domain; ++i) + { + for (u64 k = 0; k < numPoints; ++k) + { + auto p = points0[k] ^ points1[k]; + auto act = output[0][k][i] ^ output[1][k][i]; + auto t = i == p ? 1 : 0; + auto tAct = tags[0][k][i] ^ tags[1][k][i]; + auto exp = t ? (values0[k] ^ values1[k]) : ZeroBlock; + if (exp != act) + { + + throw RTE_LOC; + } + if (t != tAct) + throw RTE_LOC; + } + } } void SparseDpf_Proto_Test(const oc::CLP& cmd) From b158839ef495f9dcdf7cd499ffd3e604fbb4f713 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Fri, 14 Feb 2025 22:56:20 -0800 Subject: [PATCH 14/48] triDpf PCG working --- libOTe/Tools/Foleage/FoleagePcg.cpp | 313 +++++++++++++++++------- libOTe/Tools/Foleage/FoleagePcg.h | 14 +- libOTe/Tools/Foleage/fft/FoleageFft.cpp | 2 +- libOTe_Tests/Foleage_Tests.cpp | 10 +- 4 files changed, 241 insertions(+), 98 deletions(-) diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index 25481d44..e9aa7973 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -19,11 +19,14 @@ namespace osuCrypto if (mT != ipow(3, mLog3T)) throw RTE_LOC; - mDpfDomainDepth = std::max(1, log3ceil(divCeil(mN, mT * 256))); + // + // log3( (mN / mT) / 256 ) + // log3N - log3(mT * 256) + mBlockSize = mN / mT; + mDpfDomainDepth = std::max(1, log3ceil(divCeil(mBlockSize, 256))); mDpfBlockSize = 4 * ipow(3, mDpfDomainDepth); - mBlockSize = mN / mT; - if (mBlockSize < 8) + if (mBlockSize < 2) throw RTE_LOC; sampleA(block(431234234, 213434234123)); @@ -190,6 +193,7 @@ namespace osuCrypto std::vector prodPolyCoefficient(mC * mC * mT * mT); + std::vector prodPolyCoefficient2(mC * mC * mT * mT); std::vector prodPolyPosition(mC * mC * mT * mT); std::vector prodPolyPositionTrit(mC * mC * mT * mT); @@ -203,10 +207,11 @@ namespace osuCrypto co_await sock.recv(otherSparseCoefficients); co_await sock.recv(otherSparsePositions); setTimePoint("sendRecv"); + std::vector positionMap(mC * mC * mT * mT); u64 polyOffset = 0; u8 vA, vB; - for (u64 iA = 0; iA < mC; ++iA) + for (u64 iA = 0, pointIdx = 0; iA < mC; ++iA) { for (u64 iB = 0; iB < mC; ++iB) { @@ -214,7 +219,7 @@ namespace osuCrypto for (u64 jA = 0; jA < mT; ++jA) { - for (u64 jB = 0; jB < mT; ++jB) + for (u64 jB = 0; jB < mT; ++jB, ++pointIdx) { int_to_trits(jA, tritABlk); int_to_trits(jB, tritBBlk); @@ -246,17 +251,31 @@ namespace osuCrypto int_to_trits(posA_, tritAPos); int_to_trits(posB_, tritBPos); - for(u64 k = 0; k < tritBPos.size(); ++k) + for (u64 k = 0; k < tritBPos.size(); ++k) { tritsPos[k] = (tritAPos[k] + tritBPos[k]) % 3; } auto subblock_pos = trits_to_int(tritsPos); - + + //positionMap[pointIdx] = + size_t idx = polyOffset + blockIdx * mT + nextIdx[blockIdx]++; prodPolyCoefficient[idx] = mult_f4(vA, vB); prodPolyPosition[idx] = subblock_pos; - prodPolyPositionTrit[idx] = subblock_pos; + + if (mPartyIdx) + { + prodPolyPositionTrit[idx] = Trit32(234 % mBlockSize); + prodPolyCoefficient2[idx] = block(42314342, 234123); + + } + else + { + prodPolyPositionTrit[idx] = Trit32(subblock_pos) - Trit32(234 % mBlockSize); + prodPolyCoefficient2[idx] = block(0, mult_f4(vA, vB)) ^ block(42314342, 234123); + + } } } @@ -275,124 +294,231 @@ namespace osuCrypto PRFKeys prf_keys; PRNG prfSeedPrng(block(3412342134, 56453452362346)); prf_keys.gen(prfSeedPrng); + size_t packedBlockSize = divCeil(mBlockSize, 64); // Sample DPF keys for each of the t errors in the t blocks - u64 index = 0; PRNG genPrng; - //oc::RandomOracle dpfHash(16); - - for (u64 i = 0; i < mC; i++) + ////oc::RandomOracle dpfHash(16); + Matrix blocks(mC * mC * mT, packedBlockSize); + if (0) { - for (u64 j = 0; j < mC; j++) + + + for (u64 i = 0, index = 0; i < mC; i++) { - for (u64 k = 0; k < mT; k++) + for (u64 j = 0; j < mC; j++) { - for (u64 l = 0; l < mT; l++, ++index) + for (u64 k = 0; k < mT; k++) { - //size_t index = i * c * t * t + j * t * t + k * t + l; + for (u64 l = 0; l < mT; l++, ++index) + { + //size_t index = i * c * t * t + j * t * t + k * t + l; - // Parse the index into the right format - size_t alpha = prodPolyPosition[index]; + // Parse the index into the right format + size_t alpha = prodPolyPosition[index]; - // Output message index in the DPF output space - // which consists of 256 F4 elements - size_t alpha_0 = alpha / 256; + // Output message index in the DPF output space + // which consists of 256 F4 elements + size_t alpha_0 = alpha / 256; - // Coeff index in the block of 256 coefficients - size_t alpha_1 = alpha % 256; + // Coeff index in the block of 256 coefficients + size_t alpha_1 = alpha % 256; - // Coeff index in the uint128_t output (64 elements of F4) - size_t packed_idx = alpha_1 / 64; + // Coeff index in the uint128_t output (64 elements of F4) + size_t packed_idx = alpha_1 / 64; - // Bit index in the uint128_t ouput - size_t bit_idx = alpha_1 % 64; + // Bit index in the uint128_t ouput + size_t bit_idx = alpha_1 % 64; - // Set the DPF message to the coefficient - uint128_t coeff = uint128_t(prodPolyCoefficient[index]); + // Set the DPF message to the coefficient + uint128_t coeff = uint128_t(prodPolyCoefficient[index]); - // Position coefficient into the block - std::array beta; // init to zero - setBytes(beta, 0); - beta[packed_idx] = coeff << (2 * (63 - bit_idx)); + // Position coefficient into the block + std::array beta; // init to zero + setBytes(beta, 0); + //beta[packed_idx] = coeff << (2 * (63 - bit_idx)); + beta[packed_idx] = coeff << (2 * (bit_idx)); - // Message (beta) is of size 4 blocks of 128 bits - genPrng.SetSeed(block(index, 542345234)); - DPFKey _; - if (mPartyIdx) - { - DPFGen(prf_keys, mDpfDomainDepth, alpha_0, beta, 4, _, Dpfs[index], genPrng); - } - else - { - DPFGen(prf_keys, mDpfDomainDepth, alpha_0, beta, 4, Dpfs[index], _, genPrng); - } + // Message (beta) is of size 4 blocks of 128 bits + genPrng.SetSeed(block(index, 542345234)); + DPFKey _; + if (mPartyIdx) + { + DPFGen(prf_keys, mDpfDomainDepth, alpha_0, beta, 4, _, Dpfs[index], genPrng); + } + else + { + DPFGen(prf_keys, mDpfDomainDepth, alpha_0, beta, 4, Dpfs[index], _, genPrng); + } - //dpfHash.Update(Dpfs[index].k.data(), Dpfs[index].k.size()); - //dpfHash.Update(Dpfs[index].msg_len); - //dpfHash.Update(Dpfs[index].size); + //dpfHash.Update(Dpfs[index].k.data(), Dpfs[index].k.size()); + //dpfHash.Update(Dpfs[index].msg_len); + //dpfHash.Update(Dpfs[index].size); + } } } } - } - setTimePoint("dpfKeyGen"); + setTimePoint("dpfKeyGen"); - //block dpfHashVal; - //dpfHash.Final(dpfHashVal); - //std::cout << "dpf " << dpfHashVal << std::endl; + //block dpfHashVal; + //dpfHash.Final(dpfHashVal); + //std::cout << "dpf " << dpfHashVal << std::endl; - std::vector shares(mDpfBlockSize); - std::vector cache(mDpfBlockSize); + std::vector shares(mDpfBlockSize); + std::vector cache(mDpfBlockSize); - size_t packedBlockSize = divCeil(mBlockSize, 64); - Matrix blocks(mC * mC * mT, packedBlockSize); + Matrix blocks(mC * mC * mT, packedBlockSize); - std::vector fft(mN), fftRes(mN); - auto dpfIter = Dpfs.begin(); - //auto dpf_keys_B_iter = dpf_keys_B.begin(); - - for (size_t i = 0; i < mC; i++) - { - for (size_t j = 0; j < mC; j++) + auto dpfIter = Dpfs.begin(); + //Matrix expPos(mC* mC* mT, mT); + //Matrix expCoeff(mC* mC* mT, mT); + //auto dpf_keys_B_iter = dpf_keys_B.begin(); + for (size_t i = 0, q = 0; i < mC; i++) { - const size_t poly_index = i * mC + j; - - oc::MatrixView packed_polyA_(blocks.data(poly_index * mT), mT, blocks.cols()); - - for (size_t k = 0; k < mT; k++) + for (size_t j = 0; j < mC; j++) { - span poly_blockA = packed_polyA_[k]; - - for (size_t l = 0; l < mT; l++) - { + const size_t poly_index = i * mC + j; - DPFKey& dpf = *dpfIter++; + oc::MatrixView packed_polyA_(blocks.data(poly_index * mT), mT, blocks.cols()); - DPFFullDomainEval(dpf, cache, shares); + for (size_t k = 0; k < mT; k++) + { + span poly_blockA = packed_polyA_[k]; - // Sum all the DPFs for the current block together - // note that there is some extra "garbage" in the last - // block of uint128_t since 64 does not divide block_size. - // We deal with this slack later when packing the outputs - // into the parallel FFT matrix. - for (size_t w = 0; w < packedBlockSize; w++) + for (size_t l = 0; l < mT; l++, ++q) { - poly_blockA[w] ^= shares[w]; + DPFFullDomainEval(*dpfIter++, cache, shares); + + // Sum all the DPFs for the current block together + // note that there is some extra "garbage" in the last + // block of uint128_t since 64 does not divide block_size. + // We deal with this slack later when packing the outputs + // into the parallel FFT matrix. + for (size_t w = 0; w < packedBlockSize; w++) + { + poly_blockA[w] ^= shares[w]; + } } } } } } - setTimePoint("dpfKeyEval"); + else + { - co_await dpfEval(prodPolyPositionTrit, prodPolyCoefficient,prng, sock); + { + mDpf.init(mPartyIdx, mBlockSize, prodPolyPositionTrit.size()); + auto numOTs = mDpf.baseOtCount(); + std::vector baseRecvOts(numOTs); + std::vector> baseSendOts(numOTs); + BitVector baseChoices(numOTs); + PRNG basePrng(block(324234, 234234)); + basePrng.get(baseSendOts.data(), baseSendOts.size()); + baseChoices.randomize(basePrng); + for (u64 i = 0; i < numOTs; ++i) + { + baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; + } - //std::cout << "block " << hash(blocks.data(), blocks.size()) << std::endl; + mDpf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); + } + co_await mDpf.expand(prodPolyPositionTrit, prodPolyCoefficient2, [&](u64 treeIdx, u64 leafIdx, block v, u8 t) { + // treeIdx in [0, mC^2 * mT^2] + auto row = treeIdx / mT; + blocks(row, leafIdx / 64) ^= uint128_t(v.get(0)) << (2 * (leafIdx % 64)); + }, prng, sock); + } + setTimePoint("dpfKeyEval"); + + //if (1) + //{ + // auto F4Print = [](uint128_t v)->std::string + // { + // std::stringstream ss; + // for (u64 i = 0; i < 64; ++i) + // { + // auto lsb = *BitIterator(&v, i * 2); + // auto msb = *BitIterator(&v, i * 2 + 1); + // ss << (lsb + 2 * msb); + // } + // return ss.str(); + // }; + + // co_await sock.send(coproto::copy(blocks)); + // co_await sock.send(coproto::copy(blocks2)); + + // Matrix rBlocks(mC * mC * mT, packedBlockSize); + // Matrix rBlocks2(mC * mC * mT, packedBlockSize); + + // co_await sock.recv(rBlocks); + // co_await sock.recv(rBlocks2); + // for (u64 i = 0; i < rBlocks.rows(); ++i) + // { + // std::vector exp(packedBlockSize); + // std::vector exp2(packedBlockSize); + + // for (u64 j = 0; j < packedBlockSize; ++j) + // { + // exp[j] = blocks(i, j) ^ rBlocks(i, j); + // } + // auto points = span(prodPolyPosition.data() + i * mT, mT); + // auto coeffs = span(prodPolyCoefficient.data() + i * mT, mT); + + // for (u64 j = 0; j < mT; ++j) + // { + // auto blk = points[j] / 64; + // auto offset = (2 * (points[j] % 64)); + // exp2[blk] ^= uint128_t(coeffs[j]) << offset; + // } + + // if (exp != exp2) + // { + // std::cout << i << std::endl << "exp\n "; + // for (u64 j = 0; j < packedBlockSize; ++j) + // { + // std::cout << F4Print(exp[j])<< " "; + // } + // std::cout << std::endl << "exp2\n "; + // for (u64 j = 0; j < packedBlockSize; ++j) + // { + // std::cout << F4Print(exp2[j]) << " "; + // } + + // throw RTE_LOC; + // } + + // for (u64 j = 0; j < packedBlockSize; ++j) + // { + + // auto act = blocks2(i, j) ^ rBlocks2(i, j); + // if (exp[j] != act) + // { + // std::cout << i << std::endl << "exp\n "; + // for (u64 j = 0; j < packedBlockSize; ++j) + // { + // auto v = (blocks(i, j) ^ rBlocks(i, j)); + // std::cout << *(block*)&v << " "; + // } + // std::cout << std::endl << "act\n "; + // for (u64 j = 0; j < packedBlockSize; ++j) + // { + // auto v = (blocks2(i, j) ^ rBlocks2(i, j)); + // std::cout << *(block*)&v << " "; + // } + // throw RTE_LOC; + // } + // } + // } + //} + ////std::cout << "block " << hash(blocks.data(), blocks.size()) << std::endl; + + std::vector fft(mN), fftRes(mN); for (size_t j = 0; j < mC; j++) { for (size_t k = 0; k < mC; k++) @@ -411,7 +537,8 @@ namespace osuCrypto for (u64 element_idx = 0; element_idx < e; ++element_idx) { - fft[i] |= u32{ coeff[63 - element_idx] } << (2 * poly_index); + fft[i] |= u32{ coeff[element_idx] } << (2 * poly_index); + //fft[i] |= u32{ coeff[63 - element_idx] } << (2 * poly_index); ++i; } } @@ -451,9 +578,17 @@ namespace osuCrypto } - macoro::task<> FoleageF4Ole::dpfEval(span points, span coeffs, PRNG& prng, coproto::Socket& sock) - { - co_return; - } + //macoro::task<> FoleageF4Ole::dpfEval( + // u64 domain, + // span points, + // span coeffs, + // MatrixView output, + // PRNG& prng, + // coproto::Socket& sock) + //{ + + + // co_return; + //} } \ No newline at end of file diff --git a/libOTe/Tools/Foleage/FoleagePcg.h b/libOTe/Tools/Foleage/FoleagePcg.h index 51b49bc5..11a6b873 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.h +++ b/libOTe/Tools/Foleage/FoleagePcg.h @@ -18,9 +18,9 @@ namespace osuCrypto u64 mLog3N = 0; // the number of noisy positions per polynomial - u64 mT = 27; + u64 mT = 3; - u64 mLog3T = 3; + u64 mLog3T = 1; // the number of polynomials u64 mC = 4; @@ -52,6 +52,8 @@ namespace osuCrypto // the i'th row containts the coeffs for the i'th poly. Matrix mSparsePositions; + TriDpf mDpf; + void init(u64 partyIdx, u64 n, PRNG& prng); macoro::task<> expand( @@ -61,7 +63,13 @@ namespace osuCrypto span CMsb, PRNG& prng, coproto::Socket& sock); - macoro::task<> dpfEval(span points, span coeffs, PRNG& prng, coproto::Socket& sock); + //macoro::task<> dpfEval( + // u64 domain, + // span points, + // span coeffs, + // MatrixView output, + // PRNG& prng, + // coproto::Socket& sock); void sampleA(block seed); }; diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.cpp b/libOTe/Tools/Foleage/fft/FoleageFft.cpp index 146bb99a..2e13bbf5 100644 --- a/libOTe/Tools/Foleage/fft/FoleageFft.cpp +++ b/libOTe/Tools/Foleage/fft/FoleageFft.cpp @@ -806,7 +806,7 @@ namespace osuCrypto { { auto n = lsb.size() / stride; - auto log3N = log2ceil(n); + auto log3N = log3ceil(n); if (n != ipow(3, log3N)) throw RTE_LOC; if (lsb.size() != n * stride) diff --git a/libOTe_Tests/Foleage_Tests.cpp b/libOTe_Tests/Foleage_Tests.cpp index 12fbaf65..7885f973 100644 --- a/libOTe_Tests/Foleage_Tests.cpp +++ b/libOTe_Tests/Foleage_Tests.cpp @@ -1391,7 +1391,7 @@ namespace osuCrypto { std::array oles; - auto logn = 12; + auto logn = 4; u64 n = ipow(3, logn); auto blocks = divCeil(n, 128); bool verbose = cmd.isSet("v"); @@ -1427,17 +1427,17 @@ namespace osuCrypto // the test otherwise. for (size_t i = 0; i < blocks; i++) { - auto aLsb = C0Lsb[i] ^ C1Lsb[i]; - auto aMsb = C0Msb[i] ^ C1Msb[i]; + auto Lsb = C0Lsb[i] ^ C1Lsb[i]; + auto Msb = C0Msb[i] ^ C1Msb[i]; block mLsb, mMsb; f4Mult( ALsb[i], AMsb[i], BLsb[i], BMsb[i], mLsb, mMsb); - if (aLsb != mLsb) + if (Lsb != mLsb) throw RTE_LOC; - if (aMsb != mMsb) + if (Msb != mMsb) throw RTE_LOC; } From 58ca103dc46466a63af63141cf2478df91e3a570 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sat, 15 Feb 2025 12:17:24 -0800 Subject: [PATCH 15/48] foliage 243 elem leaf --- libOTe/Tools/CoeffCtx.h | 19 +++ libOTe/Tools/Dpf/TriDpf.h | 151 +++++++++++++-------- libOTe/Tools/Foleage/FoleagePcg.cpp | 200 ++++++++++++++++++---------- libOTe/Tools/Foleage/FoleagePcg.h | 23 +++- libOTe/Tools/Foleage/FoleageUtils.h | 39 ++++++ libOTe_Tests/Foleage_Tests.cpp | 4 +- libOTe_Tests/RegularDpf_Tests.cpp | 2 +- 7 files changed, 300 insertions(+), 138 deletions(-) diff --git a/libOTe/Tools/CoeffCtx.h b/libOTe/Tools/CoeffCtx.h index a73dec18..de711967 100644 --- a/libOTe/Tools/CoeffCtx.h +++ b/libOTe/Tools/CoeffCtx.h @@ -296,6 +296,25 @@ namespace osuCrypto { return ss.str(); } + template + void mask(F& ret, const F& x, const block& mask) + { + static_assert(std::is_trivially_copyable::value, "memset is used so must be trivially_copyable."); + if constexpr (sizeof(F) <= sizeof(block)) + { + ret = x & mask.get(0); + } + else + { + static_assert(sizeof(F) % sizeof(block) == 0, "we assume F is a multiple of block"); + block temp[sizeof(F) / sizeof(block)]; + memcpy(&temp, &x, sizeof(F)); + for (u64 i = 0; i < sizeof(F) / sizeof(block); ++i) + temp[i] &= mask; + memcpy(&ret, &temp, sizeof(F)); + } + } + }; diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index 0e1a8081..c9bf8534 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -9,6 +9,7 @@ #include "DpfMult.h" #include "libOTe/Tools/Foleage/FoleageUtils.h" +#include "libOTe/Tools/CoeffCtx.h" namespace osuCrypto { @@ -30,22 +31,6 @@ namespace osuCrypto Trit32 operator+(const Trit32& t) const { - //u64 msbMask, lsbMask; - //setBytes(msbMask, 0b10101010); - //setBytes(lsbMask, 0b01010101); - - //auto x0 = mVal; - //auto x1 = mVal >> 1; - //auto y0 = t.mVal; - //auto y1 = t.mVal >> 1; - - - //auto x1x0 = x1 ^ x0; - //auto z1 = (y0 ^ x0) & ~(x1x0 ^ y1); - //auto z0 = (x1 ^ y1) & ~(x1x0 ^ y0); - - //r.mVal = ((z1 << 1) & msbMask) | (z0 & lsbMask); - Trit32 r; for (u64 i = 0; i < 32; ++i) { @@ -54,8 +39,6 @@ namespace osuCrypto auto c = (a + b) % 3; r.mVal |= u64(c) << (i * 2); - //if (c != ((r.mVal >> (i * 2)) & 3)) - //throw RTE_LOC; } return r; } @@ -104,7 +87,6 @@ namespace osuCrypto } } - // returns the i'th Z_3 element. u8 operator[](u64 i) const { @@ -133,8 +115,16 @@ namespace osuCrypto return o; } + + template< + typename F, + typename CoeffCtx = DefaultCoeffCtx + > struct TriDpf { + using VecF = typename CoeffCtx::template Vec; + + u64 mPartyIdx = 0; u64 mDomain = 0; @@ -156,7 +146,7 @@ namespace osuCrypto { if (partyIdx > 1) throw RTE_LOC; - if (domain < 2) + if (domain == 0) throw RTE_LOC; if (!numPoints) throw RTE_LOC; @@ -184,28 +174,48 @@ namespace osuCrypto { constexpr u64 VAR = 7; STATEMENT; }\ do{}while(0) - template< - typename Output - > + template macoro::task<> expand( span points, - span values, + Fs&& values, Output&& output, PRNG& prng, - coproto::Socket& sock) + coproto::Socket& sock, + CoeffCtx ctx = {}) { - if constexpr (std::is_same, Matrix>::value) - { - if (output.rows() != mNumPoints) - throw RTE_LOC; - if (output.cols() != mDomain) - throw RTE_LOC; - } + static_assert(std::is_same_v>, "values must be a vector like type of F."); + static_assert( + std::is_invocable_v || + std::is_invocable_v + , "output must be a callback/lambda that callable with (u64 treeIdx, u64 leafIdx, F value, u8 tag) or (u64 treeIdx, u64 leafIdx, F value) "); + if (points.size() != mNumPoints) throw RTE_LOC; if (values.size() && values.size() != mNumPoints) throw RTE_LOC; + if (mDomain == 1) + { + // trivial case where the domain is 1. + if (values.size()) + { + for (u64 i = 0; i < mNumPoints; ++i) + output(i, 0, values[i], mPartyIdx); + } + else + { + VecF rand; + ctx.resize(rand, 1); + for (u64 i = 0; i < mNumPoints; ++i) + { + ctx.fromBlock(rand[0], prng.get()); + output(i, 0, rand[0], mPartyIdx); + } + } + co_return; + } + + for (u64 i = 0; i < mNumPoints; ++i) { u64 v = points[i].mVal; @@ -334,16 +344,21 @@ namespace osuCrypto } } } - AlignedUnVector sums(mNumPoints); - Matrix t(ipow(3, mDepth), mNumPoints); + //auto size = ipow(3, mDepth); + VecF sums, leafVals; + ctx.resize(sums, mNumPoints); + ctx.zero(sums.begin(), sums.end()); + ctx.resize(leafVals, mNumPoints * mDomain); + + Matrix t(mDomain, mNumPoints); // fixing the last layer { - auto size = ipow(3, mDepth); auto& parentTags = s[(mDepth - 1) % 3]; auto& curSeed = s[mDepth % 3]; + auto leafIter = leafVals.begin(); - for (u64 L = 0, L2 = 0; L2 < size; ++L, L2 += 3) + for (u64 L = 0, L2 = 0; L2 < mDomain; ++L, L2 += 3) { // parent control bits auto parentTag = getRow(parentTags, L); @@ -351,14 +366,20 @@ namespace osuCrypto // child seed std::array scl{ getRow(curSeed, L2 + 0), getRow(curSeed, L2 + 1), getRow(curSeed, L2 + 2) }; - for (u64 j = 0; j < 3; ++j) + auto m = std::min(3, mDomain - L2); + for (u64 j = 0; j < m; ++j) { for (u64 k = 0; k < mNumPoints; ++k) { auto s = curSeed[L2 + j][k] ^ parentTag[k] & sigma[j][k]; t[L2 + j][k] = lsb(s); - curSeed[L2 + j][k] = /*convert_G*/ AES::roundFn(s, s);//AES::roundFn is used to get rid of the correlation in the LSB. - sums[k] = sums[k] ^ curSeed[L2 + j][k]; + + ctx.fromBlock(*leafIter, AES::roundFn(s, s)); + ctx.plus(sums[k], sums[k], *leafIter); + ++leafIter; + + //curSeed[L2 + j][k] = /*convert_G*/ AES::roundFn(s, s);//AES::roundFn is used to get rid of the correlation in the LSB. + //sums[k] = sums[k] ^ curSeed[L2 + j][k]; //std::cout << mPartyIdx << " " << Trit32(L2 + j) << " " << curSeed[L2 + j][k] << " " << int(curTag[L2 + j][k]) << std::endl; } } @@ -368,26 +389,34 @@ namespace osuCrypto if (values.size()) { - AlignedUnVector gamma(mNumPoints), diff(mNumPoints); - setBytes(diff, 0); + VecF gamma, diff; + ctx.resize(gamma, mNumPoints); + ctx.resize(diff, mNumPoints); + auto& curSeed = s[mDepth % 3]; for (u64 k = 0; k < mNumPoints; ++k) { - diff[k] = sums[k] ^ values[k]; + //diff[k] = sums[k] + values[k]; + ctx.plus(diff[k], values[k], sums[k]); } co_await sock.send(std::move(diff)); co_await sock.recv(gamma); for (u64 k = 0; k < mNumPoints; ++k) { - gamma[k] = sums[k] ^ values[k] ^ gamma[k]; + //gamma[k] = reveal(sums[k] + values[k]); + ctx.plus(gamma[k], gamma[k], sums[k]); + ctx.plus(gamma[k], gamma[k], values[k]); } - auto& sd = s[mDepth % 3]; + auto leafIter = leafVals.begin(); + VecF temp; + ctx.resize(temp, 1); + //auto& sd = s[mDepth % 3]; //auto& td = t[mDepth & 1]; for (u64 i = 0; i < mDomain; ++i) { - auto sdi = getRow(sd, i); + //auto sdi = getRow(sd, i); auto tdi = getRow(t, i); //for (u64 k = 0; k < mNumPoints8; k += 8) @@ -398,27 +427,31 @@ namespace osuCrypto //} for (u64 k = 0; k < mNumPoints; ++k) { - auto T = block::allSame(-tdi[k]) & gamma[k]; - auto V = sdi[k] ^ T; - output(k, i, V, tdi[k]); + ctx.mask(temp[0], gamma[k], block::allSame(-tdi[k])); + ctx.plus(temp[0], temp[0], *leafIter++); + //auto V = sdi[k] ^ T; + if constexpr (std::is_invocable_v) + output(k, i, temp[0], tdi[k]); + else + output(k, i, temp[0]); + } } } else { - auto& sd = s[mDepth % 3]; - auto& td = t;// [mDepth & 1] ; + + auto leafIter = leafVals.begin(); + auto tagIter = t.begin(); for (u64 i = 0; i < mDomain; ++i) { - auto sdi = getRow(sd, i); - auto tdi = getRow(td, i); - for (u64 k = 0; k < numPoints8; k += 8) - { - SIMD8(q, output(k + q, i, sdi[k + q], tdi[k + q])); - } for (u64 k = numPoints8; k < mNumPoints; ++k) { - output(k, i, sdi[k], tdi[k]); + if constexpr (std::is_invocable_v) + output(k, i, *leafIter++, *tagIter++); + else + output(k, i, *leafIter++); + } } } @@ -511,14 +544,14 @@ namespace osuCrypto for (u64 j = 0; j < 2; ++j) { - std::array kj = PRNG(k[j+1], 3).get(); + std::array kj = PRNG(k[j + 1], 3).get(); //setBytes(kj, 0); //sendBuffer[i * 3 + j] = PRNG(k[j], 3).get(); sendBuffer[i * 2 + j][0] = kj[0] ^ mask[i][0]; sendBuffer[i * 2 + j][1] = kj[1] ^ mask[i][1]; sendBuffer[i * 2 + j][2] = kj[2] ^ mask[i][2]; - sendBuffer[i * 2 + j][(j+1 + a) % 3] ^= r; + sendBuffer[i * 2 + j][(j + 1 + a) % 3] ^= r; //std::cout << "buffer " << j << std::endl // << " " << buffer[i * 3 + j][0] << "\n" diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index e9aa7973..17214ee7 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -10,32 +10,32 @@ namespace osuCrypto { - void FoleageF4Ole::init(u64 partyIdx, u64 n, PRNG& prng) + void FoleageF4Ole::init2(u64 partyIdx, u64 n, PRNG& prng) { mPartyIdx = partyIdx; mLog3N = log3ceil(n); + mLog3T = log3ceil(mT); mN = ipow(3, mLog3N); if (mT != ipow(3, mLog3T)) throw RTE_LOC; - // - // log3( (mN / mT) / 256 ) - // log3N - log3(mT * 256) mBlockSize = mN / mT; - mDpfDomainDepth = std::max(1, log3ceil(divCeil(mBlockSize, 256))); - mDpfBlockSize = 4 * ipow(3, mDpfDomainDepth); + mBlockDepth = mLog3N - mLog3T; + mDpfLeafDepth = std::min(5, mBlockDepth); + mDpfTreeDepth = mBlockDepth - mDpfLeafDepth; - if (mBlockSize < 2) - throw RTE_LOC; + mDpfLeafSize = ipow(3, mDpfLeafDepth); + mDpfTreeSize = ipow(3, mDpfTreeDepth); - sampleA(block(431234234, 213434234123)); + _mDpfDomainDepth = std::max(1, log3ceil(divCeil(mBlockSize, 256))); + _mDpfBlockSize = 4 * ipow(3, _mDpfDomainDepth); - //std::cout << "a " << hash(mFftA.data(), mFftA.size()) << std::endl; - //std::cout << "a2 " << hash(mFftASquared.data(), mFftASquared.size()) << std::endl; - + if (mBlockSize < 2) + throw RTE_LOC; + sampleA(block(431234234, 213434234123)); } @@ -117,6 +117,8 @@ namespace osuCrypto PRNG& prng, coproto::Socket& sock) { + bool oldDpf = false; + setTimePoint("expand start"); if (divCeil(mN, 128) < ALsb.size()) @@ -193,9 +195,10 @@ namespace osuCrypto std::vector prodPolyCoefficient(mC * mC * mT * mT); - std::vector prodPolyCoefficient2(mC * mC * mT * mT); + std::vector prodPolyCoefficient2(mC * mC * mT * mT); std::vector prodPolyPosition(mC * mC * mT * mT); - std::vector prodPolyPositionTrit(mC * mC * mT * mT); + std::vector prodPolyLeafPos(mC * mC * mT * mT); + std::vector prodPolyTreePos(mC * mC * mT * mT); std::vector tritABlk(mLog3T), tritBBlk(mLog3T), tritsBlk(mLog3T); std::vector tritAPos(mLog3N - mLog3T), tritBPos(mLog3N - mLog3T), tritsPos(mLog3N - mLog3T); @@ -256,9 +259,18 @@ namespace osuCrypto tritsPos[k] = (tritAPos[k] + tritBPos[k]) % 3; } + // the position within the leaf + std::vector leafPos(tritsPos.begin(), tritsPos.begin() + mDpfLeafDepth); + + // the position within the tree + std::vector treePos(tritsPos.begin() + mDpfLeafDepth, tritsPos.begin() + mBlockDepth); + + // the index of the value within the block auto subblock_pos = trits_to_int(tritsPos); - //positionMap[pointIdx] = + auto leafPosInt = trits_to_int(leafPos); + auto treePosInt = trits_to_int(treePos); + size_t idx = polyOffset + blockIdx * mT + nextIdx[blockIdx]++; prodPolyCoefficient[idx] = mult_f4(vA, vB); @@ -266,15 +278,23 @@ namespace osuCrypto if (mPartyIdx) { - prodPolyPositionTrit[idx] = Trit32(234 % mBlockSize); - prodPolyCoefficient2[idx] = block(42314342, 234123); + prodPolyLeafPos[idx] = Trit32(73452343 % mDpfLeafSize); + prodPolyTreePos[idx] = Trit32(53423453 % mDpfTreeSize); + prodPolyCoefficient2[idx].mVal[0] = block(42314342, 234123); } else { - prodPolyPositionTrit[idx] = Trit32(subblock_pos) - Trit32(234 % mBlockSize); - prodPolyCoefficient2[idx] = block(0, mult_f4(vA, vB)) ^ block(42314342, 234123); + prodPolyLeafPos[idx] = Trit32(leafPosInt) - Trit32(53424534 % mDpfLeafSize); + prodPolyTreePos[idx] = Trit32(treePosInt) - Trit32(53423453 % mDpfTreeSize); + + auto v = prodPolyCoefficient[idx]; + auto iter = BitIterator(&prodPolyCoefficient2[idx]) + 2 * leafPosInt; + *iter++ = v & 1; v >>= 1; + *iter++ = v & 1; v >>= 1; + + prodPolyCoefficient2[idx].mVal[0] ^= block(42314342, 234123); } } } @@ -300,9 +320,11 @@ namespace osuCrypto PRNG genPrng; ////oc::RandomOracle dpfHash(16); - Matrix blocks(mC * mC * mT, packedBlockSize); - if (0) + std::vector fft(mN), fftRes(mN); + + if (oldDpf) { + Matrix blocks(mC * mC * mT, packedBlockSize); for (u64 i = 0, index = 0; i < mC; i++) @@ -345,11 +367,11 @@ namespace osuCrypto DPFKey _; if (mPartyIdx) { - DPFGen(prf_keys, mDpfDomainDepth, alpha_0, beta, 4, _, Dpfs[index], genPrng); + DPFGen(prf_keys, _mDpfDomainDepth, alpha_0, beta, 4, _, Dpfs[index], genPrng); } else { - DPFGen(prf_keys, mDpfDomainDepth, alpha_0, beta, 4, Dpfs[index], _, genPrng); + DPFGen(prf_keys, _mDpfDomainDepth, alpha_0, beta, 4, Dpfs[index], _, genPrng); } //dpfHash.Update(Dpfs[index].k.data(), Dpfs[index].k.size()); @@ -366,10 +388,9 @@ namespace osuCrypto //dpfHash.Final(dpfHashVal); //std::cout << "dpf " << dpfHashVal << std::endl; - std::vector shares(mDpfBlockSize); - std::vector cache(mDpfBlockSize); + std::vector shares(_mDpfBlockSize); + std::vector cache(_mDpfBlockSize); - Matrix blocks(mC * mC * mT, packedBlockSize); auto dpfIter = Dpfs.begin(); @@ -405,34 +426,100 @@ namespace osuCrypto } } } - } - else - { + for (size_t j = 0; j < mC; j++) { - mDpf.init(mPartyIdx, mBlockSize, prodPolyPositionTrit.size()); - auto numOTs = mDpf.baseOtCount(); - std::vector baseRecvOts(numOTs); - std::vector> baseSendOts(numOTs); - BitVector baseChoices(numOTs); - PRNG basePrng(block(324234, 234234)); - basePrng.get(baseSendOts.data(), baseSendOts.size()); - baseChoices.randomize(basePrng); - for (u64 i = 0; i < numOTs; ++i) + for (size_t k = 0; k < mC; k++) { - baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; + size_t poly_index = (j * mC + k); + + oc::MatrixView poly(blocks.data(poly_index * mT), mT, packedBlockSize); + + u64 i = 0; + for (u64 block_idx = 0; block_idx < mT; ++block_idx) + { + for (u64 packed_idx = 0; packed_idx < packedBlockSize; ++packed_idx) + { + auto coeff = extractF4(poly(block_idx, packed_idx)); + auto e = std::min(mBlockSize - packed_idx * 64, 64); + + for (u64 element_idx = 0; element_idx < e; ++element_idx) + { + fft[i] |= u32{ coeff[element_idx] } << (2 * poly_index); + //fft[i] |= u32{ coeff[63 - element_idx] } << (2 * poly_index); + ++i; + } + } + } } + } - mDpf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); + setTimePoint("transpose"); + } + else + { + Matrix blocks512(mC * mC * mT, mDpfTreeSize); + + //if (mDpfTreeSize == 1) + //{ + // //std::copy(prodPolyCoefficient2.begin(), prodPolyCoefficient2.end(), blocks512.data()); + // for(u64 i = 0; i < prodPolyCoefficient2.size(); ++i) + // blocks512(i/mT) += prodPolyCoefficient2[i]; + //} + //else + //{ + mDpf.init(mPartyIdx, mDpfTreeSize, prodPolyLeafPos.size()); + auto numOTs = mDpf.baseOtCount(); + std::vector baseRecvOts(numOTs); + std::vector> baseSendOts(numOTs); + BitVector baseChoices(numOTs); + PRNG basePrng(block(324234, 234234)); + basePrng.get(baseSendOts.data(), baseSendOts.size()); + baseChoices.randomize(basePrng); + for (u64 i = 0; i < numOTs; ++i) + { + baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; } - co_await mDpf.expand(prodPolyPositionTrit, prodPolyCoefficient2, [&](u64 treeIdx, u64 leafIdx, block v, u8 t) { + mDpf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); + + co_await mDpf.expand(prodPolyLeafPos, prodPolyCoefficient2, [&](u64 treeIdx, u64 leafIdx, block512 v, u8 t) { // treeIdx in [0, mC^2 * mT^2] auto row = treeIdx / mT; - blocks(row, leafIdx / 64) ^= uint128_t(v.get(0)) << (2 * (leafIdx % 64)); + blocks512(row, leafIdx) += v; }, prng, sock); + //} + + + for (size_t j = 0; j < mC; j++) + { + for (size_t k = 0; k < mC; k++) + { + size_t poly_index = (j * mC + k); + + oc::MatrixView poly(blocks512.data(poly_index * mT), mT, mDpfTreeSize); + + u64 i = 0; + for (u64 block_idx = 0; block_idx < mT; ++block_idx) + { + for (u64 packed_idx = 0; packed_idx < mDpfTreeSize; ++packed_idx) + { + auto coeff = extractF4(poly(block_idx, packed_idx)); + auto e = std::min(mBlockSize - packed_idx * mDpfLeafSize, mDpfLeafSize); + + for (u64 element_idx = 0; element_idx < e; ++element_idx) + { + fft[i] |= u32{ coeff[element_idx] } << (2 * poly_index); + //fft[i] |= u32{ coeff[63 - element_idx] } << (2 * poly_index); + ++i; + } + } + } + } + } + } setTimePoint("dpfKeyEval"); @@ -518,35 +605,6 @@ namespace osuCrypto //} ////std::cout << "block " << hash(blocks.data(), blocks.size()) << std::endl; - std::vector fft(mN), fftRes(mN); - for (size_t j = 0; j < mC; j++) - { - for (size_t k = 0; k < mC; k++) - { - size_t poly_index = (j * mC + k); - - oc::MatrixView poly(blocks.data(poly_index * mT), mT, packedBlockSize); - - u64 i = 0; - for (u64 block_idx = 0; block_idx < mT; ++block_idx) - { - for (u64 packed_idx = 0; packed_idx < packedBlockSize; ++packed_idx) - { - auto coeff = extractF4(poly(block_idx, packed_idx)); - auto e = std::min(mBlockSize - packed_idx * 64, 64); - - for (u64 element_idx = 0; element_idx < e; ++element_idx) - { - fft[i] |= u32{ coeff[element_idx] } << (2 * poly_index); - //fft[i] |= u32{ coeff[63 - element_idx] } << (2 * poly_index); - ++i; - } - } - } - } - } - - setTimePoint("transpose"); //std::cout << "CIn " << hash(fft.data(), fft.size()) << std::endl; diff --git a/libOTe/Tools/Foleage/FoleagePcg.h b/libOTe/Tools/Foleage/FoleagePcg.h index 11a6b873..e800087b 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.h +++ b/libOTe/Tools/Foleage/FoleagePcg.h @@ -9,6 +9,7 @@ namespace osuCrypto { + class FoleageF4Ole : public TimerAdapter { public: @@ -20,7 +21,7 @@ namespace osuCrypto // the number of noisy positions per polynomial u64 mT = 3; - u64 mLog3T = 1; + u64 mLog3T = 0; // the number of polynomials u64 mC = 4; @@ -36,14 +37,22 @@ namespace osuCrypto AlignedVector mFftASquared; // depth of 3-ary DPF with 256 F4 values per leaf. - u64 mDpfDomainDepth = 0; + u64 _mDpfDomainDepth = 0; - u64 mDpfBlockSize = 0; + u64 _mDpfBlockSize = 0; // the number of F4 values per block. Each block will have 1 non-zero. // A polynomial will have mT blocks. i.e. mN = mT * mBlockSize. u64 mBlockSize = 0; + u64 mBlockDepth = 0; + + u64 mDpfLeafDepth = 0; + u64 mDpfTreeDepth = 0; + u64 mDpfTreeSize = 0; + + u64 mDpfLeafSize = 0; + // the coefficient of the sparse polynomial. // the i'th row containts the coeffs for the i'th poly. Matrix mSparseCoefficients; @@ -52,9 +61,13 @@ namespace osuCrypto // the i'th row containts the coeffs for the i'th poly. Matrix mSparsePositions; - TriDpf mDpf; + // a dpf used to construct the leaf value of the larger DPF. + TriDpf mDpfLeaf; + + // the main DPF + TriDpf mDpf; - void init(u64 partyIdx, u64 n, PRNG& prng); + void init2(u64 partyIdx, u64 n, PRNG& prng); macoro::task<> expand( span ALsb, diff --git a/libOTe/Tools/Foleage/FoleageUtils.h b/libOTe/Tools/Foleage/FoleageUtils.h index a1f62a97..dc7ef852 100644 --- a/libOTe/Tools/Foleage/FoleageUtils.h +++ b/libOTe/Tools/Foleage/FoleageUtils.h @@ -231,6 +231,8 @@ namespace osuCrypto // Converts an array of trits (not packed) into their integer representation. inline size_t trits_to_int(span trits) { + if (trits.size() == 0) + return 0; reverse_uint8_array(trits); size_t result = 0; for (size_t i = 0; i < trits.size(); i++) @@ -320,4 +322,41 @@ namespace osuCrypto return ret; } + struct block512 + { + std::array mVal; + + block512 operator+(const block512& o) const + { + block512 r; + r.mVal[0] = mVal[0] ^ o.mVal[0]; + r.mVal[1] = mVal[1] ^ o.mVal[1]; + r.mVal[2] = mVal[2] ^ o.mVal[2]; + r.mVal[3] = mVal[3] ^ o.mVal[3]; + return r; + } + block512 operator-(const block512& o) const { return *this + o; } + block512& operator+=(const block512& o) + { + mVal[0] = mVal[0] ^ o.mVal[0]; + mVal[1] = mVal[1] ^ o.mVal[1]; + mVal[2] = mVal[2] ^ o.mVal[2]; + mVal[3] = mVal[3] ^ o.mVal[3]; + return *this; + } + }; + + inline std::array extractF4(const block512& val) + { + std::array ret; + const char* ptr = (const char*)&val; + for (u8 i = 0; i < 64; ++i) + { + ret[i * 4 + 0] = (ptr[i] >> 0) & 3; + ret[i * 4 + 1] = (ptr[i] >> 2) & 3; + ret[i * 4 + 2] = (ptr[i] >> 4) & 3; + ret[i * 4 + 3] = (ptr[i] >> 6) & 3;; + } + return ret; + } } \ No newline at end of file diff --git a/libOTe_Tests/Foleage_Tests.cpp b/libOTe_Tests/Foleage_Tests.cpp index 7885f973..16c75582 100644 --- a/libOTe_Tests/Foleage_Tests.cpp +++ b/libOTe_Tests/Foleage_Tests.cpp @@ -1401,8 +1401,8 @@ namespace osuCrypto PRNG prng1(block(6474567454546, 567546754674345444)); Timer timer; - oles[0].init(0, n, prng0); - oles[1].init(1, n, prng1); + oles[0].init2(0, n, prng0); + oles[1].init2(1, n, prng1); auto sock = coproto::LocalAsyncSocket::makePair(); std::vector ALsb(blocks), diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index d31a90d8..edd86971 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -417,7 +417,7 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) values1[i] = prng.get(); } - std::array dpf; + std::array, 2> dpf; dpf[0].init(0, domain, numPoints); dpf[1].init(1, domain, numPoints); From 25af2133722c5e27fc1a8fe4633f2dd7f32c2d73 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sat, 15 Feb 2025 19:39:19 -0800 Subject: [PATCH 16/48] foleage dpf for leaf working --- libOTe/Tools/Dpf/TriDpf.h | 16 ++++- libOTe/Tools/Foleage/FoleagePcg.cpp | 94 +++++++++++++++-------------- libOTe/Tools/Foleage/FoleagePcg.h | 4 +- libOTe/Tools/Foleage/FoleageUtils.h | 15 ++++- libOTe_Tests/Foleage_Tests.cpp | 2 +- libOTe_Tests/RegularDpf_Tests.cpp | 32 +++++++--- 6 files changed, 102 insertions(+), 61 deletions(-) diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index c9bf8534..40fd859b 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -17,7 +17,7 @@ namespace osuCrypto // The value is stored in 2 bits per Z_3 element. struct Trit32 { - u64 mVal = 0; + u64 mVal; Trit32() = default; Trit32(const Trit32&) = default; @@ -32,6 +32,7 @@ namespace osuCrypto Trit32 operator+(const Trit32& t) const { Trit32 r; + r.mVal = 0; for (u64 i = 0; i < 32; ++i) { auto a = t[i]; @@ -47,6 +48,7 @@ namespace osuCrypto Trit32 operator-(const Trit32& t) const { Trit32 r; + r.mVal = 0; for (u64 i = 0; i < 32; ++i) { auto a = t[i]; @@ -200,7 +202,12 @@ namespace osuCrypto if (values.size()) { for (u64 i = 0; i < mNumPoints; ++i) - output(i, 0, values[i], mPartyIdx); + { + if constexpr(std::is_invocable_v) + output(i, 0, values[i], mPartyIdx); + else + output(i, 0, values[i]); + } } else { @@ -209,7 +216,10 @@ namespace osuCrypto for (u64 i = 0; i < mNumPoints; ++i) { ctx.fromBlock(rand[0], prng.get()); - output(i, 0, rand[0], mPartyIdx); + if constexpr (std::is_invocable_v) + output(i, 0, rand[0], mPartyIdx); + else + output(i, 0, rand[0]); } } co_return; diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index 17214ee7..a592dda1 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -195,7 +195,9 @@ namespace osuCrypto std::vector prodPolyCoefficient(mC * mC * mT * mT); - std::vector prodPolyCoefficient2(mC * mC * mT * mT); + std::vector prodPolyCoefficientShare(mC * mC * mT * mT); + //std::vector prodPolyCoefficient2(mC * mC * mT * mT); + std::vector prodPolyCoefficient3(mC * mC * mT * mT); std::vector prodPolyPosition(mC * mC * mT * mT); std::vector prodPolyLeafPos(mC * mC * mT * mT); std::vector prodPolyTreePos(mC * mC * mT * mT); @@ -280,21 +282,16 @@ namespace osuCrypto { prodPolyLeafPos[idx] = Trit32(73452343 % mDpfLeafSize); prodPolyTreePos[idx] = Trit32(53423453 % mDpfTreeSize); - prodPolyCoefficient2[idx].mVal[0] = block(42314342, 234123); + prodPolyCoefficientShare[idx] = (idx%4); } else { - prodPolyLeafPos[idx] = Trit32(leafPosInt) - Trit32(53424534 % mDpfLeafSize); + prodPolyLeafPos[idx] = Trit32(leafPosInt) - Trit32(73452343 % mDpfLeafSize); prodPolyTreePos[idx] = Trit32(treePosInt) - Trit32(53423453 % mDpfTreeSize); auto v = prodPolyCoefficient[idx]; - auto iter = BitIterator(&prodPolyCoefficient2[idx]) + 2 * leafPosInt; - - *iter++ = v & 1; v >>= 1; - *iter++ = v & 1; v >>= 1; - - prodPolyCoefficient2[idx].mVal[0] ^= block(42314342, 234123); + prodPolyCoefficientShare[idx] = v ^ (idx % 4); } } } @@ -308,22 +305,24 @@ namespace osuCrypto setTimePoint("sparseProductCompute"); - std::vector Dpfs(mC * mC * mT * mT); - - // Sample PRF keys for the DPFs - PRFKeys prf_keys; - PRNG prfSeedPrng(block(3412342134, 56453452362346)); - prf_keys.gen(prfSeedPrng); - size_t packedBlockSize = divCeil(mBlockSize, 64); - - // Sample DPF keys for each of the t errors in the t blocks - PRNG genPrng; ////oc::RandomOracle dpfHash(16); std::vector fft(mN), fftRes(mN); + if (oldDpf) { + std::vector Dpfs(mC * mC * mT * mT); + + // Sample PRF keys for the DPFs + PRFKeys prf_keys; + PRNG prfSeedPrng(block(3412342134, 56453452362346)); + prf_keys.gen(prfSeedPrng); + size_t packedBlockSize = divCeil(mBlockSize, 64); + + // Sample DPF keys for each of the t errors in the t blocks + PRNG genPrng; + Matrix blocks(mC * mC * mT, packedBlockSize); @@ -461,37 +460,44 @@ namespace osuCrypto { Matrix blocks512(mC * mC * mT, mDpfTreeSize); - //if (mDpfTreeSize == 1) - //{ - // //std::copy(prodPolyCoefficient2.begin(), prodPolyCoefficient2.end(), blocks512.data()); - // for(u64 i = 0; i < prodPolyCoefficient2.size(); ++i) - // blocks512(i/mT) += prodPolyCoefficient2[i]; - //} - //else - //{ - mDpf.init(mPartyIdx, mDpfTreeSize, prodPolyLeafPos.size()); - auto numOTs = mDpf.baseOtCount(); - std::vector baseRecvOts(numOTs); - std::vector> baseSendOts(numOTs); - BitVector baseChoices(numOTs); - PRNG basePrng(block(324234, 234234)); - basePrng.get(baseSendOts.data(), baseSendOts.size()); - baseChoices.randomize(basePrng); - for (u64 i = 0; i < numOTs; ++i) { - baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; + mDpfLeaf.init(mPartyIdx, mDpfLeafSize, prodPolyLeafPos.size()); + auto numOTs = mDpfLeaf.baseOtCount(); + std::vector baseRecvOts(numOTs); + std::vector> baseSendOts(numOTs); + BitVector baseChoices(numOTs); + PRNG basePrng(block(324234, 234234)); + basePrng.get(baseSendOts.data(), baseSendOts.size()); + baseChoices.randomize(basePrng); + for (u64 i = 0; i < numOTs; ++i) + baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; + mDpfLeaf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); } - mDpf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); - - co_await mDpf.expand(prodPolyLeafPos, prodPolyCoefficient2, [&](u64 treeIdx, u64 leafIdx, block512 v, u8 t) { - // treeIdx in [0, mC^2 * mT^2] - auto row = treeIdx / mT; - blocks512(row, leafIdx) += v; + co_await mDpfLeaf.expand(prodPolyLeafPos, prodPolyCoefficientShare, [&](u64 treeIdx, u64 leafIdx, u8 v) { + *BitIterator(&prodPolyCoefficient3[treeIdx], leafIdx * 2 + 0) = (v >> 0) & 1; + *BitIterator(&prodPolyCoefficient3[treeIdx], leafIdx * 2 + 1) = (v >> 1) & 1; }, prng, sock); - //} + { + mDpf.init(mPartyIdx, mDpfTreeSize, prodPolyLeafPos.size()); + auto numOTs = mDpf.baseOtCount(); + std::vector baseRecvOts(numOTs); + std::vector> baseSendOts(numOTs); + BitVector baseChoices(numOTs); + PRNG basePrng(block(324234, 234234)); + basePrng.get(baseSendOts.data(), baseSendOts.size()); + baseChoices.randomize(basePrng); + for (u64 i = 0; i < numOTs; ++i) + baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; + mDpf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); + } + + co_await mDpf.expand(prodPolyTreePos, prodPolyCoefficient3, [&](u64 treeIdx, u64 leafIdx, block512 v) { + auto row = treeIdx / mT; + blocks512(row, leafIdx) ^= v; + }, prng, sock); for (size_t j = 0; j < mC; j++) { diff --git a/libOTe/Tools/Foleage/FoleagePcg.h b/libOTe/Tools/Foleage/FoleagePcg.h index e800087b..320f1756 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.h +++ b/libOTe/Tools/Foleage/FoleagePcg.h @@ -62,10 +62,10 @@ namespace osuCrypto Matrix mSparsePositions; // a dpf used to construct the leaf value of the larger DPF. - TriDpf mDpfLeaf; + TriDpf mDpfLeaf; // the main DPF - TriDpf mDpf; + TriDpf mDpf; void init2(u64 partyIdx, u64 n, PRNG& prng); diff --git a/libOTe/Tools/Foleage/FoleageUtils.h b/libOTe/Tools/Foleage/FoleageUtils.h index dc7ef852..823190f9 100644 --- a/libOTe/Tools/Foleage/FoleageUtils.h +++ b/libOTe/Tools/Foleage/FoleageUtils.h @@ -326,7 +326,7 @@ namespace osuCrypto { std::array mVal; - block512 operator+(const block512& o) const + block512 operator^(const block512& o) const { block512 r; r.mVal[0] = mVal[0] ^ o.mVal[0]; @@ -335,8 +335,8 @@ namespace osuCrypto r.mVal[3] = mVal[3] ^ o.mVal[3]; return r; } - block512 operator-(const block512& o) const { return *this + o; } - block512& operator+=(const block512& o) + //block512 operator-(const block512& o) const { return *this + o; } + block512& operator^=(const block512& o) { mVal[0] = mVal[0] ^ o.mVal[0]; mVal[1] = mVal[1] ^ o.mVal[1]; @@ -344,6 +344,15 @@ namespace osuCrypto mVal[3] = mVal[3] ^ o.mVal[3]; return *this; } + + bool operator==(const block512& o) const + { + return + mVal[0] == o.mVal[0] && + mVal[1] == o.mVal[1] && + mVal[2] == o.mVal[2] && + mVal[3] == o.mVal[3]; + } }; inline std::array extractF4(const block512& val) diff --git a/libOTe_Tests/Foleage_Tests.cpp b/libOTe_Tests/Foleage_Tests.cpp index 16c75582..c532210d 100644 --- a/libOTe_Tests/Foleage_Tests.cpp +++ b/libOTe_Tests/Foleage_Tests.cpp @@ -1391,7 +1391,7 @@ namespace osuCrypto { std::array oles; - auto logn = 4; + auto logn = 10; u64 n = ipow(3, logn); auto blocks = divCeil(n, 128); bool verbose = cmd.isSet("v"); diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index edd86971..9a9415a5 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -394,7 +394,8 @@ void SparseDpf_Proto_Test(const oc::CLP& cmd) } } -void TritDpf_Proto_Test(const oc::CLP& cmd) +template +void TritDpf_Proto_Test_(const oc::CLP& cmd) { PRNG prng(block(231234, 321312)); @@ -404,20 +405,21 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) std::vector points0(numPoints); std::vector points1(numPoints); std::vector points(numPoints); - std::vector values0(numPoints); - std::vector values1(numPoints); + std::vector values0(numPoints); + std::vector values1(numPoints); + Ctx ctx; for (u64 i = 0; i < numPoints; ++i) { points[i] = Trit32(prng.get() % domain); points1[i] = Trit32(prng.get() % domain); points0[i] = points[i] - points1[i]; - //std::cout << points[i] << " = " << points0[i] <<" + "<< points1[i] << std::endl; values0[i] = prng.get(); values1[i] = prng.get(); + //ctx.minus(points0[i], points[i], points1[i];) } - std::array, 2> dpf; + std::array, 2> dpf; dpf[0].init(0, domain, numPoints); dpf[1].init(1, domain, numPoints); @@ -444,7 +446,7 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) dpf[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); dpf[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); - std::array, 2> output; + std::array, 2> output; std::array, 2> tags; output[0].resize(numPoints, domain); output[1].resize(numPoints, domain); @@ -463,10 +465,16 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) Trit32 I(i); for (u64 k = 0; k < numPoints; ++k) { - auto act = output[0][k][i] ^ output[1][k][i]; + F act; + ctx.plus(act, output[0][k][i], output[1][k][i]); auto t = I == points[k] ? 1 : 0; auto tAct = tags[0][k][i] ^ tags[1][k][i]; - auto exp = t ? (values0[k] ^ values1[k]) : ZeroBlock; + F exp; + if (t) + ctx.plus(exp, values0[k], values1[k]); + else + ctx.zero(&exp, &exp + 1); + if (exp != act) { std::cout << "i " << i << "=" << Trit32(i) << " " << t << std::endl; @@ -480,3 +488,11 @@ void TritDpf_Proto_Test(const oc::CLP& cmd) } } +void TritDpf_Proto_Test(const oc::CLP& cmd) +{ + TritDpf_Proto_Test_(cmd); + TritDpf_Proto_Test_(cmd); + //TritDpf_Proto_Test_(cmd); + +} + From 2f8b2bcdc1034c84c5e13847ed487fee76f9874d Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sat, 15 Feb 2025 21:40:02 -0800 Subject: [PATCH 17/48] shared foleage position done and tensor coeff started --- libOTe/Tools/Dpf/TriDpf.h | 13 + libOTe/Tools/Foleage/FoleagePcg.cpp | 458 ++++++---------------------- libOTe/Tools/Foleage/FoleagePcg.h | 1 + 3 files changed, 105 insertions(+), 367 deletions(-) diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index 40fd859b..3c7eb994 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -89,6 +89,19 @@ namespace osuCrypto } } + Trit32 lower(u64 digits) + { + Trit32 r; + r.mVal = mVal & ((1ull << (2 * digits)) - 1); + return r; + } + Trit32 upper(u64 digits) + { + Trit32 r; + r.mVal = mVal >> (2 * digits); + return r; + } + // returns the i'th Z_3 element. u8 operator[](u64 i) const { diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index a592dda1..567dcf5f 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -191,31 +191,18 @@ namespace osuCrypto A[i] = a; } setTimePoint("copyOutX"); - //std::cout << "compress " << hash(fftSparsePoly.data(), fftSparsePoly.size()) << std::endl; - - std::vector prodPolyCoefficient(mC * mC * mT * mT); std::vector prodPolyCoefficientShare(mC * mC * mT * mT); - //std::vector prodPolyCoefficient2(mC * mC * mT * mT); std::vector prodPolyCoefficient3(mC * mC * mT * mT); - std::vector prodPolyPosition(mC * mC * mT * mT); std::vector prodPolyLeafPos(mC * mC * mT * mT); std::vector prodPolyTreePos(mC * mC * mT * mT); - std::vector tritABlk(mLog3T), tritBBlk(mLog3T), tritsBlk(mLog3T); - std::vector tritAPos(mLog3N - mLog3T), tritBPos(mLog3N - mLog3T), tritsPos(mLog3N - mLog3T); - - Matrix otherSparseCoefficients(mC, mT); - Matrix otherSparsePositions(mC, mT); - co_await sock.send(coproto::copy(mSparseCoefficients)); - co_await sock.send(coproto::copy(mSparsePositions)); - co_await sock.recv(otherSparseCoefficients); - co_await sock.recv(otherSparsePositions); setTimePoint("sendRecv"); - std::vector positionMap(mC * mC * mT * mT); + + std::vector tensoredCoefficients(mC * mC * mT * mT); + co_await tensor(mSparseCoefficients, tensoredCoefficients, sock); u64 polyOffset = 0; - u8 vA, vB; for (u64 iA = 0, pointIdx = 0; iA < mC; ++iA) { for (u64 iB = 0; iB < mC; ++iB) @@ -226,73 +213,16 @@ namespace osuCrypto { for (u64 jB = 0; jB < mT; ++jB, ++pointIdx) { - int_to_trits(jA, tritABlk); - int_to_trits(jB, tritBBlk); - - for (size_t k = 0; k < mLog3T; k++) - { - tritsBlk[k] = (tritABlk[k] + tritBBlk[k]) % 3; - } - u64 blockIdx = trits_to_int(tritsBlk); - - u64 posA_; - u64 posB_; - - if (mPartyIdx == 0) - { - vA = mSparseCoefficients(iA, jA); - vB = otherSparseCoefficients(iB, jB); - posA_ = mSparsePositions(iA, jA); - posB_ = otherSparsePositions(iB, jB); - - } - else - { - vA = otherSparseCoefficients(iA, jA); - vB = mSparseCoefficients(iB, jB); - posA_ = otherSparsePositions(iA, jA); - posB_ = mSparsePositions(iB, jB); - } - int_to_trits(posA_, tritAPos); - int_to_trits(posB_, tritBPos); - - for (u64 k = 0; k < tritBPos.size(); ++k) - { - tritsPos[k] = (tritAPos[k] + tritBPos[k]) % 3; - } - - // the position within the leaf - std::vector leafPos(tritsPos.begin(), tritsPos.begin() + mDpfLeafDepth); - - // the position within the tree - std::vector treePos(tritsPos.begin() + mDpfLeafDepth, tritsPos.begin() + mBlockDepth); - - // the index of the value within the block - auto subblock_pos = trits_to_int(tritsPos); - - auto leafPosInt = trits_to_int(leafPos); - auto treePosInt = trits_to_int(treePos); - + u64 i = mPartyIdx ? iB : iA; + u64 j = mPartyIdx ? jB : jA; + auto pos = Trit32(mSparsePositions(i, j)); + auto blockPos = Trit32(jA) + Trit32(jB); + auto blockIdx = blockPos.toInt(); size_t idx = polyOffset + blockIdx * mT + nextIdx[blockIdx]++; - prodPolyCoefficient[idx] = mult_f4(vA, vB); - prodPolyPosition[idx] = subblock_pos; - - if (mPartyIdx) - { - prodPolyLeafPos[idx] = Trit32(73452343 % mDpfLeafSize); - prodPolyTreePos[idx] = Trit32(53423453 % mDpfTreeSize); - prodPolyCoefficientShare[idx] = (idx%4); - - } - else - { - prodPolyLeafPos[idx] = Trit32(leafPosInt) - Trit32(73452343 % mDpfLeafSize); - prodPolyTreePos[idx] = Trit32(treePosInt) - Trit32(53423453 % mDpfTreeSize); - - auto v = prodPolyCoefficient[idx]; - prodPolyCoefficientShare[idx] = v ^ (idx % 4); - } + prodPolyLeafPos[idx] = pos.lower(mDpfLeafDepth); + prodPolyTreePos[idx] = pos.upper(mDpfLeafDepth); + prodPolyCoefficientShare[idx] = tensoredCoefficients[pointIdx]; } } @@ -306,315 +236,77 @@ namespace osuCrypto setTimePoint("sparseProductCompute"); - ////oc::RandomOracle dpfHash(16); std::vector fft(mN), fftRes(mN); + Matrix blocks512(mC * mC * mT, mDpfTreeSize); - - if (oldDpf) { - std::vector Dpfs(mC * mC * mT * mT); + mDpfLeaf.init(mPartyIdx, mDpfLeafSize, prodPolyLeafPos.size()); + auto numOTs = mDpfLeaf.baseOtCount(); + std::vector baseRecvOts(numOTs); + std::vector> baseSendOts(numOTs); + BitVector baseChoices(numOTs); + PRNG basePrng(block(324234, 234234)); + basePrng.get(baseSendOts.data(), baseSendOts.size()); + baseChoices.randomize(basePrng); + for (u64 i = 0; i < numOTs; ++i) + baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; + mDpfLeaf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); + } - // Sample PRF keys for the DPFs - PRFKeys prf_keys; - PRNG prfSeedPrng(block(3412342134, 56453452362346)); - prf_keys.gen(prfSeedPrng); - size_t packedBlockSize = divCeil(mBlockSize, 64); + co_await mDpfLeaf.expand(prodPolyLeafPos, prodPolyCoefficientShare, [&](u64 treeIdx, u64 leafIdx, u8 v) { + *BitIterator(&prodPolyCoefficient3[treeIdx], leafIdx * 2 + 0) = (v >> 0) & 1; + *BitIterator(&prodPolyCoefficient3[treeIdx], leafIdx * 2 + 1) = (v >> 1) & 1; + }, prng, sock); - // Sample DPF keys for each of the t errors in the t blocks - PRNG genPrng; - Matrix blocks(mC * mC * mT, packedBlockSize); + { + mDpf.init(mPartyIdx, mDpfTreeSize, prodPolyLeafPos.size()); + auto numOTs = mDpf.baseOtCount(); + std::vector baseRecvOts(numOTs); + std::vector> baseSendOts(numOTs); + BitVector baseChoices(numOTs); + PRNG basePrng(block(324234, 234234)); + basePrng.get(baseSendOts.data(), baseSendOts.size()); + baseChoices.randomize(basePrng); + for (u64 i = 0; i < numOTs; ++i) + baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; + mDpf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); + } + co_await mDpf.expand(prodPolyTreePos, prodPolyCoefficient3, [&](u64 treeIdx, u64 leafIdx, block512 v) { + auto row = treeIdx / mT; + blocks512(row, leafIdx) ^= v; + }, prng, sock); - for (u64 i = 0, index = 0; i < mC; i++) + for (size_t j = 0; j < mC; j++) + { + for (size_t k = 0; k < mC; k++) { - for (u64 j = 0; j < mC; j++) - { - for (u64 k = 0; k < mT; k++) - { - for (u64 l = 0; l < mT; l++, ++index) - { - //size_t index = i * c * t * t + j * t * t + k * t + l; - - // Parse the index into the right format - size_t alpha = prodPolyPosition[index]; - - // Output message index in the DPF output space - // which consists of 256 F4 elements - size_t alpha_0 = alpha / 256; - - // Coeff index in the block of 256 coefficients - size_t alpha_1 = alpha % 256; + size_t poly_index = (j * mC + k); - // Coeff index in the uint128_t output (64 elements of F4) - size_t packed_idx = alpha_1 / 64; + oc::MatrixView poly(blocks512.data(poly_index * mT), mT, mDpfTreeSize); - // Bit index in the uint128_t ouput - size_t bit_idx = alpha_1 % 64; - - // Set the DPF message to the coefficient - uint128_t coeff = uint128_t(prodPolyCoefficient[index]); - - // Position coefficient into the block - std::array beta; // init to zero - setBytes(beta, 0); - //beta[packed_idx] = coeff << (2 * (63 - bit_idx)); - beta[packed_idx] = coeff << (2 * (bit_idx)); - - // Message (beta) is of size 4 blocks of 128 bits - genPrng.SetSeed(block(index, 542345234)); - DPFKey _; - if (mPartyIdx) - { - DPFGen(prf_keys, _mDpfDomainDepth, alpha_0, beta, 4, _, Dpfs[index], genPrng); - } - else - { - DPFGen(prf_keys, _mDpfDomainDepth, alpha_0, beta, 4, Dpfs[index], _, genPrng); - } - - //dpfHash.Update(Dpfs[index].k.data(), Dpfs[index].k.size()); - //dpfHash.Update(Dpfs[index].msg_len); - //dpfHash.Update(Dpfs[index].size); - - } - } - } - } - setTimePoint("dpfKeyGen"); - - //block dpfHashVal; - //dpfHash.Final(dpfHashVal); - //std::cout << "dpf " << dpfHashVal << std::endl; - - std::vector shares(_mDpfBlockSize); - std::vector cache(_mDpfBlockSize); - - - - auto dpfIter = Dpfs.begin(); - //Matrix expPos(mC* mC* mT, mT); - //Matrix expCoeff(mC* mC* mT, mT); - //auto dpf_keys_B_iter = dpf_keys_B.begin(); - for (size_t i = 0, q = 0; i < mC; i++) - { - for (size_t j = 0; j < mC; j++) + u64 i = 0; + for (u64 block_idx = 0; block_idx < mT; ++block_idx) { - const size_t poly_index = i * mC + j; - - oc::MatrixView packed_polyA_(blocks.data(poly_index * mT), mT, blocks.cols()); - - for (size_t k = 0; k < mT; k++) + for (u64 packed_idx = 0; packed_idx < mDpfTreeSize; ++packed_idx) { - span poly_blockA = packed_polyA_[k]; - - for (size_t l = 0; l < mT; l++, ++q) - { - DPFFullDomainEval(*dpfIter++, cache, shares); - - // Sum all the DPFs for the current block together - // note that there is some extra "garbage" in the last - // block of uint128_t since 64 does not divide block_size. - // We deal with this slack later when packing the outputs - // into the parallel FFT matrix. - for (size_t w = 0; w < packedBlockSize; w++) - { - poly_blockA[w] ^= shares[w]; - } - } - } - } - } + auto coeff = extractF4(poly(block_idx, packed_idx)); + auto e = std::min(mBlockSize - packed_idx * mDpfLeafSize, mDpfLeafSize); - - for (size_t j = 0; j < mC; j++) - { - for (size_t k = 0; k < mC; k++) - { - size_t poly_index = (j * mC + k); - - oc::MatrixView poly(blocks.data(poly_index * mT), mT, packedBlockSize); - - u64 i = 0; - for (u64 block_idx = 0; block_idx < mT; ++block_idx) - { - for (u64 packed_idx = 0; packed_idx < packedBlockSize; ++packed_idx) + for (u64 element_idx = 0; element_idx < e; ++element_idx) { - auto coeff = extractF4(poly(block_idx, packed_idx)); - auto e = std::min(mBlockSize - packed_idx * 64, 64); - - for (u64 element_idx = 0; element_idx < e; ++element_idx) - { - fft[i] |= u32{ coeff[element_idx] } << (2 * poly_index); - //fft[i] |= u32{ coeff[63 - element_idx] } << (2 * poly_index); - ++i; - } + fft[i] |= u32{ coeff[element_idx] } << (2 * poly_index); + //fft[i] |= u32{ coeff[63 - element_idx] } << (2 * poly_index); + ++i; } } } } - - setTimePoint("transpose"); } - else - { - Matrix blocks512(mC * mC * mT, mDpfTreeSize); - - { - mDpfLeaf.init(mPartyIdx, mDpfLeafSize, prodPolyLeafPos.size()); - auto numOTs = mDpfLeaf.baseOtCount(); - std::vector baseRecvOts(numOTs); - std::vector> baseSendOts(numOTs); - BitVector baseChoices(numOTs); - PRNG basePrng(block(324234, 234234)); - basePrng.get(baseSendOts.data(), baseSendOts.size()); - baseChoices.randomize(basePrng); - for (u64 i = 0; i < numOTs; ++i) - baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; - mDpfLeaf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); - } - - co_await mDpfLeaf.expand(prodPolyLeafPos, prodPolyCoefficientShare, [&](u64 treeIdx, u64 leafIdx, u8 v) { - *BitIterator(&prodPolyCoefficient3[treeIdx], leafIdx * 2 + 0) = (v >> 0) & 1; - *BitIterator(&prodPolyCoefficient3[treeIdx], leafIdx * 2 + 1) = (v >> 1) & 1; - }, prng, sock); - - - { - mDpf.init(mPartyIdx, mDpfTreeSize, prodPolyLeafPos.size()); - auto numOTs = mDpf.baseOtCount(); - std::vector baseRecvOts(numOTs); - std::vector> baseSendOts(numOTs); - BitVector baseChoices(numOTs); - PRNG basePrng(block(324234, 234234)); - basePrng.get(baseSendOts.data(), baseSendOts.size()); - baseChoices.randomize(basePrng); - for (u64 i = 0; i < numOTs; ++i) - baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; - mDpf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); - } - - co_await mDpf.expand(prodPolyTreePos, prodPolyCoefficient3, [&](u64 treeIdx, u64 leafIdx, block512 v) { - auto row = treeIdx / mT; - blocks512(row, leafIdx) ^= v; - }, prng, sock); - - for (size_t j = 0; j < mC; j++) - { - for (size_t k = 0; k < mC; k++) - { - size_t poly_index = (j * mC + k); - - oc::MatrixView poly(blocks512.data(poly_index * mT), mT, mDpfTreeSize); - - u64 i = 0; - for (u64 block_idx = 0; block_idx < mT; ++block_idx) - { - for (u64 packed_idx = 0; packed_idx < mDpfTreeSize; ++packed_idx) - { - auto coeff = extractF4(poly(block_idx, packed_idx)); - auto e = std::min(mBlockSize - packed_idx * mDpfLeafSize, mDpfLeafSize); - - for (u64 element_idx = 0; element_idx < e; ++element_idx) - { - fft[i] |= u32{ coeff[element_idx] } << (2 * poly_index); - //fft[i] |= u32{ coeff[63 - element_idx] } << (2 * poly_index); - ++i; - } - } - } - } - } - } setTimePoint("dpfKeyEval"); - //if (1) - //{ - // auto F4Print = [](uint128_t v)->std::string - // { - // std::stringstream ss; - // for (u64 i = 0; i < 64; ++i) - // { - // auto lsb = *BitIterator(&v, i * 2); - // auto msb = *BitIterator(&v, i * 2 + 1); - // ss << (lsb + 2 * msb); - // } - // return ss.str(); - // }; - - // co_await sock.send(coproto::copy(blocks)); - // co_await sock.send(coproto::copy(blocks2)); - - // Matrix rBlocks(mC * mC * mT, packedBlockSize); - // Matrix rBlocks2(mC * mC * mT, packedBlockSize); - - // co_await sock.recv(rBlocks); - // co_await sock.recv(rBlocks2); - // for (u64 i = 0; i < rBlocks.rows(); ++i) - // { - // std::vector exp(packedBlockSize); - // std::vector exp2(packedBlockSize); - - // for (u64 j = 0; j < packedBlockSize; ++j) - // { - // exp[j] = blocks(i, j) ^ rBlocks(i, j); - // } - // auto points = span(prodPolyPosition.data() + i * mT, mT); - // auto coeffs = span(prodPolyCoefficient.data() + i * mT, mT); - - // for (u64 j = 0; j < mT; ++j) - // { - // auto blk = points[j] / 64; - // auto offset = (2 * (points[j] % 64)); - // exp2[blk] ^= uint128_t(coeffs[j]) << offset; - // } - - // if (exp != exp2) - // { - // std::cout << i << std::endl << "exp\n "; - // for (u64 j = 0; j < packedBlockSize; ++j) - // { - // std::cout << F4Print(exp[j])<< " "; - // } - // std::cout << std::endl << "exp2\n "; - // for (u64 j = 0; j < packedBlockSize; ++j) - // { - // std::cout << F4Print(exp2[j]) << " "; - // } - - // throw RTE_LOC; - // } - - // for (u64 j = 0; j < packedBlockSize; ++j) - // { - - // auto act = blocks2(i, j) ^ rBlocks2(i, j); - // if (exp[j] != act) - // { - // std::cout << i << std::endl << "exp\n "; - // for (u64 j = 0; j < packedBlockSize; ++j) - // { - // auto v = (blocks(i, j) ^ rBlocks(i, j)); - // std::cout << *(block*)&v << " "; - // } - // std::cout << std::endl << "act\n "; - // for (u64 j = 0; j < packedBlockSize; ++j) - // { - // auto v = (blocks2(i, j) ^ rBlocks2(i, j)); - // std::cout << *(block*)&v << " "; - // } - // throw RTE_LOC; - // } - // } - // } - //} - ////std::cout << "block " << hash(blocks.data(), blocks.size()) << std::endl; - - - //std::cout << "CIn " << hash(fft.data(), fft.size()) << std::endl; - - fft_recursive_uint32(fft, mLog3N, mN / 3); //std::cout << "Cfft " << hash(fft.data(), fft.size()) << std::endl; multiply_fft_32(mFftASquared, fft, fftRes, mN); @@ -655,4 +347,36 @@ namespace osuCrypto // co_return; //} + + macoro::task<> FoleageF4Ole::tensor(span coeffs, span prod, coproto::Socket& sock) + { + if (coeffs.size() * coeffs.size() != prod.size()) + throw RTE_LOC; + std::vector other(coeffs.size()); + co_await sock.send(coproto::copy(coeffs)); + co_await sock.recv(other); + + span A = coeffs, B = other; + if (mPartyIdx) + std::swap(A, B); + + for (u64 iA = 0, pointIdx = 0; iA < mC; ++iA) + { + for (u64 iB = 0; iB < mC; ++iB) + { + for (u64 jA = 0; jA < mT; ++jA) + { + for (u64 jB = 0; jB < mT; ++jB, ++pointIdx) + { + auto pos = iA * mT + jA; + auto pos2 = iB * mT + jB; + prod[pointIdx] = + (mult_f4(A[pos], B[pos2]) * mPartyIdx) ^ + (pointIdx % 4); + } + } + } + } + } + } \ No newline at end of file diff --git a/libOTe/Tools/Foleage/FoleagePcg.h b/libOTe/Tools/Foleage/FoleagePcg.h index 320f1756..a845feba 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.h +++ b/libOTe/Tools/Foleage/FoleagePcg.h @@ -75,6 +75,7 @@ namespace osuCrypto span CLsb, span CMsb, PRNG& prng, coproto::Socket& sock); + macoro::task<> tensor(span coeffs, span prod, coproto::Socket& sock); //macoro::task<> dpfEval( // u64 domain, From f732a6e6c97a62b0da1f93af757e06357f2eda60 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sun, 16 Feb 2025 20:04:14 -0800 Subject: [PATCH 18/48] distributed foliage working --- libOTe/Tools/Dpf/TriDpf.h | 1 - libOTe/Tools/Foleage/FoleagePcg.cpp | 270 +++++++++++++++++++--------- libOTe/Tools/Foleage/FoleagePcg.h | 73 +++++++- libOTe_Tests/Foleage_Tests.cpp | 88 ++++++++- libOTe_Tests/Foleage_Tests.h | 1 + libOTe_Tests/UnitTests.cpp | 216 +++++++++++----------- 6 files changed, 449 insertions(+), 200 deletions(-) diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index 3c7eb994..9ff6333f 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -585,7 +585,6 @@ namespace osuCrypto co_await socks[0].send(std::move(sendBuffer)); - }; auto recver = [&]() -> macoro::task<> { diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index 567dcf5f..30966542 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -10,7 +10,7 @@ namespace osuCrypto { - void FoleageF4Ole::init2(u64 partyIdx, u64 n, PRNG& prng) + void FoleageF4Ole::init(u64 partyIdx, u64 n, PRNG& prng) { mPartyIdx = partyIdx; mLog3N = log3ceil(n); @@ -32,6 +32,9 @@ namespace osuCrypto _mDpfBlockSize = 4 * ipow(3, _mDpfDomainDepth); + mDpfLeaf.init(mPartyIdx, mDpfLeafSize, mC * mC * mT * mT); + mDpf.init(mPartyIdx, mDpfTreeSize, mC * mC * mT * mT); + if (mBlockSize < 2) throw RTE_LOC; @@ -128,49 +131,39 @@ namespace osuCrypto mSparseCoefficients.resize(mC, mT); mSparsePositions.resize(mC, mT); + + std::vector tensoredCoefficients(mC * mC * mT * mT); + co_await tensor(mSparseCoefficients, tensoredCoefficients, sock); + for (u64 i = 0; i < mC * mT; ++i) { - while (mSparseCoefficients(i) == 0) - mSparseCoefficients(i) = prng.get() & 3; + //while (mSparseCoefficients(i) == 0) + // mSparseCoefficients(i) = prng.get() & 3; mSparsePositions(i) = prng.get() % mBlockSize; } - - //std::cout << "pos " << hash(mSparsePositions.data(), mSparsePositions.size()) << std::endl; - //std::cout << "coeff " << hash(mSparseCoefficients.data(), mSparseCoefficients.size()) << std::endl; - - if (mC != 4) throw RTE_LOC; // we pack 4 FFTs into a single u8. std::vector fftSparsePoly(mN); - //std::vector fftSparsePolyLsb(mN), fftSparsePolyMsb(mN); for (u64 i = 0; i < mT; ++i) { for (u64 j = 0; j < mC; ++j) { auto pos = i * mBlockSize + mSparsePositions(j, i); fftSparsePoly[pos] |= mSparseCoefficients(j, i) << (2 * j); - - //fftSparsePolyLsb[pos] |= (mSparseCoefficients(j, i) & 1) << j; - //fftSparsePolyMsb[pos] |= ((mSparseCoefficients(j, i) >> 1) & 1) << j; } } setTimePoint("sparsePolySample"); - //std::cout << "sparse " << hash(fftSparsePoly.data(), fftSparsePoly.size()) << std::endl; - // switch from polynomial to FFT form fft_recursive_uint8(fftSparsePoly, mLog3N, mN / 3); - //foleageFFT2<1>(fftSparsePolyLsb, fftSparsePolyMsb); - // multiply by the packed A polynomial multiply_fft_8(mFftA, fftSparsePoly, fftSparsePoly, mN); - //std::cout << "mult " << hash(fftSparsePoly.data(), fftSparsePoly.size()) << std::endl; setTimePoint("sparsePolyMul"); @@ -199,13 +192,10 @@ namespace osuCrypto setTimePoint("sendRecv"); - std::vector tensoredCoefficients(mC * mC * mT * mT); - co_await tensor(mSparseCoefficients, tensoredCoefficients, sock); - u64 polyOffset = 0; - for (u64 iA = 0, pointIdx = 0; iA < mC; ++iA) + for (u64 iA = 0, pointIdx = 0, polyOffset = 0; iA < mC; ++iA) { - for (u64 iB = 0; iB < mC; ++iB) + for (u64 iB = 0; iB < mC; ++iB, polyOffset += mT * mT) { std::vector nextIdx(mT); @@ -222,14 +212,16 @@ namespace osuCrypto size_t idx = polyOffset + blockIdx * mT + nextIdx[blockIdx]++; prodPolyLeafPos[idx] = pos.lower(mDpfLeafDepth); prodPolyTreePos[idx] = pos.upper(mDpfLeafDepth); - prodPolyCoefficientShare[idx] = tensoredCoefficients[pointIdx]; + + + auto coeffIdx = (iA * mT + jA) * mC * mT + iB * mT + jB; + + prodPolyCoefficientShare[idx] = tensoredCoefficients[coeffIdx]; } } if (nextIdx != std::vector(mT, mT)) throw RTE_LOC; - - polyOffset += mT * mT; } } @@ -239,19 +231,18 @@ namespace osuCrypto std::vector fft(mN), fftRes(mN); Matrix blocks512(mC * mC * mT, mDpfTreeSize); - { - mDpfLeaf.init(mPartyIdx, mDpfLeafSize, prodPolyLeafPos.size()); - auto numOTs = mDpfLeaf.baseOtCount(); - std::vector baseRecvOts(numOTs); - std::vector> baseSendOts(numOTs); - BitVector baseChoices(numOTs); - PRNG basePrng(block(324234, 234234)); - basePrng.get(baseSendOts.data(), baseSendOts.size()); - baseChoices.randomize(basePrng); - for (u64 i = 0; i < numOTs; ++i) - baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; - mDpfLeaf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); - } + //{ + // auto numOTs = mDpfLeaf.baseOtCount(); + // std::vector baseRecvOts(numOTs); + // std::vector> baseSendOts(numOTs); + // BitVector baseChoices(numOTs); + // PRNG basePrng(block(324234, 234234)); + // basePrng.get(baseSendOts.data(), baseSendOts.size()); + // baseChoices.randomize(basePrng); + // for (u64 i = 0; i < numOTs; ++i) + // baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; + // mDpfLeaf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); + //} co_await mDpfLeaf.expand(prodPolyLeafPos, prodPolyCoefficientShare, [&](u64 treeIdx, u64 leafIdx, u8 v) { *BitIterator(&prodPolyCoefficient3[treeIdx], leafIdx * 2 + 0) = (v >> 0) & 1; @@ -259,19 +250,18 @@ namespace osuCrypto }, prng, sock); - { - mDpf.init(mPartyIdx, mDpfTreeSize, prodPolyLeafPos.size()); - auto numOTs = mDpf.baseOtCount(); - std::vector baseRecvOts(numOTs); - std::vector> baseSendOts(numOTs); - BitVector baseChoices(numOTs); - PRNG basePrng(block(324234, 234234)); - basePrng.get(baseSendOts.data(), baseSendOts.size()); - baseChoices.randomize(basePrng); - for (u64 i = 0; i < numOTs; ++i) - baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; - mDpf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); - } + //{ + // auto numOTs = mDpf.baseOtCount(); + // std::vector baseRecvOts(numOTs); + // std::vector> baseSendOts(numOTs); + // BitVector baseChoices(numOTs); + // PRNG basePrng(block(324234, 234234)); + // basePrng.get(baseSendOts.data(), baseSendOts.size()); + // baseChoices.randomize(basePrng); + // for (u64 i = 0; i < numOTs; ++i) + // baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; + // mDpf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); + //} co_await mDpf.expand(prodPolyTreePos, prodPolyCoefficient3, [&](u64 treeIdx, u64 leafIdx, block512 v) { auto row = treeIdx / mT; @@ -297,7 +287,6 @@ namespace osuCrypto for (u64 element_idx = 0; element_idx < e; ++element_idx) { fft[i] |= u32{ coeff[element_idx] } << (2 * poly_index); - //fft[i] |= u32{ coeff[63 - element_idx] } << (2 * poly_index); ++i; } } @@ -334,49 +323,168 @@ namespace osuCrypto } - //macoro::task<> FoleageF4Ole::dpfEval( - // u64 domain, - // span points, - // span coeffs, - // MatrixView output, - // PRNG& prng, - // coproto::Socket& sock) - //{ - - - // co_return; - //} - macoro::task<> FoleageF4Ole::tensor(span coeffs, span prod, coproto::Socket& sock) { + //if (coeffs.size() != mC * mT) + // throw RTE_LOC; + if (coeffs.size() * coeffs.size() != prod.size()) throw RTE_LOC; - std::vector other(coeffs.size()); - co_await sock.send(coproto::copy(coeffs)); - co_await sock.recv(other); - span A = coeffs, B = other; - if (mPartyIdx) - std::swap(A, B); + if (0) + { + PRNG prng(CCBlock); + std::array, 2> s; + s[0].resize(coeffs.size()); + s[1].resize(coeffs.size()); + //prng.get(s0.data(), s0.size()); + for (u64 i = 0; i < s[0].size(); ++i) + { + s[0][i] = prng.get() % 4; + s[1][i] = prng.get() % 4; + } + std::copy(s[mPartyIdx].begin(), s[mPartyIdx].end(), coeffs.begin()); - for (u64 iA = 0, pointIdx = 0; iA < mC; ++iA) + for (u64 iA = 0, pointIdx = 0; iA < s[0].size(); ++iA) + { + for (u64 iB = 0; iB < s[1].size(); ++iB, ++pointIdx) + { + prod[pointIdx] = + (mult_f4(s[0][iA], s[1][iB]) * mPartyIdx);// ^ + //(prng.get() % 4); + } + } + } + else { - for (u64 iB = 0; iB < mC; ++iB) + + auto expand = [](block k, span diff) { + AES aes(k); + for (u64 i = 0; i < diff.size(); ++i) + diff[i] = aes.ecbEncBlock(block(i)); + }; + + if (divCeil(coeffs.size(), 128) != 1) + throw RTE_LOC; // not impl + auto size = 2 * divCeil(coeffs.size(), 128); + + + if (mPartyIdx) { - for (u64 jA = 0; jA < mT; ++jA) + if (mSendOts.size() < 2 * coeffs.size() - 1) + throw RTE_LOC; //base ots not set. + // b * a = (b0 * a + b1 * (2 * a)) + //auto getDiff = [](block k0, block k1, span diff) { + // AES aes0(k0); + // AES aes1(k1); + // for (u64 i = 0; i < diff.size(); ++i) + // diff[i] = aes0.ecbEncBlock(block(i)) ^ aes1.ecbEncBlock(block(i) * 2); + // }; + std::array, 2> a; a[0].resize(size), a[1].resize(size); + std::vector t0(size), t1(size); + expand(mSendOts[0][0], t0); + expand(mSendOts[0][1], t1); + for (u64 i = 0; i < size; ++i) + a[0][i] = t0[i] ^ t1[i]; + + // a[1] = 2 * a[0] + f4Mult(a[0][0], a[0][1], ZeroBlock, AllOneBlock, a[1][0], a[1][1]); + { - for (u64 jB = 0; jB < mT; ++jB, ++pointIdx) + auto lsbIter = BitIterator(&a[0][0]); + auto msbIter = BitIterator(&a[0][1]); + for (u64 i = 0; i < coeffs.size(); ++i) + coeffs[i] = (*lsbIter++ & 1) | ((*msbIter++ & 1) << 1); + } + + { + setBytes(prod, 0); + auto prodIter = prod.begin(); + auto lsbIter = BitIterator(&t0[0]); + auto msbIter = BitIterator(&t0[1]); + for (u64 i = 0; i < coeffs.size(); ++i) + *prodIter++ = (*lsbIter++) | (u8(*msbIter++) << 1); + } + + + std::vector buffer((2 * coeffs.size() - 1) * size); + auto buffIter = buffer.begin(); + for (u64 i = 1; i < 2 * coeffs.size(); ++i) + { + auto b = i & 1; + auto idx = i / 2; + auto prodIter = prod.begin() + idx * coeffs.size(); + + expand(mSendOts[i][0], t0); + expand(mSendOts[i][1], t1); + + // prod = mask + auto lsbIter = BitIterator(&t0[0]); + auto msbIter = BitIterator(&t0[1]); + for (u64 i = 0; i < coeffs.size(); ++i) + *prodIter++ ^= (*lsbIter++) | (u8(*msbIter++) << 1); + + for (u64 i = 0; i < a.size(); ++i) + { // mask key value + *buffIter++ = t0[i] ^ t1[i] ^ a[b][i]; + //*buffIter++ = diff[i]; + } + + } + + co_await sock.send(std::move(buffer)); + } + else + { + + if (mChoiceOts.size() < 2 * coeffs.size() - 1) + throw RTE_LOC; //base ots not set. + if (mRecvOts.size() < 2 * coeffs.size() - 1) + throw RTE_LOC; //base ots not set. + + for (u64 i = 0; i < coeffs.size(); ++i) + coeffs[i] = mChoiceOts[2 * i] | (u8(mChoiceOts[2 * i + 1] << 1)); + std::vector t(size); + expand(mRecvOts[0], t); + + { + setBytes(prod, 0); + auto prodIter = prod.begin(); + auto lsbIter = BitIterator(&t[0]); + auto msbIter = BitIterator(&t[1]); + for (u64 i = 0; i < coeffs.size(); ++i) + *prodIter++ = (*lsbIter++) | (u8(*msbIter++) << 1); + } + + std::vector buffer((2 * coeffs.size() - 1) * size); + co_await sock.recv(buffer); + + auto buffIter = buffer.begin(); + for (u64 i = 1; i < 2 * coeffs.size(); ++i) + { + auto idx = i / 2; + auto prodIter = prod.begin() + idx * coeffs.size(); + + expand(mRecvOts[i], t); + if (mChoiceOts[i]) { - auto pos = iA * mT + jA; - auto pos2 = iB * mT + jB; - prod[pointIdx] = - (mult_f4(A[pos], B[pos2]) * mPartyIdx) ^ - (pointIdx % 4); + for (u64 i = 0; i < size; ++i) + { + t[i] = t[i] ^ *buffIter++; + } } + else + buffIter += size; + + // prod = mask + auto lsbIter = BitIterator(&t[0]); + auto msbIter = BitIterator(&t[1]); + for (u64 i = 0; i < coeffs.size(); ++i) + *prodIter++ ^= (*lsbIter++) | (u8(*msbIter++) << 1); } } } - } + } } \ No newline at end of file diff --git a/libOTe/Tools/Foleage/FoleagePcg.h b/libOTe/Tools/Foleage/FoleagePcg.h index a845feba..8039dc57 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.h +++ b/libOTe/Tools/Foleage/FoleagePcg.h @@ -67,7 +67,70 @@ namespace osuCrypto // the main DPF TriDpf mDpf; - void init2(u64 partyIdx, u64 n, PRNG& prng); + std::vector mRecvOts; + std::vector> mSendOts; + BitVector mChoiceOts; + + void init(u64 partyIdx, u64 n, PRNG& prng); + + struct BaseOtCount + { + u64 mSendCount, mRecvCount; + }; + + BaseOtCount baseOtCount() const + { + BaseOtCount counts; + + counts.mSendCount = mDpfLeaf.baseOtCount() + mDpf.baseOtCount(); + counts.mRecvCount = mDpfLeaf.baseOtCount() + mDpf.baseOtCount(); + if(mPartyIdx) + counts.mSendCount += 2 * mC * mT; + else + counts.mRecvCount += 2 * mC * mT; + return counts; + } + + + void setBaseOts( + span> baseSendOts, + span recvBaseOts, + const oc::BitVector& baseChoices) + { + auto baseCounts = baseOtCount(); + if (baseSendOts.size() != baseCounts.mSendCount) + throw RTE_LOC; + if (recvBaseOts.size() != baseCounts.mRecvCount) + throw RTE_LOC; + if (baseChoices.size() != baseCounts.mRecvCount) + throw RTE_LOC; + auto recvIter = recvBaseOts; + auto sendIter = baseSendOts; + auto choiceIter = baseChoices; + + auto dpfLeafCount = mDpfLeaf.baseOtCount(); + u64 offset = 0; + mDpfLeaf.setBaseOts( + sendIter.subspan(offset, dpfLeafCount), + recvIter.subspan(offset, dpfLeafCount), + BitVector(baseChoices.data(), dpfLeafCount, offset) + ); + offset += dpfLeafCount; + + auto dpfCount = mDpf.baseOtCount(); + mDpf.setBaseOts( + sendIter.subspan(offset, dpfCount), + recvIter.subspan(offset, dpfCount), + BitVector(baseChoices.data(), dpfCount, offset) + ); + offset += dpfCount; + + auto sendOts = sendIter.subspan(offset); + auto recvOts = recvIter.subspan(offset); + mSendOts.insert(mSendOts.end(), sendOts.begin(), sendOts.end()); + mRecvOts.insert(mRecvOts.end(), recvOts.begin(), recvOts.end()); + mChoiceOts = BitVector(baseChoices.data(), baseChoices.size() - offset, offset); + } macoro::task<> expand( span ALsb, @@ -77,14 +140,6 @@ namespace osuCrypto macoro::task<> tensor(span coeffs, span prod, coproto::Socket& sock); - //macoro::task<> dpfEval( - // u64 domain, - // span points, - // span coeffs, - // MatrixView output, - // PRNG& prng, - // coproto::Socket& sock); - void sampleA(block seed); }; } diff --git a/libOTe_Tests/Foleage_Tests.cpp b/libOTe_Tests/Foleage_Tests.cpp index c532210d..40db356a 100644 --- a/libOTe_Tests/Foleage_Tests.cpp +++ b/libOTe_Tests/Foleage_Tests.cpp @@ -1401,8 +1401,37 @@ namespace osuCrypto PRNG prng1(block(6474567454546, 567546754674345444)); Timer timer; - oles[0].init2(0, n, prng0); - oles[1].init2(1, n, prng1); + oles[0].init(0, n, prng0); + oles[1].init(1, n, prng1); + + { + auto otCount0 = oles[0].baseOtCount(); + auto otCount1 = oles[1].baseOtCount(); + if (otCount0.mRecvCount != otCount1.mSendCount || + otCount0.mSendCount != otCount1.mRecvCount) + throw RTE_LOC; + std::array>, 2> baseSend; + baseSend[0].resize(otCount0.mSendCount); + baseSend[1].resize(otCount1.mSendCount); + std::array, 2> baseRecv; + std::array baseChoice; + + for (u64 i = 0; i < 2; ++i) + { + prng0.get(baseSend[i].data(), baseSend[i].size()); + baseRecv[1 ^ i].resize(baseSend[i].size()); + baseChoice[1^i].resize(baseSend[i].size()); + baseChoice[1 ^ i].randomize(prng0); + for (u64 j = 0; j < baseSend[i].size(); ++j) + { + baseRecv[1 ^ i][j] = baseSend[i][j][baseChoice[1 ^ i][j]]; + } + } + + oles[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); + oles[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); + } + auto sock = coproto::LocalAsyncSocket::makePair(); std::vector ALsb(blocks), @@ -1444,4 +1473,59 @@ namespace osuCrypto if (verbose) std::cout << "Time taken: \n" << timer << std::endl; } + void foleage_tensor_test(const CLP& cmd) + { + + std::array oles; + + bool verbose = cmd.isSet("v"); + + PRNG prng0(block(2424523452345, 111124521521455324)); + PRNG prng1(block(6474567454546, 567546754674345444)); + + oles[0].init(0, 1000, prng0); + oles[1].init(1, 1000, prng1); + + u64 n = oles[0].mC* oles[0].mT; + u64 n2 = n * n; + auto sock = coproto::LocalAsyncSocket::makePair(); + std::array, 2> coeff, prod; + coeff[0].resize(n); + coeff[1].resize(n); + prod[0].resize(n2); + prod[1].resize(n2); + + oles[1].mSendOts.resize(2 * n); + oles[0].mRecvOts.resize(2 * n); + oles[0].mChoiceOts.resize(2 * n); + for (u64 i = 0; i < 2 * n; ++i) + { + oles[1].mSendOts[i] = prng0.get();; + oles[0].mChoiceOts[i] = prng0.getBit(); + oles[0].mRecvOts[i] = oles[1].mSendOts[i][oles[0].mChoiceOts[i]]; + } + auto r = macoro::sync_wait(macoro::when_all_ready( + oles[0].tensor(coeff[0],prod[0], sock[0]), + oles[1].tensor(coeff[1],prod[1], sock[1]))); + std::get<0>(r).result(); + std::get<1>(r).result(); + + // Now we check that we got the correct OLE correlations and fail + // the test otherwise. + for (size_t i = 0; i < n; i++) + { + for (size_t j = 0; j < n; j++) + { + auto p = i * n + j; + + u8 ci = coeff[0][i]; + u8 cj = coeff[1][j]; + auto exp = mult_f4(ci, cj); + auto act = prod[0][p] ^ prod[1][p]; + if (exp != act) + throw RTE_LOC; + } + } + + } } \ No newline at end of file diff --git a/libOTe_Tests/Foleage_Tests.h b/libOTe_Tests/Foleage_Tests.h index 33b8cfa5..813a5b2a 100644 --- a/libOTe_Tests/Foleage_Tests.h +++ b/libOTe_Tests/Foleage_Tests.h @@ -9,6 +9,7 @@ namespace osuCrypto void foleage_dpf_test(); void foleage_pcg_test(const CLP& cmd); void foleage_F4ole_test(const CLP& cmd); + void foleage_tensor_test(const CLP& cmd); } \ No newline at end of file diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 44332570..193ea1b2 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -22,111 +22,113 @@ using namespace osuCrypto; namespace tests_libOTe { - TestCollection Tests([](TestCollection& tc) - { - - - tc.add("Tools_Transpose_Test ", Tools_Transpose_Test); - tc.add("Tools_Transpose_View_Test ", Tools_Transpose_View_Test); - tc.add("Tools_Transpose_Bench ", Tools_Transpose_Bench); - - tc.add("Tools_LinearCode_Test ", Tools_LinearCode_Test); - tc.add("Tools_LinearCode_sub_Test ", Tools_LinearCode_sub_Test); - tc.add("Tools_LinearCode_rep_Test ", Tools_LinearCode_rep_Test); - - tc.add("Tools_bitShift_test ", Tools_bitShift_test); - tc.add("Tools_modp_test ", Tools_modp_test); - tc.add("Tools_bitpolymul_test ", Tools_bitpolymul_test); - tc.add("Tools_quasiCyclic_test ", Tools_quasiCyclic_test); - - tc.add("Mtx_make_test ", tests::Mtx_make_test); - tc.add("Mtx_add_test ", tests::Mtx_add_test); - tc.add("Mtx_mult_test ", tests::Mtx_mult_test); - tc.add("Mtx_invert_test ", tests::Mtx_invert_test); - - tc.add("EACode_encode_basic_test ", EACode_encode_basic_test); - tc.add("EACode_weight_test ", EACode_weight_test); - - tc.add("ExConvCode_encode_basic_test ", ExConvCode_encode_basic_test); - tc.add("ExConvCode_weight_test ", ExConvCode_weight_test); - - tc.add("TungstenCode_encode_test ", TungstenCode_encode_test); - tc.add("TungstenCode_weight_test ", TungstenCode_weight_test); - - tc.add("Tools_Pprf_expandOne_test ", Tools_Pprf_expandOne_test); - tc.add("Tools_Pprf_inter_test ", Tools_Pprf_inter_test); - tc.add("Tools_Pprf_ByLeafIndex_test ", Tools_Pprf_ByLeafIndex_test); - tc.add("Tools_Pprf_ByTreeIndex_test ", Tools_Pprf_ByTreeIndex_test); - tc.add("Tools_Pprf_callback_test ", Tools_Pprf_callback_test); - - tc.add("RegularDpf_Multiply_Test ", RegularDpf_Multiply_Test); - tc.add("RegularDpf_Proto_Test ", RegularDpf_Proto_Test); - tc.add("RegularDpf_keyGen_Test ", RegularDpf_keyGen_Test); - tc.add("SparseDpf_Proto_Test ", SparseDpf_Proto_Test); - tc.add("TritDpf_Proto_Test ", TritDpf_Proto_Test); - - - tc.add("foleage_transpose_test ", foleage_transpose_test); - tc.add("foleage_fft_test ", foleage_fft_test); - tc.add("foleage_dpf_test ", foleage_dpf_test); - tc.add("foleage_spfss_test ", foleage_spfss_test); - tc.add("foleage_pcg_test ", foleage_pcg_test); - tc.add("foleage_F4ole_test ", foleage_F4ole_test); - - - tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); - tc.add("Bot_Simplest_asm_Test ", Bot_Simplest_asm_Test); - - tc.add("Bot_McQuoidRR_Moeller_EKE_Test ", Bot_McQuoidRR_Moeller_EKE_Test); - tc.add("Bot_McQuoidRR_Moeller_MR_Test ", Bot_McQuoidRR_Moeller_MR_Test); - tc.add("Bot_McQuoidRR_Moeller_F_Test ", Bot_McQuoidRR_Moeller_F_Test); - tc.add("Bot_McQuoidRR_Moeller_FM_Test ", Bot_McQuoidRR_Moeller_FM_Test); - - tc.add("Bot_McQuoidRR_Ristrestto_F_Test ", Bot_McQuoidRR_Ristrestto_F_Test); - tc.add("Bot_McQuoidRR_Ristrestto_FM_Test ", Bot_McQuoidRR_Ristrestto_FM_Test); - - tc.add("Bot_MasnyRindal_Test ", Bot_MasnyRindal_Test); - tc.add("Bot_MasnyRindal_Kyber_Test ", Bot_MasnyRindal_Kyber_Test); - - tc.add("Vole_SoftSpokenSmall_Test ", Vole_SoftSpokenSmall_Test); - tc.add("DotExt_Kos_Test ", DotExt_Kos_Test); - tc.add("DotExt_Iknp_Test ", DotExt_Iknp_Test); - - - tc.add("OtExt_genBaseOts_Test ", OtExt_genBaseOts_Test); - tc.add("OtExt_Chosen_Test ", OtExt_Chosen_Test); - tc.add("OtExt_Iknp_Test ", OtExt_Iknp_Test); - tc.add("OtExt_Kos_Test ", OtExt_Kos_Test); - tc.add("OtExt_Kos_fs_Test ", OtExt_Kos_fs_Test); - tc.add("OtExt_Kos_ro_Test ", OtExt_Kos_ro_Test); - tc.add("OtExt_Silent_random_Test ", OtExt_Silent_random_Test); - tc.add("OtExt_Silent_correlated_Test ", OtExt_Silent_correlated_Test); - tc.add("OtExt_Silent_inplace_Test ", OtExt_Silent_inplace_Test); - tc.add("OtExt_Silent_paramSweep_Test ", OtExt_Silent_paramSweep_Test); - tc.add("OtExt_Silent_QuasiCyclic_Test ", OtExt_Silent_QuasiCyclic_Test); - tc.add("OtExt_Silent_Tungsten_Test ", OtExt_Silent_Tungsten_Test); - tc.add("OtExt_Silent_baseOT_Test ", OtExt_Silent_baseOT_Test); - tc.add("OtExt_Silent_mal_Test ", OtExt_Silent_mal_Test); - - tc.add("OtExt_SoftSpokenSemiHonest_Test ", OtExt_SoftSpokenSemiHonest_Test); - tc.add("OtExt_SoftSpokenSemiHonest_Split_Test ", OtExt_SoftSpokenSemiHonest_Split_Test); - //tc.add("OtExt_SoftSpokenSemiHonest21_Test ", OtExt_SoftSpokenSemiHonest21_Test); - tc.add("OtExt_SoftSpokenMalicious21_Test ", OtExt_SoftSpokenMalicious21_Test); - tc.add("OtExt_SoftSpokenMalicious21_Split_Test ", OtExt_SoftSpokenMalicious21_Split_Test); - tc.add("DotExt_SoftSpokenMaliciousLeaky_Test ", DotExt_SoftSpokenMaliciousLeaky_Test); - - tc.add("Vole_Noisy_test ", Vole_Noisy_test); - tc.add("Vole_Silent_paramSweep_test ", Vole_Silent_paramSweep_test); - tc.add("Vole_Silent_Tungsten_test ", Vole_Silent_Tungsten_test); - tc.add("Vole_Silent_QuasiCyclic_test ", Vole_Silent_QuasiCyclic_test); - tc.add("Vole_Silent_baseOT_test ", Vole_Silent_baseOT_test); - tc.add("Vole_Silent_mal_test ", Vole_Silent_mal_test); - tc.add("Vole_Silent_Rounds_test ", Vole_Silent_Rounds_test); - - tc.add("NcoOt_Kkrt_Test ", NcoOt_Kkrt_Test); - tc.add("NcoOt_Oos_Test ", NcoOt_Oos_Test); - tc.add("NcoOt_genBaseOts_Test ", NcoOt_genBaseOts_Test); - - - }); + TestCollection Tests([](TestCollection& tc) + { + + + tc.add("Tools_Transpose_Test ", Tools_Transpose_Test); + tc.add("Tools_Transpose_View_Test ", Tools_Transpose_View_Test); + tc.add("Tools_Transpose_Bench ", Tools_Transpose_Bench); + + tc.add("Tools_LinearCode_Test ", Tools_LinearCode_Test); + tc.add("Tools_LinearCode_sub_Test ", Tools_LinearCode_sub_Test); + tc.add("Tools_LinearCode_rep_Test ", Tools_LinearCode_rep_Test); + + tc.add("Tools_bitShift_test ", Tools_bitShift_test); + tc.add("Tools_modp_test ", Tools_modp_test); + tc.add("Tools_bitpolymul_test ", Tools_bitpolymul_test); + tc.add("Tools_quasiCyclic_test ", Tools_quasiCyclic_test); + + tc.add("Mtx_make_test ", tests::Mtx_make_test); + tc.add("Mtx_add_test ", tests::Mtx_add_test); + tc.add("Mtx_mult_test ", tests::Mtx_mult_test); + tc.add("Mtx_invert_test ", tests::Mtx_invert_test); + + tc.add("EACode_encode_basic_test ", EACode_encode_basic_test); + tc.add("EACode_weight_test ", EACode_weight_test); + + tc.add("ExConvCode_encode_basic_test ", ExConvCode_encode_basic_test); + tc.add("ExConvCode_weight_test ", ExConvCode_weight_test); + + tc.add("TungstenCode_encode_test ", TungstenCode_encode_test); + tc.add("TungstenCode_weight_test ", TungstenCode_weight_test); + + tc.add("Tools_Pprf_expandOne_test ", Tools_Pprf_expandOne_test); + tc.add("Tools_Pprf_inter_test ", Tools_Pprf_inter_test); + tc.add("Tools_Pprf_ByLeafIndex_test ", Tools_Pprf_ByLeafIndex_test); + tc.add("Tools_Pprf_ByTreeIndex_test ", Tools_Pprf_ByTreeIndex_test); + tc.add("Tools_Pprf_callback_test ", Tools_Pprf_callback_test); + + tc.add("RegularDpf_Multiply_Test ", RegularDpf_Multiply_Test); + tc.add("RegularDpf_Proto_Test ", RegularDpf_Proto_Test); + tc.add("RegularDpf_keyGen_Test ", RegularDpf_keyGen_Test); + tc.add("SparseDpf_Proto_Test ", SparseDpf_Proto_Test); + tc.add("TritDpf_Proto_Test ", TritDpf_Proto_Test); + + + tc.add("foleage_transpose_test ", foleage_transpose_test); + tc.add("foleage_fft_test ", foleage_fft_test); + tc.add("foleage_dpf_test ", foleage_dpf_test); + tc.add("foleage_spfss_test ", foleage_spfss_test); + tc.add("foleage_pcg_test ", foleage_pcg_test); + + tc.add("foleage_tensor_test ", foleage_tensor_test); + tc.add("foleage_F4ole_test ", foleage_F4ole_test); + + + tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); + tc.add("Bot_Simplest_asm_Test ", Bot_Simplest_asm_Test); + + tc.add("Bot_McQuoidRR_Moeller_EKE_Test ", Bot_McQuoidRR_Moeller_EKE_Test); + tc.add("Bot_McQuoidRR_Moeller_MR_Test ", Bot_McQuoidRR_Moeller_MR_Test); + tc.add("Bot_McQuoidRR_Moeller_F_Test ", Bot_McQuoidRR_Moeller_F_Test); + tc.add("Bot_McQuoidRR_Moeller_FM_Test ", Bot_McQuoidRR_Moeller_FM_Test); + + tc.add("Bot_McQuoidRR_Ristrestto_F_Test ", Bot_McQuoidRR_Ristrestto_F_Test); + tc.add("Bot_McQuoidRR_Ristrestto_FM_Test ", Bot_McQuoidRR_Ristrestto_FM_Test); + + tc.add("Bot_MasnyRindal_Test ", Bot_MasnyRindal_Test); + tc.add("Bot_MasnyRindal_Kyber_Test ", Bot_MasnyRindal_Kyber_Test); + + tc.add("Vole_SoftSpokenSmall_Test ", Vole_SoftSpokenSmall_Test); + tc.add("DotExt_Kos_Test ", DotExt_Kos_Test); + tc.add("DotExt_Iknp_Test ", DotExt_Iknp_Test); + + + tc.add("OtExt_genBaseOts_Test ", OtExt_genBaseOts_Test); + tc.add("OtExt_Chosen_Test ", OtExt_Chosen_Test); + tc.add("OtExt_Iknp_Test ", OtExt_Iknp_Test); + tc.add("OtExt_Kos_Test ", OtExt_Kos_Test); + tc.add("OtExt_Kos_fs_Test ", OtExt_Kos_fs_Test); + tc.add("OtExt_Kos_ro_Test ", OtExt_Kos_ro_Test); + tc.add("OtExt_Silent_random_Test ", OtExt_Silent_random_Test); + tc.add("OtExt_Silent_correlated_Test ", OtExt_Silent_correlated_Test); + tc.add("OtExt_Silent_inplace_Test ", OtExt_Silent_inplace_Test); + tc.add("OtExt_Silent_paramSweep_Test ", OtExt_Silent_paramSweep_Test); + tc.add("OtExt_Silent_QuasiCyclic_Test ", OtExt_Silent_QuasiCyclic_Test); + tc.add("OtExt_Silent_Tungsten_Test ", OtExt_Silent_Tungsten_Test); + tc.add("OtExt_Silent_baseOT_Test ", OtExt_Silent_baseOT_Test); + tc.add("OtExt_Silent_mal_Test ", OtExt_Silent_mal_Test); + + tc.add("OtExt_SoftSpokenSemiHonest_Test ", OtExt_SoftSpokenSemiHonest_Test); + tc.add("OtExt_SoftSpokenSemiHonest_Split_Test ", OtExt_SoftSpokenSemiHonest_Split_Test); + //tc.add("OtExt_SoftSpokenSemiHonest21_Test ", OtExt_SoftSpokenSemiHonest21_Test); + tc.add("OtExt_SoftSpokenMalicious21_Test ", OtExt_SoftSpokenMalicious21_Test); + tc.add("OtExt_SoftSpokenMalicious21_Split_Test ", OtExt_SoftSpokenMalicious21_Split_Test); + tc.add("DotExt_SoftSpokenMaliciousLeaky_Test ", DotExt_SoftSpokenMaliciousLeaky_Test); + + tc.add("Vole_Noisy_test ", Vole_Noisy_test); + tc.add("Vole_Silent_paramSweep_test ", Vole_Silent_paramSweep_test); + tc.add("Vole_Silent_Tungsten_test ", Vole_Silent_Tungsten_test); + tc.add("Vole_Silent_QuasiCyclic_test ", Vole_Silent_QuasiCyclic_test); + tc.add("Vole_Silent_baseOT_test ", Vole_Silent_baseOT_test); + tc.add("Vole_Silent_mal_test ", Vole_Silent_mal_test); + tc.add("Vole_Silent_Rounds_test ", Vole_Silent_Rounds_test); + + tc.add("NcoOt_Kkrt_Test ", NcoOt_Kkrt_Test); + tc.add("NcoOt_Oos_Test ", NcoOt_Oos_Test); + tc.add("NcoOt_genBaseOts_Test ", NcoOt_genBaseOts_Test); + + + }); } From f7c1bb1926e5a2026e6a3073bd11854b853a317f Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Sun, 16 Feb 2025 22:55:57 -0800 Subject: [PATCH 19/48] foleage cleanup --- libOTe/Tools/Dpf/TriDpf.h | 15 +- libOTe/Tools/Foleage/F4Ops.h | 2 +- libOTe/Tools/Foleage/FoleagePcg.cpp | 537 ++++++++++-------- libOTe/Tools/Foleage/FoleagePcg.h | 131 ++--- libOTe/Tools/Foleage/FoleageUtils.h | 16 +- libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp | 2 +- libOTe/Tools/Foleage/fft/FoleageFft.cpp | 8 +- libOTe/Tools/Foleage/fft/FoleageFft.h | 2 +- libOTe_Tests/Foleage_Tests.cpp | 21 +- 9 files changed, 409 insertions(+), 325 deletions(-) diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index 9ff6333f..e61f074e 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -73,7 +73,7 @@ namespace osuCrypto for (u64 i = 31; i < 32; --i) { r *= 3; - r |= (mVal >> (i * 2)) & 3; + r += (mVal >> (i * 2)) & 3; } return r; @@ -509,12 +509,13 @@ namespace osuCrypto auto H = [](const block& a, const block& b) -> block { - RandomOracle ro(sizeof(block)); - ro.Update(a); - ro.Update(b); - block r; - ro.Final(r); - return r; + return mAesFixedKey.hashBlock(mAesFixedKey.hashBlock(a) ^ b) ^ a; + //RandomOracle ro(sizeof(block)); + //ro.Update(a); + //ro.Update(b); + //block r; + //ro.Final(r); + //return r; }; auto sender = [&]() -> macoro::task<> { diff --git a/libOTe/Tools/Foleage/F4Ops.h b/libOTe/Tools/Foleage/F4Ops.h index ba2925e1..49198ed7 100644 --- a/libOTe/Tools/Foleage/F4Ops.h +++ b/libOTe/Tools/Foleage/F4Ops.h @@ -42,7 +42,7 @@ namespace osuCrypto // Multiplies two packed matrices of F4 elements column-by-column. // Note that here the "columns" are packed into an element of uint8_t // resulting in a matrix with 4 columns. - inline void multiply_fft_8( + inline void F4Multiply( span a_poly, span b_poly, span res_poly, diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index 30966542..dc65c785 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -10,7 +10,7 @@ namespace osuCrypto { - void FoleageF4Ole::init(u64 partyIdx, u64 n, PRNG& prng) + void FoleageF4Ole::init(u64 partyIdx, u64 n) { mPartyIdx = partyIdx; mLog3N = log3ceil(n); @@ -27,10 +27,8 @@ namespace osuCrypto mDpfLeafSize = ipow(3, mDpfLeafDepth); mDpfTreeSize = ipow(3, mDpfTreeDepth); - - _mDpfDomainDepth = std::max(1, log3ceil(divCeil(mBlockSize, 256))); - _mDpfBlockSize = 4 * ipow(3, _mDpfDomainDepth); - + //std::cout << "mLeafSize " << mDpfLeafSize << " " << mDpfLeafDepth << std::endl; + //std::cout << "mTreeSize " << mDpfTreeSize << " " << mDpfTreeDepth << std::endl; mDpfLeaf.init(mPartyIdx, mDpfLeafSize, mC * mC * mT * mT); mDpf.init(mPartyIdx, mDpfTreeSize, mC * mC * mT * mT); @@ -42,6 +40,66 @@ namespace osuCrypto } + FoleageF4Ole::BaseOtCount FoleageF4Ole::baseOtCount() const + { + BaseOtCount counts; + + counts.mSendCount = mDpfLeaf.baseOtCount() + mDpf.baseOtCount(); + counts.mRecvCount = mDpfLeaf.baseOtCount() + mDpf.baseOtCount(); + if (mPartyIdx) + counts.mSendCount += 2 * mC * mT; + else + counts.mRecvCount += 2 * mC * mT; + return counts; + } + + + void FoleageF4Ole::setBaseOts( + span> baseSendOts, + span recvBaseOts, + const oc::BitVector& baseChoices) + { + auto baseCounts = baseOtCount(); + if (baseSendOts.size() != baseCounts.mSendCount) + throw RTE_LOC; + if (recvBaseOts.size() != baseCounts.mRecvCount) + throw RTE_LOC; + if (baseChoices.size() != baseCounts.mRecvCount) + throw RTE_LOC; + auto recvIter = recvBaseOts; + auto sendIter = baseSendOts; + auto choiceIter = baseChoices; + + auto dpfLeafCount = mDpfLeaf.baseOtCount(); + u64 offset = 0; + mDpfLeaf.setBaseOts( + sendIter.subspan(offset, dpfLeafCount), + recvIter.subspan(offset, dpfLeafCount), + BitVector(baseChoices.data(), dpfLeafCount, offset) + ); + offset += dpfLeafCount; + + auto dpfCount = mDpf.baseOtCount(); + mDpf.setBaseOts( + sendIter.subspan(offset, dpfCount), + recvIter.subspan(offset, dpfCount), + BitVector(baseChoices.data(), dpfCount, offset) + ); + offset += dpfCount; + + auto sendOts = sendIter.subspan(offset); + auto recvOts = recvIter.subspan(offset); + mSendOts.insert(mSendOts.end(), sendOts.begin(), sendOts.end()); + mRecvOts.insert(mRecvOts.end(), recvOts.begin(), recvOts.end()); + mChoiceOts = BitVector(baseChoices.data(), baseChoices.size() - offset, offset); + } + + bool FoleageF4Ole::hasBaseOts() const + { + return mSendOts.size() + mRecvOts.size() > 0; + } + + void FoleageF4Ole::sampleA(block seed) { @@ -93,25 +151,11 @@ namespace osuCrypto } } } - - //{ - // std::vector fft_a(mN); - // std::vector fft_a2(mN); - // PRNG APrng(block(431234234, 213434234123)); - // sample_a_and_a2(fft_a, fft_a2, mN, mC, APrng); - - // for (u64 i = 0; i < mN; ++i) - // { - // if (fft_a[i] != mFftA[i]) - // throw RTE_LOC; - // if (fft_a2[i] != mFftASquared[i]) - // throw RTE_LOC; - // } - - //} } + + macoro::task<> FoleageF4Ole::expand( span ALsb, span AMsb, @@ -120,27 +164,39 @@ namespace osuCrypto PRNG& prng, coproto::Socket& sock) { - bool oldDpf = false; - setTimePoint("expand start"); + if (hasBaseOts() == false) + throw RTE_LOC; + if (divCeil(mN, 128) < ALsb.size()) throw RTE_LOC; - if (ALsb.size() != AMsb.size() || ALsb.size() != CLsb.size() || ALsb.size() != CMsb.size()) + if (ALsb.size() != AMsb.size() || + ALsb.size() != CLsb.size() || + ALsb.size() != CMsb.size()) throw RTE_LOC; - mSparseCoefficients.resize(mC, mT); + // the coefficient of the sparse polynomial. + // the i'th row containts the coeffs for the i'th poly. mSparsePositions.resize(mC, mT); + // The mT coefficients of the mC sparse polynomials. + Matrix sparseCoefficients(mC, mT); std::vector tensoredCoefficients(mC * mC * mT * mT); - co_await tensor(mSparseCoefficients, tensoredCoefficients, sock); - for (u64 i = 0; i < mC * mT; ++i) - { - //while (mSparseCoefficients(i) == 0) - // mSparseCoefficients(i) = prng.get() & 3; + // generate random sparseCoefficients and tensor them with + // the other parties sparse coefficients. The result is shared + // as tensoredCoefficients. Each set of (mC*mT) values in + // tensoredCoefficients are the multiplication of a single coeff + // from party 0 and all of the coefficients from party 1. + co_await tensor(sparseCoefficients, tensoredCoefficients, sock); + + //co_await checkTensor(sparseCoefficients, tensoredCoefficients, sock); + + // select random positions for the sparse polynomial. + // The i'th is the noise position in the i'th block. + for (u64 i = 0; i < mSparsePositions.size(); ++i) mSparsePositions(i) = prng.get() % mBlockSize; - } if (mC != 4) throw RTE_LOC; @@ -151,21 +207,22 @@ namespace osuCrypto { for (u64 j = 0; j < mC; ++j) { - auto pos = i * mBlockSize + mSparsePositions(j, i); - fftSparsePoly[pos] |= mSparseCoefficients(j, i) << (2 * j); + auto pos = i * mBlockSize + mSparsePositions(j, i);// .toInt(); + fftSparsePoly[pos] |= sparseCoefficients(j, i) << (2 * j); } } setTimePoint("sparsePolySample"); // switch from polynomial to FFT form - fft_recursive_uint8(fftSparsePoly, mLog3N, mN / 3); + foliageFftUint8(fftSparsePoly, mLog3N, mN / 3); - // multiply by the packed A polynomial - multiply_fft_8(mFftA, fftSparsePoly, fftSparsePoly, mN); + setTimePoint("input fft"); - setTimePoint("sparsePolyMul"); + // multiply by the packed A polynomial + F4Multiply(mFftA, fftSparsePoly, fftSparsePoly, mN); + setTimePoint("input Mult"); // compress the resume and set the output. auto outSize = std::min(mN, ALsb.size() * 128); @@ -185,12 +242,24 @@ namespace osuCrypto } setTimePoint("copyOutX"); - std::vector prodPolyCoefficientShare(mC * mC * mT * mT); - std::vector prodPolyCoefficient3(mC * mC * mT * mT); + // sharing of the F4 coefficients of the product polynomails. + // these will just be the tensored coefficients but in permuted + // order to match how they are expended in the DPF and then added + // together. + std::vector prodPolyF4Coeffs(mC * mC * mT * mT); + + // We are doing to use "early termination" on the main DPF. To do + // this we are going to construct new F4^243 coefficients where + // each prodPolyF4Coeffs is positioned at prodPolyLeafPos. This + // will allow the main DPF to be more efficient as we are outputting + // 243 F4 elements for each leaf. std::vector prodPolyLeafPos(mC * mC * mT * mT); + + // once we construct large F4^243 coefficients, we will expand them + // the main DPF to get the full shared polynomail. prodPolyTreePos + // is the location that the F4^243 coefficient should be mapped to. std::vector prodPolyTreePos(mC * mC * mT * mT); - setTimePoint("sendRecv"); for (u64 iA = 0, pointIdx = 0, polyOffset = 0; iA < mC; ++iA) @@ -206,17 +275,28 @@ namespace osuCrypto u64 i = mPartyIdx ? iB : iA; u64 j = mPartyIdx ? jB : jA; - auto pos = Trit32(mSparsePositions(i, j)); + + // the block of the product coefficient is known + // purely using the block index of the input coefficients. auto blockPos = Trit32(jA) + Trit32(jB); auto blockIdx = blockPos.toInt(); + + // We want to put all DPF that will be added together + // next to each other. We do this by using nextIdx to + // keep track of the next index for each output block. size_t idx = polyOffset + blockIdx * mT + nextIdx[blockIdx]++; + + // split the position into the portion that will position + // the F4 coefficient within the F4^243 coefficient and the + // portion that will position the F4^243 coefficient within + // the main DPF. + auto pos = Trit32(mSparsePositions(i, j)); prodPolyLeafPos[idx] = pos.lower(mDpfLeafDepth); prodPolyTreePos[idx] = pos.upper(mDpfLeafDepth); - + // get the corresponding tensored F4 coefficient. auto coeffIdx = (iA * mT + jA) * mC * mT + iB * mT + jB; - - prodPolyCoefficientShare[idx] = tensoredCoefficients[coeffIdx]; + prodPolyF4Coeffs[idx] = tensoredCoefficients[coeffIdx]; } } @@ -225,84 +305,94 @@ namespace osuCrypto } } - setTimePoint("sparseProductCompute"); + setTimePoint("dpfParams"); + // sharing of the F4^243 coefficients of the product polynomails. + // These are obtained by expanding the F4 coefficients into 243 + // elements using a "small DPF". + std::vector prodPolyF4x243Coeffs(mC * mC * mT * mT); - std::vector fft(mN), fftRes(mN); - Matrix blocks512(mC * mC * mT, mDpfTreeSize); - - //{ - // auto numOTs = mDpfLeaf.baseOtCount(); - // std::vector baseRecvOts(numOTs); - // std::vector> baseSendOts(numOTs); - // BitVector baseChoices(numOTs); - // PRNG basePrng(block(324234, 234234)); - // basePrng.get(baseSendOts.data(), baseSendOts.size()); - // baseChoices.randomize(basePrng); - // for (u64 i = 0; i < numOTs; ++i) - // baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; - // mDpfLeaf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); - //} - - co_await mDpfLeaf.expand(prodPolyLeafPos, prodPolyCoefficientShare, [&](u64 treeIdx, u64 leafIdx, u8 v) { - *BitIterator(&prodPolyCoefficient3[treeIdx], leafIdx * 2 + 0) = (v >> 0) & 1; - *BitIterator(&prodPolyCoefficient3[treeIdx], leafIdx * 2 + 1) = (v >> 1) & 1; + // current coefficients are single F4 elements. Expand them into + // 3^5=243 elements. These will be used as the new coefficients + // in the large tree. + co_await mDpfLeaf.expand(prodPolyLeafPos, prodPolyF4Coeffs, [&](u64 treeIdx, u64 leafIdx, u8 v) { + *BitIterator(&prodPolyF4x243Coeffs[treeIdx], leafIdx * 2 + 0) = (v >> 0) & 1; + *BitIterator(&prodPolyF4x243Coeffs[treeIdx], leafIdx * 2 + 1) = (v >> 1) & 1; }, prng, sock); + setTimePoint("leafDpf"); + - //{ - // auto numOTs = mDpf.baseOtCount(); - // std::vector baseRecvOts(numOTs); - // std::vector> baseSendOts(numOTs); - // BitVector baseChoices(numOTs); - // PRNG basePrng(block(324234, 234234)); - // basePrng.get(baseSendOts.data(), baseSendOts.size()); - // baseChoices.randomize(basePrng); - // for (u64 i = 0; i < numOTs; ++i) - // baseRecvOts[i] = baseSendOts[i][baseChoices[i]]; - // mDpf.setBaseOts(baseSendOts, baseRecvOts, baseChoices); - //} - - co_await mDpf.expand(prodPolyTreePos, prodPolyCoefficient3, [&](u64 treeIdx, u64 leafIdx, block512 v) { - auto row = treeIdx / mT; - blocks512(row, leafIdx) ^= v; + Matrix blocks(mC * mC * mT, mDpfTreeSize); + // expand the main tree and add the mT point functions correspond + // to a block together. This will give us the coefficients of the + // the product polynomial. + co_await mDpf.expand(prodPolyTreePos, prodPolyF4x243Coeffs, + [&, count = 0, out = blocks.data(), end = blocks.data() + blocks.size()] + (u64 treeIdx, u64 leafIdx, FoleageF4x243 v) mutable { + // the callback is called in column major order but blocks + // is row major (leafIdx will be the same). So we need to compute + // the correct index. Moreover, we are adding together mT trees + // so we also need divide the treeIdx by mT. To make this more + // efficient, we use the out pointer and manually increment it. + + assert(out == &blocks(treeIdx / mT, leafIdx)); + *out ^= v; + + if (++count == mT) + { + count = 0; + out += blocks.cols(); + if (out >= end) + { + out -= blocks.size() - 1; + } + } }, prng, sock); + + setTimePoint("mainDpf"); + + + std::vector fft(mN), fftRes(mN); + + // We have mC*mC = 16 polynomials. We need to apply + // the FFT to each. We do this by packing the 16 polynomials + // into a single u32. We then apply the FFT to this u32. + // This is done for each of the mT blocks of each polynomail. + // + // The DPFs used 512 bits to represent mDpfLeafSize=243 F4 elements. + // We need to skip the last 26 bits of each FoleageF4x243. for (size_t j = 0; j < mC; j++) { for (size_t k = 0; k < mC; k++) { size_t poly_index = (j * mC + k); - oc::MatrixView poly(blocks512.data(poly_index * mT), mT, mDpfTreeSize); + oc::MatrixView poly(blocks.data(poly_index * mT), mT, mDpfTreeSize); - u64 i = 0; - for (u64 block_idx = 0; block_idx < mT; ++block_idx) + for (u64 block_idx = 0, i = 0; block_idx < mT; ++block_idx) { for (u64 packed_idx = 0; packed_idx < mDpfTreeSize; ++packed_idx) { auto coeff = extractF4(poly(block_idx, packed_idx)); auto e = std::min(mBlockSize - packed_idx * mDpfLeafSize, mDpfLeafSize); - for (u64 element_idx = 0; element_idx < e; ++element_idx) + for (u64 element_idx = 0; element_idx < e; ++element_idx, ++i) { fft[i] |= u32{ coeff[element_idx] } << (2 * poly_index); - ++i; } } } } } - - setTimePoint("dpfKeyEval"); + setTimePoint("transpose"); fft_recursive_uint32(fft, mLog3N, mN / 3); - //std::cout << "Cfft " << hash(fft.data(), fft.size()) << std::endl; + setTimePoint("product fft"); multiply_fft_32(mFftASquared, fft, fftRes, mN); + setTimePoint("product mult"); - //std::cout << "C " << hash(fftRes.data(), fftRes.size()) << std::endl; - - setTimePoint("fft"); // XOR the (packed) columns into the accumulator. // Specifically, we perform column-wise XORs to get the result. @@ -311,9 +401,6 @@ namespace osuCrypto setBytes(msbMask, 0b10101010); for (size_t i = 0; i < outSize; i++) { - //auto resA = extractF4(res_poly_mat_A[i]); - //auto resB = extractF4(res_poly_mat_B[i]); - *BitIterator(CLsb.data(), i) = popcount(fftRes[i] & lsbMask) & 1; *BitIterator(CMsb.data(), i) = popcount(fftRes[i] & msbMask) & 1; } @@ -326,165 +413,169 @@ namespace osuCrypto macoro::task<> FoleageF4Ole::tensor(span coeffs, span prod, coproto::Socket& sock) { - //if (coeffs.size() != mC * mT) - // throw RTE_LOC; - if (coeffs.size() * coeffs.size() != prod.size()) throw RTE_LOC; - if (0) + auto expand = [](block k, span diff) { + AES aes(k); + for (u64 i = 0; i < diff.size(); ++i) + diff[i] = aes.ecbEncBlock(block(i)); + }; + + if (divCeil(coeffs.size(), 128) != 1) + throw RTE_LOC; // not impl + auto size = 2 * divCeil(coeffs.size(), 128); + + + if (mPartyIdx) { - PRNG prng(CCBlock); - std::array, 2> s; - s[0].resize(coeffs.size()); - s[1].resize(coeffs.size()); - //prng.get(s0.data(), s0.size()); - for (u64 i = 0; i < s[0].size(); ++i) + if (mSendOts.size() < 2 * coeffs.size() - 1) + throw RTE_LOC; //base ots not set. + // b * a = (b0 * a + b1 * (2 * a)) + //auto getDiff = [](block k0, block k1, span diff) { + // AES aes0(k0); + // AES aes1(k1); + // for (u64 i = 0; i < diff.size(); ++i) + // diff[i] = aes0.ecbEncBlock(block(i)) ^ aes1.ecbEncBlock(block(i) * 2); + // }; + std::array, 2> a; a[0].resize(size), a[1].resize(size); + std::vector t0(size), t1(size); + expand(mSendOts[0][0], t0); + expand(mSendOts[0][1], t1); + for (u64 i = 0; i < size; ++i) + a[0][i] = t0[i] ^ t1[i]; + + // a[1] = 2 * a[0] + f4Mult(a[0][0], a[0][1], ZeroBlock, AllOneBlock, a[1][0], a[1][1]); + { - s[0][i] = prng.get() % 4; - s[1][i] = prng.get() % 4; + auto lsbIter = BitIterator(&a[0][0]); + auto msbIter = BitIterator(&a[0][1]); + for (u64 i = 0; i < coeffs.size(); ++i) + coeffs[i] = (*lsbIter++ & 1) | ((*msbIter++ & 1) << 1); } - std::copy(s[mPartyIdx].begin(), s[mPartyIdx].end(), coeffs.begin()); - for (u64 iA = 0, pointIdx = 0; iA < s[0].size(); ++iA) { - for (u64 iB = 0; iB < s[1].size(); ++iB, ++pointIdx) - { - prod[pointIdx] = - (mult_f4(s[0][iA], s[1][iB]) * mPartyIdx);// ^ - //(prng.get() % 4); - } + setBytes(prod, 0); + auto prodIter = prod.begin(); + auto lsbIter = BitIterator(&t0[0]); + auto msbIter = BitIterator(&t0[1]); + for (u64 i = 0; i < coeffs.size(); ++i) + *prodIter++ = (*lsbIter++) | (u8(*msbIter++) << 1); } - } - else - { - auto expand = [](block k, span diff) { - AES aes(k); - for (u64 i = 0; i < diff.size(); ++i) - diff[i] = aes.ecbEncBlock(block(i)); - }; - if (divCeil(coeffs.size(), 128) != 1) - throw RTE_LOC; // not impl - auto size = 2 * divCeil(coeffs.size(), 128); + std::vector buffer((2 * coeffs.size() - 1) * size); + auto buffIter = buffer.begin(); + for (u64 i = 1; i < 2 * coeffs.size(); ++i) + { + auto b = i & 1; + auto idx = i / 2; + auto prodIter = prod.begin() + idx * coeffs.size(); + expand(mSendOts[i][0], t0); + expand(mSendOts[i][1], t1); - if (mPartyIdx) - { - if (mSendOts.size() < 2 * coeffs.size() - 1) - throw RTE_LOC; //base ots not set. - // b * a = (b0 * a + b1 * (2 * a)) - //auto getDiff = [](block k0, block k1, span diff) { - // AES aes0(k0); - // AES aes1(k1); - // for (u64 i = 0; i < diff.size(); ++i) - // diff[i] = aes0.ecbEncBlock(block(i)) ^ aes1.ecbEncBlock(block(i) * 2); - // }; - std::array, 2> a; a[0].resize(size), a[1].resize(size); - std::vector t0(size), t1(size); - expand(mSendOts[0][0], t0); - expand(mSendOts[0][1], t1); - for (u64 i = 0; i < size; ++i) - a[0][i] = t0[i] ^ t1[i]; - - // a[1] = 2 * a[0] - f4Mult(a[0][0], a[0][1], ZeroBlock, AllOneBlock, a[1][0], a[1][1]); + // prod = mask + auto lsbIter = BitIterator(&t0[0]); + auto msbIter = BitIterator(&t0[1]); + for (u64 i = 0; i < coeffs.size(); ++i) + *prodIter++ ^= (*lsbIter++) | (u8(*msbIter++) << 1); - { - auto lsbIter = BitIterator(&a[0][0]); - auto msbIter = BitIterator(&a[0][1]); - for (u64 i = 0; i < coeffs.size(); ++i) - coeffs[i] = (*lsbIter++ & 1) | ((*msbIter++ & 1) << 1); + for (u64 i = 0; i < a.size(); ++i) + { // mask key value + *buffIter++ = t0[i] ^ t1[i] ^ a[b][i]; + //*buffIter++ = diff[i]; } - { - setBytes(prod, 0); - auto prodIter = prod.begin(); - auto lsbIter = BitIterator(&t0[0]); - auto msbIter = BitIterator(&t0[1]); - for (u64 i = 0; i < coeffs.size(); ++i) - *prodIter++ = (*lsbIter++) | (u8(*msbIter++) << 1); - } + } + co_await sock.send(std::move(buffer)); + } + else + { - std::vector buffer((2 * coeffs.size() - 1) * size); - auto buffIter = buffer.begin(); - for (u64 i = 1; i < 2 * coeffs.size(); ++i) - { - auto b = i & 1; - auto idx = i / 2; - auto prodIter = prod.begin() + idx * coeffs.size(); - - expand(mSendOts[i][0], t0); - expand(mSendOts[i][1], t1); - - // prod = mask - auto lsbIter = BitIterator(&t0[0]); - auto msbIter = BitIterator(&t0[1]); - for (u64 i = 0; i < coeffs.size(); ++i) - *prodIter++ ^= (*lsbIter++) | (u8(*msbIter++) << 1); - - for (u64 i = 0; i < a.size(); ++i) - { // mask key value - *buffIter++ = t0[i] ^ t1[i] ^ a[b][i]; - //*buffIter++ = diff[i]; - } + if (mChoiceOts.size() < 2 * coeffs.size() - 1) + throw RTE_LOC; //base ots not set. + if (mRecvOts.size() < 2 * coeffs.size() - 1) + throw RTE_LOC; //base ots not set. - } + for (u64 i = 0; i < coeffs.size(); ++i) + coeffs[i] = mChoiceOts[2 * i] | (u8(mChoiceOts[2 * i + 1] << 1)); + std::vector t(size); + expand(mRecvOts[0], t); - co_await sock.send(std::move(buffer)); - } - else { - - if (mChoiceOts.size() < 2 * coeffs.size() - 1) - throw RTE_LOC; //base ots not set. - if (mRecvOts.size() < 2 * coeffs.size() - 1) - throw RTE_LOC; //base ots not set. - + setBytes(prod, 0); + auto prodIter = prod.begin(); + auto lsbIter = BitIterator(&t[0]); + auto msbIter = BitIterator(&t[1]); for (u64 i = 0; i < coeffs.size(); ++i) - coeffs[i] = mChoiceOts[2 * i] | (u8(mChoiceOts[2 * i + 1] << 1)); - std::vector t(size); - expand(mRecvOts[0], t); + *prodIter++ = (*lsbIter++) | (u8(*msbIter++) << 1); + } - { - setBytes(prod, 0); - auto prodIter = prod.begin(); - auto lsbIter = BitIterator(&t[0]); - auto msbIter = BitIterator(&t[1]); - for (u64 i = 0; i < coeffs.size(); ++i) - *prodIter++ = (*lsbIter++) | (u8(*msbIter++) << 1); - } + std::vector buffer((2 * coeffs.size() - 1) * size); + co_await sock.recv(buffer); - std::vector buffer((2 * coeffs.size() - 1) * size); - co_await sock.recv(buffer); + auto buffIter = buffer.begin(); + for (u64 i = 1; i < 2 * coeffs.size(); ++i) + { + auto idx = i / 2; + auto prodIter = prod.begin() + idx * coeffs.size(); - auto buffIter = buffer.begin(); - for (u64 i = 1; i < 2 * coeffs.size(); ++i) + expand(mRecvOts[i], t); + if (mChoiceOts[i]) { - auto idx = i / 2; - auto prodIter = prod.begin() + idx * coeffs.size(); - - expand(mRecvOts[i], t); - if (mChoiceOts[i]) + for (u64 i = 0; i < size; ++i) { - for (u64 i = 0; i < size; ++i) - { - t[i] = t[i] ^ *buffIter++; - } + t[i] = t[i] ^ *buffIter++; } - else - buffIter += size; - - // prod = mask - auto lsbIter = BitIterator(&t[0]); - auto msbIter = BitIterator(&t[1]); - for (u64 i = 0; i < coeffs.size(); ++i) - *prodIter++ ^= (*lsbIter++) | (u8(*msbIter++) << 1); } + else + buffIter += size; + + // prod = mask + auto lsbIter = BitIterator(&t[0]); + auto msbIter = BitIterator(&t[1]); + for (u64 i = 0; i < coeffs.size(); ++i) + *prodIter++ ^= (*lsbIter++) | (u8(*msbIter++) << 1); } } - } + + //macoro::task<> FoleageF4Ole::checkTensor(span coeffs, span tensoredCoefficients, coproto::Socket& sock) + //{ + // std::array, 2> pCoeffs;// (coeffs.size()); + // pCoeffs[mPartyIdx] = std::vector(coeffs.begin(), coeffs.end()); + // pCoeffs[1 - mPartyIdx].resize(coeffs.size()); + + // Matrix pProd(coeffs.size(), coeffs.size()); + + // co_await sock.send(coproto::copy(pCoeffs[mPartyIdx])); + // co_await sock.send(coproto::copy(tensoredCoefficients)); + // co_await sock.recv(pCoeffs[1 - mPartyIdx]); + // co_await sock.recv(pProd); + + // for (u64 i = 0; i < pProd.size(); ++i) + // { + // pProd(i) ^= tensoredCoefficients[i]; + // } + + // for (u64 i = 0; i < coeffs.size(); ++i) + // { + // auto scaler = pCoeffs[0][i]; + // for (u64 j = 0; j < coeffs.size(); ++j) + // { + // u8 exp = mult_f4(scaler, pCoeffs[1][j]); + // auto prod = pProd(i, j); + // if (prod != exp) + // { + // std::cout << "tensor check failed " << i << " " << j << " exp " << int(exp) << " act " << int(prod) << std::endl; + // throw RTE_LOC; + // } + // } + // } + + //} + } \ No newline at end of file diff --git a/libOTe/Tools/Foleage/FoleagePcg.h b/libOTe/Tools/Foleage/FoleagePcg.h index 8039dc57..59e96580 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.h +++ b/libOTe/Tools/Foleage/FoleagePcg.h @@ -15,20 +15,22 @@ namespace osuCrypto public: u64 mPartyIdx = 0; - // log3 polynomial size - u64 mLog3N = 0; - // the number of noisy positions per polynomial - u64 mT = 3; + u64 mT = 27; + // will be set to the log3 of mT. u64 mLog3T = 0; - // the number of polynomials + // the number of polynomials. u64 mC = 4; - // the size of a polynomial, 3^mLog3N + // the size of a polynomial, 3^mLog3N. + // We will produce this many OLEs. u64 mN = 0; + // log3 polynomial size + u64 mLog3N = 0; + // The A poly in FFT format. We pack mC FFTs into a single u8. The // first is hard coded to the identity polynomial. AlignedUnVector mFftA; @@ -36,110 +38,95 @@ namespace osuCrypto // The A^2 poly in FFT format. We pack mC^2 FFTs into a single u32. AlignedVector mFftASquared; - // depth of 3-ary DPF with 256 F4 values per leaf. - u64 _mDpfDomainDepth = 0; - - u64 _mDpfBlockSize = 0; - // the number of F4 values per block. Each block will have 1 non-zero. // A polynomial will have mT blocks. i.e. mN = mT * mBlockSize. u64 mBlockSize = 0; + // The log3 of mBlockSize. u64 mBlockDepth = 0; + // The number of F4 elements that are packed into a leaf + // of the main DPF. This will at most be 243. + u64 mDpfLeafSize = 0; + + // The log3 of mDpfLeafSize. This will at most be 5. u64 mDpfLeafDepth = 0; - u64 mDpfTreeDepth = 0; + + // the number of F4x243 elements that the main DPF will output. + // This will be approximately be mBlockSize / mDpfLeafSize. u64 mDpfTreeSize = 0; - u64 mDpfLeafSize = 0; + // The log3 of mDpfTreeSize. + u64 mDpfTreeDepth = 0; - // the coefficient of the sparse polynomial. - // the i'th row containts the coeffs for the i'th poly. - Matrix mSparseCoefficients; - // the locations of the non-zeros in the j'th block of the sparse polynomial. // the i'th row containts the coeffs for the i'th poly. Matrix mSparsePositions; - // a dpf used to construct the leaf value of the larger DPF. + // a dpf used to construct the F4x243 leaf value of the larger DPF. TriDpf mDpfLeaf; - // the main DPF - TriDpf mDpf; + // the main DPF which outputs 243 F4 elements for each leaf. + TriDpf mDpf; + // The base OTs used to tensor the coefficients of the sparse polynomial. std::vector mRecvOts; + + // The base OTs used to tensor the coefficients of the sparse polynomial. std::vector> mSendOts; + + // The base OTs used to tensor the coefficients of the sparse polynomial. BitVector mChoiceOts; - void init(u64 partyIdx, u64 n, PRNG& prng); + + // Intializes the protocol to generate n OLEs. Most efficient when n + // is a power of 3. Once called, baseOtCount() can be called to + // determine the required number of base OTs. + void init(u64 partyIdx, u64 n); struct BaseOtCount { - u64 mSendCount, mRecvCount; - }; + // the number of base OTs as sender. + u64 mSendCount = 0; - BaseOtCount baseOtCount() const - { - BaseOtCount counts; - - counts.mSendCount = mDpfLeaf.baseOtCount() + mDpf.baseOtCount(); - counts.mRecvCount = mDpfLeaf.baseOtCount() + mDpf.baseOtCount(); - if(mPartyIdx) - counts.mSendCount += 2 * mC * mT; - else - counts.mRecvCount += 2 * mC * mT; - return counts; - } + // the number of base OTs as receiver. + u64 mRecvCount = 0; + }; + // returns the number of base OTs required. + BaseOtCount baseOtCount() const; + // sets the base OTs that will be used. void setBaseOts( span> baseSendOts, span recvBaseOts, - const oc::BitVector& baseChoices) - { - auto baseCounts = baseOtCount(); - if (baseSendOts.size() != baseCounts.mSendCount) - throw RTE_LOC; - if (recvBaseOts.size() != baseCounts.mRecvCount) - throw RTE_LOC; - if (baseChoices.size() != baseCounts.mRecvCount) - throw RTE_LOC; - auto recvIter = recvBaseOts; - auto sendIter = baseSendOts; - auto choiceIter = baseChoices; - - auto dpfLeafCount = mDpfLeaf.baseOtCount(); - u64 offset = 0; - mDpfLeaf.setBaseOts( - sendIter.subspan(offset, dpfLeafCount), - recvIter.subspan(offset, dpfLeafCount), - BitVector(baseChoices.data(), dpfLeafCount, offset) - ); - offset += dpfLeafCount; - - auto dpfCount = mDpf.baseOtCount(); - mDpf.setBaseOts( - sendIter.subspan(offset, dpfCount), - recvIter.subspan(offset, dpfCount), - BitVector(baseChoices.data(), dpfCount, offset) - ); - offset += dpfCount; - - auto sendOts = sendIter.subspan(offset); - auto recvOts = recvIter.subspan(offset); - mSendOts.insert(mSendOts.end(), sendOts.begin(), sendOts.end()); - mRecvOts.insert(mRecvOts.end(), recvOts.begin(), recvOts.end()); - mChoiceOts = BitVector(baseChoices.data(), baseChoices.size() - offset, offset); - } + const oc::BitVector& baseChoices); + // returns true of the base OTs have been set. + bool hasBaseOts() const; + + // The F4 OLE protocol. This will generate n OLEs. + // the resulting OLEs are in bit decomposition form. + // A = (AMsb || ALsb), C = (CMsb || CLsb). This party will + // output (A,C) while the other outputs (A',C') such that + // A * A' = C + C'. macoro::task<> expand( span ALsb, span AMsb, span CLsb, - span CMsb, PRNG& prng, coproto::Socket& sock); + span CMsb, + PRNG& prng, + coproto::Socket& sock); + // sample random coefficients for the sparse polynomial and tensor + // them with the other parties coefficients. The result is shared + // as tensoredCoefficients. We allow the coeff to be zero. macoro::task<> tensor(span coeffs, span prod, coproto::Socket& sock); + // sample the A polynomial. This is the polynomial that will be + // multiplied the sparse polynomials by. void sampleA(block seed); + + }; } diff --git a/libOTe/Tools/Foleage/FoleageUtils.h b/libOTe/Tools/Foleage/FoleageUtils.h index 823190f9..9c28f269 100644 --- a/libOTe/Tools/Foleage/FoleageUtils.h +++ b/libOTe/Tools/Foleage/FoleageUtils.h @@ -322,21 +322,23 @@ namespace osuCrypto return ret; } - struct block512 + // A 512 bit value that is used to represent a vector of 3^5=243 F4 elements. + // We use this value because its greater than 128 bits and almost a power of 2. + // the last 26 bits are unused. + struct FoleageF4x243 { std::array mVal; - block512 operator^(const block512& o) const + FoleageF4x243 operator^(const FoleageF4x243& o) const { - block512 r; + FoleageF4x243 r; r.mVal[0] = mVal[0] ^ o.mVal[0]; r.mVal[1] = mVal[1] ^ o.mVal[1]; r.mVal[2] = mVal[2] ^ o.mVal[2]; r.mVal[3] = mVal[3] ^ o.mVal[3]; return r; } - //block512 operator-(const block512& o) const { return *this + o; } - block512& operator^=(const block512& o) + FoleageF4x243& operator^=(const FoleageF4x243& o) { mVal[0] = mVal[0] ^ o.mVal[0]; mVal[1] = mVal[1] ^ o.mVal[1]; @@ -345,7 +347,7 @@ namespace osuCrypto return *this; } - bool operator==(const block512& o) const + bool operator==(const FoleageF4x243& o) const { return mVal[0] == o.mVal[0] && @@ -355,7 +357,7 @@ namespace osuCrypto } }; - inline std::array extractF4(const block512& val) + inline std::array extractF4(const FoleageF4x243& val) { std::array ret; const char* ptr = (const char*)&val; diff --git a/libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp b/libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp index 957dd7a2..576f4a66 100644 --- a/libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp +++ b/libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp @@ -80,7 +80,7 @@ namespace osuCrypto clock_t t; t = clock(); - fft_recursive_uint8(coeffs, num_vars, num_coeffs / 3); + foliageFftUint8(coeffs, num_vars, num_coeffs / 3); t = clock() - t; double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.cpp b/libOTe/Tools/Foleage/fft/FoleageFft.cpp index 2e13bbf5..30e571de 100644 --- a/libOTe/Tools/Foleage/fft/FoleageFft.cpp +++ b/libOTe/Tools/Foleage/fft/FoleageFft.cpp @@ -232,7 +232,7 @@ namespace osuCrypto { } } - void fft_recursive_uint8( + void foliageFftUint8( span coeffs, const size_t num_vars, const size_t num_coeffs) @@ -242,19 +242,19 @@ namespace osuCrypto { if (num_vars > 1) { // apply FFT on all left coefficients - fft_recursive_uint8( + foliageFftUint8( coeffs, num_vars - 1, num_coeffs / 3); // apply FFT on all middle coefficients - fft_recursive_uint8( + foliageFftUint8( coeffs.subspan(num_coeffs), num_vars - 1, num_coeffs / 3); // apply FFT on all right coefficients - fft_recursive_uint8( + foliageFftUint8( coeffs.subspan(2 * num_coeffs), num_vars - 1, num_coeffs / 3); diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.h b/libOTe/Tools/Foleage/fft/FoleageFft.h index c73b7d01..6eea18b6 100644 --- a/libOTe/Tools/Foleage/fft/FoleageFft.h +++ b/libOTe/Tools/Foleage/fft/FoleageFft.h @@ -31,7 +31,7 @@ namespace osuCrypto { const size_t num_coeffs); // FFT for (up to) 4 polynomials over F4 - void fft_recursive_uint8( + void foliageFftUint8( span coeffs, const size_t num_vars, const size_t num_coeffs); diff --git a/libOTe_Tests/Foleage_Tests.cpp b/libOTe_Tests/Foleage_Tests.cpp index 40db356a..8bd0128a 100644 --- a/libOTe_Tests/Foleage_Tests.cpp +++ b/libOTe_Tests/Foleage_Tests.cpp @@ -470,7 +470,7 @@ namespace osuCrypto } timer.setTimePoint("begin"); - fft_recursive_uint8(a, nn, n / 3); + foliageFftUint8(a, nn, n / 3); timer.setTimePoint("fft_recursive_uint8"); foleageFFT(lsb.data(), msb.data(), nn, n / 3); timer.setTimePoint("foleageFFT 8 bit"); @@ -846,8 +846,8 @@ namespace osuCrypto // Evaluate the FFTs on the error polynomials eA and eB - fft_recursive_uint8(fft_eA, n, poly_size / 3); - fft_recursive_uint8(fft_eB, n, poly_size / 3); + foliageFftUint8(fft_eA, n, poly_size / 3); + foliageFftUint8(fft_eB, n, poly_size / 3); printf("[. ]Done with Step 1 (sampling error vectors)\n"); @@ -862,8 +862,8 @@ namespace osuCrypto // Compute the coordinate-wise multiplication over the packed FFT result std::vector res_poly_A(poly_size); std::vector res_poly_B(poly_size); - multiply_fft_8(fft_a, fft_eA, res_poly_A, poly_size); // a*eA - multiply_fft_8(fft_a, fft_eB, res_poly_B, poly_size); // a*eB + F4Multiply(fft_a, fft_eA, res_poly_A, poly_size); // a*eA + F4Multiply(fft_a, fft_eB, res_poly_B, poly_size); // a*eB //std::cout << "multA " << hash(res_poly_A.data(), res_poly_A.size()) << std::endl; @@ -1396,13 +1396,16 @@ namespace osuCrypto auto blocks = divCeil(n, 128); bool verbose = cmd.isSet("v"); + if(cmd.hasValue("t")) + oles[0].mT = oles[1].mT = cmd.get("t"); + //PRNG prng(block(342342)); PRNG prng0(block(2424523452345, 111124521521455324)); PRNG prng1(block(6474567454546, 567546754674345444)); Timer timer; - oles[0].init(0, n, prng0); - oles[1].init(1, n, prng1); + oles[0].init(0, n); + oles[1].init(1, n); { auto otCount0 = oles[0].baseOtCount(); @@ -1483,8 +1486,8 @@ namespace osuCrypto PRNG prng0(block(2424523452345, 111124521521455324)); PRNG prng1(block(6474567454546, 567546754674345444)); - oles[0].init(0, 1000, prng0); - oles[1].init(1, 1000, prng1); + oles[0].init(0, 1000); + oles[1].init(1, 1000); u64 n = oles[0].mC* oles[0].mT; u64 n2 = n * n; From b4f0cda37d4511952741c09d857d48384823f7e1 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Mon, 17 Feb 2025 18:28:28 -0800 Subject: [PATCH 20/48] optimizing --- frontend/H4.cpp | 918 ---------------------------- frontend/benchmark.h | 184 ++++++ frontend/main.cpp | 5 +- libOTe/Tools/Dpf/TriDpf.h | 227 ++++--- libOTe/Tools/Foleage/FoleagePcg.cpp | 19 +- libOTe/Tools/Foleage/FoleagePcg.h | 15 +- 6 files changed, 347 insertions(+), 1021 deletions(-) delete mode 100644 frontend/H4.cpp diff --git a/frontend/H4.cpp b/frontend/H4.cpp deleted file mode 100644 index 56c3afc1..00000000 --- a/frontend/H4.cpp +++ /dev/null @@ -1,918 +0,0 @@ -///////////////////////////////////////////////////////////////////////////// -//// Example source code for blog post: -//// "C++ Coroutines: Understanding Symmetric-Transfer" -//// -//// Implementation of a naive 'task' coroutine type. -// -//#include -//#include -//#include -//// using namespace std; -// -//#ifndef H4_H -//#define H4_H -// -//#define H4_VERSION "4.0.8" -// -//#ifndef H4_USERLOOP -//#define H4_USERLOOP 1 // improves performance -//#endif -//#define H4_COUNT_LOOPS 0 // DIAGNOSTICS -//#define H4_HOOK_TASKS 0 -// -//#define H4_JITTER_LO 100 // Entropy lower bound -//#define H4_JITTER_HI 350 // Entropy upper bound -//#define H4_Q_CAPACITY 10 // Default Q capacity -//#define H4_Q_ABS_MIN 6 // Absolute minimum Q capacity -// -//#define H4_DEBUG 0 -// -//#define H4_SAFETY_TIME 200 // ms, the time space where h4 could fix rollover issue, -// // too long might let more functions called earler if these falls just between millis() rollover and this period, just after the rollover, -// // too tight might cause missing it (if the h4.loop() didn't take control at the short period) -// -//#if H4_DEBUG -//#define H4_Pirntf(f_, ...) Serial.printf((f_), ##__VA_ARGS__) -//#else -//#define H4_Pirntf(f_, ...) -//#endif -// -// -//enum { -// H4_CHUNKER_ID = 90, -// H4AT_SCAVENGER_ID, -// H4AS_SSE_KA_ID, -// H4AS_WS_KA_ID, -// H4AMC_RCX_ID, -// H4AMC_KA_ID -//}; // must not grow past 99! -// -//// #include -// -//#include -//#include -//#include -//#include -//#include -//#include -//#include -//#include -//#include -//#define __PRETTY_FUNCTION__ __FUNCSIG__ -//#ifdef ARDUINO_ARCH_RP2040 -//#define h4rebootCore rp2040.restart -//#elif defined(ARDUINO) -//#define h4rebootCore ESP.restart -//#else -//void somef() {} -//#define h4rebootCore somef -//#endif -//#define H4_BOARD ARDUINO_BOARD -// -//uint32_t globMillis; -//uint32_t millis() { -// return globMillis; -//} -// -//void debugFunction(std::string f) { std::cout << f << std::endl; } -//void h4reboot(); -// -//void HAL_enableInterrupts(); -//void HAL_disableInterrupts(); -// -//uint64_t millis64(); -// -// -//class task; -//using H4_TASK_PTR = task*; -//using H4_TIMER = H4_TASK_PTR; -// -//class H4Delay; -//struct H4Coroutine {}; -// -//using H4_FN_COUNT = std::function; -//using H4_FN_TASK = std::function; -//using H4_FN_TIF = std::function; -//using H4_FN_VOID = std::function; -//using H4_FN_COROUTINE = std::function; -//using H4_FN_RTPTR = H4_FN_COUNT; -//// -//using H4_INT_MAP = std::unordered_map; -//using H4_TIMER_MAP = std::unordered_map; -//// -// -//#define CSTR(x) x.c_str() -//#define ME H4::context -//#define MY(x) H4::context->x -//#define TAG(x) (u+((x)*100)) -// -//extern H4_TASK_PTR& H4_context; -// -//class H4Countdown { -//public: -// uint32_t count; -// H4Countdown(uint32_t start = 1) { count = start; } -// uint32_t operator()() { return --count; } -//}; -// -//class H4Random : public H4Countdown { -//public: -// H4Random(uint32_t tmin = 0, uint32_t tmax = 0); -//}; -// -//// -//// T A S K -//// -//class task { -// bool harakiri = false; -// -// void _chain(); -// void _destruct(); -// friend class H4Delay; -//public: -// uint64_t id; -// H4_FN_VOID f; -// H4_FN_COROUTINE fcoro; -// uint32_t rmin = 0; -// uint32_t rmax = 0; -// H4_FN_COUNT reaper; -// H4_FN_VOID chain; -// // H4_FN_COROUTINE chaincoro; -// uint32_t uid = 0; -// bool singleton = false; -// H4_FN_VOID lastRites = [] {}; -// size_t len = 0; -// uint64_t at; -// uint32_t nrq = 0; -// void* partial = NULL; -// -// bool operator()(const task* lhs, const task* rhs) const; -// void operator()(); -// -// task() {} // only for comparison operator -// -// task( -// H4_FN_VOID _f, -// uint32_t _m, -// uint32_t _x, -// H4_FN_COUNT _r, -// H4_FN_VOID _c, -// uint32_t _u = 0, -// bool _s = false -// ); -// -// task( -// H4_FN_COROUTINE _f, -// uint32_t _m, -// uint32_t _x, -// H4_FN_COUNT _r, -// H4_FN_VOID _c, -// uint32_t _u = 0, -// bool _s = false -// ); -// -// ~task() {}//H4_Pirntf("T=%u TASK DTOR %p\n",millis(),this); } -// -// static void cancelSingleton(uint32_t id); -// uint32_t cleardown(uint32_t t); -// // The many ways to die... :) -// uint32_t endF(); // finalise: finishEarly -// uint32_t endU(); // unconditional finishNow; -// uint32_t endC(H4_FN_TIF); // conditional -// uint32_t endK(); // kill, chop etc -// // -// void createPartial(void* d, size_t l); -// void getPartial(void* d) { memcpy(d, partial, len); } -// void putPartial(void* d) { memcpy(partial, d, len); } -// void requeue(); -// void schedule(); -// static uint32_t randomRange(uint32_t lo, uint32_t hi); // move to h4 -//}; -// -//class H4Coroutine -//{ -// -// // task* owner; -// uint32_t duration; -// task* owner = nullptr; -// task* resumer = nullptr; -//public: -// class promise_type { -// // uint32_t duration; -// task* owner = nullptr; -// task* resumer = nullptr; -// friend class H4Coroutine; -// public: -// H4Coroutine get_return_object() noexcept; -// std::suspend_never initial_suspend() noexcept; -// void return_void() noexcept; -// void unhandled_exception() noexcept; -// struct final_awaiter; -// final_awaiter final_suspend() noexcept; -// -// void cancel(); -// }; -// std::coroutine_handle _coro; -// -// explicit H4Coroutine(std::coroutine_handle h) : _coro(h) { debugFunction(__PRETTY_FUNCTION__); printf("this=%p\n", this); printf("h=%p\n", h.address()); } -// ~H4Coroutine() { -// debugFunction(__PRETTY_FUNCTION__); -// // printf("this=%p\n", this); -// // printf("_coro=%p\tduration=%u\towner=%p\tresumer=%p\n", _coro,duration,owner,resumer); -// // if (_coro) _coro.destroy(); -// } -//}; -// -//class H4Delay { -// task* owner; -// uint32_t duration; -//public: -// -// explicit H4Delay(uint32_t duration, task* caller = H4_context) : duration(duration), owner(caller) { -// // debugFunction(__PRETTY_FUNCTION__); printf("this=%p\n", this); -// // printf("_coro=%p\tduration=%u\towner=%p\tresumer=%p\n", _coro,duration,owner,resumer); -// } -// -// bool await_ready() noexcept; -// void await_suspend(const std::coroutine_handle h) noexcept; -// void await_resume() noexcept; -//}; -//// -//// H 4 -//// -// -//class H4 : public std::priority_queue, task> { // H4P 35500 - 35700 -// friend class task; -// H4_TIMER_MAP singles; -// std::vector loopChain; -//public: -// std::unordered_map unloadables; -// static H4_TASK_PTR context; -// static std::map> suspendedTasks; -// -// -// void loop(); -// void setup(); -// -// H4(uint32_t baud = 0, size_t qSize = H4_Q_CAPACITY) { -// reserve(qSize); -// if (baud) { -// // Serial.begin(baud); -// H4_Pirntf("\nH4 RUNNING %s\n", H4_VERSION); -// } -// } -// -// H4_TASK_PTR every(uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR everyRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR nTimes(uint32_t n, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR nTimesRandom(uint32_t n, uint32_t msec, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR once(uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR onceRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR queueFunction(H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR randomTimes(uint32_t tmin, uint32_t tmax, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR randomTimesRandom(uint32_t tmin, uint32_t tmax, uint32_t msec, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR repeatWhile(H4_FN_COUNT w, uint32_t msec, H4_FN_VOID fn = []() {}, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR repeatWhileEver(H4_FN_COUNT w, uint32_t msec, H4_FN_VOID fn = []() {}, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// -// H4_TASK_PTR every(uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR everyRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR nTimes(uint32_t n, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR nTimesRandom(uint32_t n, uint32_t msec, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR once(uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR onceRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR queueFunction(H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR randomTimes(uint32_t tmin, uint32_t tmax, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR randomTimesRandom(uint32_t tmin, uint32_t tmax, uint32_t msec, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR repeatWhile(H4_FN_COUNT w, uint32_t msec, H4_FN_COROUTINE fn = [](H4Coroutine) -> H4Delay { return H4Delay(0); }, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// H4_TASK_PTR repeatWhileEver(H4_FN_COUNT w, uint32_t msec, H4_FN_COROUTINE fn = [](H4Coroutine) -> H4Delay { return H4Delay(0); }, H4_FN_VOID fnc = nullptr, uint32_t u = 0, bool s = false); -// -// H4_TASK_PTR cancel(H4_TASK_PTR t = context) { return endK(t); } // ? rv ? -// void cancel(std::initializer_list l) { for (auto const t : l) cancel(t); } -// void cancelAll(H4_FN_VOID fn = nullptr); -// void cancelSingleton(uint32_t s) { task::cancelSingleton(s); } -// void cancelSingleton(std::initializer_list l) { for (auto const i : l) cancelSingleton(i); } -// uint32_t finishEarly(H4_TASK_PTR t = context) { return endF(t); } -// uint32_t finishNow(H4_TASK_PTR t = context) { return endU(t); } -// bool finishIf(H4_TASK_PTR t, H4_FN_TIF f) { return endC(t, f); } -// // syscall only -// size_t _capacity() { return c.capacity(); } -// std::vector _copyQ(); -// void _hookLoop(H4_FN_VOID f, uint32_t subid); -// bool _unHook(uint32_t token); -// -// // protected: -// uint32_t gpFramed(task* t, std::function f); -// bool has(task* t) { return find(c.begin(), c.end(), t) != c.end(); } -// uint32_t endF(task* t); -// uint32_t endU(task* t); -// bool endC(task* t, H4_FN_TIF f); -// task* endK(task* t); -// void qt(task* t); -// void reserve(size_t n) { c.reserve(n); } -// H4_FN_TASK taskEvent = [](task*, uint32_t) {}; -// // -//#if H4_HOOK_TASKS -// static H4_FN_TASK taskHook; -// -// void _hookTask(H4_FN_TASK f) { taskHook = f; } -// static std::string dumpTask(task* t, uint32_t faze); -// static void addTaskNames(H4_INT_MAP names); -// static std::string getTaskType(uint32_t t); -// static const char* getTaskName(uint32_t t); -//#else -// static void addTaskNames(H4_INT_MAP names) {} -//#endif -// static void dumpQ(); -// // public: -// task* add(H4_FN_VOID _f, uint32_t _m, uint32_t _x, H4_FN_COUNT _r, H4_FN_VOID _c, uint32_t _u = 0, bool _s = false); -// task* add(H4_FN_COROUTINE _f, uint32_t _m, uint32_t _x, H4_FN_COUNT _r, H4_FN_VOID _c, uint32_t _u = 0, bool _s = false); -//}; -// -//template -//class pr { -// size_t size = sizeof(T); -// -// template -// T2 put(T2 v) { -// memcpy(MY(partial), reinterpret_cast(&v), size); -// return get(); -// } -// template -// T2 get() { return (*(reinterpret_cast(MY(partial)))); } -// -//public: -// pr(T v) { -// if (!MY(partial)) { -// MY(partial) = reinterpret_cast(malloc(size)); -// put(v); -// } -// } -// -// pr operator=(const T other) { return put(other); } -// -// operator T() { return get(); } -// -// T operator +(T v) { return get() + v; } -// -// T operator +=(T v) { return put(get() + v); } -// -// T* operator->() const { -// return reinterpret_cast(MY(partial)); -// } -//}; -// -//extern H4 h4; -// -//template -//static void h4Chunker(T& x, std::function fn, uint32_t lo = H4_JITTER_LO, uint32_t hi = H4_JITTER_HI, H4_FN_VOID final = nullptr) { -// H4_TIMER p = h4.repeatWhile( -// H4Countdown(x.size()), -// task::randomRange(lo, hi), // arbitrary -// [=]() { -// typename T::iterator thunk; -// ME->getPartial(&thunk); -// fn(thunk++); -// ME->putPartial((void*)&thunk); -// // yield(); -// }, -// final, -// H4_CHUNKER_ID); -// typename T::iterator chunkIt = x.begin(); -// p->createPartial((void*)&chunkIt, sizeof(typename T::iterator)); -// p->lastRites = [=] { -// free(p->partial); -// p->partial = nullptr; -// }; -//} -// -//#endif // H4_H -// -//#define __attribute__(X) -// -//////////////////////////// H4.cpp ///////////////////////////// -//#ifdef ARDUINO_ARCH_ESP32 -//portMUX_TYPE h4_mutex = portMUX_INITIALIZER_UNLOCKED; -//void HAL_enableInterrupts() { portEXIT_CRITICAL(&h4_mutex); } -//void HAL_disableInterrupts() { portENTER_CRITICAL(&h4_mutex); } -//#else -//void HAL_enableInterrupts() { /* interrupts(); */ } -//void HAL_disableInterrupts() { /* noInterrupts(); */ } -//#endif -//// -//// and ...here we go! -//// -//void __attribute__((weak)) h4setup(); -//void __attribute__((weak)) h4UserLoop(); -// -//H4_TIMER H4::context = nullptr; -//H4_TASK_PTR& H4_context = H4::context; -// -//std::map> H4::suspendedTasks; -// -//void h4reboot() { h4rebootCore(); } -// -//H4Random::H4Random(uint32_t rmin, uint32_t rmax) { count = task::randomRange(rmin, rmax); } -// -//__attribute__((weak)) H4_INT_MAP h4TaskNames = {}; -// -//#if H4_COUNT_LOOPS -//uint32_t h4Nloops = 0; -//#endif -// -//H4Delay H4Delay::promise_type::get_return_object() noexcept { -// debugFunction(__PRETTY_FUNCTION__); -// return H4Delay(std::coroutine_handle::from_promise(*this)); -//} -//std::suspend_never H4Delay::promise_type::initial_suspend() noexcept { debugFunction(__PRETTY_FUNCTION__); return {}; } -//void H4Delay::promise_type::return_void() noexcept { debugFunction(__PRETTY_FUNCTION__); } -//void H4Delay::promise_type::unhandled_exception() noexcept { debugFunction(__PRETTY_FUNCTION__); std::terminate(); } -//struct H4Delay::promise_type::final_awaiter { -// bool await_ready() noexcept { debugFunction(__PRETTY_FUNCTION__); return false; } -// bool await_suspend(std::coroutine_handle h) noexcept { -// debugFunction(__PRETTY_FUNCTION__); -// printf("h=%p\n", h.address()); -// auto owner = h.promise().owner; -// if (owner) owner->_destruct(); -// H4::suspendedTasks.erase(owner); -// // [ ] IF NOT IMMEDIATEREQUEUE: MANAGE REQUEUE AND CHAIN CALLS. -// return false; -// } -// void await_resume() noexcept { debugFunction(__PRETTY_FUNCTION__); } -//}; -//H4Delay::promise_type::final_awaiter H4Delay::promise_type::final_suspend() noexcept { return {}; } -// -//bool H4Delay::await_ready() noexcept { debugFunction(__PRETTY_FUNCTION__); return false; } -// -//void H4Delay::await_suspend(const std::coroutine_handle h) noexcept { -// debugFunction(__PRETTY_FUNCTION__); -// printf("h=%p\n", h.address()); -// // Schedule the resumer. -// _coro = h; -// resumer = h4.once(duration, [this] { -// -// debugFunction(__PRETTY_FUNCTION__); -// _coro.resume(); -// }); -// h.promise().owner = owner; -// h.promise().resumer = resumer; -// H4::suspendedTasks[owner] = _coro; -//} -// -//void H4Delay::await_resume() noexcept { -// debugFunction(__PRETTY_FUNCTION__); -// resumer = nullptr; -//} -// -// -//void H4Delay::promise_type::cancel() { -// debugFunction(__PRETTY_FUNCTION__); -// auto _coro = std::coroutine_handle::from_promise(*this); -// printf("_coro=%p\n", _coro.address()); -// if (_coro) { -// // _coro.promise().owner = nullptr; -// _coro.destroy(); -// } -// if (resumer) { -// h4.cancel(resumer); -// resumer = nullptr; -// } -// H4::suspendedTasks.erase(owner); -// owner = nullptr; -//} -// -// -//void H4::dumpQ() {} -// -//uint64_t millis64() { -// static volatile uint64_t overflow = 0; -// static volatile uint32_t lastSample = 0; -// static const uint64_t kOverflowIncrement = static_cast(0x100000000); -// -// uint64_t overflowSample; -// uint32_t sample; -// -// // Tracking timer wrap assumes that this function gets called with -// // a period that is less than 1/2 the timer range. -// HAL_disableInterrupts(); -// sample = millis(); -// -// if (lastSample > sample) -// { -// overflow = overflow + kOverflowIncrement; -// } -// -// lastSample = sample; -// overflowSample = overflow; -// HAL_enableInterrupts(); -// -// return (overflowSample | static_cast(sample)); -//} -//// -//// task -//// -//task::task( -// H4_FN_VOID _f, -// uint32_t _m, -// uint32_t _x, -// H4_FN_COUNT _r, -// H4_FN_VOID _c, -// uint32_t _u, -// bool _s -//) : -// f{ _f }, -// rmin{ _m }, -// rmax{ _x }, -// reaper{ _r }, -// chain{ _c }, -// uid{ _u }, -// singleton{ _s } -//{ -// static uint64_t count = 0; -// count++; -// id = count; -// if (_s) { -// uint32_t id = _u % 100; -// if (h4.singles.count(id)) h4.singles[id]->endK(); -// h4.singles[id] = this; -// } -// schedule(); -//} -//task::task( -// H4_FN_COROUTINE _f, -// uint32_t _m, -// uint32_t _x, -// H4_FN_COUNT _r, -// H4_FN_VOID _c, -// uint32_t _u, -// bool _s -//) : -// fcoro{ _f }, -// rmin{ _m }, -// rmax{ _x }, -// reaper{ _r }, -// chain{ _c }, -// uid{ _u }, -// singleton{ _s } -//{ -// static uint64_t count = 0; -// count++; -// id = count; -// if (_s) { -// uint32_t id = _u % 100; -// if (h4.singles.count(id)) h4.singles[id]->endK(); -// h4.singles[id] = this; -// } -// schedule(); -//} -// -//bool task::operator() (const task* lhs, const task* rhs) const { return ((lhs->at > rhs->at) || (lhs->at == rhs->at && lhs->id > rhs->id)) ? true : false; } -//H4Coroutine h4dummy; -//void task::operator()() { -// if (harakiri) _destruct(); // for clean exits -// else { -// std::cout << "CALLING " << (f ? "F" : fcoro ? "FCORO" : "UNDEFINED") << std::endl; -// if (f) f(); -// else fcoro(h4dummy); -// // f(); -// bool thisis_suspended = H4::suspendedTasks.count(this); -// // CURRENTLY: THIS ONLY PREVENTS DESTRUCTION AT THIS POINT, IN FUTURE: RELAY REQUEUE & CHAIN .. -// if (reaper) { // it's finite -// if (!(reaper())) { // ...and it just ended -// _chain(); // run chain function if there is one -// if ((rmin == rmax) && rmin) { -// rmin = 86400000; // reque in +24 hrs -// rmax = 0; -// reaper = nullptr; // and every day after -// requeue(); -// } -// else if (!thisis_suspended) _destruct(); -// } -// else requeue(); -// } -// else requeue(); -// } -//} -// -//void task::_chain() { if (chain) h4.add(chain, 0, 0, H4Countdown(1), nullptr, uid); } // prevents tag rescaling during the pass -// -//void task::cancelSingleton(uint32_t s) { if (h4.singles.count(s)) h4.singles[s]->endK(); } -// -//uint32_t task::cleardown(uint32_t pass) { -// if (singleton) { -// uint32_t id = uid % 100; -// h4.singles.erase(id); -// } -// return pass; -//} -// -//void task::_destruct() { -// debugFunction(__PRETTY_FUNCTION__); -//#if H4_HOOK_TASKS -// H4::taskHook(this, 4); -//#endif -// lastRites(); -// if (partial) free(partial); -// delete this; -//} -//// The many ways to die... :) -//uint32_t task::endF() { -// // H4_Pirntf("ENDF %p\n",this); -// reaper = H4Countdown(1); -// at = 0; -// return cleardown(1 + nrq); -//} -// -//uint32_t task::endU() { -// // H4_Pirntf("ENDU %p\n",this); -// _chain(); -// return nrq + endK(); -//} -// -//uint32_t task::endC(H4_FN_TIF f) { -// bool rv = f(this); -// if (rv) return endF(); -// return rv; -//} -// -//uint32_t task::endK() { -// debugFunction(__PRETTY_FUNCTION__); -// // H4_Pirntf("ENDK %p\n",this); -// auto it = std::find_if(H4::suspendedTasks.begin(), H4::suspendedTasks.end(), [this](const std::pair> p) { return p.first == this; }); -// bool thisiscoro = it != H4::suspendedTasks.end(); -// std::cout << "\tthisiscoro=" << thisiscoro << std::endl; -// if (thisiscoro) { -// it->second.promise().cancel(); -// } -// harakiri = true; -// return cleardown(at = 0); -//} -// -//uint32_t task::randomRange(uint32_t rmin, uint32_t rmax) { return rmax > rmin ? (rand() % (rmax - rmin)) + rmin : rmin; } -// -//void task::requeue() { -// nrq++; -// schedule(); -// h4.qt(this); -//} -// -//void task::schedule() { at = millis64() + randomRange(rmin, rmax); } -// -//void task::createPartial(void* d, size_t l) { -// partial = malloc(l); -// memcpy(partial, d, l); -// len = l; -//} -//// -//// H4 -//// -//task* H4::add(H4_FN_VOID _f, uint32_t _m, uint32_t _x, H4_FN_COUNT _r, H4_FN_VOID _c, uint32_t _u, bool _s) { -// task* t = new task(_f, _m, _x, _r, _c, _u, _s); -//#if H4_HOOK_TASKS -// H4::taskHook(t, 1); -//#endif -// qt(t); -// return t; -//} -//task* H4::add(H4_FN_COROUTINE _f, uint32_t _m, uint32_t _x, H4_FN_COUNT _r, H4_FN_VOID _c, uint32_t _u, bool _s) { -// task* t = new task(_f, _m, _x, _r, _c, _u, _s); -//#if H4_HOOK_TASKS -// H4::taskHook(t, 1); -//#endif -// qt(t); -// return t; -//} -// -//uint32_t H4::gpFramed(task* t, H4_FN_RTPTR f) { -// uint32_t rv = 0; -// printf("t=%p, f=%p\n", t, f); -// if (t) { -// HAL_disableInterrupts(); -// if (has(t) || (t == H4::context) || H4::suspendedTasks.count(t)) rv = f(); // fix bug where context = 0! -// HAL_enableInterrupts(); -// } -// return rv; -//} -// -//uint32_t H4::endF(task* t) { return gpFramed(t, [=] { return t->endF(); }); } -// -//uint32_t H4::endU(task* t) { return gpFramed(t, [=] { return t->endU(); }); } -// -//bool H4::endC(task* t, H4_FN_TIF f) { return gpFramed(t, [=] { return t->endC(f); }); } -// -//task* H4::endK(task* t) { -// debugFunction(__PRETTY_FUNCTION__); -// return reinterpret_cast(gpFramed(t, [=] { return t->endK(); })); } -// -//void H4::qt(task* t) { -// HAL_disableInterrupts(); -// push(t); -// HAL_enableInterrupts(); -//#if H4_HOOK_TASKS -// H4::taskHook(t, 2); -//#endif -//} -//// -//extern void h4setup(); -// -//std::vector H4::_copyQ() { -// std::vector t; -// HAL_disableInterrupts(); -// t = c; -// HAL_enableInterrupts(); -// return t; -//} -// -//void H4::_hookLoop(H4_FN_VOID f, uint32_t subid) { -// if (f) { -// unloadables[subid] = loopChain.size(); -// loopChain.push_back(f); -// } -//} -// -//bool H4::_unHook(uint32_t subid) { -// if (unloadables.count(subid)) { -// loopChain.erase(loopChain.begin() + unloadables[subid]); -// unloadables.erase(subid); -// return true; -// } -// return false; -//} -// -//void setup() { -// h4.setup(); -// h4setup(); -//} -// -//void loop() { -// h4.loop(); -//} -// -//void H4::cancelAll(H4_FN_VOID f) { -// HAL_disableInterrupts(); -// while (!empty()) { -// top()->endK(); -// pop(); -// } -// HAL_enableInterrupts(); -// if (f) f(); -//} -// -//H4_TASK_PTR H4::every(uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, nullptr, fnc, TAG(3), s); } -// -//H4_TASK_PTR H4::everyRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, nullptr, fnc, TAG(4), s); } -// -//H4_TASK_PTR H4::nTimes(uint32_t n, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Countdown(n), fnc, TAG(5), s); } -// -//H4_TASK_PTR H4::nTimesRandom(uint32_t n, uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Countdown(n), fnc, TAG(6), s); } -// -//H4_TASK_PTR H4::once(uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Countdown(1), fnc, TAG(7), s); } -// -//H4_TASK_PTR H4::onceRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Countdown(1), fnc, TAG(8), s); } -// -//H4_TASK_PTR H4::queueFunction(H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, 0, 0, H4Countdown(1), fnc, TAG(9), s); } -// -//H4_TASK_PTR H4::randomTimes(uint32_t tmin, uint32_t tmax, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Random(tmin, tmax), fnc, TAG(10), s); } -// -//H4_TASK_PTR H4::randomTimesRandom(uint32_t tmin, uint32_t tmax, uint32_t Rmin, uint32_t Rmax, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Random(tmin, tmax), fnc, TAG(11), s); } -// -//H4_TASK_PTR H4::repeatWhile(H4_FN_COUNT fncd, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, fncd, fnc, TAG(12), s); } -// -//H4_TASK_PTR H4::repeatWhileEver(H4_FN_COUNT fncd, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { -// return add(fn, msec, 0, fncd, -// std::bind([this](H4_FN_COUNT fncd, uint32_t msec, H4_FN_VOID fn, H4_FN_VOID fnc, uint32_t u, bool s) { -// fnc(); -// repeatWhileEver(fncd, msec, fn, fnc, u, s); -// }, fncd, msec, fn, fnc, u, s), -// TAG(13), s); -//} -// -//H4_TASK_PTR H4::every(uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, nullptr, fnc, TAG(3), s); } -// -//H4_TASK_PTR H4::everyRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, nullptr, fnc, TAG(4), s); } -// -//H4_TASK_PTR H4::nTimes(uint32_t n, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Countdown(n), fnc, TAG(5), s); } -// -//H4_TASK_PTR H4::nTimesRandom(uint32_t n, uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Countdown(n), fnc, TAG(6), s); } -// -//H4_TASK_PTR H4::once(uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Countdown(1), fnc, TAG(7), s); } -// -//H4_TASK_PTR H4::onceRandom(uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Countdown(1), fnc, TAG(8), s); } -// -//H4_TASK_PTR H4::queueFunction(H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, 0, 0, H4Countdown(1), fnc, TAG(9), s); } -// -//H4_TASK_PTR H4::randomTimes(uint32_t tmin, uint32_t tmax, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, H4Random(tmin, tmax), fnc, TAG(10), s); } -// -//H4_TASK_PTR H4::randomTimesRandom(uint32_t tmin, uint32_t tmax, uint32_t Rmin, uint32_t Rmax, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, Rmin, Rmax, H4Random(tmin, tmax), fnc, TAG(11), s); } -// -//H4_TASK_PTR H4::repeatWhile(H4_FN_COUNT fncd, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { return add(fn, msec, 0, fncd, fnc, TAG(12), s); } -// -//H4_TASK_PTR H4::repeatWhileEver(H4_FN_COUNT fncd, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { -// return add(fn, msec, 0, fncd, -// std::bind([this](H4_FN_COUNT fncd, uint32_t msec, H4_FN_COROUTINE fn, H4_FN_VOID fnc, uint32_t u, bool s) { -// fnc(); -// repeatWhileEver(fncd, msec, fn, fnc, u, s); -// }, fncd, msec, fn, fnc, u, s), -// TAG(13), s); -//} -// -//void H4::setup() { -//} -// -//void H4::loop() { -// task* t = nullptr; -// uint64_t now = millis64(); -// HAL_disableInterrupts(); -// if (size()) { -// if (((int64_t)(top()->at - now)) < 1) { -// t = top(); -// pop(); -// } -// } -// HAL_enableInterrupts(); -// if (t) { // H4P 35000 35100 -// H4::context = t; -// // H4_Pirntf("T=%u H4context <-- %p\n",millis(),t); -// (*t)(); -// // H4_Pirntf("T=%u H4context --> %p\n",millis(),t); -// // dumpQ(); -// }; -// // -// for (auto const& f : loopChain) f(); -//#if H4_USERLOOP -// h4UserLoop(); -//#endif -//#if H4_COUNT_LOOPS -// h4Nloops++; -//#endif -//} -// -//H4 h4(0); -//int H4main() { -// setup(); -// // Emulating while(1) loop. -// while (millis() < 10000) { -// if (!(millis() % 5)) -// std::cout << " T= " << millis() << "ms" << std::endl; -// // Each millisecond runs thousands of iterations, simulate a few: -// for (auto i = 0; i < 20; i++) -// loop(); -// globMillis++; -// } -// return 0; -//} -// -//H4Delay someF() { -// debugFunction(__PRETTY_FUNCTION__); -// printf("on 500, awaiting 400 ms:\n"); -// // auto currentContext = H4::context; -// // h4.once(100, [currentContext]{ debugFunction(__PRETTY_FUNCTION__); h4.cancel(currentContext); }); -// co_await H4Delay(400); -// printf("400ms awaited!\n"); -//} -//void h4setup() { -// // h4.once(1000, []{ printf("1000ms elapsed\n"); }); -// /* h4.queueFunction([]() ->H4Delay { -// // for (auto i=0 ; i<20; i++) { -// // printf("i=%d\n", i); -// co_await H4Delay(5); -// // } -// }); */ -// /* h4.queueFunction([](H4Coroutine) -> H4Delay { // Replacement to h4Chunker(vs,[](std::vector::iterator it){ printf("Processing [%s]\n", *it.data());}, 100,200); -// std::vector vs {"Hello", "World"}; -// for (auto &v : vs) { -// printf ("Processing [%s]\n", v.data()); -// co_await H4Delay(task::randomRange(100,200)); -// } -// }); -// h4.queueFunction([](H4Coroutine) -> H4Delay { // Replacement to h4.nTimes(20, 5, []{ printf("i=%d\n", ME->nrq);}); -// for (auto i = 0; i < 20; i++) { -// printf("i=%d\n", i); -// co_await H4Delay(5); // Delay asynchronously :) -// } -// printf("Chain Function\n"); -// }); -// -// h4.queueFunction([](H4Coroutine) -> H4Delay { // Replacement to h4.every(100, []{printf("Some processing\n"); }); -// while (true) { -// printf("Some processing\n"); -// co_await H4Delay(100); -// } -// }); */ -// auto context = h4.once(500, someF); -// h4.once(1000, [context] { debugFunction(__PRETTY_FUNCTION__); h4.cancel(context); }); -//} -//void h4UserLoop() { -// -//} -// -///* -// Coroutines: -// - co_await H4Delay({$Time}); -// - co_await H4Delay(0) does queue the continuation to the next loop iteration. -// - The function signature should return H4Delay type instead of void, and can accepts H4Coroutine Parameter. -// - Finishing the timer can be done by h4.cancel($task) or h4.FinishNow/h4.FinishIf/h4.cancel. Where they all destroy the coroutine handle. -// - FinishNow/FinishIf would call the chain function, cancel does not. -// - The chain or requeue of the coroutine gets called immediately, (Not after the coroutine function itself finishes), therefore if some h4.nTimes() function gets called it'd be rescheduled once it call the coroutine function. Also the chain would be scheduled just after calling the last function even if it's a coroutine. and it would co_await. -// -// -// */ \ No newline at end of file diff --git a/frontend/benchmark.h b/frontend/benchmark.h index 431f613d..dc9cfc36 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -16,6 +16,8 @@ #include "libOTe/Tools/TungstenCode/TungstenCode.h" #include "libOTe/Tools/ExConvCodeOld/ExConvCodeOld.h" #include "libOTe/Tools/Dpf/RegularDpf.h" +#include "libOTe/Tools/Dpf/TriDpf.h" +#include "libOTe/Tools/Foleage/FoleagePcg.h" namespace osuCrypto { @@ -764,4 +766,186 @@ namespace osuCrypto if (cmd.isSet("v")) std::cout << timer << std::endl; } + + + + void TriDpfBenchmark(const oc::CLP& cmd) + { + //using F = FoleageF4x243; + //using Ctx = FoleageCoeffCtx; + using F = block; + using Ctx = CoeffCtxGF2; + Timer timer; + + PRNG prng(block(231234, 321312)); + u64 depth = cmd.getOr("depth", 3); + u64 domain = ipow(3, depth) - 3; + u64 numPoints = cmd.getOr("numPoints", 1000); + u64 trials = cmd.getOr("trials", 1); + + std::vector points0(numPoints); + std::vector points1(numPoints); + std::vector points(numPoints); + std::vector values0(numPoints); + std::vector values1(numPoints); + Ctx ctx; + for (u64 i = 0; i < numPoints; ++i) + { + points[i] = Trit32(prng.get() % domain); + points1[i] = Trit32(prng.get() % domain); + points0[i] = points[i] - points1[i]; + //std::cout << points[i] << " = " << points0[i] <<" + "<< points1[i] << std::endl; + values0[i] = prng.get(); + values1[i] = prng.get(); + //ctx.minus(points0[i], points[i], points1[i];) + } + + + for (u64 i = 0; i < trials; ++i) + { + + std::array, 2> dpf; + dpf[0].init(0, domain, numPoints); + dpf[1].init(1, domain, numPoints); + + auto baseCount = dpf[0].baseOtCount(); + + std::array, 2> baseRecv; + std::array>, 2> baseSend; + std::array baseChoice; + baseRecv[0].resize(baseCount); + baseRecv[1].resize(baseCount); + baseSend[0].resize(baseCount); + baseSend[1].resize(baseCount); + baseChoice[0].resize(baseCount); + baseChoice[1].resize(baseCount); + baseChoice[0].randomize(prng); + baseChoice[1].randomize(prng); + for (u64 i = 0; i < baseCount; ++i) + { + baseSend[0][i] = prng.get(); + baseSend[1][i] = prng.get(); + baseRecv[0][i] = baseSend[1][i][baseChoice[0][i]]; + baseRecv[1][i] = baseSend[0][i][baseChoice[1][i]]; + } + dpf[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); + dpf[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); + + std::array, 2> output; + //std::array, 2> tags; + output[0].resize(domain, numPoints, AllocType::Uninitialized); + output[1].resize(domain, numPoints, AllocType::Uninitialized); + // tags[0].resize(numPoints, domain, AllocType::Uninitialized); + // tags[1].resize(numPoints, domain, AllocType::Uninitialized); + + auto sock = coproto::LocalAsyncSocket::makePair(); + + timer.setTimePoint("beign"); + auto out0 = output[0].data(); + auto out1 = output[1].data(); + + macoro::sync_wait(macoro::when_all_ready( + dpf[0].expand(points0, values0, [&](auto k, auto i, auto v) { *out0++ = v; }, prng, sock[0]), + dpf[1].expand(points1, values1, [&](auto k, auto i, auto v) { *out1++ = v; }, prng, sock[1]) + )); + timer.setTimePoint("end"); + } + + std::cout << timer << std::endl; + + } + + + + + // This test evaluates the full PCG.Expand for both parties and + // checks correctness of the resulting OLE correlation. + void FoleageBenchmark(const CLP& cmd) + { + + auto logn = cmd.getOr("nn", 10); + u64 n = ipow(3, logn); + auto blocks = divCeil(n, 128); + bool verbose = cmd.isSet("v"); + + + u64 trials = cmd.getOr("trials", 1); + + //PRNG prng(block(342342)); + PRNG prng0(block(2424523452345, 111124521521455324)); + PRNG prng1(block(6474567454546, 567546754674345444)); + Timer timer; + + macoro::thread_pool::work work; + macoro::thread_pool pool(2, work); + auto sock = coproto::LocalAsyncSocket::makePair(); + sock[0].setExecutor(pool); + sock[1].setExecutor(pool); + + for (u64 ii = 0; ii < trials; ++ii) + { + + std::array oles; + if (cmd.hasValue("t")) + oles[0].mT = oles[1].mT = cmd.get("t"); + if (cmd.hasValue("c")) + oles[0].mC = oles[1].mC = cmd.get("c"); + + oles[0].init(0, n); + oles[1].init(1, n); + + std::cout << "leaf " << oles[0].mDpfLeafDepth << " main " << oles[0].mDpfTreeDepth << std::endl; + + { + auto otCount0 = oles[0].baseOtCount(); + auto otCount1 = oles[1].baseOtCount(); + if (otCount0.mRecvCount != otCount1.mSendCount || + otCount0.mSendCount != otCount1.mRecvCount) + throw RTE_LOC; + std::array>, 2> baseSend; + baseSend[0].resize(otCount0.mSendCount); + baseSend[1].resize(otCount1.mSendCount); + std::array, 2> baseRecv; + std::array baseChoice; + + for (u64 i = 0; i < 2; ++i) + { + prng0.get(baseSend[i].data(), baseSend[i].size()); + baseRecv[1 ^ i].resize(baseSend[i].size()); + baseChoice[1 ^ i].resize(baseSend[i].size()); + baseChoice[1 ^ i].randomize(prng0); + for (u64 j = 0; j < baseSend[i].size(); ++j) + { + baseRecv[1 ^ i][j] = baseSend[i][j][baseChoice[1 ^ i][j]]; + } + } + + oles[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); + oles[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); + } + + std::vector + ALsb(blocks), + AMsb(blocks), + BLsb(blocks), + BMsb(blocks), + C0Lsb(blocks), + C0Msb(blocks), + C1Lsb(blocks), + C1Msb(blocks); + + + oles[0].setTimer(timer); + timer.setTimePoint("start"); + auto r = macoro::sync_wait(macoro::when_all_ready( + oles[0].expand(ALsb, AMsb, C0Lsb, C0Msb, prng0, sock[0]) | macoro::start_on(pool), + oles[1].expand(BLsb, BMsb, C1Lsb, C1Msb, prng1, sock[1]) | macoro::start_on(pool))); + timer.setTimePoint("end"); + std::get<0>(r).result(); + std::get<1>(r).result(); + + } + work = {}; + std::cout << "n="< allocation(allocSize); + auto allocIter = allocation.data(); + + auto makeMatrix = [&](u64 rows, u64 cols, T) -> MatrixView + { + auto ret = MatrixView((T*)allocIter, rows, cols); + allocIter += sizeof(T) * ret.size(); + if (allocIter > allocation.data() + allocSize) + throw std::runtime_error("TriDpf: allocation error. " LOCATION); + return ret; + }; // shares of S' - auto pow3 = ipow(3, mDepth); - std::array, 3> s; + std::array, 3> s; auto last = mDepth % 3; - s[last].resize(pow3, mNumPoints, oc::AllocType::Uninitialized); - s[(last + 2) % 3].resize(pow3 / 3, mNumPoints, oc::AllocType::Uninitialized); - s[(last + 1) % 3].resize(pow3 / 9, mNumPoints, oc::AllocType::Uninitialized); + s[(last + 0) % 3] = makeMatrix(pow3 / 1, mNumPoints, block{}); + s[(last + 2) % 3] = makeMatrix(pow3 / 3, mNumPoints, block{}); + s[(last + 1) % 3] = makeMatrix(pow3 / 9, mNumPoints, block{}); -#if defined(NDEBUG) - auto getRow = [](auto&& m, u64 i) {return m.data(i); }; -#else auto getRow = [](auto&& m, u64 i) {return m[i]; }; -#endif - Matrix z(3, mNumPoints); - Matrix sigma(3, mNumPoints); + auto z = makeMatrix(3, mNumPoints, block{}); + auto sigma = makeMatrix(3, mNumPoints, block{}); { @@ -349,7 +362,30 @@ namespace osuCrypto for (u64 j = 0; j < 3; ++j) { - for (u64 k = 0; k < mNumPoints; ++k) + auto sig = sigma[j].data(); + auto seedj = seed[j].data(); + + for (u64 k = 0; k < numPoints8; k += 8) + { + block s[8];//, w[8]; + SIMD8(q, s[q] = seedj[k + q] ^ parentTag[k + q] & sig[k + q]); + + for (u64 i = 0; i < 3; ++i) + { + auto child = childSeed[j * 3 + i].data(); + auto zi = z.data(i); + + auto w = &child[k]; + aes[i].hashBlocks<8>(s, w); + //SIMD8(q, child[k + q] = w[q]); + SIMD8(q, zi[k + q] ^= w[q]); + } + + // replace the seed with the tag. + SIMD8(q, seed[j][k+q] = tagBit(s[q])); + } + + for (u64 k = numPoints8; k < mNumPoints; ++k) { auto seedjk = seed[j][k] ^ parentTag[k] & sigma[j][k]; @@ -367,47 +403,61 @@ namespace osuCrypto } } } + + //auto size = ipow(3, mDepth); VecF sums, leafVals; ctx.resize(sums, mNumPoints); ctx.zero(sums.begin(), sums.end()); ctx.resize(leafVals, mNumPoints * mDomain); + Matrix t(mDomain, mNumPoints, AllocType::Uninitialized); + //auto t = makeMatrix(mDomain, mNumPoints, u8{}); - Matrix t(mDomain, mNumPoints); // fixing the last layer { auto& parentTags = s[(mDepth - 1) % 3]; auto& curSeed = s[mDepth % 3]; - auto leafIter = leafVals.begin(); + auto leafIter = leafVals.data(); for (u64 L = 0, L2 = 0; L2 < mDomain; ++L, L2 += 3) { // parent control bits auto parentTag = getRow(parentTags, L); + auto m = std::min(3, mDomain - L2); - // child seed - std::array scl{ getRow(curSeed, L2 + 0), getRow(curSeed, L2 + 1), getRow(curSeed, L2 + 2) }; - auto m = std::min(3, mDomain - L2); for (u64 j = 0; j < m; ++j) { - for (u64 k = 0; k < mNumPoints; ++k) + auto cs = curSeed.data(L2 + j); + auto sig = sigma.data(j); + auto tt = t.data(L2 + j); + + for (u64 k = 0; k < numPoints8; k+= 8) + { + block s[8]; + + SIMD8(q, s[q] = cs[k + q] ^ parentTag[k+q] & sig[k+q]); + SIMD8(q, tt[k+q] = lsb(s[q])); + SIMD8(q, ctx.fromBlock(leafIter[q], AES::roundFn(s[q], s[q]))); + SIMD8(q, ctx.plus(sums[k+q], sums[k+q], leafIter[q])); + leafIter += 8;; + } + + for (u64 k = numPoints8; k < mNumPoints; ++k) { - auto s = curSeed[L2 + j][k] ^ parentTag[k] & sigma[j][k]; - t[L2 + j][k] = lsb(s); + auto s = cs[k] ^ parentTag[k] & sig[k]; + tt[k] = lsb(s); ctx.fromBlock(*leafIter, AES::roundFn(s, s)); ctx.plus(sums[k], sums[k], *leafIter); ++leafIter; - - //curSeed[L2 + j][k] = /*convert_G*/ AES::roundFn(s, s);//AES::roundFn is used to get rid of the correlation in the LSB. - //sums[k] = sums[k] ^ curSeed[L2 + j][k]; - //std::cout << mPartyIdx << " " << Trit32(L2 + j) << " " << curSeed[L2 + j][k] << " " << int(curTag[L2 + j][k]) << std::endl; } } } } + + allocation.clear(); //std::cout << "----------" << std::endl; if (values.size()) @@ -420,23 +470,19 @@ namespace osuCrypto for (u64 k = 0; k < mNumPoints; ++k) { - //diff[k] = sums[k] + values[k]; ctx.plus(diff[k], values[k], sums[k]); } co_await sock.send(std::move(diff)); co_await sock.recv(gamma); for (u64 k = 0; k < mNumPoints; ++k) { - //gamma[k] = reveal(sums[k] + values[k]); ctx.plus(gamma[k], gamma[k], sums[k]); ctx.plus(gamma[k], gamma[k], values[k]); } - auto leafIter = leafVals.begin(); + auto leafIter = leafVals.data(); VecF temp; ctx.resize(temp, 1); - //auto& sd = s[mDepth % 3]; - //auto& td = t[mDepth & 1]; for (u64 i = 0; i < mDomain; ++i) { //auto sdi = getRow(sd, i); @@ -460,15 +506,17 @@ namespace osuCrypto } } + if (leafIter != leafVals.data() + leafVals.size()) + throw RTE_LOC; } else { - auto leafIter = leafVals.begin(); - auto tagIter = t.begin(); + auto leafIter = leafVals.data(); + auto tagIter = t.data(); for (u64 i = 0; i < mDomain; ++i) { - for (u64 k = numPoints8; k < mNumPoints; ++k) + for (u64 k = 0; k < mNumPoints; ++k) { if constexpr (std::is_invocable_v) output(k, i, *leafIter++, *tagIter++); @@ -477,6 +525,8 @@ namespace osuCrypto } } + if (leafIter != leafVals.data() + leafVals.size()) + throw RTE_LOC; } } @@ -497,9 +547,9 @@ namespace osuCrypto //} //std::cout << "=======" << iter << "======== " << std::endl; - Matrix sigmaShares(3, mNumPoints); - AlignedUnVector> mask(mNumPoints); - AlignedUnVector> recvBuffer(mNumPoints * 2); + Matrix sigmaShares(3, mNumPoints, AllocType::Uninitialized); + Matrix mask(mNumPoints, 3, AllocType::Uninitialized); + Matrix recvBuffer(mNumPoints * 2, 3, AllocType::Uninitialized); std::array socks; socks[0] = sock; @@ -507,30 +557,28 @@ namespace osuCrypto if (mPartyIdx) std::swap(socks[0], socks[1]); - - auto H = [](const block& a, const block& b) -> block { - return mAesFixedKey.hashBlock(mAesFixedKey.hashBlock(a) ^ b) ^ a; - //RandomOracle ro(sizeof(block)); - //ro.Update(a); - //ro.Update(b); - //block r; - //ro.Final(r); - //return r; + auto expand3 = [](const block& k, span r) { + //r = PRNG(k, 3).get(); + r[0] = k; + r[1] = k ^ block(3450136502437610243, 6108362938092146510); + r[2] = k ^ block(3428970074314387014, 2030711220607601239); + mAesFixedKey.hashBlocks<3>(r.data(), r.data()); }; + auto sender = [&]() -> macoro::task<> { PRNG prng(block(234134, 21452345 * mPartyIdx)); BitVector correction(mNumPoints * 2); - AlignedUnVector> sendBuffer(mNumPoints * 2); + AlignedUnVector sendBuffer(mNumPoints * 2 * sizeof(std::array)); + auto sendIter = sendBuffer.data(); + co_await socks[0].recv(correction); - //auto sendIter = mBaseSendOts.begin() + mOtIdx; for (u64 i = 0; i < mNumPoints; ++i) { auto keys0 = mBaseSendOts[mOtIdx + i * 2 + 0]; auto keys1 = mBaseSendOts[mOtIdx + i * 2 + 1]; - std::array k;// , m; - //std::cout << "p" << mPartyIdx << std::endl;// "\n " << k[0] << "\n " << k[1] << "\n " << k[2] << std::endl; + std::array k; for (u64 j = 0; j < 3; ++j) { auto j0 = j & 1; @@ -541,48 +589,31 @@ namespace osuCrypto auto k0 = keys0[b0]; auto k1 = keys1[b1]; - k[j] = H(k0, k1); - //std::cout << "k" << j << " " << k[j] << " = H( " - // << std::hex << k0.get(0) << " " << b0 << " " - // << std::hex << k1.get(0) << " " << b1 << " ) " << std::endl; + k[j] = k0 ^ k1; } block r = prng.get(); *BitIterator(&r) = mPartyIdx; - //std::array mask;// = prng.get(); - //mask[i] = prng.get(); auto a = points[i][mDepth - iter]; - //std::cout << "a0 " << int(a) << std::endl; - - { - - // sendBuffer[i * 3 + 0] = kj ^ mask ^ unitVec(r, a); - // 0 = kj ^ mask ^ unitVec(r, a); - // mask = kj ^ unitVec(r, a); - - mask[i] = PRNG(k[0], 3).get(); - //setBytes(mask[i], 0); - mask[i][a] ^= r; - } + expand3(k[0], mask[i]); + mask(i,a) ^= r; for (u64 j = 0; j < 2; ++j) { - std::array kj = PRNG(k[j + 1], 3).get(); - //setBytes(kj, 0); - - //sendBuffer[i * 3 + j] = PRNG(k[j], 3).get(); - sendBuffer[i * 2 + j][0] = kj[0] ^ mask[i][0]; - sendBuffer[i * 2 + j][1] = kj[1] ^ mask[i][1]; - sendBuffer[i * 2 + j][2] = kj[2] ^ mask[i][2]; - sendBuffer[i * 2 + j][(j + 1 + a) % 3] ^= r; - - //std::cout << "buffer " << j << std::endl - // << " " << buffer[i * 3 + j][0] << "\n" - // << " " << buffer[i * 3 + j][1] << "\n" - // << " " << buffer[i * 3 + j][2] << "\n"; + std::array kj; + expand3(k[j + 1], kj); +; + kj[0] ^= mask(i,0); + kj[1] ^= mask(i,1); + kj[2] ^= mask(i,2); + kj[(j + 1 + a) % 3] ^= r; + memcpy(sendIter, &kj, sizeof(kj)); + sendIter += sizeof(kj); } } + if (sendIter != sendBuffer.data() + sendBuffer.size()) + throw RTE_LOC; co_await socks[0].send(std::move(sendBuffer)); @@ -610,9 +641,8 @@ namespace osuCrypto { auto a = points[i][mDepth - iter]; - auto k = H( - mBaseRecvOts[mOtIdx + i * 2 + 0], - mBaseRecvOts[mOtIdx + i * 2 + 1]); + auto k = + mBaseRecvOts[mOtIdx + i * 2 + 0] ^ mBaseRecvOts[mOtIdx + i * 2 + 1]; //std::cout << "p" << mPartyIdx << " ka " << k << " = H( " // << std::hex << mBaseRecvOts[i * 2 + 0].get(0) << " " << int(mBaseChoice[i * 2 + 0]) << " " // << std::hex << mBaseRecvOts[i * 2 + 1].get(0) << " " << int(mBaseChoice[i * 2 + 0]) << " )" << " a1 " << int(a) << std::endl; @@ -621,17 +651,17 @@ namespace osuCrypto // << " " << buffer[i * 3 + a][0] << "\n" // << " " << buffer[i * 3 + a][1] << "\n" // << " " << buffer[i * 3 + a][2] << "\n"; - std::array ka = PRNG(k, 3).get(); - //setBytes(ka, 0); + std::array ka; + expand3(k, ka); - sigma[0][i] = ka[0] ^ mask[i][0] ^ z[0][i]; - sigma[1][i] = ka[1] ^ mask[i][1] ^ z[1][i]; - sigma[2][i] = ka[2] ^ mask[i][2] ^ z[2][i]; + sigma(0,i) = ka[0] ^ mask(i,0) ^ z(0,i); + sigma(1,i) = ka[1] ^ mask(i,1) ^ z(1,i); + sigma(2,i) = ka[2] ^ mask(i,2) ^ z(2,i); if (a) { - sigma[0][i] ^= recvBuffer[i * 2 + a - 1][0]; - sigma[1][i] ^= recvBuffer[i * 2 + a - 1][1]; - sigma[2][i] ^= recvBuffer[i * 2 + a - 1][2]; + sigma(0,i) ^= recvBuffer(i * 2 + a - 1,0); + sigma(1,i) ^= recvBuffer(i * 2 + a - 1,1); + sigma(2,i) ^= recvBuffer(i * 2 + a - 1,2); } //std::cout << "sigma " << std::endl @@ -643,13 +673,9 @@ namespace osuCrypto co_await sock.recv(sigmaShares); - for (u64 i = 0; i < mNumPoints; ++i) + for (u64 i = 0; i < sigma.size(); ++i) { - for (u64 j = 0; j < 3; ++j) - { - //std::cout << "sigma = " << (sigma[j][i] ^ sigmaShares[j][i]) << " = " << sigma[j][i] << " ^ " << sigmaShares[j][i] << std::endl; - sigma[j][i] ^= sigmaShares[j][i];//^ mask[i][j]; - } + sigma(i) ^= sigmaShares(i);//^ mask[i][j]; } mOtIdx += mNumPoints * 2; @@ -742,6 +768,15 @@ namespace osuCrypto mBaseChoice[i] = baseChoices[i]; } } + + + // extracts the lsb of b and returns a block saturated with that bit. + static block tagBit(const block& b) + { + auto bit = b & block(0, 1); + auto mask = _mm_sub_epi64(_mm_set1_epi64x(0), bit); + return _mm_unpacklo_epi64(mask, mask); + } }; } diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index dc65c785..4a63cb2d 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -171,7 +171,7 @@ namespace osuCrypto if (divCeil(mN, 128) < ALsb.size()) throw RTE_LOC; - if (ALsb.size() != AMsb.size() || + if (ALsb.size() != AMsb.size() || ALsb.size() != CLsb.size() || ALsb.size() != CMsb.size()) throw RTE_LOC; @@ -285,7 +285,7 @@ namespace osuCrypto // next to each other. We do this by using nextIdx to // keep track of the next index for each output block. size_t idx = polyOffset + blockIdx * mT + nextIdx[blockIdx]++; - + // split the position into the portion that will position // the F4 coefficient within the F4^243 coefficient and the // portion that will position the F4^243 coefficient within @@ -315,9 +315,18 @@ namespace osuCrypto // current coefficients are single F4 elements. Expand them into // 3^5=243 elements. These will be used as the new coefficients // in the large tree. - co_await mDpfLeaf.expand(prodPolyLeafPos, prodPolyF4Coeffs, [&](u64 treeIdx, u64 leafIdx, u8 v) { - *BitIterator(&prodPolyF4x243Coeffs[treeIdx], leafIdx * 2 + 0) = (v >> 0) & 1; - *BitIterator(&prodPolyF4x243Coeffs[treeIdx], leafIdx * 2 + 1) = (v >> 1) & 1; + co_await mDpfLeaf.expand(prodPolyLeafPos, prodPolyF4Coeffs, + [&, byteIdx = 0, bitIdx = 0](u64 treeIdx, u64 leafIdx, u8 v) mutable { + if (treeIdx == 0) + { + byteIdx = leafIdx / 4; + bitIdx = leafIdx % 4; + } + assert(byteIdx == leafIdx / 4); + assert(bitIdx == leafIdx % 4); + + auto ptr = (u8*)&prodPolyF4x243Coeffs.data()[treeIdx]; + ptr[byteIdx] |= u8((v & 3) << (2 * bitIdx)); }, prng, sock); setTimePoint("leafDpf"); diff --git a/libOTe/Tools/Foleage/FoleagePcg.h b/libOTe/Tools/Foleage/FoleagePcg.h index 59e96580..acc25fbb 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.h +++ b/libOTe/Tools/Foleage/FoleagePcg.h @@ -66,8 +66,21 @@ namespace osuCrypto // a dpf used to construct the F4x243 leaf value of the larger DPF. TriDpf mDpfLeaf; + + struct FoleageCoeffCtx : CoeffCtxGF2 + { + + OC_FORCEINLINE void fromBlock(FoleageF4x243& ret, const block& b) { + ret.mVal[0] = b; + ret.mVal[1] = b ^ block(2314523225322345310, 3520873105824273452); + ret.mVal[2] = b ^ block(3456459829022368567, 2452343456563201231); + ret.mVal[3] = b ^ block(2430734095872024920, 8425914932983749298); + mAesFixedKey.hashBlocks<4>(ret.mVal.data(), ret.mVal.data()); + } + }; + // the main DPF which outputs 243 F4 elements for each leaf. - TriDpf mDpf; + TriDpf mDpf; // The base OTs used to tensor the coefficients of the sparse polynomial. std::vector mRecvOts; From 062ffb477dc37dfc43daaf88a0a44119522c2fc1 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 19 Feb 2025 13:37:59 -0800 Subject: [PATCH 21/48] cleanup --- frontend/benchmark.h | 4 +- libOTe/Tools/Dpf/DpfMult.h | 2 +- libOTe/Tools/Dpf/RegularDpf.h | 6 +- libOTe/Tools/Dpf/SparseDpf.h | 3 +- libOTe/Tools/Dpf/TriDpf.h | 16 +- libOTe/Tools/Foleage/FoleageMain.cpp | 20 +- libOTe/Tools/Foleage/FoleagePcg.cpp | 6 +- libOTe/Tools/Foleage/FoleageUtils.h | 31 +- libOTe/Tools/Foleage/fft/FoleageFft.cpp | 22 +- libOTe/Tools/Foleage/fft/FoleageFft.h | 4 +- libOTe/Tools/Foleage/spfss_test.cpp | 273 ++- libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp | 94 +- libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h | 6 +- .../Tools/Foleage/tri-dpf/FoleageDpf_test.cpp | 18 +- libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h | 15 +- libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h | 10 +- libOTe/Tools/Foleage/uint128.h | 1578 ++++++++--------- libOTe_Tests/Foleage_Tests.cpp | 273 +-- libOTe_Tests/RegularDpf_Tests.cpp | 8 +- 19 files changed, 1197 insertions(+), 1192 deletions(-) diff --git a/frontend/benchmark.h b/frontend/benchmark.h index dc9cfc36..d066eedb 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -788,7 +788,7 @@ namespace osuCrypto std::vector points(numPoints); std::vector values0(numPoints); std::vector values1(numPoints); - Ctx ctx; + //Ctx ctx; for (u64 i = 0; i < numPoints; ++i) { points[i] = Trit32(prng.get() % domain); @@ -866,7 +866,7 @@ namespace osuCrypto auto logn = cmd.getOr("nn", 10); u64 n = ipow(3, logn); auto blocks = divCeil(n, 128); - bool verbose = cmd.isSet("v"); + //bool verbose = cmd.isSet("v"); u64 trials = cmd.getOr("trials", 1); diff --git a/libOTe/Tools/Dpf/DpfMult.h b/libOTe/Tools/Dpf/DpfMult.h index abcd462b..a3655a58 100644 --- a/libOTe/Tools/Dpf/DpfMult.h +++ b/libOTe/Tools/Dpf/DpfMult.h @@ -128,7 +128,7 @@ namespace osuCrypto { auto Phi = block(-u64(phi[j]), -u64(phi[j])); theta[j] ^= theta1[j]; - xy[j] = C[j] ^ theta[j] & A0[j] ^ Phi & b1[j]; + xy[j] = C[j] ^ (theta[j] & A0[j]) ^ (Phi & b1[j]); if (mPartyIdx) xy[j] ^= theta[j] & Phi; diff --git a/libOTe/Tools/Dpf/RegularDpf.h b/libOTe/Tools/Dpf/RegularDpf.h index 08080ff7..76fc38f9 100644 --- a/libOTe/Tools/Dpf/RegularDpf.h +++ b/libOTe/Tools/Dpf/RegularDpf.h @@ -539,7 +539,7 @@ namespace osuCrypto { for (u64 j = 0; j < 2; ++j) { - SIMD8(q, temp[q] = currentSeed[j][k + q] ^ parentTag[k + q] & sigma[j][k + q]); + SIMD8(q, temp[q] = currentSeed[j][k + q] ^ (parentTag[k + q] & sigma[j][k + q])); SIMD8(q, tag[j][k + q] = tagBit(temp[q])); SIMD8(q, currentSeed[j][k + q] = AES::roundFn(temp[q], temp[q])); SIMD8(q, diff[k + q] ^= currentSeed[j][k + q]); @@ -550,7 +550,7 @@ namespace osuCrypto { for (u64 j = 0; j < 2; ++j) { - temp[0] = currentSeed[j][k] ^ parentTag[k] & sigma[j][k]; + temp[0] = currentSeed[j][k] ^ (parentTag[k] & sigma[j][k]); tag[j][k] = tagBit(temp[0]); currentSeed[j][k] = AES::roundFn(temp[0], temp[0]); diff[k] ^= currentSeed[j][k]; @@ -559,7 +559,7 @@ namespace osuCrypto } } - if (values.size() || inputKey && inputKey->mLeafVals.size()) + if (values.size() || (inputKey && inputKey->mLeafVals.size())) { AlignedUnVector gamma(mNumPoints); if (inputKey) diff --git a/libOTe/Tools/Dpf/SparseDpf.h b/libOTe/Tools/Dpf/SparseDpf.h index 311b2099..be3d8fae 100644 --- a/libOTe/Tools/Dpf/SparseDpf.h +++ b/libOTe/Tools/Dpf/SparseDpf.h @@ -273,7 +273,7 @@ namespace osuCrypto densePoints[i] = points[i] >> depth; Matrix seeds(points.size(), 1ull << mDenseDepth); Matrix tags(points.size(), 1ull << mDenseDepth); - co_await mRegDpf.expand(densePoints, {}, prng.get(), [&](auto treeIdx, auto leafIdx, auto seed, auto tag) { + co_await mRegDpf.expand(densePoints, {}, prng.get(), [&](auto treeIdx, auto leafIdx, auto seed, block tag) { seeds(treeIdx, leafIdx) = seed; tags(treeIdx, leafIdx) = tag.get(0)&1; }, sock); @@ -495,7 +495,6 @@ namespace osuCrypto //std::cout << "-----------final-------------" << std::endl; for (u64 r = 0; r < mNumPoints; ++r) { - auto& tree = trees[r]; auto size = sparsePoints[r].size(); for (u64 i = 0; i < size; ++i) { diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index c8da014b..27494703 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -306,9 +306,9 @@ namespace osuCrypto } std::array aes{ - AES(block(324532455457855483,3575765667434524523)), - AES(block(456475435444364534,9923458239234989843)), - AES(block(324532450985209453,5387987243989842789)) }; + AES(block(324532455457855483ull,3575765667434524523ull)), + AES(block(456475435444364534ull,9923458239234989843ull)), + AES(block(324532450985209453ull,5387987243989842789ull)) }; // at each iteration we first correct the parent level. // The parent level has two siblings which are random. @@ -368,7 +368,7 @@ namespace osuCrypto for (u64 k = 0; k < numPoints8; k += 8) { block s[8];//, w[8]; - SIMD8(q, s[q] = seedj[k + q] ^ parentTag[k + q] & sig[k + q]); + SIMD8(q, s[q] = seedj[k + q] ^ (parentTag[k + q] & sig[k + q])); for (u64 i = 0; i < 3; ++i) { @@ -387,7 +387,7 @@ namespace osuCrypto for (u64 k = numPoints8; k < mNumPoints; ++k) { - auto seedjk = seed[j][k] ^ parentTag[k] & sigma[j][k]; + auto seedjk = seed[j][k] ^ (parentTag[k] & sigma[j][k]); for (u64 i = 0; i < 3; ++i) { @@ -437,7 +437,7 @@ namespace osuCrypto { block s[8]; - SIMD8(q, s[q] = cs[k + q] ^ parentTag[k+q] & sig[k+q]); + SIMD8(q, s[q] = cs[k + q] ^ (parentTag[k+q] & sig[k+q])); SIMD8(q, tt[k+q] = lsb(s[q])); SIMD8(q, ctx.fromBlock(leafIter[q], AES::roundFn(s[q], s[q]))); SIMD8(q, ctx.plus(sums[k+q], sums[k+q], leafIter[q])); @@ -446,7 +446,7 @@ namespace osuCrypto for (u64 k = numPoints8; k < mNumPoints; ++k) { - auto s = cs[k] ^ parentTag[k] & sig[k]; + auto s = cs[k] ^ (parentTag[k] & sig[k]); tt[k] = lsb(s); ctx.fromBlock(*leafIter, AES::roundFn(s, s)); @@ -466,7 +466,7 @@ namespace osuCrypto ctx.resize(gamma, mNumPoints); ctx.resize(diff, mNumPoints); - auto& curSeed = s[mDepth % 3]; + //auto& curSeed = s[mDepth % 3]; for (u64 k = 0; k < mNumPoints; ++k) { diff --git a/libOTe/Tools/Foleage/FoleageMain.cpp b/libOTe/Tools/Foleage/FoleageMain.cpp index 37517dae..2ba24634 100644 --- a/libOTe/Tools/Foleage/FoleageMain.cpp +++ b/libOTe/Tools/Foleage/FoleageMain.cpp @@ -74,7 +74,7 @@ namespace osuCrypto size_t alpha = random_index(block_size, prng); // Pick a random output message for benchmarking purposes - uint128_t beta[DPF_MSG_SIZE]; + block beta[DPF_MSG_SIZE]; prng.get(beta, DPF_MSG_SIZE); // Message (beta) is of size 8 blocks of 128 bits @@ -89,13 +89,13 @@ namespace osuCrypto //************************************************ // Allocate memory for the DPF outputs (this is reused for each evaluation) - AlignedUnVector shares(dpf_block_size); - AlignedUnVector cache(dpf_block_size); + AlignedUnVector shares(dpf_block_size); + AlignedUnVector cache(dpf_block_size); // Allocate memory for the concatenated DPF outputs const size_t packed_block_size = ceil(block_size / 64.0); const size_t packed_poly_size = t * packed_block_size; - AlignedUnVector packed_polys(c * c * packed_poly_size); + AlignedUnVector packed_polys(c * c * packed_poly_size); // Allocate memory for the output FFT AlignedUnVector fft_u(poly_size); @@ -112,14 +112,14 @@ namespace osuCrypto time = clock(); size_t key_index; - uint128_t* poly_block; + block* poly_block; size_t i, j, k, l, w; for (i = 0; i < c; i++) { for (j = 0; j < c; j++) { const size_t poly_index = i * c + j; - uint128_t* packed_poly = &packed_polys[poly_index * packed_poly_size]; + block* packed_poly = &packed_polys[poly_index * packed_poly_size]; for (k = 0; k < t; k++) { @@ -148,7 +148,7 @@ namespace osuCrypto for (size_t i = 0; i < c * c; i++) { size_t poly_index = i * packed_poly_size; - const uint128_t* poly = &packed_polys[poly_index]; + const block* poly = &packed_polys[poly_index]; #ifdef ENABLE_SSE _mm_prefetch((char*)poly, _MM_HINT_T2); @@ -156,7 +156,7 @@ namespace osuCrypto size_t block_idx, packed_coeff_idx, coeff_idx; //uint8_t packed_bit_idx; - uint128_t packed_coeff; + block packed_coeff; block_idx = 0; packed_coeff_idx = 0; @@ -175,7 +175,7 @@ namespace osuCrypto for (size_t l = 0; l < 64; l++) { packed_coeff = packed_coeff >> 2; - fft_u[k + l] |= static_cast(packed_coeff) & 0b11; + fft_u[k + l] |= static_cast(packed_coeff.get(0)) & 0b11; fft_u[k + l] = fft_u[k + l] << 2; } @@ -199,7 +199,7 @@ namespace osuCrypto for (size_t k = poly_size - 64 + 1; k < poly_size; k++) { packed_coeff = packed_coeff >> 2; - fft_u[k] |= static_cast(packed_coeff) & 0b11 ; + fft_u[k] |= static_cast(packed_coeff.get(0)) & 0b11 ; fft_u[k] = fft_u[k] << 2; } } diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index 4a63cb2d..bdeb4af7 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -316,7 +316,7 @@ namespace osuCrypto // 3^5=243 elements. These will be used as the new coefficients // in the large tree. co_await mDpfLeaf.expand(prodPolyLeafPos, prodPolyF4Coeffs, - [&, byteIdx = 0, bitIdx = 0](u64 treeIdx, u64 leafIdx, u8 v) mutable { + [&, byteIdx = 0ull, bitIdx = 0ull](u64 treeIdx, u64 leafIdx, u8 v) mutable { if (treeIdx == 0) { byteIdx = leafIdx / 4; @@ -337,7 +337,7 @@ namespace osuCrypto // to a block together. This will give us the coefficients of the // the product polynomial. co_await mDpf.expand(prodPolyTreePos, prodPolyF4x243Coeffs, - [&, count = 0, out = blocks.data(), end = blocks.data() + blocks.size()] + [&, count = 0ull, out = blocks.data(), end = blocks.data() + blocks.size()] (u64 treeIdx, u64 leafIdx, FoleageF4x243 v) mutable { // the callback is called in column major order but blocks // is row major (leafIdx will be the same). So we need to compute @@ -405,7 +405,7 @@ namespace osuCrypto // XOR the (packed) columns into the accumulator. // Specifically, we perform column-wise XORs to get the result. - uint128_t lsbMask, msbMask; + u32 lsbMask, msbMask; setBytes(lsbMask, 0b01010101); setBytes(msbMask, 0b10101010); for (size_t i = 0; i < outSize; i++) diff --git a/libOTe/Tools/Foleage/FoleageUtils.h b/libOTe/Tools/Foleage/FoleageUtils.h index 9c28f269..444724ee 100644 --- a/libOTe/Tools/Foleage/FoleageUtils.h +++ b/libOTe/Tools/Foleage/FoleageUtils.h @@ -9,7 +9,7 @@ namespace osuCrypto { - using uint128_t = absl::uint128_t; + //using uint128_t = absl::uint128_t; //using int128_t = block; //using uint128_t = block; //using uint128_t = __uint128_t; @@ -301,14 +301,33 @@ namespace osuCrypto return result; } - inline int popcount(uint128_t x) + inline int popcount(block x) { - std::array xArr; - memcpy(xArr.data(), &x, 16); - return popcount(xArr[0]) + popcount(xArr[1]); + //std::array xArr; + //memcpy(xArr.data(), &x, 16); + return popcount(x.get(0)) + popcount(x.get(1)); } + //inline int popcount(uint128_t x) + //{ + // std::array xArr; + // memcpy(xArr.data(), &x, 16); + // return popcount(xArr[0]) + popcount(xArr[1]); + //} - inline std::array extractF4(const uint128_t& val) + //inline std::array extractF4(const uint128_t& val) + //{ + // std::array ret; + // const char* ptr = (const char*)&val; + // for (u8 i = 0; i < 16; ++i) + // { + // ret[i * 4 + 0] = (ptr[i] >> 0) & 3; + // ret[i * 4 + 1] = (ptr[i] >> 2) & 3; + // ret[i * 4 + 2] = (ptr[i] >> 4) & 3; + // ret[i * 4 + 3] = (ptr[i] >> 6) & 3;; + // } + // return ret; + //} + inline std::array extractF4(const block& val) { std::array ret; const char* ptr = (const char*)&val; diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.cpp b/libOTe/Tools/Foleage/fft/FoleageFft.cpp index 4c8effd4..7f7b8939 100644 --- a/libOTe/Tools/Foleage/fft/FoleageFft.cpp +++ b/libOTe/Tools/Foleage/fft/FoleageFft.cpp @@ -53,7 +53,7 @@ namespace osuCrypto { // computed as: mult_l = (h ^ l) and mult_h = l // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); // tL coefficient obtained by evaluating on X_i=1 tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; @@ -129,7 +129,7 @@ namespace osuCrypto { // computed as: mult_l = (h ^ l) and mult_h = l // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); // tL coefficient obtained by evaluating on X_i=1 tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; @@ -205,7 +205,7 @@ namespace osuCrypto { // computed as: mult_l = (h ^ l) and mult_h = l // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); // tL coefficient obtained by evaluating on X_i=1 tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; @@ -281,7 +281,7 @@ namespace osuCrypto { // computed as: mult_l = (h ^ l) and mult_h = l // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = (xor_h >> 1) ^ (xor_l) | (xor_l << 1); + mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); // tL coefficient obtained by evaluating on X_i=1 tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; @@ -435,13 +435,13 @@ namespace osuCrypto { block coeffsR0[width]; block coeffsR1[width]; - //{ constexpr u64 VAR = 1; STATEMENT; }\ - //{ constexpr u64 VAR = 2; STATEMENT; }\ - //{ constexpr u64 VAR = 3; STATEMENT; }\ - //{ constexpr u64 VAR = 4; STATEMENT; }\ - //{ constexpr u64 VAR = 5; STATEMENT; }\ - //{ constexpr u64 VAR = 6; STATEMENT; }\ - //{ constexpr u64 VAR = 7; STATEMENT; }\ + //{ constexpr u64 VAR = 1; STATEMENT; } + //{ constexpr u64 VAR = 2; STATEMENT; } + //{ constexpr u64 VAR = 3; STATEMENT; } + //{ constexpr u64 VAR = 4; STATEMENT; } + //{ constexpr u64 VAR = 5; STATEMENT; } + //{ constexpr u64 VAR = 6; STATEMENT; } + //{ constexpr u64 VAR = 7; STATEMENT; } #define SIMD8(VAR, STATEMENT) \ { constexpr u64 VAR = 0; STATEMENT; }\ diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.h b/libOTe/Tools/Foleage/fft/FoleageFft.h index 6eea18b6..3f233e7c 100644 --- a/libOTe/Tools/Foleage/fft/FoleageFft.h +++ b/libOTe/Tools/Foleage/fft/FoleageFft.h @@ -5,6 +5,7 @@ #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/MatrixView.h" #include "libOTe/Tools/Foleage/FoleageUtils.h" +#include //#include "libOTe/Tools/Foleage/utils.h" namespace osuCrypto { @@ -328,7 +329,7 @@ namespace osuCrypto { - template + template OC_FORCEINLINE void foleageFFTOne( T* __restrict coeffsL0, T* __restrict coeffsL1, @@ -338,7 +339,6 @@ namespace osuCrypto { T* __restrict coeffsR1) { -#pragma unroll(stride) for (u64 i = 0; i < stride; ++i) { diff --git a/libOTe/Tools/Foleage/spfss_test.cpp b/libOTe/Tools/Foleage/spfss_test.cpp index 9e0b301d..237d331b 100644 --- a/libOTe/Tools/Foleage/spfss_test.cpp +++ b/libOTe/Tools/Foleage/spfss_test.cpp @@ -1,158 +1,115 @@ -#include -#include -#include - -#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" -#include "FoleageUtils.h" - -#define SUMT 730 // sum of T DPFs - -#define FULLEVALDOMAIN 10 -#define MESSAGESIZE 8 -#define MAXRANDINDEX ipow(3, FULLEVALDOMAIN) -namespace osuCrypto -{ - - //size_t randIndex() - //{ - // srand(time(NULL)); - // return ((size_t)rand()) % ((size_t)MAXRANDINDEX); - //} - - //uint128_t randMsg() - //{ - // uint128_t msg; - // RAND_bytes((uint8_t*)&msg, sizeof(uint128_t)); - // return msg; - //} - - double benchmark_spfss() - { - size_t num_leaves = ipow(3, FULLEVALDOMAIN); - size_t size = FULLEVALDOMAIN; // evaluation will result in 3^size points - PRNG prng(block(3423423)); - - size_t secret_index = prng.get() % MAXRANDINDEX; - uint128_t secret_msg = prng.get(); - size_t msg_len = MESSAGESIZE; - - PRFKeys prf_keys; - prf_keys.gen(prng); - - std::vector kA(SUMT); - std::vector kB(SUMT); - - clock_t t; - t = clock(); - - for (size_t i = 0; i < SUMT; i++) - DPFGen(prf_keys, size, secret_index, span(&secret_msg,1), msg_len, kA[i], kB[i], prng); - - t = clock() - t; - double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("Time %f ms\n", time_taken); - - return time_taken; - } - - double benchmarkAES() - { - size_t num_leaves = ipow(3, FULLEVALDOMAIN); - size_t size = FULLEVALDOMAIN; - PRNG prng(block(3423423)); - - PRFKeys prf_keys; - prf_keys.gen(prng); - - AlignedUnVector data_in (num_leaves * MESSAGESIZE); - AlignedUnVector data_out(num_leaves * MESSAGESIZE); - AlignedUnVector data_tmp(num_leaves * MESSAGESIZE); - AlignedUnVector tmp; - - // fill with unique data - for (size_t i = 0; i < num_leaves * MESSAGESIZE; i++) - data_tmp[i] = (uint128_t)i; - - // make the input data pseudorandom for correct timing - PRFBatchEval(prf_keys.prf_key0, data_tmp, data_in, num_leaves * MESSAGESIZE); - - //************************************************ - // Benchmark AES encryption time required in DPF loop - //************************************************ - - clock_t t; - t = clock(); - - for (size_t n = 0; n < SUMT; n++) - { - size_t num_nodes = 1; - for (size_t i = 0; i < size; i++) - { - PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes); - PRFBatchEval(prf_keys.prf_key1, data_in, data_out.subspan(num_nodes), num_nodes); - PRFBatchEval(prf_keys.prf_key2, data_in, data_out.subspan(num_nodes * 2), num_nodes); - - tmp = data_out; - data_out = data_in; - data_in = tmp; - - num_nodes *= 3; - } - // compute AES part of output extension - PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes * MESSAGESIZE); - } - - t = clock() - t; - double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("Time %f ms\n", time_taken); - - return time_taken; - } - - int mainSpfss(int argc, char** argv) - { - - double time = 0; - int testTrials = 10; - - //printf("******************************************\n"); - //printf("Testing DPF.FullEval\n"); - //for (int i = 0; i < testTrials; i++) - //{ - // time += foliage_spfss_test(); - // printf("Done with trial %i of %i\n", i + 1, testTrials); - //} - //printf("******************************************\n"); - //printf("PASS\n"); - //printf("DPF.FullEval: (avg time) %0.2f ms\n", time / testTrials); - //printf("******************************************\n\n"); - - time = 0; - printf("******************************************\n"); - printf("Benchmarking DPF.Gen\n"); - for (int i = 0; i < testTrials; i++) - { - time += benchmark_spfss(); - printf("Done with trial %i of %i\n", i + 1, testTrials); - } - printf("******************************************\n"); - printf("Avg time: %0.4f ms\n", time / testTrials); - printf("******************************************\n\n"); - - time = 0; - printf("******************************************\n"); - printf("Benchmarking AES\n"); - for (int i = 0; i < testTrials; i++) - { - time += benchmarkAES(); - printf("Done with trial %i of %i\n", i + 1, testTrials); - } - printf("******************************************\n"); - printf("Avg time: %0.2f ms\n", time / testTrials); - printf("******************************************\n\n"); - - return 0; - } -} \ No newline at end of file +//#include +//#include +//#include +// +//#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" +//#include "FoleageUtils.h" +// +//#define SUMT 730 // sum of T DPFs +// +//#define FULLEVALDOMAIN 10 +//#define MESSAGESIZE 8 +//#define MAXRANDINDEX ipow(3, FULLEVALDOMAIN) +//namespace osuCrypto +//{ +// +// double benchmarkAES() +// { +// size_t num_leaves = ipow(3, FULLEVALDOMAIN); +// size_t size = FULLEVALDOMAIN; +// PRNG prng(block(3423423)); +// +// PRFKeys prf_keys; +// prf_keys.gen(prng); +// +// AlignedUnVector data_in (num_leaves * MESSAGESIZE); +// AlignedUnVector data_out(num_leaves * MESSAGESIZE); +// AlignedUnVector data_tmp(num_leaves * MESSAGESIZE); +// AlignedUnVector tmp; +// +// // fill with unique data +// for (size_t i = 0; i < num_leaves * MESSAGESIZE; i++) +// data_tmp[i] = block(i); +// +// // make the input data pseudorandom for correct timing +// PRFBatchEval(prf_keys.prf_key0, data_tmp, data_in, num_leaves * MESSAGESIZE); +// +// //************************************************ +// // Benchmark AES encryption time required in DPF loop +// //************************************************ +// +// clock_t t; +// t = clock(); +// +// for (size_t n = 0; n < SUMT; n++) +// { +// size_t num_nodes = 1; +// for (size_t i = 0; i < size; i++) +// { +// PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes); +// PRFBatchEval(prf_keys.prf_key1, data_in, data_out.subspan(num_nodes), num_nodes); +// PRFBatchEval(prf_keys.prf_key2, data_in, data_out.subspan(num_nodes * 2), num_nodes); +// +// tmp = data_out; +// data_out = data_in; +// data_in = tmp; +// +// num_nodes *= 3; +// } +// // compute AES part of output extension +// PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes * MESSAGESIZE); +// } +// +// t = clock() - t; +// double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms +// +// printf("Time %f ms\n", time_taken); +// +// return time_taken; +// } +// +// int mainSpfss(int argc, char** argv) +// { +// +// double time = 0; +// int testTrials = 10; +// +// //printf("******************************************\n"); +// //printf("Testing DPF.FullEval\n"); +// //for (int i = 0; i < testTrials; i++) +// //{ +// // time += foliage_spfss_test(); +// // printf("Done with trial %i of %i\n", i + 1, testTrials); +// //} +// //printf("******************************************\n"); +// //printf("PASS\n"); +// //printf("DPF.FullEval: (avg time) %0.2f ms\n", time / testTrials); +// //printf("******************************************\n\n"); +// +// time = 0; +// //printf("******************************************\n"); +// //printf("Benchmarking DPF.Gen\n"); +// //for (int i = 0; i < testTrials; i++) +// //{ +// // time += benchmark_spfss(); +// // printf("Done with trial %i of %i\n", i + 1, testTrials); +// //} +// //printf("******************************************\n"); +// //printf("Avg time: %0.4f ms\n", time / testTrials); +// //printf("******************************************\n\n"); +// +// time = 0; +// printf("******************************************\n"); +// printf("Benchmarking AES\n"); +// for (int i = 0; i < testTrials; i++) +// { +// time += benchmarkAES(); +// printf("Done with trial %i of %i\n", i + 1, testTrials); +// } +// printf("******************************************\n"); +// printf("Avg time: %0.2f ms\n", time / testTrials); +// printf("******************************************\n\n"); +// +// return 0; +// } +//} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp index 359f7243..28f4ca66 100644 --- a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp +++ b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp @@ -9,15 +9,15 @@ #define LOG_BATCH_SIZE 6 // operate in smallish batches to maximize cache hits namespace osuCrypto { + // Naming conventions: // - A,B refer to shares given to parties A and B // - 0,1,2 refer to the branch index in the ternary tree - void DPFGen( PRFKeys& prf_keys, size_t domain_size, size_t index, - span msg_blocks, + span msg_blocks, size_t msg_block_len, DPFKey& k0, DPFKey& k1, @@ -25,35 +25,35 @@ namespace osuCrypto { // starting seeds given to each party - uint128_t seedA = prng.get(); - uint128_t seedB = prng.get(); + block seedA = prng.get(); + block seedB = prng.get(); // correction word provided to both parties // (one correction word per level) - std::vector sCW0(domain_size); - std::vector sCW1(domain_size); - std::vector sCW2(domain_size); + std::vector sCW0(domain_size); + std::vector sCW1(domain_size); + std::vector sCW2(domain_size); // variables for the intermediate values - uint128_t parent, parentA, parentB, sA0, sA1, sA2, sB0, sB1, sB2; + block parent, parentA, parentB, sA0, sA1, sA2, sB0, sB1, sB2; // current parent value (xor of the two seeds) parent = seedA ^ seedB; // control bit of the parent on the special path must always be set to 1 // so as to apply the corresponding correction word - if (get_lsb(parent) == uint128_t{ 0 }) + if (get_lsb(parent) == ZeroBlock) seedA = flip_lsb(seedA); parentA = seedA; parentB = seedB; - uint8_t prev_control_bit_A, prev_control_bit_B; + block prev_control_bit_A, prev_control_bit_B; for (size_t i = 0; i < domain_size; i++) { - prev_control_bit_A = static_cast(get_lsb(parentA)); - prev_control_bit_B = static_cast(get_lsb(parentB)); + prev_control_bit_A = get_lsb(parentA); + prev_control_bit_B = get_lsb(parentB); // expand the starting seeds of each party PRFEval(prf_keys.prf_key0, parentA, sA0); @@ -65,7 +65,7 @@ namespace osuCrypto // on-path correction word is set to random // so as to be indistinguishable from the real correction words - uint128_t r = prng.get(); + block r = prng.get(); // get the current trit (ternary bit) of the special index uint8_t trit = get_trit(index, domain_size, i); @@ -74,14 +74,14 @@ namespace osuCrypto { case 0: parent = sA0 ^ sB0 ^ r; - if (get_lsb(parent) == 0) + if (get_lsb(parent) == ZeroBlock) r = flip_lsb(r); sCW0[i] = r; sCW1[i] = sA1 ^ sB1; sCW2[i] = sA2 ^ sB2; - if (get_lsb(parentA) == 1) + if (get_lsb(parentA) == AllOneBlock) { parentA = sA0 ^ r; parentB = sB0; @@ -96,14 +96,14 @@ namespace osuCrypto case 1: parent = sA1 ^ sB1 ^ r; - if (get_lsb(parent) == 0) + if (get_lsb(parent) == ZeroBlock) r = flip_lsb(r); sCW0[i] = sA0 ^ sB0; sCW1[i] = r; sCW2[i] = sA2 ^ sB2; - if (get_lsb(parentA) == 1) + if (get_lsb(parentA) == AllOneBlock) { parentA = sA1 ^ r; parentB = sB1; @@ -118,14 +118,14 @@ namespace osuCrypto case 2: parent = sA2 ^ sB2 ^ r; - if (get_lsb(parent) == 0) + if (get_lsb(parent) == ZeroBlock) r = flip_lsb(r); sCW0[i] = sA0 ^ sB0; sCW1[i] = sA1 ^ sB1; sCW2[i] = r; - if (get_lsb(parentA) == 1) + if (get_lsb(parentA) == AllOneBlock) { parentA = sA2 ^ r; parentB = sB2; @@ -145,29 +145,29 @@ namespace osuCrypto } // set the last correction word to correct the output to msg - uint128_t leaf_seedA, leaf_seedB; + block leaf_seedA, leaf_seedB; uint8_t last_trit = get_trit(index, domain_size, domain_size - 1); if (last_trit == 0) { - leaf_seedA = sA0 ^ uint128_t(prev_control_bit_A * sCW0[domain_size - 1]); - leaf_seedB = sB0 ^ uint128_t(prev_control_bit_B * sCW0[domain_size - 1]); + leaf_seedA = sA0 ^ prev_control_bit_A & sCW0[domain_size - 1]; + leaf_seedB = sB0 ^ prev_control_bit_B & sCW0[domain_size - 1]; } else if (last_trit == 1) { - leaf_seedA = sA1 ^ uint128_t(prev_control_bit_A * sCW1[domain_size - 1]); - leaf_seedB = sB1 ^ uint128_t(prev_control_bit_B * sCW1[domain_size - 1]); - } + leaf_seedA = sA1 ^ prev_control_bit_A & sCW1[domain_size - 1]; + leaf_seedB = sB1 ^ prev_control_bit_B & sCW1[domain_size - 1]; + } else if (last_trit == 2) { - leaf_seedA = sA2 ^ uint128_t(prev_control_bit_A * sCW2[domain_size - 1]); - leaf_seedB = sB2 ^ uint128_t(prev_control_bit_B * sCW2[domain_size - 1]); + leaf_seedA = sA2 ^ prev_control_bit_A & sCW2[domain_size - 1]; + leaf_seedB = sB2 ^ prev_control_bit_B & sCW2[domain_size - 1]; } - AlignedUnVector outputA(msg_block_len); - AlignedUnVector outputB(msg_block_len); - AlignedUnVector cache(msg_block_len); - AlignedUnVector outputCW(msg_block_len); + AlignedUnVector outputA(msg_block_len); + AlignedUnVector outputB(msg_block_len); + AlignedUnVector cache(msg_block_len); + AlignedUnVector outputCW(msg_block_len); outputA[0] = leaf_seedA; outputB[0] = leaf_seedB; @@ -180,9 +180,9 @@ namespace osuCrypto // memcpy all the generated values into two keys // 16 = sizeof(uint128_t) - size_t key_size = sizeof(uint128_t); // initial seed size; - key_size += 3 * domain_size * sizeof(uint128_t); // correction words - key_size += sizeof(uint128_t) * msg_block_len; // output correction word + size_t key_size = sizeof(block); // initial seed size; + key_size += 3 * domain_size * sizeof(block); // correction words + key_size += sizeof(block) * msg_block_len; // output correction word k0.prf_keys = &prf_keys; k0.k.resize(key_size); @@ -215,8 +215,8 @@ namespace osuCrypto // is only expanded once. void DPFFullDomainEval( DPFKey& key, - span cache, - span output) + span cache, + span output) { size_t size = key.size; span k = key.k; @@ -233,15 +233,15 @@ namespace osuCrypto const size_t num_leaves = ipow(3, size); memcpy(&output[0], &k[0], 16); // output[0] is the start seed - const uint128_t* sCW0 = (uint128_t*)&k[16]; - const uint128_t* sCW1 = (uint128_t*)&k[16 * size + 16]; - const uint128_t* sCW2 = (uint128_t*)&k[16 * 2 * size + 16]; + const block* sCW0 = (block*)&k[16]; + const block* sCW1 = (block*)&k[16 * size + 16]; + const block* sCW2 = (block*)&k[16 * 2 * size + 16]; // inner loop variables related to node expansion // and correction word application - span tmp; + span tmp; size_t idx0, idx1, idx2; - uint8_t cb = 0; + block cb = ZeroBlock; // batching variables related to chunking of inner loop processing // for the purpose of maximizing cache hits @@ -275,10 +275,10 @@ namespace osuCrypto while (idx0 < offset + batch_size) { - cb = static_cast(output[idx0]) & 1; // gets the LSB of the parent - cache[idx0] ^= (cb * sCW0[i]); - cache[idx1] ^= (cb * sCW1[i]); - cache[idx2] ^= (cb * sCW2[i]); + cb = get_lsb(output[idx0]); // gets the LSB of the parent + cache[idx0] ^= (cb & sCW0[i]); + cache[idx1] ^= (cb & sCW1[i]); + cache[idx2] ^= (cb & sCW2[i]); idx0++; idx1++; @@ -297,7 +297,7 @@ namespace osuCrypto const size_t output_length = key.msg_len * num_leaves; const size_t msg_len = key.msg_len; - uint128_t* outputCW = (uint128_t*)&k[16 * 3 * size + 16]; + block* outputCW = (block*)&k[16 * 3 * size + 16]; ExtendOutput(prf_keys, output, cache, num_leaves, output_length); size_t j = 0; @@ -307,7 +307,7 @@ namespace osuCrypto // which is the case internally in ExtendOutput. It would be good // to remove this assumption however using memcpy is costly... - if (cache[i * msg_len] & uint128_t{ 1 }) // parent control bit + if (get_lsb(cache[i * msg_len]) != ZeroBlock) // parent control bit { for (j = 0; j < msg_len; j++) output[i * msg_len + j] ^= outputCW[j]; diff --git a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h index 0e3f96eb..a81e5c48 100644 --- a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h +++ b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h @@ -21,7 +21,7 @@ namespace osuCrypto PRFKeys& prf_keys, size_t domain_size, size_t index, - span msg_blocks, + span msg_blocks, size_t msg_block_len, DPFKey& k0, DPFKey& k1, @@ -29,7 +29,7 @@ namespace osuCrypto void DPFFullDomainEval( DPFKey& k, - span cache, - span output); + span cache, + span output); } diff --git a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp index ce57ed9d..2a596a8f 100644 --- a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp +++ b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp @@ -20,7 +20,7 @@ namespace osuCrypto return prng.get() % (size_t)MAXRANDINDEX; } //using int128_t = uint128_t; - uint128_t randMsg(PRNG& prng) + block randMsg(PRNG& prng) { return prng.get(); //uint128_t msg; @@ -30,11 +30,11 @@ namespace osuCrypto double benchmark_dpfGen() { - size_t num_leaves = ipow(3, FULLEVALDOMAIN); + //size_t num_leaves = ipow(3, FULLEVALDOMAIN); size_t size = FULLEVALDOMAIN; // evaluation will result in 3^size points PRNG prng(block(3423423)); size_t secret_index = randIndex(prng); - uint128_t secret_msg = randMsg(prng); + block secret_msg = randMsg(prng); size_t msg_len = 1; PRFKeys prf_keys; @@ -45,7 +45,7 @@ namespace osuCrypto clock_t t; t = clock(); - DPFGen(prf_keys, size, secret_index, span(&secret_msg,1), msg_len, kA, kB, prng); + DPFGen(prf_keys, size, secret_index, span(&secret_msg,1), msg_len, kA, kB, prng); t = clock() - t; double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms @@ -63,14 +63,14 @@ namespace osuCrypto PRFKeys prf_keys; prf_keys.gen(prng); - AlignedUnVector data_in(num_leaves * MESSAGESIZE); - AlignedUnVector data_out(num_leaves * MESSAGESIZE); - AlignedUnVector data_tmp(num_leaves * MESSAGESIZE); - AlignedUnVector tmp; + AlignedUnVector data_in(num_leaves * MESSAGESIZE); + AlignedUnVector data_out(num_leaves * MESSAGESIZE); + AlignedUnVector data_tmp(num_leaves * MESSAGESIZE); + AlignedUnVector tmp; // fill with unique data for (size_t i = 0; i < num_leaves * MESSAGESIZE; i++) - data_tmp[i] = (uint128_t)i; + data_tmp[i] = block(i); // make the input data pseudorandom for correct timing PRFBatchEval(prf_keys.prf_key0, data_tmp, data_in, num_leaves * MESSAGESIZE); diff --git a/libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h b/libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h index 363ac56e..df53f946 100644 --- a/libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h +++ b/libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h @@ -36,18 +36,15 @@ namespace osuCrypto //void DestroyPRFKey(struct PRFKeys* prf_keys); // XOR with input to prevent inversion using Davies–Meyer construction - inline void PRFEval(EVP_CIPHER_CTX& ctx, uint128_t& input, uint128_t& outputs) + inline void PRFEval(EVP_CIPHER_CTX& ctx, block& input, block& outputs) { - block in, out; - copyBytes(in, input); - out = ctx.hashBlock(in); - copyBytes(outputs, out); + outputs = ctx.hashBlock(input); } // PRF used to expand the DPF tree. Just a call to AES-ECB. // Note: we use ECB-mode (instead of CTR) as we want to manage each block separately. // XOR with input to prevent inversion using Davies–Meyer construction - inline void PRFBatchEval(EVP_CIPHER_CTX& ctx, span input, span outputs, u64 num_blocks) + inline void PRFBatchEval(EVP_CIPHER_CTX& ctx, span input, span outputs, u64 num_blocks) { if (num_blocks > input.size()) throw RTE_LOC; @@ -59,8 +56,8 @@ namespace osuCrypto // extends the output by the provided factor using the PRG inline void ExtendOutput( PRFKeys& prf_keys, - span output, - span cache, + span output, + span cache, const size_t output_size, const size_t new_output_size) { @@ -75,7 +72,7 @@ namespace osuCrypto for (size_t i = 0; i < output_size; i++) { for (size_t j = 0; j < factor; j++) - cache[i * factor + j] = output[i] ^ uint128_t{ j }; + cache[i * factor + j] = output[i] ^ block(0, j); } PRFBatchEval(prf_keys.prf_key_ext, cache, output, new_output_size); diff --git a/libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h b/libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h index d6648a1d..a3c3a938 100644 --- a/libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h +++ b/libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h @@ -9,14 +9,14 @@ namespace osuCrypto { - static inline uint128_t flip_lsb(uint128_t input) + static inline block flip_lsb(block input) { - return input ^ uint128_t{ 1 }; + return input ^ block(0, 1); } - static inline uint128_t get_lsb(uint128_t input) + static inline block get_lsb(block input) { - return input & uint128_t{ 1 }; + return block::allSame(-(input.get(0) & 1)); } static inline int get_trit(uint64_t x, int size, int t) @@ -31,7 +31,7 @@ namespace osuCrypto return ternary[t]; } - static inline int get_bit(uint128_t x, int size, int b) + static inline int get_bit(block x, int size, int b) { return *oc::BitIterator((u8*)&x, (size - b)); //return ((x) >> (size - b)) & 1; diff --git a/libOTe/Tools/Foleage/uint128.h b/libOTe/Tools/Foleage/uint128.h index 05fa058b..ed9fdd70 100644 --- a/libOTe/Tools/Foleage/uint128.h +++ b/libOTe/Tools/Foleage/uint128.h @@ -1,790 +1,790 @@ +//// +//// Copyright 2017 The Abseil Authors. +//// +//// Licensed under the Apache License, Version 2.0 (the "License"); +//// you may not use this file except in compliance with the License. +//// You may obtain a copy of the License at +//// +//// https://www.apache.org/licenses/LICENSE-2.0 +//// +//// Unless required by applicable law or agreed to in writing, software +//// distributed under the License is distributed on an "AS IS" BASIS, +//// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +//// See the License for the specific language governing permissions and +//// limitations under the License. +//// +//// ----------------------------------------------------------------------------- +//// File: int128_t.h +//// ----------------------------------------------------------------------------- +//// +//// This header file defines 128-bit integer types, `uint128_t` and `int128_t`. // -// Copyright 2017 The Abseil Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// ----------------------------------------------------------------------------- -// File: int128_t.h -// ----------------------------------------------------------------------------- -// -// This header file defines 128-bit integer types, `uint128_t` and `int128_t`. - -#ifndef ABSL_INT128_H_ -#define ABSL_INT128_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define ABSL_IS_LITTLE_ENDIAN -#if defined(_MSC_VER) -// In very old versions of MSVC and when the /Zc:wchar_t flag is off, wchar_t is -// a typedef for unsigned short. Otherwise wchar_t is mapped to the __wchar_t -// builtin type. We need to make sure not to define operator wchar_t() -// alongside operator unsigned short() in these instances. -#define ABSL_INTERNAL_WCHAR_T __wchar_t -#if defined(_M_X64) -#include -#pragma intrinsic(_umul128) -#endif // defined(_M_X64) -#else // defined(_MSC_VER) -#define ABSL_INTERNAL_WCHAR_T wchar_t -#endif // defined(_MSC_VER) - -#ifdef _WIN32 -#ifdef abslint128_t_EXPORTS -#define ABSL_DLL __declspec(dllexport) -#else -#define ABSL_DLL __declspec(dllimport) -#endif -#else // _WIN32 -#define ABSL_DLL -#endif // _WIN32 - -// ABSL_ATTRIBUTE_ALWAYS_INLINE -// ABSL_ATTRIBUTE_NOINLINE -// -// Forces functions to either inline or not inline. Introduced in gcc 3.1. -#if defined(__GNUC__) || defined(__clang__) -#define ABSL_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) -#elif defined(_MSC_VER) && !__INTEL_COMPILER && _MSC_VER >= 1310 // since Visual Studio .NET 2003 -#define ABSL_ATTRIBUTE_ALWAYS_INLINE inline __forceinline -#else -#define ABSL_ATTRIBUTE_ALWAYS_INLINE inline -#endif - -// ABSL_INTERNAL_ASSUME(cond) -// Informs the compiler than a condition is always true and that it can assume -// it to be true for optimization purposes. The call has undefined behavior if -// the condition is false. -// In !NDEBUG mode, the condition is checked with an assert(). -// NOTE: The expression must not have side effects, as it will only be evaluated -// in some compilation modes and not others. -// -// Example: -// -// int x = ...; -// ABSL_INTERNAL_ASSUME(x >= 0); -// // The compiler can optimize the division to a simple right shift using the -// // assumption specified above. -// int y = x / 16; -// - -#if defined(_MSC_VER) -#define ABSL_INTERNAL_ASSUME(cond) __assume(cond) -#else -#define ABSL_INTERNAL_ASSUME(cond) -#endif - -namespace absl { - - - // uint128_t - // - // An unsigned 128-bit integer type. The API is meant to mimic an intrinsic type - // as closely as is practical, including exhibiting undefined behavior in - // analogous cases (e.g. division by zero). This type is intended to be a - // drop-in replacement once C++ supports an intrinsic `uint128_t_t` type; when - // that occurs, existing well-behaved uses of `uint128_t` will continue to work - // using that new type. - // - // Note: code written with this type will continue to compile once `uint128_t_t` - // is introduced, provided the replacement helper functions - // `Uint128(Low|High)64()` and `MakeUint128()` are made. - // - // A `uint128_t` supports the following: - // - // * Implicit construction from integral types - // * Explicit conversion to integral types - // - // Additionally, if your compiler supports `__int128_t`, `uint128_t` is - // interoperable with that type. (Abseil checks for this compatibility through - // the `ABSL_HAVE_INTRINSIC_INT128` macro.) - // - // However, a `uint128_t` differs from intrinsic integral types in the following - // ways: - // - // * Errors on implicit conversions that do not preserve value (such as - // loss of precision when converting to float values). - // * Requires explicit construction from and conversion to floating point - // types. - // * Conversion to integral types requires an explicit static_cast() to - // mimic use of the `-Wnarrowing` compiler flag. - // * The alignment requirement of `uint128_t` may differ from that of an - // intrinsic 128-bit integer type depending on platform and build - // configuration. - // - // Example: - // - // float y = absl::Uint128Max(); // Error. uint128_t cannot be implicitly - // // converted to float. - // - // absl::uint128_t v; - // uint64_t i = v; // Error - // uint64_t i = static_cast(v); // OK - // - class -#if defined(ABSL_HAVE_INTRINSIC_INT128) - alignas(unsigned __int128_t) -#endif // ABSL_HAVE_INTRINSIC_INT128 - uint128_t { - public: - uint128_t() = default; - - // Constructors from arithmetic types - constexpr uint128_t(int v); // NOLINT(runtime/explicit) - constexpr uint128_t(unsigned int v); // NOLINT(runtime/explicit) - constexpr uint128_t(long v); // NOLINT(runtime/int) - constexpr uint128_t(unsigned long v); // NOLINT(runtime/int) - constexpr uint128_t(long long v); // NOLINT(runtime/int) - constexpr uint128_t(unsigned long long v); // NOLINT(runtime/int) -#ifdef ABSL_HAVE_INTRINSIC_INT128 - constexpr uint128_t(__int128_t v); // NOLINT(runtime/explicit) - constexpr uint128_t(unsigned __int128_t v); // NOLINT(runtime/explicit) -#endif // ABSL_HAVE_INTRINSIC_INT128 - explicit uint128_t(float v); - explicit uint128_t(double v); - explicit uint128_t(long double v); - - // Assignment operators from arithmetic types - uint128_t& operator=(int v); - uint128_t& operator=(unsigned int v); - uint128_t& operator=(long v); // NOLINT(runtime/int) - uint128_t& operator=(unsigned long v); // NOLINT(runtime/int) - uint128_t& operator=(long long v); // NOLINT(runtime/int) - uint128_t& operator=(unsigned long long v); // NOLINT(runtime/int) -#ifdef ABSL_HAVE_INTRINSIC_INT128 - uint128_t& operator=(__int128_t v); - uint128_t& operator=(unsigned __int128_t v); -#endif // ABSL_HAVE_INTRINSIC_INT128 - - // Conversion operators to other arithmetic types - constexpr explicit operator bool() const; - constexpr explicit operator char() const; - constexpr explicit operator signed char() const; - constexpr explicit operator unsigned char() const; - constexpr explicit operator char16_t() const; - constexpr explicit operator char32_t() const; - constexpr explicit operator ABSL_INTERNAL_WCHAR_T() const; - constexpr explicit operator short() const; // NOLINT(runtime/int) - // NOLINTNEXTLINE(runtime/int) - constexpr explicit operator unsigned short() const; - constexpr explicit operator int() const; - constexpr explicit operator unsigned int() const; - constexpr explicit operator long() const; // NOLINT(runtime/int) - // NOLINTNEXTLINE(runtime/int) - constexpr explicit operator unsigned long() const; - // NOLINTNEXTLINE(runtime/int) - constexpr explicit operator long long() const; - // NOLINTNEXTLINE(runtime/int) - constexpr explicit operator unsigned long long() const; -#ifdef ABSL_HAVE_INTRINSIC_INT128 - constexpr explicit operator __int128_t() const; - constexpr explicit operator unsigned __int128_t() const; -#endif // ABSL_HAVE_INTRINSIC_INT128 - explicit operator float() const; - explicit operator double() const; - explicit operator long double() const; - - // Trivial copy constructor, assignment operator and destructor. - - // Arithmetic operators. - uint128_t& operator+=(uint128_t other); - uint128_t& operator-=(uint128_t other); - uint128_t& operator*=(uint128_t other); - // Long division/modulo for uint128_t. - uint128_t& operator/=(uint128_t other); - uint128_t& operator%=(uint128_t other); - uint128_t operator++(int); - uint128_t operator--(int); - uint128_t& operator<<=(int); - uint128_t& operator>>=(int); - uint128_t& operator&=(uint128_t other); - uint128_t& operator|=(uint128_t other); - uint128_t& operator^=(uint128_t other); - uint128_t& operator++(); - uint128_t& operator--(); - - // Uint128Low64() - // - // Returns the lower 64-bit value of a `uint128_t` value. - friend constexpr uint64_t Uint128Low64(uint128_t v); - - // Uint128High64() - // - // Returns the higher 64-bit value of a `uint128_t` value. - friend constexpr uint64_t Uint128High64(uint128_t v); - - // MakeUInt128() - // - // Constructs a `uint128_t` numeric value from two 64-bit unsigned integers. - // Note that this factory function is the only way to construct a `uint128_t` - // from integer values greater than 2^64. - // - // Example: - // - // absl::uint128_t big = absl::MakeUint128(1, 0); - friend constexpr uint128_t MakeUint128(uint64_t high, uint64_t low); - - // Uint128Max() - // - // Returns the highest value for a 128-bit unsigned integer. - friend constexpr uint128_t Uint128Max(); - - // Support for absl::Hash. - template - friend H AbslHashValue(H h, uint128_t v) { - return H::combine(std::move(h), Uint128High64(v), Uint128Low64(v)); - } - - // Combined division/modulo for a 128-bit unsigned integer. - static void DivMod(uint128_t dividend, uint128_t divisor, uint128_t* quotient_ret, - uint128_t* remainder_ret); - - static std::string ToFormattedString(uint128_t v, std::ios_base::fmtflags flags = std::ios_base::fmtflags()); - - static std::string ToString(uint128_t v); - - private: - constexpr uint128_t(uint64_t high, uint64_t low); - - // TODO(strel) Update implementation to use __int128_t once all users of - // uint128_t are fixed to not depend on alignof(uint128_t) == 8. Also add - // alignas(16) to class definition to keep alignment consistent across - // platforms. -#if defined(ABSL_IS_LITTLE_ENDIAN) - uint64_t lo_; - uint64_t hi_; -#elif defined(ABSL_IS_BIG_ENDIAN) - uint64_t hi_; - uint64_t lo_; -#else // byte order -#error "Unsupported byte order: must be little-endian or big-endian." -#endif // byte order - }; - - // allow uint128_t to be logged - std::ostream& operator<<(std::ostream& os, uint128_t v); - - // TODO(strel) add operator>>(std::istream&, uint128_t) - - constexpr uint128_t Uint128Max() { - return uint128_t((std::numeric_limits::max)(), - (std::numeric_limits::max)()); - } - -} // namespace absl - -// Specialized numeric_limits for uint128_t. -namespace std { - template <> - class numeric_limits { - public: - static constexpr bool is_specialized = true; - static constexpr bool is_signed = false; - static constexpr bool is_integer = true; - static constexpr bool is_exact = true; - static constexpr bool has_infinity = false; - static constexpr bool has_quiet_NaN = false; - static constexpr bool has_signaling_NaN = false; - static constexpr float_denorm_style has_denorm = denorm_absent; - static constexpr bool has_denorm_loss = false; - static constexpr float_round_style round_style = round_toward_zero; - static constexpr bool is_iec559 = false; - static constexpr bool is_bounded = true; - static constexpr bool is_modulo = true; - static constexpr int digits = 128; - static constexpr int digits10 = 38; - static constexpr int max_digits10 = 0; - static constexpr int radix = 2; - static constexpr int min_exponent = 0; - static constexpr int min_exponent10 = 0; - static constexpr int max_exponent = 0; - static constexpr int max_exponent10 = 0; -#ifdef ABSL_HAVE_INTRINSIC_INT128 - static constexpr bool traps = numeric_limits::traps; -#else // ABSL_HAVE_INTRINSIC_INT128 - static constexpr bool traps = numeric_limits::traps; -#endif // ABSL_HAVE_INTRINSIC_INT128 - static constexpr bool tinyness_before = false; - - static constexpr absl::uint128_t(min)() { return 0; } - static constexpr absl::uint128_t lowest() { return 0; } - static constexpr absl::uint128_t(max)() { return absl::Uint128Max(); } - static constexpr absl::uint128_t epsilon() { return 0; } - static constexpr absl::uint128_t round_error() { return 0; } - static constexpr absl::uint128_t infinity() { return 0; } - static constexpr absl::uint128_t quiet_NaN() { return 0; } - static constexpr absl::uint128_t signaling_NaN() { return 0; } - static constexpr absl::uint128_t denorm_min() { return 0; } - }; -} // namespace std - - -// -------------------------------------------------------------------------- -// Implementation details follow -// -------------------------------------------------------------------------- -namespace absl { - - constexpr uint128_t MakeUint128(uint64_t high, uint64_t low) { - return uint128_t(high, low); - } - - // Assignment from integer types. - - inline uint128_t& uint128_t::operator=(int v) { return *this = uint128_t(v); } - - inline uint128_t& uint128_t::operator=(unsigned int v) { - return *this = uint128_t(v); - } - - inline uint128_t& uint128_t::operator=(long v) { // NOLINT(runtime/int) - return *this = uint128_t(v); - } - - // NOLINTNEXTLINE(runtime/int) - inline uint128_t& uint128_t::operator=(unsigned long v) { - return *this = uint128_t(v); - } - - // NOLINTNEXTLINE(runtime/int) - inline uint128_t& uint128_t::operator=(long long v) { - return *this = uint128_t(v); - } - - // NOLINTNEXTLINE(runtime/int) - inline uint128_t& uint128_t::operator=(unsigned long long v) { - return *this = uint128_t(v); - } - -#ifdef ABSL_HAVE_INTRINSIC_INT128 - inline uint128_t& uint128_t::operator=(__int128_t v) { - return *this = uint128_t(v); - } - - inline uint128_t& uint128_t::operator=(unsigned __int128_t v) { - return *this = uint128_t(v); - } -#endif // ABSL_HAVE_INTRINSIC_INT128 - - - // Arithmetic operators. - - uint128_t operator<<(uint128_t lhs, int amount); - uint128_t operator>>(uint128_t lhs, int amount); - uint128_t operator+(uint128_t lhs, uint128_t rhs); - uint128_t operator-(uint128_t lhs, uint128_t rhs); - uint128_t operator*(uint128_t lhs, uint128_t rhs); - uint128_t operator/(uint128_t lhs, uint128_t rhs); - uint128_t operator%(uint128_t lhs, uint128_t rhs); - - inline uint128_t& uint128_t::operator<<=(int amount) { - *this = *this << amount; - return *this; - } - - inline uint128_t& uint128_t::operator>>=(int amount) { - *this = *this >> amount; - return *this; - } - - inline uint128_t& uint128_t::operator+=(uint128_t other) { - *this = *this + other; - return *this; - } - - inline uint128_t& uint128_t::operator-=(uint128_t other) { - *this = *this - other; - return *this; - } - - inline uint128_t& uint128_t::operator*=(uint128_t other) { - *this = *this * other; - return *this; - } - - inline uint128_t& uint128_t::operator/=(uint128_t other) { - *this = *this / other; - return *this; - } - - inline uint128_t& uint128_t::operator%=(uint128_t other) { - *this = *this % other; - return *this; - } - - constexpr uint64_t Uint128Low64(uint128_t v) { return v.lo_; } - - constexpr uint64_t Uint128High64(uint128_t v) { return v.hi_; } - - // Constructors from integer types. - -#if defined(ABSL_IS_LITTLE_ENDIAN) - - constexpr uint128_t::uint128_t(uint64_t high, uint64_t low) - : lo_{ low }, hi_{ high } { - } - - constexpr uint128_t::uint128_t(int v) - : lo_{ static_cast(v) }, - hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { - } - constexpr uint128_t::uint128_t(long v) // NOLINT(runtime/int) - : lo_{ static_cast(v) }, - hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { - } - constexpr uint128_t::uint128_t(long long v) // NOLINT(runtime/int) - : lo_{ static_cast(v) }, - hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { - } - - constexpr uint128_t::uint128_t(unsigned int v) : lo_{ v }, hi_{ 0 } {} - // NOLINTNEXTLINE(runtime/int) - constexpr uint128_t::uint128_t(unsigned long v) : lo_{ v }, hi_{ 0 } {} - // NOLINTNEXTLINE(runtime/int) - constexpr uint128_t::uint128_t(unsigned long long v) : lo_{ v }, hi_{ 0 } {} - -#ifdef ABSL_HAVE_INTRINSIC_INT128 - constexpr uint128_t::uint128_t(__int128_t v) - : lo_{ static_cast(v & ~uint64_t{0}) }, - hi_{ static_cast(static_cast(v) >> 64) } { - } - constexpr uint128_t::uint128_t(unsigned __int128_t v) - : lo_{ static_cast(v & ~uint64_t{0}) }, - hi_{ static_cast(v >> 64) } { - } -#endif // ABSL_HAVE_INTRINSIC_INT128 - -#elif defined(ABSL_IS_BIG_ENDIAN) - - constexpr uint128_t::uint128_t(uint64_t high, uint64_t low) - : hi_{ high }, lo_{ low } { - } - - constexpr uint128_t::uint128_t(int v) - : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, - lo_{ static_cast(v) } { - } - constexpr uint128_t::uint128_t(long v) // NOLINT(runtime/int) - : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, - lo_{ static_cast(v) } { - } - constexpr uint128_t::uint128_t(long long v) // NOLINT(runtime/int) - : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, - lo_{ static_cast(v) } { - } - - constexpr uint128_t::uint128_t(unsigned int v) : hi_{ 0 }, lo_{ v } {} - // NOLINTNEXTLINE(runtime/int) - constexpr uint128_t::uint128_t(unsigned long v) : hi_{ 0 }, lo_{ v } {} - // NOLINTNEXTLINE(runtime/int) - constexpr uint128_t::uint128_t(unsigned long long v) : hi_{ 0 }, lo_{ v } {} - -#ifdef ABSL_HAVE_INTRINSIC_INT128 - constexpr uint128_t::uint128_t(__int128_t v) - : hi_{ static_cast(static_cast(v) >> 64) }, - lo_{ static_cast(v & ~uint64_t{0}) } { - } - constexpr uint128_t::uint128_t(unsigned __int128_t v) - : hi_{ static_cast(v >> 64) }, - lo_{ static_cast(v & ~uint64_t{0}) } { - } -#endif // ABSL_HAVE_INTRINSIC_INT128 - - constexpr uint128_t::uint128_t(int128_t v) - : hi_{ static_cast(Int128High64(v)) }, lo_{ Int128Low64(v) } { - } - -#else // byte order -#error "Unsupported byte order: must be little-endian or big-endian." -#endif // byte order - -// Conversion operators to integer types. - - constexpr uint128_t::operator bool() const { return lo_ || hi_; } - - constexpr uint128_t::operator char() const { return static_cast(lo_); } - - constexpr uint128_t::operator signed char() const { - return static_cast(lo_); - } - - constexpr uint128_t::operator unsigned char() const { - return static_cast(lo_); - } - - constexpr uint128_t::operator char16_t() const { - return static_cast(lo_); - } - - constexpr uint128_t::operator char32_t() const { - return static_cast(lo_); - } - - constexpr uint128_t::operator ABSL_INTERNAL_WCHAR_T() const { - return static_cast(lo_); - } - - // NOLINTNEXTLINE(runtime/int) - constexpr uint128_t::operator short() const { return static_cast(lo_); } - - constexpr uint128_t::operator unsigned short() const { // NOLINT(runtime/int) - return static_cast(lo_); // NOLINT(runtime/int) - } - - constexpr uint128_t::operator int() const { return static_cast(lo_); } - - constexpr uint128_t::operator unsigned int() const { - return static_cast(lo_); - } - - // NOLINTNEXTLINE(runtime/int) - constexpr uint128_t::operator long() const { return static_cast(lo_); } - - constexpr uint128_t::operator unsigned long() const { // NOLINT(runtime/int) - return static_cast(lo_); // NOLINT(runtime/int) - } - - constexpr uint128_t::operator long long() const { // NOLINT(runtime/int) - return static_cast(lo_); // NOLINT(runtime/int) - } - - constexpr uint128_t::operator unsigned long long() const { // NOLINT(runtime/int) - return static_cast(lo_); // NOLINT(runtime/int) - } - -#ifdef ABSL_HAVE_INTRINSIC_INT128 - constexpr uint128_t::operator __int128_t() const { - return (static_cast<__int128_t>(hi_) << 64) + lo_; - } - - constexpr uint128_t::operator unsigned __int128_t() const { - return (static_cast(hi_) << 64) + lo_; - } -#endif // ABSL_HAVE_INTRINSIC_INT128 - - // Conversion operators to floating point types. - - inline uint128_t::operator float() const { - return static_cast(lo_) + std::ldexp(static_cast(hi_), 64); - } - - inline uint128_t::operator double() const { - return static_cast(lo_) + std::ldexp(static_cast(hi_), 64); - } - - inline uint128_t::operator long double() const { - return static_cast(lo_) + - std::ldexp(static_cast(hi_), 64); - } - - // Comparison operators. - - inline bool operator==(uint128_t lhs, uint128_t rhs) { - return (Uint128Low64(lhs) == Uint128Low64(rhs) && - Uint128High64(lhs) == Uint128High64(rhs)); - } - - inline bool operator!=(uint128_t lhs, uint128_t rhs) { - return !(lhs == rhs); - } - - inline bool operator<(uint128_t lhs, uint128_t rhs) { -#ifdef ABSL_HAVE_INTRINSIC_INT128 - return static_cast(lhs) < - static_cast(rhs); -#else - return (Uint128High64(lhs) == Uint128High64(rhs)) - ? (Uint128Low64(lhs) < Uint128Low64(rhs)) - : (Uint128High64(lhs) < Uint128High64(rhs)); -#endif - } - - inline bool operator>(uint128_t lhs, uint128_t rhs) { return rhs < lhs; } - - inline bool operator<=(uint128_t lhs, uint128_t rhs) { return !(rhs < lhs); } - - inline bool operator>=(uint128_t lhs, uint128_t rhs) { return !(lhs < rhs); } - - // Unary operators. - - inline uint128_t operator-(uint128_t val) { - uint64_t hi = ~Uint128High64(val); - uint64_t lo = ~Uint128Low64(val) + 1; - if (lo == 0) ++hi; // carry - return MakeUint128(hi, lo); - } - - inline bool operator!(uint128_t val) { - return !Uint128High64(val) && !Uint128Low64(val); - } - - // Logical operators. - - inline uint128_t operator~(uint128_t val) { - return MakeUint128(~Uint128High64(val), ~Uint128Low64(val)); - } - - inline uint128_t operator|(uint128_t lhs, uint128_t rhs) { - return MakeUint128(Uint128High64(lhs) | Uint128High64(rhs), - Uint128Low64(lhs) | Uint128Low64(rhs)); - } - - inline uint128_t operator&(uint128_t lhs, uint128_t rhs) { - return MakeUint128(Uint128High64(lhs) & Uint128High64(rhs), - Uint128Low64(lhs) & Uint128Low64(rhs)); - } - - inline uint128_t operator^(uint128_t lhs, uint128_t rhs) { - return MakeUint128(Uint128High64(lhs) ^ Uint128High64(rhs), - Uint128Low64(lhs) ^ Uint128Low64(rhs)); - } - - inline uint128_t& uint128_t::operator|=(uint128_t other) { - hi_ |= other.hi_; - lo_ |= other.lo_; - return *this; - } - - inline uint128_t& uint128_t::operator&=(uint128_t other) { - hi_ &= other.hi_; - lo_ &= other.lo_; - return *this; - } - - inline uint128_t& uint128_t::operator^=(uint128_t other) { - hi_ ^= other.hi_; - lo_ ^= other.lo_; - return *this; - } - - // Arithmetic operators. - - inline uint128_t operator<<(uint128_t lhs, int amount) { -#ifdef ABSL_HAVE_INTRINSIC_INT128 - return static_cast(lhs) << amount; -#else - // uint64_t shifts of >= 64 are undefined, so we will need some - // special-casing. - if (amount < 64) { - if (amount != 0) { - return MakeUint128( - (Uint128High64(lhs) << amount) | (Uint128Low64(lhs) >> (64 - amount)), - Uint128Low64(lhs) << amount); - } - return lhs; - } - return MakeUint128(Uint128Low64(lhs) << (amount - 64), 0); -#endif - } - - inline uint128_t operator>>(uint128_t lhs, int amount) { -#ifdef ABSL_HAVE_INTRINSIC_INT128 - return static_cast(lhs) >> amount; -#else - // uint64_t shifts of >= 64 are undefined, so we will need some - // special-casing. - if (amount < 64) { - if (amount != 0) { - return MakeUint128(Uint128High64(lhs) >> amount, - (Uint128Low64(lhs) >> amount) | - (Uint128High64(lhs) << (64 - amount))); - } - return lhs; - } - return MakeUint128(0, Uint128High64(lhs) >> (amount - 64)); -#endif - } - - inline uint128_t operator+(uint128_t lhs, uint128_t rhs) { - uint128_t result = MakeUint128(Uint128High64(lhs) + Uint128High64(rhs), - Uint128Low64(lhs) + Uint128Low64(rhs)); - if (Uint128Low64(result) < Uint128Low64(lhs)) { // check for carry - return MakeUint128(Uint128High64(result) + 1, Uint128Low64(result)); - } - return result; - } - - inline uint128_t operator-(uint128_t lhs, uint128_t rhs) { - uint128_t result = MakeUint128(Uint128High64(lhs) - Uint128High64(rhs), - Uint128Low64(lhs) - Uint128Low64(rhs)); - if (Uint128Low64(lhs) < Uint128Low64(rhs)) { // check for carry - return MakeUint128(Uint128High64(result) - 1, Uint128Low64(result)); - } - return result; - } - - inline uint128_t operator*(uint128_t lhs, uint128_t rhs) { -#if defined(ABSL_HAVE_INTRINSIC_INT128) - // TODO(strel) Remove once alignment issues are resolved and unsigned __int128_t - // can be used for uint128_t storage. - return static_cast(lhs) * - static_cast(rhs); -#elif defined(_MSC_VER) && defined(_M_X64) - uint64_t carry; - uint64_t low = _umul128(Uint128Low64(lhs), Uint128Low64(rhs), &carry); - return MakeUint128(Uint128Low64(lhs) * Uint128High64(rhs) + - Uint128High64(lhs) * Uint128Low64(rhs) + carry, - low); -#else // ABSL_HAVE_INTRINSIC128 - uint64_t a32 = Uint128Low64(lhs) >> 32; - uint64_t a00 = Uint128Low64(lhs) & 0xffffffff; - uint64_t b32 = Uint128Low64(rhs) >> 32; - uint64_t b00 = Uint128Low64(rhs) & 0xffffffff; - uint128_t result = - MakeUint128(Uint128High64(lhs) * Uint128Low64(rhs) + - Uint128Low64(lhs) * Uint128High64(rhs) + a32 * b32, - a00 * b00); - result += uint128_t(a32 * b00) << 32; - result += uint128_t(a00 * b32) << 32; - return result; -#endif // ABSL_HAVE_INTRINSIC128 - } - - // Increment/decrement operators. - - inline uint128_t uint128_t::operator++(int) { - uint128_t tmp(*this); - *this += 1; - return tmp; - } - - inline uint128_t uint128_t::operator--(int) { - uint128_t tmp(*this); - *this -= 1; - return tmp; - } - - inline uint128_t& uint128_t::operator++() { - *this += 1; - return *this; - } - - inline uint128_t& uint128_t::operator--() { - *this -= 1; - return *this; - } - - - -} // namespace absl - -#undef ABSL_INTERNAL_WCHAR_T - -#endif // ABSL_INT128_H_ \ No newline at end of file +//#ifndef ABSL_INT128_H_ +//#define ABSL_INT128_H_ +// +//#include +//#include +//#include +//#include +//#include +//#include +//#include +//#include +//#include +// +//#define ABSL_IS_LITTLE_ENDIAN +//#if defined(_MSC_VER) +//// In very old versions of MSVC and when the /Zc:wchar_t flag is off, wchar_t is +//// a typedef for unsigned short. Otherwise wchar_t is mapped to the __wchar_t +//// builtin type. We need to make sure not to define operator wchar_t() +//// alongside operator unsigned short() in these instances. +//#define ABSL_INTERNAL_WCHAR_T __wchar_t +//#if defined(_M_X64) +//#include +//#pragma intrinsic(_umul128) +//#endif // defined(_M_X64) +//#else // defined(_MSC_VER) +//#define ABSL_INTERNAL_WCHAR_T wchar_t +//#endif // defined(_MSC_VER) +// +//#ifdef _WIN32 +//#ifdef abslint128_t_EXPORTS +//#define ABSL_DLL __declspec(dllexport) +//#else +//#define ABSL_DLL __declspec(dllimport) +//#endif +//#else // _WIN32 +//#define ABSL_DLL +//#endif // _WIN32 +// +//// ABSL_ATTRIBUTE_ALWAYS_INLINE +//// ABSL_ATTRIBUTE_NOINLINE +//// +//// Forces functions to either inline or not inline. Introduced in gcc 3.1. +//#if defined(__GNUC__) || defined(__clang__) +//#define ABSL_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) +//#elif defined(_MSC_VER) && !__INTEL_COMPILER && _MSC_VER >= 1310 // since Visual Studio .NET 2003 +//#define ABSL_ATTRIBUTE_ALWAYS_INLINE inline __forceinline +//#else +//#define ABSL_ATTRIBUTE_ALWAYS_INLINE inline +//#endif +// +//// ABSL_INTERNAL_ASSUME(cond) +//// Informs the compiler than a condition is always true and that it can assume +//// it to be true for optimization purposes. The call has undefined behavior if +//// the condition is false. +//// In !NDEBUG mode, the condition is checked with an assert(). +//// NOTE: The expression must not have side effects, as it will only be evaluated +//// in some compilation modes and not others. +//// +//// Example: +//// +//// int x = ...; +//// ABSL_INTERNAL_ASSUME(x >= 0); +//// // The compiler can optimize the division to a simple right shift using the +//// // assumption specified above. +//// int y = x / 16; +//// +// +//#if defined(_MSC_VER) +//#define ABSL_INTERNAL_ASSUME(cond) __assume(cond) +//#else +//#define ABSL_INTERNAL_ASSUME(cond) +//#endif +// +//namespace absl { +// +// +// // uint128_t +// // +// // An unsigned 128-bit integer type. The API is meant to mimic an intrinsic type +// // as closely as is practical, including exhibiting undefined behavior in +// // analogous cases (e.g. division by zero). This type is intended to be a +// // drop-in replacement once C++ supports an intrinsic `uint128_t_t` type; when +// // that occurs, existing well-behaved uses of `uint128_t` will continue to work +// // using that new type. +// // +// // Note: code written with this type will continue to compile once `uint128_t_t` +// // is introduced, provided the replacement helper functions +// // `Uint128(Low|High)64()` and `MakeUint128()` are made. +// // +// // A `uint128_t` supports the following: +// // +// // * Implicit construction from integral types +// // * Explicit conversion to integral types +// // +// // Additionally, if your compiler supports `__int128_t`, `uint128_t` is +// // interoperable with that type. (Abseil checks for this compatibility through +// // the `ABSL_HAVE_INTRINSIC_INT128` macro.) +// // +// // However, a `uint128_t` differs from intrinsic integral types in the following +// // ways: +// // +// // * Errors on implicit conversions that do not preserve value (such as +// // loss of precision when converting to float values). +// // * Requires explicit construction from and conversion to floating point +// // types. +// // * Conversion to integral types requires an explicit static_cast() to +// // mimic use of the `-Wnarrowing` compiler flag. +// // * The alignment requirement of `uint128_t` may differ from that of an +// // intrinsic 128-bit integer type depending on platform and build +// // configuration. +// // +// // Example: +// // +// // float y = absl::Uint128Max(); // Error. uint128_t cannot be implicitly +// // // converted to float. +// // +// // absl::uint128_t v; +// // uint64_t i = v; // Error +// // uint64_t i = static_cast(v); // OK +// // +// class +//#if defined(ABSL_HAVE_INTRINSIC_INT128) +// alignas(unsigned __int128_t) +//#endif // ABSL_HAVE_INTRINSIC_INT128 +// uint128_t { +// public: +// uint128_t() = default; +// +// // Constructors from arithmetic types +// constexpr uint128_t(int v); // NOLINT(runtime/explicit) +// constexpr uint128_t(unsigned int v); // NOLINT(runtime/explicit) +// constexpr uint128_t(long v); // NOLINT(runtime/int) +// constexpr uint128_t(unsigned long v); // NOLINT(runtime/int) +// constexpr uint128_t(long long v); // NOLINT(runtime/int) +// constexpr uint128_t(unsigned long long v); // NOLINT(runtime/int) +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// constexpr uint128_t(__int128_t v); // NOLINT(runtime/explicit) +// constexpr uint128_t(unsigned __int128_t v); // NOLINT(runtime/explicit) +//#endif // ABSL_HAVE_INTRINSIC_INT128 +// explicit uint128_t(float v); +// explicit uint128_t(double v); +// explicit uint128_t(long double v); +// +// // Assignment operators from arithmetic types +// uint128_t& operator=(int v); +// uint128_t& operator=(unsigned int v); +// uint128_t& operator=(long v); // NOLINT(runtime/int) +// uint128_t& operator=(unsigned long v); // NOLINT(runtime/int) +// uint128_t& operator=(long long v); // NOLINT(runtime/int) +// uint128_t& operator=(unsigned long long v); // NOLINT(runtime/int) +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// uint128_t& operator=(__int128_t v); +// uint128_t& operator=(unsigned __int128_t v); +//#endif // ABSL_HAVE_INTRINSIC_INT128 +// +// // Conversion operators to other arithmetic types +// constexpr explicit operator bool() const; +// constexpr explicit operator char() const; +// constexpr explicit operator signed char() const; +// constexpr explicit operator unsigned char() const; +// constexpr explicit operator char16_t() const; +// constexpr explicit operator char32_t() const; +// constexpr explicit operator ABSL_INTERNAL_WCHAR_T() const; +// constexpr explicit operator short() const; // NOLINT(runtime/int) +// // NOLINTNEXTLINE(runtime/int) +// constexpr explicit operator unsigned short() const; +// constexpr explicit operator int() const; +// constexpr explicit operator unsigned int() const; +// constexpr explicit operator long() const; // NOLINT(runtime/int) +// // NOLINTNEXTLINE(runtime/int) +// constexpr explicit operator unsigned long() const; +// // NOLINTNEXTLINE(runtime/int) +// constexpr explicit operator long long() const; +// // NOLINTNEXTLINE(runtime/int) +// constexpr explicit operator unsigned long long() const; +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// constexpr explicit operator __int128_t() const; +// constexpr explicit operator unsigned __int128_t() const; +//#endif // ABSL_HAVE_INTRINSIC_INT128 +// explicit operator float() const; +// explicit operator double() const; +// explicit operator long double() const; +// +// // Trivial copy constructor, assignment operator and destructor. +// +// // Arithmetic operators. +// uint128_t& operator+=(uint128_t other); +// uint128_t& operator-=(uint128_t other); +// uint128_t& operator*=(uint128_t other); +// // Long division/modulo for uint128_t. +// uint128_t& operator/=(uint128_t other); +// uint128_t& operator%=(uint128_t other); +// uint128_t operator++(int); +// uint128_t operator--(int); +// uint128_t& operator<<=(int); +// uint128_t& operator>>=(int); +// uint128_t& operator&=(uint128_t other); +// uint128_t& operator|=(uint128_t other); +// uint128_t& operator^=(uint128_t other); +// uint128_t& operator++(); +// uint128_t& operator--(); +// +// // Uint128Low64() +// // +// // Returns the lower 64-bit value of a `uint128_t` value. +// friend constexpr uint64_t Uint128Low64(uint128_t v); +// +// // Uint128High64() +// // +// // Returns the higher 64-bit value of a `uint128_t` value. +// friend constexpr uint64_t Uint128High64(uint128_t v); +// +// // MakeUInt128() +// // +// // Constructs a `uint128_t` numeric value from two 64-bit unsigned integers. +// // Note that this factory function is the only way to construct a `uint128_t` +// // from integer values greater than 2^64. +// // +// // Example: +// // +// // absl::uint128_t big = absl::MakeUint128(1, 0); +// friend constexpr uint128_t MakeUint128(uint64_t high, uint64_t low); +// +// // Uint128Max() +// // +// // Returns the highest value for a 128-bit unsigned integer. +// friend constexpr uint128_t Uint128Max(); +// +// // Support for absl::Hash. +// template +// friend H AbslHashValue(H h, uint128_t v) { +// return H::combine(std::move(h), Uint128High64(v), Uint128Low64(v)); +// } +// +// // Combined division/modulo for a 128-bit unsigned integer. +// static void DivMod(uint128_t dividend, uint128_t divisor, uint128_t* quotient_ret, +// uint128_t* remainder_ret); +// +// static std::string ToFormattedString(uint128_t v, std::ios_base::fmtflags flags = std::ios_base::fmtflags()); +// +// static std::string ToString(uint128_t v); +// +// private: +// constexpr uint128_t(uint64_t high, uint64_t low); +// +// // TODO(strel) Update implementation to use __int128_t once all users of +// // uint128_t are fixed to not depend on alignof(uint128_t) == 8. Also add +// // alignas(16) to class definition to keep alignment consistent across +// // platforms. +//#if defined(ABSL_IS_LITTLE_ENDIAN) +// uint64_t lo_; +// uint64_t hi_; +//#elif defined(ABSL_IS_BIG_ENDIAN) +// uint64_t hi_; +// uint64_t lo_; +//#else // byte order +//#error "Unsupported byte order: must be little-endian or big-endian." +//#endif // byte order +// }; +// +// // allow uint128_t to be logged +// std::ostream& operator<<(std::ostream& os, uint128_t v); +// +// // TODO(strel) add operator>>(std::istream&, uint128_t) +// +// constexpr uint128_t Uint128Max() { +// return uint128_t((std::numeric_limits::max)(), +// (std::numeric_limits::max)()); +// } +// +//} // namespace absl +// +//// Specialized numeric_limits for uint128_t. +//namespace std { +// template <> +// class numeric_limits { +// public: +// static constexpr bool is_specialized = true; +// static constexpr bool is_signed = false; +// static constexpr bool is_integer = true; +// static constexpr bool is_exact = true; +// static constexpr bool has_infinity = false; +// static constexpr bool has_quiet_NaN = false; +// static constexpr bool has_signaling_NaN = false; +// static constexpr float_denorm_style has_denorm = denorm_absent; +// static constexpr bool has_denorm_loss = false; +// static constexpr float_round_style round_style = round_toward_zero; +// static constexpr bool is_iec559 = false; +// static constexpr bool is_bounded = true; +// static constexpr bool is_modulo = true; +// static constexpr int digits = 128; +// static constexpr int digits10 = 38; +// static constexpr int max_digits10 = 0; +// static constexpr int radix = 2; +// static constexpr int min_exponent = 0; +// static constexpr int min_exponent10 = 0; +// static constexpr int max_exponent = 0; +// static constexpr int max_exponent10 = 0; +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// static constexpr bool traps = numeric_limits::traps; +//#else // ABSL_HAVE_INTRINSIC_INT128 +// static constexpr bool traps = numeric_limits::traps; +//#endif // ABSL_HAVE_INTRINSIC_INT128 +// static constexpr bool tinyness_before = false; +// +// static constexpr absl::uint128_t(min)() { return 0; } +// static constexpr absl::uint128_t lowest() { return 0; } +// static constexpr absl::uint128_t(max)() { return absl::Uint128Max(); } +// static constexpr absl::uint128_t epsilon() { return 0; } +// static constexpr absl::uint128_t round_error() { return 0; } +// static constexpr absl::uint128_t infinity() { return 0; } +// static constexpr absl::uint128_t quiet_NaN() { return 0; } +// static constexpr absl::uint128_t signaling_NaN() { return 0; } +// static constexpr absl::uint128_t denorm_min() { return 0; } +// }; +//} // namespace std +// +// +//// -------------------------------------------------------------------------- +//// Implementation details follow +//// -------------------------------------------------------------------------- +//namespace absl { +// +// constexpr uint128_t MakeUint128(uint64_t high, uint64_t low) { +// return uint128_t(high, low); +// } +// +// // Assignment from integer types. +// +// inline uint128_t& uint128_t::operator=(int v) { return *this = uint128_t(v); } +// +// inline uint128_t& uint128_t::operator=(unsigned int v) { +// return *this = uint128_t(v); +// } +// +// inline uint128_t& uint128_t::operator=(long v) { // NOLINT(runtime/int) +// return *this = uint128_t(v); +// } +// +// // NOLINTNEXTLINE(runtime/int) +// inline uint128_t& uint128_t::operator=(unsigned long v) { +// return *this = uint128_t(v); +// } +// +// // NOLINTNEXTLINE(runtime/int) +// inline uint128_t& uint128_t::operator=(long long v) { +// return *this = uint128_t(v); +// } +// +// // NOLINTNEXTLINE(runtime/int) +// inline uint128_t& uint128_t::operator=(unsigned long long v) { +// return *this = uint128_t(v); +// } +// +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// inline uint128_t& uint128_t::operator=(__int128_t v) { +// return *this = uint128_t(v); +// } +// +// inline uint128_t& uint128_t::operator=(unsigned __int128_t v) { +// return *this = uint128_t(v); +// } +//#endif // ABSL_HAVE_INTRINSIC_INT128 +// +// +// // Arithmetic operators. +// +// uint128_t operator<<(uint128_t lhs, int amount); +// uint128_t operator>>(uint128_t lhs, int amount); +// uint128_t operator+(uint128_t lhs, uint128_t rhs); +// uint128_t operator-(uint128_t lhs, uint128_t rhs); +// uint128_t operator*(uint128_t lhs, uint128_t rhs); +// uint128_t operator/(uint128_t lhs, uint128_t rhs); +// uint128_t operator%(uint128_t lhs, uint128_t rhs); +// +// inline uint128_t& uint128_t::operator<<=(int amount) { +// *this = *this << amount; +// return *this; +// } +// +// inline uint128_t& uint128_t::operator>>=(int amount) { +// *this = *this >> amount; +// return *this; +// } +// +// inline uint128_t& uint128_t::operator+=(uint128_t other) { +// *this = *this + other; +// return *this; +// } +// +// inline uint128_t& uint128_t::operator-=(uint128_t other) { +// *this = *this - other; +// return *this; +// } +// +// inline uint128_t& uint128_t::operator*=(uint128_t other) { +// *this = *this * other; +// return *this; +// } +// +// inline uint128_t& uint128_t::operator/=(uint128_t other) { +// *this = *this / other; +// return *this; +// } +// +// inline uint128_t& uint128_t::operator%=(uint128_t other) { +// *this = *this % other; +// return *this; +// } +// +// constexpr uint64_t Uint128Low64(uint128_t v) { return v.lo_; } +// +// constexpr uint64_t Uint128High64(uint128_t v) { return v.hi_; } +// +// // Constructors from integer types. +// +//#if defined(ABSL_IS_LITTLE_ENDIAN) +// +// constexpr uint128_t::uint128_t(uint64_t high, uint64_t low) +// : lo_{ low }, hi_{ high } { +// } +// +// constexpr uint128_t::uint128_t(int v) +// : lo_{ static_cast(v) }, +// hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { +// } +// constexpr uint128_t::uint128_t(long v) // NOLINT(runtime/int) +// : lo_{ static_cast(v) }, +// hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { +// } +// constexpr uint128_t::uint128_t(long long v) // NOLINT(runtime/int) +// : lo_{ static_cast(v) }, +// hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { +// } +// +// constexpr uint128_t::uint128_t(unsigned int v) : lo_{ v }, hi_{ 0 } {} +// // NOLINTNEXTLINE(runtime/int) +// constexpr uint128_t::uint128_t(unsigned long v) : lo_{ v }, hi_{ 0 } {} +// // NOLINTNEXTLINE(runtime/int) +// constexpr uint128_t::uint128_t(unsigned long long v) : lo_{ v }, hi_{ 0 } {} +// +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// constexpr uint128_t::uint128_t(__int128_t v) +// : lo_{ static_cast(v & ~uint64_t{0}) }, +// hi_{ static_cast(static_cast(v) >> 64) } { +// } +// constexpr uint128_t::uint128_t(unsigned __int128_t v) +// : lo_{ static_cast(v & ~uint64_t{0}) }, +// hi_{ static_cast(v >> 64) } { +// } +//#endif // ABSL_HAVE_INTRINSIC_INT128 +// +//#elif defined(ABSL_IS_BIG_ENDIAN) +// +// constexpr uint128_t::uint128_t(uint64_t high, uint64_t low) +// : hi_{ high }, lo_{ low } { +// } +// +// constexpr uint128_t::uint128_t(int v) +// : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, +// lo_{ static_cast(v) } { +// } +// constexpr uint128_t::uint128_t(long v) // NOLINT(runtime/int) +// : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, +// lo_{ static_cast(v) } { +// } +// constexpr uint128_t::uint128_t(long long v) // NOLINT(runtime/int) +// : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, +// lo_{ static_cast(v) } { +// } +// +// constexpr uint128_t::uint128_t(unsigned int v) : hi_{ 0 }, lo_{ v } {} +// // NOLINTNEXTLINE(runtime/int) +// constexpr uint128_t::uint128_t(unsigned long v) : hi_{ 0 }, lo_{ v } {} +// // NOLINTNEXTLINE(runtime/int) +// constexpr uint128_t::uint128_t(unsigned long long v) : hi_{ 0 }, lo_{ v } {} +// +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// constexpr uint128_t::uint128_t(__int128_t v) +// : hi_{ static_cast(static_cast(v) >> 64) }, +// lo_{ static_cast(v & ~uint64_t{0}) } { +// } +// constexpr uint128_t::uint128_t(unsigned __int128_t v) +// : hi_{ static_cast(v >> 64) }, +// lo_{ static_cast(v & ~uint64_t{0}) } { +// } +//#endif // ABSL_HAVE_INTRINSIC_INT128 +// +// constexpr uint128_t::uint128_t(int128_t v) +// : hi_{ static_cast(Int128High64(v)) }, lo_{ Int128Low64(v) } { +// } +// +//#else // byte order +//#error "Unsupported byte order: must be little-endian or big-endian." +//#endif // byte order +// +//// Conversion operators to integer types. +// +// constexpr uint128_t::operator bool() const { return lo_ || hi_; } +// +// constexpr uint128_t::operator char() const { return static_cast(lo_); } +// +// constexpr uint128_t::operator signed char() const { +// return static_cast(lo_); +// } +// +// constexpr uint128_t::operator unsigned char() const { +// return static_cast(lo_); +// } +// +// constexpr uint128_t::operator char16_t() const { +// return static_cast(lo_); +// } +// +// constexpr uint128_t::operator char32_t() const { +// return static_cast(lo_); +// } +// +// constexpr uint128_t::operator ABSL_INTERNAL_WCHAR_T() const { +// return static_cast(lo_); +// } +// +// // NOLINTNEXTLINE(runtime/int) +// constexpr uint128_t::operator short() const { return static_cast(lo_); } +// +// constexpr uint128_t::operator unsigned short() const { // NOLINT(runtime/int) +// return static_cast(lo_); // NOLINT(runtime/int) +// } +// +// constexpr uint128_t::operator int() const { return static_cast(lo_); } +// +// constexpr uint128_t::operator unsigned int() const { +// return static_cast(lo_); +// } +// +// // NOLINTNEXTLINE(runtime/int) +// constexpr uint128_t::operator long() const { return static_cast(lo_); } +// +// constexpr uint128_t::operator unsigned long() const { // NOLINT(runtime/int) +// return static_cast(lo_); // NOLINT(runtime/int) +// } +// +// constexpr uint128_t::operator long long() const { // NOLINT(runtime/int) +// return static_cast(lo_); // NOLINT(runtime/int) +// } +// +// constexpr uint128_t::operator unsigned long long() const { // NOLINT(runtime/int) +// return static_cast(lo_); // NOLINT(runtime/int) +// } +// +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// constexpr uint128_t::operator __int128_t() const { +// return (static_cast<__int128_t>(hi_) << 64) + lo_; +// } +// +// constexpr uint128_t::operator unsigned __int128_t() const { +// return (static_cast(hi_) << 64) + lo_; +// } +//#endif // ABSL_HAVE_INTRINSIC_INT128 +// +// // Conversion operators to floating point types. +// +// inline uint128_t::operator float() const { +// return static_cast(lo_) + std::ldexp(static_cast(hi_), 64); +// } +// +// inline uint128_t::operator double() const { +// return static_cast(lo_) + std::ldexp(static_cast(hi_), 64); +// } +// +// inline uint128_t::operator long double() const { +// return static_cast(lo_) + +// std::ldexp(static_cast(hi_), 64); +// } +// +// // Comparison operators. +// +// inline bool operator==(uint128_t lhs, uint128_t rhs) { +// return (Uint128Low64(lhs) == Uint128Low64(rhs) && +// Uint128High64(lhs) == Uint128High64(rhs)); +// } +// +// inline bool operator!=(uint128_t lhs, uint128_t rhs) { +// return !(lhs == rhs); +// } +// +// inline bool operator<(uint128_t lhs, uint128_t rhs) { +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// return static_cast(lhs) < +// static_cast(rhs); +//#else +// return (Uint128High64(lhs) == Uint128High64(rhs)) +// ? (Uint128Low64(lhs) < Uint128Low64(rhs)) +// : (Uint128High64(lhs) < Uint128High64(rhs)); +//#endif +// } +// +// inline bool operator>(uint128_t lhs, uint128_t rhs) { return rhs < lhs; } +// +// inline bool operator<=(uint128_t lhs, uint128_t rhs) { return !(rhs < lhs); } +// +// inline bool operator>=(uint128_t lhs, uint128_t rhs) { return !(lhs < rhs); } +// +// // Unary operators. +// +// inline uint128_t operator-(uint128_t val) { +// uint64_t hi = ~Uint128High64(val); +// uint64_t lo = ~Uint128Low64(val) + 1; +// if (lo == 0) ++hi; // carry +// return MakeUint128(hi, lo); +// } +// +// inline bool operator!(uint128_t val) { +// return !Uint128High64(val) && !Uint128Low64(val); +// } +// +// // Logical operators. +// +// inline uint128_t operator~(uint128_t val) { +// return MakeUint128(~Uint128High64(val), ~Uint128Low64(val)); +// } +// +// inline uint128_t operator|(uint128_t lhs, uint128_t rhs) { +// return MakeUint128(Uint128High64(lhs) | Uint128High64(rhs), +// Uint128Low64(lhs) | Uint128Low64(rhs)); +// } +// +// inline uint128_t operator&(uint128_t lhs, uint128_t rhs) { +// return MakeUint128(Uint128High64(lhs) & Uint128High64(rhs), +// Uint128Low64(lhs) & Uint128Low64(rhs)); +// } +// +// inline uint128_t operator^(uint128_t lhs, uint128_t rhs) { +// return MakeUint128(Uint128High64(lhs) ^ Uint128High64(rhs), +// Uint128Low64(lhs) ^ Uint128Low64(rhs)); +// } +// +// inline uint128_t& uint128_t::operator|=(uint128_t other) { +// hi_ |= other.hi_; +// lo_ |= other.lo_; +// return *this; +// } +// +// inline uint128_t& uint128_t::operator&=(uint128_t other) { +// hi_ &= other.hi_; +// lo_ &= other.lo_; +// return *this; +// } +// +// inline uint128_t& uint128_t::operator^=(uint128_t other) { +// hi_ ^= other.hi_; +// lo_ ^= other.lo_; +// return *this; +// } +// +// // Arithmetic operators. +// +// inline uint128_t operator<<(uint128_t lhs, int amount) { +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// return static_cast(lhs) << amount; +//#else +// // uint64_t shifts of >= 64 are undefined, so we will need some +// // special-casing. +// if (amount < 64) { +// if (amount != 0) { +// return MakeUint128( +// (Uint128High64(lhs) << amount) | (Uint128Low64(lhs) >> (64 - amount)), +// Uint128Low64(lhs) << amount); +// } +// return lhs; +// } +// return MakeUint128(Uint128Low64(lhs) << (amount - 64), 0); +//#endif +// } +// +// inline uint128_t operator>>(uint128_t lhs, int amount) { +//#ifdef ABSL_HAVE_INTRINSIC_INT128 +// return static_cast(lhs) >> amount; +//#else +// // uint64_t shifts of >= 64 are undefined, so we will need some +// // special-casing. +// if (amount < 64) { +// if (amount != 0) { +// return MakeUint128(Uint128High64(lhs) >> amount, +// (Uint128Low64(lhs) >> amount) | +// (Uint128High64(lhs) << (64 - amount))); +// } +// return lhs; +// } +// return MakeUint128(0, Uint128High64(lhs) >> (amount - 64)); +//#endif +// } +// +// inline uint128_t operator+(uint128_t lhs, uint128_t rhs) { +// uint128_t result = MakeUint128(Uint128High64(lhs) + Uint128High64(rhs), +// Uint128Low64(lhs) + Uint128Low64(rhs)); +// if (Uint128Low64(result) < Uint128Low64(lhs)) { // check for carry +// return MakeUint128(Uint128High64(result) + 1, Uint128Low64(result)); +// } +// return result; +// } +// +// inline uint128_t operator-(uint128_t lhs, uint128_t rhs) { +// uint128_t result = MakeUint128(Uint128High64(lhs) - Uint128High64(rhs), +// Uint128Low64(lhs) - Uint128Low64(rhs)); +// if (Uint128Low64(lhs) < Uint128Low64(rhs)) { // check for carry +// return MakeUint128(Uint128High64(result) - 1, Uint128Low64(result)); +// } +// return result; +// } +// +// inline uint128_t operator*(uint128_t lhs, uint128_t rhs) { +//#if defined(ABSL_HAVE_INTRINSIC_INT128) +// // TODO(strel) Remove once alignment issues are resolved and unsigned __int128_t +// // can be used for uint128_t storage. +// return static_cast(lhs) * +// static_cast(rhs); +//#elif defined(_MSC_VER) && defined(_M_X64) +// uint64_t carry; +// uint64_t low = _umul128(Uint128Low64(lhs), Uint128Low64(rhs), &carry); +// return MakeUint128(Uint128Low64(lhs) * Uint128High64(rhs) + +// Uint128High64(lhs) * Uint128Low64(rhs) + carry, +// low); +//#else // ABSL_HAVE_INTRINSIC128 +// uint64_t a32 = Uint128Low64(lhs) >> 32; +// uint64_t a00 = Uint128Low64(lhs) & 0xffffffff; +// uint64_t b32 = Uint128Low64(rhs) >> 32; +// uint64_t b00 = Uint128Low64(rhs) & 0xffffffff; +// uint128_t result = +// MakeUint128(Uint128High64(lhs) * Uint128Low64(rhs) + +// Uint128Low64(lhs) * Uint128High64(rhs) + a32 * b32, +// a00 * b00); +// result += uint128_t(a32 * b00) << 32; +// result += uint128_t(a00 * b32) << 32; +// return result; +//#endif // ABSL_HAVE_INTRINSIC128 +// } +// +// // Increment/decrement operators. +// +// inline uint128_t uint128_t::operator++(int) { +// uint128_t tmp(*this); +// *this += 1; +// return tmp; +// } +// +// inline uint128_t uint128_t::operator--(int) { +// uint128_t tmp(*this); +// *this -= 1; +// return tmp; +// } +// +// inline uint128_t& uint128_t::operator++() { +// *this += 1; +// return *this; +// } +// +// inline uint128_t& uint128_t::operator--() { +// *this -= 1; +// return *this; +// } +// +// +// +//} // namespace absl +// +//#undef ABSL_INTERNAL_WCHAR_T +// +//#endif // ABSL_INT128_H_ \ No newline at end of file diff --git a/libOTe_Tests/Foleage_Tests.cpp b/libOTe_Tests/Foleage_Tests.cpp index 8bd0128a..225128d5 100644 --- a/libOTe_Tests/Foleage_Tests.cpp +++ b/libOTe_Tests/Foleage_Tests.cpp @@ -11,7 +11,7 @@ #include "cryptoTools/Common/Timer.h" namespace osuCrypto { - //u8 extractF4(const uint128_t& val, u8 idx) + //u8 extractF4(const block& val, u8 idx) //{ // auto byteIdx = idx / 4; // auto bitIdx = idx % 4; @@ -20,18 +20,18 @@ namespace osuCrypto //} void testOutputCorrectness( - span shares0, - span shares1, + span shares0, + span shares1, size_t num_outputs, size_t secret_index, - span secret_msg, + span secret_msg, size_t msg_len) { for (size_t i = 0; i < msg_len; i++) { - uint128_t shareA = shares0[secret_index * msg_len + i]; - uint128_t shareB = shares1[secret_index * msg_len + i]; - uint128_t res = shareA ^ shareB; + block shareA = shares0[secret_index * msg_len + i]; + block shareB = shares1[secret_index * msg_len + i]; + block res = shareA ^ shareB; if (res != secret_msg[i]) { @@ -47,11 +47,11 @@ namespace osuCrypto for (size_t j = 0; j < msg_len; j++) { - uint128_t shareA = shares0[i * msg_len + j]; - uint128_t shareB = shares1[i * msg_len + j]; - uint128_t res = shareA ^ shareB; + block shareA = shares0[i * msg_len + j]; + block shareB = shares1[i * msg_len + j]; + block res = shareA ^ shareB; - if (res != 0) + if (res != ZeroBlock) { printf("FAIL (non-zero) %zu\n", i); printBytes(&shareA, 16); @@ -64,8 +64,8 @@ namespace osuCrypto } void printOutputShares( - uint128_t* shares0, - uint128_t* shares1, + block* shares0, + block* shares1, size_t num_outputs, size_t msg_len) { @@ -73,9 +73,9 @@ namespace osuCrypto { for (size_t j = 0; j < msg_len; j++) { - uint128_t shareA = shares0[i * msg_len + j]; - uint128_t shareB = shares1[i * msg_len + j]; - //uint128_t res = shareA ^ shareB; + block shareA = shares0[i * msg_len + j]; + block shareB = shares1[i * msg_len + j]; + //block res = shareA ^ shareB; printf("(%zu, %zu) %zu\n", i, j, msg_len); printBytes(&shareA, 16); @@ -87,18 +87,18 @@ namespace osuCrypto void testOutputCorrectness_spf( - span shares0, - span shares1, + span shares0, + span shares1, size_t num_outputs, size_t secret_index, - span secret_msg, + span secret_msg, size_t msg_len) { for (size_t i = 0; i < msg_len; i++) { - uint128_t shareA = shares0[secret_index * msg_len + i]; - uint128_t shareB = shares1[secret_index * msg_len + i]; - uint128_t res = shareA ^ shareB; + block shareA = shares0[secret_index * msg_len + i]; + block shareB = shares1[secret_index * msg_len + i]; + block res = shareA ^ shareB; if (res != secret_msg[i]) { @@ -114,11 +114,11 @@ namespace osuCrypto for (size_t j = 0; j < msg_len; j++) { - uint128_t shareA = shares0[i * msg_len + j]; - uint128_t shareB = shares1[i * msg_len + j]; - uint128_t res = shareA ^ shareB; + block shareA = shares0[i * msg_len + j]; + block shareB = shares1[i * msg_len + j]; + block res = shareA ^ shareB; - if (res != 0) + if (res != ZeroBlock) { printf("FAIL (non-zero) %zu\n", i); printBytes(&shareA, 16); @@ -131,8 +131,8 @@ namespace osuCrypto } void printOutputShares_spf( - uint128_t* shares0, - uint128_t* shares1, + block* shares0, + block* shares1, size_t num_outputs, size_t msg_len) { @@ -140,9 +140,9 @@ namespace osuCrypto { for (size_t j = 0; j < msg_len; j++) { - uint128_t shareA = shares0[i * msg_len + j]; - uint128_t shareB = shares1[i * msg_len + j]; - //uint128_t res = shareA ^ shareB; + block shareA = shares0[i * msg_len + j]; + block shareB = shares1[i * msg_len + j]; + //block res = shareA ^ shareB; printf("(%zu, %zu) %zu\n", i, j, msg_len); printBytes(&shareA, 16); @@ -300,7 +300,7 @@ namespace osuCrypto } { - int randomize = 241234123; // set to 1 to make debuggable + int randomize = 241234123; // set to 1 to make debuggable std::vector v(9 * 8); std::vector v2(9 * 8); @@ -352,11 +352,11 @@ namespace osuCrypto } } - if(0) + if (0) { u64 trials = 1000000; - int randomize = 241234123; // set to 1 to make debuggable + //int randomize = 241234123; // set to 1 to make debuggable u64 ss = 9; std::vector lsb(ss * trials), msb(ss * trials); @@ -411,12 +411,12 @@ namespace osuCrypto for (u64 j = 0; j < 3; ++j) { - foleageTransposeLeaf<2>((u8*)& bLsb[j * 3], (__m128i*)& bLsb[j * 3]); - foleageTransposeLeaf<2>((u8*)& bMsb[j * 3], (__m128i*)& bMsb[j * 3]); + foleageTransposeLeaf<2>((u8*)&bLsb[j * 3], (__m128i*) & bLsb[j * 3]); + foleageTransposeLeaf<2>((u8*)&bMsb[j * 3], (__m128i*) & bMsb[j * 3]); foleageFFTOne<1>( - &bLsb2[j * 3 + 0], &bLsb2[j * 3 + 0], - &bLsb2[j * 3 + 1], &bLsb2[j * 3 + 1], - &bLsb2[j * 3 + 2], &bLsb2[j * 3 + 2] + &bLsb2[j * 3 + 0], &bMsb2[j * 3 + 0], + &bLsb2[j * 3 + 1], &bMsb2[j * 3 + 1], + &bLsb2[j * 3 + 2], &bMsb2[j * 3 + 2] ); } @@ -424,7 +424,7 @@ namespace osuCrypto foleageTranspose<2>((u8*)&bMsb2[0], (__m128i*)bMsb); - foleageFFTOne<3,block>( + foleageFFTOne<3, block>( &bLsb[0], &bMsb[0], &bLsb[3], &bMsb[3], &bLsb[6], &bMsb[6] @@ -457,16 +457,16 @@ namespace osuCrypto for (u64 i = 0; i < n; ++i) { lsb[i] = - (a[i] >> 0) & 1 | - (a[i] >> 1) & 2 | - (a[i] >> 2) & 4 | - (a[i] >> 3) & 8; + ((a[i] >> 0) & 1) | + ((a[i] >> 1) & 2) | + ((a[i] >> 2) & 4) | + ((a[i] >> 3) & 8); auto m = a[i] >> 1; msb[i] = - (m >> 0) & 1 | - (m >> 1) & 2 | - (m >> 2) & 4 | - (m >> 3) & 8; + ((m >> 0) & 1) | + ((m >> 1) & 2) | + ((m >> 2) & 4) | + ((m >> 3) & 8); } timer.setTimePoint("begin"); @@ -478,16 +478,16 @@ namespace osuCrypto for (u64 i = 0; i < n; ++i) { auto a0 = - (a[i] >> 0) & 1 | - (a[i] >> 1) & 2 | - (a[i] >> 2) & 4 | - (a[i] >> 3) & 8; + ((a[i] >> 0) & 1) | + ((a[i] >> 1) & 2) | + ((a[i] >> 2) & 4) | + ((a[i] >> 3) & 8); auto m = a[i] >> 1; auto a1 = - (m >> 0) & 1 | - (m >> 1) & 2 | - (m >> 2) & 4 | - (m >> 3) & 8; + ((m >> 0) & 1) | + ((m >> 1) & 2) | + ((m >> 2) & 4) | + ((m >> 3) & 8); if (a0 != lsb[i] || a1 != msb[i]) throw RTE_LOC; @@ -572,9 +572,9 @@ namespace osuCrypto size_t secret_index = prng.get() % MAXRANDINDEX; // sample a random message of size msg_len - std::vector secret_msg(msg_len); + std::vector secret_msg(msg_len); for (size_t i = 0; i < msg_len; i++) - secret_msg[i] = prng.get(); + secret_msg[i] = prng.get(); PRFKeys prf_keys; prf_keys.gen(prng); @@ -585,9 +585,9 @@ namespace osuCrypto for (size_t i = 0; i < SUMT; i++) DPFGen(prf_keys, size, secret_index, secret_msg, msg_len, kA[i], kB[i], prng); - std::vector shares0(num_leaves * msg_len); - std::vector shares1(num_leaves * msg_len); - std::vector cache(num_leaves * msg_len); + std::vector shares0(num_leaves * msg_len); + std::vector shares1(num_leaves * msg_len); + std::vector cache(num_leaves * msg_len); //************************************************ // Test full domain evaluation @@ -638,9 +638,9 @@ namespace osuCrypto size_t secret_index = prng.get() % ipow(3, size); // sample a random message of size msg_len - std::vector secret_msg(msg_len); + std::vector secret_msg(msg_len); for (size_t i = 0; i < msg_len; i++) - secret_msg[i] = prng.get(); + secret_msg[i] = prng.get(); PRFKeys prf_keys; prf_keys.gen(prng); @@ -650,9 +650,9 @@ namespace osuCrypto DPFGen(prf_keys, size, secret_index, secret_msg, msg_len, kA, kB, prng); - std::vector shares0(num_leaves * msg_len); - std::vector shares1(num_leaves * msg_len); - std::vector cache(num_leaves * msg_len); + std::vector shares0(num_leaves * msg_len); + std::vector shares1(num_leaves * msg_len); + std::vector cache(num_leaves * msg_len); //************************************************ // Test full domain evaluation @@ -704,7 +704,7 @@ namespace osuCrypto void foleage_pcg_test(const CLP& cmd) { bool check = !cmd.isSet("noCheck"); - auto N = 12; // 3^N number of OLEs generated in total + auto N = 5; // 3^N number of OLEs generated in total // The C and T parameters are computed using the SageMath script that can be // found in https://github.com/mbombar/estimator_folding @@ -1024,19 +1024,36 @@ namespace osuCrypto // Coeff index in the block of 256 coefficients size_t alpha_1 = alpha % 256; - // Coeff index in the uint128_t output (64 elements of F4) - size_t packed_idx = floor(alpha_1 / 64.0); + // Coeff index in the block output (64 elements of F4) + size_t byte_idx = alpha_1 / 4; - // Bit index in the uint128_t ouput - size_t bit_idx = alpha_1 % 64; + // Bit index in the block ouput + size_t element_idx = alpha_1 % 4; // Set the DPF message to the coefficient - uint128_t coeff = uint128_t(err_poly_cross_coeffs[index]); + u8 coeff = err_poly_cross_coeffs[index];//block(err_poly_cross_coeffs[index]); // Position coefficient into the block - std::array beta; // init to zero + std::array beta; // init to zero setBytes(beta, 0); - beta[packed_idx] = coeff << (2 * (63 - bit_idx)); + + // Set the coefficient in the right position + ((uint8_t*)&beta)[byte_idx] = coeff << (2 * element_idx); + //beta[packed_idx] = coeff << (2 * (63 - bit_idx)); + + + // Coeff index in the block output (64 elements of F4) + size_t packed_idx = alpha_1 / 4; + + //// Bit index in the block ouput + //size_t bit_idx = alpha_1 % 4; + //std::array beta2; // init to zero + //beta2[packed_idx] = uint128_t{ coeff } << (2 * (63 - bit_idx)); + //if (memcmp(&beta, &beta2, sizeof(beta)) != 0) + //{ + // std::cout << "FAIL: beta != beta2" << std::endl; + // throw RTE_LOC; + //} // Message (beta) is of size 4 blocks of 128 bits genPrng.SetSeed(block(index, 542345234)); @@ -1067,9 +1084,9 @@ namespace osuCrypto //************************************************************************ // Allocate memory for the DPF outputs (this is reused for each evaluation) - std::vector shares_A(dpf_block_size); - std::vector shares_B(dpf_block_size); - std::vector cache(dpf_block_size); + std::vector shares_A(dpf_block_size); + std::vector shares_B(dpf_block_size); + std::vector cache(dpf_block_size); // Allocate memory for the concatenated DPF outputs size_t packed_block_size = divCeil(block_size, 64); @@ -1079,10 +1096,10 @@ namespace osuCrypto // printf("[DEBUG]: packed_poly_size = %zu\n", packed_poly_size); // // each row is a block. every t rows is a polynomial. - Matrix packed_polys_A_(c * c * t, packed_block_size); - Matrix packed_polys_B_(c * c * t, packed_block_size); - //std::vector packed_polys_A(c * c * packed_poly_size); - //std::vector packed_polys_B(c * c * packed_poly_size); + Matrix packed_polys_A_(c * c * t, packed_block_size); + Matrix packed_polys_B_(c * c * t, packed_block_size); + //std::vector packed_polys_A(c * c * packed_poly_size); + //std::vector packed_polys_B(c * c * packed_poly_size); // Allocate memory for the output FFT std::vectorfft_uA(poly_size); @@ -1105,15 +1122,15 @@ namespace osuCrypto { const size_t poly_index = i * c + j; - oc::MatrixView packed_polyA_(packed_polys_A_.data(poly_index * t), t, packed_block_size); - oc::MatrixView packed_polyB_(packed_polys_B_.data(poly_index * t), t, packed_block_size); - //uint128_t* packed_polyA = &packed_polys_A[poly_index * packed_poly_size]; - //uint128_t* packed_polyB = &packed_polys_B[poly_index * packed_poly_size]; + oc::MatrixView packed_polyA_(packed_polys_A_.data(poly_index * t), t, packed_block_size); + oc::MatrixView packed_polyB_(packed_polys_B_.data(poly_index * t), t, packed_block_size); + //block* packed_polyA = &packed_polys_A[poly_index * packed_poly_size]; + //block* packed_polyB = &packed_polys_B[poly_index * packed_poly_size]; for (size_t k = 0; k < t; k++) { - span poly_blockA = packed_polyA_[k]; - span poly_blockB = packed_polyB_[k]; + span poly_blockA = packed_polyA_[k]; + span poly_blockB = packed_polyB_[k]; for (size_t l = 0; l < t; l++) { @@ -1126,7 +1143,7 @@ namespace osuCrypto // Sum all the DPFs for the current block together // note that there is some extra "garbage" in the last - // block of uint128_t since 64 does not divide block_size. + // block of block since 64 does not divide block_size. // We deal with this slack later when packing the outputs // into the parallel FFT matrix. for (size_t w = 0; w < packed_block_size; w++) @@ -1155,20 +1172,20 @@ namespace osuCrypto size_t err_count = 0; size_t poly_index = i * c + j; - oc::MatrixView packed_polyA_(packed_polys_A_.data(poly_index * t), t, packed_block_size); - oc::MatrixView packed_polyB_(packed_polys_B_.data(poly_index * t), t, packed_block_size); - //uint128_t* poly_A = &packed_polys_A[poly_index * packed_poly_size]; - //uint128_t* poly_B = &packed_polys_B[poly_index * packed_poly_size]; + oc::MatrixView packed_polyA_(packed_polys_A_.data(poly_index * t), t, packed_block_size); + oc::MatrixView packed_polyB_(packed_polys_B_.data(poly_index * t), t, packed_block_size); + //block* poly_A = &packed_polys_A[poly_index * packed_poly_size]; + //block* poly_B = &packed_polys_B[poly_index * packed_poly_size]; for (size_t p = 0; p < packed_poly_size; p++) { - uint128_t res = packed_polyA_(p) ^ packed_polyB_(p); - if (res) + block res = packed_polyA_(p) ^ packed_polyB_(p); + if (res != ZeroBlock) { auto e = extractF4(res); for (size_t l = 0; l < 64; l++) { - //if (((res >> (2 * (63 - l))) & uint128_t(0b11)) != uint128_t(0)) + //if (((res >> (2 * (63 - l))) & block(0b11)) != block(0)) err_count += (e[l] | (e[l] >> 1)) & 1; //if (e[l]) // err_count++; @@ -1215,14 +1232,13 @@ namespace osuCrypto size_t poly_index = j * c + k; - oc::MatrixView poly_A(packed_polys_A_.data(poly_index * t), t, packed_block_size); - oc::MatrixView poly_B(packed_polys_B_.data(poly_index * t), t, packed_block_size); + oc::MatrixView poly_A(packed_polys_A_.data(poly_index * t), t, packed_block_size); + oc::MatrixView poly_B(packed_polys_B_.data(poly_index * t), t, packed_block_size); - //uint128_t* poly_A = &packed_polys_A[poly_index * packed_poly_size]; - //uint128_t* poly_B = &packed_polys_B[poly_index * packed_poly_size]; + //block* poly_A = &packed_polys_A[poly_index * packed_poly_size]; + //block* poly_B = &packed_polys_B[poly_index * packed_poly_size]; - u64 i = 0; - for (u64 block_idx = 0; block_idx < t; ++block_idx) + for (u64 block_idx = 0, i = 0; block_idx < t; ++block_idx) { for (u64 packed_idx = 0; packed_idx < packed_block_size; ++packed_idx) { @@ -1235,8 +1251,8 @@ namespace osuCrypto auto e = std::min(block_size - packed_idx * 64, 64); for (u64 element_idx = 0; element_idx < e; ++element_idx) { - test_poly_A[i] = coeffA[63 - element_idx]; - test_poly_B[i] = coeffB[63 - element_idx]; + test_poly_A[i] = coeffA[/*63 - */element_idx]; + test_poly_B[i] = coeffB[/*63 - */element_idx]; ++i; } } @@ -1250,6 +1266,23 @@ namespace osuCrypto if (got_coeff != exp_coeff) { printf("FAIL: incorrect cross coefficient at index %zu (%i =/= %i)\n", i, got_coeff, exp_coeff); + + + + for (size_t i = 0; i < poly_size; i++) + { + int exp_coeff = err_polys_cross[j * c * poly_size + k * poly_size + i]; + std::cout << exp_coeff << " "; + + } + std::cout << "\n"; + for (size_t i = 0; i < poly_size; i++) + { + int got_coeff = test_poly_A[i] ^ test_poly_B[i]; + std::cout << got_coeff << " "; + } + std::cout << "\n"; + throw RTE_LOC; } } @@ -1269,8 +1302,8 @@ namespace osuCrypto { size_t poly_index = (j * c + k);// *packed_poly_size; - oc::MatrixView polyA(packed_polys_A_.data(poly_index * t), t, packed_block_size); - oc::MatrixView polyB(packed_polys_B_.data(poly_index * t), t, packed_block_size); + oc::MatrixView polyA(packed_polys_A_.data(poly_index * t), t, packed_block_size); + oc::MatrixView polyB(packed_polys_B_.data(poly_index * t), t, packed_block_size); u64 i = 0; for (u64 block_idx = 0; block_idx < t; ++block_idx) @@ -1287,8 +1320,8 @@ namespace osuCrypto for (u64 element_idx = 0; element_idx < e; ++element_idx) { - fft_uA[i] |= u32{ coeffA[63 - element_idx] } << (2 * poly_index); - fft_uB[i] |= u32{ coeffB[63 - element_idx] } << (2 * poly_index); + fft_uA[i] |= u32{ coeffA[/*63 - */element_idx] } << (2 * poly_index); + fft_uB[i] |= u32{ coeffB[/*63 - */element_idx] } << (2 * poly_index); ++i; } } @@ -1320,7 +1353,7 @@ namespace osuCrypto // XOR the (packed) columns into the accumulator. // Specifically, we perform column-wise XORs to get the result. - uint128_t lsbMask, msbMask; + u32 lsbMask, msbMask; setBytes(lsbMask, 0b01010101); setBytes(msbMask, 0b10101010); for (size_t i = 0; i < poly_size; i++) @@ -1329,12 +1362,12 @@ namespace osuCrypto //auto resB = extractF4(res_poly_mat_B[i]); z_poly_A[i] = - popcount(res_poly_mat_A[i] & lsbMask) & 1 | - (popcount(res_poly_mat_A[i] & msbMask) & 1) << 1; + (popcount(res_poly_mat_A[i] & lsbMask) & 1) | + ((popcount(res_poly_mat_A[i] & msbMask) & 1) << 1); z_poly_B[i] = - popcount(res_poly_mat_B[i] & lsbMask) & 1 | - (popcount(res_poly_mat_B[i] & msbMask) & 1) << 1; + (popcount(res_poly_mat_B[i] & lsbMask) & 1) | + ((popcount(res_poly_mat_B[i] & msbMask) & 1) << 1); //u8 aSum = 0; @@ -1396,14 +1429,14 @@ namespace osuCrypto auto blocks = divCeil(n, 128); bool verbose = cmd.isSet("v"); - if(cmd.hasValue("t")) + if (cmd.hasValue("t")) oles[0].mT = oles[1].mT = cmd.get("t"); //PRNG prng(block(342342)); PRNG prng0(block(2424523452345, 111124521521455324)); PRNG prng1(block(6474567454546, 567546754674345444)); Timer timer; - + oles[0].init(0, n); oles[1].init(1, n); @@ -1423,7 +1456,7 @@ namespace osuCrypto { prng0.get(baseSend[i].data(), baseSend[i].size()); baseRecv[1 ^ i].resize(baseSend[i].size()); - baseChoice[1^i].resize(baseSend[i].size()); + baseChoice[1 ^ i].resize(baseSend[i].size()); baseChoice[1 ^ i].randomize(prng0); for (u64 j = 0; j < baseSend[i].size(); ++j) { @@ -1446,7 +1479,7 @@ namespace osuCrypto C1Lsb(blocks), C1Msb(blocks); - if(verbose) + if (verbose) oles[0].setTimer(timer); auto r = macoro::sync_wait(macoro::when_all_ready( @@ -1481,7 +1514,7 @@ namespace osuCrypto std::array oles; - bool verbose = cmd.isSet("v"); + //bool verbose = cmd.isSet("v"); PRNG prng0(block(2424523452345, 111124521521455324)); PRNG prng1(block(6474567454546, 567546754674345444)); @@ -1489,7 +1522,7 @@ namespace osuCrypto oles[0].init(0, 1000); oles[1].init(1, 1000); - u64 n = oles[0].mC* oles[0].mT; + u64 n = oles[0].mC * oles[0].mT; u64 n2 = n * n; auto sock = coproto::LocalAsyncSocket::makePair(); std::array, 2> coeff, prod; @@ -1508,8 +1541,8 @@ namespace osuCrypto oles[0].mRecvOts[i] = oles[1].mSendOts[i][oles[0].mChoiceOts[i]]; } auto r = macoro::sync_wait(macoro::when_all_ready( - oles[0].tensor(coeff[0],prod[0], sock[0]), - oles[1].tensor(coeff[1],prod[1], sock[1]))); + oles[0].tensor(coeff[0], prod[0], sock[0]), + oles[1].tensor(coeff[1], prod[1], sock[1]))); std::get<0>(r).result(); std::get<1>(r).result(); diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 9a9415a5..f84740a0 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -165,8 +165,8 @@ void RegularDpf_Proto_Test(const CLP& cmd) auto sock = coproto::LocalAsyncSocket::makePair(); macoro::sync_wait(macoro::when_all_ready( - dpf[0].expand(points0, values0, prng.get(), [&](auto k, auto i, auto v, auto t) { output[0](k, i) = v; tags[0](k, i) = t.get(0) & 1; }, sock[0]), - dpf[1].expand(points1, values1, prng.get(), [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }, sock[1]) + dpf[0].expand(points0, values0, prng.get(), [&](auto k, auto i, auto v, block t) { output[0](k, i) = v; tags[0](k, i) = t.get(0) & 1; }, sock[0]), + dpf[1].expand(points1, values1, prng.get(), [&](auto k, auto i, auto v, block t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }, sock[1]) )); @@ -271,8 +271,8 @@ void RegularDpf_keyGen_Test(const oc::CLP& cmd) } if (key[1] != key2[1]) throw RTE_LOC; - RegularDpf::expand(0, domain, key2[0], [&](auto k, auto i, auto v, auto t) { output[0](k, i) = v; tags[0](k, i) = t.get(0) & 1; }); - RegularDpf::expand(1, domain, key2[1], [&](auto k, auto i, auto v, auto t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }); + RegularDpf::expand(0, domain, key2[0], [&](auto k, auto i, auto v, block t) { output[0](k, i) = v; tags[0](k, i) = t.get(0) & 1; }); + RegularDpf::expand(1, domain, key2[1], [&](auto k, auto i, auto v, block t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }); for (u64 i = 0; i < domain; ++i) { From d6edddae2efff441d2c16c3c556594cd67a58a59 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Tue, 25 Feb 2025 08:50:53 -0800 Subject: [PATCH 22/48] working version --- frontend/ExampleBase.cpp | 3 + frontend/ExampleNChooseOne.cpp | 420 ++++++++++---------- frontend/ExampleSilent.cpp | 344 ++++++++-------- frontend/ExampleTwoChooseOne.cpp | 591 ++++++++++++++-------------- frontend/ExampleVole.cpp | 2 + frontend/benchmark.h | 10 +- frontend/main.cpp | 4 +- libOTe/CMakeLists.txt | 1 + libOTe/Tools/Dpf/TriDpf.h | 64 ++- libOTe/Tools/Foleage/FoleagePcg.cpp | 8 +- libOTe_Tests/RegularDpf_Tests.cpp | 14 +- 11 files changed, 731 insertions(+), 730 deletions(-) diff --git a/frontend/ExampleBase.cpp b/frontend/ExampleBase.cpp index 13b2726d..89122494 100644 --- a/frontend/ExampleBase.cpp +++ b/frontend/ExampleBase.cpp @@ -75,6 +75,9 @@ namespace osuCrypto std::cout << tag << (role == Role::Receiver ? " (receiver)" : " (sender)") << " n=" << totalOTs << " " << milli << " ms" << std::endl; + +#else + std::cout << "This example requires coproto to enable boost support. Please build libOTe with `-DCOPROTO_ENABLE_BOOST=ON`. \n" << LOCATION << std::endl; #endif } diff --git a/frontend/ExampleNChooseOne.cpp b/frontend/ExampleNChooseOne.cpp index c2ed1a3b..37f78d54 100644 --- a/frontend/ExampleNChooseOne.cpp +++ b/frontend/ExampleNChooseOne.cpp @@ -16,225 +16,227 @@ namespace osuCrypto - auto chls = cp::LocalAsyncSocket::makePair(); + auto chls = cp::LocalAsyncSocket::makePair(); - template - void NChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP&) - { + template + void NChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP&) + { #ifdef COPROTO_ENABLE_BOOST - const u64 step = 1024; - - if (totalOTs == 0) - totalOTs = 1 << 20; - - bool randomOT = true; - u64 numOTs = (u64)totalOTs; - auto numChosenMsgs = 256; - - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); - //auto chl = role == Role::Sender ? chls[0] : chls[1]; - - PRNG prng(ZeroBlock);// sysRandomSeed()); - - NcoOtSender sender; - NcoOtReceiver recver; - - // all Nco Ot extenders must have configure called first. This determines - // a variety of parameters such as how many base OTs are required. - bool maliciousSecure = false; - u64 statSecParam = 40; - u64 inputBitCount = 76; // the kkrt protocol default to 128 but oos can only do 76. - - // create a lambda function that performs the computation of a single receiver thread. - auto recvRoutine = [&]() -> task<> - { - auto i = u64{}, min = u64{}; - auto recvMsgs = std::vector{}; - auto choices = std::vector{}; - - recver.configure(maliciousSecure, statSecParam, inputBitCount); - - if (randomOT) - { - // once configure(...) and setBaseOts(...) are called, - // we can compute many batches of OTs. First we need to tell - // the instance how many OTs we want in this batch. This is done here. - co_await (recver.init(numOTs, prng, chl)); - - // now we can iterate over the OTs and actually retrieve the desired - // messages. However, for efficiency we will do this in steps where - // we do some computation followed by sending off data. This is more - // efficient since data will be sent in the background :). - for (i = 0; i < numOTs; ) - { - // figure out how many OTs we want to do in this step. - min = std::min(numOTs - i, step); - - //// iterate over this step. - for (u64 j = 0; j < min; ++j, ++i) - { - // For the OT index by i, we need to pick which - // one of the N OT messages that we want. For this - // example we simply pick a random one. Note only the - // first log2(N) bits of choice is considered. - block choice = prng.get(); - - // this will hold the (random) OT message of our choice - block otMessage; - - // retrieve the desired message. - recver.encode(i, &choice, &otMessage); - - // do something cool with otMessage - //otMessage; - } - - // Note that all OTs in this region must be encode. If there are some - // that you don't actually care about, then you can skip them by calling - // - // recver.zeroEncode(i); - // - - // Now that we have gotten out the OT mMessages for this step, - // we are ready to send over network some information that - // allows the sender to also compute the OT mMessages. Since we just - // encoded "min" OT mMessages, we will tell the class to send the - // next min "correction" values. - co_await (recver.sendCorrection(chl, min)); - } - - // once all numOTs have been encoded and had their correction values sent - // we must call check. This allows to sender to make sure we did not cheat. - // For semi-honest protocols, this can and will be skipped. - co_await (recver.check(chl, prng.get())); - } - else - { - recvMsgs.resize(numOTs); - choices.resize(numOTs); - - // define which messages the receiver should learn. - for (i = 0; i < numOTs; ++i) - choices[i] = prng.get(); - - // the messages that were learned are written to recvMsgs. - co_await (recver.receiveChosen(numChosenMsgs, recvMsgs, choices, prng, chl)); - } - - co_await (chl.flush()); - }; - - // create a lambda function that performs the computation of a single sender thread. - auto sendRoutine = [&]() -> macoro::task<> - { - auto sendMessages = Matrix{}; - auto i = u64{}, min = u64{}; - - sender.configure(maliciousSecure, statSecParam, inputBitCount); - //co_await (sync(chl, Role::Sender)); - - if (randomOT) - { - // Same explanation as above. - co_await (sender.init(numOTs, prng, chl)); - - // Same explanation as above. - for (i = 0; i < numOTs; ) - { - // Same explanation as above. - min = std::min(numOTs - i, step); - - // unlike for the receiver, before we call encode to get - // some desired OT message, we must call recvCorrection(...). - // This receivers some information that the receiver had sent - // and allows the sender to compute any OT message that they desired. - // Note that the step size must match what the receiver used. - // If this is unknown you can use recvCorrection(chl) -> u64 - // which will tell you how many were sent. - co_await (sender.recvCorrection(chl, min)); - - // we now encode any OT message with index less that i + min. - for (u64 j = 0; j < min; ++j, ++i) - { - // in particular, the sender can retrieve many OT messages - // at a single index, in this case we chose to retrieve 3 - // but that is arbitrary. - auto choice0 = prng.get(); - auto choice1 = prng.get(); - auto choice2 = prng.get(); - - // these we hold the actual OT messages. - block - otMessage0, - otMessage1, - otMessage2; - - // now retrieve the messages - sender.encode(i, &choice0, &otMessage0); - sender.encode(i, &choice1, &otMessage1); - sender.encode(i, &choice2, &otMessage2); - } - } - - // This call is required to make sure the receiver did not cheat. - // All corrections must be received before this is called. - co_await (sender.check(chl, ZeroBlock)); - } - else - { - // populate this with the messages that you want to send. - sendMessages.resize(numOTs, numChosenMsgs); - prng.get(sendMessages.data(), sendMessages.size()); - - // perform the OTs with the given messages. - co_await (sender.sendChosen(sendMessages, prng, chl)); - - } - - co_await (chl.flush()); - }; - - - Timer time; - auto s = time.setTimePoint("start"); - - - task<> proto; - if (role == Role::Sender) - proto = sendRoutine(); - else - proto = recvRoutine(); - try - { - cp::sync_wait(proto); - } - catch (std::exception& e) - { - std::cout << e.what() << std::endl; - } - - auto e = time.setTimePoint("finish"); - auto milli = std::chrono::duration_cast(e - s).count(); - - if (role == Role::Sender) - std::cout << tag << " n=" << totalOTs << " " << milli << " ms " << std::endl; + const u64 step = 1024; + + if (totalOTs == 0) + totalOTs = 1 << 20; + + bool randomOT = true; + u64 numOTs = (u64)totalOTs; + auto numChosenMsgs = 256; + + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); + //auto chl = role == Role::Sender ? chls[0] : chls[1]; + + PRNG prng(ZeroBlock);// sysRandomSeed()); + + NcoOtSender sender; + NcoOtReceiver recver; + + // all Nco Ot extenders must have configure called first. This determines + // a variety of parameters such as how many base OTs are required. + bool maliciousSecure = false; + u64 statSecParam = 40; + u64 inputBitCount = 76; // the kkrt protocol default to 128 but oos can only do 76. + + // create a lambda function that performs the computation of a single receiver thread. + auto recvRoutine = [&]() -> task<> + { + auto i = u64{}, min = u64{}; + auto recvMsgs = std::vector{}; + auto choices = std::vector{}; + + recver.configure(maliciousSecure, statSecParam, inputBitCount); + + if (randomOT) + { + // once configure(...) and setBaseOts(...) are called, + // we can compute many batches of OTs. First we need to tell + // the instance how many OTs we want in this batch. This is done here. + co_await(recver.init(numOTs, prng, chl)); + + // now we can iterate over the OTs and actually retrieve the desired + // messages. However, for efficiency we will do this in steps where + // we do some computation followed by sending off data. This is more + // efficient since data will be sent in the background :). + for (i = 0; i < numOTs; ) + { + // figure out how many OTs we want to do in this step. + min = std::min(numOTs - i, step); + + //// iterate over this step. + for (u64 j = 0; j < min; ++j, ++i) + { + // For the OT index by i, we need to pick which + // one of the N OT messages that we want. For this + // example we simply pick a random one. Note only the + // first log2(N) bits of choice is considered. + block choice = prng.get(); + + // this will hold the (random) OT message of our choice + block otMessage; + + // retrieve the desired message. + recver.encode(i, &choice, &otMessage); + + // do something cool with otMessage + //otMessage; + } + + // Note that all OTs in this region must be encode. If there are some + // that you don't actually care about, then you can skip them by calling + // + // recver.zeroEncode(i); + // + + // Now that we have gotten out the OT mMessages for this step, + // we are ready to send over network some information that + // allows the sender to also compute the OT mMessages. Since we just + // encoded "min" OT mMessages, we will tell the class to send the + // next min "correction" values. + co_await(recver.sendCorrection(chl, min)); + } + + // once all numOTs have been encoded and had their correction values sent + // we must call check. This allows to sender to make sure we did not cheat. + // For semi-honest protocols, this can and will be skipped. + co_await(recver.check(chl, prng.get())); + } + else + { + recvMsgs.resize(numOTs); + choices.resize(numOTs); + + // define which messages the receiver should learn. + for (i = 0; i < numOTs; ++i) + choices[i] = prng.get(); + + // the messages that were learned are written to recvMsgs. + co_await(recver.receiveChosen(numChosenMsgs, recvMsgs, choices, prng, chl)); + } + + co_await(chl.flush()); + }; + + // create a lambda function that performs the computation of a single sender thread. + auto sendRoutine = [&]() -> macoro::task<> + { + auto sendMessages = Matrix{}; + auto i = u64{}, min = u64{}; + + sender.configure(maliciousSecure, statSecParam, inputBitCount); + //co_await (sync(chl, Role::Sender)); + + if (randomOT) + { + // Same explanation as above. + co_await(sender.init(numOTs, prng, chl)); + + // Same explanation as above. + for (i = 0; i < numOTs; ) + { + // Same explanation as above. + min = std::min(numOTs - i, step); + + // unlike for the receiver, before we call encode to get + // some desired OT message, we must call recvCorrection(...). + // This receivers some information that the receiver had sent + // and allows the sender to compute any OT message that they desired. + // Note that the step size must match what the receiver used. + // If this is unknown you can use recvCorrection(chl) -> u64 + // which will tell you how many were sent. + co_await(sender.recvCorrection(chl, min)); + + // we now encode any OT message with index less that i + min. + for (u64 j = 0; j < min; ++j, ++i) + { + // in particular, the sender can retrieve many OT messages + // at a single index, in this case we chose to retrieve 3 + // but that is arbitrary. + auto choice0 = prng.get(); + auto choice1 = prng.get(); + auto choice2 = prng.get(); + + // these we hold the actual OT messages. + block + otMessage0, + otMessage1, + otMessage2; + + // now retrieve the messages + sender.encode(i, &choice0, &otMessage0); + sender.encode(i, &choice1, &otMessage1); + sender.encode(i, &choice2, &otMessage2); + } + } + + // This call is required to make sure the receiver did not cheat. + // All corrections must be received before this is called. + co_await(sender.check(chl, ZeroBlock)); + } + else + { + // populate this with the messages that you want to send. + sendMessages.resize(numOTs, numChosenMsgs); + prng.get(sendMessages.data(), sendMessages.size()); + + // perform the OTs with the given messages. + co_await(sender.sendChosen(sendMessages, prng, chl)); + + } + + co_await(chl.flush()); + }; + + + Timer time; + auto s = time.setTimePoint("start"); + + + task<> proto; + if (role == Role::Sender) + proto = sendRoutine(); + else + proto = recvRoutine(); + try + { + cp::sync_wait(proto); + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + } + + auto e = time.setTimePoint("finish"); + auto milli = std::chrono::duration_cast(e - s).count(); + + if (role == Role::Sender) + std::cout << tag << " n=" << totalOTs << " " << milli << " ms " << std::endl; +#else + std::cout << "This example requires coproto to enable boost support. Please build libOTe with `-DCOPROTO_ENABLE_BOOST=ON`. \n" << LOCATION << std::endl; #endif - } + } - bool NChooseOne_Examples(const CLP& cmd) - { - bool flagSet = false; + bool NChooseOne_Examples(const CLP& cmd) + { + bool flagSet = false; #ifdef ENABLE_KKRT - flagSet |= runIf(NChooseOne_example, cmd, kkrt); + flagSet |= runIf(NChooseOne_example, cmd, kkrt); #endif #ifdef ENABLE_OOS - flagSet |= runIf(NChooseOne_example, cmd, oos); + flagSet |= runIf(NChooseOne_example, cmd, oos); #endif - return flagSet; - } + return flagSet; + } } diff --git a/frontend/ExampleSilent.cpp b/frontend/ExampleSilent.cpp index c2b17d02..f2a7efcc 100644 --- a/frontend/ExampleSilent.cpp +++ b/frontend/ExampleSilent.cpp @@ -10,180 +10,182 @@ namespace osuCrypto { - void Silent_example(Role role, u64 numOTs, u64 numThreads, std::string ip, std::string tag, const CLP& cmd) - { + void Silent_example(Role role, u64 numOTs, u64 numThreads, std::string ip, std::string tag, const CLP& cmd) + { #if defined(ENABLE_SILENTOT) && defined(COPROTO_ENABLE_BOOST) - if (numOTs == 0) - numOTs = 1 << 20; - - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); - - - PRNG prng(sysRandomSeed()); - - bool fakeBase = cmd.isSet("fakeBase"); - u64 trials = cmd.getOr("trials", 1); - auto malicious = cmd.isSet("mal") ? SilentSecType::Malicious : SilentSecType::SemiHonest; - - auto multType = (MultType)cmd.getOr("multType", (int)DefaultMultType); - - std::vector types; - if (cmd.isSet("base")) - types.push_back(SilentBaseType::Base); - else - types.push_back(SilentBaseType::BaseExtend); - - macoro::thread_pool threadPool; - auto work = threadPool.make_work(); - if (numThreads > 1) - threadPool.create_threads(numThreads); - - for (auto type : types) - { - for (u64 tt = 0; tt < trials; ++tt) - { - Timer timer; - auto start = timer.setTimePoint("start"); - if (role == Role::Sender) - { - SilentOtExtSender sender; - - // optionally request the LPN encoding matrix. - sender.mMultType = multType; - - // optionally configure the sender. default is semi honest security. - sender.configure(numOTs, 2, numThreads, malicious); - - if (fakeBase) - { - auto nn = sender.baseOtCount(); - BitVector bits(nn); - bits.randomize(prng); - std::vector> baseSendMsgs(bits.size()); - std::vector baseRecvMsgs(bits.size()); - - auto commonPrng = PRNG(ZeroBlock); - commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); - for (u64 i = 0; i < bits.size(); ++i) - baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; - - sender.setBaseOts(baseRecvMsgs, bits); - } - else - { - // optional. You can request that the base ot are generated either - // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. - // The default is the latter, base + extension. - cp::sync_wait(sender.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); - } - - std::vector> messages(numOTs); - - // create the protocol object. - auto protocol = sender.silentSend(messages, prng, chl); - - // run the protocol - if (numThreads <= 1) - cp::sync_wait(protocol); - else - // launch the protocol on the thread pool. - cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); - - // messages has been populated with random OT messages. - // See the header for other options. - } - else - { - - SilentOtExtReceiver recver; - - // optionally request the LPN encoding matrix. - recver.mMultType = multType; - - // configure the sender. optional for semi honest security... - recver.configure(numOTs, 2, numThreads, malicious); - - if (fakeBase) - { - auto nn = recver.baseOtCount(); - BitVector bits(nn); - bits.randomize(prng); - std::vector> baseSendMsgs(bits.size()); - std::vector baseRecvMsgs(bits.size()); - - auto commonPrng = PRNG(ZeroBlock); - commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); - for (u64 i = 0; i < bits.size(); ++i) - baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; - - recver.setBaseOts(baseSendMsgs); - } - else - { - // optional. You can request that the base ot are generated either - // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. - // The default is the latter, base + extension. - cp::sync_wait(recver.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); - } - - std::vector messages(numOTs); - BitVector choices(numOTs); - - // create the protocol object. - auto protocol = recver.silentReceive(choices, messages, prng, chl); - - // run the protocol - if (numThreads <= 1) - cp::sync_wait(protocol); - else - // launch the protocol on the thread pool. - cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); - - // choices, messages has been populated with random OT messages. - // messages[i] = sender.message[i][choices[i]] - // See the header for other options. - } - auto end = timer.setTimePoint("end"); - auto milli = std::chrono::duration_cast(end - start).count(); - - u64 com = chl.bytesReceived() + chl.bytesSent(); - - if (role == Role::Sender) - { - std::string typeStr = type == SilentBaseType::Base ? "b " : "be "; - lout << tag << - " n:" << Color::Green << std::setw(6) << std::setfill(' ') << numOTs << Color::Default << - " type: " << Color::Green << typeStr << Color::Default << - " || " << Color::Green << - std::setw(6) << std::setfill(' ') << milli << " ms " << - std::setw(6) << std::setfill(' ') << com << " bytes" << std::endl << Color::Default; - - if (cmd.getOr("v", 0) > 1) - lout << gTimer << std::endl; - } - - if (cmd.isSet("v")) - { - if (role == Role::Sender) - lout << " **** sender ****\n" << timer << std::endl; - - if (role == Role::Receiver) - lout << " **** receiver ****\n" << timer << std::endl; - } - } - - } - - cp::sync_wait(chl.flush()); - + if (numOTs == 0) + numOTs = 1 << 20; + + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); + + + PRNG prng(sysRandomSeed()); + + bool fakeBase = cmd.isSet("fakeBase"); + u64 trials = cmd.getOr("trials", 1); + auto malicious = cmd.isSet("mal") ? SilentSecType::Malicious : SilentSecType::SemiHonest; + + auto multType = (MultType)cmd.getOr("multType", (int)DefaultMultType); + + std::vector types; + if (cmd.isSet("base")) + types.push_back(SilentBaseType::Base); + else + types.push_back(SilentBaseType::BaseExtend); + + macoro::thread_pool threadPool; + auto work = threadPool.make_work(); + if (numThreads > 1) + threadPool.create_threads(numThreads); + + for (auto type : types) + { + for (u64 tt = 0; tt < trials; ++tt) + { + Timer timer; + auto start = timer.setTimePoint("start"); + if (role == Role::Sender) + { + SilentOtExtSender sender; + + // optionally request the LPN encoding matrix. + sender.mMultType = multType; + + // optionally configure the sender. default is semi honest security. + sender.configure(numOTs, 2, numThreads, malicious); + + if (fakeBase) + { + auto nn = sender.baseOtCount(); + BitVector bits(nn); + bits.randomize(prng); + std::vector> baseSendMsgs(bits.size()); + std::vector baseRecvMsgs(bits.size()); + + auto commonPrng = PRNG(ZeroBlock); + commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); + for (u64 i = 0; i < bits.size(); ++i) + baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; + + sender.setBaseOts(baseRecvMsgs, bits); + } + else + { + // optional. You can request that the base ot are generated either + // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. + // The default is the latter, base + extension. + cp::sync_wait(sender.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); + } + + std::vector> messages(numOTs); + + // create the protocol object. + auto protocol = sender.silentSend(messages, prng, chl); + + // run the protocol + if (numThreads <= 1) + cp::sync_wait(protocol); + else + // launch the protocol on the thread pool. + cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); + + // messages has been populated with random OT messages. + // See the header for other options. + } + else + { + + SilentOtExtReceiver recver; + + // optionally request the LPN encoding matrix. + recver.mMultType = multType; + + // configure the sender. optional for semi honest security... + recver.configure(numOTs, 2, numThreads, malicious); + + if (fakeBase) + { + auto nn = recver.baseOtCount(); + BitVector bits(nn); + bits.randomize(prng); + std::vector> baseSendMsgs(bits.size()); + std::vector baseRecvMsgs(bits.size()); + + auto commonPrng = PRNG(ZeroBlock); + commonPrng.get(baseSendMsgs.data(), baseSendMsgs.size()); + for (u64 i = 0; i < bits.size(); ++i) + baseRecvMsgs[i] = baseSendMsgs[i][bits[i]]; + + recver.setBaseOts(baseSendMsgs); + } + else + { + // optional. You can request that the base ot are generated either + // using just base OTs (few rounds, more computation) or 128 base OTs and then extend those. + // The default is the latter, base + extension. + cp::sync_wait(recver.genSilentBaseOts(prng, chl, type == SilentBaseType::BaseExtend)); + } + + std::vector messages(numOTs); + BitVector choices(numOTs); + + // create the protocol object. + auto protocol = recver.silentReceive(choices, messages, prng, chl); + + // run the protocol + if (numThreads <= 1) + cp::sync_wait(protocol); + else + // launch the protocol on the thread pool. + cp::sync_wait(std::move(protocol) | macoro::start_on(threadPool)); + + // choices, messages has been populated with random OT messages. + // messages[i] = sender.message[i][choices[i]] + // See the header for other options. + } + auto end = timer.setTimePoint("end"); + auto milli = std::chrono::duration_cast(end - start).count(); + + u64 com = chl.bytesReceived() + chl.bytesSent(); + + if (role == Role::Sender) + { + std::string typeStr = type == SilentBaseType::Base ? "b " : "be "; + lout << tag << + " n:" << Color::Green << std::setw(6) << std::setfill(' ') << numOTs << Color::Default << + " type: " << Color::Green << typeStr << Color::Default << + " || " << Color::Green << + std::setw(6) << std::setfill(' ') << milli << " ms " << + std::setw(6) << std::setfill(' ') << com << " bytes" << std::endl << Color::Default; + + if (cmd.getOr("v", 0) > 1) + lout << gTimer << std::endl; + } + + if (cmd.isSet("v")) + { + if (role == Role::Sender) + lout << " **** sender ****\n" << timer << std::endl; + + if (role == Role::Receiver) + lout << " **** receiver ****\n" << timer << std::endl; + } + } + + } + + cp::sync_wait(chl.flush()); + +#else + std::cout << "This example requires coproto to enable boost support. Please build libOTe with `-DCOPROTO_ENABLE_BOOST=ON`. \n" << LOCATION << std::endl; #endif - } - bool Silent_Examples(const CLP& cmd) - { - return runIf(Silent_example, cmd, Silent); - } + } + bool Silent_Examples(const CLP& cmd) + { + return runIf(Silent_example, cmd, Silent); + } } diff --git a/frontend/ExampleTwoChooseOne.cpp b/frontend/ExampleTwoChooseOne.cpp index e8affeab..707fff98 100644 --- a/frontend/ExampleTwoChooseOne.cpp +++ b/frontend/ExampleTwoChooseOne.cpp @@ -19,340 +19,341 @@ namespace osuCrypto { #ifdef ENABLE_IKNP - void noHash(IknpOtExtSender& s, IknpOtExtReceiver& r) - { - s.mHashType = HashType::NoHash; - r.mHashType = HashType::NoHash; - } + void noHash(IknpOtExtSender& s, IknpOtExtReceiver& r) + { + s.mHashType = HashType::NoHash; + r.mHashType = HashType::NoHash; + } #endif - template - void noHash(Sender&, Receiver&) - { - throw std::runtime_error("This protocol does not support noHash"); - } + template + void noHash(Sender&, Receiver&) + { + throw std::runtime_error("This protocol does not support noHash"); + } #ifdef ENABLE_SOFTSPOKEN_OT - // soft spoken takes an extra parameter as input what determines - // the computation/communication trade-off. - template - using is_SoftSpoken = typename std::conditional< - std::is_same>::value || - std::is_same>::value || - std::is_same::value || - std::is_same::value, - std::true_type, - std::false_type>::type; + // soft spoken takes an extra parameter as input what determines + // the computation/communication trade-off. + template + using is_SoftSpoken = typename std::conditional< + std::is_same>::value || + std::is_same>::value || + std::is_same::value || + std::is_same::value, + std::true_type, + std::false_type>::type; #else - template - using is_SoftSpoken = std::false_type; + template + using is_SoftSpoken = std::false_type; #endif - template - typename std::enable_if::value, T>::type - construct(const CLP& cmd) - { - return T(cmd.getOr("f", 2)); - } + template + typename std::enable_if::value, T>::type + construct(const CLP& cmd) + { + return T(cmd.getOr("f", 2)); + } - template - typename std::enable_if::value, T>::type - construct(const CLP& cmd) - { - return T{}; - } + template + typename std::enable_if::value, T>::type + construct(const CLP& cmd) + { + return T{}; + } - template - void TwoChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& cmd) - { + template + void TwoChooseOne_example(Role role, int totalOTs, int numThreads, std::string ip, std::string tag, const CLP& cmd) + { #ifdef COPROTO_ENABLE_BOOST - if (totalOTs == 0) - totalOTs = 1 << 20; + if (totalOTs == 0) + totalOTs = 1 << 20; - bool randomOT = !cmd.isSet("chosen"); + bool randomOT = !cmd.isSet("chosen"); - // get up the networking - auto chl = cp::asioConnect(ip, role == Role::Sender); + // get up the networking + auto chl = cp::asioConnect(ip, role == Role::Sender); - PRNG prng(sysRandomSeed()); + PRNG prng(sysRandomSeed()); - OtExtSender sender = construct(cmd); - OtExtRecver receiver = construct(cmd); + OtExtSender sender = construct(cmd); + OtExtRecver receiver = construct(cmd); #ifdef LIBOTE_HAS_BASE_OT - // Now compute the base OTs, we need to set them on the first pair of extenders. - // In real code you would only have a sender or reciever, not both. But we do - // here just showing the example. - if (role == Role::Receiver) - { - DefaultBaseOT base; - std::vector> baseMsg(receiver.baseOtCount()); - - // perform the base To, call sync_wait to block until they have completed. - cp::sync_wait(base.send(baseMsg, prng, chl)); - receiver.setBaseOts(baseMsg); - } - else - { - - DefaultBaseOT base; - BitVector bv(sender.baseOtCount()); - std::vector baseMsg(sender.baseOtCount()); - bv.randomize(prng); - - // perform the base To, call sync_wait to block until they have completed. - cp::sync_wait(base.receive(bv, baseMsg, prng, chl)); - sender.setBaseOts(baseMsg, bv); - } + // Now compute the base OTs, we need to set them on the first pair of extenders. + // In real code you would only have a sender or reciever, not both. But we do + // here just showing the example. + if (role == Role::Receiver) + { + DefaultBaseOT base; + std::vector> baseMsg(receiver.baseOtCount()); + + // perform the base To, call sync_wait to block until they have completed. + cp::sync_wait(base.send(baseMsg, prng, chl)); + receiver.setBaseOts(baseMsg); + } + else + { + + DefaultBaseOT base; + BitVector bv(sender.baseOtCount()); + std::vector baseMsg(sender.baseOtCount()); + bv.randomize(prng); + + // perform the base To, call sync_wait to block until they have completed. + cp::sync_wait(base.receive(bv, baseMsg, prng, chl)); + sender.setBaseOts(baseMsg, bv); + } #else - if (!cmd.isSet("fakeBase")) - std::cout << "warning, base ots are not enabled. Fake base OTs will be used. " << std::endl; - PRNG commonPRNG(oc::CCBlock); - if (role == Role::Receiver) - { - std::vector> sendMsgs(receiver.baseOtCount()); - commonPRNG.get(sendMsgs.data(), sendMsgs.size()); - receiver.setBaseOts(sendMsgs); - } - else - { - - std::vector> sendMsgs(sender.baseOtCount()); - commonPRNG.get(sendMsgs.data(), sendMsgs.size()); - - BitVector bv(sendMsgs.size()); - bv.randomize(commonPRNG); - std::vector recvMsgs(sendMsgs.size()); - for (u64 i = 0; i < sendMsgs.size(); ++i) - recvMsgs[i] = sendMsgs[i][bv[i]]; - sender.setBaseOts(recvMsgs, bv); - } + if (!cmd.isSet("fakeBase")) + std::cout << "warning, base ots are not enabled. Fake base OTs will be used. " << std::endl; + PRNG commonPRNG(oc::CCBlock); + if (role == Role::Receiver) + { + std::vector> sendMsgs(receiver.baseOtCount()); + commonPRNG.get(sendMsgs.data(), sendMsgs.size()); + receiver.setBaseOts(sendMsgs); + } + else + { + + std::vector> sendMsgs(sender.baseOtCount()); + commonPRNG.get(sendMsgs.data(), sendMsgs.size()); + + BitVector bv(sendMsgs.size()); + bv.randomize(commonPRNG); + std::vector recvMsgs(sendMsgs.size()); + for (u64 i = 0; i < sendMsgs.size(); ++i) + recvMsgs[i] = sendMsgs[i][bv[i]]; + sender.setBaseOts(recvMsgs, bv); + } #endif - if (cmd.isSet("noHash")) - noHash(sender, receiver); - - Timer timer, sendTimer, recvTimer; - sendTimer.setTimePoint("start"); - recvTimer.setTimePoint("start"); - auto s = timer.setTimePoint("start"); - - if (numThreads == 1) - { - if (role == Role::Receiver) - { - // construct the choices that we want. - BitVector choice(totalOTs); - // in this case pick random messages. - choice.randomize(prng); - - // construct a vector to stored the received messages. - AlignedUnVector rMsgs(totalOTs); - - try { - - if (randomOT) - { - // perform totalOTs random OTs, the results will be written to msgs. - cp::sync_wait(receiver.receive(choice, rMsgs, prng, chl)); - } - else - { - // perform totalOTs chosen message OTs, the results will be written to msgs. - cp::sync_wait(receiver.receiveChosen(choice, rMsgs, prng, chl)); - } - } - catch (std::exception& e) - { - std::cout << e.what() << std::endl; - cp::sync_wait(chl.close()); - } - } - else - { - // construct a vector to stored the random send messages. - AlignedUnVector> sMsgs(totalOTs); - - - // if delta OT is used, then the user can call the following - // to set the desired XOR difference between the zero messages - // and the one messages. - // - // senders[i].setDelta(some 128 bit delta); - // - try - { - if (randomOT) - { - // perform the OTs and write the random OTs to msgs. - cp::sync_wait(sender.send(sMsgs, prng, chl)); - } - else - { - // Populate msgs with something useful... - prng.get(sMsgs.data(), sMsgs.size()); - - // perform the OTs. The receiver will learn one - // of the messages stored in msgs. - cp::sync_wait(sender.sendChosen(sMsgs, prng, chl)); - } - } - catch (std::exception& e) - { - std::cout << e.what() << std::endl; - cp::sync_wait(chl.close()); - } - } - - } - else - { - - // for multi threading, we only show example for random OTs. - // We first need to construct the inputs - // that each thread will use. Note that the actual protocol - // is not thread safe so everything needs to be independent. - std::vector> tasks(numThreads); - std::vector threadPrngs(numThreads); - std::vector threadChls(numThreads); - - macoro::thread_pool::work work; - macoro::thread_pool threadPool(numThreads, work); - - if (role == Role::Receiver) - { - std::vector receivers(numThreads); - std::vector threadChoices(numThreads); - std::vector> threadMsgs(numThreads); - - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - { - u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); - u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); - - threadChoices[threadIndex].resize(endIndex - beginIndex); - threadChoices[threadIndex].randomize(prng); - - threadMsgs[threadIndex].resize(endIndex - beginIndex); - - // create a copy of the receiver so that each can run - // independently. A single receiver is not thread safe. - receivers[threadIndex] = receiver.splitBase(); - - // create a PRNG for this thread. - threadPrngs[threadIndex].SetSeed(prng.get()); - - // create a socket for this thread. This is done by calling fork(). - threadChls[threadIndex] = chl.fork(); - - // start the receive protocol on the thread pool - tasks[threadIndex] = - receivers[threadIndex].receive( - threadChoices[threadIndex], - threadMsgs[threadIndex], - threadPrngs[threadIndex], - threadChls[threadIndex]) - | macoro::start_on(threadPool); - } - - // block this thread until the receive operations - // have completed. - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - cp::sync_wait(tasks[threadIndex]); - } - else - { - std::vector senders(numThreads); - std::vector>> threadMsgs(numThreads); - - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - { - u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); - u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); - - threadMsgs[threadIndex].resize(endIndex - beginIndex); - - // create a copy of the receiver so that each can run - // independently. A single receiver is not thread safe. - senders[threadIndex] = sender.splitBase(); - - // create a PRNG for this thread. - threadPrngs[threadIndex].SetSeed(prng.get()); - - // create a socket for this thread. This is done by calling fork(). - threadChls[threadIndex] = chl.fork(); - - // start the send protocol on the thread pool - tasks[threadIndex] = - senders[threadIndex].send( - threadMsgs[threadIndex], - threadPrngs[threadIndex], - threadChls[threadIndex]) - | macoro::start_on(threadPool); - } - - // block this thread until the receive operations - // have completed. - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - cp::sync_wait(tasks[threadIndex]); - } - - work.reset(); - for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) - macoro::sync_wait(threadChls[threadIndex].flush()); - } - - - // make sure all messages have been sent. - cp::sync_wait(chl.flush()); - - auto e = timer.setTimePoint("finish"); - auto milli = std::chrono::duration_cast(e - s).count(); - - auto com = 0;// (chls[0].getTotalDataRecv() + chls[0].getTotalDataSent())* numThreads; - - if (role == Role::Sender) - lout << tag << " n=" << Color::Green << totalOTs << " " << milli << " ms " << com << " bytes" << std::endl << Color::Default; + if (cmd.isSet("noHash")) + noHash(sender, receiver); + + Timer timer, sendTimer, recvTimer; + sendTimer.setTimePoint("start"); + recvTimer.setTimePoint("start"); + auto s = timer.setTimePoint("start"); + + if (numThreads == 1) + { + if (role == Role::Receiver) + { + // construct the choices that we want. + BitVector choice(totalOTs); + // in this case pick random messages. + choice.randomize(prng); + + // construct a vector to stored the received messages. + AlignedUnVector rMsgs(totalOTs); + + try { + + if (randomOT) + { + // perform totalOTs random OTs, the results will be written to msgs. + cp::sync_wait(receiver.receive(choice, rMsgs, prng, chl)); + } + else + { + // perform totalOTs chosen message OTs, the results will be written to msgs. + cp::sync_wait(receiver.receiveChosen(choice, rMsgs, prng, chl)); + } + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + cp::sync_wait(chl.close()); + } + } + else + { + // construct a vector to stored the random send messages. + AlignedUnVector> sMsgs(totalOTs); + + + // if delta OT is used, then the user can call the following + // to set the desired XOR difference between the zero messages + // and the one messages. + // + // senders[i].setDelta(some 128 bit delta); + // + try + { + if (randomOT) + { + // perform the OTs and write the random OTs to msgs. + cp::sync_wait(sender.send(sMsgs, prng, chl)); + } + else + { + // Populate msgs with something useful... + prng.get(sMsgs.data(), sMsgs.size()); + + // perform the OTs. The receiver will learn one + // of the messages stored in msgs. + cp::sync_wait(sender.sendChosen(sMsgs, prng, chl)); + } + } + catch (std::exception& e) + { + std::cout << e.what() << std::endl; + cp::sync_wait(chl.close()); + } + } + + } + else + { + + // for multi threading, we only show example for random OTs. + // We first need to construct the inputs + // that each thread will use. Note that the actual protocol + // is not thread safe so everything needs to be independent. + std::vector> tasks(numThreads); + std::vector threadPrngs(numThreads); + std::vector threadChls(numThreads); + + macoro::thread_pool::work work; + macoro::thread_pool threadPool(numThreads, work); + + if (role == Role::Receiver) + { + std::vector receivers(numThreads); + std::vector threadChoices(numThreads); + std::vector> threadMsgs(numThreads); + + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + { + u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); + u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); + + threadChoices[threadIndex].resize(endIndex - beginIndex); + threadChoices[threadIndex].randomize(prng); + + threadMsgs[threadIndex].resize(endIndex - beginIndex); + + // create a copy of the receiver so that each can run + // independently. A single receiver is not thread safe. + receivers[threadIndex] = receiver.splitBase(); + + // create a PRNG for this thread. + threadPrngs[threadIndex].SetSeed(prng.get()); + + // create a socket for this thread. This is done by calling fork(). + threadChls[threadIndex] = chl.fork(); + + // start the receive protocol on the thread pool + tasks[threadIndex] = + receivers[threadIndex].receive( + threadChoices[threadIndex], + threadMsgs[threadIndex], + threadPrngs[threadIndex], + threadChls[threadIndex]) + | macoro::start_on(threadPool); + } + + // block this thread until the receive operations + // have completed. + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + cp::sync_wait(tasks[threadIndex]); + } + else + { + std::vector senders(numThreads); + std::vector>> threadMsgs(numThreads); + + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + { + u64 beginIndex = oc::roundUpTo(totalOTs * threadIndex / numThreads, 128); + u64 endIndex = oc::roundUpTo((totalOTs + 1) * threadIndex / numThreads, 128); + + threadMsgs[threadIndex].resize(endIndex - beginIndex); + + // create a copy of the receiver so that each can run + // independently. A single receiver is not thread safe. + senders[threadIndex] = sender.splitBase(); + + // create a PRNG for this thread. + threadPrngs[threadIndex].SetSeed(prng.get()); + + // create a socket for this thread. This is done by calling fork(). + threadChls[threadIndex] = chl.fork(); + + // start the send protocol on the thread pool + tasks[threadIndex] = + senders[threadIndex].send( + threadMsgs[threadIndex], + threadPrngs[threadIndex], + threadChls[threadIndex]) + | macoro::start_on(threadPool); + } + + // block this thread until the receive operations + // have completed. + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + cp::sync_wait(tasks[threadIndex]); + } + + work.reset(); + for (u64 threadIndex = 0; threadIndex < (u64)numThreads; ++threadIndex) + macoro::sync_wait(threadChls[threadIndex].flush()); + } + + + // make sure all messages have been sent. + cp::sync_wait(chl.flush()); + + auto e = timer.setTimePoint("finish"); + auto milli = std::chrono::duration_cast(e - s).count(); + + auto com = 0;// (chls[0].getTotalDataRecv() + chls[0].getTotalDataSent())* numThreads; + + if (role == Role::Sender) + lout << tag << " n=" << Color::Green << totalOTs << " " << milli << " ms " << com << " bytes" << std::endl << Color::Default; - if (cmd.isSet("v") && role == Role::Sender) - { - if (role == Role::Sender) - lout << " **** sender ****\n" << sendTimer << std::endl; - - if (role == Role::Receiver) - lout << " **** receiver ****\n" << recvTimer << std::endl; - } + if (cmd.isSet("v") && role == Role::Sender) + { + if (role == Role::Sender) + lout << " **** sender ****\n" << sendTimer << std::endl; + + if (role == Role::Receiver) + lout << " **** receiver ****\n" << recvTimer << std::endl; + } #else - throw std::runtime_error("This example requires coproto to enable boost support. Please build libOTe with `-DCOPROTO_ENABLE_BOOST=ON`. " LOCATION); + + std::cout << "This example requires coproto to enable boost support. Please build libOTe with `-DCOPROTO_ENABLE_BOOST=ON`. \n" << LOCATION << std::endl; #endif - } + } - bool TwoChooseOne_Examples(const CLP& cmd) - { - bool flagSet = false; + bool TwoChooseOne_Examples(const CLP& cmd) + { + bool flagSet = false; #ifdef ENABLE_IKNP - flagSet |= runIf(TwoChooseOne_example, cmd, iknp); + flagSet |= runIf(TwoChooseOne_example, cmd, iknp); #endif #ifdef ENABLE_KOS - flagSet |= runIf(TwoChooseOne_example, cmd, kos); + flagSet |= runIf(TwoChooseOne_example, cmd, kos); #endif #ifdef ENABLE_DELTA_KOS - flagSet |= runIf(TwoChooseOne_example, cmd, dkos); + flagSet |= runIf(TwoChooseOne_example, cmd, dkos); #endif #ifdef ENABLE_SOFTSPOKEN_OT - flagSet |= runIf(TwoChooseOne_example, SoftSpokenShOtReceiver<>>, cmd, sshonest); - flagSet |= runIf(TwoChooseOne_example, cmd, smalicious); + flagSet |= runIf(TwoChooseOne_example, SoftSpokenShOtReceiver<>>, cmd, sshonest); + flagSet |= runIf(TwoChooseOne_example, cmd, smalicious); #endif - return flagSet; - } + return flagSet; + } } diff --git a/frontend/ExampleVole.cpp b/frontend/ExampleVole.cpp index 2ef71894..902e68a3 100644 --- a/frontend/ExampleVole.cpp +++ b/frontend/ExampleVole.cpp @@ -100,6 +100,8 @@ namespace osuCrypto // make sure all messages are sent. cp::sync_wait(chl.flush()); +#else + std::cout << "This example requires coproto to enable boost support. Please build libOTe with `-DCOPROTO_ENABLE_BOOST=ON`. \n" << LOCATION << std::endl; #endif } bool Vole_Examples(const CLP& cmd) diff --git a/frontend/benchmark.h b/frontend/benchmark.h index d066eedb..e56a863a 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -783,16 +783,16 @@ namespace osuCrypto u64 numPoints = cmd.getOr("numPoints", 1000); u64 trials = cmd.getOr("trials", 1); - std::vector points0(numPoints); - std::vector points1(numPoints); - std::vector points(numPoints); + std::vector points0(numPoints); + std::vector points1(numPoints); + std::vector points(numPoints); std::vector values0(numPoints); std::vector values1(numPoints); //Ctx ctx; for (u64 i = 0; i < numPoints; ++i) { - points[i] = Trit32(prng.get() % domain); - points1[i] = Trit32(prng.get() % domain); + points[i] = F3x32(prng.get() % domain); + points1[i] = F3x32(prng.get() % domain); points0[i] = points[i] - points1[i]; //std::cout << points[i] << " = " << points0[i] <<" + "<< points1[i] << std::endl; values0[i] = prng.get(); diff --git a/frontend/main.cpp b/frontend/main.cpp index 06eb1fe2..7feac005 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -183,7 +183,7 @@ int main(int argc, char** argv) std::cout - << "Protocols:\n" + << "Example Protocols:\n" << Color::Green << " -simplest-asm " << Color::Default << " : to run the ASM-SimplestOT active secure 1-out-of-2 base OT " << Color::Red << (spaEnabled ? "" : "(disabled)") << "\n" << Color::Default << Color::Green << " -simplest " << Color::Default << " : to run the SimplestOT active secure 1-out-of-2 base OT " << Color::Red << (spEnabled ? "" : "(disabled)") << "\n" << Color::Default << Color::Green << " -moellerpopf " << Color::Default << " : to run the McRosRoyTwist active secure 1-out-of-2 base OT " << Color::Red << (popfotMoellerEnabled ? "" : "(disabled)") << "\n" << Color::Default @@ -192,7 +192,7 @@ int main(int argc, char** argv) << Color::Green << " -np " << Color::Default << " : to run the NaorPinkas active secure 1-out-of-2 base OT " << Color::Red << (npEnabled ? "" : "(disabled)") << "\n" << Color::Default << Color::Green << " -iknp " << Color::Default << " : to run the IKNP passive secure 1-out-of-2 OT " << Color::Red << (iknpEnabled ? "" : "(disabled)") << "\n" << Color::Default << Color::Green << " -diknp " << Color::Default << " : to run the IKNP passive secure 1-out-of-2 Delta-OT " << Color::Red << (diknpEnabled ? "" : "(disabled)") << "\n" << Color::Default - << Color::Green << " -Silent " << Color::Default << " : to run the Silent passive secure 1-out-of-2 OT " << Color::Red << (silentEnabled ? "" : "(disabled)") << "\n" << Color::Default + << Color::Green << " -Silent " << Color::Default << " : to run the Silent active secure 1-out-of-2 OT " << Color::Red << (silentEnabled ? "" : "(disabled)") << "\n" << Color::Default << Color::Green << " -kos " << Color::Default << " : to run the KOS active secure 1-out-of-2 OT " << Color::Red << (kosEnabled ? "" : "(disabled)") << "\n" << Color::Default << Color::Green << " -dkos " << Color::Default << " : to run the KOS active secure 1-out-of-2 Delta-OT " << Color::Red << (dkosEnabled ? "" : "(disabled)") << "\n" << Color::Default << Color::Green << " -ssdelta " << Color::Default << " : to run the SoftSpoken passive secure 1-out-of-2 Delta-OT " << Color::Red << (softSpokenEnabled ? "" : "(disabled)") << "\n" << Color::Default diff --git a/libOTe/CMakeLists.txt b/libOTe/CMakeLists.txt index 6e1b8c27..62fc59ca 100644 --- a/libOTe/CMakeLists.txt +++ b/libOTe/CMakeLists.txt @@ -18,6 +18,7 @@ target_include_directories(libOTe PUBLIC target_link_libraries(libOTe cryptoTools) if(MSVC) + target_compile_definitions(libOTe PUBLIC _NO_DEBUG_HEAP=1) target_compile_options(libOTe PRIVATE $<$:/std:c++${LIBOTE_STD_VER}>) #target_compile_options(libOTe PRIVATE -openmp:experimental) else() diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Tools/Dpf/TriDpf.h index 27494703..73167ae4 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Tools/Dpf/TriDpf.h @@ -15,23 +15,23 @@ namespace osuCrypto { // a value representing (Z_3)^32. // The value is stored in 2 bits per Z_3 element. - struct Trit32 + struct F3x32 { u64 mVal; - Trit32() = default; - Trit32(const Trit32&) = default; + F3x32() = default; + F3x32(const F3x32&) = default; - Trit32(u64 v) + F3x32(u64 v) { fromInt(v); } - Trit32& operator=(const Trit32&) = default; + F3x32& operator=(const F3x32&) = default; - Trit32 operator+(const Trit32& t) const + F3x32 operator+(const F3x32& t) const { - Trit32 r; + F3x32 r; r.mVal = 0; for (u64 i = 0; i < 32; ++i) { @@ -45,9 +45,9 @@ namespace osuCrypto } - Trit32 operator-(const Trit32& t) const + F3x32 operator-(const F3x32& t) const { - Trit32 r; + F3x32 r; r.mVal = 0; for (u64 i = 0; i < 32; ++i) { @@ -61,7 +61,7 @@ namespace osuCrypto } - bool operator==(const Trit32& t) const + bool operator==(const F3x32& t) const { return mVal == t.mVal; } @@ -89,15 +89,15 @@ namespace osuCrypto } } - Trit32 lower(u64 digits) + F3x32 lower(u64 digits) { - Trit32 r; + F3x32 r; r.mVal = mVal & ((1ull << (2 * digits)) - 1); return r; } - Trit32 upper(u64 digits) + F3x32 upper(u64 digits) { - Trit32 r; + F3x32 r; r.mVal = mVal >> (2 * digits); return r; } @@ -109,7 +109,7 @@ namespace osuCrypto } }; - inline std::ostream& operator<<(std::ostream& o, const Trit32& t) + inline std::ostream& operator<<(std::ostream& o, const F3x32& t) { u64 m = 0; u64 v = t.mVal; @@ -150,9 +150,9 @@ namespace osuCrypto u64 mOtIdx = 0; - std::vector> mBaseSendOts; - std::vector mBaseRecvOts; - std::vector mBaseChoice; + AlignedUnVector> mBaseSendOts; + AlignedUnVector mBaseRecvOts; + AlignedUnVector mBaseChoice; void init( u64 partyIdx, @@ -191,7 +191,7 @@ namespace osuCrypto template macoro::task<> expand( - span points, + span points, Fs&& values, Output&& output, PRNG& prng, @@ -320,7 +320,7 @@ namespace osuCrypto for (u64 iter = 1; iter <= mDepth; ++iter) { - co_await correctionWord(points, z, sigma, iter, sock); + co_await correctionWord(points, z, sigma, iter, prng, sock); //std::cout << "sigma[" << iter << "] " << sigma[0][0] << " " << sigma[1][0] << " " << sigma[2][0] << std::endl; //std::cout << "tau[" << iter << "] " << int(tau[0][0]) << " " << int(tau[1][0]) << " " << int(tau[2][0]) << std::endl; @@ -531,22 +531,14 @@ namespace osuCrypto } - // we are going to create 3 ot message - // - // m0, m1, m2 - // - // such that m_{-a0} = r || 1 for some random r. - // - // the receiver will use choice a1. - macoro::task<> correctionWord(span points, MatrixView z, MatrixView sigma, u64 iter, coproto::Socket& sock) + macoro::task<> correctionWord( + span points, + MatrixView z, + MatrixView sigma, + u64 iter, + PRNG& prng, + coproto::Socket& sock) { - //{ - // char x = 0; - // co_await sock.send(char{ x }); - // co_await sock.recv(x); - //} - //std::cout << "=======" << iter << "======== " << std::endl; - Matrix sigmaShares(3, mNumPoints, AllocType::Uninitialized); Matrix mask(mNumPoints, 3, AllocType::Uninitialized); Matrix recvBuffer(mNumPoints * 2, 3, AllocType::Uninitialized); @@ -558,7 +550,6 @@ namespace osuCrypto std::swap(socks[0], socks[1]); auto expand3 = [](const block& k, span r) { - //r = PRNG(k, 3).get(); r[0] = k; r[1] = k ^ block(3450136502437610243, 6108362938092146510); r[2] = k ^ block(3428970074314387014, 2030711220607601239); @@ -567,7 +558,6 @@ namespace osuCrypto auto sender = [&]() -> macoro::task<> { - PRNG prng(block(234134, 21452345 * mPartyIdx)); BitVector correction(mNumPoints * 2); AlignedUnVector sendBuffer(mNumPoints * 2 * sizeof(std::array)); diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Tools/Foleage/FoleagePcg.cpp index bdeb4af7..8c28fd8b 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Tools/Foleage/FoleagePcg.cpp @@ -253,12 +253,12 @@ namespace osuCrypto // each prodPolyF4Coeffs is positioned at prodPolyLeafPos. This // will allow the main DPF to be more efficient as we are outputting // 243 F4 elements for each leaf. - std::vector prodPolyLeafPos(mC * mC * mT * mT); + std::vector prodPolyLeafPos(mC * mC * mT * mT); // once we construct large F4^243 coefficients, we will expand them // the main DPF to get the full shared polynomail. prodPolyTreePos // is the location that the F4^243 coefficient should be mapped to. - std::vector prodPolyTreePos(mC * mC * mT * mT); + std::vector prodPolyTreePos(mC * mC * mT * mT); @@ -278,7 +278,7 @@ namespace osuCrypto // the block of the product coefficient is known // purely using the block index of the input coefficients. - auto blockPos = Trit32(jA) + Trit32(jB); + auto blockPos = F3x32(jA) + F3x32(jB); auto blockIdx = blockPos.toInt(); // We want to put all DPF that will be added together @@ -290,7 +290,7 @@ namespace osuCrypto // the F4 coefficient within the F4^243 coefficient and the // portion that will position the F4^243 coefficient within // the main DPF. - auto pos = Trit32(mSparsePositions(i, j)); + auto pos = F3x32(mSparsePositions(i, j)); // (F_3)^n + (F_3)^n prodPolyLeafPos[idx] = pos.lower(mDpfLeafDepth); prodPolyTreePos[idx] = pos.upper(mDpfLeafDepth); diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index f84740a0..705c5c0b 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -402,16 +402,16 @@ void TritDpf_Proto_Test_(const oc::CLP& cmd) u64 depth = cmd.getOr("depth", 3); u64 domain = ipow(3, depth) - 3; u64 numPoints = cmd.getOr("numPoints", 17); - std::vector points0(numPoints); - std::vector points1(numPoints); - std::vector points(numPoints); + std::vector points0(numPoints); + std::vector points1(numPoints); + std::vector points(numPoints); std::vector values0(numPoints); std::vector values1(numPoints); Ctx ctx; for (u64 i = 0; i < numPoints; ++i) { - points[i] = Trit32(prng.get() % domain); - points1[i] = Trit32(prng.get() % domain); + points[i] = F3x32(prng.get() % domain); + points1[i] = F3x32(prng.get() % domain); points0[i] = points[i] - points1[i]; //std::cout << points[i] << " = " << points0[i] <<" + "<< points1[i] << std::endl; values0[i] = prng.get(); @@ -462,7 +462,7 @@ void TritDpf_Proto_Test_(const oc::CLP& cmd) for (u64 i = 0; i < domain; ++i) { - Trit32 I(i); + F3x32 I(i); for (u64 k = 0; k < numPoints; ++k) { F act; @@ -477,7 +477,7 @@ void TritDpf_Proto_Test_(const oc::CLP& cmd) if (exp != act) { - std::cout << "i " << i << "=" << Trit32(i) << " " << t << std::endl; + std::cout << "i " << i << "=" << F3x32(i) << " " << t << std::endl; std::cout << "exp " << exp << std::endl; std::cout << "act " << act << std::endl; throw RTE_LOC; From 871bbc2353ba8e4fb1dbdca0c920c26321a3a7df Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 10:12:58 -0800 Subject: [PATCH 23/48] foleage cleanup --- frontend/benchmark.h | 8 +- libOTe/{Tools => }/Dpf/DpfMult.h | 0 libOTe/{Tools => }/Dpf/RegularDpf.h | 0 libOTe/{Tools => }/Dpf/SparseDpf.h | 0 libOTe/{Tools => }/Dpf/TriDpf.h | 2 +- libOTe/Tools/Foleage/F4Ops.h | 218 --- libOTe/Tools/Foleage/FoleageMain.cpp | 304 ---- libOTe/Tools/Foleage/FoleageUtils.h | 392 ---- libOTe/Tools/Foleage/PerfectShuffle.h | 437 ----- libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp | 138 -- libOTe/Tools/Foleage/fft/FoleageFFT_bench.h | 13 - libOTe/Tools/Foleage/fft/FoleageFft.cpp | 856 --------- libOTe/Tools/Foleage/fft/FoleageFft.h | 388 ---- libOTe/Tools/Foleage/spfss_test.cpp | 115 -- libOTe/Tools/Foleage/tri-dpf/.gitignore | 5 - libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp | 317 ---- libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h | 35 - .../Tools/Foleage/tri-dpf/FoleageDpf_test.cpp | 166 -- .../Tools/Foleage/tri-dpf/FoleageDpf_test.h | 9 - libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h | 80 - libOTe/Tools/Foleage/tri-dpf/LICENSE | 9 - libOTe/Tools/Foleage/tri-dpf/README.md | 116 -- libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h | 68 - libOTe/Tools/Foleage/uint128.h | 790 --------- .../Foleage/FoleageTriple.cpp} | 231 ++- .../Foleage/FoleageTriple.h} | 73 +- libOTe/Triple/Foleage/FoleageUtils.h | 266 +++ libOTe/Triple/Foleage/fft/FoleageFft.cpp | 310 ++++ libOTe/Triple/Foleage/fft/FoleageFft.h | 41 + .../Silent/SilentOtExtReceiver.cpp | 18 +- .../TwoChooseOne/Silent/SilentOtExtSender.cpp | 10 +- libOTe/TwoChooseOne/TcoOtDefines.h | 2 +- libOTe_Tests/Foleage_Tests.cpp | 1570 ++--------------- libOTe_Tests/Foleage_Tests.h | 12 +- libOTe_Tests/RegularDpf_Tests.cpp | 6 +- libOTe_Tests/UnitTests.cpp | 9 +- 36 files changed, 1054 insertions(+), 5960 deletions(-) rename libOTe/{Tools => }/Dpf/DpfMult.h (100%) rename libOTe/{Tools => }/Dpf/RegularDpf.h (100%) rename libOTe/{Tools => }/Dpf/SparseDpf.h (100%) rename libOTe/{Tools => }/Dpf/TriDpf.h (99%) delete mode 100644 libOTe/Tools/Foleage/F4Ops.h delete mode 100644 libOTe/Tools/Foleage/FoleageMain.cpp delete mode 100644 libOTe/Tools/Foleage/FoleageUtils.h delete mode 100644 libOTe/Tools/Foleage/PerfectShuffle.h delete mode 100644 libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp delete mode 100644 libOTe/Tools/Foleage/fft/FoleageFFT_bench.h delete mode 100644 libOTe/Tools/Foleage/fft/FoleageFft.cpp delete mode 100644 libOTe/Tools/Foleage/fft/FoleageFft.h delete mode 100644 libOTe/Tools/Foleage/spfss_test.cpp delete mode 100644 libOTe/Tools/Foleage/tri-dpf/.gitignore delete mode 100644 libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp delete mode 100644 libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h delete mode 100644 libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp delete mode 100644 libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.h delete mode 100644 libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h delete mode 100644 libOTe/Tools/Foleage/tri-dpf/LICENSE delete mode 100644 libOTe/Tools/Foleage/tri-dpf/README.md delete mode 100644 libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h delete mode 100644 libOTe/Tools/Foleage/uint128.h rename libOTe/{Tools/Foleage/FoleagePcg.cpp => Triple/Foleage/FoleageTriple.cpp} (72%) rename libOTe/{Tools/Foleage/FoleagePcg.h => Triple/Foleage/FoleageTriple.h} (72%) create mode 100644 libOTe/Triple/Foleage/FoleageUtils.h create mode 100644 libOTe/Triple/Foleage/fft/FoleageFft.cpp create mode 100644 libOTe/Triple/Foleage/fft/FoleageFft.h diff --git a/frontend/benchmark.h b/frontend/benchmark.h index e56a863a..855daf72 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -15,9 +15,9 @@ #include "libOTe/Tools/CoeffCtx.h" #include "libOTe/Tools/TungstenCode/TungstenCode.h" #include "libOTe/Tools/ExConvCodeOld/ExConvCodeOld.h" -#include "libOTe/Tools/Dpf/RegularDpf.h" -#include "libOTe/Tools/Dpf/TriDpf.h" -#include "libOTe/Tools/Foleage/FoleagePcg.h" +#include "libOTe/Dpf/RegularDpf.h" +#include "libOTe/Dpf/TriDpf.h" +#include "libOTe/Triple/Foleage/FoleageTriple.h" namespace osuCrypto { @@ -885,7 +885,7 @@ namespace osuCrypto for (u64 ii = 0; ii < trials; ++ii) { - std::array oles; + std::array oles; if (cmd.hasValue("t")) oles[0].mT = oles[1].mT = cmd.get("t"); if (cmd.hasValue("c")) diff --git a/libOTe/Tools/Dpf/DpfMult.h b/libOTe/Dpf/DpfMult.h similarity index 100% rename from libOTe/Tools/Dpf/DpfMult.h rename to libOTe/Dpf/DpfMult.h diff --git a/libOTe/Tools/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h similarity index 100% rename from libOTe/Tools/Dpf/RegularDpf.h rename to libOTe/Dpf/RegularDpf.h diff --git a/libOTe/Tools/Dpf/SparseDpf.h b/libOTe/Dpf/SparseDpf.h similarity index 100% rename from libOTe/Tools/Dpf/SparseDpf.h rename to libOTe/Dpf/SparseDpf.h diff --git a/libOTe/Tools/Dpf/TriDpf.h b/libOTe/Dpf/TriDpf.h similarity index 99% rename from libOTe/Tools/Dpf/TriDpf.h rename to libOTe/Dpf/TriDpf.h index 73167ae4..d0abb00e 100644 --- a/libOTe/Tools/Dpf/TriDpf.h +++ b/libOTe/Dpf/TriDpf.h @@ -8,7 +8,7 @@ #include "cryptoTools/Common/Matrix.h" #include "DpfMult.h" -#include "libOTe/Tools/Foleage/FoleageUtils.h" +#include "libOTe/Triple/Foleage/FoleageUtils.h" #include "libOTe/Tools/CoeffCtx.h" namespace osuCrypto diff --git a/libOTe/Tools/Foleage/F4Ops.h b/libOTe/Tools/Foleage/F4Ops.h deleted file mode 100644 index 49198ed7..00000000 --- a/libOTe/Tools/Foleage/F4Ops.h +++ /dev/null @@ -1,218 +0,0 @@ -#pragma once - -#include "libOTe/Tools/Foleage/FoleageUtils.h" - -namespace osuCrypto -{ - //typedef __int128 int128_t; - //typedef unsigned __int128 uint128_t; - - // Samples a non-zero element of F4 - inline uint8_t rand_f4x(PRNG& prng) - { - uint8_t t = 0; - while (t == 0) - { - t = prng.get() & 3; - } - return t; - } - - // Multiplies two elements of F4 (optionally: 4 elements packed into uint8_t) - // and returns the result. - inline uint8_t mult_f4(uint8_t a, uint8_t b) - { - u8 tmp = ((a & 0b10) & (b & 0b10)); - uint8_t res = tmp ^ (((a & 0b10) & ((b & 0b01) << 1)) ^ (((a & 0b01) << 1) & (b & 0b10))); - res |= ((a & 0b01) & (b & 0b01)) ^ (tmp >> 1); - return res; - } - - inline void f4Mult( - block aLsb, block aMsb, - block bLsb, block bMsb, - block& cLsb, block& cMsb) - { - auto tmp = aMsb & bMsb;// msb only - cMsb = tmp ^ (aMsb & bLsb) ^ (aLsb & bMsb);// msb only - cLsb = (aLsb & bLsb) ^ tmp; - } - - - // Multiplies two packed matrices of F4 elements column-by-column. - // Note that here the "columns" are packed into an element of uint8_t - // resulting in a matrix with 4 columns. - inline void F4Multiply( - span a_poly, - span b_poly, - span res_poly, - size_t poly_size) - { - const uint8_t pattern = 0xaa; - uint8_t mask_h = pattern; // 0b10101010 - uint8_t mask_l = mask_h >> 1; // 0b01010101 - - uint8_t tmp; - uint8_t a_h, a_l, b_h, b_l; - - for (size_t i = 0; i < poly_size; i++) - { - // multiplication over F4 - a_h = (a_poly[i] & mask_h); - a_l = (a_poly[i] & mask_l); - b_h = (b_poly[i] & mask_h); - b_l = (b_poly[i] & mask_l); - - tmp = (a_h & b_h); - res_poly[i] = tmp ^ (a_h & (b_l << 1)); - res_poly[i] ^= ((a_l << 1) & b_h); - res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); - } - } - - // Multiplies two packed matrices of F4 elements column-by-column. - // Note that here the "columns" are packed into an element of uint16_t - // resulting in a matrix with 8 columns. - inline void multiply_fft_16( - span a_poly, - span b_poly, - span res_poly, - size_t poly_size) - { - const uint16_t pattern = 0xaaaa; - uint16_t mask_h = pattern; // 0b101010101010101001010 - uint16_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - uint16_t tmp; - uint16_t a_h, a_l, b_h, b_l; - - for (size_t i = 0; i < poly_size; i++) - { - // multiplication over F4 - a_h = (a_poly[i] & mask_h); - a_l = (a_poly[i] & mask_l); - b_h = (b_poly[i] & mask_h); - b_l = (b_poly[i] & mask_l); - - tmp = (a_h & b_h); - res_poly[i] = tmp ^ (a_h & (b_l << 1)); - res_poly[i] ^= ((a_l << 1) & b_h); - res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); - } - } - - // Multiplies two packed matrices of F4 elements column-by-column. - // Note that here the "columns" are packed into an element of uint32_t - // resulting in a matrix with 16 columns. - inline void multiply_fft_32( - span a_poly, - span b_poly, - span res_poly, - size_t poly_size) - { - const uint32_t pattern = 0xaaaaaaaa; - uint32_t mask_h = pattern; // 0b101010101010101001010 - uint32_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - uint32_t tmp; - uint32_t a_h, a_l, b_h, b_l; - - for (size_t i = 0; i < poly_size; i++) - { - // multiplication over F4 - a_h = (a_poly[i] & mask_h); - a_l = (a_poly[i] & mask_l); - b_h = (b_poly[i] & mask_h); - b_l = (b_poly[i] & mask_l); - - tmp = (a_h & b_h); - res_poly[i] = tmp ^ (a_h & (b_l << 1)); - res_poly[i] ^= ((a_l << 1) & b_h); - res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); - } - } - - // Multiplies two packed matrices of F4 elements column-by-column. - // Note that here the "columns" are packed into an element of uint64_t - // resulting in a matrix with 32 columns. - inline void multiply_fft_64( - span a_poly, - span b_poly, - span res_poly, - size_t poly_size) - { - const uint64_t pattern = 0xaaaaaaaaaaaaaaaa; - uint64_t mask_h = pattern; // 0b101010101010101001010 - uint64_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - uint64_t tmp; - uint64_t a_h, a_l, b_h, b_l; - - for (size_t i = 0; i < poly_size; i++) - { - // multiplication over F4 - a_h = (a_poly[i] & mask_h); - a_l = (a_poly[i] & mask_l); - b_h = (b_poly[i] & mask_h); - b_l = (b_poly[i] & mask_l); - - tmp = (a_h & b_h); - res_poly[i] = tmp ^ (a_h & (b_l << 1)); - res_poly[i] ^= ((a_l << 1) & b_h); - res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); - } - } - - - - // samples the a polynomials and axa polynomials - inline void sample_a_and_a2(span fft_a, span fft_a2, size_t poly_size, size_t c, PRNG& prng) - { - if (c > 16) - throw RTE_LOC; - - prng.get(fft_a.data(), poly_size); - - // make a_0 the identity polynomial (in FFT space) i.e., all 1s - for (size_t i = 0; i < poly_size; i++) - { - fft_a[i] = (fft_a[i] & ~3ull) | 1; - } - - //std::cout << "sampleA " << int(fft_a[0]) << int(fft_a[1]) << int(fft_a[2]) << int(fft_a[3]) << std::endl; - - // FOR DEBUGGING: set fft_a to the identity - // for (size_t i = 0; i < poly_size; i++) - // { - // fft_a[i] = (0xaaaa >> 1); - // } - - uint32_t prod; - for (size_t j = 0; j < c; j++) - { - for (size_t k = 0; k < c; k++) - { - for (size_t i = 0; i < poly_size; i++) - { - auto a = (fft_a[i] >> (2 * j)) & 0b11; - auto b = (fft_a[i] >> (2 * k)) & 0b11; - auto a1 = a & 1; - auto a2 = a & 2; - auto b1 = b & 1; - auto b2 = b & 2; - - { - u8 tmp = (a2 & b2); - prod = tmp ^ ((a2 & (b1 << 1)) ^ ((a1 << 1) & b2)); - prod |= (a1 & b1) ^ (tmp >> 1); - //return res; - } - //prod = mult_f4(, ); - size_t slot = j * c + k; - fft_a2[i] |= prod << (2 * slot); - } - } - } - } - -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/FoleageMain.cpp b/libOTe/Tools/Foleage/FoleageMain.cpp deleted file mode 100644 index 2ba24634..00000000 --- a/libOTe/Tools/Foleage/FoleageMain.cpp +++ /dev/null @@ -1,304 +0,0 @@ -#include -#include -#include - -#include "libOTe/Tools/Foleage/F4Ops.h" -#include "libOTe/Tools/Foleage/fft/FoleageFft.h" - -#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" -#include "libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h" - -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -// Benchmarks are less documented compared to test.c; see test.c to -// better understand what is done here for timing purposes. - -#define DPF_MSG_SIZE 8 -namespace osuCrypto -{ - - - double bench_pcg(size_t n, size_t c, size_t t) - { - if (c > 4) - { - printf("ERROR: currently only implemented for c <= 4"); - exit(0); - } - - const size_t poly_size = ipow(3, n); - PRNG prng(block(342)); - - //************************************************************************ - // Step 0: Sample the global (1, a1 ... a_c-1) polynomials - //************************************************************************ - AlignedUnVector fft_a(poly_size); - AlignedUnVector fft_a2(poly_size); - sample_a_and_a2(fft_a, fft_a2, poly_size, c, prng); - - //************************************************************************ - // Step 1: Sample DPF keys for the cross product. - // For benchmarking purposes, we sample random DPF functions for a - // sufficiently large domain size to express a block of coefficients. - //************************************************************************ - size_t dpf_domain_bits = ceil(log_base(poly_size / (t * DPF_MSG_SIZE * 64), 3)); - printf("dpf_domain_bits = %zu \n", dpf_domain_bits); - - size_t seed_size_bits = (128 * (dpf_domain_bits * 3 + 1) + DPF_MSG_SIZE * 128) * c * c * t * t; - printf("PCG seed size: %.2f MB\n", seed_size_bits / 8000000.0); - - size_t dpf_block_size = DPF_MSG_SIZE * ipow(3, dpf_domain_bits); - size_t block_size = ceil(poly_size / t); - - printf("block_size = %zu \n", block_size); - - std::vectordpf_keys_A(c * c * t * t); - std::vectordpf_keys_B(c * c * t * t); - - // Sample PRF keys for the DPFs - PRFKeys prf_keys; - prf_keys.gen(prng); - - // Sample DPF keys for each of the t errors in the t blocks - for (size_t i = 0; i < c; i++) - { - for (size_t j = 0; j < c; j++) - { - for (size_t k = 0; k < t; k++) - { - for (size_t l = 0; l < t; l++) - { - size_t index = i * c * t * t + j * t * t + k * t + l; - - // Pick a random index for benchmarking purposes - size_t alpha = random_index(block_size, prng); - - // Pick a random output message for benchmarking purposes - block beta[DPF_MSG_SIZE]; - prng.get(beta, DPF_MSG_SIZE); - - // Message (beta) is of size 8 blocks of 128 bits - DPFGen(prf_keys, dpf_domain_bits, alpha, beta, DPF_MSG_SIZE, dpf_keys_A[index], dpf_keys_B[index], prng); - } - } - } - } - - //************************************************ - printf("Benchmarking PCG evaluation \n"); - //************************************************ - - // Allocate memory for the DPF outputs (this is reused for each evaluation) - AlignedUnVector shares(dpf_block_size); - AlignedUnVector cache(dpf_block_size); - - // Allocate memory for the concatenated DPF outputs - const size_t packed_block_size = ceil(block_size / 64.0); - const size_t packed_poly_size = t * packed_block_size; - AlignedUnVector packed_polys(c * c * packed_poly_size); - - // Allocate memory for the output FFT - AlignedUnVector fft_u(poly_size); - - // Allocate memory for the final inner product - AlignedUnVector z_poly(poly_size); - AlignedUnVector res_poly_mat(poly_size); - - //************************************************************************ - // Step 3: Evaluate all the DPFs to recover shares of the c*c polynomials. - //************************************************************************ - - clock_t time; - time = clock(); - - size_t key_index; - block* poly_block; - size_t i, j, k, l, w; - for (i = 0; i < c; i++) - { - for (j = 0; j < c; j++) - { - const size_t poly_index = i * c + j; - block* packed_poly = &packed_polys[poly_index * packed_poly_size]; - - for (k = 0; k < t; k++) - { - poly_block = &packed_poly[k * packed_block_size]; - - for (l = 0; l < t; l++) - { - key_index = i * c * t * t + j * t * t + k * t + l; - - DPFFullDomainEval(dpf_keys_A[key_index], cache, shares); - - for (w = 0; w < packed_block_size; w++) - poly_block[w] ^= shares[w]; - } - } - } - } - - //************************************************************************ - // Step 3: Compute the transpose of the polynomials to pack them into - // the parallel FFT format. - // - // TODO: this is the bottleneck of the computation and can be improved - // using SIMD operations for performing matrix transposes (see TODO in test.c). - //************************************************************************ - for (size_t i = 0; i < c * c; i++) - { - size_t poly_index = i * packed_poly_size; - const block* poly = &packed_polys[poly_index]; - -#ifdef ENABLE_SSE - _mm_prefetch((char*)poly, _MM_HINT_T2); -#endif // ENABLE_SSE - - size_t block_idx, packed_coeff_idx, coeff_idx; - //uint8_t packed_bit_idx; - block packed_coeff; - - block_idx = 0; - packed_coeff_idx = 0; - coeff_idx = 0; - - for (size_t k = 0; k < poly_size - 64; k += 64) - { - packed_coeff = poly[block_idx * packed_block_size + packed_coeff_idx]; - -#ifdef ENABLE_SSE - _mm_prefetch((char*)&fft_u[k], _MM_HINT_T2); -#endif // ENABLE_SSE - //__builtin_prefetch(&fft_u[k], 0, 0); - //__builtin_prefetch(&fft_u[k], 1, 0); - - for (size_t l = 0; l < 64; l++) - { - packed_coeff = packed_coeff >> 2; - fft_u[k + l] |= static_cast(packed_coeff.get(0)) & 0b11; - fft_u[k + l] = fft_u[k + l] << 2; - } - - packed_coeff_idx++; - coeff_idx += 64; - - if (coeff_idx > block_size) - { - coeff_idx = 0; - block_idx++; - packed_coeff_idx = 0; - -#ifdef ENABLE_SSE - _mm_prefetch((char*)&poly[block_idx * packed_block_size], _MM_HINT_T2); - //__builtin_prefetch(&poly[block_idx * packed_block_size], 0, 2); -#endif // ENABLE_SSE - } - } - - packed_coeff = poly[block_idx * packed_block_size + packed_coeff_idx]; - for (size_t k = poly_size - 64 + 1; k < poly_size; k++) - { - packed_coeff = packed_coeff >> 2; - fft_u[k] |= static_cast(packed_coeff.get(0)) & 0b11 ; - fft_u[k] = fft_u[k] << 2; - } - } - - fft_recursive_uint32(fft_u, n, poly_size / 3); - multiply_fft_32(fft_a2, fft_u, res_poly_mat, poly_size); - - // Perform column-wise XORs to get the result - for (size_t i = 0; i < poly_size; i++) - { - // XOR the (packed) columns into the accumulator - for (size_t j = 0; j < c * c; j++) - { - z_poly[i] ^= res_poly_mat[i] & 0b11; - res_poly_mat[i] = res_poly_mat[i] >> 2; - } - } - - time = clock() - time; - double time_taken = ((double)time) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("Eval time (total) %f ms\n", time_taken); - printf("DONE\n\n"); - - //DestroyPRFKey(prf_keys); - //free(fft_a); - //free(fft_a2); - //free(dpf_keys_A); - //free(dpf_keys_B); - //free(shares); - //free(cache); - //free(fft_u); - //free(packed_polys); - //free(res_poly_mat); - //free(z_poly); - - return time_taken; - } - - void printUsage() - { - printf("Usage: ./pcg [OPTIONS]\n"); - printf("Options:\n"); - printf(" --test\tTests correctness of the PCG.\n"); - printf(" --bench\tBenchmarks the PCG on conservative and aggressive parameters.\n"); - } - - void runBenchmarks(size_t n, size_t c, size_t t, int num_trials) - { - double time = 0; - - for (int i = 0; i < num_trials; i++) - { - time += bench_pcg(n, c, t); - printf("Done with trial %i of %i\n", i + 1, num_trials); - } - printf("******************************************\n"); - printf("Avg time (N=3^%zu, c=%zu, t=%zu): %0.4f ms\n", n, c, t, time / num_trials); - printf("******************************************\n\n"); - } - - int main_foliage(int argc, char** argv) - { - int num_trials = 5; - - for (int i = 1; i < argc; i++) - { - if (strcmp(argv[i], "--bench") == 0) - { - printf("******************************************\n"); - printf("Benchmarking PCG with conservative parameters (c=4, t=27)\n"); - runBenchmarks(14, 4, 27, num_trials); - runBenchmarks(16, 4, 27, num_trials); - runBenchmarks(18, 4, 27, num_trials); - - printf("******************************************\n"); - printf("Benchmarking PCG with aggressive parameters (c=3, t=27)\n"); - runBenchmarks(14, 3, 27, num_trials); - runBenchmarks(16, 3, 27, num_trials); - runBenchmarks(18, 3, 27, num_trials); - } - //else if (strcmp(argv[i], "--test") == 0) - //{ - // printf("******************************************\n"); - // printf("Testing PCG\n"); - // foliage_pcg_test(); - // printf("******************************************\n"); - // printf("PASS\n"); - // printf("******************************************\n\n"); - //} - else - { - printUsage(); - } - } - - if (argc == 1) - printUsage(); - - return 0; - } -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/FoleageUtils.h b/libOTe/Tools/Foleage/FoleageUtils.h deleted file mode 100644 index 444724ee..00000000 --- a/libOTe/Tools/Foleage/FoleageUtils.h +++ /dev/null @@ -1,392 +0,0 @@ -#pragma once -#include "cryptoTools/Crypto/AES.h" -#include "cryptoTools/Crypto/PRNG.h" -#include "cryptoTools/Crypto/RandomOracle.h" -#include -#include "uint128.h" -#include -#include - -namespace osuCrypto -{ - //using uint128_t = absl::uint128_t; - //using int128_t = block; - //using uint128_t = block; - //using uint128_t = __uint128_t; - //struct uint128_t - //{ - // std::array mVals; - - // uint128_t() = default; - // uint128_t(const uint128_t&) = default; - // uint128_t& operator=(const uint128_t&) = default; - - // uint128_t(const u64& v) : mVals({ v,0 }) {}; - - // bool operator==(const uint128_t& o) const { return mVals[0] == o.mVals[0] && mVals[1] == o.mVals[1]; } - // bool operator!=(const uint128_t& o) const { return !(*this == o); } - - // bool operator==(const u64& o) const { return *this == uint128_t{ o }; } - // bool operator!=(const u64& o) const { return *this != uint128_t{ o }; } - // bool operator==(const int& o) const { return *this == uint128_t{ u64(o) }; } - // bool operator!=(const int& o) const { return *this != uint128_t{ u64(o) }; } - - - // uint128_t operator^(const uint128_t&o) const { - // uint128_t r = *this; - // r ^= o; - // return r; - // } - // uint128_t& operator^=(const uint128_t& o) - // { - // mVals[0] ^= o.mVals[0]; - // mVals[1] ^= o.mVals[1]; - // return *this; - // } - - // uint128_t operator&(const uint128_t&o) const { - // uint128_t r = *this; - // r &= o; - // return r; - // } - // uint128_t& operator&=(const uint128_t&o) - // { - // mVals[0] &= o.mVals[0]; - // mVals[1] &= o.mVals[1]; - // return *this; - // } - - - // uint128_t operator+(const uint128_t&o) const - // { - // uint128_t r = *this; - // r += o; - // return r; - // } - // uint128_t& operator+=(const uint128_t&o) - // { - // u64 v; - // char cout = _addcarry_u64(0, mVals[0], o.mVals[0], &mVals[0]); - // _addcarry_u64(cout, mVals[1], o.mVals[1], &mVals[1]); - // return *this; - // } - - - // uint128_t operator-(const uint128_t&o) const - // { - // uint128_t r = *this; - // r -= o; - // return r; - // } - // uint128_t& operator-=(const uint128_t&o) - // { - // auto borrow = _subborrow_u64(0, mVals[0], o.mVals[0], &mVals[0]); - // _subborrow_u64(borrow, mVals[1], o.mVals[1], &mVals[1]); - // return *this; - // } - - - // uint128_t operator>>(u64 s) const - // { - // auto r = *this; - // r >>= s; - // return r; - // } - // uint128_t& operator>>=(u64 s) - // { - // assert(s <= 128); - // if (s < 64) - // { - // mVals[0] = (mVals[0] >> s) | (mVals[1] << (64-s)); - // mVals[1] >>= s; - // } - // else - // { - // s = s - 64; - // mVals[0] = mVals[1] >> s; - // mVals[1] = 0; - // } - // return *this; - // } - - // uint128_t operator<<(u64 s) const - // { - // auto r = *this; - // r <<= s; - // return r; - // } - // uint128_t& operator<<=(u64 s) - // { - // assert(s <= 128); - // if (s < 64) - // { - // mVals[1] = (mVals[1] << s) | (mVals[0] >> (64 - s)); - // mVals[0] <<= s; - // } - // else - // { - // s = s - 64; - // mVals[1] = mVals[0] << s; - // mVals[0] = 0; - // } - // return *this; - // } - - // uint128_t operator>>(int s) const { return *this >> u64(s); } - // uint128_t& operator>>=(int s) { return *this >>= u64(s); } - - // uint128_t operator<<(int s) const { return *this << u64(s); } - // uint128_t& operator<<=(int s) { return *this >>= u64(s); } - - // operator u64 () const - // { - // return mVals[0]; - // } - - - //}; - - - - inline void printBytes(void* p, int num) - { - unsigned char* c = (unsigned char*)p; - for (int i = 0; i < num; i++) - { - printf("%02x", c[i]); - } - printf("\n"); - } - - template - inline block hash(T* ptr, u64 size) - { - oc::RandomOracle ro(16); - ro.Update(ptr, size); - block f; - ro.Final(f); - return f; - } - - - inline std::string hex32(span ptr) - { - std::stringstream ss; - for (u64 i = 0; i < ptr.size(); ++i) - ss << std::setw(8)<() % (max + 1); - //while (1) - //{ - - // // Use rejection sampling to ensure uniformity - // if (rand_value <= (UINT64_MAX - (UINT64_MAX % (max + 1)))) - // return rand_value % (max + 1); - //} - } - - // Samples a random trit (0,1,2) via rejection sampling - inline uint8_t rand_trit(PRNG& prng) - { - uint8_t t; - - while (1) - { - //RAND_bytes(&rand_byte, 1); - t = prng.get(); - if (t <= 170) // Rejecting values greater than 170 - return t % 3; - } - } - - // Reverses the order of elements in an array of uint8_t values - inline void reverse_uint8_array(span trits) - { - size_t i = 0; - size_t j = trits.size() - 1; - - while (i < j) - { - // Swap elements at positions i and j - uint8_t temp = trits[i]; - trits[i] = trits[j]; - trits[j] = temp; - - // Move towards the center of the array - i++; - j--; - } - } - - // Converts an array of trits (not packed) into their integer representation. - inline size_t trits_to_int(span trits) - { - if (trits.size() == 0) - return 0; - reverse_uint8_array(trits); - size_t result = 0; - for (size_t i = 0; i < trits.size(); i++) - result = result * 3 + (size_t)trits[i]; - - return result; - } - - // Converts an integer into ternary representation (each trit = 0,1,2) - inline void int_to_trits(size_t n, span trits) - { - for (size_t i = 0; i < trits.size(); i++) - trits[i] = 0; - - size_t index = 0; - while (n > 0 && index < trits.size()) - { - trits[index] = (uint8_t)(n % 3); - n = n / 3; - index++; - } - } - - // Computes the log of `a` base `base` - inline double log_base(double a, double base) - { - return std::log2(a) / std::log2(base); - } - - inline u64 log3ceil(u64 x) - { - if (x == 0) return 0; - u64 i = 0; - u64 v = 1; - while (v < x) - { - v *= 3; - i++; - } - //assert(i == ceil(log_base(x, 3))); - - return i; - } - - // Compute base^exp without the floating-point precision - // errors of the built-in pow function. - inline constexpr size_t ipow(size_t base, size_t exp) - { - if (exp == 1) - return base; - - if (exp == 0) - return 1; - - size_t result = 1; - while (1) - { - if (exp & 1) - result *= base; - exp >>= 1; - if (!exp) - break; - base *= base; - } - - return result; - } - - inline int popcount(block x) - { - //std::array xArr; - //memcpy(xArr.data(), &x, 16); - return popcount(x.get(0)) + popcount(x.get(1)); - } - //inline int popcount(uint128_t x) - //{ - // std::array xArr; - // memcpy(xArr.data(), &x, 16); - // return popcount(xArr[0]) + popcount(xArr[1]); - //} - - //inline std::array extractF4(const uint128_t& val) - //{ - // std::array ret; - // const char* ptr = (const char*)&val; - // for (u8 i = 0; i < 16; ++i) - // { - // ret[i * 4 + 0] = (ptr[i] >> 0) & 3; - // ret[i * 4 + 1] = (ptr[i] >> 2) & 3; - // ret[i * 4 + 2] = (ptr[i] >> 4) & 3; - // ret[i * 4 + 3] = (ptr[i] >> 6) & 3;; - // } - // return ret; - //} - inline std::array extractF4(const block& val) - { - std::array ret; - const char* ptr = (const char*)&val; - for (u8 i = 0; i < 16; ++i) - { - ret[i * 4 + 0] = (ptr[i] >> 0) & 3; - ret[i * 4 + 1] = (ptr[i] >> 2) & 3; - ret[i * 4 + 2] = (ptr[i] >> 4) & 3; - ret[i * 4 + 3] = (ptr[i] >> 6) & 3;; - } - return ret; - } - - // A 512 bit value that is used to represent a vector of 3^5=243 F4 elements. - // We use this value because its greater than 128 bits and almost a power of 2. - // the last 26 bits are unused. - struct FoleageF4x243 - { - std::array mVal; - - FoleageF4x243 operator^(const FoleageF4x243& o) const - { - FoleageF4x243 r; - r.mVal[0] = mVal[0] ^ o.mVal[0]; - r.mVal[1] = mVal[1] ^ o.mVal[1]; - r.mVal[2] = mVal[2] ^ o.mVal[2]; - r.mVal[3] = mVal[3] ^ o.mVal[3]; - return r; - } - FoleageF4x243& operator^=(const FoleageF4x243& o) - { - mVal[0] = mVal[0] ^ o.mVal[0]; - mVal[1] = mVal[1] ^ o.mVal[1]; - mVal[2] = mVal[2] ^ o.mVal[2]; - mVal[3] = mVal[3] ^ o.mVal[3]; - return *this; - } - - bool operator==(const FoleageF4x243& o) const - { - return - mVal[0] == o.mVal[0] && - mVal[1] == o.mVal[1] && - mVal[2] == o.mVal[2] && - mVal[3] == o.mVal[3]; - } - }; - - inline std::array extractF4(const FoleageF4x243& val) - { - std::array ret; - const char* ptr = (const char*)&val; - for (u8 i = 0; i < 64; ++i) - { - ret[i * 4 + 0] = (ptr[i] >> 0) & 3; - ret[i * 4 + 1] = (ptr[i] >> 2) & 3; - ret[i * 4 + 2] = (ptr[i] >> 4) & 3; - ret[i * 4 + 3] = (ptr[i] >> 6) & 3;; - } - return ret; - } -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/PerfectShuffle.h b/libOTe/Tools/Foleage/PerfectShuffle.h deleted file mode 100644 index c5367fdb..00000000 --- a/libOTe/Tools/Foleage/PerfectShuffle.h +++ /dev/null @@ -1,437 +0,0 @@ -#pragma once -#include "cryptoTools/Common/Defines.h" -#include -#include - -namespace osuCrypto -{ - - - - // given a shuffle on blocks of 2*Shift, shuffle - // them together to have block size Shift. - template - inline u32 cPerfectShuffle_round(u32 x) - { - static_assert(Shift, "Shift must be 1,2,4,8. That is, we assume the x is split into chunks of size 2*Shift and we will shuffle these into chunks of size Shift"); - u32 t; - t = (x ^ (x >> Shift)) & v; - x = x ^ t ^ (t << Shift); - return x; - } - - - - // Hackers Delight perfect shuffle, Sec 7.2. Interlace bits. - // https://doc.lagout.org/security/Hackers%20Delight.pdf - // - // input : abcd efgh ijkl mnop ABCD EFGH IJKL MNOP, - // output: aAbB cCdD eEfF gGhH iIjJ kKlL mMnN oOpP - inline u32 cPerfectShuffle(u16 x0, u16 x1) - { - u32 x = x0 | (u32{ x1 } << 16); - x = cPerfectShuffle_round<8>(x); - x = cPerfectShuffle_round<4>(x); - x = cPerfectShuffle_round<2>(x); - x = cPerfectShuffle_round<1>(x); - return x; - } - - // Hackers Delight perfect shuffle, Sec 7.2. Uninterlace bits. - // https://doc.lagout.org/security/Hackers%20Delight.pdf - // - // input : aAbB cCdD eEfF gGhH iIjJ kKlL mMnN oOpP - // output: abcd efgh ijkl mnop ABCD EFGH IJKL MNOP, - inline std::array cPerfectUnshuffle(u32 x) - { - x = cPerfectShuffle_round<1>(x); - x = cPerfectShuffle_round<2>(x); - x = cPerfectShuffle_round<4>(x); - x = cPerfectShuffle_round<8>(x); - - std::array r; - r[0] = x; - r[1] = x >> 16; - return r; - } - - // perfect shuffle the bits of `input0` and `input1` into `output`. - // bits from `input0` and `input1` alternate. - inline void cPerfectShuffle(span input0, span input1, span output) - { - if (input0.size() != input1.size()) - throw RTE_LOC; - if (input0.size() != (output.size() + 1) / 2) - throw RTE_LOC; - - u64 n32 = output.size() / sizeof(u32); - - auto in0 = (u16*)input0.data(); - auto in1 = (u16*)input1.data(); - auto out = (u32*)output.data(); - for (u64 i = 0; i < n32; ++i) - { - out[i] = cPerfectShuffle(in0[i], in1[i]); - } - - auto n8 = n32 * sizeof(u32); - if (output.size() != n8) - { - u16 x0 = 0, x1 = 0; - copyBytesMin(x0, input0.subspan(n8 / 2)); - copyBytesMin(x1, input1.subspan(n8 / 2)); - auto t = cPerfectShuffle(x0, x1); - copyBytesMin(output.subspan(n8), t); - } - } - - // perfect unshuffle the bits of `input` into `output0` and `output1`. - // even indexed bits of `input` go to `output0`. - inline void cPerfectUnshuffle(span input, span output0, span output1) - { - if (output0.size() != output1.size()) - throw RTE_LOC; - if (output0.size() != (input.size() + 1) / 2) - throw RTE_LOC; - u64 n32 = input.size() / sizeof(u32); - auto out0 = (u16*)output0.data(); - auto out1 = (u16*)output1.data(); - auto in = (u32*)input.data(); - for (u64 i = 0; i < n32; ++i) - { - auto t = cPerfectUnshuffle(in[i]); - assert((u8*)&(out0[i]) < output0.data() + output0.size()); - assert((u8*)&(out1[i]) < output1.data() + output1.size()); - - out0[i] = ((u16*)&t)[0]; - out1[i] = ((u16*)&t)[1]; - } - - auto n8 = n32 * sizeof(u32); - if (input.size() != n8) - { - // auto rem = output0.size() - n8 / 2; - u32 t = 0; - copyBytesMin(t, input.subspan(n8)); - auto r = cPerfectUnshuffle(t); - copyBytesMin(output0.subspan(n8 / 2), r[0]); - copyBytesMin(output1.subspan(n8 / 2), r[1]); - } - } - -#ifdef ENABLE_SSE - - // given a shuffle on blocks of 2*Shift, shuffle - // them together to have block size Shift. - template - inline void ssePerfectShuffle_round(oc::block& x) - { - static_assert(Shift, "Shift must be 1,2,4,8. That is, we assume the x is split into chunks of size 2*Shift and we will shuffle these into chunks of size Shift"); - oc::block t; - - //t = (x ^ (x >> shift)) & 0x0000FF00; - t = _mm_srli_epi32(x, Shift); - t = _mm_xor_si128(t, x); - t = _mm_and_si128(t, _mm_set_epi32(v, v, v, v)); - - // x = x ^ t ^ (t << shift); - x = _mm_xor_si128(t, x); - t = _mm_slli_epi32(t, Shift); - x = _mm_xor_si128(t, x); - } - - // given a shuffle on blocks of 2*Shift, shuffle - // them together to have block size Shift. - template - inline void ssePerfectShuffle_round(oc::block* x) - { - static_assert(Shift, "Shift must be 1,2,4,8. That is, we assume the x is split into chunks of size 2*Shift and we will shuffle these into chunks of size Shift"); - oc::block t[8]; - auto V = _mm_set_epi32(v, v, v, v); - - //t = (x ^ (x >> shift)) & 0x0000FF00; - t[0] = _mm_srli_epi32(x[0], Shift); - t[1] = _mm_srli_epi32(x[1], Shift); - t[2] = _mm_srli_epi32(x[2], Shift); - t[3] = _mm_srli_epi32(x[3], Shift); - t[4] = _mm_srli_epi32(x[4], Shift); - t[5] = _mm_srli_epi32(x[5], Shift); - t[6] = _mm_srli_epi32(x[6], Shift); - t[7] = _mm_srli_epi32(x[7], Shift); - - t[0] = _mm_xor_si128(t[0], x[0]); - t[1] = _mm_xor_si128(t[1], x[1]); - t[2] = _mm_xor_si128(t[2], x[2]); - t[3] = _mm_xor_si128(t[3], x[3]); - t[4] = _mm_xor_si128(t[4], x[4]); - t[5] = _mm_xor_si128(t[5], x[5]); - t[6] = _mm_xor_si128(t[6], x[6]); - t[7] = _mm_xor_si128(t[7], x[7]); - - t[0] = _mm_and_si128(t[0], V); - t[1] = _mm_and_si128(t[1], V); - t[2] = _mm_and_si128(t[2], V); - t[3] = _mm_and_si128(t[3], V); - t[4] = _mm_and_si128(t[4], V); - t[5] = _mm_and_si128(t[5], V); - t[6] = _mm_and_si128(t[6], V); - t[7] = _mm_and_si128(t[7], V); - - // x = x ^ t ^ (t << shift); - x[0] = _mm_xor_si128(t[0], x[0]); - x[1] = _mm_xor_si128(t[1], x[1]); - x[2] = _mm_xor_si128(t[2], x[2]); - x[3] = _mm_xor_si128(t[3], x[3]); - x[4] = _mm_xor_si128(t[4], x[4]); - x[5] = _mm_xor_si128(t[5], x[5]); - x[6] = _mm_xor_si128(t[6], x[6]); - x[7] = _mm_xor_si128(t[7], x[7]); - t[0] = _mm_slli_epi32(t[0], Shift); - t[1] = _mm_slli_epi32(t[1], Shift); - t[2] = _mm_slli_epi32(t[2], Shift); - t[3] = _mm_slli_epi32(t[3], Shift); - t[4] = _mm_slli_epi32(t[4], Shift); - t[5] = _mm_slli_epi32(t[5], Shift); - t[6] = _mm_slli_epi32(t[6], Shift); - t[7] = _mm_slli_epi32(t[7], Shift); - x[0] = _mm_xor_si128(t[0], x[0]); - x[1] = _mm_xor_si128(t[1], x[1]); - x[2] = _mm_xor_si128(t[2], x[2]); - x[3] = _mm_xor_si128(t[3], x[3]); - x[4] = _mm_xor_si128(t[4], x[4]); - x[5] = _mm_xor_si128(t[5], x[5]); - x[6] = _mm_xor_si128(t[6], x[6]); - x[7] = _mm_xor_si128(t[7], x[7]); - } - - inline oc::block ssePerfectShuffle(u64 x0, u64 x1) - { - // perfect shuffle the bytes. - const oc::block b = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0); - oc::block y = _mm_set_epi64x(x1, x0); - y = _mm_shuffle_epi8(y, b); - - // perfect shuffle the bits. - ssePerfectShuffle_round<4>(y); - ssePerfectShuffle_round<2>(y); - ssePerfectShuffle_round<1>(y); - return y; - } - - inline std::array ssePerfectUnshuffle(oc::block y) - { - // perfect shuffle the bits. - ssePerfectShuffle_round<1>(y); - ssePerfectShuffle_round<2>(y); - ssePerfectShuffle_round<4>(y); - - // perfect shuffle the bytes. - const oc::block b = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); - y = _mm_shuffle_epi8(y, b); - - return std::bit_cast>(y); - } - - // perfect shuffle 4 blocks on x0,x1 into 8 blocks of y. - inline void ssePerfectShuffle(const oc::block* x0, const oc::block* x1, oc::block* y) - { - // perfect shuffle the bytes. - const oc::block b = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0); - y[0] = _mm_set_epi64x(((u64*)x1)[0], ((u64*)x0)[0]); - y[1] = _mm_set_epi64x(((u64*)x1)[1], ((u64*)x0)[1]); - y[2] = _mm_set_epi64x(((u64*)x1)[2], ((u64*)x0)[2]); - y[3] = _mm_set_epi64x(((u64*)x1)[3], ((u64*)x0)[3]); - y[4] = _mm_set_epi64x(((u64*)x1)[4], ((u64*)x0)[4]); - y[5] = _mm_set_epi64x(((u64*)x1)[5], ((u64*)x0)[5]); - y[6] = _mm_set_epi64x(((u64*)x1)[6], ((u64*)x0)[6]); - y[7] = _mm_set_epi64x(((u64*)x1)[7], ((u64*)x0)[7]); - y[0] = _mm_shuffle_epi8(y[0], b); - y[1] = _mm_shuffle_epi8(y[1], b); - y[2] = _mm_shuffle_epi8(y[2], b); - y[3] = _mm_shuffle_epi8(y[3], b); - y[4] = _mm_shuffle_epi8(y[4], b); - y[5] = _mm_shuffle_epi8(y[5], b); - y[6] = _mm_shuffle_epi8(y[6], b); - y[7] = _mm_shuffle_epi8(y[7], b); - - // perfect shuffle the bits. - ssePerfectShuffle_round<4>(y); - ssePerfectShuffle_round<2>(y); - ssePerfectShuffle_round<1>(y); - } - - // perfect unshuffle 8 blocks of y into 4 blocks on x0,x1 into. - inline void ssePerfectUnshuffle(const oc::block* yy, oc::block* x0, oc::block* x1) - { - std::array y; - std::copy((u8*)yy, (u8*)(yy + y.size()), (u8*)y.data()); - // m emcpy(y.data(), yy, sizeof(y)); - - // perfect shuffle the bits. - ssePerfectShuffle_round<1>(y.data()); - ssePerfectShuffle_round<2>(y.data()); - ssePerfectShuffle_round<4>(y.data()); - - // perfect shuffle the bytes. - const oc::block b = _mm_set_epi8(15, 13, 11, 9, 7, 5, 3, 1, 14, 12, 10, 8, 6, 4, 2, 0); - y[0] = _mm_shuffle_epi8(y[0], b); - y[1] = _mm_shuffle_epi8(y[1], b); - y[2] = _mm_shuffle_epi8(y[2], b); - y[3] = _mm_shuffle_epi8(y[3], b); - y[4] = _mm_shuffle_epi8(y[4], b); - y[5] = _mm_shuffle_epi8(y[5], b); - y[6] = _mm_shuffle_epi8(y[6], b); - y[7] = _mm_shuffle_epi8(y[7], b); - - - u64* yyy = (u64*)y.data(); - u64* xx1 = (u64*)x1; - u64* xx0 = (u64*)x0; - xx0[0] = yyy[0]; - xx1[0] = yyy[1]; - xx0[1] = yyy[2]; - xx1[1] = yyy[3]; - xx0[2] = yyy[4]; - xx1[2] = yyy[5]; - xx0[3] = yyy[6]; - xx1[3] = yyy[7]; - - xx0[4] = yyy[8]; - xx1[4] = yyy[9]; - xx0[5] = yyy[10]; - xx1[5] = yyy[11]; - xx0[6] = yyy[12]; - xx1[6] = yyy[13]; - xx0[7] = yyy[14]; - xx1[7] = yyy[15]; - } - - inline void ssePerfectShuffle(span input0, span input1, span output) - { - assert(input0.size() == input1.size()); - assert(input0.size() == (output.size() + 1) / 2); - u64 n1024 = output.size() / sizeof(std::array); - - auto in0 = (oc::block*)input0.data(); - auto in1 = (oc::block*)input1.data(); - auto out = (oc::block*)output.data(); - for (u64 i = 0; i < n1024; ++i) - { - ssePerfectShuffle(in0, in1, out); - in0 += 4; - in1 += 4; - out += 8; - } - - auto n64 = n1024 * 16; - auto n8 = n64 * sizeof(u64); - auto rem = input0.size() - n8 / 2; - while (rem) - { - auto min = std::min(rem, sizeof(u64)); - u64 x0 = 0, x1 = 0; - std::copy(input0.data() + n8 / 2, input0.data() + n8 / 2 + min, (u8*)&x0); - std::copy(input1.data() + n8 / 2, input1.data() + n8 / 2 + min, (u8*)&x1); - //m emcpy(&x0, &input0[n8 / 2], min); - //m emcpy(&x1, &input1[n8 / 2], min); - rem -= min; - - auto t = ssePerfectShuffle(x0, x1); - - auto min2 = std::min(output.size() - n8, sizeof(oc::block)); - std::copy((u8*)&t, (u8*)&t + min2, output.data() + n8); - //m emcpy(&output[n8], &t, min2); - n8 += min2; - } - } - - - inline void ssePerfectUnshuffle(span input, span output0, span output1) - { - assert(output0.size() == output1.size()); - assert(output0.size() == (input.size() + 1) / 2); - - u64 n1024 = input.size() / sizeof(std::array); - - auto out0 = (oc::block*)output0.data(); - auto out1 = (oc::block*)output1.data(); - auto in = (oc::block*)input.data(); - for (u64 i = 0; i < n1024; ++i) - { - assert((u8*)(in + 8) <= input.data() + input.size()); - assert((u8*)(out0 + 4) <= output0.data() + output0.size()); - assert((u8*)(out1 + 4) <= output1.data() + output1.size()); - ssePerfectUnshuffle(in, out0, out1); - - in += 8; - out0 += 4; - out1 += 4; - } - - - auto n64 = n1024 * 16; - auto n8 = n64 * sizeof(u64); - //auto n8 = n32 * sizeof(u32); - while (input.size() != n8) - { - auto rem = input.size() - n8; - auto min = std::min(rem, sizeof(oc::block)); - oc::block t = oc::ZeroBlock; - // m emcpy(&t, &input[n8], min); - std::copy(&input[n8], &input[n8] + min, (u8*)&t); - - auto r = ssePerfectUnshuffle(t); - - auto min2 = std::min(output0.size() - n8 / 2, sizeof(u64)); - // m emcpy(&output0[n8 / 2], &r[0], min2); - std::copy((u8*)&r[0], (u8*)&r[0] + min2, output0.data() + n8 / 2); - //m emcpy(&output1[n8 / 2], &r[1], min2); - std::copy((u8*)&r[1], (u8*)&r[1] + min2, output1.data() + n8 / 2); - - n8 += min; - } - } -#endif - - inline void perfectShuffle(span input0, span input1, span output) - { -#ifdef ENABLE_SSE - ssePerfectShuffle(input0, input1, output); -#else - cPerfectShuffle(input0, input1, output); -#endif - } - - inline void perfectUnshuffle(span input, span output0, span output1) - { -#ifdef ENABLE_SSE - ssePerfectUnshuffle(input, output0, output1); -#else - cPerfectUnshuffle(input, output0, output1); -#endif - } -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp b/libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp deleted file mode 100644 index 576f4a66..00000000 --- a/libOTe/Tools/Foleage/fft/FoleageFFT_bench.cpp +++ /dev/null @@ -1,138 +0,0 @@ -//#include -//#include -//#include -//#include - -#include -#include - -#include "libOTe/Tools/Foleage/fft/FoleageFft.h" -#include "cryptoTools/Common/Aligned.h" -#include "cryptoTools/Crypto/PRNG.h" - -#include "libOTe/Tools/Foleage/FoleageUtils.h" - -#define NUMVARS 16 - -namespace osuCrypto -{ - - - double Foleage_FFT64_bench() - { - size_t num_vars = NUMVARS; - size_t num_coeffs = ipow(3, num_vars); - AlignedUnVector coeffs (num_coeffs); - PRNG prng(block(342)); - prng.get(coeffs.data(), num_coeffs); - - //************************************************ - printf("Benchmarking FFT evaluation with uint64_t packing \n"); - //************************************************ - - clock_t t; - t = clock(); - fft_recursive_uint64(coeffs, num_vars, num_coeffs / 3); - t = clock() - t; - double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("FFT (uint64) eval time (total) %f ms\n", time_taken); - - return time_taken; - } - - double Foleage_FFT32_bench() - { - size_t num_vars = NUMVARS; - size_t num_coeffs = ipow(3, num_vars); - - AlignedUnVector < uint32_t> coeffs(num_coeffs); - PRNG prng(block(342)); - prng.get(coeffs.data(), num_coeffs); - - //************************************************ - printf("Benchmarking FFT evaluation with uint32_t packing \n"); - //************************************************ - - clock_t t; - t = clock(); - fft_recursive_uint32(coeffs, num_vars, num_coeffs / 3); - t = clock() - t; - double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("FFT (uint32) eval time (total) %f ms\n", time_taken); - - - return time_taken; - } - - double Foleage_FFT8_bench() - { - size_t num_vars = NUMVARS; - size_t num_coeffs = ipow(3, num_vars); - AlignedUnVector coeffs (num_coeffs); - PRNG prng(block(342)); - prng.get(coeffs.data(), num_coeffs); - - //************************************************ - printf("Benchmarking FFT evaluation without packing \n"); - //************************************************ - - clock_t t; - t = clock(); - foliageFftUint8(coeffs, num_vars, num_coeffs / 3); - t = clock() - t; - double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("FFT (uint8) eval time (total) %f ms\n", time_taken); - - //free(coeffs); - - return time_taken; - } - - int mainFFT(int argc, char** argv) - { - double time = 0; - int testTrials = 5; - - printf("******************************************\n"); - printf("Testing FFT (uint8 packing)\n"); - for (int i = 0; i < testTrials; i++) - { - time += Foleage_FFT8_bench(); - printf("Done with trial %i of %i\n", i + 1, testTrials); - } - printf("******************************************\n"); - printf("DONE\n"); - printf("Avg time: %0.2f\n", time / testTrials); - printf("******************************************\n\n"); - - printf("******************************************\n"); - printf("Testing FFT (uint32 packing) \n"); - time = 0; - for (int i = 0; i < testTrials; i++) - { - time += Foleage_FFT32_bench(); - printf("Done with trial %i of %i\n", i + 1, testTrials); - } - printf("******************************************\n"); - printf("DONE\n"); - printf("Avg time: %0.2f\n", time / testTrials); - printf("******************************************\n\n"); - - printf("******************************************\n"); - printf("Testing FFT (uint64 packing) \n"); - time = 0; - for (int i = 0; i < testTrials; i++) - { - time += Foleage_FFT64_bench(); - printf("Done with trial %i of %i\n", i + 1, testTrials); - } - printf("******************************************\n"); - printf("DONE\n"); - printf("Avg time: %0.2f\n", time / testTrials); - printf("******************************************\n\n"); - return 0; - } -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/fft/FoleageFFT_bench.h b/libOTe/Tools/Foleage/fft/FoleageFFT_bench.h deleted file mode 100644 index bc05951f..00000000 --- a/libOTe/Tools/Foleage/fft/FoleageFFT_bench.h +++ /dev/null @@ -1,13 +0,0 @@ -#pragma once - - -namespace osuCrypto -{ - - double Foleage_FFT8_bench(); - double Foleage_FFT32_bench(); - double Foleage_FFT64_bench(); - - - -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.cpp b/libOTe/Tools/Foleage/fft/FoleageFft.cpp deleted file mode 100644 index 7f7b8939..00000000 --- a/libOTe/Tools/Foleage/fft/FoleageFft.cpp +++ /dev/null @@ -1,856 +0,0 @@ -#include -#include -#include "libOTe/Tools/Foleage/fft/FoleageFft.h" -#include "libOTe/Tools/Foleage/PerfectShuffle.h" -namespace osuCrypto { - - void fft_recursive_uint64( - span coeffs, - const size_t num_vars, - const size_t num_coeffs) - { - // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) - - if (num_vars > 1) - { - // apply FFT on all left coefficients - fft_recursive_uint64( - coeffs, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all middle coefficients - fft_recursive_uint64( - coeffs.subspan(num_coeffs), - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all right coefficients - fft_recursive_uint64( - coeffs.subspan(2 * num_coeffs), - num_vars - 1, - num_coeffs / 3); - } - - // temp variables to store intermediate values - uint64_t tL, tM; - uint64_t mult, xor_h, xor_l; - - uint64_t* coeffsL = coeffs.data() + 0; - uint64_t* coeffsM = coeffs.data() + num_coeffs; - uint64_t* coeffsR = coeffs.data() + 2 * num_coeffs; - - const uint64_t pattern = 0xaaaaaaaaaaaaaaaa; - const uint64_t mask_h = pattern; // 0b101010101010101001010 - const uint64_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - for (size_t j = 0; j < num_coeffs; j++) - { - xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; - xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); - - // tL coefficient obtained by evaluating on X_i=1 - tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; - - // tM coefficient obtained by evaluating on X_i=\alpha - tM = coeffsL[j] ^ coeffsR[j] ^ mult; - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; - - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - coeffsL[j] = tL; - coeffsM[j] = tM; - } - } - - void fft_recursive_uint32( - span coeffs, - const size_t num_vars, - const size_t num_coeffs) - { - // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) - - if (num_vars > 1) - { - // apply FFT on all left coefficients - fft_recursive_uint32( - coeffs, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all middle coefficients - fft_recursive_uint32( - coeffs.subspan(num_coeffs), - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all right coefficients - fft_recursive_uint32( - coeffs.subspan(2 * num_coeffs), - num_vars - 1, - num_coeffs / 3); - } - - // temp variables to store intermediate values - uint32_t tL, tM; - uint32_t mult, xor_h, xor_l; - - uint32_t* coeffsL = coeffs.data() + 0; - uint32_t* coeffsM = coeffs.data() + num_coeffs; - uint32_t* coeffsR = coeffs.data() + 2 * num_coeffs; - - const uint32_t pattern = 0xaaaaaaaa; - const uint32_t mask_h = pattern; // 0b101010101010101001010 - const uint32_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - for (size_t j = 0; j < num_coeffs; j++) - { - xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; - xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); - - // tL coefficient obtained by evaluating on X_i=1 - tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; - - // tM coefficient obtained by evaluating on X_i=\alpha - tM = coeffsL[j] ^ coeffsR[j] ^ mult; - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; - - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - coeffsL[j] = tL; - coeffsM[j] = tM; - } - } - - void fft_recursive_uint16( - span coeffs, - const size_t num_vars, - const size_t num_coeffs) - { - // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) - - if (num_vars > 1) - { - // apply FFT on all left coefficients - fft_recursive_uint16( - coeffs, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all middle coefficients - fft_recursive_uint16( - coeffs.subspan(num_coeffs), - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all right coefficients - fft_recursive_uint16( - coeffs.subspan(2 * num_coeffs), - num_vars - 1, - num_coeffs / 3); - } - - // temp variables to store intermediate values - uint16_t tL, tM; - uint16_t mult, xor_h, xor_l; - - uint16_t* coeffsL = coeffs.data() + 0; - uint16_t* coeffsM = coeffs.data() + num_coeffs; - uint16_t* coeffsR = coeffs.data() + 2 * num_coeffs; - - const uint16_t pattern = 0xaaaa; - const uint16_t mask_h = pattern; // 0b101010101010101001010 - const uint16_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - for (size_t j = 0; j < num_coeffs; j++) - { - xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; - xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); - - // tL coefficient obtained by evaluating on X_i=1 - tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; - - // tM coefficient obtained by evaluating on X_i=\alpha - tM = coeffsL[j] ^ coeffsR[j] ^ mult; - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; - - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - coeffsL[j] = tL; - coeffsM[j] = tM; - } - } - - void foliageFftUint8( - span coeffs, - const size_t num_vars, - const size_t num_coeffs) - { - // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) - - if (num_vars > 1) - { - // apply FFT on all left coefficients - foliageFftUint8( - coeffs, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all middle coefficients - foliageFftUint8( - coeffs.subspan(num_coeffs), - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all right coefficients - foliageFftUint8( - coeffs.subspan(2 * num_coeffs), - num_vars - 1, - num_coeffs / 3); - } - - // temp variables to store intermediate values - uint8_t tL, tM; - uint8_t mult, xor_h, xor_l; - - uint8_t* coeffsL = coeffs.data() + 0; - uint8_t* coeffsM = coeffs.data() + num_coeffs; - uint8_t* coeffsR = coeffs.data() + 2 * num_coeffs; - - const uint8_t pattern = 0xaa; - const uint8_t mask_h = pattern; // 0b101010101010101001010 - const uint8_t mask_l = mask_h >> 1; // 0b010101010101010100101 - - for (size_t j = 0; j < num_coeffs; j++) - { - xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; - xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); - - // tL coefficient obtained by evaluating on X_i=1 - tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; - - // tM coefficient obtained by evaluating on X_i=\alpha - tM = coeffsL[j] ^ coeffsR[j] ^ mult; - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; - - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - coeffsL[j] = tL; - coeffsM[j] = tM; - } - } - - - void foleageFFT2( - uint8_t* lsb, - uint8_t* msb, - const size_t num_vars, - const size_t num_coeffs) - { - if (num_vars > 1) - { - // apply FFT on all left coefficients - foleageFFT2( - lsb, msb, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all middle coefficients - foleageFFT2( - lsb + num_coeffs, - msb + num_coeffs, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all right coefficients - foleageFFT2( - lsb + 2 * num_coeffs, - msb + 2 * num_coeffs, - num_vars - 1, - num_coeffs / 3); - } - - uint8_t* __restrict ptrL0 = lsb + 0; - uint8_t* __restrict ptrL1 = msb + 0; - uint8_t* __restrict ptrM0 = lsb + num_coeffs; - uint8_t* __restrict ptrM1 = msb + num_coeffs; - uint8_t* __restrict ptrR0 = lsb + 2 * num_coeffs; - uint8_t* __restrict ptrR1 = msb + 2 * num_coeffs; - - for (size_t j = 0; j < num_coeffs; j++) - { - - auto coeffsL0 = *ptrL0; - auto coeffsL1 = *ptrL1; - auto coeffsM0 = *ptrM0; - auto coeffsM1 = *ptrM1; - auto coeffsR0 = *ptrR0; - auto coeffsR1 = *ptrR1; - - auto xor_h = coeffsM1 ^ coeffsR1; - auto xor_l = coeffsM0 ^ coeffsR0; - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - auto mult0 = xor_h ^ xor_l; - auto mult1 = xor_l; - - // tL coefficient obtained by evaluating on X_i=1 - auto tL0 = coeffsL0 ^ coeffsM0 ^ coeffsR0; - auto tL1 = coeffsL1 ^ coeffsM1 ^ coeffsR1; - - // tM coefficient obtained by evaluating on X_i=\alpha - auto tM0 = coeffsL0 ^ coeffsR0 ^ mult0; - auto tM1 = coeffsL1 ^ coeffsR1 ^ mult1; - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - *ptrR0 = coeffsL0 ^ coeffsM0 ^ mult0; - *ptrR1 = coeffsL1 ^ coeffsM1 ^ mult1; - - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - *ptrL0 = tL0; - *ptrL1 = tL1; - *ptrM0 = tM0; - *ptrM1 = tM1; - - ++ptrL0; - ++ptrL1; - ++ptrM0; - ++ptrM1; - ++ptrR0; - ++ptrR1; - } - } - - template - void foleageFFTLevel( - u8* lsb, - u8* msb, - BlockSize blockSize, - u64 regions - ) - { - //static_assert(depth); - //u64 blockSize = ipow(3, depth - 1); - - for (u64 r = 0; r < regions; ++r) - { - - uint8_t* __restrict ptrL0 = lsb + r * 3 * blockSize + 0; - uint8_t* __restrict ptrL1 = msb + r * 3 * blockSize + 0; - uint8_t* __restrict ptrM0 = lsb + r * 3 * blockSize + blockSize; - uint8_t* __restrict ptrM1 = msb + r * 3 * blockSize + blockSize; - uint8_t* __restrict ptrR0 = lsb + r * 3 * blockSize + 2 * blockSize; - uint8_t* __restrict ptrR1 = msb + r * 3 * blockSize + 2 * blockSize; - - constexpr u64 width = 1; - auto main = blockSize / (width * 16); - for (u64 k = 0; k < main; ++k) - { - - block coeffsL0[width]; - block coeffsL1[width]; - block coeffsM0[width]; - block coeffsM1[width]; - block coeffsR0[width]; - block coeffsR1[width]; - - //{ constexpr u64 VAR = 1; STATEMENT; } - //{ constexpr u64 VAR = 2; STATEMENT; } - //{ constexpr u64 VAR = 3; STATEMENT; } - //{ constexpr u64 VAR = 4; STATEMENT; } - //{ constexpr u64 VAR = 5; STATEMENT; } - //{ constexpr u64 VAR = 6; STATEMENT; } - //{ constexpr u64 VAR = 7; STATEMENT; } - -#define SIMD8(VAR, STATEMENT) \ - { constexpr u64 VAR = 0; STATEMENT; }\ - do{}while(0) - - - SIMD8(q, coeffsL0[q] = _mm_loadu_si128((__m128i*)(ptrL0 + q * 16))); - SIMD8(q, coeffsL1[q] = _mm_loadu_si128((__m128i*)(ptrL1 + q * 16))); - SIMD8(q, coeffsM0[q] = _mm_loadu_si128((__m128i*)(ptrM0 + q * 16))); - SIMD8(q, coeffsM1[q] = _mm_loadu_si128((__m128i*)(ptrM1 + q * 16))); - SIMD8(q, coeffsR0[q] = _mm_loadu_si128((__m128i*)(ptrR0 + q * 16))); - SIMD8(q, coeffsR1[q] = _mm_loadu_si128((__m128i*)(ptrR1 + q * 16))); - - - - block xor_h[width], xor_l[width]; - SIMD8(j, xor_h[j] = coeffsM1[j] ^ coeffsR1[j]); - SIMD8(j, xor_l[j] = coeffsM0[j] ^ coeffsR0[j]); - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - block mult0[width];// , mult1[width]; - SIMD8(j, mult0[j] = xor_h[j] ^ xor_l[j]); - //SIMD8(j, mult1[j] = xor_l[j]); - - // tL coefficient obtained by evaluating on X_i=1 - block tL0[width], tL1[width]; - SIMD8(j, tL0[j] = coeffsL0[j] ^ coeffsM0[j] ^ coeffsR0[j]); - SIMD8(j, tL1[j] = coeffsL1[j] ^ coeffsM1[j] ^ coeffsR1[j]); - - // tM coefficient obtained by evaluating on X_i=\alpha - block tM0[width], tM1[width]; - SIMD8(j, tM0[j] = coeffsL0[j] ^ coeffsR0[j] ^ mult0[j]); - SIMD8(j, tM1[j] = coeffsL1[j] ^ coeffsR1[j] ^ xor_l[j]); - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - SIMD8(j, coeffsR0[j] = coeffsL0[j] ^ coeffsM0[j] ^ mult0[j]); - SIMD8(j, coeffsR1[j] = coeffsL1[j] ^ coeffsM1[j] ^ xor_l[j]); - - - SIMD8(j, _mm_storeu_si128((__m128i*)(ptrR0 + j * 16), coeffsR0[j])); - SIMD8(j, _mm_storeu_si128((__m128i*)(ptrR1 + j * 16), coeffsR1[j])); - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - SIMD8(j, _mm_storeu_si128((__m128i*)(ptrL0 + j * 16), tL0[j])); - SIMD8(j, _mm_storeu_si128((__m128i*)(ptrL1 + j * 16), tL1[j])); - SIMD8(j, _mm_storeu_si128((__m128i*)(ptrM0 + j * 16), tM0[j])); - SIMD8(j, _mm_storeu_si128((__m128i*)(ptrM1 + j * 16), tM1[j])); - - - ptrL0 += width * 16; - ptrL1 += width * 16; - ptrM0 += width * 16; - ptrM1 += width * 16; - ptrR0 += width * 16; - ptrR1 += width * 16; - } - -#undef SIMD8 - - for (size_t j = main * width * 16; j < blockSize; j++) - { - - auto coeffsL0 = *ptrL0; - auto coeffsL1 = *ptrL1; - auto coeffsM0 = *ptrM0; - auto coeffsM1 = *ptrM1; - auto coeffsR0 = *ptrR0; - auto coeffsR1 = *ptrR1; - - auto xor_h = coeffsM1 ^ coeffsR1; - auto xor_l = coeffsM0 ^ coeffsR0; - - // pre compute: \alpha * (cM[j] ^ cR[j]) - // computed as: mult_l = (h ^ l) and mult_h = l - // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] - // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] - auto mult0 = xor_h ^ xor_l; - auto mult1 = xor_l; - - // tL coefficient obtained by evaluating on X_i=1 - auto tL0 = coeffsL0 ^ coeffsM0 ^ coeffsR0; - auto tL1 = coeffsL1 ^ coeffsM1 ^ coeffsR1; - - // tM coefficient obtained by evaluating on X_i=\alpha - auto tM0 = coeffsL0 ^ coeffsR0 ^ mult0; - auto tM1 = coeffsL1 ^ coeffsR1 ^ mult1; - - // Explanation: - // cL + cM*\alpha + cR*\alpha^2 - // = cL + cM*\alpha + cR*\alpha + cR - // = cL + cR + \alpha*(cM + cR) - - // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 - *ptrR0 = coeffsL0 ^ coeffsM0 ^ mult0; - *ptrR1 = coeffsL1 ^ coeffsM1 ^ mult1; - - // Explanation: - // cL + cM*(\alpha+1) + cR(\alpha+1)^2 - // = cL + cM + cM*\alpha + cR*(3\alpha + 2) - // = cL + cM + \alpha*(cM + cR) - // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. - - *ptrL0 = tL0; - *ptrL1 = tL1; - *ptrM0 = tM0; - *ptrM1 = tM1; - - ++ptrL0; - ++ptrL1; - ++ptrM0; - ++ptrM1; - ++ptrR0; - ++ptrR1; - } - } - } - - template - void foleageFFTL1L2( - u8* lsb, - u8* msb, - u64 regions - ) - { - //static_assert(depth); - //u64 blockSize = ipow(3, depth - 1); - u64 r = 0; - if constexpr (0 && stride == 2) - { - constexpr auto stepSize = 24; - auto main = regions / stepSize; - block tempLsb[9]; - block tempMsb[9]; - - for (u64 i = 0; i < main; ++i, r += stepSize) - { - auto lsb0 = lsb + r * stride; - auto lsb1 = lsb + r * stride + 16; - auto lsb2 = lsb + r * stride + 32; - - auto msb0 = msb + r * stride; - auto msb1 = msb + r * stride + 16; - auto msb2 = msb + r * stride + 32; - - // 0 1 2 3 4 5 6 7 - // 8 9 10 11 12 13 14 15 - // 16 17 18 19 20 21 22 23 - foleageTransposeLeaf(lsb0, (__m128i*) & tempLsb[0]); - foleageTransposeLeaf(lsb1, (__m128i*) & tempLsb[1]); - foleageTransposeLeaf(lsb2, (__m128i*) & tempLsb[2]); - - foleageTransposeLeaf(msb0, (__m128i*) & tempMsb[0]); - foleageTransposeLeaf(msb1, (__m128i*) & tempMsb[1]); - foleageTransposeLeaf(msb2, (__m128i*) & tempMsb[2]); - - - foleageFFTOne<1>( - &tempLsb[0], &tempMsb[0], - &tempLsb[1], &tempMsb[1], - &tempLsb[2], &tempMsb[2] - ); - - foleageFFTOne<1>( - &tempLsb[3], &tempMsb[3], - &tempLsb[4], &tempMsb[4], - &tempLsb[5], &tempMsb[5] - ); - - foleageFFTOne<1>( - &tempLsb[6], &tempMsb[6], - &tempLsb[7], &tempMsb[7], - &tempLsb[8], &tempMsb[8] - ); - - foleageTranspose((u8*)&tempLsb[0], (__m128i*)lsb0); - - foleageTranspose((u8*)&tempMsb[0], (__m128i*)msb0); - - foleageFFTOne<3>( - (block*)lsb0, (block*)msb0, - (block*)lsb1, (block*)msb1, - (block*)lsb2, (block*)msb2 - ); - } - } - - for (; r < regions; ++r) - { - constexpr u8 blockSize = 3 * stride; - uint8_t* __restrict ptrL0 = lsb + r * 3 * blockSize + 0; - uint8_t* __restrict ptrL1 = msb + r * 3 * blockSize + 0; - uint8_t* __restrict ptrM0 = lsb + r * 3 * blockSize + blockSize; - uint8_t* __restrict ptrM1 = msb + r * 3 * blockSize + blockSize; - uint8_t* __restrict ptrR0 = lsb + r * 3 * blockSize + 2 * blockSize; - uint8_t* __restrict ptrR1 = msb + r * 3 * blockSize + 2 * blockSize; - - - for (u64 j = 0; j < 9; j += 3) - { - - foleageFFTOne( - ptrL0 + (j + 0) * stride, ptrL1 + (j + 0) * stride, - ptrL0 + (j + 1) * stride, ptrL1 + (j + 1) * stride, - ptrL0 + (j + 2) * stride, ptrL1 + (j + 2) * stride - ); - } - - //foleageFFTOne( - // ptrL0 + 0 * stride, ptrL1 + 0 * stride, - // ptrL0 + 1 * stride, ptrL1 + 1 * stride, - // ptrL0 + 2 * stride, ptrL1 + 2 * stride - //); - //foleageFFTOne( - // ptrM0 + 0 * stride, ptrM1 + 0 * stride, - // ptrM0 + 1 * stride, ptrM1 + 1 * stride, - // ptrM0 + 2 * stride, ptrM1 + 2 * stride - //); - - //foleageFFTOne( - // ptrR0 + 0 * stride, ptrR1 + 0 * stride, - // ptrR0 + 1 * stride, ptrR1 + 1 * stride, - // ptrR0 + 2 * stride, ptrR1 + 2 * stride - //); - - - foleageFFTOne( - ptrL0, ptrL1, - ptrM0, ptrM1, - ptrR0, ptrR1 - ); - //foleageFFTOne( - // ptrL0 + 1 * stride, ptrL1 + 1 * stride, - // ptrM0 + 1 * stride, ptrM1 + 1 * stride, - // ptrR0 + 1 * stride, ptrR1 + 1 * stride - //); - //foleageFFTOne( - // ptrL0 + 2 * stride, ptrL1 + 2 * stride, - // ptrM0 + 2 * stride, ptrM1 + 2 * stride, - // ptrR0 + 2 * stride, ptrR1 + 2 * stride - //); - } - } - - template - void foleageFFTL1L2L3( - u8* lsb, - u8* msb, - u64 regions - ) - { - - for (u64 r = 0; r < regions; ++r) - { - constexpr u8 L4blockSize = 27 * stride; - - // L3 has 3 blocks of size L3blockSize - constexpr u8 L3blockSize = 9 * stride; - constexpr u8 L2blockSize = 3 * stride; - constexpr u8 L1blockSize = 1 * stride; - - uint8_t* baseLsb = lsb + r * L4blockSize; - uint8_t* baseMsb = msb + r * L4blockSize; - - - for (u64 k = 0; k < 3; ++k) - { - // left 1/3 - uint8_t* __restrict ptrL0 = baseLsb + k * L3blockSize + 0 * L2blockSize; - uint8_t* __restrict ptrL1 = baseMsb + k * L3blockSize + 0 * L2blockSize; - // middle 1/3 - uint8_t* __restrict ptrM0 = baseLsb + k * L3blockSize + 1 * L2blockSize; - uint8_t* __restrict ptrM1 = baseMsb + k * L3blockSize + 1 * L2blockSize; - // right 1/3 - uint8_t* __restrict ptrR0 = baseLsb + k * L3blockSize + 2 * L2blockSize; - uint8_t* __restrict ptrR1 = baseMsb + k * L3blockSize + 2 * L2blockSize; - - - for (u64 j = 0; j < 9; j += 3) - { - foleageFFTOne( - ptrL0 + (j + 0) * stride, ptrL1 + (j + 0) * stride, - ptrL0 + (j + 1) * stride, ptrL1 + (j + 1) * stride, - ptrL0 + (j + 2) * stride, ptrL1 + (j + 2) * stride - ); - } - - foleageFFTOne( - ptrL0, ptrL1, - ptrM0, ptrM1, - ptrR0, ptrR1 - ); - } - - - foleageFFTOne( - baseLsb + 0 * L3blockSize, baseMsb + 0 * L3blockSize, - baseLsb + 1 * L3blockSize, baseMsb + 1 * L3blockSize, - baseLsb + 2 * L3blockSize, baseMsb + 2 * L3blockSize - ); - - } - } - - void foleageFFT( - uint8_t* lsb, - uint8_t* msb, - const size_t num_vars, - const size_t num_coeffs) - { - //assert(lsb.size() == msb.size()); - //assert(lsb.size() % stride == 0); - //assert(blockSize == 1 || blockSize % 3 == 0); - //assert(blockSize < lsb.size() / stride); - - // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) - //u64 stepSize = blockSize * stride; - - - if (num_vars > 1) - { - // apply FFT on all left coefficients - foleageFFT( - lsb, msb, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all middle coefficients - foleageFFT( - lsb + num_coeffs, - msb + num_coeffs, - num_vars - 1, - num_coeffs / 3); - - // apply FFT on all right coefficients - foleageFFT( - lsb + 2 * num_coeffs, - msb + 2 * num_coeffs, - num_vars - 1, - num_coeffs / 3); - } - - - foleageFFTLevel(lsb, msb, num_coeffs, 1); - } - - template - void foleageFFT2( - span lsb, - span msb) - { - auto n = lsb.size() / stride; - - auto log3N = log3ceil(n); - if (n != ipow(3, log3N)) - throw RTE_LOC; - if (lsb.size() != n * stride) - throw RTE_LOC; - if (lsb.size() != msb.size()) - throw RTE_LOC; - for (u64 i = 1; i <= log3N; ++i) - { - auto regionSize = ipow(3, i); - auto regions = n / regionSize; - - switch (i) - { - case 1: - if(log3N == 1) - foleageFFTLevel(lsb.data(), msb.data(), std::integral_constant{}, regions); - break; - case 2: - // foleageFFTLevel(lsb.data(), msb.data(), std::integral_constant{}, regions); - //if (log3N == 2) - foleageFFTL1L2(lsb.data(), msb.data(), regions); - break; - case 3: - foleageFFTLevel(lsb.data(), msb.data(), std::integral_constant{}, regions); - //foleageFFTL1L2L3(lsb.data(), msb.data(), regions); - break; - case 4: - foleageFFTLevel(lsb.data(), msb.data(), std::integral_constant{}, regions); - break; - default: - u64 blockSize = regionSize / 3 * stride; - foleageFFTLevel(lsb.data(), msb.data(), blockSize, regions); - break; - } - } - - } - - template - void foleageFFT2<2>( - span lsb, - span msb); - template - void foleageFFT2<1>( - span lsb, - span msb); -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/fft/FoleageFft.h b/libOTe/Tools/Foleage/fft/FoleageFft.h deleted file mode 100644 index 3f233e7c..00000000 --- a/libOTe/Tools/Foleage/fft/FoleageFft.h +++ /dev/null @@ -1,388 +0,0 @@ -#pragma once - -#include -#include -#include "cryptoTools/Common/Defines.h" -#include "cryptoTools/Common/MatrixView.h" -#include "libOTe/Tools/Foleage/FoleageUtils.h" -#include - -//#include "libOTe/Tools/Foleage/utils.h" -namespace osuCrypto { - - //typedef __int128 int128_t; - //typedef unsigned __int128 uint128_t; - - // FFT for (up to) 32 polynomials over F4 - void fft_recursive_uint64( - span coeffs, - const size_t num_vars, - const size_t num_coeffs); - - // FFT for (up to) 16 polynomials over F4 - void fft_recursive_uint32( - span coeffs, - const size_t num_vars, - const size_t num_coeffs); - - // FFT for (up to) 8 polynomials over F4 - void fft_recursive_uint16( - span coeffs, - const size_t num_vars, - const size_t num_coeffs); - - // FFT for (up to) 4 polynomials over F4 - void foliageFftUint8( - span coeffs, - const size_t num_vars, - const size_t num_coeffs); - - - - void foleageFFT( - uint8_t* lsb, - uint8_t* msb, - const size_t num_vars, - const size_t num_coeffs); - - inline void printShuffle1(const u16* ptr) - { - - for (u64 j = 0; j < 8; ++j) - { - auto v = ptr[j]; - std::cout << std::setw(2) << std::setfill(' ') << v << " "; - } - } - inline void printShuffle3(const u16* ptr) - { - for (u64 i = 0; i < 3; ++i) - { - printShuffle1(ptr + i * 8); - std::cout << std::endl; - } - } - - inline void printShuffle9(const u16* ptr) - { - for (u64 i = 0; i < 3; ++i) - { - printShuffle1(ptr + i * 24); - printShuffle1(ptr + i * 24 + 8); - printShuffle1(ptr + i * 24 + 16); - std::cout << std::endl; - } - } - // shuffles 3 blocks or 48 bytes - template - void foleageTransposeLeaf(u8* src, __m128i* dst) - { - - if constexpr (stride == 2) - { - // input: - // 0 1 2 3 4 5 6 7 - // 8 9 10 11 12 13 14 15 - // 16 17 18 19 20 21 22 23 - // - // output: - // 0 3 6 9 12 15 18 21 - // 1 4 7 10 13 16 19 22 - // 2 5 8 11 14 17 20 23 - - if (1) - { - // 0 6 12 18 - auto a0 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(18, 12, 6, 0), 2); - // 3 9 15 21 - auto a1 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(20, 14, 8, 2), 2); - // 0 3 6 9 12 15 18 21 - dst[0] = _mm_blendv_epi8(a0, a1, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - - // 1 7 13 19 - auto b0 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(19, 13, 7, 1), 2); - // 4 10 16 22 - auto b1 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(21, 15, 9, 3), 2); - // 1 4 7 10 13 16 19 22 - dst[1] = _mm_blendv_epi8(b0, b1, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - - // 2 8 14 20 - auto c0 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(20, 14, 8, 2), 2); - // 5 11 17 23 - auto c1 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(22, 16, 10, 4), 2); - // 2 5 8 11 14 17 20 23 - dst[2] = _mm_blendv_epi8(c0, c1, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - - } - else - { - - - // 0 1 2 3 4 5 6 7 - auto v0 = _mm_loadu_si128((__m128i*)src); - - // 8 9 10 11 12 13 14 15 - auto v1 = _mm_loadu_si128((__m128i*)(src + 16)); - - // 16 17 18 19 20 21 22 23 - auto v2 = _mm_loadu_si128((__m128i*)(src + 32)); - - // 0 3 6 1 4 7 2 5 - // 0 0c 0d 0e 0f, 1a 1b 1c 1d 1e 1f, 2a 2b 2c 2d - v0 = _mm_shuffle_epi8(v0, _mm_set_epi8(11, 10, 5, 4, 15, 14, 9, 8, 3, 2, 13, 12, 7, 6, 1, 0)); - - // 8 11 14 9 12 15 10 13 - // 2e ef 2g 2h 2i ej, 0g 0h 0i 0j 0k 0l, 1g 1h 1i 1j - v1 = _mm_shuffle_epi8(v1, _mm_set_epi8(11, 10, 5, 4, 15, 14, 9, 8, 3, 2, 13, 12, 7, 6, 1, 0)); - - // 16 19 22 17 20 23 18 21 - // 1k 1l 1m 1n 1o 1p, 2k 2l 2m 2n 2o 2p, 0m 0n 0o 0p - v2 = _mm_shuffle_epi8(v2, _mm_set_epi8(11, 10, 5, 4, 15, 14, 9, 8, 3, 2, 13, 12, 7, 6, 1, 0)); - - // 0 3 6 9 12 15 18 21 - // 0 0c 0d 0e 0f, 0g 0h 0i 0j 0k 0l, 1g 1h 1i 1j - auto u0 = _mm_blendv_epi8(v0, v1, _mm_set_epi16(-1, -1, -1, -1, -1, 0, 0, 0)); - - // 0 3 6 17 20 23 18 21 - // 0 0c 0d 0e 0f, 0g 0h 0i 0j 0k 0l, 0m 0n 0o 0p - u0 = _mm_blendv_epi8(u0, v2, _mm_set_epi16(-1, -1, 0, 0, 0, 0, 0, 0)); - - // 16 19 22 1 4 7 2 5 - // 1k 1l 1m 1n 1o 1p, 1a 1b 1c 1d 1e 1f, 2a 2b 2c 2d - auto u1 = _mm_blendv_epi8(v2, v0, _mm_set_epi16(-1, -1, -1, -1, -1, 0, 0, 0)); - - // 16 19 22 1 4 7 10 13 - // 1k 1l 1m 1n 1o 1p, 1a 1b 1c 1d 1e 1f, 1g 1h 1i 1j - u1 = _mm_blendv_epi8(u1, v1, _mm_set_epi16(-1, -1, 0, 0, 0, 0, 0, 0)); - - // 1 4 7 10 13 16 19 22 - // 1a 1b 1c 1d 1e 1f 1g 1h 1i 1j 1k 1l 1m 1n 1o 1p - u1 = _mm_shuffle_epi8(u1, _mm_set_epi8(5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6)); - - - // 8 11 14 17 20 23 18 21 - // 2e ef 2g 2h 2i ej, 2k 2l 2m 2n 2o 2p, 0m 0n 0o 0p - auto u2 = _mm_blendv_epi8(v1, v2, _mm_set_epi16(-1, -1, -1, -1, -1, 0, 0, 0)); - - // 8 11 14 17 20 23 2 5 - // 2e ef 2g 2h 2i ej, 2k 2l 2m 2n 2o 2p, 2a 2b 2c 2d - u2 = _mm_blendv_epi8(u2, v0, _mm_set_epi16(-1, -1, 0, 0, 0, 0, 0, 0)); - - // 2 5 8 11 14 17 20 23 - // 2a 2b 2c 2d 2e ef 2g 2h 2i ej 2k 2l 2m 2n 2o 2p, - u2 = _mm_shuffle_epi8(u2, _mm_set_epi8(11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12)); - - - _mm_store_si128(dst, u0); - _mm_store_si128(dst + 1, u1); - _mm_store_si128(dst + 2, u2); - } - } - else - { - throw RTE_LOC; - } - } - - // src points at the input data. Logically, there are 3 rows and 24 columns. - // each element is of stride bytes. The output is 3 rows and 8 columns. Each - // element is of stride * 3 bytes. The i'th element in the output are the - // three elements in the i'th column of the input. - // - // the input has 8 columns of row 0, then 8 columns row 1, 8 columns row 2, then repeates. - template - void foleageTranspose(u8* __restrict src, __m128i* __restrict dst) - { - if constexpr (stride == 2) - { - // input data: - // 0 1 2 3 4 5 6 7 - // 8 9 10 11 12 13 14 15 - // 16 17 18 19 20 21 22 23 - // - // 24 25 26 27 28 29 30 31 - // 32 33 34 35 36 37 38 39 - // 40 41 42 43 44 45 46 47 - // - // 48 49 50 51 52 53 54 55 - // 56 57 58 59 60 61 62 63 - // 64 65 66 67 68 69 70 71 - // - - // the input comes in 16 byte chunks. chunks {0,3,6},{1,4,7},{2,5,8} each belong to the same FFT position {0,1,2}. If we lay out the data - // logically we get: - // | | | | | | | - // 0 1 2 3 4 5 6 7 24 25 26 27 28 29 30 31 48 49 50 51 52 53 54 55 - // 8 9 10 11 12 13 14 15 32 33 34 35 36 37 38 39 56 57 58 59 60 61 62 63 - // 16 17 18 19 20 21 22 23 40 41 42 43 44 45 46 47 64 65 66 67 68 69 70 71 - // | | | | | | | - // - // at the previous FFT level, each column corresponds to a FFT instance, e.g. sub blocks {0,8,16}, {1,9,17}, ... - // - // We now want to merge these sub blocks into a single block. This corresponds - // to doing a 3x3 sub block transpose. - // - // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 - // | | | | | | | - // 0 8 16 3 11 19 6 14 22 25 33 41 28 36 44 31 39 47 50 58 66 53 61 69 - // 1 9 17 4 12 20 7 15 23 26 34 42 29 37 45 48 56 64 51 59 67 54 62 70 - // 2 10 18 5 13 21 24 32 40 27 35 43 30 38 46 49 57 65 52 60 68 55 63 71 - // | | | | | | | - // - // We are going to transpose using the i32gather instruction. We want the output to be stored with - // each row being contiguous, e.g. "0 8 ... 69" should all be next to eachother. - // Each position takes up stride=2 bytes. But the i32gather instruction works on 4 byte chunks. - // So we will split the 8 gathered values across two instructions. One to gather the even - // positions and one the odd. We will then blend these two together to get the final output. eg: - - //0 8 16 3 11 19 6 14 | 22 25 33 41 28 36 44 31 |39 47 50 58 66 53 61 69 - //0 16 11 6 | 22 33 28 44 |39 50 66 61 - // 8 3 19 14 | 25 41 36 31 | 47 58 53 69 - // - // For row 0 we want to select - // * blend(gather(0,16,11,6),gatherHigh(8,3,19,14)) - // * blend(gather(22,33,28,44),gatherHigh(35,41,36,31)) - // * blend(gather(39,50,66,61),gatherHigh(47,58,53,69)) - // - // where gatherHigh(a,b,c,d) = gather(a-1,b-1,c-1,d-1), - // and blend(...) takes every other 16 bits. - // - // The other rows follow the same logic. - // - // the final set of indices are (each 4 are in reverse order to match _mm_set_epi32): - // - // 6,11,16, 0 | 44,28,33,22 | 61,66,50,39 - // 13,18, 2, 7 | 30,35,40,34 | 68,52,57,46 - // - //* 7,12,17,1 | 45,29,34,23 | 62,67,51,56 - //* 14,19, 3,8 | 47,36,41,25 | 69,53,58,63 - // - //* 24,13,18,2 | 46,30,35,40 | 63,68,52,57 - //* 31,20, 4,9 | 48,37,42,26 | 70,54,59,64 - - // 0 0 - auto a00 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(6, 11, 16, 0), 2); - auto a01 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(13, 18, 2, 7), 2); - dst[0] = _mm_blendv_epi8(a00, a01, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - auto a10 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(44, 28, 33, 22), 2); - auto a11 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(30, 35, 40, 24), 2); - dst[1] = _mm_blendv_epi8(a10, a11, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - auto a20 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(61, 66, 50, 39), 2); - auto a21 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(68, 52, 57, 46), 2); - dst[2] = _mm_blendv_epi8(a20, a21, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - - - auto b00 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(7, 12, 17, 1), 2); - auto b01 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(14, 19, 3, 8), 2); - dst[3] = _mm_blendv_epi8(b00, b01, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - auto b10 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(45, 29, 34, 23), 2); - auto b11 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(47, 36, 41, 25), 2); - dst[4] = _mm_blendv_epi8(b10, b11, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - auto b20 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(62, 67, 51, 56), 2); - auto b21 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(69, 53, 58, 63), 2); - dst[5] = _mm_blendv_epi8(b20, b21, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - - - auto c00 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(24, 13, 18, 2), 2); - auto c01 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(31, 20, 4, 9), 2); - dst[6] = _mm_blendv_epi8(c00, c01, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - auto c10 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(46, 30, 35, 40), 2); - auto c11 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(48, 37, 42, 26), 2); - dst[7] = _mm_blendv_epi8(c10, c11, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - auto c20 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(63, 68, 52, 57), 2); - auto c21 = _mm_i32gather_epi32((int*)src, _mm_set_epi32(70, 54, 59, 64), 2); - dst[8] = _mm_blendv_epi8(c20, c21, _mm_set_epi16(-1, 0, -1, 0, -1, 0, -1, 0)); - - } - else - { - throw RTE_LOC; - } - } - - template - void foliageUnTranspose(u8* src, __m128i* dst) - { - constexpr std::array inv{ - 0, 1, 2, 24, 25, 26, 48, 49, 50, - 3, 4, 5, 27, 28, 29, 51, 52, 53, - 6, 7, 8, 30, 31, 32, 54, 55, 56, - 9, 10, 11, 33, 34, 35, 57, 58, 59, - 12, 13, 14, 36, 37, 38, 60, 61, 62, - 15, 16, 17, 39, 40, 41, 63, 64, 65, - 18, 19, 20, 42, 43, 44, 66, 67, 68, - 21, 22, 23, 45, 46, 47, 69, 70, 71 - }; - - auto dstPtr = (u8*)dst; - for (u64 i = 0; i < inv.size(); ++i) - { - memcpy(dstPtr, src + inv[i] * stride, stride); - dstPtr += stride; - } - - - // 0 1 2 24 25 26 48 49 50 3 4 5 27 28 29 51 52 53 6 7 8 30 31 32 - // 54 55 56 9 10 11 33 34 35 57 58 59 12 13 14 36 37 38 60 61 62 15 16 17 - // 39 40 41 63 64 65 18 19 20 42 43 44 66 67 68 21 22 23 45 46 47 69 70 71 - } - - - - template - OC_FORCEINLINE void foleageFFTOne( - T* __restrict coeffsL0, - T* __restrict coeffsL1, - T* __restrict coeffsM0, - T* __restrict coeffsM1, - T* __restrict coeffsR0, - T* __restrict coeffsR1) - { - - for (u64 i = 0; i < stride; ++i) - { - - auto xor_h = coeffsM1[i] ^ coeffsR1[i]; - auto xor_l = coeffsM0[i] ^ coeffsR0[i]; - - auto mult0 = xor_h ^ xor_l; - auto mult1 = xor_l; - - // tL coefficient obtained by evaluating on X_i=1 - auto tL0 = coeffsL0[i] ^ xor_l; - auto tL1 = coeffsL1[i] ^ xor_h; - auto tM0 = coeffsL0[i] ^ coeffsR0[i] ^ mult0; - auto tM1 = coeffsL1[i] ^ coeffsR1[i] ^ mult1; - coeffsR0[i] = coeffsL0[i] ^ coeffsM0[i] ^ mult0; - coeffsR1[i] = coeffsL1[i] ^ coeffsM1[i] ^ mult1; - coeffsL0[i] = tL0; - coeffsL1[i] = tL1; - coeffsM0[i] = tM0; - coeffsM1[i] = tM1; - } - } - - - - inline void foleageFFT( - MatrixView lsb, - MatrixView msb) - { - if (lsb.rows() != msb.rows()) - throw RTE_LOC; - if (lsb.cols() != msb.cols()) - throw RTE_LOC; - auto numCoeffs = lsb.rows(); - if (numCoeffs % 3) - throw RTE_LOC; - auto numVars = log3ceil(numCoeffs); - foleageFFT(lsb.data(), msb.data(), numVars, lsb.size() / 3); - } - - - template - void foleageFFT2( - span lsb, - span msb); - -} diff --git a/libOTe/Tools/Foleage/spfss_test.cpp b/libOTe/Tools/Foleage/spfss_test.cpp deleted file mode 100644 index 237d331b..00000000 --- a/libOTe/Tools/Foleage/spfss_test.cpp +++ /dev/null @@ -1,115 +0,0 @@ -//#include -//#include -//#include -// -//#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" -//#include "FoleageUtils.h" -// -//#define SUMT 730 // sum of T DPFs -// -//#define FULLEVALDOMAIN 10 -//#define MESSAGESIZE 8 -//#define MAXRANDINDEX ipow(3, FULLEVALDOMAIN) -//namespace osuCrypto -//{ -// -// double benchmarkAES() -// { -// size_t num_leaves = ipow(3, FULLEVALDOMAIN); -// size_t size = FULLEVALDOMAIN; -// PRNG prng(block(3423423)); -// -// PRFKeys prf_keys; -// prf_keys.gen(prng); -// -// AlignedUnVector data_in (num_leaves * MESSAGESIZE); -// AlignedUnVector data_out(num_leaves * MESSAGESIZE); -// AlignedUnVector data_tmp(num_leaves * MESSAGESIZE); -// AlignedUnVector tmp; -// -// // fill with unique data -// for (size_t i = 0; i < num_leaves * MESSAGESIZE; i++) -// data_tmp[i] = block(i); -// -// // make the input data pseudorandom for correct timing -// PRFBatchEval(prf_keys.prf_key0, data_tmp, data_in, num_leaves * MESSAGESIZE); -// -// //************************************************ -// // Benchmark AES encryption time required in DPF loop -// //************************************************ -// -// clock_t t; -// t = clock(); -// -// for (size_t n = 0; n < SUMT; n++) -// { -// size_t num_nodes = 1; -// for (size_t i = 0; i < size; i++) -// { -// PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes); -// PRFBatchEval(prf_keys.prf_key1, data_in, data_out.subspan(num_nodes), num_nodes); -// PRFBatchEval(prf_keys.prf_key2, data_in, data_out.subspan(num_nodes * 2), num_nodes); -// -// tmp = data_out; -// data_out = data_in; -// data_in = tmp; -// -// num_nodes *= 3; -// } -// // compute AES part of output extension -// PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes * MESSAGESIZE); -// } -// -// t = clock() - t; -// double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms -// -// printf("Time %f ms\n", time_taken); -// -// return time_taken; -// } -// -// int mainSpfss(int argc, char** argv) -// { -// -// double time = 0; -// int testTrials = 10; -// -// //printf("******************************************\n"); -// //printf("Testing DPF.FullEval\n"); -// //for (int i = 0; i < testTrials; i++) -// //{ -// // time += foliage_spfss_test(); -// // printf("Done with trial %i of %i\n", i + 1, testTrials); -// //} -// //printf("******************************************\n"); -// //printf("PASS\n"); -// //printf("DPF.FullEval: (avg time) %0.2f ms\n", time / testTrials); -// //printf("******************************************\n\n"); -// -// time = 0; -// //printf("******************************************\n"); -// //printf("Benchmarking DPF.Gen\n"); -// //for (int i = 0; i < testTrials; i++) -// //{ -// // time += benchmark_spfss(); -// // printf("Done with trial %i of %i\n", i + 1, testTrials); -// //} -// //printf("******************************************\n"); -// //printf("Avg time: %0.4f ms\n", time / testTrials); -// //printf("******************************************\n\n"); -// -// time = 0; -// printf("******************************************\n"); -// printf("Benchmarking AES\n"); -// for (int i = 0; i < testTrials; i++) -// { -// time += benchmarkAES(); -// printf("Done with trial %i of %i\n", i + 1, testTrials); -// } -// printf("******************************************\n"); -// printf("Avg time: %0.2f ms\n", time / testTrials); -// printf("******************************************\n\n"); -// -// return 0; -// } -//} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/tri-dpf/.gitignore b/libOTe/Tools/Foleage/tri-dpf/.gitignore deleted file mode 100644 index 71035e10..00000000 --- a/libOTe/Tools/Foleage/tri-dpf/.gitignore +++ /dev/null @@ -1,5 +0,0 @@ -*.json -*.o -*.a -.DS_Store -bin diff --git a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp deleted file mode 100644 index 28f4ca66..00000000 --- a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.cpp +++ /dev/null @@ -1,317 +0,0 @@ - -#include "FoleageDpf.h" - -#include "libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h" - - -//#include - -#define LOG_BATCH_SIZE 6 // operate in smallish batches to maximize cache hits -namespace osuCrypto -{ - - // Naming conventions: - // - A,B refer to shares given to parties A and B - // - 0,1,2 refer to the branch index in the ternary tree - void DPFGen( - PRFKeys& prf_keys, - size_t domain_size, - size_t index, - span msg_blocks, - size_t msg_block_len, - DPFKey& k0, - DPFKey& k1, - PRNG& prng) - { - - // starting seeds given to each party - block seedA = prng.get(); - block seedB = prng.get(); - - // correction word provided to both parties - // (one correction word per level) - std::vector sCW0(domain_size); - std::vector sCW1(domain_size); - std::vector sCW2(domain_size); - - // variables for the intermediate values - block parent, parentA, parentB, sA0, sA1, sA2, sB0, sB1, sB2; - - // current parent value (xor of the two seeds) - parent = seedA ^ seedB; - - // control bit of the parent on the special path must always be set to 1 - // so as to apply the corresponding correction word - if (get_lsb(parent) == ZeroBlock) - seedA = flip_lsb(seedA); - - parentA = seedA; - parentB = seedB; - - block prev_control_bit_A, prev_control_bit_B; - - for (size_t i = 0; i < domain_size; i++) - { - prev_control_bit_A = get_lsb(parentA); - prev_control_bit_B = get_lsb(parentB); - - // expand the starting seeds of each party - PRFEval(prf_keys.prf_key0, parentA, sA0); - PRFEval(prf_keys.prf_key1, parentA, sA1); - PRFEval(prf_keys.prf_key2, parentA, sA2); - PRFEval(prf_keys.prf_key0, parentB, sB0); - PRFEval(prf_keys.prf_key1, parentB, sB1); - PRFEval(prf_keys.prf_key2, parentB, sB2); - - // on-path correction word is set to random - // so as to be indistinguishable from the real correction words - block r = prng.get(); - - // get the current trit (ternary bit) of the special index - uint8_t trit = get_trit(index, domain_size, i); - - switch (trit) - { - case 0: - parent = sA0 ^ sB0 ^ r; - if (get_lsb(parent) == ZeroBlock) - r = flip_lsb(r); - - sCW0[i] = r; - sCW1[i] = sA1 ^ sB1; - sCW2[i] = sA2 ^ sB2; - - if (get_lsb(parentA) == AllOneBlock) - { - parentA = sA0 ^ r; - parentB = sB0; - } - else - { - parentA = sA0; - parentB = sB0 ^ r; - } - - break; - - case 1: - parent = sA1 ^ sB1 ^ r; - if (get_lsb(parent) == ZeroBlock) - r = flip_lsb(r); - - sCW0[i] = sA0 ^ sB0; - sCW1[i] = r; - sCW2[i] = sA2 ^ sB2; - - if (get_lsb(parentA) == AllOneBlock) - { - parentA = sA1 ^ r; - parentB = sB1; - } - else - { - parentA = sA1; - parentB = sB1 ^ r; - } - - break; - - case 2: - parent = sA2 ^ sB2 ^ r; - if (get_lsb(parent) == ZeroBlock) - r = flip_lsb(r); - - sCW0[i] = sA0 ^ sB0; - sCW1[i] = sA1 ^ sB1; - sCW2[i] = r; - - if (get_lsb(parentA) == AllOneBlock) - { - parentA = sA2 ^ r; - parentB = sB2; - } - else - { - parentA = sA2; - parentB = sB2 ^ r; - } - - break; - - default: - printf("error: not a ternary digit!\n"); - exit(0); - } - } - - // set the last correction word to correct the output to msg - block leaf_seedA, leaf_seedB; - uint8_t last_trit = get_trit(index, domain_size, domain_size - 1); - if (last_trit == 0) - { - leaf_seedA = sA0 ^ prev_control_bit_A & sCW0[domain_size - 1]; - leaf_seedB = sB0 ^ prev_control_bit_B & sCW0[domain_size - 1]; - } - else if (last_trit == 1) - { - leaf_seedA = sA1 ^ prev_control_bit_A & sCW1[domain_size - 1]; - leaf_seedB = sB1 ^ prev_control_bit_B & sCW1[domain_size - 1]; - } - - else if (last_trit == 2) - { - leaf_seedA = sA2 ^ prev_control_bit_A & sCW2[domain_size - 1]; - leaf_seedB = sB2 ^ prev_control_bit_B & sCW2[domain_size - 1]; - } - - AlignedUnVector outputA(msg_block_len); - AlignedUnVector outputB(msg_block_len); - AlignedUnVector cache(msg_block_len); - AlignedUnVector outputCW(msg_block_len); - - outputA[0] = leaf_seedA; - outputB[0] = leaf_seedB; - - ExtendOutput(prf_keys, outputA, cache, 1, msg_block_len); - ExtendOutput(prf_keys, outputB, cache, 1, msg_block_len); - - for (size_t i = 0; i < msg_block_len; i++) - outputCW[i] = outputA[i] ^ outputB[i] ^ msg_blocks[i]; - - // memcpy all the generated values into two keys - // 16 = sizeof(uint128_t) - size_t key_size = sizeof(block); // initial seed size; - key_size += 3 * domain_size * sizeof(block); // correction words - key_size += sizeof(block) * msg_block_len; // output correction word - - k0.prf_keys = &prf_keys; - k0.k.resize(key_size); - k0.size = domain_size; - k0.msg_len = msg_block_len; - memcpy(&k0.k[0], &seedA, 16); - memcpy(&k0.k[16], &sCW0[0], domain_size * 16); - memcpy(&k0.k[16 * domain_size + 16], &sCW1[0], domain_size * 16); - memcpy(&k0.k[16 * 2 * domain_size + 16], &sCW2[0], domain_size * 16); - memcpy(&k0.k[16 * 3 * domain_size + 16], &outputCW[0], msg_block_len * 16); - - k1.prf_keys = &prf_keys; - k1.k.resize(key_size); - k1.size = domain_size; - k1.msg_len = msg_block_len; - memcpy(&k1.k[0], &seedB, 16); - memcpy(&k1.k[16], &sCW0[0], domain_size * 16); - memcpy(&k1.k[16 * domain_size + 16], &sCW1[0], domain_size * 16); - memcpy(&k1.k[16 * 2 * domain_size + 16], &sCW2[0], domain_size * 16); - memcpy(&k1.k[16 * 3 * domain_size + 16], &outputCW[0], msg_block_len * 16); - - //free(outputA); - //free(outputB); - //free(cache); - //free(outputCW); - } - - // evaluates the full DPF domain; much faster than - // batching the evaluation points since each level of the DPF tree - // is only expanded once. - void DPFFullDomainEval( - DPFKey& key, - span cache, - span output) - { - size_t size = key.size; - span k = key.k; - PRFKeys& prf_keys = *key.prf_keys; - - if (size % 2 == 1) - { - auto tmp = cache; - cache = output; - output = tmp; - } - - // full_eval_size = pow(3, size); - const size_t num_leaves = ipow(3, size); - - memcpy(&output[0], &k[0], 16); // output[0] is the start seed - const block* sCW0 = (block*)&k[16]; - const block* sCW1 = (block*)&k[16 * size + 16]; - const block* sCW2 = (block*)&k[16 * 2 * size + 16]; - - // inner loop variables related to node expansion - // and correction word application - span tmp; - size_t idx0, idx1, idx2; - block cb = ZeroBlock; - - // batching variables related to chunking of inner loop processing - // for the purpose of maximizing cache hits - size_t max_batch_size = ipow(3, LOG_BATCH_SIZE); - size_t batch, num_batches, batch_size, offset; - - size_t num_nodes = 1; - for (uint8_t i = 0; i < size; i++) - { - if (i < LOG_BATCH_SIZE) - { - batch_size = num_nodes; - num_batches = 1; - } - else - { - batch_size = max_batch_size; - num_batches = num_nodes / max_batch_size; - } - - offset = 0; - for (batch = 0; batch < num_batches; batch++) - { - PRFBatchEval(prf_keys.prf_key0, output.subspan(offset), cache.subspan(offset), batch_size); - PRFBatchEval(prf_keys.prf_key1, output.subspan(offset), cache.subspan(num_nodes + offset), batch_size); - PRFBatchEval(prf_keys.prf_key2, output.subspan(offset), cache.subspan((num_nodes * 2) + offset), batch_size); - - idx0 = offset; - idx1 = num_nodes + offset; - idx2 = (num_nodes * 2) + offset; - - while (idx0 < offset + batch_size) - { - cb = get_lsb(output[idx0]); // gets the LSB of the parent - cache[idx0] ^= (cb & sCW0[i]); - cache[idx1] ^= (cb & sCW1[i]); - cache[idx2] ^= (cb & sCW2[i]); - - idx0++; - idx1++; - idx2++; - } - - offset += batch_size; - } - - tmp = output; - output = cache; - cache = tmp; - - num_nodes *= 3; - } - - const size_t output_length = key.msg_len * num_leaves; - const size_t msg_len = key.msg_len; - block* outputCW = (block*)&k[16 * 3 * size + 16]; - ExtendOutput(prf_keys, output, cache, num_leaves, output_length); - - size_t j = 0; - for (size_t i = 0; i < num_leaves; i++) - { - // TODO: a bit hacky, assumes that cache[i*msg_len] = old_output[i] - // which is the case internally in ExtendOutput. It would be good - // to remove this assumption however using memcpy is costly... - - if (get_lsb(cache[i * msg_len]) != ZeroBlock) // parent control bit - { - for (j = 0; j < msg_len; j++) - output[i * msg_len + j] ^= outputCW[j]; - } - } - } -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h deleted file mode 100644 index a81e5c48..00000000 --- a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h +++ /dev/null @@ -1,35 +0,0 @@ -#pragma once - -#include -#include - -#include "libOTe/Tools/Foleage/FoleageUtils.h" -#include "libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h" - - -namespace osuCrypto -{ - struct DPFKey - { - PRFKeys* prf_keys; - AlignedUnVector k; - size_t msg_len; - size_t size; - }; - - void DPFGen( - PRFKeys& prf_keys, - size_t domain_size, - size_t index, - span msg_blocks, - size_t msg_block_len, - DPFKey& k0, - DPFKey& k1, - PRNG& prng); - - void DPFFullDomainEval( - DPFKey& k, - span cache, - span output); - -} diff --git a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp deleted file mode 100644 index 2a596a8f..00000000 --- a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.cpp +++ /dev/null @@ -1,166 +0,0 @@ -//#include -//#include -//#include -//#include -#include -#include -#include - -#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" -//#include "libOTe/Tools/Foleage/tri-dpf/FoleageHalfDpf.h" -#include - -#define FULLEVALDOMAIN 14 -#define MESSAGESIZE 2 -#define MAXRANDINDEX ipow(3, FULLEVALDOMAIN) -namespace osuCrypto -{ - size_t randIndex(PRNG& prng) - { - return prng.get() % (size_t)MAXRANDINDEX; - } - //using int128_t = uint128_t; - block randMsg(PRNG& prng) - { - return prng.get(); - //uint128_t msg; - //RAND_bytes((uint8_t*)&msg, sizeof(uint128_t)); - //return msg; - } - - double benchmark_dpfGen() - { - //size_t num_leaves = ipow(3, FULLEVALDOMAIN); - size_t size = FULLEVALDOMAIN; // evaluation will result in 3^size points - PRNG prng(block(3423423)); - size_t secret_index = randIndex(prng); - block secret_msg = randMsg(prng); - size_t msg_len = 1; - - PRFKeys prf_keys; - prf_keys.gen(prng); - - DPFKey kA; - DPFKey kB; - - clock_t t; - t = clock(); - DPFGen(prf_keys, size, secret_index, span(&secret_msg,1), msg_len, kA, kB, prng); - t = clock() - t; - double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("Time %f ms\n", time_taken); - - return time_taken; - } - - double benchmark_dpfAES() - { - size_t num_leaves = ipow(3, FULLEVALDOMAIN); - size_t size = FULLEVALDOMAIN; - - PRNG prng(block(3423423)); - PRFKeys prf_keys; - prf_keys.gen(prng); - - AlignedUnVector data_in(num_leaves * MESSAGESIZE); - AlignedUnVector data_out(num_leaves * MESSAGESIZE); - AlignedUnVector data_tmp(num_leaves * MESSAGESIZE); - AlignedUnVector tmp; - - // fill with unique data - for (size_t i = 0; i < num_leaves * MESSAGESIZE; i++) - data_tmp[i] = block(i); - - // make the input data pseudorandom for correct timing - PRFBatchEval(prf_keys.prf_key0, data_tmp, data_in, num_leaves * MESSAGESIZE); - - //************************************************ - // Benchmark AES encryption time required in DPF loop - //************************************************ - - clock_t t; - t = clock(); - size_t num_nodes = 1; - for (size_t i = 0; i < size; i++) - { - PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes); - PRFBatchEval(prf_keys.prf_key1, data_in, data_out.subspan(num_nodes), num_nodes); - PRFBatchEval(prf_keys.prf_key2, data_in, data_out.subspan(num_nodes * 2), num_nodes); - - tmp = data_out; - data_out = data_in; - data_in = tmp; - - num_nodes *= 3; - } - - // compute AES part of output extension - PRFBatchEval(prf_keys.prf_key0, data_in, data_out, num_nodes * MESSAGESIZE); - - t = clock() - t; - double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("Time %f ms\n", time_taken); - - return time_taken; - } - - int main_test_tridpf(int argc, char** argv) - { - - double time = 0; - int testTrials = 3; - - //printf("******************************************\n"); - //printf("Testing DPF.FullEval\n"); - //for (int i = 0; i < testTrials; i++) - //{ - // time += foliage_dpf_test(); - // printf("Done with trial %i of %i\n", i + 1, testTrials); - //} - //printf("******************************************\n"); - //printf("PASS\n"); - //printf("DPF.FullEval: (avg time) %0.2f ms\n", time / testTrials); - //printf("******************************************\n\n"); - - //time = 0; - //printf("******************************************\n"); - //printf("Testing HalfDPF.FullEval\n"); - //for (int i = 0; i < testTrials; i++) - //{ - // time += foliage_Halfdpf_test(); - // printf("Done with trial %i of %i\n", i + 1, testTrials); - //} - //printf("******************************************\n"); - //printf("PASS\n"); - //printf("HalfDPF.FullEval: (avg time) %0.2f ms\n", time / testTrials); - //printf("******************************************\n\n"); - - time = 0; - printf("******************************************\n"); - printf("Benchmarking DPF.Gen\n"); - for (int i = 0; i < testTrials; i++) - { - time += benchmark_dpfGen(); - printf("Done with trial %i of %i\n", i + 1, testTrials); - } - printf("******************************************\n"); - printf("Avg time: %0.4f ms\n", time / testTrials); - printf("******************************************\n\n"); - - time = 0; - printf("******************************************\n"); - printf("Benchmarking AES\n"); - for (int i = 0; i < testTrials; i++) - { - time += benchmark_dpfAES(); - printf("Done with trial %i of %i\n", i + 1, testTrials); - } - printf("******************************************\n"); - printf("Avg time: %0.2f ms\n", time / testTrials); - printf("******************************************\n\n"); - - return 0; - } -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.h b/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.h deleted file mode 100644 index 9776388a..00000000 --- a/libOTe/Tools/Foleage/tri-dpf/FoleageDpf_test.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - - -namespace osuCrypto -{ - void foliage_dpf_test(); - void foliage_Halfdpf_test(); - -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h b/libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h deleted file mode 100644 index df53f946..00000000 --- a/libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once - - -#include -#include "cryptoTools/Crypto/AES.h" -//#include "utils.h" -#include "libOTe/Tools/Foleage/FoleageUtils.h" - -namespace osuCrypto -{ - - - using EVP_CIPHER_CTX = oc::AES; - struct PRFKeys - { - PRFKeys() = default; - - - void gen(PRNG& prng) - { - prf_key0.setKey(prng.get()); - prf_key1.setKey(prng.get()); - prf_key2.setKey(prng.get()); - prf_key_ext.setKey(prng.get()); - } - - - - EVP_CIPHER_CTX prf_key0; - EVP_CIPHER_CTX prf_key1; - EVP_CIPHER_CTX prf_key2; - EVP_CIPHER_CTX prf_key_ext; - }; - - //void PRFKeyGen(struct PRFKeys* prf_keys); - //void DestroyPRFKey(struct PRFKeys* prf_keys); - - // XOR with input to prevent inversion using Davies–Meyer construction - inline void PRFEval(EVP_CIPHER_CTX& ctx, block& input, block& outputs) - { - outputs = ctx.hashBlock(input); - } - - // PRF used to expand the DPF tree. Just a call to AES-ECB. - // Note: we use ECB-mode (instead of CTR) as we want to manage each block separately. - // XOR with input to prevent inversion using Davies–Meyer construction - inline void PRFBatchEval(EVP_CIPHER_CTX& ctx, span input, span outputs, u64 num_blocks) - { - if (num_blocks > input.size()) - throw RTE_LOC; - if (num_blocks > outputs.size()) - throw RTE_LOC; - ctx.hashBlocks((block*)input.data(), num_blocks, (block*)outputs.data()); - } - - // extends the output by the provided factor using the PRG - inline void ExtendOutput( - PRFKeys& prf_keys, - span output, - span cache, - const size_t output_size, - const size_t new_output_size) - { - - if (new_output_size % output_size != 0) - throw std::runtime_error("ERROR: new_output_size needs to be a multiple of output_size. " LOCATION); - if (new_output_size < output_size) - throw std::runtime_error("ERROR: new_output_size < output_size" LOCATION); - - size_t factor = new_output_size / output_size; - - for (size_t i = 0; i < output_size; i++) - { - for (size_t j = 0; j < factor; j++) - cache[i * factor + j] = output[i] ^ block(0, j); - } - - PRFBatchEval(prf_keys.prf_key_ext, cache, output, new_output_size); - } -} diff --git a/libOTe/Tools/Foleage/tri-dpf/LICENSE b/libOTe/Tools/Foleage/tri-dpf/LICENSE deleted file mode 100644 index 2aa6fcd1..00000000 --- a/libOTe/Tools/Foleage/tri-dpf/LICENSE +++ /dev/null @@ -1,9 +0,0 @@ -MIT License - -Copyright © 2024 Sacha Servan-Schreiber - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Softwareâ€), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED “AS ISâ€, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/libOTe/Tools/Foleage/tri-dpf/README.md b/libOTe/Tools/Foleage/tri-dpf/README.md deleted file mode 100644 index f628b4f1..00000000 --- a/libOTe/Tools/Foleage/tri-dpf/README.md +++ /dev/null @@ -1,116 +0,0 @@ -# Ternary-tree DPF Implementation - -A simple C implementation of Distributed Point Functions (DPFs) with several performance optimizations. - -Optimizations include: - -- Ternary instead of a binary tree (increases communication slightly but improves evaluation performance by having a flatter tree). -- Using batched AES for fast PRF evaluation with AES-NI. -- The half-tree optimization of [Guo et al.](https://eprint.iacr.org/2022/1431.pdf), however, this only improves performance by 2\%-4\% in the ternary-tree case. - -## Dependencies - -- OpenSSL -- GNU Make -- Cmake -- Clang - -## Getting everything to run (tested on Ubuntu, CentOS, and MacOS) - -| Install dependencies (Ubuntu): | Install dependencies (CentOS): | -| -------------------------------------- | ------------------------------------------- | -| `sudo apt-get install build-essential` | `sudo yum groupinstall 'Development Tools'` | -| `sudo apt-get install cmake` | `sudo yum install cmake` | -| `sudo apt install libssl-dev` | `sudo yum install openssl-devel` | -| `sudo apt install clang` | `sudo yum install clang` | - -## Running tests and benchmarks - -``` -make -./bin/test -``` - -## Possible extensions (TODOs): - -- Arbitrary output size and full domain evaluation optimization of [Boyle et al.](https://eprint.iacr.org/2018/707). -- Serialization for DPF keys. - -## Minimal example - -```c -size_t domain_size = 10; -size_t num_leaves = ipow(3, domain_size); // domain of size 3^10 - -size_t secret_index = 5; -uint128_t secret_msg = 1; - -// common PRF keys -struct PRFKeys *prf_keys = malloc(sizeof(struct PRFKeys)); -PRFKeyGen(prf_keys); - -// DPF keys for each party -struct DPFKey *kA = malloc(sizeof(struct DPFKey)); -struct DPFKey *kB = malloc(sizeof(struct DPFKey)); - -DPFGen(prf_keys, domain_size, secret_index, &secret_msg, 1, kA, kB); - -uint128_t *shares0 = malloc(sizeof(uint128_t) * num_leaves); -uint128_t *shares1 = malloc(sizeof(uint128_t) * num_leaves); - -// cache is used to speed up evaluations when running many -// DPF evaluations sequentially -uint128_t *cache = malloc(sizeof(uint128_t) * num_leaves); - -// evaluate the DPF using the key of party A -DPFFullDomainEval(kA, cache, shares0); - -// evaluate the DPF using the key of party B -DPFFullDomainEval(kB, cache, shares1); - -DestroyPRFKey(prf_keys); -free(kA); -free(kB); -free(shares0); -free(shares1); -free(cache); -``` - -#### Performance on M1 Macbook Pro - -Domain of size $3^{14} \approx 2^{22}$ and message size of 256 bits. - -``` -****************************************** -Testing DPF.FullEval -****************************************** -PASS -Avg time for DPF.FullEval: 68.29 ms -****************************************** - -****************************************** -Testing HalfDPF.FullEval -****************************************** -PASS -Avg time for HalfDPF.FullEval: 65.38 ms -****************************************** -``` - -## Citation - -``` -@misc{foleage, - author = {Maxime Bombar and Dung Bui and Geoffroy Couteau and Alain Couvreur and Clément Ducros and Sacha Servan-Schreiber}, - title = {FOLEAGE: $\mathbb{F}_4$OLE-Based Multi-Party Computation for Boolean Circuits}, - howpublished = {Cryptology ePrint Archive, Paper 2024/429}, - year = {2024}, - note = {\url{https://eprint.iacr.org/2024/429}}, - url = {https://eprint.iacr.org/2024/429} -} - -``` - -## âš ï¸ Important Warning - -This implementation is intended for _research purposes only_. The code has NOT been vetted by security experts. -As such, no portion of the code should be used in any real-world or production setting! diff --git a/libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h b/libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h deleted file mode 100644 index a3c3a938..00000000 --- a/libOTe/Tools/Foleage/tri-dpf/TriDpfUtils.h +++ /dev/null @@ -1,68 +0,0 @@ -#pragma once - - -#include -#include -#include "libOTe/Tools/Foleage/FoleageUtils.h" -#include "cryptoTools/Common/BitIterator.h" - -namespace osuCrypto -{ - - static inline block flip_lsb(block input) - { - return input ^ block(0, 1); - } - - static inline block get_lsb(block input) - { - return block::allSame(-(input.get(0) & 1)); - } - - static inline int get_trit(uint64_t x, int size, int t) - { - std::vector ternary(size); - for (int i = 0; i < size; i++) - { - ternary[i] = x % 3; - x /= 3; - } - - return ternary[t]; - } - - static inline int get_bit(block x, int size, int b) - { - return *oc::BitIterator((u8*)&x, (size - b)); - //return ((x) >> (size - b)) & 1; - } - - //static void printBytes(void* p, int num) - //{ - // unsigned char* c = (unsigned char*)p; - // for (int i = 0; i < num; i++) - // { - // printf("%02x", c[i]); - // } - // printf("\n"); - //} - - //// Compute base^exp without the floating-point precision - //// errors of the built-in pow function. - //static inline int ipow(int base, int exp) - //{ - // int result = 1; - // while (1) - // { - // if (exp & 1) - // result *= base; - // exp >>= 1; - // if (!exp) - // break; - // base *= base; - // } - - // return result; - //} - -} \ No newline at end of file diff --git a/libOTe/Tools/Foleage/uint128.h b/libOTe/Tools/Foleage/uint128.h deleted file mode 100644 index ed9fdd70..00000000 --- a/libOTe/Tools/Foleage/uint128.h +++ /dev/null @@ -1,790 +0,0 @@ -//// -//// Copyright 2017 The Abseil Authors. -//// -//// Licensed under the Apache License, Version 2.0 (the "License"); -//// you may not use this file except in compliance with the License. -//// You may obtain a copy of the License at -//// -//// https://www.apache.org/licenses/LICENSE-2.0 -//// -//// Unless required by applicable law or agreed to in writing, software -//// distributed under the License is distributed on an "AS IS" BASIS, -//// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -//// See the License for the specific language governing permissions and -//// limitations under the License. -//// -//// ----------------------------------------------------------------------------- -//// File: int128_t.h -//// ----------------------------------------------------------------------------- -//// -//// This header file defines 128-bit integer types, `uint128_t` and `int128_t`. -// -//#ifndef ABSL_INT128_H_ -//#define ABSL_INT128_H_ -// -//#include -//#include -//#include -//#include -//#include -//#include -//#include -//#include -//#include -// -//#define ABSL_IS_LITTLE_ENDIAN -//#if defined(_MSC_VER) -//// In very old versions of MSVC and when the /Zc:wchar_t flag is off, wchar_t is -//// a typedef for unsigned short. Otherwise wchar_t is mapped to the __wchar_t -//// builtin type. We need to make sure not to define operator wchar_t() -//// alongside operator unsigned short() in these instances. -//#define ABSL_INTERNAL_WCHAR_T __wchar_t -//#if defined(_M_X64) -//#include -//#pragma intrinsic(_umul128) -//#endif // defined(_M_X64) -//#else // defined(_MSC_VER) -//#define ABSL_INTERNAL_WCHAR_T wchar_t -//#endif // defined(_MSC_VER) -// -//#ifdef _WIN32 -//#ifdef abslint128_t_EXPORTS -//#define ABSL_DLL __declspec(dllexport) -//#else -//#define ABSL_DLL __declspec(dllimport) -//#endif -//#else // _WIN32 -//#define ABSL_DLL -//#endif // _WIN32 -// -//// ABSL_ATTRIBUTE_ALWAYS_INLINE -//// ABSL_ATTRIBUTE_NOINLINE -//// -//// Forces functions to either inline or not inline. Introduced in gcc 3.1. -//#if defined(__GNUC__) || defined(__clang__) -//#define ABSL_ATTRIBUTE_ALWAYS_INLINE __attribute__((always_inline)) -//#elif defined(_MSC_VER) && !__INTEL_COMPILER && _MSC_VER >= 1310 // since Visual Studio .NET 2003 -//#define ABSL_ATTRIBUTE_ALWAYS_INLINE inline __forceinline -//#else -//#define ABSL_ATTRIBUTE_ALWAYS_INLINE inline -//#endif -// -//// ABSL_INTERNAL_ASSUME(cond) -//// Informs the compiler than a condition is always true and that it can assume -//// it to be true for optimization purposes. The call has undefined behavior if -//// the condition is false. -//// In !NDEBUG mode, the condition is checked with an assert(). -//// NOTE: The expression must not have side effects, as it will only be evaluated -//// in some compilation modes and not others. -//// -//// Example: -//// -//// int x = ...; -//// ABSL_INTERNAL_ASSUME(x >= 0); -//// // The compiler can optimize the division to a simple right shift using the -//// // assumption specified above. -//// int y = x / 16; -//// -// -//#if defined(_MSC_VER) -//#define ABSL_INTERNAL_ASSUME(cond) __assume(cond) -//#else -//#define ABSL_INTERNAL_ASSUME(cond) -//#endif -// -//namespace absl { -// -// -// // uint128_t -// // -// // An unsigned 128-bit integer type. The API is meant to mimic an intrinsic type -// // as closely as is practical, including exhibiting undefined behavior in -// // analogous cases (e.g. division by zero). This type is intended to be a -// // drop-in replacement once C++ supports an intrinsic `uint128_t_t` type; when -// // that occurs, existing well-behaved uses of `uint128_t` will continue to work -// // using that new type. -// // -// // Note: code written with this type will continue to compile once `uint128_t_t` -// // is introduced, provided the replacement helper functions -// // `Uint128(Low|High)64()` and `MakeUint128()` are made. -// // -// // A `uint128_t` supports the following: -// // -// // * Implicit construction from integral types -// // * Explicit conversion to integral types -// // -// // Additionally, if your compiler supports `__int128_t`, `uint128_t` is -// // interoperable with that type. (Abseil checks for this compatibility through -// // the `ABSL_HAVE_INTRINSIC_INT128` macro.) -// // -// // However, a `uint128_t` differs from intrinsic integral types in the following -// // ways: -// // -// // * Errors on implicit conversions that do not preserve value (such as -// // loss of precision when converting to float values). -// // * Requires explicit construction from and conversion to floating point -// // types. -// // * Conversion to integral types requires an explicit static_cast() to -// // mimic use of the `-Wnarrowing` compiler flag. -// // * The alignment requirement of `uint128_t` may differ from that of an -// // intrinsic 128-bit integer type depending on platform and build -// // configuration. -// // -// // Example: -// // -// // float y = absl::Uint128Max(); // Error. uint128_t cannot be implicitly -// // // converted to float. -// // -// // absl::uint128_t v; -// // uint64_t i = v; // Error -// // uint64_t i = static_cast(v); // OK -// // -// class -//#if defined(ABSL_HAVE_INTRINSIC_INT128) -// alignas(unsigned __int128_t) -//#endif // ABSL_HAVE_INTRINSIC_INT128 -// uint128_t { -// public: -// uint128_t() = default; -// -// // Constructors from arithmetic types -// constexpr uint128_t(int v); // NOLINT(runtime/explicit) -// constexpr uint128_t(unsigned int v); // NOLINT(runtime/explicit) -// constexpr uint128_t(long v); // NOLINT(runtime/int) -// constexpr uint128_t(unsigned long v); // NOLINT(runtime/int) -// constexpr uint128_t(long long v); // NOLINT(runtime/int) -// constexpr uint128_t(unsigned long long v); // NOLINT(runtime/int) -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// constexpr uint128_t(__int128_t v); // NOLINT(runtime/explicit) -// constexpr uint128_t(unsigned __int128_t v); // NOLINT(runtime/explicit) -//#endif // ABSL_HAVE_INTRINSIC_INT128 -// explicit uint128_t(float v); -// explicit uint128_t(double v); -// explicit uint128_t(long double v); -// -// // Assignment operators from arithmetic types -// uint128_t& operator=(int v); -// uint128_t& operator=(unsigned int v); -// uint128_t& operator=(long v); // NOLINT(runtime/int) -// uint128_t& operator=(unsigned long v); // NOLINT(runtime/int) -// uint128_t& operator=(long long v); // NOLINT(runtime/int) -// uint128_t& operator=(unsigned long long v); // NOLINT(runtime/int) -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// uint128_t& operator=(__int128_t v); -// uint128_t& operator=(unsigned __int128_t v); -//#endif // ABSL_HAVE_INTRINSIC_INT128 -// -// // Conversion operators to other arithmetic types -// constexpr explicit operator bool() const; -// constexpr explicit operator char() const; -// constexpr explicit operator signed char() const; -// constexpr explicit operator unsigned char() const; -// constexpr explicit operator char16_t() const; -// constexpr explicit operator char32_t() const; -// constexpr explicit operator ABSL_INTERNAL_WCHAR_T() const; -// constexpr explicit operator short() const; // NOLINT(runtime/int) -// // NOLINTNEXTLINE(runtime/int) -// constexpr explicit operator unsigned short() const; -// constexpr explicit operator int() const; -// constexpr explicit operator unsigned int() const; -// constexpr explicit operator long() const; // NOLINT(runtime/int) -// // NOLINTNEXTLINE(runtime/int) -// constexpr explicit operator unsigned long() const; -// // NOLINTNEXTLINE(runtime/int) -// constexpr explicit operator long long() const; -// // NOLINTNEXTLINE(runtime/int) -// constexpr explicit operator unsigned long long() const; -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// constexpr explicit operator __int128_t() const; -// constexpr explicit operator unsigned __int128_t() const; -//#endif // ABSL_HAVE_INTRINSIC_INT128 -// explicit operator float() const; -// explicit operator double() const; -// explicit operator long double() const; -// -// // Trivial copy constructor, assignment operator and destructor. -// -// // Arithmetic operators. -// uint128_t& operator+=(uint128_t other); -// uint128_t& operator-=(uint128_t other); -// uint128_t& operator*=(uint128_t other); -// // Long division/modulo for uint128_t. -// uint128_t& operator/=(uint128_t other); -// uint128_t& operator%=(uint128_t other); -// uint128_t operator++(int); -// uint128_t operator--(int); -// uint128_t& operator<<=(int); -// uint128_t& operator>>=(int); -// uint128_t& operator&=(uint128_t other); -// uint128_t& operator|=(uint128_t other); -// uint128_t& operator^=(uint128_t other); -// uint128_t& operator++(); -// uint128_t& operator--(); -// -// // Uint128Low64() -// // -// // Returns the lower 64-bit value of a `uint128_t` value. -// friend constexpr uint64_t Uint128Low64(uint128_t v); -// -// // Uint128High64() -// // -// // Returns the higher 64-bit value of a `uint128_t` value. -// friend constexpr uint64_t Uint128High64(uint128_t v); -// -// // MakeUInt128() -// // -// // Constructs a `uint128_t` numeric value from two 64-bit unsigned integers. -// // Note that this factory function is the only way to construct a `uint128_t` -// // from integer values greater than 2^64. -// // -// // Example: -// // -// // absl::uint128_t big = absl::MakeUint128(1, 0); -// friend constexpr uint128_t MakeUint128(uint64_t high, uint64_t low); -// -// // Uint128Max() -// // -// // Returns the highest value for a 128-bit unsigned integer. -// friend constexpr uint128_t Uint128Max(); -// -// // Support for absl::Hash. -// template -// friend H AbslHashValue(H h, uint128_t v) { -// return H::combine(std::move(h), Uint128High64(v), Uint128Low64(v)); -// } -// -// // Combined division/modulo for a 128-bit unsigned integer. -// static void DivMod(uint128_t dividend, uint128_t divisor, uint128_t* quotient_ret, -// uint128_t* remainder_ret); -// -// static std::string ToFormattedString(uint128_t v, std::ios_base::fmtflags flags = std::ios_base::fmtflags()); -// -// static std::string ToString(uint128_t v); -// -// private: -// constexpr uint128_t(uint64_t high, uint64_t low); -// -// // TODO(strel) Update implementation to use __int128_t once all users of -// // uint128_t are fixed to not depend on alignof(uint128_t) == 8. Also add -// // alignas(16) to class definition to keep alignment consistent across -// // platforms. -//#if defined(ABSL_IS_LITTLE_ENDIAN) -// uint64_t lo_; -// uint64_t hi_; -//#elif defined(ABSL_IS_BIG_ENDIAN) -// uint64_t hi_; -// uint64_t lo_; -//#else // byte order -//#error "Unsupported byte order: must be little-endian or big-endian." -//#endif // byte order -// }; -// -// // allow uint128_t to be logged -// std::ostream& operator<<(std::ostream& os, uint128_t v); -// -// // TODO(strel) add operator>>(std::istream&, uint128_t) -// -// constexpr uint128_t Uint128Max() { -// return uint128_t((std::numeric_limits::max)(), -// (std::numeric_limits::max)()); -// } -// -//} // namespace absl -// -//// Specialized numeric_limits for uint128_t. -//namespace std { -// template <> -// class numeric_limits { -// public: -// static constexpr bool is_specialized = true; -// static constexpr bool is_signed = false; -// static constexpr bool is_integer = true; -// static constexpr bool is_exact = true; -// static constexpr bool has_infinity = false; -// static constexpr bool has_quiet_NaN = false; -// static constexpr bool has_signaling_NaN = false; -// static constexpr float_denorm_style has_denorm = denorm_absent; -// static constexpr bool has_denorm_loss = false; -// static constexpr float_round_style round_style = round_toward_zero; -// static constexpr bool is_iec559 = false; -// static constexpr bool is_bounded = true; -// static constexpr bool is_modulo = true; -// static constexpr int digits = 128; -// static constexpr int digits10 = 38; -// static constexpr int max_digits10 = 0; -// static constexpr int radix = 2; -// static constexpr int min_exponent = 0; -// static constexpr int min_exponent10 = 0; -// static constexpr int max_exponent = 0; -// static constexpr int max_exponent10 = 0; -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// static constexpr bool traps = numeric_limits::traps; -//#else // ABSL_HAVE_INTRINSIC_INT128 -// static constexpr bool traps = numeric_limits::traps; -//#endif // ABSL_HAVE_INTRINSIC_INT128 -// static constexpr bool tinyness_before = false; -// -// static constexpr absl::uint128_t(min)() { return 0; } -// static constexpr absl::uint128_t lowest() { return 0; } -// static constexpr absl::uint128_t(max)() { return absl::Uint128Max(); } -// static constexpr absl::uint128_t epsilon() { return 0; } -// static constexpr absl::uint128_t round_error() { return 0; } -// static constexpr absl::uint128_t infinity() { return 0; } -// static constexpr absl::uint128_t quiet_NaN() { return 0; } -// static constexpr absl::uint128_t signaling_NaN() { return 0; } -// static constexpr absl::uint128_t denorm_min() { return 0; } -// }; -//} // namespace std -// -// -//// -------------------------------------------------------------------------- -//// Implementation details follow -//// -------------------------------------------------------------------------- -//namespace absl { -// -// constexpr uint128_t MakeUint128(uint64_t high, uint64_t low) { -// return uint128_t(high, low); -// } -// -// // Assignment from integer types. -// -// inline uint128_t& uint128_t::operator=(int v) { return *this = uint128_t(v); } -// -// inline uint128_t& uint128_t::operator=(unsigned int v) { -// return *this = uint128_t(v); -// } -// -// inline uint128_t& uint128_t::operator=(long v) { // NOLINT(runtime/int) -// return *this = uint128_t(v); -// } -// -// // NOLINTNEXTLINE(runtime/int) -// inline uint128_t& uint128_t::operator=(unsigned long v) { -// return *this = uint128_t(v); -// } -// -// // NOLINTNEXTLINE(runtime/int) -// inline uint128_t& uint128_t::operator=(long long v) { -// return *this = uint128_t(v); -// } -// -// // NOLINTNEXTLINE(runtime/int) -// inline uint128_t& uint128_t::operator=(unsigned long long v) { -// return *this = uint128_t(v); -// } -// -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// inline uint128_t& uint128_t::operator=(__int128_t v) { -// return *this = uint128_t(v); -// } -// -// inline uint128_t& uint128_t::operator=(unsigned __int128_t v) { -// return *this = uint128_t(v); -// } -//#endif // ABSL_HAVE_INTRINSIC_INT128 -// -// -// // Arithmetic operators. -// -// uint128_t operator<<(uint128_t lhs, int amount); -// uint128_t operator>>(uint128_t lhs, int amount); -// uint128_t operator+(uint128_t lhs, uint128_t rhs); -// uint128_t operator-(uint128_t lhs, uint128_t rhs); -// uint128_t operator*(uint128_t lhs, uint128_t rhs); -// uint128_t operator/(uint128_t lhs, uint128_t rhs); -// uint128_t operator%(uint128_t lhs, uint128_t rhs); -// -// inline uint128_t& uint128_t::operator<<=(int amount) { -// *this = *this << amount; -// return *this; -// } -// -// inline uint128_t& uint128_t::operator>>=(int amount) { -// *this = *this >> amount; -// return *this; -// } -// -// inline uint128_t& uint128_t::operator+=(uint128_t other) { -// *this = *this + other; -// return *this; -// } -// -// inline uint128_t& uint128_t::operator-=(uint128_t other) { -// *this = *this - other; -// return *this; -// } -// -// inline uint128_t& uint128_t::operator*=(uint128_t other) { -// *this = *this * other; -// return *this; -// } -// -// inline uint128_t& uint128_t::operator/=(uint128_t other) { -// *this = *this / other; -// return *this; -// } -// -// inline uint128_t& uint128_t::operator%=(uint128_t other) { -// *this = *this % other; -// return *this; -// } -// -// constexpr uint64_t Uint128Low64(uint128_t v) { return v.lo_; } -// -// constexpr uint64_t Uint128High64(uint128_t v) { return v.hi_; } -// -// // Constructors from integer types. -// -//#if defined(ABSL_IS_LITTLE_ENDIAN) -// -// constexpr uint128_t::uint128_t(uint64_t high, uint64_t low) -// : lo_{ low }, hi_{ high } { -// } -// -// constexpr uint128_t::uint128_t(int v) -// : lo_{ static_cast(v) }, -// hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { -// } -// constexpr uint128_t::uint128_t(long v) // NOLINT(runtime/int) -// : lo_{ static_cast(v) }, -// hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { -// } -// constexpr uint128_t::uint128_t(long long v) // NOLINT(runtime/int) -// : lo_{ static_cast(v) }, -// hi_{ v < 0 ? (std::numeric_limits::max)() : 0 } { -// } -// -// constexpr uint128_t::uint128_t(unsigned int v) : lo_{ v }, hi_{ 0 } {} -// // NOLINTNEXTLINE(runtime/int) -// constexpr uint128_t::uint128_t(unsigned long v) : lo_{ v }, hi_{ 0 } {} -// // NOLINTNEXTLINE(runtime/int) -// constexpr uint128_t::uint128_t(unsigned long long v) : lo_{ v }, hi_{ 0 } {} -// -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// constexpr uint128_t::uint128_t(__int128_t v) -// : lo_{ static_cast(v & ~uint64_t{0}) }, -// hi_{ static_cast(static_cast(v) >> 64) } { -// } -// constexpr uint128_t::uint128_t(unsigned __int128_t v) -// : lo_{ static_cast(v & ~uint64_t{0}) }, -// hi_{ static_cast(v >> 64) } { -// } -//#endif // ABSL_HAVE_INTRINSIC_INT128 -// -//#elif defined(ABSL_IS_BIG_ENDIAN) -// -// constexpr uint128_t::uint128_t(uint64_t high, uint64_t low) -// : hi_{ high }, lo_{ low } { -// } -// -// constexpr uint128_t::uint128_t(int v) -// : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, -// lo_{ static_cast(v) } { -// } -// constexpr uint128_t::uint128_t(long v) // NOLINT(runtime/int) -// : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, -// lo_{ static_cast(v) } { -// } -// constexpr uint128_t::uint128_t(long long v) // NOLINT(runtime/int) -// : hi_{ v < 0 ? (std::numeric_limits::max)() : 0 }, -// lo_{ static_cast(v) } { -// } -// -// constexpr uint128_t::uint128_t(unsigned int v) : hi_{ 0 }, lo_{ v } {} -// // NOLINTNEXTLINE(runtime/int) -// constexpr uint128_t::uint128_t(unsigned long v) : hi_{ 0 }, lo_{ v } {} -// // NOLINTNEXTLINE(runtime/int) -// constexpr uint128_t::uint128_t(unsigned long long v) : hi_{ 0 }, lo_{ v } {} -// -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// constexpr uint128_t::uint128_t(__int128_t v) -// : hi_{ static_cast(static_cast(v) >> 64) }, -// lo_{ static_cast(v & ~uint64_t{0}) } { -// } -// constexpr uint128_t::uint128_t(unsigned __int128_t v) -// : hi_{ static_cast(v >> 64) }, -// lo_{ static_cast(v & ~uint64_t{0}) } { -// } -//#endif // ABSL_HAVE_INTRINSIC_INT128 -// -// constexpr uint128_t::uint128_t(int128_t v) -// : hi_{ static_cast(Int128High64(v)) }, lo_{ Int128Low64(v) } { -// } -// -//#else // byte order -//#error "Unsupported byte order: must be little-endian or big-endian." -//#endif // byte order -// -//// Conversion operators to integer types. -// -// constexpr uint128_t::operator bool() const { return lo_ || hi_; } -// -// constexpr uint128_t::operator char() const { return static_cast(lo_); } -// -// constexpr uint128_t::operator signed char() const { -// return static_cast(lo_); -// } -// -// constexpr uint128_t::operator unsigned char() const { -// return static_cast(lo_); -// } -// -// constexpr uint128_t::operator char16_t() const { -// return static_cast(lo_); -// } -// -// constexpr uint128_t::operator char32_t() const { -// return static_cast(lo_); -// } -// -// constexpr uint128_t::operator ABSL_INTERNAL_WCHAR_T() const { -// return static_cast(lo_); -// } -// -// // NOLINTNEXTLINE(runtime/int) -// constexpr uint128_t::operator short() const { return static_cast(lo_); } -// -// constexpr uint128_t::operator unsigned short() const { // NOLINT(runtime/int) -// return static_cast(lo_); // NOLINT(runtime/int) -// } -// -// constexpr uint128_t::operator int() const { return static_cast(lo_); } -// -// constexpr uint128_t::operator unsigned int() const { -// return static_cast(lo_); -// } -// -// // NOLINTNEXTLINE(runtime/int) -// constexpr uint128_t::operator long() const { return static_cast(lo_); } -// -// constexpr uint128_t::operator unsigned long() const { // NOLINT(runtime/int) -// return static_cast(lo_); // NOLINT(runtime/int) -// } -// -// constexpr uint128_t::operator long long() const { // NOLINT(runtime/int) -// return static_cast(lo_); // NOLINT(runtime/int) -// } -// -// constexpr uint128_t::operator unsigned long long() const { // NOLINT(runtime/int) -// return static_cast(lo_); // NOLINT(runtime/int) -// } -// -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// constexpr uint128_t::operator __int128_t() const { -// return (static_cast<__int128_t>(hi_) << 64) + lo_; -// } -// -// constexpr uint128_t::operator unsigned __int128_t() const { -// return (static_cast(hi_) << 64) + lo_; -// } -//#endif // ABSL_HAVE_INTRINSIC_INT128 -// -// // Conversion operators to floating point types. -// -// inline uint128_t::operator float() const { -// return static_cast(lo_) + std::ldexp(static_cast(hi_), 64); -// } -// -// inline uint128_t::operator double() const { -// return static_cast(lo_) + std::ldexp(static_cast(hi_), 64); -// } -// -// inline uint128_t::operator long double() const { -// return static_cast(lo_) + -// std::ldexp(static_cast(hi_), 64); -// } -// -// // Comparison operators. -// -// inline bool operator==(uint128_t lhs, uint128_t rhs) { -// return (Uint128Low64(lhs) == Uint128Low64(rhs) && -// Uint128High64(lhs) == Uint128High64(rhs)); -// } -// -// inline bool operator!=(uint128_t lhs, uint128_t rhs) { -// return !(lhs == rhs); -// } -// -// inline bool operator<(uint128_t lhs, uint128_t rhs) { -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// return static_cast(lhs) < -// static_cast(rhs); -//#else -// return (Uint128High64(lhs) == Uint128High64(rhs)) -// ? (Uint128Low64(lhs) < Uint128Low64(rhs)) -// : (Uint128High64(lhs) < Uint128High64(rhs)); -//#endif -// } -// -// inline bool operator>(uint128_t lhs, uint128_t rhs) { return rhs < lhs; } -// -// inline bool operator<=(uint128_t lhs, uint128_t rhs) { return !(rhs < lhs); } -// -// inline bool operator>=(uint128_t lhs, uint128_t rhs) { return !(lhs < rhs); } -// -// // Unary operators. -// -// inline uint128_t operator-(uint128_t val) { -// uint64_t hi = ~Uint128High64(val); -// uint64_t lo = ~Uint128Low64(val) + 1; -// if (lo == 0) ++hi; // carry -// return MakeUint128(hi, lo); -// } -// -// inline bool operator!(uint128_t val) { -// return !Uint128High64(val) && !Uint128Low64(val); -// } -// -// // Logical operators. -// -// inline uint128_t operator~(uint128_t val) { -// return MakeUint128(~Uint128High64(val), ~Uint128Low64(val)); -// } -// -// inline uint128_t operator|(uint128_t lhs, uint128_t rhs) { -// return MakeUint128(Uint128High64(lhs) | Uint128High64(rhs), -// Uint128Low64(lhs) | Uint128Low64(rhs)); -// } -// -// inline uint128_t operator&(uint128_t lhs, uint128_t rhs) { -// return MakeUint128(Uint128High64(lhs) & Uint128High64(rhs), -// Uint128Low64(lhs) & Uint128Low64(rhs)); -// } -// -// inline uint128_t operator^(uint128_t lhs, uint128_t rhs) { -// return MakeUint128(Uint128High64(lhs) ^ Uint128High64(rhs), -// Uint128Low64(lhs) ^ Uint128Low64(rhs)); -// } -// -// inline uint128_t& uint128_t::operator|=(uint128_t other) { -// hi_ |= other.hi_; -// lo_ |= other.lo_; -// return *this; -// } -// -// inline uint128_t& uint128_t::operator&=(uint128_t other) { -// hi_ &= other.hi_; -// lo_ &= other.lo_; -// return *this; -// } -// -// inline uint128_t& uint128_t::operator^=(uint128_t other) { -// hi_ ^= other.hi_; -// lo_ ^= other.lo_; -// return *this; -// } -// -// // Arithmetic operators. -// -// inline uint128_t operator<<(uint128_t lhs, int amount) { -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// return static_cast(lhs) << amount; -//#else -// // uint64_t shifts of >= 64 are undefined, so we will need some -// // special-casing. -// if (amount < 64) { -// if (amount != 0) { -// return MakeUint128( -// (Uint128High64(lhs) << amount) | (Uint128Low64(lhs) >> (64 - amount)), -// Uint128Low64(lhs) << amount); -// } -// return lhs; -// } -// return MakeUint128(Uint128Low64(lhs) << (amount - 64), 0); -//#endif -// } -// -// inline uint128_t operator>>(uint128_t lhs, int amount) { -//#ifdef ABSL_HAVE_INTRINSIC_INT128 -// return static_cast(lhs) >> amount; -//#else -// // uint64_t shifts of >= 64 are undefined, so we will need some -// // special-casing. -// if (amount < 64) { -// if (amount != 0) { -// return MakeUint128(Uint128High64(lhs) >> amount, -// (Uint128Low64(lhs) >> amount) | -// (Uint128High64(lhs) << (64 - amount))); -// } -// return lhs; -// } -// return MakeUint128(0, Uint128High64(lhs) >> (amount - 64)); -//#endif -// } -// -// inline uint128_t operator+(uint128_t lhs, uint128_t rhs) { -// uint128_t result = MakeUint128(Uint128High64(lhs) + Uint128High64(rhs), -// Uint128Low64(lhs) + Uint128Low64(rhs)); -// if (Uint128Low64(result) < Uint128Low64(lhs)) { // check for carry -// return MakeUint128(Uint128High64(result) + 1, Uint128Low64(result)); -// } -// return result; -// } -// -// inline uint128_t operator-(uint128_t lhs, uint128_t rhs) { -// uint128_t result = MakeUint128(Uint128High64(lhs) - Uint128High64(rhs), -// Uint128Low64(lhs) - Uint128Low64(rhs)); -// if (Uint128Low64(lhs) < Uint128Low64(rhs)) { // check for carry -// return MakeUint128(Uint128High64(result) - 1, Uint128Low64(result)); -// } -// return result; -// } -// -// inline uint128_t operator*(uint128_t lhs, uint128_t rhs) { -//#if defined(ABSL_HAVE_INTRINSIC_INT128) -// // TODO(strel) Remove once alignment issues are resolved and unsigned __int128_t -// // can be used for uint128_t storage. -// return static_cast(lhs) * -// static_cast(rhs); -//#elif defined(_MSC_VER) && defined(_M_X64) -// uint64_t carry; -// uint64_t low = _umul128(Uint128Low64(lhs), Uint128Low64(rhs), &carry); -// return MakeUint128(Uint128Low64(lhs) * Uint128High64(rhs) + -// Uint128High64(lhs) * Uint128Low64(rhs) + carry, -// low); -//#else // ABSL_HAVE_INTRINSIC128 -// uint64_t a32 = Uint128Low64(lhs) >> 32; -// uint64_t a00 = Uint128Low64(lhs) & 0xffffffff; -// uint64_t b32 = Uint128Low64(rhs) >> 32; -// uint64_t b00 = Uint128Low64(rhs) & 0xffffffff; -// uint128_t result = -// MakeUint128(Uint128High64(lhs) * Uint128Low64(rhs) + -// Uint128Low64(lhs) * Uint128High64(rhs) + a32 * b32, -// a00 * b00); -// result += uint128_t(a32 * b00) << 32; -// result += uint128_t(a00 * b32) << 32; -// return result; -//#endif // ABSL_HAVE_INTRINSIC128 -// } -// -// // Increment/decrement operators. -// -// inline uint128_t uint128_t::operator++(int) { -// uint128_t tmp(*this); -// *this += 1; -// return tmp; -// } -// -// inline uint128_t uint128_t::operator--(int) { -// uint128_t tmp(*this); -// *this -= 1; -// return tmp; -// } -// -// inline uint128_t& uint128_t::operator++() { -// *this += 1; -// return *this; -// } -// -// inline uint128_t& uint128_t::operator--() { -// *this -= 1; -// return *this; -// } -// -// -// -//} // namespace absl -// -//#undef ABSL_INTERNAL_WCHAR_T -// -//#endif // ABSL_INT128_H_ \ No newline at end of file diff --git a/libOTe/Tools/Foleage/FoleagePcg.cpp b/libOTe/Triple/Foleage/FoleageTriple.cpp similarity index 72% rename from libOTe/Tools/Foleage/FoleagePcg.cpp rename to libOTe/Triple/Foleage/FoleageTriple.cpp index 8c28fd8b..74975c74 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.cpp +++ b/libOTe/Triple/Foleage/FoleageTriple.cpp @@ -1,16 +1,14 @@ -#include "FoleagePcg.h" -#include "libOTe/Tools/Foleage/FoleageUtils.h" -#include "libOTe/Tools/Foleage/F4Ops.h" -#include "libOTe/Tools/Foleage/fft/FoleageFft.h" +#include "FoleageTriple.h" +#include "libOTe/Triple/Foleage/FoleageUtils.h" +#include "libOTe/Triple/Foleage/fft/FoleageFft.h" #include "cryptoTools/Common/BitIterator.h" -#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" -#include "libOTe/Tools/Foleage/tri-dpf/FoleagePrf.h" -#include "libOTe/Tools/Dpf/TriDpf.h" +#include "libOTe/Dpf/TriDpf.h" +#include "libOTe/Base/BaseOT.h" namespace osuCrypto { - void FoleageF4Ole::init(u64 partyIdx, u64 n) + void FoleageTriple::init(u64 partyIdx, u64 n) { mPartyIdx = partyIdx; mLog3N = log3ceil(n); @@ -27,8 +25,6 @@ namespace osuCrypto mDpfLeafSize = ipow(3, mDpfLeafDepth); mDpfTreeSize = ipow(3, mDpfTreeDepth); - //std::cout << "mLeafSize " << mDpfLeafSize << " " << mDpfLeafDepth << std::endl; - //std::cout << "mTreeSize " << mDpfTreeSize << " " << mDpfTreeDepth << std::endl; mDpfLeaf.init(mPartyIdx, mDpfLeafSize, mC * mC * mT * mT); mDpf.init(mPartyIdx, mDpfTreeSize, mC * mC * mT * mT); @@ -36,11 +32,11 @@ namespace osuCrypto if (mBlockSize < 2) throw RTE_LOC; - sampleA(block(431234234, 213434234123)); + sampleA(block(3127894527893612049, 240925987420932408)); } - FoleageF4Ole::BaseOtCount FoleageF4Ole::baseOtCount() const + FoleageTriple::BaseOtCount FoleageTriple::baseOtCount() const { BaseOtCount counts; @@ -54,7 +50,7 @@ namespace osuCrypto } - void FoleageF4Ole::setBaseOts( + void FoleageTriple::setBaseOts( span> baseSendOts, span recvBaseOts, const oc::BitVector& baseChoices) @@ -94,13 +90,158 @@ namespace osuCrypto mChoiceOts = BitVector(baseChoices.data(), baseChoices.size() - offset, offset); } - bool FoleageF4Ole::hasBaseOts() const + bool FoleageTriple::hasBaseOts() const { return mSendOts.size() + mRecvOts.size() > 0; } + macoro::task<> FoleageTriple::genBaseOts( + PRNG& prng, + Socket& sock, + SilentBaseType baseType) + { + if (isInitialized() == false) + { + throw std::runtime_error("init must be called first. " LOCATION); + } + auto baseCount = baseOtCount(); + + setTimePoint("genBase.start"); + if (mPartyIdx) + { + if (baseType == SilentBaseType::BaseExtend) + { +#ifdef ENABLE_SOFTSPOKEN_OT + if (!mOtExtRecver) + mOtExtRecver.emplace(); + if (!mOtExtSender) + mOtExtSender.emplace(); - void FoleageF4Ole::sampleA(block seed) + if (mOtExtRecver->hasBaseOts() == false) + co_await mOtExtRecver->genBaseOts(prng, sock); + + u64 extSenderCount = 0; + if (mOtExtSender->hasBaseOts() == false) + { + extSenderCount = mOtExtSender->baseOtCount(); + baseCount.mRecvCount += extSenderCount; + } + + + BitVector choice(baseCount.mRecvCount); + choice.randomize(prng); + std::vector recvMsg(choice.size()); + co_await mOtExtRecver->receive(choice, recvMsg, prng, sock); + + if (extSenderCount) + { + BitVector senderChoice(choice.data(), extSenderCount); + span senderMsg(recvMsg.data(), extSenderCount); + mOtExtSender->setBaseOts(senderMsg, senderChoice); + } + + std::vector> sendMsg(baseCount.mSendCount); + co_await mOtExtSender->send(sendMsg, prng, sock); + + choice = BitVector(choice.data(), choice.size() - extSenderCount, extSenderCount); + setBaseOts(sendMsg, span(recvMsg).subspan(extSenderCount), choice); +#else + throw std::runtime_error("ENABLE_SOFTSPOKEN_OT = false, must enable soft spoken. " LOCATION); +#endif + } + else + { +#ifdef LIBOTE_HAS_BASE_OT + auto sock2 = sock.fork(); + auto prng2 = prng.fork(); + auto baseOt1 = DefaultBaseOT{}; + auto baseOt2 = DefaultBaseOT{}; + std::vector recvMsg(baseCount.mRecvCount); + std::vector> sendMsg(baseCount.mSendCount); + BitVector choice(baseCount.mRecvCount); + choice.randomize(prng); + + co_await( + macoro::when_all_ready( + baseOt1.send(sendMsg, prng, sock), + baseOt2.receive(choice,recvMsg, prng2, sock2))); + + setBaseOts(sendMsg, recvMsg, choice); +#else + throw std::runtime_error("A base OT must be enabled. " LOCATION); +#endif + } + } + else + { + + if (baseType == SilentBaseType::BaseExtend) + { +#ifdef ENABLE_SOFTSPOKEN_OT + if (!mOtExtRecver) + mOtExtRecver.emplace(); + if (!mOtExtSender) + mOtExtSender.emplace(); + + if (mOtExtSender->hasBaseOts() == false) + co_await mOtExtSender->genBaseOts(prng, sock); + + u64 extRecverCount = 0; + if (mOtExtRecver->hasBaseOts() == false) + { + extRecverCount = mOtExtRecver->baseOtCount(); + baseCount.mSendCount += extRecverCount; + } + + std::vector> sendMsg(baseCount.mSendCount); + co_await mOtExtSender->send(sendMsg, prng, sock); + + if (extRecverCount) + { + span> recverMsg(sendMsg.data(), extRecverCount); + mOtExtRecver->setBaseOts(recverMsg); + } + + BitVector choice(baseCount.mRecvCount); + choice.randomize(prng); + std::vector recvMsg(choice.size()); + co_await mOtExtRecver->receive(choice, recvMsg, prng, sock); + + setBaseOts(span(sendMsg).subspan(extRecverCount), recvMsg, choice); +#else + throw std::runtime_error("ENABLE_SOFTSPOKEN_OT = false, must enable soft spoken. " LOCATION); +#endif + } + else + { +#ifdef LIBOTE_HAS_BASE_OT + auto sock2 = sock.fork(); + auto prng2 = prng.fork(); + auto baseOt1 = DefaultBaseOT{}; + auto baseOt2 = DefaultBaseOT{}; + std::vector recvMsg(baseCount.mRecvCount); + std::vector> sendMsg(baseCount.mSendCount); + BitVector choice(baseCount.mRecvCount); + choice.randomize(prng); + + co_await( + macoro::when_all_ready( + baseOt1.receive(choice, recvMsg, prng, sock), + baseOt2.send(sendMsg, prng2, sock2) + )); + + setBaseOts(sendMsg, recvMsg, choice); +#else + throw std::runtime_error("A base OT must be enabled. " LOCATION); +#endif + } + + } + setTimePoint("genBase.done"); + } + + + void FoleageTriple::sampleA(block seed) { if (mC > 4) @@ -118,12 +259,6 @@ namespace osuCrypto mFftA[i] = (mFftA[i] & ~3) | 1; } - - // FOR DEBUGGING: set fft_a to the identity - // for (size_t i = 0; i < mN; i++) - // { - // mFftA[i] = (0xaaaa >> 1); - // } uint32_t prod; for (size_t i = 0; i < mN; i++) { @@ -143,9 +278,7 @@ namespace osuCrypto u8 tmp = (a2 & b2); prod = tmp ^ ((a2 & (b1 << 1)) ^ ((a1 << 1) & b2)); prod |= (a1 & b1) ^ (tmp >> 1); - //return res; } - //prod = mult_f4(, ); size_t slot = j * mC + k; mFftASquared[i] |= prod << (2 * slot); } @@ -156,7 +289,7 @@ namespace osuCrypto - macoro::task<> FoleageF4Ole::expand( + macoro::task<> FoleageTriple::expand( span ALsb, span AMsb, span CLsb, @@ -167,13 +300,16 @@ namespace osuCrypto setTimePoint("expand start"); if (hasBaseOts() == false) - throw RTE_LOC; + { + co_await genBaseOts(prng, sock); + } if (divCeil(mN, 128) < ALsb.size()) throw RTE_LOC; if (ALsb.size() != AMsb.size() || - ALsb.size() != CLsb.size() || - ALsb.size() != CMsb.size()) + ALsb.size() != CLsb.size()) + throw RTE_LOC; + if (ALsb.size() != CMsb.size() && CMsb.size()) throw RTE_LOC; // the coefficient of the sparse polynomial. @@ -399,19 +535,34 @@ namespace osuCrypto fft_recursive_uint32(fft, mLog3N, mN / 3); setTimePoint("product fft"); - multiply_fft_32(mFftASquared, fft, fftRes, mN); + F4Multiply(mFftASquared, fft, fftRes, mN); setTimePoint("product mult"); - // XOR the (packed) columns into the accumulator. - // Specifically, we perform column-wise XORs to get the result. - u32 lsbMask, msbMask; - setBytes(lsbMask, 0b01010101); - setBytes(msbMask, 0b10101010); - for (size_t i = 0; i < outSize; i++) + if (CMsb.size()) + { + + // XOR the (packed) columns into the accumulator. + // Specifically, we perform column-wise XORs to get the result. + u32 lsbMask, msbMask; + setBytes(lsbMask, 0b01010101); + setBytes(msbMask, 0b10101010); + for (size_t i = 0; i < outSize; i++) + { + *BitIterator(CLsb.data(), i) = popcount(fftRes[i] & lsbMask) & 1; + *BitIterator(CMsb.data(), i) = popcount(fftRes[i] & msbMask) & 1; + } + } + else { - *BitIterator(CLsb.data(), i) = popcount(fftRes[i] & lsbMask) & 1; - *BitIterator(CMsb.data(), i) = popcount(fftRes[i] & msbMask) & 1; + // XOR the (packed) columns into the accumulator. + // Specifically, we perform column-wise XORs to get the result. + u32 lsbMask, msbMask; + setBytes(lsbMask, 0b01010101); + for (size_t i = 0; i < outSize; i++) + { + *BitIterator(CLsb.data(), i) = popcount(fftRes[i] & lsbMask) & 1; + } } @@ -420,7 +571,7 @@ namespace osuCrypto } - macoro::task<> FoleageF4Ole::tensor(span coeffs, span prod, coproto::Socket& sock) + macoro::task<> FoleageTriple::tensor(span coeffs, span prod, coproto::Socket& sock) { if (coeffs.size() * coeffs.size() != prod.size()) throw RTE_LOC; @@ -455,7 +606,7 @@ namespace osuCrypto a[0][i] = t0[i] ^ t1[i]; // a[1] = 2 * a[0] - f4Mult(a[0][0], a[0][1], ZeroBlock, AllOneBlock, a[1][0], a[1][1]); + F4Multiply(a[0][0], a[0][1], ZeroBlock, AllOneBlock, a[1][0], a[1][1]); { auto lsbIter = BitIterator(&a[0][0]); @@ -552,7 +703,7 @@ namespace osuCrypto } } - //macoro::task<> FoleageF4Ole::checkTensor(span coeffs, span tensoredCoefficients, coproto::Socket& sock) + //macoro::task<> FoleageTriple::checkTensor(span coeffs, span tensoredCoefficients, coproto::Socket& sock) //{ // std::array, 2> pCoeffs;// (coeffs.size()); // pCoeffs[mPartyIdx] = std::vector(coeffs.begin(), coeffs.end()); @@ -575,7 +726,7 @@ namespace osuCrypto // auto scaler = pCoeffs[0][i]; // for (u64 j = 0; j < coeffs.size(); ++j) // { - // u8 exp = mult_f4(scaler, pCoeffs[1][j]); + // u8 exp = F4Multiply(scaler, pCoeffs[1][j]); // auto prod = pProd(i, j); // if (prod != exp) // { diff --git a/libOTe/Tools/Foleage/FoleagePcg.h b/libOTe/Triple/Foleage/FoleageTriple.h similarity index 72% rename from libOTe/Tools/Foleage/FoleagePcg.h rename to libOTe/Triple/Foleage/FoleageTriple.h index acc25fbb..fe58d5d4 100644 --- a/libOTe/Tools/Foleage/FoleagePcg.h +++ b/libOTe/Triple/Foleage/FoleageTriple.h @@ -5,12 +5,20 @@ #include "coproto/Socket/Socket.h" #include "cryptoTools/Crypto/PRNG.h" #include "cryptoTools/Common/Timer.h" -#include "libOTe/Tools/Dpf/TriDpf.h" +#include "libOTe/Dpf/TriDpf.h" +#include "libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenShOtExt.h" + namespace osuCrypto { - - class FoleageF4Ole : public TimerAdapter + // The two party Foleage PCG protocol for generating F4 OLEs + // and Binary Beaver triples. The caller should call + // + // FoleageTriple::init(...) + // FoleageTriple::expand(...) + // + // There are two expand function, one for OLEs and one for Triples. + class FoleageTriple : public TimerAdapter { public: u64 mPartyIdx = 0; @@ -66,10 +74,13 @@ namespace osuCrypto // a dpf used to construct the F4x243 leaf value of the larger DPF. TriDpf mDpfLeaf; +#ifdef ENABLE_SOFTSPOKEN_OT + std::optional> mOtExtRecver; + std::optional> mOtExtSender; +#endif struct FoleageCoeffCtx : CoeffCtxGF2 { - OC_FORCEINLINE void fromBlock(FoleageF4x243& ret, const block& b) { ret.mVal[0] = b; ret.mVal[1] = b ^ block(2314523225322345310, 3520873105824273452); @@ -84,19 +95,21 @@ namespace osuCrypto // The base OTs used to tensor the coefficients of the sparse polynomial. std::vector mRecvOts; - + // The base OTs used to tensor the coefficients of the sparse polynomial. - std::vector> mSendOts; + std::vector> mSendOts; // The base OTs used to tensor the coefficients of the sparse polynomial. BitVector mChoiceOts; - // Intializes the protocol to generate n OLEs. Most efficient when n + // Intializes the protocol to generate n F4 OLEs. Most efficient when n // is a power of 3. Once called, baseOtCount() can be called to // determine the required number of base OTs. void init(u64 partyIdx, u64 n); + bool isInitialized() const { return mN > 0; } + struct BaseOtCount { // the number of base OTs as sender. @@ -118,6 +131,8 @@ namespace osuCrypto // returns true of the base OTs have been set. bool hasBaseOts() const; + macoro::task<> genBaseOts(PRNG& prng, Socket& sock, SilentBaseType baseType = SilentBaseType::BaseExtend); + // The F4 OLE protocol. This will generate n OLEs. // the resulting OLEs are in bit decomposition form. // A = (AMsb || ALsb), C = (CMsb || CLsb). This party will @@ -127,10 +142,48 @@ namespace osuCrypto span ALsb, span AMsb, span CLsb, - span CMsb, - PRNG& prng, + span CMsb, + PRNG& prng, coproto::Socket& sock); + + // The F2 beaver triple protocol. This will generate n beaver triples. + macoro::task<> expand( + span A, + span B, + span C, + PRNG& prng, + coproto::Socket& sock) + { + if (mPartyIdx) + { + co_await expand(B, A, C, {}, prng, sock); + + for (u64 i = 0; i < A.size(); ++i) + { + // b(0)b(1) + auto bb = B[i] & A[i]; + // b(0)b(1) + [ab]1(0) + C[i] ^= bb; + } + } + else + { + //auto bLsb = temp[0]; + //auto bMsb = temp[1]; + co_await expand(A, B, C, {}, prng, sock); + + for (u64 i = 0; i < A.size(); ++i) + { + // a(0)a(1) + auto aa = A[i] & B[i]; + + // a(0)a(1) + [ab]0(0) + C[i] ^= aa; + } + } + } + // sample random coefficients for the sparse polynomial and tensor // them with the other parties coefficients. The result is shared // as tensoredCoefficients. We allow the coeff to be zero. @@ -142,4 +195,6 @@ namespace osuCrypto }; + + } diff --git a/libOTe/Triple/Foleage/FoleageUtils.h b/libOTe/Triple/Foleage/FoleageUtils.h new file mode 100644 index 00000000..d3834a00 --- /dev/null +++ b/libOTe/Triple/Foleage/FoleageUtils.h @@ -0,0 +1,266 @@ +#pragma once +#include "cryptoTools/Crypto/AES.h" +#include "cryptoTools/Crypto/PRNG.h" +#include "cryptoTools/Crypto/RandomOracle.h" +#include +#include +#include + +namespace osuCrypto +{ + + // Multiplies two elements of F4 + // and returns the result. + inline uint8_t F4Multiply(uint8_t a, uint8_t b) + { + u8 tmp = ((a & 0b10) & (b & 0b10)); + uint8_t res = tmp ^ (((a & 0b10) & ((b & 0b01) << 1)) ^ (((a & 0b01) << 1) & (b & 0b10))); + res |= ((a & 0b01) & (b & 0b01)) ^ (tmp >> 1); + return res; + } + + + // component-wise Multiplies two elements of F4^64 + // and returns the result. + inline void F4Multiply( + block aLsb, block aMsb, + block bLsb, block bMsb, + block& cLsb, block& cMsb) + { + auto tmp = aMsb & bMsb;// msb only + cMsb = tmp ^ (aMsb & bLsb) ^ (aLsb & bMsb);// msb only + cLsb = (aLsb & bLsb) ^ tmp; + } + + + // Multiplies two packed matrices of F4 elements column-by-column. + // Note that here the "columns" are packed into an element of uint8_t + // resulting in a matrix with 4 columns. + inline void F4Multiply( + span a_poly, + span b_poly, + span res_poly, + size_t poly_size) + { + const uint8_t pattern = 0xaa; + uint8_t mask_h = pattern; // 0b10101010 + uint8_t mask_l = mask_h >> 1; // 0b01010101 + + uint8_t tmp; + uint8_t a_h, a_l, b_h, b_l; + + for (size_t i = 0; i < poly_size; i++) + { + // multiplication over F4 + a_h = (a_poly[i] & mask_h); + a_l = (a_poly[i] & mask_l); + b_h = (b_poly[i] & mask_h); + b_l = (b_poly[i] & mask_l); + + tmp = (a_h & b_h); + res_poly[i] = tmp ^ (a_h & (b_l << 1)); + res_poly[i] ^= ((a_l << 1) & b_h); + res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); + } + } + + // Multiplies two packed matrices of F4 elements column-by-column. + // Note that here the "columns" are packed into an element of uint16_t + // resulting in a matrix with 8 columns. + inline void F4Multiply( + span a_poly, + span b_poly, + span res_poly, + size_t poly_size) + { + const uint16_t pattern = 0xaaaa; + uint16_t mask_h = pattern; // 0b101010101010101001010 + uint16_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + uint16_t tmp; + uint16_t a_h, a_l, b_h, b_l; + + for (size_t i = 0; i < poly_size; i++) + { + // multiplication over F4 + a_h = (a_poly[i] & mask_h); + a_l = (a_poly[i] & mask_l); + b_h = (b_poly[i] & mask_h); + b_l = (b_poly[i] & mask_l); + + tmp = (a_h & b_h); + res_poly[i] = tmp ^ (a_h & (b_l << 1)); + res_poly[i] ^= ((a_l << 1) & b_h); + res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); + } + } + + // Multiplies two packed matrices of F4 elements column-by-column. + // Note that here the "columns" are packed into an element of uint32_t + // resulting in a matrix with 16 columns. + inline void F4Multiply( + span a_poly, + span b_poly, + span res_poly, + size_t poly_size) + { + const uint32_t pattern = 0xaaaaaaaa; + uint32_t mask_h = pattern; // 0b101010101010101001010 + uint32_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + uint32_t tmp; + uint32_t a_h, a_l, b_h, b_l; + + for (size_t i = 0; i < poly_size; i++) + { + // multiplication over F4 + a_h = (a_poly[i] & mask_h); + a_l = (a_poly[i] & mask_l); + b_h = (b_poly[i] & mask_h); + b_l = (b_poly[i] & mask_l); + + tmp = (a_h & b_h); + res_poly[i] = tmp ^ (a_h & (b_l << 1)); + res_poly[i] ^= ((a_l << 1) & b_h); + res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); + } + } + + // Multiplies two packed matrices of F4 elements column-by-column. + // Note that here the "columns" are packed into an element of uint64_t + // resulting in a matrix with 32 columns. + inline void F4Multiply( + span a_poly, + span b_poly, + span res_poly, + size_t poly_size) + { + const uint64_t pattern = 0xaaaaaaaaaaaaaaaa; + uint64_t mask_h = pattern; // 0b101010101010101001010 + uint64_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + uint64_t tmp; + uint64_t a_h, a_l, b_h, b_l; + + for (size_t i = 0; i < poly_size; i++) + { + // multiplication over F4 + a_h = (a_poly[i] & mask_h); + a_l = (a_poly[i] & mask_l); + b_h = (b_poly[i] & mask_h); + b_l = (b_poly[i] & mask_l); + + tmp = (a_h & b_h); + res_poly[i] = tmp ^ (a_h & (b_l << 1)); + res_poly[i] ^= ((a_l << 1) & b_h); + res_poly[i] |= (a_l & b_l) ^ (tmp >> 1); + } + } + + inline u64 log3ceil(u64 x) + { + if (x == 0) return 0; + u64 i = 0; + u64 v = 1; + while (v < x) + { + v *= 3; + i++; + } + //assert(i == ceil(log_base(x, 3))); + + return i; + } + + // Compute base^exp without the floating-point precision + // errors of the built-in pow function. + inline constexpr size_t ipow(size_t base, size_t exp) + { + if (exp == 1) + return base; + + if (exp == 0) + return 1; + + size_t result = 1; + while (1) + { + if (exp & 1) + result *= base; + exp >>= 1; + if (!exp) + break; + base *= base; + } + + return result; + } + + inline int popcount(block x) + { + return popcount(x.get(0)) + popcount(x.get(1)); + } + + inline std::array extractF4(const block& val) + { + std::array ret; + const char* ptr = (const char*)&val; + for (u8 i = 0; i < 16; ++i) + { + ret[i * 4 + 0] = (ptr[i] >> 0) & 3; + ret[i * 4 + 1] = (ptr[i] >> 2) & 3; + ret[i * 4 + 2] = (ptr[i] >> 4) & 3; + ret[i * 4 + 3] = (ptr[i] >> 6) & 3;; + } + return ret; + } + + // A 512 bit value that is used to represent a vector of 3^5=243 F4 elements. + // We use this value because its greater than 128 bits and almost a power of 2. + // the last 26 bits are unused. + struct FoleageF4x243 + { + std::array mVal; + + FoleageF4x243 operator^(const FoleageF4x243& o) const + { + FoleageF4x243 r; + r.mVal[0] = mVal[0] ^ o.mVal[0]; + r.mVal[1] = mVal[1] ^ o.mVal[1]; + r.mVal[2] = mVal[2] ^ o.mVal[2]; + r.mVal[3] = mVal[3] ^ o.mVal[3]; + return r; + } + FoleageF4x243& operator^=(const FoleageF4x243& o) + { + mVal[0] = mVal[0] ^ o.mVal[0]; + mVal[1] = mVal[1] ^ o.mVal[1]; + mVal[2] = mVal[2] ^ o.mVal[2]; + mVal[3] = mVal[3] ^ o.mVal[3]; + return *this; + } + + bool operator==(const FoleageF4x243& o) const + { + return + mVal[0] == o.mVal[0] && + mVal[1] == o.mVal[1] && + mVal[2] == o.mVal[2] && + mVal[3] == o.mVal[3]; + } + }; + + inline std::array extractF4(const FoleageF4x243& val) + { + std::array ret; + const char* ptr = (const char*)&val; + for (u8 i = 0; i < 64; ++i) + { + ret[i * 4 + 0] = (ptr[i] >> 0) & 3; + ret[i * 4 + 1] = (ptr[i] >> 2) & 3; + ret[i * 4 + 2] = (ptr[i] >> 4) & 3; + ret[i * 4 + 3] = (ptr[i] >> 6) & 3;; + } + return ret; + } +} \ No newline at end of file diff --git a/libOTe/Triple/Foleage/fft/FoleageFft.cpp b/libOTe/Triple/Foleage/fft/FoleageFft.cpp new file mode 100644 index 00000000..bb8fb491 --- /dev/null +++ b/libOTe/Triple/Foleage/fft/FoleageFft.cpp @@ -0,0 +1,310 @@ +#include +#include +#include "libOTe/Triple/Foleage/fft/FoleageFft.h" + +namespace osuCrypto { + + void fft_recursive_uint64( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint64( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint64( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint64( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint64_t tL, tM; + uint64_t mult, xor_h, xor_l; + + uint64_t* coeffsL = coeffs.data() + 0; + uint64_t* coeffsM = coeffs.data() + num_coeffs; + uint64_t* coeffsR = coeffs.data() + 2 * num_coeffs; + + const uint64_t pattern = 0xaaaaaaaaaaaaaaaa; + const uint64_t mask_h = pattern; // 0b101010101010101001010 + const uint64_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + + void fft_recursive_uint32( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint32( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint32( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint32( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint32_t tL, tM; + uint32_t mult, xor_h, xor_l; + + uint32_t* coeffsL = coeffs.data() + 0; + uint32_t* coeffsM = coeffs.data() + num_coeffs; + uint32_t* coeffsR = coeffs.data() + 2 * num_coeffs; + + const uint32_t pattern = 0xaaaaaaaa; + const uint32_t mask_h = pattern; // 0b101010101010101001010 + const uint32_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + + void fft_recursive_uint16( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + fft_recursive_uint16( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + fft_recursive_uint16( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + fft_recursive_uint16( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint16_t tL, tM; + uint16_t mult, xor_h, xor_l; + + uint16_t* coeffsL = coeffs.data() + 0; + uint16_t* coeffsM = coeffs.data() + num_coeffs; + uint16_t* coeffsR = coeffs.data() + 2 * num_coeffs; + + const uint16_t pattern = 0xaaaa; + const uint16_t mask_h = pattern; // 0b101010101010101001010 + const uint16_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } + + void foliageFftUint8( + span coeffs, + const size_t num_vars, + const size_t num_coeffs) + { + // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) + + if (num_vars > 1) + { + // apply FFT on all left coefficients + foliageFftUint8( + coeffs, + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all middle coefficients + foliageFftUint8( + coeffs.subspan(num_coeffs), + num_vars - 1, + num_coeffs / 3); + + // apply FFT on all right coefficients + foliageFftUint8( + coeffs.subspan(2 * num_coeffs), + num_vars - 1, + num_coeffs / 3); + } + + // temp variables to store intermediate values + uint8_t tL, tM; + uint8_t mult, xor_h, xor_l; + + uint8_t* coeffsL = coeffs.data() + 0; + uint8_t* coeffsM = coeffs.data() + num_coeffs; + uint8_t* coeffsR = coeffs.data() + 2 * num_coeffs; + + const uint8_t pattern = 0xaa; + const uint8_t mask_h = pattern; // 0b101010101010101001010 + const uint8_t mask_l = mask_h >> 1; // 0b010101010101010100101 + + for (size_t j = 0; j < num_coeffs; j++) + { + xor_h = (coeffsM[j] ^ coeffsR[j]) & mask_h; + xor_l = (coeffsM[j] ^ coeffsR[j]) & mask_l; + + // pre compute: \alpha * (cM[j] ^ cR[j]) + // computed as: mult_l = (h ^ l) and mult_h = l + // mult_l = (xor&mask_h>>1) ^ (xor & mask_l) [align h and l then xor] + // mult_h = (xor&mask_l) shifted left by 1 to put in h place [shift and OR into place] + mult = ((xor_h >> 1) ^ xor_l) | (xor_l << 1); + + // tL coefficient obtained by evaluating on X_i=1 + tL = coeffsL[j] ^ coeffsM[j] ^ coeffsR[j]; + + // tM coefficient obtained by evaluating on X_i=\alpha + tM = coeffsL[j] ^ coeffsR[j] ^ mult; + + // Explanation: + // cL + cM*\alpha + cR*\alpha^2 + // = cL + cM*\alpha + cR*\alpha + cR + // = cL + cR + \alpha*(cM + cR) + + // tR: coefficient obtained by evaluating on X_i=\alpha^2=\alpha + 1 + coeffsR[j] = coeffsL[j] ^ coeffsM[j] ^ mult; + + // Explanation: + // cL + cM*(\alpha+1) + cR(\alpha+1)^2 + // = cL + cM + cM*\alpha + cR*(3\alpha + 2) + // = cL + cM + \alpha*(cM + cR) + // Note: we're in the F_2 field extension so 3\alpha+2 = \alpha+0. + + coeffsL[j] = tL; + coeffsM[j] = tM; + } + } +} \ No newline at end of file diff --git a/libOTe/Triple/Foleage/fft/FoleageFft.h b/libOTe/Triple/Foleage/fft/FoleageFft.h new file mode 100644 index 00000000..3e8c280e --- /dev/null +++ b/libOTe/Triple/Foleage/fft/FoleageFft.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include "cryptoTools/Common/Defines.h" +#include "cryptoTools/Common/MatrixView.h" +#include "libOTe/Triple/Foleage/FoleageUtils.h" +#include + +//#include "libOTe/Tools/Foleage/utils.h" +namespace osuCrypto { + + //typedef __int128 int128_t; + //typedef unsigned __int128 uint128_t; + + // FFT for (up to) 32 polynomials over F4 + void fft_recursive_uint64( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + // FFT for (up to) 16 polynomials over F4 + void fft_recursive_uint32( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + // FFT for (up to) 8 polynomials over F4 + void fft_recursive_uint16( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + // FFT for (up to) 4 polynomials over F4 + void foliageFftUint8( + span coeffs, + const size_t num_vars, + const size_t num_coeffs); + + +} diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp index b8559a59..e0d390e3 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtReceiver.cpp @@ -18,10 +18,6 @@ namespace osuCrypto { - - //u64 getPartitions(u64 scaler, u64 p, u64 secParam); - - // sets the KOS base OTs that are then used to extend void SilentOtExtReceiver::setBaseOts( span> baseSendOts) { @@ -30,7 +26,7 @@ namespace osuCrypto mOtExtRecver.emplace(); mOtExtRecver->setBaseOts(baseSendOts); #else - throw std::runtime_error("soft spoken ot must be enabled"); + throw std::runtime_error("softSpoken ot must be enabled. " LOCATION); #endif } @@ -45,7 +41,7 @@ namespace osuCrypto } return mOtExtRecver->baseOtCount(); #else - throw std::runtime_error("soft spoken ot must be enabled"); + throw std::runtime_error("softSpoken ot must be enabled. " LOCATION); #endif } @@ -56,7 +52,7 @@ namespace osuCrypto return false; return mOtExtRecver->hasBaseOts(); #else - throw std::runtime_error("soft spoken ot must be enabled"); + throw std::runtime_error("softSpoken ot must be enabled. " LOCATION); #endif }; @@ -73,7 +69,6 @@ namespace osuCrypto mGen.setBase(genOts); std::copy(malOts.begin(), malOts.end(), mMalCheckOts.begin()); - } task<> SilentOtExtReceiver::genBaseOts( @@ -82,16 +77,15 @@ namespace osuCrypto { setTimePoint("recver.gen.start"); #ifdef ENABLE_SOFTSPOKEN_OT - //mOtExtRecver.mFiatShamir = true; - if (!mOtExtRecver) mOtExtRecver.emplace(); co_await mOtExtRecver->genBaseOts(prng, chl); #else - throw std::runtime_error("soft spoken ot must be enabled"); + throw std::runtime_error("softSpoken ot must be enabled. " LOCATION); co_return; #endif } + // Returns an independent copy of this extender. std::unique_ptr SilentOtExtReceiver::split() { @@ -104,7 +98,7 @@ namespace osuCrypto ptr->mOtExtRecver = mOtExtRecver->splitBase(); return ret; #else - throw std::runtime_error("soft spoken ot must be enabled"); + throw std::runtime_error("softSpoken ot must be enabled. " LOCATION); #endif }; diff --git a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp index cf8b283b..9a5a6101 100644 --- a/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp +++ b/libOTe/TwoChooseOne/Silent/SilentOtExtSender.cpp @@ -28,7 +28,7 @@ namespace osuCrypto mOtExtSender->setBaseOts(baseRecvOts, choices); #else - throw std::runtime_error("KOS must be enabled"); + throw std::runtime_error("softspoken must be enabled. " LOCATION); #endif } @@ -44,7 +44,7 @@ namespace osuCrypto ptr->mOtExtSender = mOtExtSender->splitBase(); return ret; #else - throw std::runtime_error("KOS must be enabled"); + throw std::runtime_error("softspoken must be enabled. " LOCATION); #endif } @@ -57,7 +57,7 @@ namespace osuCrypto mOtExtSender.emplace(); return mOtExtSender->genBaseOts(prng, chl); #else - throw std::runtime_error("KOS must be enabled"); + throw std::runtime_error("softspoken must be enabled. " LOCATION); #endif } @@ -73,7 +73,7 @@ namespace osuCrypto } return mOtExtSender->baseOtCount(); #else - throw std::runtime_error("KOS must be enabled"); + throw std::runtime_error("softspoken must be enabled. " LOCATION); #endif } @@ -85,7 +85,7 @@ namespace osuCrypto return false; return mOtExtSender->hasBaseOts(); #else - throw std::runtime_error("KOS must be enabled"); + throw std::runtime_error("softspoken must be enabled. " LOCATION); #endif } diff --git a/libOTe/TwoChooseOne/TcoOtDefines.h b/libOTe/TwoChooseOne/TcoOtDefines.h index 9be38fac..410b1215 100644 --- a/libOTe/TwoChooseOne/TcoOtDefines.h +++ b/libOTe/TwoChooseOne/TcoOtDefines.h @@ -22,7 +22,7 @@ namespace osuCrypto const u64 commStepSize(512); // TODO: try increasing this for optimization. const u64 superBlkShift(3); const u64 superBlkSize(1 << superBlkShift); - const u64 gKosChallengeRepititions(1); + const u64 gKosChallengeRepititions(4); enum class SilentBaseType { // Use a standalone base OT protocol to generate the required base OTs diff --git a/libOTe_Tests/Foleage_Tests.cpp b/libOTe_Tests/Foleage_Tests.cpp index 225128d5..bd24f546 100644 --- a/libOTe_Tests/Foleage_Tests.cpp +++ b/libOTe_Tests/Foleage_Tests.cpp @@ -1,1430 +1,108 @@ #include "Foleage_Tests.h" -#include "libOTe/Tools/Foleage/tri-dpf/FoleageDpf.h" -#include "libOTe/Tools/Foleage/fft/FoleageFft.h" -//#include "libOTe/Tools/Foleage/tri-dpf/FoleageHalfDpf.h" -#include "libOTe/Tools/Foleage/F4Ops.h" +#include "libOTe/Triple/Foleage/fft/FoleageFft.h" #include "cryptoTools/Common/Matrix.h" -#include "libOTe/Tools/Foleage/FoleagePcg.h" +#include "libOTe/Triple/Foleage/FoleageTriple.h" #include "coproto/Socket/LocalAsyncSock.h" -#include "libOTe/Tools/Foleage/PerfectShuffle.h" #include "cryptoTools/Common/Timer.h" namespace osuCrypto { - //u8 extractF4(const block& val, u8 idx) - //{ - // auto byteIdx = idx / 4; - // auto bitIdx = idx % 4; - // auto byte = ((u8*)&val)[byteIdx]; - // return (byte >> (bitIdx * 2)) & 0b11; - //} - void testOutputCorrectness( - span shares0, - span shares1, - size_t num_outputs, - size_t secret_index, - span secret_msg, - size_t msg_len) - { - for (size_t i = 0; i < msg_len; i++) - { - block shareA = shares0[secret_index * msg_len + i]; - block shareB = shares1[secret_index * msg_len + i]; - block res = shareA ^ shareB; - - if (res != secret_msg[i]) - { - printf("FAIL (wrong message)\n"); - exit(0); - } - } - - for (size_t i = 0; i < num_outputs; i++) - { - if (i == secret_index) - continue; - - for (size_t j = 0; j < msg_len; j++) - { - block shareA = shares0[i * msg_len + j]; - block shareB = shares1[i * msg_len + j]; - block res = shareA ^ shareB; - - if (res != ZeroBlock) - { - printf("FAIL (non-zero) %zu\n", i); - printBytes(&shareA, 16); - printBytes(&shareB, 16); - - exit(0); - } - } - } - } - - void printOutputShares( - block* shares0, - block* shares1, - size_t num_outputs, - size_t msg_len) - { - for (size_t i = 0; i < num_outputs; i++) - { - for (size_t j = 0; j < msg_len; j++) - { - block shareA = shares0[i * msg_len + j]; - block shareB = shares1[i * msg_len + j]; - //block res = shareA ^ shareB; - - printf("(%zu, %zu) %zu\n", i, j, msg_len); - printBytes(&shareA, 16); - printBytes(&shareB, 16); - } - } - } - - - - void testOutputCorrectness_spf( - span shares0, - span shares1, - size_t num_outputs, - size_t secret_index, - span secret_msg, - size_t msg_len) - { - for (size_t i = 0; i < msg_len; i++) - { - block shareA = shares0[secret_index * msg_len + i]; - block shareB = shares1[secret_index * msg_len + i]; - block res = shareA ^ shareB; - - if (res != secret_msg[i]) - { - printf("FAIL (wrong message)\n"); - throw RTE_LOC; - } - } - - for (size_t i = 0; i < num_outputs; i++) - { - if (i == secret_index) - continue; - - for (size_t j = 0; j < msg_len; j++) - { - block shareA = shares0[i * msg_len + j]; - block shareB = shares1[i * msg_len + j]; - block res = shareA ^ shareB; - - if (res != ZeroBlock) - { - printf("FAIL (non-zero) %zu\n", i); - printBytes(&shareA, 16); - printBytes(&shareB, 16); - throw RTE_LOC; - //exit(0); - } - } - } - } - - void printOutputShares_spf( - block* shares0, - block* shares1, - size_t num_outputs, - size_t msg_len) - { - for (size_t i = 0; i < num_outputs; i++) - { - for (size_t j = 0; j < msg_len; j++) - { - block shareA = shares0[i * msg_len + j]; - block shareB = shares1[i * msg_len + j]; - //block res = shareA ^ shareB; - - printf("(%zu, %zu) %zu\n", i, j, msg_len); - printBytes(&shareA, 16); - printBytes(&shareB, 16); - } - } - } - - void foleage_transpose_test(const oc::CLP& cmd) - { - { - - std::vector v(3 * 8); - std::vector v2(3 * 8); - - for (u64 i = 0; i < v.size(); ++i) - { - v[i] = i; - } - - - // input: - // 0 1 2 3 4 5 6 7 - // 8 9 10 11 12 13 14 15 - // 16 17 18 19 20 21 22 23 - // - // output: - // 0 3 6 9 12 15 18 21 - // 1 4 7 10 13 16 19 22 - // 2 5 8 11 14 17 20 23 - //printShuffle3(v.data()); - foleageTransposeLeaf<2>((u8*)v.data(), (__m128i*)v2.data()); - //printShuffle3(v2.data()); - - for (u64 i = 0; i < v2.size(); ++i) - { - auto e = i * 3 % 24 + (i / 8); - if (v2[i] != e) - throw RTE_LOC; - - } - - } - - { - int randomize = 1;// 241234123; // set to 1 to make debuggable - - std::vector v(9 * 8); - std::vector v2(9 * 8); - - - for (u64 i = 0; i < v.size(); ++i) - { - v[i] = i * randomize; - } - - - //std::cout << "\n"; - //printShuffle3(v.data()); - //std::cout << "\n"; - //printShuffle3(v.data() + 3 * 8); - //std::cout << "\n"; - //printShuffle3(v.data() + 6 * 8); - //std::cout << "--------------\n"; - - ////dst[i * 3 + j] = a0; - foleageTranspose<2>((u8*)v.data(), (__m128i*)v2.data()); - - - //printShuffle9(v2.data()); - //std::cout << "\n"; - - - // 0 1 2 3 4 5 6 7 - // 8 9 10 11 12 13 14 15 - // 16 17 18 19 20 21 22 23 - // - // 24 25 26 27 28 29 30 31 - // 32 33 34 35 36 37 38 39 - // 40 41 42 43 44 45 46 47 - // - // 48 49 50 51 52 53 54 55 - // 56 57 58 59 60 61 62 63 - // 64 65 66 67 68 69 70 71 - - - // 0 8 16 3 11 19 6 14 22 25 33 41 28 36 44 31 39 47 50 58 66 53 61 69 - // 1 9 17 4 12 20 7 15 23 26 34 42 29 37 45 48 56 64 51 59 67 54 62 70 - // 2 10 18 5 13 21 24 32 40 27 45 43 30 38 46 49 57 65 52 60 68 55 63 71 - //std::cout << std::endl; - - std::vector> exp(3); - for (u64 i = 0, k = 0; i < 3; ++i) - { - for (u64 j = 0; j < 24; ++j, ++k) - { - auto row = j / 8; - exp[row].push_back(k * randomize); - } - } - //std::cout << "before\n"; - //for (u64 i = 0; i < 3; ++i) - //{ - // for (u64 j = 0; j < 24; ++j) - // { - // //std::cout << v2[i * 24 + j] << " "; - // std::cout << std::setw(2) << std::setfill(' ') << exp[i][j] << " "; - // } - // std::cout << std::endl; - //} - - - for (u64 i = 0; i < 3; ++i) - { - for (u64 j = 0; j < 8; ++j) - { - auto b = j * 3; - for (u64 k = 0; k < 3; ++k) - { - for (u64 l = 0; l < k; ++l) - { - std::swap(exp[k][b + l], exp[l][b + k]); - } - } - } - } - //std::cout << "after\n"; - //for (u64 i = 0; i < 3; ++i) - //{ - // for (u64 j = 0; j < 24; ++j) - // { - // //std::cout << v2[i * 24 + j] << " "; - // std::cout << std::setw(2) << std::setfill(' ') << exp[i][j] << " "; - // } - // std::cout << std::endl; - //} - - for (u64 i = 0; i < 3; ++i) - { - for (u64 j = 0; j < 24; ++j) - { - if (exp[i][j] != v2[i * 24 + j]) - throw RTE_LOC; - } - } - - //printShuffle9(v.data()); - //foleageTranspose<2>((u8*)v2.data(), (__m128i*)v.data()); - - //for (u64 i = 0; i < v.size(); ++i) - //{ - // if(v[i] != i * randomize) - // throw RTE_LOC; - //} - } - - { - int randomize = 241234123; // set to 1 to make debuggable - - std::vector v(9 * 8); - std::vector v2(9 * 8); - - - for (u64 i = 0; i < v.size(); ++i) - { - v[i] = i * randomize; - } - //std::cout << "in\n" << std::endl; - //printShuffle9(v.data()); - - - foleageTransposeLeaf<2>((u8*)&v[0], (__m128i*)& v2[0]); - foleageTransposeLeaf<2>((u8*)&v[3 * 8], (__m128i*)& v2[3 * 8]); - foleageTransposeLeaf<2>((u8*)&v[6 * 8], (__m128i*)& v2[6 * 8]); - - //std::cout << "l1\n" << std::endl; - //printShuffle9(v2.data()); - - foleageTranspose<2>((u8*)v2.data(), (__m128i*)v.data()); - - - //std::cout << "l2\n" << std::endl; - //printShuffle9(v.data()); - - //std::vector inverse(v.size()); - //for (u64 i = 0; i < v.size(); ++i) - //{ - // inverse[v[i]] = i; - //} - - - foliageUnTranspose<2>((u8*)v.data(), (__m128i*)v2.data()); - - //std::cout << "inv\n" << std::endl; - //for (u64 i = 0; i < inverse.size(); ++i) - //{ - // std::cout << std::setw(2) << std::setfill(' ') << inverse[i] << ", "; - //} - ////printShuffle9(inverse.data()); - //std::cout << "f\n"; - //printShuffle9(v2.data()); - - for (u64 i = 0; i < v.size(); ++i) - { - if (v2[i] != u16(i * randomize)) - throw RTE_LOC; - } - } - - if (0) - { - - u64 trials = 1000000; - //int randomize = 241234123; // set to 1 to make debuggable - - u64 ss = 9; - std::vector lsb(ss * trials), msb(ss * trials); - std::vector lsb2(ss * trials), msb2(ss * trials); - - PRNG prng(block(342134213421, 2341234123421)); - prng.get(lsb.data(), lsb.size()); - prng.get(msb.data(), msb.size()); - - - //for (u64 i = 0; i < 3 * 24; ++i) - //{ - // ((u16*)lsb.data())[i] = i * randomize; - // ((u16*)msb.data())[i] = i * randomize ^ 2134123423; - //} - std::cout << "in\n" << std::endl; - //printShuffle9(v.data()); - Timer t; - t.setTimePoint("b"); - - auto l = (u16*)lsb.data(); - auto m = (u16*)msb.data(); - for (u64 i = 0; i < trials * 8; ++i) - { - for (u64 j = 0; j < 3; ++j) - { - foleageFFTOne<1>( - &l[i * ss + j * 3 + 0], &m[i * ss + j * 3 + 0], - &l[i * ss + j * 3 + 1], &m[i * ss + j * 3 + 1], - &l[i * ss + j * 3 + 2], &m[i * ss + j * 3 + 2] - ); - } - - - for (u64 j = 0; j < 3; ++j) - { - foleageFFTOne<2>( - &l[i * ss + 0 * 3 + j], &m[i * ss + 0 * 3 + j], - &l[i * ss + 1 * 3 + j], &m[i * ss + 1 * 3 + j], - &l[i * ss + 2 * 3 + j], &m[i * ss + 2 * 3 + j] - ); - } - } - - t.setTimePoint("o"); - for (u64 i = 0; i < trials; ++i) - { - auto bLsb = lsb.data() + i * ss; - auto bMsb = msb.data() + i * ss; - auto bLsb2 = lsb2.data() + i * ss; - auto bMsb2 = msb2.data() + i * ss; - for (u64 j = 0; j < 3; ++j) - { - - foleageTransposeLeaf<2>((u8*)&bLsb[j * 3], (__m128i*) & bLsb[j * 3]); - foleageTransposeLeaf<2>((u8*)&bMsb[j * 3], (__m128i*) & bMsb[j * 3]); - foleageFFTOne<1>( - &bLsb2[j * 3 + 0], &bMsb2[j * 3 + 0], - &bLsb2[j * 3 + 1], &bMsb2[j * 3 + 1], - &bLsb2[j * 3 + 2], &bMsb2[j * 3 + 2] - ); - } - - foleageTranspose<2>((u8*)&bLsb2[0], (__m128i*)bLsb); - - foleageTranspose<2>((u8*)&bMsb2[0], (__m128i*)bMsb); - - foleageFFTOne<3, block>( - &bLsb[0], &bMsb[0], - &bLsb[3], &bMsb[3], - &bLsb[6], &bMsb[6] - ); - - } - t.setTimePoint("e"); - - std::cout << t << std::endl; - - } - } - - void foleage_fft_test(const oc::CLP& cmd) - { - PRNG prng(block(342134213421, 2341234123421)); - u64 nn = 14; - u64 n = ipow(3, nn); - Timer timer; - u64 trials = cmd.getOr("trials", 1); - - if (0) - { - - std::vector a(n); - std::vector lsb(n); - std::vector msb(n); - - prng.get(a.data(), a.size()); - for (u64 i = 0; i < n; ++i) - { - lsb[i] = - ((a[i] >> 0) & 1) | - ((a[i] >> 1) & 2) | - ((a[i] >> 2) & 4) | - ((a[i] >> 3) & 8); - auto m = a[i] >> 1; - msb[i] = - ((m >> 0) & 1) | - ((m >> 1) & 2) | - ((m >> 2) & 4) | - ((m >> 3) & 8); - } - - timer.setTimePoint("begin"); - foliageFftUint8(a, nn, n / 3); - timer.setTimePoint("fft_recursive_uint8"); - foleageFFT(lsb.data(), msb.data(), nn, n / 3); - timer.setTimePoint("foleageFFT 8 bit"); - - for (u64 i = 0; i < n; ++i) - { - auto a0 = - ((a[i] >> 0) & 1) | - ((a[i] >> 1) & 2) | - ((a[i] >> 2) & 4) | - ((a[i] >> 3) & 8); - auto m = a[i] >> 1; - auto a1 = - ((m >> 0) & 1) | - ((m >> 1) & 2) | - ((m >> 2) & 4) | - ((m >> 3) & 8); - - if (a0 != lsb[i] || a1 != msb[i]) - throw RTE_LOC; - } - - } - { - - std::vector a(n), a2(n); - oc::Matrix lsb(n, 2); - oc::Matrix msb(n, 2); - - prng.get(a.data(), a.size()); - - auto av = span((u8*)a.data(), n * 4); - auto av2 = span((u8*)a2.data(), n * 4); - - perfectUnshuffle(av, lsb, msb); - - auto lsb2 = lsb; - auto msb2 = msb; - - timer.setTimePoint("beign"); - for (u64 i = 0; i < trials; ++i) - fft_recursive_uint32(a, nn, n / 3); - timer.setTimePoint("fft_recursive_uint32"); - - - if (0) - { - - for (u64 i = 0; i < trials; ++i) - foleageFFT(lsb.data(), msb.data(), nn, 2 * n / 3); - timer.setTimePoint("foleageFFT 32bit"); - - perfectShuffle(lsb, msb, av2); - for (u64 i = 0; i < n; ++i) - { - if (a[i] != a2[i]) - throw RTE_LOC; - } - timer.setTimePoint("foleageFFT 32bit check"); - } - - if (1) - { - - for (u64 i = 0; i < trials; ++i) - foleageFFT2<2>(lsb2, msb2); - - timer.setTimePoint("foleageFFT2 32bit"); - - perfectShuffle(lsb2, msb2, av2); - for (u64 i = 0; i < n; ++i) - { - if (a[i] != a2[i]) - throw RTE_LOC; - } - timer.setTimePoint("foleageFFT2 32bit check"); - } - - std::cout << timer << std::endl; - - } - - } - - void foleage_spfss_test() - { - - size_t SUMT = 730;// sum of T DPFs - size_t FULLEVALDOMAIN = 10; - size_t MESSAGESIZE = 8; - size_t MAXRANDINDEX = ipow(3, FULLEVALDOMAIN); - - const size_t size = FULLEVALDOMAIN; // evaluation will result in 3^size points - const size_t msg_len = MESSAGESIZE; - PRNG prng(block(3423423)); - - size_t num_leaves = ipow(3, size); - - size_t secret_index = prng.get() % MAXRANDINDEX; - - // sample a random message of size msg_len - std::vector secret_msg(msg_len); - for (size_t i = 0; i < msg_len; i++) - secret_msg[i] = prng.get(); - - PRFKeys prf_keys; - prf_keys.gen(prng); - - std::vector kA(SUMT); - std::vector kB(SUMT); - - for (size_t i = 0; i < SUMT; i++) - DPFGen(prf_keys, size, secret_index, secret_msg, msg_len, kA[i], kB[i], prng); - - std::vector shares0(num_leaves * msg_len); - std::vector shares1(num_leaves * msg_len); - std::vector cache(num_leaves * msg_len); - - //************************************************ - // Test full domain evaluation - //************************************************ - - for (size_t i = 0; i < SUMT; i++) - DPFFullDomainEval(kA[i], cache, shares0); - - clock_t t; - t = clock(); - - for (size_t i = 0; i < SUMT; i++) - DPFFullDomainEval(kB[i], cache, shares1); // we can reuse the same shares and cache - - t = clock() - t; - double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("Time %f ms\n", time_taken); - - // printOutputShares(shares0, shares1, num_leaves, msg_len); - - testOutputCorrectness_spf( - shares0, - shares1, - num_leaves, - secret_index, - secret_msg, - msg_len); - - //DestroyPRFKey(prf_keys); - //free(kA); - //free(kB); - //free(shares0); - //free(shares1); - //free(cache); - - } - - - void foleage_dpf_test() - { - const size_t size = 14; // evaluation will result in 3^size points - const size_t msg_len = 2; - PRNG prng(block(342134)); - - size_t num_leaves = ipow(3, size); - - size_t secret_index = prng.get() % ipow(3, size); - - // sample a random message of size msg_len - std::vector secret_msg(msg_len); - for (size_t i = 0; i < msg_len; i++) - secret_msg[i] = prng.get(); - - PRFKeys prf_keys; - prf_keys.gen(prng); - - DPFKey kA; - DPFKey kB; - - DPFGen(prf_keys, size, secret_index, secret_msg, msg_len, kA, kB, prng); - - std::vector shares0(num_leaves * msg_len); - std::vector shares1(num_leaves * msg_len); - std::vector cache(num_leaves * msg_len); - - //************************************************ - // Test full domain evaluation - //************************************************ - - DPFFullDomainEval(kA, cache, shares0); - - clock_t t; - t = clock(); - DPFFullDomainEval(kB, cache, shares1); - t = clock() - t; - double time_taken = ((double)t) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("Time %f ms\n", time_taken); - - // printOutputShares(shares0, shares1, num_leaves, msg_len); - - testOutputCorrectness( - shares0, - shares1, - num_leaves, - secret_index, - secret_msg, - msg_len); - - } - - - - // This test case implements Figure 1 from https://eprint.iacr.org/2024/429.pdf. - // It uses /libs/fft and libs/tri-dpf extensively. - // Several simplifying design choices are made: - // 1. We assume that c*c <= 16 so that we can use a parallel FFT packing of F4 - // elements using a uint32_t type. - // 2. We assume that t is a power of 3 so that the block size of each error - // vector divides the size of the polynomial. This makes the code significantly - // more readable and easier to understand. - - // TODO[feature]: The current implementation assumes that C*C <= 16 in order - // to parallelize the FFTs and other components. Making the code work with - // arbitrary values of C is left for future work. - - // TODO[feature]: modularize the different components of the test case and - // design more unit tests. - - - // This test evaluates the full PCG.Expand for both parties and - // checks correctness of the resulting OLE correlation. - void foleage_pcg_test(const CLP& cmd) - { - bool check = !cmd.isSet("noCheck"); - auto N = 5; // 3^N number of OLEs generated in total - - // The C and T parameters are computed using the SageMath script that can be - // found in https://github.com/mbombar/estimator_folding - - auto C = 4;// compression factor - auto T = 27;// noise weight - - - clock_t time; - time = clock(); - PRNG prng0(block(2424523452345, 111124521521455324)); - PRNG prng1(block(6474567454546, 567546754674345444)); - - const size_t n = N; - const size_t c = C; - const size_t t = T; - - // 3^n - const size_t poly_size = ipow(3, n); - - //************************************************************************ - // Step 0: Sample the global (1, a1 ... a_c-1) polynomials - //************************************************************************ - std::vector fft_a(poly_size); - std::vector fft_a2(poly_size); - PRNG APrng(block(431234234, 213434234123)); - sample_a_and_a2(fft_a, fft_a2, poly_size, c, APrng); - - //std::cout << "a " << hash(fft_a.data(), fft_a.size()) << std::endl; - //std::cout << "a2 " << hash(fft_a2.data(), fft_a2.size()) << std::endl; - - - //************************************************************************ - // Here, we figure out a good block size for the error vectors such that - // t*block_size = 3^n and block_size/L*128 is close to a power of 3. - // We pack L=256 coefficients of F4 into each DPF output (note that larger - // packing values are also okay, but they will do increase key size). - //************************************************************************ - size_t dpf_domain_bits = log3ceil(divCeil(poly_size, t * 256.0)); - if (dpf_domain_bits == 0) - dpf_domain_bits = 1; - - printf("DPF domain bits %zu \n", dpf_domain_bits); - - // 4*128 ==> 256 coefficients in F4 - size_t dpf_block_size = 4 * ipow(3, dpf_domain_bits); - - printf("dpf_block_size = %zu\n", dpf_block_size); - - // Note: We assume that t is a power of 3 and so it divides poly_size - assert(poly_size % t == 0); - - // the size of a single regular block. We have t blocks in each polynomial - // poly_size = 2^n / t = 3^{n-3} - size_t block_size = poly_size / t; - - printf("block_size = %zu \n", block_size); - - printf("[ ]Done with Step 0 (sampling the public values)\n"); - - //************************************************************************ - // Step 1: Sample error polynomials eA and eB (c polynomials in total) - // each polynomial is t-sparse and has degree (t * block_size) = poly_size. - //************************************************************************ - std::vector err_polys_A(c * poly_size); - std::vector err_polys_B(c * poly_size); - - // coefficients associated with each error vector - std::vector err_poly_coeffs_A(c * t); - std::vector err_poly_coeffs_B(c * t); - - // positions of the T errors in each error vector - std::vector err_poly_positions_A(c * t); - std::vector err_poly_positions_B(c * t); - - for (size_t i = 0; i < c; i++) - { - for (size_t j = 0; j < t; j++) - { - size_t offset = i * t + j; - - // random *non-zero* coefficients in F4 - uint8_t a = rand_f4x(prng0); - uint8_t b = rand_f4x(prng1); - err_poly_coeffs_A[offset] = a; - err_poly_coeffs_B[offset] = b; - - // random index within the block - size_t pos_A = random_index(block_size - 1, prng0); - size_t pos_B = random_index(block_size - 1, prng1); - - if (pos_A >= block_size || pos_B >= block_size) - { - printf("FAIL: position > block_size: %zu, %zu\n", pos_A, pos_B); - throw RTE_LOC; - //exit(0); - } - - err_poly_positions_A[offset] = pos_A; - err_poly_positions_B[offset] = pos_B; - - // set the coefficient at the error position to the error value - err_polys_A[i * poly_size + j * block_size + pos_A] = a; - err_polys_B[i * poly_size + j * block_size + pos_B] = b; - } - } - - - //std::cout << "posA " << hash(err_poly_positions_A.data(), err_poly_positions_A.size()) << std::endl; - //std::cout << "posB " << hash(err_poly_positions_B.data(), err_poly_positions_B.size()) << std::endl; - //std::cout << "coeffA " << hash(err_poly_coeffs_A.data(), err_poly_coeffs_A.size()) << std::endl; - //std::cout << "coeffB " << hash(err_poly_coeffs_B.data(), err_poly_coeffs_B.size()) << std::endl; - - - // Compute FFT of eA and eB in packed form. - // Note that because c = 4, we can pack 4 FFTs into a uint8_t - std::vector fft_eA(poly_size); - std::vector fft_eB(poly_size); - uint8_t coeff_A, coeff_B; - - // This loop essentially computes a transpose to pack the coefficients - // of each polynomial into one "row" of the parallel FFT matrix - for (size_t j = 0; j < c; j++) - { - for (size_t i = 0; i < poly_size; i++) - { - // extract the i-th coefficient of the j-th error polynomial - coeff_A = err_polys_A[j * poly_size + i]; - coeff_B = err_polys_B[j * poly_size + i]; - - // pack the extracted coefficient into the j-th FFT slot - fft_eA[i] |= (coeff_A << (2 * j)); - fft_eB[i] |= (coeff_B << (2 * j)); - } - } - - //std::cout << "sparseA " << hash(fft_eA.data(), fft_eA.size()) << std::endl; - //std::cout << "sparseB " << hash(fft_eB.data(), fft_eB.size()) << std::endl; - - - // Evaluate the FFTs on the error polynomials eA and eB - foliageFftUint8(fft_eA, n, poly_size / 3); - foliageFftUint8(fft_eB, n, poly_size / 3); - - printf("[. ]Done with Step 1 (sampling error vectors)\n"); - - //************************************************************************ - // Step 2: compute the inner product xA = and xB = - //************************************************************************ - - // Initialize polynomials to zero (accumulators for inner product) - std::vector x_poly_A(poly_size); - std::vector x_poly_B(poly_size); - - // Compute the coordinate-wise multiplication over the packed FFT result - std::vector res_poly_A(poly_size); - std::vector res_poly_B(poly_size); - F4Multiply(fft_a, fft_eA, res_poly_A, poly_size); // a*eA - F4Multiply(fft_a, fft_eB, res_poly_B, poly_size); // a*eB - - - //std::cout << "multA " << hash(res_poly_A.data(), res_poly_A.size()) << std::endl; - //std::cout << "multB " << hash(res_poly_B.data(), res_poly_B.size()) << std::endl; - - - - // XOR the result into the accumulator. - // Specifically, we XOR all the columns of the FFT result to get a - // vector of size poly_size. - for (size_t j = 0; j < c; j++) - { - for (size_t i = 0; i < poly_size; i++) - { - x_poly_A[i] ^= (res_poly_A[i] >> (2 * j)) & 0b11; - x_poly_B[i] ^= (res_poly_B[i] >> (2 * j)) & 0b11; - } - } - - //std::cout << "compressA " << hash(x_poly_A.data(), x_poly_A.size()) << std::endl; - //std::cout << "compressB " << hash(x_poly_B.data(), x_poly_B.size()) << std::endl; - - - printf("[.. ]Done with Step 2 (computing the local vectors)\n"); - - //************************************************************************ - // Step 3: Compute cross product (eA x eB) using the position vectors - //************************************************************************ - std::vector err_poly_cross_coeffs(c * c * t * t); - std::vector err_poly_cross_positions(c * c * t * t); - std::vector err_polys_cross(c * c * poly_size); - std::vector trit_decomp_A(n); - std::vector trit_decomp_B(n); - std::vector trit_decomp(n); - - for (size_t iA = 0; iA < c; iA++) - { - for (size_t iB = 0; iB < c; iB++) - { - size_t poly_index = iA * c * t * t + iB * t * t; - std::vector next_idx(t); - - for (size_t jA = 0; jA < t; jA++) - { - for (size_t jB = 0; jB < t; jB++) - { - // jA-th coefficient value of the iA-th polynomial - uint8_t vA = err_poly_coeffs_A[iA * t + jA]; - - // jB-th coefficient value of the iB-th polynomial - uint8_t vB = err_poly_coeffs_B[iB * t + jB]; - - // Resulting cross-product coefficient - uint8_t v = mult_f4(vA, vB); - - // Compute the position (in the full polynomial) - size_t posA = jA * block_size + err_poly_positions_A[iA * t + jA]; - size_t posB = jB * block_size + err_poly_positions_B[iB * t + jB]; - - if (err_polys_A[iA * poly_size + posA] == 0) - { - printf("FAIL: Incorrect position recovered\n"); - throw RTE_LOC; - //exit(0); - } - - if (err_polys_B[iB * poly_size + posB] == 0) - { - printf("FAIL: Incorrect position recovered\n"); - throw RTE_LOC; - } - - // Decompose the position into the ternary basis - int_to_trits(posA, trit_decomp_A); - int_to_trits(posB, trit_decomp_B); - - // printf("[DEBUG]: posA=%zu, posB=%zu\n", posA, posB); - - // Sum ternary decomposition coordinate-wise to - // get the new position (in ternary). - for (size_t k = 0; k < n; k++) - { - // printf("[DEBUG]: trits_A[%zu]=%i, trits_B[%zu]=%i\n", - // k, trit_decomp_A[k], k, trit_decomp_B[k]); - trit_decomp[k] = (trit_decomp_A[k] + trit_decomp_B[k]) % 3; - } - - // Get back the resulting cross-product position as an integer - size_t pos = trits_to_int(trit_decomp); - size_t block_idx = floor(pos / block_size); // block index in polynomial - //size_t in_block_idx = pos % block_size; // index within the block - - err_polys_cross[(iA * c + iB) * poly_size + pos] ^= v; - - size_t idx = next_idx[block_idx]; - next_idx[block_idx]++; - - // printf("[DEBUG]: pos=%zu, block_idx=%zu, idx=%zu\n", pos, block_idx, idx); - err_poly_cross_coeffs[poly_index + block_idx * t + idx] = v; - err_poly_cross_positions[poly_index + block_idx * t + idx] = pos % block_size; - } - } - - for (size_t k = 0; k < t; k++) - { - if (next_idx[k] > t) - { - std::cout << "FAIL: next_idx > t at the end: " << next_idx[k] << std::endl; - throw RTE_LOC; - } - } - - //free(next_idx); - } - } - - - // cleanup temporary values - //free(trit_decomp); - //free(trit_decomp_A); - //free(trit_decomp_B); - - printf("[... ]Done with Step 3 (computing the cross product)\n"); - - //************************************************************************ - // Step 4: Sample the DPF keys for the cross product (eA x eB) - //************************************************************************ - - std::vector dpf_keys_A(c * c * t * t); - std::vector dpf_keys_B(c * c * t * t); - - // Sample PRF keys for the DPFs - PRFKeys prf_keys; - PRNG prfSeedPrng(block(3412342134, 56453452362346)); - prf_keys.gen(prfSeedPrng); - PRNG genPrng; - oc::RandomOracle dpfHash0(16); - oc::RandomOracle dpfHash1(16); - - // Sample DPF keys for each of the t errors in the t blocks - for (size_t i = 0; i < c; i++) - { - for (size_t j = 0; j < c; j++) - { - for (size_t k = 0; k < t; k++) - { - for (size_t l = 0; l < t; l++) - { - size_t index = i * c * t * t + j * t * t + k * t + l; - - // Parse the index into the right format - size_t alpha = err_poly_cross_positions[index]; - - // Output message index in the DPF output space - // which consists of 256 F4 elements - size_t alpha_0 = floor(alpha / 256.0); - - // Coeff index in the block of 256 coefficients - size_t alpha_1 = alpha % 256; - - // Coeff index in the block output (64 elements of F4) - size_t byte_idx = alpha_1 / 4; - - // Bit index in the block ouput - size_t element_idx = alpha_1 % 4; - - // Set the DPF message to the coefficient - u8 coeff = err_poly_cross_coeffs[index];//block(err_poly_cross_coeffs[index]); - - // Position coefficient into the block - std::array beta; // init to zero - setBytes(beta, 0); - - // Set the coefficient in the right position - ((uint8_t*)&beta)[byte_idx] = coeff << (2 * element_idx); - //beta[packed_idx] = coeff << (2 * (63 - bit_idx)); - - - // Coeff index in the block output (64 elements of F4) - size_t packed_idx = alpha_1 / 4; - - //// Bit index in the block ouput - //size_t bit_idx = alpha_1 % 4; - //std::array beta2; // init to zero - //beta2[packed_idx] = uint128_t{ coeff } << (2 * (63 - bit_idx)); - //if (memcmp(&beta, &beta2, sizeof(beta)) != 0) - //{ - // std::cout << "FAIL: beta != beta2" << std::endl; - // throw RTE_LOC; - //} - - // Message (beta) is of size 4 blocks of 128 bits - genPrng.SetSeed(block(index, 542345234)); - DPFGen(prf_keys, dpf_domain_bits, alpha_0, beta, 4, dpf_keys_A[index], dpf_keys_B[index], genPrng); - - - dpfHash0.Update(dpf_keys_A[index].k.data(), dpf_keys_A[index].k.size()); - dpfHash0.Update(dpf_keys_A[index].msg_len); - dpfHash0.Update(dpf_keys_A[index].size); - dpfHash1.Update(dpf_keys_B[index].k.data(), dpf_keys_B[index].k.size()); - dpfHash1.Update(dpf_keys_B[index].msg_len); - dpfHash1.Update(dpf_keys_B[index].size); - } - } - } - } - - block dpfHashVal0, dpfHashVal1; - dpfHash0.Final(dpfHashVal0); - dpfHash1.Final(dpfHashVal1); - //std::cout << "dpfA " << dpfHashVal0 << std::endl; - //std::cout << "dpfB " << dpfHashVal1 << std::endl; - - printf("[.... ]Done with Step 4 (sampling DPF keys)\n"); - - //************************************************************************ - // Step 5: Evaluate the DPFs to compute shares of (eA x eB) - //************************************************************************ - - // Allocate memory for the DPF outputs (this is reused for each evaluation) - std::vector shares_A(dpf_block_size); - std::vector shares_B(dpf_block_size); - std::vector cache(dpf_block_size); - - // Allocate memory for the concatenated DPF outputs - size_t packed_block_size = divCeil(block_size, 64); - size_t packed_poly_size = t * packed_block_size; - - // printf("[DEBUG]: packed_block_size = %zu\n", packed_block_size); - // printf("[DEBUG]: packed_poly_size = %zu\n", packed_poly_size); - // - // each row is a block. every t rows is a polynomial. - Matrix packed_polys_A_(c * c * t, packed_block_size); - Matrix packed_polys_B_(c * c * t, packed_block_size); - //std::vector packed_polys_A(c * c * packed_poly_size); - //std::vector packed_polys_B(c * c * packed_poly_size); - - // Allocate memory for the output FFT - std::vectorfft_uA(poly_size); - std::vectorfft_uB(poly_size); - //std::vectorfft_uA2(poly_size); - //std::vectorfft_uB2(poly_size); - - // Allocate memory for the final inner product - std::vector z_poly_A(poly_size); - std::vector z_poly_B(poly_size); - std::vector res_poly_mat_A(poly_size); - std::vector res_poly_mat_B(poly_size); - - auto dpf_keys_A_iter = dpf_keys_A.begin(); - auto dpf_keys_B_iter = dpf_keys_B.begin(); - - for (size_t i = 0; i < c; i++) - { - for (size_t j = 0; j < c; j++) - { - const size_t poly_index = i * c + j; - - oc::MatrixView packed_polyA_(packed_polys_A_.data(poly_index * t), t, packed_block_size); - oc::MatrixView packed_polyB_(packed_polys_B_.data(poly_index * t), t, packed_block_size); - //block* packed_polyA = &packed_polys_A[poly_index * packed_poly_size]; - //block* packed_polyB = &packed_polys_B[poly_index * packed_poly_size]; - - for (size_t k = 0; k < t; k++) - { - span poly_blockA = packed_polyA_[k]; - span poly_blockB = packed_polyB_[k]; - - for (size_t l = 0; l < t; l++) - { - - DPFKey& dpf_keyA = *dpf_keys_A_iter++; - DPFKey& dpf_keyB = *dpf_keys_B_iter++; - - DPFFullDomainEval(dpf_keyA, cache, shares_A); - DPFFullDomainEval(dpf_keyB, cache, shares_B); - - // Sum all the DPFs for the current block together - // note that there is some extra "garbage" in the last - // block of block since 64 does not divide block_size. - // We deal with this slack later when packing the outputs - // into the parallel FFT matrix. - for (size_t w = 0; w < packed_block_size; w++) - { - poly_blockA[w] ^= shares_A[w]; - poly_blockB[w] ^= shares_B[w]; - } - } - } - } - } - - //std::cout << "blockA " << hash(packed_polys_A_.data(), packed_polys_A_.size()) << std::endl; - //std::cout << "blockB " << hash(packed_polys_B_.data(), packed_polys_B_.size()) << std::endl; - - - if (check) - { - - // Here, we test to make sure all polynomials have at most t^2 errors - // and fail the test otherwise. - for (size_t i = 0; i < c; i++) - { - for (size_t j = 0; j < c; j++) - { - size_t err_count = 0; - size_t poly_index = i * c + j; - - oc::MatrixView packed_polyA_(packed_polys_A_.data(poly_index * t), t, packed_block_size); - oc::MatrixView packed_polyB_(packed_polys_B_.data(poly_index * t), t, packed_block_size); - //block* poly_A = &packed_polys_A[poly_index * packed_poly_size]; - //block* poly_B = &packed_polys_B[poly_index * packed_poly_size]; - - for (size_t p = 0; p < packed_poly_size; p++) - { - block res = packed_polyA_(p) ^ packed_polyB_(p); - if (res != ZeroBlock) - { - auto e = extractF4(res); - for (size_t l = 0; l < 64; l++) - { - //if (((res >> (2 * (63 - l))) & block(0b11)) != block(0)) - err_count += (e[l] | (e[l] >> 1)) & 1; - //if (e[l]) - // err_count++; - } - } - } + // This test evaluates the full PCG.Expand for both parties and + // checks correctness of the resulting OLE correlation. + void foleage_F4ole_test(const CLP& cmd) + { + std::array oles; - // printf("[DEBUG]: Number of non-zero coefficients in poly (%zu,%zu) is %zu\n", i, j, err_count); + auto logn = 6; + u64 n = ipow(3, logn) - 67; + auto blocks = divCeil(n, 128); + bool verbose = cmd.isSet("v"); - if (err_count > t * t) - { - printf("FAIL: Number of non-zero coefficients is %zu > t*t\n", err_count); - throw RTE_LOC; - } - else if (err_count == 0) - { - printf("FAIL: Number of non-zero coefficients in poly (%zu,%zu) is %zu\n", i, j, err_count); - throw RTE_LOC; - } - } - } - } - printf("[..... ]Done with Step 5 (evaluating all DPFs)\n"); + if (cmd.hasValue("t")) + oles[0].mT = oles[1].mT = cmd.get("t"); - //************************************************************************ - // Step 6: Compute an FFT over the shares of (eA x eB) - //************************************************************************ + PRNG prng0(block(2424523452345, 111124521521455324)); + PRNG prng1(block(6474567454546, 567546754674345444)); + Timer timer; - // Pack the coefficients into FFT blocks - // - // TODO[optimization]: use AVX and fast matrix transposition algorithms. - // The transpose is the bottleneck of the current implementation and - // therefore improving this step can result in significant performance gains. + oles[0].init(0, n); + oles[1].init(1, n); - if (check) { + auto otCount0 = oles[0].baseOtCount(); + auto otCount1 = oles[1].baseOtCount(); + if (otCount0.mRecvCount != otCount1.mSendCount || + otCount0.mSendCount != otCount1.mRecvCount) + throw RTE_LOC; + std::array>, 2> baseSend; + baseSend[0].resize(otCount0.mSendCount); + baseSend[1].resize(otCount1.mSendCount); + std::array, 2> baseRecv; + std::array baseChoice; - for (size_t j = 0; j < c; j++) + for (u64 i = 0; i < 2; ++i) { - for (size_t k = 0; k < c; k++) + prng0.get(baseSend[i].data(), baseSend[i].size()); + baseRecv[1 ^ i].resize(baseSend[i].size()); + baseChoice[1 ^ i].resize(baseSend[i].size()); + baseChoice[1 ^ i].randomize(prng0); + for (u64 j = 0; j < baseSend[i].size(); ++j) { - std::vector test_poly_A(poly_size); - std::vector test_poly_B(poly_size); - - size_t poly_index = j * c + k; - - oc::MatrixView poly_A(packed_polys_A_.data(poly_index * t), t, packed_block_size); - oc::MatrixView poly_B(packed_polys_B_.data(poly_index * t), t, packed_block_size); - - //block* poly_A = &packed_polys_A[poly_index * packed_poly_size]; - //block* poly_B = &packed_polys_B[poly_index * packed_poly_size]; - - for (u64 block_idx = 0, i = 0; block_idx < t; ++block_idx) - { - for (u64 packed_idx = 0; packed_idx < packed_block_size; ++packed_idx) - { - auto coeffA = extractF4(poly_A(block_idx, packed_idx)); - auto coeffB = extractF4(poly_B(block_idx, packed_idx)); - - //auto idx = j * c + k; - //if (idx >= 16) - // throw RTE_LOC; - auto e = std::min(block_size - packed_idx * 64, 64); - for (u64 element_idx = 0; element_idx < e; ++element_idx) - { - test_poly_A[i] = coeffA[/*63 - */element_idx]; - test_poly_B[i] = coeffB[/*63 - */element_idx]; - ++i; - } - } - } - - for (size_t i = 0; i < poly_size; i++) - { - uint8_t exp_coeff = err_polys_cross[j * c * poly_size + k * poly_size + i]; - uint8_t got_coeff = test_poly_A[i] ^ test_poly_B[i]; - - if (got_coeff != exp_coeff) - { - printf("FAIL: incorrect cross coefficient at index %zu (%i =/= %i)\n", i, got_coeff, exp_coeff); - - - - for (size_t i = 0; i < poly_size; i++) - { - int exp_coeff = err_polys_cross[j * c * poly_size + k * poly_size + i]; - std::cout << exp_coeff << " "; - - } - std::cout << "\n"; - for (size_t i = 0; i < poly_size; i++) - { - int got_coeff = test_poly_A[i] ^ test_poly_B[i]; - std::cout << got_coeff << " "; - } - std::cout << "\n"; - - throw RTE_LOC; - } - } - + baseRecv[1 ^ i][j] = baseSend[i][j][baseChoice[1 ^ i][j]]; } } - } - - // TODO[optimization]: for arbitrary values of C, we only need to perform - // C*(C+1)/2 FFTs which can lead to a more efficient implementation. - // Because we assume C=4, we have C*C = 16 which fits perfectly into a - // uint32 packing. - - for (size_t j = 0; j < c; j++) - { - for (size_t k = 0; k < c; k++) - { - size_t poly_index = (j * c + k);// *packed_poly_size; - - oc::MatrixView polyA(packed_polys_A_.data(poly_index * t), t, packed_block_size); - oc::MatrixView polyB(packed_polys_B_.data(poly_index * t), t, packed_block_size); - - u64 i = 0; - for (u64 block_idx = 0; block_idx < t; ++block_idx) - { - for (u64 packed_idx = 0; packed_idx < packed_block_size; ++packed_idx) - { - auto coeffA = extractF4(polyA(block_idx, packed_idx)); - auto coeffB = extractF4(polyB(block_idx, packed_idx)); - //auto idx = j * c + k; - //if (idx >= 16) - // throw RTE_LOC; - auto e = std::min(block_size - packed_idx * 64, 64); - - for (u64 element_idx = 0; element_idx < e; ++element_idx) - { - fft_uA[i] |= u32{ coeffA[/*63 - */element_idx] } << (2 * poly_index); - fft_uB[i] |= u32{ coeffB[/*63 - */element_idx] } << (2 * poly_index); - ++i; - } - } - } - } + oles[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); + oles[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); } - //std::cout << "Cin0 " << hash(fft_uA.data(), fft_uA.size()) << std::endl; - //std::cout << "Cin1 " << hash(fft_uB.data(), fft_uB.size()) << std::endl; - - fft_recursive_uint32(fft_uA, n, poly_size / 3); - fft_recursive_uint32(fft_uB, n, poly_size / 3); - - //std::cout << "Cfft0 " << hash(fft_uA.data(), fft_uA.size()) << std::endl; - //std::cout << "Cfft1 " << hash(fft_uB.data(), fft_uB.size()) << std::endl; - - - printf("[...... ]Done with Step 6 (computing FFTs)\n"); - - //************************************************************************ - // Step 7: Compute shares of z = - //************************************************************************ - multiply_fft_32(fft_a2, fft_uA, res_poly_mat_A, poly_size); - multiply_fft_32(fft_a2, fft_uB, res_poly_mat_B, poly_size); - //std::cout << "C0 " << hash(res_poly_mat_A.data(), res_poly_mat_A.size()) << std::endl; - //std::cout << "C1 " << hash(res_poly_mat_B.data(), res_poly_mat_B.size()) << std::endl; - - //size_t num_ffts = c * c; - - // XOR the (packed) columns into the accumulator. - // Specifically, we perform column-wise XORs to get the result. - u32 lsbMask, msbMask; - setBytes(lsbMask, 0b01010101); - setBytes(msbMask, 0b10101010); - for (size_t i = 0; i < poly_size; i++) - { - //auto resA = extractF4(res_poly_mat_A[i]); - //auto resB = extractF4(res_poly_mat_B[i]); - - z_poly_A[i] = - (popcount(res_poly_mat_A[i] & lsbMask) & 1) | - ((popcount(res_poly_mat_A[i] & msbMask) & 1) << 1); - - z_poly_B[i] = - (popcount(res_poly_mat_B[i] & lsbMask) & 1) | - ((popcount(res_poly_mat_B[i] & msbMask) & 1) << 1); - - //u8 aSum = 0; - - //for (size_t j = 0; j < c * c; j++) - //{ - // aSum ^= resA[j]; - //} + auto sock = coproto::LocalAsyncSocket::makePair(); + std::vector + ALsb(blocks), + AMsb(blocks), + BLsb(blocks), + BMsb(blocks), + C0Lsb(blocks), + C0Msb(blocks), + C1Lsb(blocks), + C1Msb(blocks); - //if ((aSum & 1) != aLsb) - // throw RTE_LOC; - //if (((aSum>>1) & 1) != aMsb) - // throw RTE_LOC; + if (verbose) + oles[0].setTimer(timer); - //for (size_t j = 0; j < c * c; j++) - //{ - // z_poly_A[i] ^= resA[j]; - // z_poly_B[i] ^= resB[j]; - //} - } + auto r = macoro::sync_wait(macoro::when_all_ready( + oles[0].expand(ALsb, AMsb, C0Lsb, C0Msb, prng0, sock[0]), + oles[1].expand(BLsb, BMsb, C1Lsb, C1Msb, prng1, sock[1]))); + std::get<0>(r).result(); + std::get<1>(r).result(); // Now we check that we got the correct OLE correlations and fail // the test otherwise. - for (size_t i = 0; i < poly_size; i++) + for (size_t i = 0; i < blocks; i++) { - uint8_t res = z_poly_A[i] ^ z_poly_B[i]; - uint8_t exp = mult_f4(x_poly_A[i], x_poly_B[i]); - - // printf("[DEBUG]: Got: (%i,%i), Expected: (%i, %i)\n", - // (res >> 1) & 1, res & 1, (exp >> 1) & 1, exp & 1); + auto Lsb = C0Lsb[i] ^ C1Lsb[i]; + auto Msb = C0Msb[i] ^ C1Msb[i]; + block mLsb, mMsb; + F4Multiply( + ALsb[i], AMsb[i], + BLsb[i], BMsb[i], + mLsb, mMsb); - if (res != exp) - { - printf("FAIL: Incorrect correlation output at index %zu\n", i); - printf("Got: (%i,%i), Expected: (%i, %i)\n", - (res >> 1) & 1, res & 1, (exp >> 1) & 1, exp & 1); + if (Lsb != mLsb) + throw RTE_LOC; + if (Msb != mMsb) throw RTE_LOC; - - } } - time = clock() - time; - double time_taken = ((double)time) / (CLOCKS_PER_SEC / 1000.0); // ms - - printf("[.......]Done with Step 7 (recovering shares)\n\n"); - - printf("Time elapsed %f ms\n", time_taken); - + if (verbose) + std::cout << "Time taken: \n" << timer << std::endl; } - - - // This test evaluates the full PCG.Expand for both parties and - // checks correctness of the resulting OLE correlation. - void foleage_F4ole_test(const CLP& cmd) + void foleage_Triple_test(const CLP& cmd) { - std::array oles; + std::array oles; - auto logn = 10; + auto logn = 5; u64 n = ipow(3, logn); auto blocks = divCeil(n, 128); bool verbose = cmd.isSet("v"); @@ -1432,7 +110,6 @@ namespace osuCrypto if (cmd.hasValue("t")) oles[0].mT = oles[1].mT = cmd.get("t"); - //PRNG prng(block(342342)); PRNG prng0(block(2424523452345, 111124521521455324)); PRNG prng1(block(6474567454546, 567546754674345444)); Timer timer; @@ -1469,22 +146,21 @@ namespace osuCrypto } auto sock = coproto::LocalAsyncSocket::makePair(); - std::vector - ALsb(blocks), - AMsb(blocks), - BLsb(blocks), - BMsb(blocks), - C0Lsb(blocks), - C0Msb(blocks), - C1Lsb(blocks), - C1Msb(blocks); + std::array, 2> + A, B, C; + for (u64 i = 0; i < 2; ++i) + { + A[i].resize(blocks); + B[i].resize(blocks); + C[i].resize(blocks); + } if (verbose) oles[0].setTimer(timer); auto r = macoro::sync_wait(macoro::when_all_ready( - oles[0].expand(ALsb, AMsb, C0Lsb, C0Msb, prng0, sock[0]), - oles[1].expand(BLsb, BMsb, C1Lsb, C1Msb, prng1, sock[1]))); + oles[0].expand(A[0], B[0], C[0], prng0, sock[0]), + oles[1].expand(A[1], B[1], C[1], prng1, sock[1]))); std::get<0>(r).result(); std::get<1>(r).result(); @@ -1492,27 +168,87 @@ namespace osuCrypto // the test otherwise. for (size_t i = 0; i < blocks; i++) { - auto Lsb = C0Lsb[i] ^ C1Lsb[i]; - auto Msb = C0Msb[i] ^ C1Msb[i]; - block mLsb, mMsb; - f4Mult( - ALsb[i], AMsb[i], - BLsb[i], BMsb[i], - mLsb, mMsb); - - if (Lsb != mLsb) - throw RTE_LOC; - if (Msb != mMsb) + auto a = A[0][i] ^ A[1][i]; + auto b = B[0][i] ^ B[1][i]; + auto c = C[0][i] ^ C[1][i]; + if ((a & b) != c) throw RTE_LOC; } if (verbose) std::cout << "Time taken: \n" << timer << std::endl; } + + void foleage_GenBase_test(const CLP& cmd) + { + for (auto type : { SilentBaseType::Base, SilentBaseType::BaseExtend }) + { + + std::array oles; + PRNG prng0(block(2424523452345, 111124521521455324)); + PRNG prng1(block(6474567454546, 567546754674345444)); + + // insecure but makes the but makes the test run faster. + oles[0].mT = 3; + oles[1].mT = 3; + + u64 n = 1000; + oles[0].init(0, n); + oles[1].init(1, n); + + auto blocks = divCeil(n, 128); + + auto sock = coproto::LocalAsyncSocket::makePair(); + std::vector + ALsb(blocks), + AMsb(blocks), + BLsb(blocks), + BMsb(blocks), + C0Lsb(blocks), + C0Msb(blocks), + C1Lsb(blocks), + C1Msb(blocks); + + // baseExtend is the default and will be called by expand. + if (type == SilentBaseType::Base) + { + auto r = macoro::sync_wait(macoro::when_all_ready( + oles[0].genBaseOts(prng0, sock[0], type), + oles[1].genBaseOts(prng1, sock[1], type))); + std::get<0>(r).result(); + std::get<1>(r).result(); + } + + auto r = macoro::sync_wait(macoro::when_all_ready( + oles[0].expand(ALsb, AMsb, C0Lsb, C0Msb, prng0, sock[0]), + oles[1].expand(BLsb, BMsb, C1Lsb, C1Msb, prng1, sock[1]))); + std::get<0>(r).result(); + std::get<1>(r).result(); + + // Now we check that we got the correct OLE correlations and fail + // the test otherwise. + for (size_t i = 0; i < blocks; i++) + { + auto Lsb = C0Lsb[i] ^ C1Lsb[i]; + auto Msb = C0Msb[i] ^ C1Msb[i]; + block mLsb, mMsb; + F4Multiply( + ALsb[i], AMsb[i], + BLsb[i], BMsb[i], + mLsb, mMsb); + + if (Lsb != mLsb) + throw RTE_LOC; + if (Msb != mMsb) + throw RTE_LOC; + } + } + } + void foleage_tensor_test(const CLP& cmd) { - std::array oles; + std::array oles; //bool verbose = cmd.isSet("v"); @@ -1556,7 +292,7 @@ namespace osuCrypto u8 ci = coeff[0][i]; u8 cj = coeff[1][j]; - auto exp = mult_f4(ci, cj); + auto exp = F4Multiply(ci, cj); auto act = prod[0][p] ^ prod[1][p]; if (exp != act) throw RTE_LOC; diff --git a/libOTe_Tests/Foleage_Tests.h b/libOTe_Tests/Foleage_Tests.h index 813a5b2a..ce9c94e4 100644 --- a/libOTe_Tests/Foleage_Tests.h +++ b/libOTe_Tests/Foleage_Tests.h @@ -2,13 +2,15 @@ #include "cryptoTools/Common/CLP.h" namespace osuCrypto { - void foleage_transpose_test(const oc::CLP& cmd); - void foleage_fft_test(const oc::CLP& cmd); + //void foleage_transpose_test(const oc::CLP& cmd); + //void foleage_fft_test(const oc::CLP& cmd); - void foleage_spfss_test(); - void foleage_dpf_test(); - void foleage_pcg_test(const CLP& cmd); + //void foleage_spfss_test(); + //void foleage_dpf_test(); + //void foleage_pcg_test(const CLP& cmd); void foleage_F4ole_test(const CLP& cmd); + void foleage_Triple_test(const CLP& cmd); + void foleage_GenBase_test(const CLP& cmd); void foleage_tensor_test(const CLP& cmd); diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 705c5c0b..471d111c 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -1,10 +1,10 @@ #include "RegularDpf_Tests.h" -#include "libOTe/Tools/Dpf/RegularDpf.h" +#include "libOTe/Dpf/RegularDpf.h" #include "coproto/Socket/LocalAsyncSock.h" -#include "libOTe/Tools/Dpf/SparseDpf.h" +#include "libOTe/Dpf/SparseDpf.h" #include #include -#include "libOTe/Tools/Dpf/TriDpf.h" +#include "libOTe/Dpf/TriDpf.h" using namespace oc; diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 193ea1b2..66cf724c 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -65,15 +65,10 @@ namespace tests_libOTe tc.add("SparseDpf_Proto_Test ", SparseDpf_Proto_Test); tc.add("TritDpf_Proto_Test ", TritDpf_Proto_Test); - - tc.add("foleage_transpose_test ", foleage_transpose_test); - tc.add("foleage_fft_test ", foleage_fft_test); - tc.add("foleage_dpf_test ", foleage_dpf_test); - tc.add("foleage_spfss_test ", foleage_spfss_test); - tc.add("foleage_pcg_test ", foleage_pcg_test); - tc.add("foleage_tensor_test ", foleage_tensor_test); tc.add("foleage_F4ole_test ", foleage_F4ole_test); + tc.add("foleage_Triple_test ", foleage_Triple_test); + tc.add("foleage_GenBase_test ", foleage_GenBase_test); tc.add("Bot_Simplest_Test ", Bot_Simplest_Test); From 4d5b0fac82cd7031546e7f0db0ff5a333e84888d Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 11:00:18 -0800 Subject: [PATCH 24/48] fixed sparse dpf --- libOTe/Dpf/RegularDpf.h | 2 +- libOTe/Triple/Foleage/FoleageTriple.cpp | 6 +- libOTe/TwoChooseOne/TcoOtDefines.h | 2 +- libOTe_Tests/RegularDpf_Tests.cpp | 77 ++++++++++++++++++++++++- libOTe_Tests/RegularDpf_Tests.h | 1 + libOTe_Tests/UnitTests.cpp | 2 + 6 files changed, 84 insertions(+), 6 deletions(-) diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 76fc38f9..caa8bbfc 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -609,7 +609,7 @@ namespace osuCrypto } else { - auto& sd = s[mDepth & 1]; + auto& sd = s[mDepth % 3]; auto& td = tags; for (u64 i = 0; i < mDomain; ++i) { diff --git a/libOTe/Triple/Foleage/FoleageTriple.cpp b/libOTe/Triple/Foleage/FoleageTriple.cpp index 74975c74..62237077 100644 --- a/libOTe/Triple/Foleage/FoleageTriple.cpp +++ b/libOTe/Triple/Foleage/FoleageTriple.cpp @@ -144,7 +144,7 @@ namespace osuCrypto co_await mOtExtSender->send(sendMsg, prng, sock); choice = BitVector(choice.data(), choice.size() - extSenderCount, extSenderCount); - setBaseOts(sendMsg, span(recvMsg).subspan(extSenderCount), choice); + setBaseOts(sendMsg, span(recvMsg).subspan(extSenderCount), choice); #else throw std::runtime_error("ENABLE_SOFTSPOKEN_OT = false, must enable soft spoken. " LOCATION); #endif @@ -207,7 +207,7 @@ namespace osuCrypto std::vector recvMsg(choice.size()); co_await mOtExtRecver->receive(choice, recvMsg, prng, sock); - setBaseOts(span(sendMsg).subspan(extRecverCount), recvMsg, choice); + setBaseOts(span>(sendMsg).subspan(extRecverCount), recvMsg, choice); #else throw std::runtime_error("ENABLE_SOFTSPOKEN_OT = false, must enable soft spoken. " LOCATION); #endif @@ -557,7 +557,7 @@ namespace osuCrypto { // XOR the (packed) columns into the accumulator. // Specifically, we perform column-wise XORs to get the result. - u32 lsbMask, msbMask; + u32 lsbMask; setBytes(lsbMask, 0b01010101); for (size_t i = 0; i < outSize; i++) { diff --git a/libOTe/TwoChooseOne/TcoOtDefines.h b/libOTe/TwoChooseOne/TcoOtDefines.h index 410b1215..9be38fac 100644 --- a/libOTe/TwoChooseOne/TcoOtDefines.h +++ b/libOTe/TwoChooseOne/TcoOtDefines.h @@ -22,7 +22,7 @@ namespace osuCrypto const u64 commStepSize(512); // TODO: try increasing this for optimization. const u64 superBlkShift(3); const u64 superBlkSize(1 << superBlkShift); - const u64 gKosChallengeRepititions(4); + const u64 gKosChallengeRepititions(1); enum class SilentBaseType { // Use a standalone base OT protocol to generate the required base OTs diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 471d111c..30fc1b8f 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -190,6 +190,81 @@ void RegularDpf_Proto_Test(const CLP& cmd) } } +void RegularDpf_Puncture_Test(const oc::CLP& cmd) +{ + + PRNG prng(block(231234, 321312)); + u64 domain = cmd.getOr("domain", 211); + u64 numPoints = cmd.getOr("numPoints", 7); + std::vector points0(numPoints); + std::vector points1(numPoints); + for (u64 i = 0; i < numPoints; ++i) + { + points1[i] = prng.get(); + points0[i] = (prng.get() % domain) ^ points1[i]; + } + + std::array dpf; + dpf[0].init(0, domain, numPoints); + dpf[1].init(1, domain, numPoints); + + auto baseCount = dpf[0].baseOtCount(); + + std::array, 2> baseRecv; + std::array>, 2> baseSend; + std::array baseChoice; + baseRecv[0].resize(baseCount); + baseRecv[1].resize(baseCount); + baseSend[0].resize(baseCount); + baseSend[1].resize(baseCount); + baseChoice[0].resize(baseCount); + baseChoice[1].resize(baseCount); + baseChoice[0].randomize(prng); + baseChoice[1].randomize(prng); + for (u64 i = 0; i < baseCount; ++i) + { + baseSend[0][i] = prng.get(); + baseSend[1][i] = prng.get(); + baseRecv[0][i] = baseSend[1][i][baseChoice[0][i]]; + baseRecv[1][i] = baseSend[0][i][baseChoice[1][i]]; + } + dpf[0].setBaseOts(baseSend[0], baseRecv[0], baseChoice[0]); + dpf[1].setBaseOts(baseSend[1], baseRecv[1], baseChoice[1]); + + std::array, 2> output; + std::array, 2> tags; + output[0].resize(numPoints, domain); + output[1].resize(numPoints, domain); + tags[0].resize(numPoints, domain); + tags[1].resize(numPoints, domain); + + auto sock = coproto::LocalAsyncSocket::makePair(); + macoro::sync_wait(macoro::when_all_ready( + dpf[0].expand(points0, {}, prng.get(), [&](auto k, auto i, auto v, block t) { output[0](k, i) = v; tags[0](k, i) = t.get(0) & 1; }, sock[0]), + dpf[1].expand(points1, {}, prng.get(), [&](auto k, auto i, auto v, block t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }, sock[1]) + )); + + + for (u64 i = 0; i < domain; ++i) + { + for (u64 k = 0; k < numPoints; ++k) + { + auto p = points0[k] ^ points1[k]; + auto act = output[0][k][i] ^ output[1][k][i]; + auto t = i == p ? 1 : 0; + auto tAct = tags[0][k][i] ^ tags[1][k][i]; + if (t == 0 && act != ZeroBlock) + throw RTE_LOC; + + if (t == 1 && act == ZeroBlock) + throw RTE_LOC; + + if (t != tAct) + throw RTE_LOC; + } + } +} + void RegularDpf_keyGen_Test(const oc::CLP& cmd) { @@ -327,7 +402,7 @@ void SparseDpf_Proto_Test(const oc::CLP& cmd) std::sort(sparsePoints[i].begin(), sparsePoints[i].end()); index[i] = prng.get() % sparsePoints.cols(); value[i] = prng.get(); - auto alpha = sparsePoints(i, index[i]); + //auto alpha = sparsePoints(i, index[i]); //std::cout << "alpha " << alpha << " " << oc::BitVector((u8*)&alpha, log2ceil(domain)) << std::endl; points[0][i] = prng.get(); points[1][i] = points[0][i] ^ sparsePoints(i, index[i]); diff --git a/libOTe_Tests/RegularDpf_Tests.h b/libOTe_Tests/RegularDpf_Tests.h index 7edad019..c82f2281 100644 --- a/libOTe_Tests/RegularDpf_Tests.h +++ b/libOTe_Tests/RegularDpf_Tests.h @@ -4,6 +4,7 @@ void RegularDpf_Multiply_Test(const oc::CLP& cmd); void RegularDpf_Proto_Test(const oc::CLP& cmd); +void RegularDpf_Puncture_Test(const oc::CLP& cmd); void RegularDpf_keyGen_Test(const oc::CLP& cmd); void SparseDpf_Proto_Test(const oc::CLP& cmd); void TritDpf_Proto_Test(const oc::CLP& cmd); \ No newline at end of file diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 66cf724c..14712112 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -61,6 +61,7 @@ namespace tests_libOTe tc.add("RegularDpf_Multiply_Test ", RegularDpf_Multiply_Test); tc.add("RegularDpf_Proto_Test ", RegularDpf_Proto_Test); + tc.add("RegularDpf_Puncture_Test ", RegularDpf_Puncture_Test); tc.add("RegularDpf_keyGen_Test ", RegularDpf_keyGen_Test); tc.add("SparseDpf_Proto_Test ", SparseDpf_Proto_Test); tc.add("TritDpf_Proto_Test ", TritDpf_Proto_Test); @@ -68,6 +69,7 @@ namespace tests_libOTe tc.add("foleage_tensor_test ", foleage_tensor_test); tc.add("foleage_F4ole_test ", foleage_F4ole_test); tc.add("foleage_Triple_test ", foleage_Triple_test); + tc.add("foleage_GenBase_test ", foleage_GenBase_test); From b6671961acef9df199f811903e194a5347abaabe Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 11:01:25 -0800 Subject: [PATCH 25/48] foleage ci --- .github/workflows/build-test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 5a791acd..2bba8f1a 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -6,9 +6,9 @@ name: CI on: # Triggers the workflow on push or pull request events but only for the master branch push: - branches: [ master, stage, components, cpp20 ] + branches: [ master, stage, foliage, cpp20 ] pull_request: - branches: [ master, stage, components ] + branches: [ master, stage, foliage ] # Allows you to run this workflow manually from the Actions tab workflow_dispatch: From 6bc720b2a1f92bd982766e2ef58c1cb6e54383c3 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 11:18:40 -0800 Subject: [PATCH 26/48] foleage mac compile --- cryptoTools | 2 +- libOTe/Dpf/TriDpf.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cryptoTools b/cryptoTools index 2bf5fe84..96de1fc6 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit 2bf5fe84e19cadd9aeea5c191a08ac59e65b54e7 +Subproject commit 96de1fc6e808cfbe2aec0495a5a74d99535adb9d diff --git a/libOTe/Dpf/TriDpf.h b/libOTe/Dpf/TriDpf.h index d0abb00e..df9d03cb 100644 --- a/libOTe/Dpf/TriDpf.h +++ b/libOTe/Dpf/TriDpf.h @@ -764,8 +764,8 @@ namespace osuCrypto static block tagBit(const block& b) { auto bit = b & block(0, 1); - auto mask = _mm_sub_epi64(_mm_set1_epi64x(0), bit); - return _mm_unpacklo_epi64(mask, mask); + auto mask = block(0,0).sub_epi64(bit); + return mask.unpacklo_epi64(mask); } }; From f0b1a867b7da392b3a2e85ba497a94569d4b8de7 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 11:27:42 -0800 Subject: [PATCH 27/48] removing deprecated calls --- cryptoTools | 2 +- libOTe/Dpf/RegularDpf.h | 4 ++-- libOTe/Tools/Pprf/RegularPprf.h | 32 ++++++++++++++++---------------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/cryptoTools b/cryptoTools index 96de1fc6..8239552d 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit 96de1fc6e808cfbe2aec0495a5a74d99535adb9d +Subproject commit 8239552d5b1919f6d8372268459680bc225037a8 diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index caa8bbfc..014746d0 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -479,7 +479,7 @@ namespace osuCrypto // (s0', s1') = H(s) mAesFixedKey.ecbEncBlocks<8>(¤tSeed[j][k], &temp[0]); SIMD8(q, childSeed[j * 2 + 0][k + q] = AES::roundEnc(temp[q], currentSeed[j][k + q])); - SIMD8(q, childSeed[j * 2 + 1][k + q] = temp[q] + currentSeed[j][k + q]); + SIMD8(q, childSeed[j * 2 + 1][k + q] = temp[q].add_epi64(currentSeed[j][k + q])); // z = z ^ s' SIMD8(q, z[0][k + q] ^= childSeed[j * 2 + 0][k + q]); @@ -499,7 +499,7 @@ namespace osuCrypto temp[0] = mAesFixedKey.ecbEncBlock(currentSeed[j][k]); childSeed[j * 2 + 0][k] = AES::roundEnc(temp[0], currentSeed[j][k]); - childSeed[j * 2 + 1][k] = temp[0] + currentSeed[j][k]; + childSeed[j * 2 + 1][k] = temp[0].add_epi64(currentSeed[j][k]); z[0][k] ^= childSeed[j * 2 + 0][k]; z[1][k] ^= childSeed[j * 2 + 1][k]; diff --git a/libOTe/Tools/Pprf/RegularPprf.h b/libOTe/Tools/Pprf/RegularPprf.h index f19f4c17..b5f745a9 100644 --- a/libOTe/Tools/Pprf/RegularPprf.h +++ b/libOTe/Tools/Pprf/RegularPprf.h @@ -335,14 +335,14 @@ namespace osuCrypto sums[0][7] = sums[0][7] ^ child0[7]; // child1 = AES(parent) + parent - child1[0] = child1[0] + parent[0]; - child1[1] = child1[1] + parent[1]; - child1[2] = child1[2] + parent[2]; - child1[3] = child1[3] + parent[3]; - child1[4] = child1[4] + parent[4]; - child1[5] = child1[5] + parent[5]; - child1[6] = child1[6] + parent[6]; - child1[7] = child1[7] + parent[7]; + child1[0] = child1[0].add_epi64(parent[0]); + child1[1] = child1[1].add_epi64(parent[1]); + child1[2] = child1[2].add_epi64(parent[2]); + child1[3] = child1[3].add_epi64(parent[3]); + child1[4] = child1[4].add_epi64(parent[4]); + child1[5] = child1[5].add_epi64(parent[5]); + child1[6] = child1[6].add_epi64(parent[6]); + child1[7] = child1[7].add_epi64(parent[7]); sums[1][0] = sums[1][0] ^ child1[0]; sums[1][1] = sums[1][1] ^ child1[1]; @@ -953,14 +953,14 @@ namespace osuCrypto mySums[0][7] = mySums[0][7] ^ child0[7]; // child1 = AES(parent) + parent - child1[0] = child1[0] + parent[0]; - child1[1] = child1[1] + parent[1]; - child1[2] = child1[2] + parent[2]; - child1[3] = child1[3] + parent[3]; - child1[4] = child1[4] + parent[4]; - child1[5] = child1[5] + parent[5]; - child1[6] = child1[6] + parent[6]; - child1[7] = child1[7] + parent[7]; + child1[0] = child1[0].add_epi64(parent[0]); + child1[1] = child1[1].add_epi64(parent[1]); + child1[2] = child1[2].add_epi64(parent[2]); + child1[3] = child1[3].add_epi64(parent[3]); + child1[4] = child1[4].add_epi64(parent[4]); + child1[5] = child1[5].add_epi64(parent[5]); + child1[6] = child1[6].add_epi64(parent[6]); + child1[7] = child1[7].add_epi64(parent[7]); mySums[1][0] = mySums[1][0] ^ child1[0]; mySums[1][1] = mySums[1][1] ^ child1[1]; From ea8a06807b4081333df25088eec88d2dce55c365 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 11:38:12 -0800 Subject: [PATCH 28/48] removing deprecated calls --- libOTe/Tools/Tools.cpp | 100 ++++++++++++------------- libOTe/Triple/Foleage/fft/FoleageFft.h | 2 - 2 files changed, 48 insertions(+), 54 deletions(-) diff --git a/libOTe/Tools/Tools.cpp b/libOTe/Tools/Tools.cpp index 6e3bff73..2836ed60 100644 --- a/libOTe/Tools/Tools.cpp +++ b/libOTe/Tools/Tools.cpp @@ -22,12 +22,8 @@ using std::array; namespace osuCrypto { - //bool gUseBgicksPprf(true); - -//using namespace std; - -// Utility function to do modular exponentiation. -// It returns (x^y) % p + // Utility function to do modular exponentiation. + // It returns (x^y) % p u64 power(u64 x, u64 y, u64 p) { u64 res = 1; // Initialize result @@ -105,14 +101,14 @@ namespace osuCrypto { bool isPrime(u64 n) { - PRNG prng(ZeroBlock); + PRNG prng(oc::sysRandomSeed()); return isPrime(n, prng); } u64 nextPrime(u64 n) { - PRNG prng(ZeroBlock); + PRNG prng(oc::sysRandomSeed()); while (isPrime(n, prng) == false) ++n; @@ -335,8 +331,8 @@ namespace osuCrypto { outU16View[16 * x + 7 - j][y] = in[0].movemask_epi8(); outU16View[16 * x + 15 - j][y] = in[1].movemask_epi8(); - in[0] = (in[0] << 1); - in[1] = (in[1] << 1); + in[0] = in[0].slli_epi64(1); + in[1] = in[1].slli_epi64(1); } } @@ -494,14 +490,14 @@ namespace osuCrypto { out7 -= out.stride(); // shift the 128 values so that the top bit is now the next one. - t.blks[0] = (t.blks[0] << 1); - t.blks[1] = (t.blks[1] << 1); - t.blks[2] = (t.blks[2] << 1); - t.blks[3] = (t.blks[3] << 1); - t.blks[4] = (t.blks[4] << 1); - t.blks[5] = (t.blks[5] << 1); - t.blks[6] = (t.blks[6] << 1); - t.blks[7] = (t.blks[7] << 1); + t.blks[0] = t.blks[0].slli_epi64(1); + t.blks[1] = t.blks[1].slli_epi64(1); + t.blks[2] = t.blks[2].slli_epi64(1); + t.blks[3] = t.blks[3].slli_epi64(1); + t.blks[4] = t.blks[4].slli_epi64(1); + t.blks[5] = t.blks[5].slli_epi64(1); + t.blks[6] = t.blks[6].slli_epi64(1); + t.blks[7] = t.blks[7].slli_epi64(1); } } } @@ -550,7 +546,7 @@ namespace osuCrypto { auto out0 = outStart + (chunkSize * subBlockHight + hh) * 8 * out.stride() + w * 2; out0 -= out.stride() * skip; - t.blks[0] = (t.blks[0] << int(skip)); + t.blks[0] = t.blks[0].slli_epi64(skip); for (int j = 0; j < rem; j++) { @@ -558,7 +554,7 @@ namespace osuCrypto { out0 -= out.stride(); - t.blks[0] = (t.blks[0] << 1); + t.blks[0] = t.blks[0].slli_epi64(1); } } } @@ -623,14 +619,14 @@ namespace osuCrypto { out6 -= out.stride(); out7 -= out.stride(); - t.blks[0] = (t.blks[0] << 1); - t.blks[1] = (t.blks[1] << 1); - t.blks[2] = (t.blks[2] << 1); - t.blks[3] = (t.blks[3] << 1); - t.blks[4] = (t.blks[4] << 1); - t.blks[5] = (t.blks[5] << 1); - t.blks[6] = (t.blks[6] << 1); - t.blks[7] = (t.blks[7] << 1); + t.blks[0] = t.blks[0].slli_epi64(1); + t.blks[1] = t.blks[1].slli_epi64(1); + t.blks[2] = t.blks[2].slli_epi64(1); + t.blks[3] = t.blks[3].slli_epi64(1); + t.blks[4] = t.blks[4].slli_epi64(1); + t.blks[5] = t.blks[5].slli_epi64(1); + t.blks[6] = t.blks[6].slli_epi64(1); + t.blks[7] = t.blks[7].slli_epi64(1); } } else @@ -655,14 +651,14 @@ namespace osuCrypto { out6 -= out.stride(); out7 -= out.stride(); - t.blks[0] = (t.blks[0] << 1); - t.blks[1] = (t.blks[1] << 1); - t.blks[2] = (t.blks[2] << 1); - t.blks[3] = (t.blks[3] << 1); - t.blks[4] = (t.blks[4] << 1); - t.blks[5] = (t.blks[5] << 1); - t.blks[6] = (t.blks[6] << 1); - t.blks[7] = (t.blks[7] << 1); + t.blks[0] = t.blks[0].slli_epi64(1); + t.blks[1] = t.blks[1].slli_epi64(1); + t.blks[2] = t.blks[2].slli_epi64(1); + t.blks[3] = t.blks[3].slli_epi64(1); + t.blks[4] = t.blks[4].slli_epi64(1); + t.blks[5] = t.blks[5].slli_epi64(1); + t.blks[6] = t.blks[6].slli_epi64(1); + t.blks[7] = t.blks[7].slli_epi64(1); } } } @@ -936,14 +932,14 @@ namespace osuCrypto { auto x16_7 = x * 16 + 7; auto x16_15 = x * 16 + 15; - block b0 = (in[0] << 0); - block b1 = (in[0] << 1); - block b2 = (in[0] << 2); - block b3 = (in[0] << 3); - block b4 = (in[0] << 4); - block b5 = (in[0] << 5); - block b6 = (in[0] << 6); - block b7 = (in[0] << 7); + block b0 = in[0].slli_epi64(0); + block b1 = in[0].slli_epi64(1); + block b2 = in[0].slli_epi64(2); + block b3 = in[0].slli_epi64(3); + block b4 = in[0].slli_epi64(4); + block b5 = in[0].slli_epi64(5); + block b6 = in[0].slli_epi64(6); + block b7 = in[0].slli_epi64(7); outU16View[x16_7 - 0][i8y] = b0.movemask_epi8(); outU16View[x16_7 - 1][i8y] = b1.movemask_epi8(); @@ -954,14 +950,14 @@ namespace osuCrypto { outU16View[x16_7 - 6][i8y] = b6.movemask_epi8(); outU16View[x16_7 - 7][i8y] = b7.movemask_epi8(); - b0 = (in[1] << 0); - b1 = (in[1] << 1); - b2 = (in[1] << 2); - b3 = (in[1] << 3); - b4 = (in[1] << 4); - b5 = (in[1] << 5); - b6 = (in[1] << 6); - b7 = (in[1] << 7); + b0 = in[1].slli_epi64(0); + b1 = in[1].slli_epi64(1); + b2 = in[1].slli_epi64(2); + b3 = in[1].slli_epi64(3); + b4 = in[1].slli_epi64(4); + b5 = in[1].slli_epi64(5); + b6 = in[1].slli_epi64(6); + b7 = in[1].slli_epi64(7); outU16View[x16_15 - 0][i8y] = b0.movemask_epi8(); outU16View[x16_15 - 1][i8y] = b1.movemask_epi8(); diff --git a/libOTe/Triple/Foleage/fft/FoleageFft.h b/libOTe/Triple/Foleage/fft/FoleageFft.h index 3e8c280e..bcae8b01 100644 --- a/libOTe/Triple/Foleage/fft/FoleageFft.h +++ b/libOTe/Triple/Foleage/fft/FoleageFft.h @@ -5,9 +5,7 @@ #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/MatrixView.h" #include "libOTe/Triple/Foleage/FoleageUtils.h" -#include -//#include "libOTe/Tools/Foleage/utils.h" namespace osuCrypto { //typedef __int128 int128_t; From 5ffb297912450893a881d708a551feb094d0d303 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 11:43:15 -0800 Subject: [PATCH 29/48] mac compile fixes --- cryptoTools | 2 +- libOTe/Tools/Tools.cpp | 2 +- libOTe/Triple/Foleage/FoleageTriple.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cryptoTools b/cryptoTools index 8239552d..6f762c94 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit 8239552d5b1919f6d8372268459680bc225037a8 +Subproject commit 6f762c941110b2aa90d2e2c5dc026d48e2e38bba diff --git a/libOTe/Tools/Tools.cpp b/libOTe/Tools/Tools.cpp index 2836ed60..15c3362b 100644 --- a/libOTe/Tools/Tools.cpp +++ b/libOTe/Tools/Tools.cpp @@ -706,7 +706,7 @@ namespace osuCrypto { out0 -= out.stride(); - t.blks[0] = (t.blks[0] << 1); + t.blks[0] = t.blks[0].slli_epi64(1); } } } diff --git a/libOTe/Triple/Foleage/FoleageTriple.cpp b/libOTe/Triple/Foleage/FoleageTriple.cpp index 62237077..aebe4310 100644 --- a/libOTe/Triple/Foleage/FoleageTriple.cpp +++ b/libOTe/Triple/Foleage/FoleageTriple.cpp @@ -361,7 +361,7 @@ namespace osuCrypto setTimePoint("input Mult"); // compress the resume and set the output. - auto outSize = std::min(mN, ALsb.size() * 128); + auto outSize = std::min(mN, ALsb.size() * 128); std::vector A(mN); for (u64 i = 0; i < outSize; ++i) { From 1568af32f03ddc0ec05abd881a7b8220518e7210 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 11:48:52 -0800 Subject: [PATCH 30/48] mac compile fixes --- libOTe/Dpf/RegularDpf.h | 4 ++-- libOTe_Tests/OT_Tests.cpp | 2 +- libOTe_Tests/Pprf_Tests.cpp | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 014746d0..e2775d9c 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -176,8 +176,8 @@ namespace osuCrypto static block tagBit(const block& b) { auto bit = b & block(0, 1); - auto mask = _mm_sub_epi64(_mm_set1_epi64x(0), bit); - return _mm_unpacklo_epi64(mask, mask); + auto mask = block(0,0).sub_epi64(bit); + return mask.unpacklo_epi64(mask); } }; diff --git a/libOTe_Tests/OT_Tests.cpp b/libOTe_Tests/OT_Tests.cpp index 1c8346e5..fd40c6d3 100644 --- a/libOTe_Tests/OT_Tests.cpp +++ b/libOTe_Tests/OT_Tests.cpp @@ -254,7 +254,7 @@ namespace tests_libOTe for (u64 i = 0; i < 10000; ++i) { transpose128(data.data()); - data[0] += block::allSame((u64)1); + data[0] = data[0].add_epi64(block::allSame((u64)1)); } // Add a check just to make sure this doesn't get compiled out. diff --git a/libOTe_Tests/Pprf_Tests.cpp b/libOTe_Tests/Pprf_Tests.cpp index c9db73aa..7b556685 100644 --- a/libOTe_Tests/Pprf_Tests.cpp +++ b/libOTe_Tests/Pprf_Tests.cpp @@ -382,7 +382,7 @@ void Tools_Pprf_inter_test(const CLP& cmd) for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true, false }) for (auto e : { true, false }) { Tools_Pprf_test_impl(d, n, p, f, e, v); - Tools_Pprf_test_impl(d, n, p, f, e, v); + Tools_Pprf_test_impl(d, n, p, f, e, v); } } @@ -397,7 +397,7 @@ void Tools_Pprf_ByLeafIndex_test(const CLP& cmd) for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true/*, false */}) for (auto e : { true/*, false */}) { Tools_Pprf_test_impl(d, n, p, f, e, v); - Tools_Pprf_test_impl(d, n, p, f, e, v); + Tools_Pprf_test_impl(d, n, p, f, e, v); } #else throw UnitTestSkipped("ENABLE_SILENTOT not defined."); @@ -416,7 +416,7 @@ void Tools_Pprf_ByTreeIndex_test(const oc::CLP& cmd) for (auto d : { 32,3242 }) for (auto n : { 8, 19}) for (auto p : { true/*, false*/ }) { Tools_Pprf_test_impl(d, n, p, f, false, v); - Tools_Pprf_test_impl(d, n, p, f, false, v); + Tools_Pprf_test_impl(d, n, p, f, false, v); } #else @@ -435,7 +435,7 @@ void Tools_Pprf_callback_test(const oc::CLP& cmd) for (auto d : { 32,3242 }) for (auto n : { 8, 128 }) for (auto p : { true/*, false */}) { Tools_Pprf_test_impl(d, n, p, f, false, v); - Tools_Pprf_test_impl(d, n, p, f, false, v); + Tools_Pprf_test_impl(d, n, p, f, false, v); } #else throw UnitTestSkipped("ENABLE_SILENTOT not defined."); From 1d72b981efd188198ff93cfc5106c0a38330dae4 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 14:29:09 -0800 Subject: [PATCH 31/48] mac ci debug --- frontend/main.cpp | 3 +-- libOTe/Dpf/RegularDpf.h | 2 +- libOTe/Tools/Tools.cpp | 2 +- libOTe_Tests/RegularDpf_Tests.cpp | 4 ++++ 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/frontend/main.cpp b/frontend/main.cpp index 7feac005..b7f83289 100644 --- a/frontend/main.cpp +++ b/frontend/main.cpp @@ -16,20 +16,19 @@ #include "benchmark.h" #include "ExampleBase.h" -#include "benchmark.h" #include "ExampleTwoChooseOne.h" #include "ExampleNChooseOne.h" #include "ExampleSilent.h" #include "ExampleVole.h" #include "ExampleMessagePassing.h" #include "libOTe/Tools/LDPC/Util.h" -#include "cryptoTools/Crypto/RandomOracle.h" #include "libOTe/Tools/EACode/EAChecker.h" #include "libOTe/Tools/ExConvCode/ExConvChecker.h" #include "libOTe/TwoChooseOne/Iknp/IknpOtExtSender.h" #include "libOTe/TwoChooseOne/Iknp/IknpOtExtReceiver.h" + using namespace osuCrypto; #ifdef ENABLE_IKNP void minimal() diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index e2775d9c..04d25f76 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -715,7 +715,7 @@ namespace osuCrypto auto seed = seeds[p][a]; auto temp = mAesFixedKey.ecbEncBlock(seed); seeds[p][0] = AES::roundEnc(temp, seed); - seeds[p][1] = temp + seed; + seeds[p][1] = temp.add_epi64(seed); } } } diff --git a/libOTe/Tools/Tools.cpp b/libOTe/Tools/Tools.cpp index 15c3362b..95007e86 100644 --- a/libOTe/Tools/Tools.cpp +++ b/libOTe/Tools/Tools.cpp @@ -691,7 +691,7 @@ namespace osuCrypto { auto out0 = outStart + (chunkSize * subBlockHight + hh) * 8 * out.stride() + w * 2; out0 -= out.stride() * skip; - t.blks[0] = (t.blks[0] << int(skip)); + t.blks[0] = t.blks[0].slli_epi64(skip); for (int j = 0; j < rem; j++) { diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 30fc1b8f..1b854471 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -258,6 +258,10 @@ void RegularDpf_Puncture_Test(const oc::CLP& cmd) if (t == 1 && act == ZeroBlock) throw RTE_LOC; + if (t) + { + std::cout << act <<" " << output[0][k][i]<<" ^ "< Date: Wed, 26 Feb 2025 14:31:07 -0800 Subject: [PATCH 32/48] mac ci debug --- libOTe_Tests/RegularDpf_Tests.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 1b854471..9c6110cd 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -245,6 +245,7 @@ void RegularDpf_Puncture_Test(const oc::CLP& cmd) )); + bool failed = false; for (u64 i = 0; i < domain; ++i) { for (u64 k = 0; k < numPoints; ++k) @@ -257,7 +258,7 @@ void RegularDpf_Puncture_Test(const oc::CLP& cmd) throw RTE_LOC; if (t == 1 && act == ZeroBlock) - throw RTE_LOC; + failed = true; if (t) { std::cout << act <<" " << output[0][k][i]<<" ^ "< Date: Wed, 26 Feb 2025 14:41:36 -0800 Subject: [PATCH 33/48] mac ci debug --- .github/workflows/build-test.yml | 286 +++++++++++++++---------------- libOTe/Dpf/RegularDpf.h | 7 +- 2 files changed, 145 insertions(+), 148 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 2bba8f1a..bb802c25 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -16,121 +16,121 @@ on: # A workflow run is made up of one or more jobs that can run sequentially or in parallel jobs: # This workflow contains a single job called "build" - build-ubuntu: - # The type of runner that the job will run on - runs-on: ubuntu-latest - timeout-minutes: 30 - - # Steps represent a sequence of tasks that will be executed as part of the job - steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v2 - with: - submodules: recursive - - # Runs a set of commands using the runners shell + # build-ubuntu: + # # The type of runner that the job will run on + # runs-on: ubuntu-latest + # timeout-minutes: 30 + + # # Steps represent a sequence of tasks that will be executed as part of the job + # steps: + # # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + # - uses: actions/checkout@v2 + # with: + # submodules: recursive + + # # Runs a set of commands using the runners shell - #- name: build relic - # run: python3 build.py -DENABLE_BOOST=OFF -DENABLE_SODIUM=OFF -DENABLE_ASAN=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo + # #- name: build relic + # # run: python3 build.py -DENABLE_BOOST=OFF -DENABLE_SODIUM=OFF -DENABLE_ASAN=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo - #- name: build bitpolymul - # run: python3 build.py --bitpolymul --par=4 -DVERBOSE_FETCH=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo + # #- name: build bitpolymul + # # run: python3 build.py --bitpolymul --par=4 -DVERBOSE_FETCH=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo - - name: build libOTe - run: python3 build.py --par=4 -D ENABLE_ALL_OT=ON -DENABLE_CIRCUITS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DENABLE_ASAN=ON -DENABLE_MOCK_OT=true + # - name: build libOTe + # run: python3 build.py --par=4 -D ENABLE_ALL_OT=ON -DENABLE_CIRCUITS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DENABLE_ASAN=ON -DENABLE_MOCK_OT=true - - name: unit tests - run: | - ./out/build/linux/frontend/frontend_libOTe -u + # - name: unit tests + # run: | + # ./out/build/linux/frontend/frontend_libOTe -u - - name: find source tree - run: | - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -D CMAKE_PREFIX_PATH=../../ - cmake --build out/ - ./out/main - rm -rf out/ - cd ../.. + # - name: find source tree + # run: | + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -D CMAKE_PREFIX_PATH=../../ + # cmake --build out/ + # ./out/main + # rm -rf out/ + # cd ../.. - - name: hint test - run: | - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -D LIBOTE_HINT=../.. - cmake --build out/ - ./out/main - rm -rf out/ - cd ../.. - - - name: install prefix test - run: | - python3 build.py --install=~/install -DCMAKE_BUILD_TYPE=RelWithDebInfo - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_PREFIX_PATH=~/install - cmake --build out/ - ./out/main - rm -rf out/ - cd ../.. + # - name: hint test + # run: | + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -D LIBOTE_HINT=../.. + # cmake --build out/ + # ./out/main + # rm -rf out/ + # cd ../.. + + # - name: install prefix test + # run: | + # python3 build.py --install=~/install -DCMAKE_BUILD_TYPE=RelWithDebInfo + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_PREFIX_PATH=~/install + # cmake --build out/ + # ./out/main + # rm -rf out/ + # cd ../.. - - name: install test - run: | - python3 build.py --install --sudo -DCMAKE_BUILD_TYPE=RelWithDebInfo - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo - cmake --build out/ - ./out/main - rm -rf out/ - cd ../.. + # - name: install test + # run: | + # python3 build.py --install --sudo -DCMAKE_BUILD_TYPE=RelWithDebInfo + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo + # cmake --build out/ + # ./out/main + # rm -rf out/ + # cd ../.. - - name: build libOTe w/ sodium - run: | - rm ./out/build/linux/frontend/frontend_libOTe - python3 build.py --par=4 -D ENABLE_ALL_OT=ON -D ENABLE_SODIUM=ON -DENABLE_RELIC=OFF -DPRINT_LOG_ON_FAIL=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo + # - name: build libOTe w/ sodium + # run: | + # rm ./out/build/linux/frontend/frontend_libOTe + # python3 build.py --par=4 -D ENABLE_ALL_OT=ON -D ENABLE_SODIUM=ON -DENABLE_RELIC=OFF -DPRINT_LOG_ON_FAIL=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo - - name: unit tests - run: ./out/build/linux/frontend/frontend_libOTe -u + # - name: unit tests + # run: ./out/build/linux/frontend/frontend_libOTe -u - - name: find source tree - run: | - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -DSODIUM=ON -D CMAKE_PREFIX_PATH=../../ - cmake --build out/ - ./out/main - rm -rf out/ - cd ../.. + # - name: find source tree + # run: | + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -DSODIUM=ON -D CMAKE_PREFIX_PATH=../../ + # cmake --build out/ + # ./out/main + # rm -rf out/ + # cd ../.. - - name: hint test - run: | - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -DSODIUM=ON -D LIBOTE_HINT=../.. - cmake --build out/ - ./out/main - rm -rf out/ - cd ../.. - - - name: install prefix test - run: | - python3 build.py --install=~/install - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -DSODIUM=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_PREFIX_PATH=~/install - cmake --build out/ - ./out/main - rm -rf out/ - cd ../.. + # - name: hint test + # run: | + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -DSODIUM=ON -D LIBOTE_HINT=../.. + # cmake --build out/ + # ./out/main + # rm -rf out/ + # cd ../.. + + # - name: install prefix test + # run: | + # python3 build.py --install=~/install + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -DSODIUM=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_PREFIX_PATH=~/install + # cmake --build out/ + # ./out/main + # rm -rf out/ + # cd ../.. - - name: install test - run: | - python3 build.py --install --sudo - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -DSODIUM=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo - cmake --build out/ - ./out/main - rm -rf out/ - cd ../.. + # - name: install test + # run: | + # python3 build.py --install --sudo + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -DSODIUM=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo + # cmake --build out/ + # ./out/main + # rm -rf out/ + # cd ../.. # This workflow contains a single job called "build" build-osx: @@ -151,7 +151,7 @@ jobs: run: python3 build.py -DENABLE_BOOST=OFF -DVERBOSE_FETCH=ON -DENABLE_SSE=OFF -DENABLE_MOCK_OT=true -D ENABLE_ALL_OT=ON - name: unit tests - run: ./out/build/osx/frontend/frontend_libOTe -u + run: ./out/build/osx/frontend/frontend_libOTe -u RegularDpf_Puncture_Test - name: find source tree @@ -194,52 +194,52 @@ jobs: cd ../.. - build-windows: - # The type of runner that the job will run on - runs-on: windows-2022 - timeout-minutes: 30 + # build-windows: + # # The type of runner that the job will run on + # runs-on: windows-2022 + # timeout-minutes: 30 - # Steps represent a sequence of tasks that will be executed as part of the job - steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v2 - with: - submodules: recursive - - uses: seanmiddleditch/gha-setup-ninja@v3 - - uses: ilammy/msvc-dev-cmd@v1 - - # Runs a set of commands using the runners shell - - name: build libOTe - run: python3 build.py --par=1 -D ENABLE_ALL_OT=ON -DENABLE_MOCK_OT=true -G Ninja + # # Steps represent a sequence of tasks that will be executed as part of the job + # steps: + # # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + # - uses: actions/checkout@v2 + # with: + # submodules: recursive + # - uses: seanmiddleditch/gha-setup-ninja@v3 + # - uses: ilammy/msvc-dev-cmd@v1 + + # # Runs a set of commands using the runners shell + # - name: build libOTe + # run: python3 build.py --par=1 -D ENABLE_ALL_OT=ON -DENABLE_MOCK_OT=true -G Ninja - - name: unit test - run: ./out/build/x64-Release/frontend/frontend_libOTe.exe -u + # - name: unit test + # run: ./out/build/x64-Release/frontend/frontend_libOTe.exe -u - - name: find source tree - run: | - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -DCMAKE_BUILD_TYPE=Release -D CMAKE_PREFIX_PATH=../../ - cmake --build out/ --config Release - ./out/Release/main.exe - rm -r -fo out/ - cd ../.. + # - name: find source tree + # run: | + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=Release -D CMAKE_PREFIX_PATH=../../ + # cmake --build out/ --config Release + # ./out/Release/main.exe + # rm -r -fo out/ + # cd ../.. - - name: hint test - run: | - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -D LIBOTE_HINT=../.. - cmake --build out/ --config Release - ./out/Release/main.exe - rm -r -fo out/ - cd ../.. - - - name: install prefix test - run: | - python3 build.py --install=~/install - cd libOTe_Tests/cmakeTests - cmake -S . -B out/ -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=~/install - cmake --build out/ --config Release - ./out/Release/main.exe - rm -r -fo out/ - cd ../.. + # - name: hint test + # run: | + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -D LIBOTE_HINT=../.. + # cmake --build out/ --config Release + # ./out/Release/main.exe + # rm -r -fo out/ + # cd ../.. + + # - name: install prefix test + # run: | + # python3 build.py --install=~/install + # cd libOTe_Tests/cmakeTests + # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=~/install + # cmake --build out/ --config Release + # ./out/Release/main.exe + # rm -r -fo out/ + # cd ../.. diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 04d25f76..7161f8fc 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -615,12 +615,9 @@ namespace osuCrypto { auto sdi = getRow(sd, i); auto tdi = getRow(td, i); - for (u64 k = 0; k < numPoints8; k += 8) - { - SIMD8(q, output(k + q, i, sdi[k + q], tdi[k + q])); - } - for (u64 k = numPoints8; k < mNumPoints; ++k) + for (u64 k = 0; k < mNumPoints; ++k) { + std::cout << "k " << k << " i " << i << " " << sdi[k] << std::endl; output(k, i, sdi[k], tdi[k]); } } From 5da5b1b4341e5e912d86b4a2dbdac9f7e468e9ac Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 15:02:29 -0800 Subject: [PATCH 34/48] mac compile fixes --- CMakePresets.json | 4 ++-- libOTe/Dpf/RegularDpf.h | 8 +++++++- libOTe_Tests/RegularDpf_Tests.cpp | 4 ++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index 23a7fc71..f2a12563 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -70,8 +70,8 @@ "FETCH_AUTO": "ON", "ENABLE_CIRCUITS": true, "VERBOSE_FETCH": true, - "ENABLE_SSE": true, - "ENABLE_AVX": true, + "ENABLE_SSE": false, + "ENABLE_AVX": false, "ENABLE_ASAN": true, "CMAKE_INSTALL_PREFIX": "${sourceDir}/out/install/${presetName}", "CMAKE_PREFIX_PATH": "${sourceDir}/../out/install/${presetName}" diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 7161f8fc..70f402d9 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -504,6 +504,9 @@ namespace osuCrypto z[0][k] ^= childSeed[j * 2 + 0][k]; z[1][k] ^= childSeed[j * 2 + 1][k]; + std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " split " << childSeed[j * 2 + 0][k] << " " << childSeed[j * 2 + 1][k] << std::endl; + + currentSeed[j][k] = tagBit(currentSeed[j][k]); } } @@ -554,6 +557,9 @@ namespace osuCrypto tag[j][k] = tagBit(temp[0]); currentSeed[j][k] = AES::roundFn(temp[0], temp[0]); diff[k] ^= currentSeed[j][k]; + + std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " leaf " << currentSeed[j][k] << std::endl; + } } } @@ -617,7 +623,7 @@ namespace osuCrypto auto tdi = getRow(td, i); for (u64 k = 0; k < mNumPoints; ++k) { - std::cout << "k " << k << " i " << i << " " << sdi[k] << std::endl; + std::cout<<"p " << mPartyIdx << " k " << k << " i " << i << " out " << sdi[k] << std::endl; output(k, i, sdi[k], tdi[k]); } } diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 9c6110cd..2b30e67e 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -194,8 +194,8 @@ void RegularDpf_Puncture_Test(const oc::CLP& cmd) { PRNG prng(block(231234, 321312)); - u64 domain = cmd.getOr("domain", 211); - u64 numPoints = cmd.getOr("numPoints", 7); + u64 domain = cmd.getOr("domain", 8); + u64 numPoints = cmd.getOr("numPoints", 1); std::vector points0(numPoints); std::vector points1(numPoints); for (u64 i = 0; i < numPoints; ++i) From 1bc509ca56ee51b2e99fd1dc48a8a1701ab95ae7 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 15:07:43 -0800 Subject: [PATCH 35/48] mac compile fixes --- libOTe/Dpf/RegularDpf.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 70f402d9..4f4812d3 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -254,6 +254,9 @@ namespace osuCrypto coproto::Socket& sock, RegularDpfKey* outputKey) { + + std::cout <<"p " << mPartyIdx << " seed " << seed << std::endl; + if (inputKey == nullptr) { if (points.size() != mNumPoints) @@ -317,6 +320,9 @@ namespace osuCrypto sc0[k] = basePeng.get(); sc1[k] = basePeng.get(); + std::cout << "p " << mPartyIdx << " k " << k << " root " << sc0[k] << " " << sc1[k] << std::endl; + + tag[k] = block::allSame(-mPartyIdx); z[0][k] = sc0[k]; From 3225f82bcbbfd4cca665462f9d98e349b5161d26 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 15:52:47 -0800 Subject: [PATCH 36/48] mac compile fixes --- libOTe_Tests/RegularDpf_Tests.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 2b30e67e..c7c78d8b 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -194,6 +194,7 @@ void RegularDpf_Puncture_Test(const oc::CLP& cmd) { PRNG prng(block(231234, 321312)); + std::cout << "pp " << prng.get() << std::endl; u64 domain = cmd.getOr("domain", 8); u64 numPoints = cmd.getOr("numPoints", 1); std::vector points0(numPoints); From f8939a94e22463ba9dcfbad337398fea47cb80fd Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 15:56:01 -0800 Subject: [PATCH 37/48] mac compile fixes --- libOTe_Tests/RegularDpf_Tests.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index c7c78d8b..319fbec0 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -238,11 +238,13 @@ void RegularDpf_Puncture_Test(const oc::CLP& cmd) output[1].resize(numPoints, domain); tags[0].resize(numPoints, domain); tags[1].resize(numPoints, domain); + auto seed0 = prng.get(); + auto seed1 = prng.get(); auto sock = coproto::LocalAsyncSocket::makePair(); macoro::sync_wait(macoro::when_all_ready( - dpf[0].expand(points0, {}, prng.get(), [&](auto k, auto i, auto v, block t) { output[0](k, i) = v; tags[0](k, i) = t.get(0) & 1; }, sock[0]), - dpf[1].expand(points1, {}, prng.get(), [&](auto k, auto i, auto v, block t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }, sock[1]) + dpf[0].expand(points0, {}, seed0, [&](auto k, auto i, auto v, block t) { output[0](k, i) = v; tags[0](k, i) = t.get(0) & 1; }, sock[0]), + dpf[1].expand(points1, {}, seed1, [&](auto k, auto i, auto v, block t) { output[1](k, i) = v; tags[1](k, i) = t.get(0) & 1; }, sock[1]) )); From 8cabf66e28c34ee99f1b0bc8ad172808dabd4b71 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 16:00:34 -0800 Subject: [PATCH 38/48] mac ci debug --- libOTe_Tests/RegularDpf_Tests.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 319fbec0..77eb3364 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -238,8 +238,8 @@ void RegularDpf_Puncture_Test(const oc::CLP& cmd) output[1].resize(numPoints, domain); tags[0].resize(numPoints, domain); tags[1].resize(numPoints, domain); - auto seed0 = prng.get(); - auto seed1 = prng.get(); + auto seed0 = prng.get(); + auto seed1 = prng.get(); auto sock = coproto::LocalAsyncSocket::makePair(); macoro::sync_wait(macoro::when_all_ready( From 72401f9fb95f17cce178051eedb78ea1fd952c1d Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 16:07:02 -0800 Subject: [PATCH 39/48] mac ci debug --- libOTe/Dpf/RegularDpf.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 4f4812d3..36f7fa42 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -559,12 +559,16 @@ namespace osuCrypto { for (u64 j = 0; j < 2; ++j) { + std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " prnt " << currentSeed[j][k] << " ^ " << parentTag[k] << " & " << sigma[j][k] << std::endl; temp[0] = currentSeed[j][k] ^ (parentTag[k] & sigma[j][k]); tag[j][k] = tagBit(temp[0]); + + currentSeed[j][k] = AES::roundFn(temp[0], temp[0]); + diff[k] ^= currentSeed[j][k]; - std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " leaf " << currentSeed[j][k] << std::endl; + std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " leaf " << currentSeed[j][k] << " " << temp[0] << std::endl; } } From d1091db6c7ae4e4088c9ff90d8ce11a7438dcf17 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 16:15:31 -0800 Subject: [PATCH 40/48] mac ci debug --- .github/workflows/build-test.yml | 2 +- cryptoTools | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index bb802c25..e535edbe 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -151,7 +151,7 @@ jobs: run: python3 build.py -DENABLE_BOOST=OFF -DVERBOSE_FETCH=ON -DENABLE_SSE=OFF -DENABLE_MOCK_OT=true -D ENABLE_ALL_OT=ON - name: unit tests - run: ./out/build/osx/frontend/frontend_libOTe -u RegularDpf_Puncture_Test + run: ./out/build/osx/frontend/frontend_libOTe -u RegularDpf_Puncture_Test aes - name: find source tree diff --git a/cryptoTools b/cryptoTools index 6f762c94..a70c9b70 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit 6f762c941110b2aa90d2e2c5dc026d48e2e38bba +Subproject commit a70c9b70211005acf77aaa0a50ed37d0eaf458fa From deb786078e7c173dd9332bbf97ac466237af82e7 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 16:38:33 -0800 Subject: [PATCH 41/48] mac ci debug --- cryptoTools | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cryptoTools b/cryptoTools index a70c9b70..ea446267 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit a70c9b70211005acf77aaa0a50ed37d0eaf458fa +Subproject commit ea4462677a02be8db27f3f8f0f5412c7c3391648 From d5adf47a2309ff7d4defc4dd4f7ee50ac2045f15 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 18:01:21 -0800 Subject: [PATCH 42/48] mac ci debug --- cryptoTools | 2 +- libOTe/Dpf/RegularDpf.h | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cryptoTools b/cryptoTools index ea446267..a99176e2 160000 --- a/cryptoTools +++ b/cryptoTools @@ -1 +1 @@ -Subproject commit ea4462677a02be8db27f3f8f0f5412c7c3391648 +Subproject commit a99176e27a2a0430770d44b759b90f6781f8e8f6 diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 36f7fa42..79677a10 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -564,12 +564,12 @@ namespace osuCrypto tag[j][k] = tagBit(temp[0]); - currentSeed[j][k] = AES::roundFn(temp[0], temp[0]); + auto rr = AES::roundFn(temp[0], temp[0]); - diff[k] ^= currentSeed[j][k]; - - std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " leaf " << currentSeed[j][k] << " " << temp[0] << std::endl; + diff[k] ^= rr; + std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " leaf " << rr << " " << temp[0] << std::endl; + currentSeed[j][k] = rr; } } } From 28c7cb536f7b4011f073bbc98d5da5f453aaf8e0 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 20:31:20 -0800 Subject: [PATCH 43/48] mac ci debug --- libOTe/Dpf/RegularDpf.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 79677a10..997ee6a6 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -564,8 +564,8 @@ namespace osuCrypto tag[j][k] = tagBit(temp[0]); - auto rr = AES::roundFn(temp[0], temp[0]); - + //auto rr = AES::roundFn(temp[0], temp[0]); + auto rr = temp[0]; diff[k] ^= rr; std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " leaf " << rr << " " << temp[0] << std::endl; From 54bcd804bbabfb8dca0452c68b69d66f00ec51b4 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 20:37:11 -0800 Subject: [PATCH 44/48] mac ci debug --- libOTe/Dpf/RegularDpf.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 997ee6a6..2e16741a 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -550,7 +550,7 @@ namespace osuCrypto { SIMD8(q, temp[q] = currentSeed[j][k + q] ^ (parentTag[k + q] & sigma[j][k + q])); SIMD8(q, tag[j][k + q] = tagBit(temp[q])); - SIMD8(q, currentSeed[j][k + q] = AES::roundFn(temp[q], temp[q])); + SIMD8(q, currentSeed[j][k + q] = AES::roundEnc(temp[q], temp[q])); SIMD8(q, diff[k + q] ^= currentSeed[j][k + q]); } } @@ -564,8 +564,8 @@ namespace osuCrypto tag[j][k] = tagBit(temp[0]); - //auto rr = AES::roundFn(temp[0], temp[0]); - auto rr = temp[0]; + auto rr = AES::roundEnc(temp[0], temp[0]); + //auto rr = temp[0]; diff[k] ^= rr; std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " leaf " << rr << " " << temp[0] << std::endl; From c5a0c62b131ad979b1d81fe3a3bcc635118b71c0 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 20:54:03 -0800 Subject: [PATCH 45/48] ci working --- .github/workflows/build-test.yml | 286 +++++++++++++++--------------- libOTe/Dpf/RegularDpf.h | 14 -- libOTe_Tests/RegularDpf_Tests.cpp | 5 - 3 files changed, 143 insertions(+), 162 deletions(-) diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index e535edbe..771ede1f 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -16,121 +16,121 @@ on: # A workflow run is made up of one or more jobs that can run sequentially or in parallel jobs: # This workflow contains a single job called "build" - # build-ubuntu: - # # The type of runner that the job will run on - # runs-on: ubuntu-latest - # timeout-minutes: 30 - - # # Steps represent a sequence of tasks that will be executed as part of the job - # steps: - # # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - # - uses: actions/checkout@v2 - # with: - # submodules: recursive - - # # Runs a set of commands using the runners shell + build-ubuntu: + # The type of runner that the job will run on + runs-on: ubuntu-latest + timeout-minutes: 30 + + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + with: + submodules: recursive + + # Runs a set of commands using the runners shell - # #- name: build relic - # # run: python3 build.py -DENABLE_BOOST=OFF -DENABLE_SODIUM=OFF -DENABLE_ASAN=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo + #- name: build relic + # run: python3 build.py -DENABLE_BOOST=OFF -DENABLE_SODIUM=OFF -DENABLE_ASAN=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo - # #- name: build bitpolymul - # # run: python3 build.py --bitpolymul --par=4 -DVERBOSE_FETCH=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo + #- name: build bitpolymul + # run: python3 build.py --bitpolymul --par=4 -DVERBOSE_FETCH=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo - # - name: build libOTe - # run: python3 build.py --par=4 -D ENABLE_ALL_OT=ON -DENABLE_CIRCUITS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DENABLE_ASAN=ON -DENABLE_MOCK_OT=true + - name: build libOTe + run: python3 build.py --par=4 -D ENABLE_ALL_OT=ON -DENABLE_CIRCUITS=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DENABLE_ASAN=ON -DENABLE_MOCK_OT=true - # - name: unit tests - # run: | - # ./out/build/linux/frontend/frontend_libOTe -u + - name: unit tests + run: | + ./out/build/linux/frontend/frontend_libOTe -u - # - name: find source tree - # run: | - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -D CMAKE_PREFIX_PATH=../../ - # cmake --build out/ - # ./out/main - # rm -rf out/ - # cd ../.. + - name: find source tree + run: | + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -D CMAKE_PREFIX_PATH=../../ + cmake --build out/ + ./out/main + rm -rf out/ + cd ../.. - # - name: hint test - # run: | - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -D LIBOTE_HINT=../.. - # cmake --build out/ - # ./out/main - # rm -rf out/ - # cd ../.. - - # - name: install prefix test - # run: | - # python3 build.py --install=~/install -DCMAKE_BUILD_TYPE=RelWithDebInfo - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_PREFIX_PATH=~/install - # cmake --build out/ - # ./out/main - # rm -rf out/ - # cd ../.. + - name: hint test + run: | + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -D LIBOTE_HINT=../.. + cmake --build out/ + ./out/main + rm -rf out/ + cd ../.. + + - name: install prefix test + run: | + python3 build.py --install=~/install -DCMAKE_BUILD_TYPE=RelWithDebInfo + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_PREFIX_PATH=~/install + cmake --build out/ + ./out/main + rm -rf out/ + cd ../.. - # - name: install test - # run: | - # python3 build.py --install --sudo -DCMAKE_BUILD_TYPE=RelWithDebInfo - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo - # cmake --build out/ - # ./out/main - # rm -rf out/ - # cd ../.. + - name: install test + run: | + python3 build.py --install --sudo -DCMAKE_BUILD_TYPE=RelWithDebInfo + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo + cmake --build out/ + ./out/main + rm -rf out/ + cd ../.. - # - name: build libOTe w/ sodium - # run: | - # rm ./out/build/linux/frontend/frontend_libOTe - # python3 build.py --par=4 -D ENABLE_ALL_OT=ON -D ENABLE_SODIUM=ON -DENABLE_RELIC=OFF -DPRINT_LOG_ON_FAIL=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo + - name: build libOTe w/ sodium + run: | + rm ./out/build/linux/frontend/frontend_libOTe + python3 build.py --par=4 -D ENABLE_ALL_OT=ON -D ENABLE_SODIUM=ON -DENABLE_RELIC=OFF -DPRINT_LOG_ON_FAIL=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo - # - name: unit tests - # run: ./out/build/linux/frontend/frontend_libOTe -u + - name: unit tests + run: ./out/build/linux/frontend/frontend_libOTe -u - # - name: find source tree - # run: | - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -DSODIUM=ON -D CMAKE_PREFIX_PATH=../../ - # cmake --build out/ - # ./out/main - # rm -rf out/ - # cd ../.. + - name: find source tree + run: | + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -DCMAKE_BUILD_TYPE=RelWithDebInfo -DSODIUM=ON -D CMAKE_PREFIX_PATH=../../ + cmake --build out/ + ./out/main + rm -rf out/ + cd ../.. - # - name: hint test - # run: | - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -DSODIUM=ON -D LIBOTE_HINT=../.. - # cmake --build out/ - # ./out/main - # rm -rf out/ - # cd ../.. - - # - name: install prefix test - # run: | - # python3 build.py --install=~/install - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -DSODIUM=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_PREFIX_PATH=~/install - # cmake --build out/ - # ./out/main - # rm -rf out/ - # cd ../.. + - name: hint test + run: | + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -DSODIUM=ON -D LIBOTE_HINT=../.. + cmake --build out/ + ./out/main + rm -rf out/ + cd ../.. + + - name: install prefix test + run: | + python3 build.py --install=~/install + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -DSODIUM=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_PREFIX_PATH=~/install + cmake --build out/ + ./out/main + rm -rf out/ + cd ../.. - # - name: install test - # run: | - # python3 build.py --install --sudo - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -DSODIUM=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo - # cmake --build out/ - # ./out/main - # rm -rf out/ - # cd ../.. + - name: install test + run: | + python3 build.py --install --sudo + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -DSODIUM=ON -DCMAKE_BUILD_TYPE=RelWithDebInfo + cmake --build out/ + ./out/main + rm -rf out/ + cd ../.. # This workflow contains a single job called "build" build-osx: @@ -151,7 +151,7 @@ jobs: run: python3 build.py -DENABLE_BOOST=OFF -DVERBOSE_FETCH=ON -DENABLE_SSE=OFF -DENABLE_MOCK_OT=true -D ENABLE_ALL_OT=ON - name: unit tests - run: ./out/build/osx/frontend/frontend_libOTe -u RegularDpf_Puncture_Test aes + run: ./out/build/osx/frontend/frontend_libOTe -u - name: find source tree @@ -194,52 +194,52 @@ jobs: cd ../.. - # build-windows: - # # The type of runner that the job will run on - # runs-on: windows-2022 - # timeout-minutes: 30 + build-windows: + # The type of runner that the job will run on + runs-on: windows-2022 + timeout-minutes: 30 - # # Steps represent a sequence of tasks that will be executed as part of the job - # steps: - # # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - # - uses: actions/checkout@v2 - # with: - # submodules: recursive - # - uses: seanmiddleditch/gha-setup-ninja@v3 - # - uses: ilammy/msvc-dev-cmd@v1 - - # # Runs a set of commands using the runners shell - # - name: build libOTe - # run: python3 build.py --par=1 -D ENABLE_ALL_OT=ON -DENABLE_MOCK_OT=true -G Ninja + # Steps represent a sequence of tasks that will be executed as part of the job + steps: + # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it + - uses: actions/checkout@v2 + with: + submodules: recursive + - uses: seanmiddleditch/gha-setup-ninja@v3 + - uses: ilammy/msvc-dev-cmd@v1 + + # Runs a set of commands using the runners shell + - name: build libOTe + run: python3 build.py -D ENABLE_ALL_OT=ON -DENABLE_MOCK_OT=true -G Ninja - # - name: unit test - # run: ./out/build/x64-Release/frontend/frontend_libOTe.exe -u + - name: unit test + run: ./out/build/x64-Release/frontend/frontend_libOTe.exe -u - # - name: find source tree - # run: | - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=Release -D CMAKE_PREFIX_PATH=../../ - # cmake --build out/ --config Release - # ./out/Release/main.exe - # rm -r -fo out/ - # cd ../.. + - name: find source tree + run: | + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -DCMAKE_BUILD_TYPE=Release -D CMAKE_PREFIX_PATH=../../ + cmake --build out/ --config Release + ./out/Release/main.exe + rm -r -fo out/ + cd ../.. - # - name: hint test - # run: | - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -D LIBOTE_HINT=../.. - # cmake --build out/ --config Release - # ./out/Release/main.exe - # rm -r -fo out/ - # cd ../.. - - # - name: install prefix test - # run: | - # python3 build.py --install=~/install - # cd libOTe_Tests/cmakeTests - # cmake -S . -B out/ -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=~/install - # cmake --build out/ --config Release - # ./out/Release/main.exe - # rm -r -fo out/ - # cd ../.. + - name: hint test + run: | + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -D LIBOTE_HINT=../.. + cmake --build out/ --config Release + ./out/Release/main.exe + rm -r -fo out/ + cd ../.. + + - name: install prefix test + run: | + python3 build.py --install=~/install + cd libOTe_Tests/cmakeTests + cmake -S . -B out/ -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH=~/install + cmake --build out/ --config Release + ./out/Release/main.exe + rm -r -fo out/ + cd ../.. diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 2e16741a..403e8020 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -255,8 +255,6 @@ namespace osuCrypto RegularDpfKey* outputKey) { - std::cout <<"p " << mPartyIdx << " seed " << seed << std::endl; - if (inputKey == nullptr) { if (points.size() != mNumPoints) @@ -319,10 +317,6 @@ namespace osuCrypto { sc0[k] = basePeng.get(); sc1[k] = basePeng.get(); - - std::cout << "p " << mPartyIdx << " k " << k << " root " << sc0[k] << " " << sc1[k] << std::endl; - - tag[k] = block::allSame(-mPartyIdx); z[0][k] = sc0[k]; @@ -510,9 +504,6 @@ namespace osuCrypto z[0][k] ^= childSeed[j * 2 + 0][k]; z[1][k] ^= childSeed[j * 2 + 1][k]; - std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " split " << childSeed[j * 2 + 0][k] << " " << childSeed[j * 2 + 1][k] << std::endl; - - currentSeed[j][k] = tagBit(currentSeed[j][k]); } } @@ -559,16 +550,12 @@ namespace osuCrypto { for (u64 j = 0; j < 2; ++j) { - std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " prnt " << currentSeed[j][k] << " ^ " << parentTag[k] << " & " << sigma[j][k] << std::endl; temp[0] = currentSeed[j][k] ^ (parentTag[k] & sigma[j][k]); tag[j][k] = tagBit(temp[0]); auto rr = AES::roundEnc(temp[0], temp[0]); - //auto rr = temp[0]; diff[k] ^= rr; - - std::cout << "p " << mPartyIdx << " k " << k << " j " << j << " leaf " << rr << " " << temp[0] << std::endl; currentSeed[j][k] = rr; } } @@ -633,7 +620,6 @@ namespace osuCrypto auto tdi = getRow(td, i); for (u64 k = 0; k < mNumPoints; ++k) { - std::cout<<"p " << mPartyIdx << " k " << k << " i " << i << " out " << sdi[k] << std::endl; output(k, i, sdi[k], tdi[k]); } } diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/RegularDpf_Tests.cpp index 77eb3364..34bf4e02 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/RegularDpf_Tests.cpp @@ -194,7 +194,6 @@ void RegularDpf_Puncture_Test(const oc::CLP& cmd) { PRNG prng(block(231234, 321312)); - std::cout << "pp " << prng.get() << std::endl; u64 domain = cmd.getOr("domain", 8); u64 numPoints = cmd.getOr("numPoints", 1); std::vector points0(numPoints); @@ -262,10 +261,6 @@ void RegularDpf_Puncture_Test(const oc::CLP& cmd) if (t == 1 && act == ZeroBlock) failed = true; - if (t) - { - std::cout << act <<" " << output[0][k][i]<<" ^ "< Date: Wed, 26 Feb 2025 20:58:57 -0800 Subject: [PATCH 46/48] dpf fix --- libOTe/Dpf/RegularDpf.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libOTe/Dpf/RegularDpf.h b/libOTe/Dpf/RegularDpf.h index 403e8020..1d3fff20 100644 --- a/libOTe/Dpf/RegularDpf.h +++ b/libOTe/Dpf/RegularDpf.h @@ -729,7 +729,7 @@ namespace osuCrypto for (u64 p = 0; p < 2; ++p) { - seeds[p][a] = AES::roundFn(seeds[p][a], seeds[p][a]); + seeds[p][a] = AES::roundEnc(seeds[p][a], seeds[p][a]); } auto diff = seeds[0][a] ^ seeds[1][a]; From b0b57507be376866ee32148dd30e502845c429b6 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 22:01:57 -0800 Subject: [PATCH 47/48] added conditional compile options --- cmake/buildOptions.cmake | 57 +++++--- cmake/buildOptions.cmake.in | 11 +- frontend/benchmark.h | 23 ++- frontend/main.cpp | 2 +- libOTe/Dpf/DpfMult.h | 5 +- libOTe/Dpf/RegularDpf.h | 7 +- libOTe/Dpf/SparseDpf.h | 4 + libOTe/Dpf/{TriDpf.h => TernaryDpf.h} | 131 ++--------------- libOTe/Triple/Foleage/FoleageTriple.cpp | 20 ++- libOTe/Triple/Foleage/FoleageTriple.h | 21 ++- libOTe/Triple/Foleage/FoleageUtils.h | 136 +++++++++++++++++- libOTe/Triple/Foleage/fft/FoleageFft.cpp | 42 ++++-- libOTe/Triple/Foleage/fft/FoleageFft.h | 23 ++- libOTe/config.h.in | 13 ++ libOTe_Tests/CMakeLists.txt | 2 +- .../{RegularDpf_Tests.cpp => Dpf_Tests.cpp} | 46 ++++-- .../{RegularDpf_Tests.h => Dpf_Tests.h} | 0 libOTe_Tests/Foleage_Tests.cpp | 27 +++- libOTe_Tests/UnitTests.cpp | 2 +- 19 files changed, 374 insertions(+), 198 deletions(-) rename libOTe/Dpf/{TriDpf.h => TernaryDpf.h} (89%) rename libOTe_Tests/{RegularDpf_Tests.cpp => Dpf_Tests.cpp} (94%) rename libOTe_Tests/{RegularDpf_Tests.h => Dpf_Tests.h} (100%) diff --git a/cmake/buildOptions.cmake b/cmake/buildOptions.cmake index aa5eedf2..773f4ec8 100644 --- a/cmake/buildOptions.cmake +++ b/cmake/buildOptions.cmake @@ -60,15 +60,6 @@ if(DEFINED ENABLE_ALL_OT) set(ENABLE_SIMPLESTOT_ASM ${oc_BB} CACHE BOOL "" FORCE) set(ENABLE_MR_KYBER ${oc_BB} CACHE BOOL "" FORCE) - # requires sse - if(ENABLE_SSE) - set(oc_BB ${ENABLE_ALL_OT}) - else() - set(oc_BB OFF) - endif() - set(ENABLE_SILENTOT ${oc_BB} CACHE BOOL "" FORCE) - - # general set(ENABLE_KOS ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) set(ENABLE_IKNP ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) @@ -77,17 +68,18 @@ if(DEFINED ENABLE_ALL_OT) set(ENABLE_OOS ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) set(ENABLE_KKRT ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) set(ENABLE_SILENTOT ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) - set(ENABLE_SILENT_VOLE ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) + set(ENABLE_SILENT_VOLE ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) + set(ENABLE_FOLEAGE ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) + set(ENABLE_REGULAR_DPF ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) + set(ENABLE_TERNARY_DPF ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) + set(ENABLE_SPARSE_DPF ${ENABLE_ALL_OT} CACHE BOOL "" FORCE) + unset(ENABLE_ALL_OT CACHE) endif() -if(APPLE) - option(ENABLE_BITPOLYMUL "Build with bit poly mul inegration" FALSE) -else() - option(ENABLE_BITPOLYMUL "Build with bit poly mul inegration" TRUE) -endif() +option(ENABLE_BITPOLYMUL "Build with bit poly mul inegration" FALSE) option(ENABLE_MOCK_OT "Build the insecure mock base OT" OFF) @@ -111,11 +103,13 @@ option(ENABLE_KKRT "Build the KKRT 1-oo-N OT-Ext protocol." OFF) option(ENABLE_PPRF "Build the PPRF protocol." OFF) option(ENABLE_SILENT_VOLE "Build the Silent Vole protocol." OFF) -option(ENABLE_INSECURE_SILVER "Build with silver codes." OFF) -option(ENABLE_LDPC "Build with ldpc functions." OFF) -if(ENABLE_INSECURE_SILVER) - set(ENABLE_LDPC ON) -endif() +option(ENABLE_FOLEAGE "Build the Foleage OLE protocol." OFF) + + +option(ENABLE_REGULAR_DPF "Build the Regular DPF protocol." OFF) +option(ENABLE_TERNARY_DPF "Build the Ternary DPF protocol." OFF) +option(ENABLE_SPARSE_DPF "Build the Sparse DPF protocol." OFF) + option(NO_KOS_WARNING "Build with no kos security warning." OFF) @@ -133,6 +127,14 @@ if(ENABLE_IKNP) set(ENABLE_KOS true) endif() +if(ENABLE_FOLEAGE) + set(ENABLE_TERNARY_DPF true) +endif() + +if(ENABLE_SPARSE_DPF) + set(ENABLE_REGULAR_DPF true) +endif() + message(STATUS "General Options\n=======================================================") message(STATUS "Option: VERBOSE_FETCH = ${VERBOSE_FETCH}") @@ -160,7 +162,17 @@ message(STATUS "1-out-of-2 Delta-OT Extension protocols\n======================= message(STATUS "Option: ENABLE_DELTA_KOS = ${ENABLE_DELTA_KOS}\n\n") message(STATUS "Vole protocols\n=======================================================") -message(STATUS "Option: ENABLE_SILENT_VOLE = ${ENABLE_SILENT_VOLE}\n\n") +message(STATUS "Option: ENABLE_SILENT_VOLE = ${ENABLE_SILENT_VOLE}\n\n") + + +message(STATUS "DPF protocols\n=======================================================") +message(STATUS "Option: ENABLE_REGULAR_DPF = ${ENABLE_REGULAR_DPF}") +message(STATUS "Option: ENABLE_SPARSE_DPF = ${ENABLE_SPARSE_DPF}") +message(STATUS "Option: ENABLE_TERNARY_DPF = ${ENABLE_TERNARY_DPF}") +message(STATUS "Option: ENABLE_PPRF = ${ENABLE_PPRF}\n\n") + +message(STATUS "OLE and Triple protocols\n=======================================================") +message(STATUS "Option: ENABLE_FOLEAGE = ${ENABLE_FOLEAGE}\n\n") message(STATUS "1-out-of-N OT Extension protocols\n=======================================================") message(STATUS "Option: ENABLE_OOS = ${ENABLE_OOS}") @@ -168,8 +180,7 @@ message(STATUS "Option: ENABLE_KKRT = ${ENABLE_KKRT}\n\n") message(STATUS "other \n=======================================================") -message(STATUS "Option: NO_KOS_WARNING = ${NO_KOS_WARNING}") -message(STATUS "Option: ENABLE_PPRF = ${ENABLE_PPRF}\n\n") +message(STATUS "Option: NO_KOS_WARNING = ${NO_KOS_WARNING}\n\n") ############################################# # Config Checks # diff --git a/cmake/buildOptions.cmake.in b/cmake/buildOptions.cmake.in index f3329773..9cf876f9 100644 --- a/cmake/buildOptions.cmake.in +++ b/cmake/buildOptions.cmake.in @@ -68,7 +68,12 @@ set(ENABLE_DELTA_KOS @ENABLE_DELTA_KOS@) set(ENABLE_OOS @ENABLE_OOS@) set(ENABLE_KKRT @ENABLE_KKRT@) set(ENABLE_SILENT_VOLE @ENABLE_SILENT_VOLE@) -set(NO_SILVER_WARNING @NO_SILVER_WARNING@) + +set(ENABLE_FOLEAGE @ENABLE_FOLEAGE@) + +set(ENABLE_REGULAR_DPF @ENABLE_REGULAR_DPF@) +set(ENABLE_SPARSE_DPF @ENABLE_SPARSE_DPF@) +set(ENABLE_TERNARY_DPF @ENABLE_TERNARY_DPF@) set(ENABLE_PPRF @ENABLE_PPRF@) @@ -127,5 +132,9 @@ set(libOTe_delta_kos_FOUND ${ENABLE_DELTA_KOS}) set(libOTe_silent_vole_FOUND ${ENABLE_SILENT_VOLE}) set(libOTe_oos_FOUND ${ENABLE_OOS}) set(libOTe_kkrt_FOUND ${ENABLE_KKRT}) +set(libOTe_foleage_FOUND ${ENABLE_FOLEAGE}) +set(libOTe_regular_dpf_FOUND ${ENABLE_REGULAR_DPF}) +set(libOTe_ternary_dpf_FOUND ${ENABLE_TERNARY_DPF}) +set(libOTe_sparse_dpf_FOUND ${ENABLE_SPARSE_DPF}) diff --git a/frontend/benchmark.h b/frontend/benchmark.h index 855daf72..3aa1b90c 100644 --- a/frontend/benchmark.h +++ b/frontend/benchmark.h @@ -16,7 +16,7 @@ #include "libOTe/Tools/TungstenCode/TungstenCode.h" #include "libOTe/Tools/ExConvCodeOld/ExConvCodeOld.h" #include "libOTe/Dpf/RegularDpf.h" -#include "libOTe/Dpf/TriDpf.h" +#include "libOTe/Dpf/TernaryDpf.h" #include "libOTe/Triple/Foleage/FoleageTriple.h" namespace osuCrypto @@ -695,6 +695,7 @@ namespace osuCrypto void RegularDpfBenchmark(const oc::CLP& cmd) { +#ifdef ENABLE_REGULAR_DPF PRNG prng(block(231234, 321312)); u64 trials = cmd.getOr("t", 100); u64 domain = 1ull << cmd.getOr("d", 10); @@ -765,14 +766,15 @@ namespace osuCrypto if (cmd.isSet("v")) std::cout << timer << std::endl; +#else + std::cout << "ENABLE_REGULAR_DPF = false" << std::endl; +#endif } - - void TriDpfBenchmark(const oc::CLP& cmd) + void TernaryDpfBenchmark(const oc::CLP& cmd) { - //using F = FoleageF4x243; - //using Ctx = FoleageCoeffCtx; +#ifdef ENABLE_TERNARY_DPF using F = block; using Ctx = CoeffCtxGF2; Timer timer; @@ -804,7 +806,7 @@ namespace osuCrypto for (u64 i = 0; i < trials; ++i) { - std::array, 2> dpf; + std::array, 2> dpf; dpf[0].init(0, domain, numPoints); dpf[1].init(1, domain, numPoints); @@ -852,6 +854,9 @@ namespace osuCrypto } std::cout << timer << std::endl; +#else + std::cout << "ENABLE_TERNARY_DPF = false" << std::endl; +#endif } @@ -862,7 +867,8 @@ namespace osuCrypto // checks correctness of the resulting OLE correlation. void FoleageBenchmark(const CLP& cmd) { - +#ifdef ENABLE_FOLEAGE + auto logn = cmd.getOr("nn", 10); u64 n = ipow(3, logn); auto blocks = divCeil(n, 128); @@ -947,5 +953,8 @@ namespace osuCrypto } work = {}; std::cout << "n="<> (i * 2)) & 3; - } - - return r; - } - - void fromInt(u64 v) - { - mVal = 0; - for (u64 i = 0; i < 32; ++i) - { - mVal |= (v % 3) << (i * 2); - v /= 3; - } - } - - F3x32 lower(u64 digits) - { - F3x32 r; - r.mVal = mVal & ((1ull << (2 * digits)) - 1); - return r; - } - F3x32 upper(u64 digits) - { - F3x32 r; - r.mVal = mVal >> (2 * digits); - return r; - } - - // returns the i'th Z_3 element. - u8 operator[](u64 i) const - { - return (mVal >> (i * 2)) & 3; - } - }; - - inline std::ostream& operator<<(std::ostream& o, const F3x32& t) - { - u64 m = 0; - u64 v = t.mVal; - while (v) - { - ++m; - v >>= 2; - } - if (!m) - o << "0"; - else - { - for (u64 i = m - 1; i < m; --i) - { - o << int(t[i]); - } - } - return o; - } - template< typename F, typename CoeffCtx = DefaultCoeffCtx > - struct TriDpf + struct TernaryDpf { using VecF = typename CoeffCtx::template Vec; @@ -245,11 +131,11 @@ namespace osuCrypto for (u64 j = 0; j < mDepth; ++j) { if ((v & 3) == 3) - throw std::runtime_error("TriDpf: invalid point sharing. Expects the input points to be shared over Z_3^D where each Z_3 elements takes up 2 bits of a the value. " LOCATION); + throw std::runtime_error("TernaryDpf: invalid point sharing. Expects the input points to be shared over Z_3^D where each Z_3 elements takes up 2 bits of a the value. " LOCATION); v >>= 2; } if (v) - throw std::runtime_error("TriDpf: invalid point sharing. point is larger than 3^D " LOCATION); + throw std::runtime_error("TernaryDpf: invalid point sharing. point is larger than 3^D " LOCATION); } u64 numPoints8 = mNumPoints / 8 * 8; @@ -268,7 +154,7 @@ namespace osuCrypto auto ret = MatrixView((T*)allocIter, rows, cols); allocIter += sizeof(T) * ret.size(); if (allocIter > allocation.data() + allocSize) - throw std::runtime_error("TriDpf: allocation error. " LOCATION); + throw std::runtime_error("TernaryDpf: allocation error. " LOCATION); return ret; }; @@ -771,4 +657,5 @@ namespace osuCrypto } -#undef SIMD8 \ No newline at end of file +#undef SIMD8 +#endif diff --git a/libOTe/Triple/Foleage/FoleageTriple.cpp b/libOTe/Triple/Foleage/FoleageTriple.cpp index aebe4310..191abe5c 100644 --- a/libOTe/Triple/Foleage/FoleageTriple.cpp +++ b/libOTe/Triple/Foleage/FoleageTriple.cpp @@ -1,8 +1,21 @@ +// © 2025 Peter Rindal. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Code partially authored by: +// Maxime Bombar, Dung Bui, Geoffroy Couteau, Alain Couvreur, Clément Ducros, and Sacha Servan - Schreiber + +#include "libOTe/config.h" +#if defined(ENABLE_FOLEAGE) + #include "FoleageTriple.h" #include "libOTe/Triple/Foleage/FoleageUtils.h" #include "libOTe/Triple/Foleage/fft/FoleageFft.h" #include "cryptoTools/Common/BitIterator.h" -#include "libOTe/Dpf/TriDpf.h" +#include "libOTe/Dpf/TernaryDpf.h" #include "libOTe/Base/BaseOT.h" namespace osuCrypto { @@ -533,7 +546,7 @@ namespace osuCrypto } setTimePoint("transpose"); - fft_recursive_uint32(fft, mLog3N, mN / 3); + foleageFftUint32(fft, mLog3N, mN / 3); setTimePoint("product fft"); F4Multiply(mFftASquared, fft, fftRes, mN); setTimePoint("product mult"); @@ -738,4 +751,5 @@ namespace osuCrypto //} -} \ No newline at end of file +} +#endif diff --git a/libOTe/Triple/Foleage/FoleageTriple.h b/libOTe/Triple/Foleage/FoleageTriple.h index fe58d5d4..da63b0d6 100644 --- a/libOTe/Triple/Foleage/FoleageTriple.h +++ b/libOTe/Triple/Foleage/FoleageTriple.h @@ -1,11 +1,25 @@ #pragma once +// © 2025 Peter Rindal. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Code partially authored by: +// Maxime Bombar, Dung Bui, Geoffroy Couteau, Alain Couvreur, Clément Ducros, and Sacha Servan - Schreiber + + +#include "libOTe/config.h" +#if defined(ENABLE_FOLEAGE) + #include "cryptoTools/Common/Defines.h" #include "cryptoTools/Common/Matrix.h" #include "cryptoTools/Common/Aligned.h" #include "coproto/Socket/Socket.h" #include "cryptoTools/Crypto/PRNG.h" #include "cryptoTools/Common/Timer.h" -#include "libOTe/Dpf/TriDpf.h" +#include "libOTe/Dpf/TernaryDpf.h" #include "libOTe/TwoChooseOne/SoftSpokenOT/SoftSpokenShOtExt.h" namespace osuCrypto @@ -72,7 +86,7 @@ namespace osuCrypto Matrix mSparsePositions; // a dpf used to construct the F4x243 leaf value of the larger DPF. - TriDpf mDpfLeaf; + TernaryDpf mDpfLeaf; #ifdef ENABLE_SOFTSPOKEN_OT std::optional> mOtExtRecver; @@ -91,7 +105,7 @@ namespace osuCrypto }; // the main DPF which outputs 243 F4 elements for each leaf. - TriDpf mDpf; + TernaryDpf mDpf; // The base OTs used to tensor the coefficients of the sparse polynomial. std::vector mRecvOts; @@ -198,3 +212,4 @@ namespace osuCrypto } +#endif \ No newline at end of file diff --git a/libOTe/Triple/Foleage/FoleageUtils.h b/libOTe/Triple/Foleage/FoleageUtils.h index d3834a00..8e6799a4 100644 --- a/libOTe/Triple/Foleage/FoleageUtils.h +++ b/libOTe/Triple/Foleage/FoleageUtils.h @@ -1,4 +1,18 @@ #pragma once +// © 2025 Peter Rindal. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Code partially authored by: +// Maxime Bombar, Dung Bui, Geoffroy Couteau, Alain Couvreur, Clément Ducros, and Sacha Servan - Schreiber + + +#include "libOTe/config.h" +#if defined(ENABLE_FOLEAGE) || defined(ENABLE_TERNARY_DPF) + #include "cryptoTools/Crypto/AES.h" #include "cryptoTools/Crypto/PRNG.h" #include "cryptoTools/Crypto/RandomOracle.h" @@ -9,6 +23,125 @@ namespace osuCrypto { + + // a value representing (Z_3)^32. + // The value is stored in 2 bits per Z_3 element. + struct F3x32 + { + u64 mVal; + + F3x32() = default; + F3x32(const F3x32&) = default; + + F3x32(u64 v) + { + fromInt(v); + } + + F3x32& operator=(const F3x32&) = default; + + F3x32 operator+(const F3x32& t) const + { + F3x32 r; + r.mVal = 0; + for (u64 i = 0; i < 32; ++i) + { + auto a = t[i]; + auto b = (*this)[i]; + auto c = (a + b) % 3; + + r.mVal |= u64(c) << (i * 2); + } + return r; + } + + + F3x32 operator-(const F3x32& t) const + { + F3x32 r; + r.mVal = 0; + for (u64 i = 0; i < 32; ++i) + { + auto a = t[i]; + auto b = (*this)[i]; + auto c = (b + 3 - a) % 3; + + r.mVal |= u64(c) << (i * 2); + } + return r; + } + + + bool operator==(const F3x32& t) const + { + return mVal == t.mVal; + } + + + u64 toInt() const + { + u64 r = 0; + for (u64 i = 31; i < 32; --i) + { + r *= 3; + r += (mVal >> (i * 2)) & 3; + } + + return r; + } + + void fromInt(u64 v) + { + mVal = 0; + for (u64 i = 0; i < 32; ++i) + { + mVal |= (v % 3) << (i * 2); + v /= 3; + } + } + + F3x32 lower(u64 digits) + { + F3x32 r; + r.mVal = mVal & ((1ull << (2 * digits)) - 1); + return r; + } + F3x32 upper(u64 digits) + { + F3x32 r; + r.mVal = mVal >> (2 * digits); + return r; + } + + // returns the i'th Z_3 element. + u8 operator[](u64 i) const + { + return (mVal >> (i * 2)) & 3; + } + }; + + inline std::ostream& operator<<(std::ostream& o, const F3x32& t) + { + u64 m = 0; + u64 v = t.mVal; + while (v) + { + ++m; + v >>= 2; + } + if (!m) + o << "0"; + else + { + for (u64 i = m - 1; i < m; --i) + { + o << int(t[i]); + } + } + return o; + } + + // Multiplies two elements of F4 // and returns the result. inline uint8_t F4Multiply(uint8_t a, uint8_t b) @@ -263,4 +396,5 @@ namespace osuCrypto } return ret; } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/libOTe/Triple/Foleage/fft/FoleageFft.cpp b/libOTe/Triple/Foleage/fft/FoleageFft.cpp index bb8fb491..d3c7b40b 100644 --- a/libOTe/Triple/Foleage/fft/FoleageFft.cpp +++ b/libOTe/Triple/Foleage/fft/FoleageFft.cpp @@ -1,32 +1,45 @@ +// © 2025 Peter Rindal. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Code partially authored by: +// Maxime Bombar, Dung Bui, Geoffroy Couteau, Alain Couvreur, Clément Ducros, and Sacha Servan - Schreiber + + +#include "libOTe/config.h" +#if defined(ENABLE_FOLEAGE) + #include #include #include "libOTe/Triple/Foleage/fft/FoleageFft.h" namespace osuCrypto { - void fft_recursive_uint64( + void foleageFftUint64( span coeffs, const size_t num_vars, const size_t num_coeffs) { - // coeffs (coeffs_h, coeffs_l) are parsed as L(left)|M(middle)|R(right) if (num_vars > 1) { // apply FFT on all left coefficients - fft_recursive_uint64( + foleageFftUint64( coeffs, num_vars - 1, num_coeffs / 3); // apply FFT on all middle coefficients - fft_recursive_uint64( + foleageFftUint64( coeffs.subspan(num_coeffs), num_vars - 1, num_coeffs / 3); // apply FFT on all right coefficients - fft_recursive_uint64( + foleageFftUint64( coeffs.subspan(2 * num_coeffs), num_vars - 1, num_coeffs / 3); @@ -80,7 +93,7 @@ namespace osuCrypto { } } - void fft_recursive_uint32( + void foleageFftUint32( span coeffs, const size_t num_vars, const size_t num_coeffs) @@ -90,19 +103,19 @@ namespace osuCrypto { if (num_vars > 1) { // apply FFT on all left coefficients - fft_recursive_uint32( + foleageFftUint32( coeffs, num_vars - 1, num_coeffs / 3); // apply FFT on all middle coefficients - fft_recursive_uint32( + foleageFftUint32( coeffs.subspan(num_coeffs), num_vars - 1, num_coeffs / 3); // apply FFT on all right coefficients - fft_recursive_uint32( + foleageFftUint32( coeffs.subspan(2 * num_coeffs), num_vars - 1, num_coeffs / 3); @@ -156,7 +169,7 @@ namespace osuCrypto { } } - void fft_recursive_uint16( + void foleageFftUint16( span coeffs, const size_t num_vars, const size_t num_coeffs) @@ -166,19 +179,19 @@ namespace osuCrypto { if (num_vars > 1) { // apply FFT on all left coefficients - fft_recursive_uint16( + foleageFftUint16( coeffs, num_vars - 1, num_coeffs / 3); // apply FFT on all middle coefficients - fft_recursive_uint16( + foleageFftUint16( coeffs.subspan(num_coeffs), num_vars - 1, num_coeffs / 3); // apply FFT on all right coefficients - fft_recursive_uint16( + foleageFftUint16( coeffs.subspan(2 * num_coeffs), num_vars - 1, num_coeffs / 3); @@ -307,4 +320,5 @@ namespace osuCrypto { coeffsM[j] = tM; } } -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/libOTe/Triple/Foleage/fft/FoleageFft.h b/libOTe/Triple/Foleage/fft/FoleageFft.h index bcae8b01..c0b98799 100644 --- a/libOTe/Triple/Foleage/fft/FoleageFft.h +++ b/libOTe/Triple/Foleage/fft/FoleageFft.h @@ -1,4 +1,16 @@ #pragma once +// © 2025 Peter Rindal. +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +// Code partially authored by: +// Maxime Bombar, Dung Bui, Geoffroy Couteau, Alain Couvreur, Clément Ducros, and Sacha Servan - Schreiber + +#include "libOTe/config.h" +#if defined(ENABLE_FOLEAGE) #include #include @@ -8,23 +20,20 @@ namespace osuCrypto { - //typedef __int128 int128_t; - //typedef unsigned __int128 uint128_t; - // FFT for (up to) 32 polynomials over F4 - void fft_recursive_uint64( + void foleageFftUint64( span coeffs, const size_t num_vars, const size_t num_coeffs); // FFT for (up to) 16 polynomials over F4 - void fft_recursive_uint32( + void foleageFftUint32( span coeffs, const size_t num_vars, const size_t num_coeffs); // FFT for (up to) 8 polynomials over F4 - void fft_recursive_uint16( + void foleageFftUint16( span coeffs, const size_t num_vars, const size_t num_coeffs); @@ -37,3 +46,5 @@ namespace osuCrypto { } + +#endif \ No newline at end of file diff --git a/libOTe/config.h.in b/libOTe/config.h.in index 6a0832f2..284b738e 100644 --- a/libOTe/config.h.in +++ b/libOTe/config.h.in @@ -40,6 +40,19 @@ // build the library with SoftSpokenOT enabled #cmakedefine ENABLE_SOFTSPOKEN_OT @ENABLE_SOFTSPOKEN_OT@ +// build the library with Foleage enabled +#cmakedefine ENABLE_FOLEAGE @ENABLE_FOLEAGE@ + +// build the library with regular dpf enabled +#cmakedefine ENABLE_REGULAR_DPF @ENABLE_REGULAR_DPF@ + +// build the library with ternary dpf enabled +#cmakedefine ENABLE_TERNARY_DPF @ENABLE_TERNARY_DPF@ + +// build the library with sparse dpf enabled +#cmakedefine ENABLE_SPARSE_DPF @ENABLE_SPARSE_DPF@ + + // build the library with KOS Delta-OT-ext enabled diff --git a/libOTe_Tests/CMakeLists.txt b/libOTe_Tests/CMakeLists.txt index 8ee65aa7..20b1a580 100644 --- a/libOTe_Tests/CMakeLists.txt +++ b/libOTe_Tests/CMakeLists.txt @@ -8,7 +8,7 @@ set(SRCS NcoOT_Tests.cpp OT_Tests.cpp Pprf_Tests.cpp - RegularDpf_Tests.cpp + Dpf_Tests.cpp SilentOT_Tests.cpp SoftSpoken_Tests.cpp TungstenCode_Tests.cpp diff --git a/libOTe_Tests/RegularDpf_Tests.cpp b/libOTe_Tests/Dpf_Tests.cpp similarity index 94% rename from libOTe_Tests/RegularDpf_Tests.cpp rename to libOTe_Tests/Dpf_Tests.cpp index 34bf4e02..13996d46 100644 --- a/libOTe_Tests/RegularDpf_Tests.cpp +++ b/libOTe_Tests/Dpf_Tests.cpp @@ -1,15 +1,18 @@ -#include "RegularDpf_Tests.h" +#include "Dpf_Tests.h" #include "libOTe/Dpf/RegularDpf.h" #include "coproto/Socket/LocalAsyncSock.h" #include "libOTe/Dpf/SparseDpf.h" #include #include -#include "libOTe/Dpf/TriDpf.h" +#include "libOTe/Dpf/TernaryDpf.h" +#include "cryptoTools/Common/TestCollection.h" +#include "libOTe/Tools/CoeffCtx.h" using namespace oc; void RegularDpf_Multiply_Test(const CLP& cmd) { +#if defined(ENABLE_REGULAR_DPF) || defined(ENABLE_SPARSE_DPF) u64 n = 13; PRNG prng(block(231234, 321312)); std::array dpf; @@ -110,10 +113,15 @@ void RegularDpf_Multiply_Test(const CLP& cmd) } } } +#else + throw UnitTestSkipped("ENABLE_REGULAR_DPF and ENABLE_SPARSE_DPF not defined."); +#endif } void RegularDpf_Proto_Test(const CLP& cmd) { +#ifdef ENABLE_REGULAR_DPF + PRNG prng(block(231234, 321312)); u64 domain = cmd.getOr("domain", 211); u64 numPoints = cmd.getOr("numPoints", 11); @@ -188,11 +196,14 @@ void RegularDpf_Proto_Test(const CLP& cmd) throw RTE_LOC; } } +#else + throw UnitTestSkipped("ENABLE_REGULAR_DPF not defined."); +#endif } void RegularDpf_Puncture_Test(const oc::CLP& cmd) { - +#ifdef ENABLE_REGULAR_DPF PRNG prng(block(231234, 321312)); u64 domain = cmd.getOr("domain", 8); u64 numPoints = cmd.getOr("numPoints", 1); @@ -269,10 +280,15 @@ void RegularDpf_Puncture_Test(const oc::CLP& cmd) if (failed) throw RTE_LOC; + +#else + throw UnitTestSkipped("ENABLE_REGULAR_DPF not defined."); +#endif } void RegularDpf_keyGen_Test(const oc::CLP& cmd) { +#ifdef ENABLE_REGULAR_DPF PRNG prng(block(231234, 321312)); u64 domain = cmd.getOr("domain", 211); @@ -373,10 +389,16 @@ void RegularDpf_keyGen_Test(const oc::CLP& cmd) throw RTE_LOC; } } + +#else + throw UnitTestSkipped("ENABLE_REGULAR_DPF not defined."); +#endif } void SparseDpf_Proto_Test(const oc::CLP& cmd) { +#ifdef ENABLE_SPARSE_DPF + PRNG prng(block(32324, 2342)); u64 numPoints = 1; u64 domain = 1773; @@ -473,11 +495,15 @@ void SparseDpf_Proto_Test(const oc::CLP& cmd) throw RTE_LOC; } } +#else + throw UnitTestSkipped("ENABLE_SPARSE_DPF not defined."); +#endif } template -void TritDpf_Proto_Test_(const oc::CLP& cmd) +void TernaryDpf_Proto_Test_(const oc::CLP& cmd) { +#ifdef ENABLE_TERNARY_DPF PRNG prng(block(231234, 321312)); u64 depth = cmd.getOr("depth", 3); @@ -500,7 +526,7 @@ void TritDpf_Proto_Test_(const oc::CLP& cmd) //ctx.minus(points0[i], points[i], points1[i];) } - std::array, 2> dpf; + std::array, 2> dpf; dpf[0].init(0, domain, numPoints); dpf[1].init(1, domain, numPoints); @@ -567,13 +593,13 @@ void TritDpf_Proto_Test_(const oc::CLP& cmd) throw RTE_LOC; } } - +#else + throw UnitTestSkipped("ENABLE_TERNARY_DPF not defined."); +#endif } void TritDpf_Proto_Test(const oc::CLP& cmd) { - TritDpf_Proto_Test_(cmd); - TritDpf_Proto_Test_(cmd); - //TritDpf_Proto_Test_(cmd); - + TernaryDpf_Proto_Test_(cmd); + TernaryDpf_Proto_Test_(cmd); } diff --git a/libOTe_Tests/RegularDpf_Tests.h b/libOTe_Tests/Dpf_Tests.h similarity index 100% rename from libOTe_Tests/RegularDpf_Tests.h rename to libOTe_Tests/Dpf_Tests.h diff --git a/libOTe_Tests/Foleage_Tests.cpp b/libOTe_Tests/Foleage_Tests.cpp index bd24f546..3afe07ef 100644 --- a/libOTe_Tests/Foleage_Tests.cpp +++ b/libOTe_Tests/Foleage_Tests.cpp @@ -5,6 +5,9 @@ #include "libOTe/Triple/Foleage/FoleageTriple.h" #include "coproto/Socket/LocalAsyncSock.h" #include "cryptoTools/Common/Timer.h" +#include "cryptoTools/Common/TestCollection.h" + + namespace osuCrypto { @@ -12,6 +15,7 @@ namespace osuCrypto // checks correctness of the resulting OLE correlation. void foleage_F4ole_test(const CLP& cmd) { +#ifdef ENABLE_FOLEAGE std::array oles; auto logn = 6; @@ -97,9 +101,16 @@ namespace osuCrypto if (verbose) std::cout << "Time taken: \n" << timer << std::endl; + +#else + throw UnitTestSkipped("ENABLE_FOLEAGE not defined."); +#endif } + void foleage_Triple_test(const CLP& cmd) { +#ifdef ENABLE_FOLEAGE + std::array oles; auto logn = 5; @@ -177,10 +188,16 @@ namespace osuCrypto if (verbose) std::cout << "Time taken: \n" << timer << std::endl; +#else + throw UnitTestSkipped("ENABLE_FOLEAGE not defined."); +#endif } void foleage_GenBase_test(const CLP& cmd) { +#ifdef ENABLE_FOLEAGE + // This test checks the base OTs are generated correctly. + for (auto type : { SilentBaseType::Base, SilentBaseType::BaseExtend }) { @@ -243,15 +260,17 @@ namespace osuCrypto throw RTE_LOC; } } +#else + throw UnitTestSkipped("ENABLE_FOLEAGE not defined."); +#endif } void foleage_tensor_test(const CLP& cmd) { +#ifdef ENABLE_FOLEAGE std::array oles; - //bool verbose = cmd.isSet("v"); - PRNG prng0(block(2424523452345, 111124521521455324)); PRNG prng1(block(6474567454546, 567546754674345444)); @@ -298,6 +317,8 @@ namespace osuCrypto throw RTE_LOC; } } - +#else + throw UnitTestSkipped("ENABLE_FOLEAGE not defined."); +#endif } } \ No newline at end of file diff --git a/libOTe_Tests/UnitTests.cpp b/libOTe_Tests/UnitTests.cpp index 14712112..bc05e74f 100644 --- a/libOTe_Tests/UnitTests.cpp +++ b/libOTe_Tests/UnitTests.cpp @@ -16,7 +16,7 @@ #include "libOTe/Tools/LDPC/Mtx.h" #include "libOTe_Tests/Pprf_Tests.h" #include "libOTe_Tests/TungstenCode_Tests.h" -#include "libOTe_Tests/RegularDpf_Tests.h" +#include "libOTe_Tests/Dpf_Tests.h" #include "libOTe_Tests/Foleage_Tests.h" using namespace osuCrypto; From af5aca1199ea3b33823ba26657d780858c2f9c58 Mon Sep 17 00:00:00 2001 From: Peter Rindal Date: Wed, 26 Feb 2025 22:22:42 -0800 Subject: [PATCH 48/48] readme --- README.md | 63 +++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 8903b7d5..f56a88d3 100644 --- a/README.md +++ b/README.md @@ -24,20 +24,35 @@ Malicious OT extension: Vole: * Generic subfield noisy VOLE (semi-honest) [[BCGIKRS19]](https://eprint.iacr.org/2019/1159.pdf) -* Generic subfield silent VOLE (malicious/semi-honest) [[BCGIKRS19]](https://eprint.iacr.org/2019/1159.pdf),[[RRT23]](https://eprint.iacr.org/2023/882). +* Generic subfield silent VOLE (malicious/semi-honest) [[BCGIKRS19]](https://eprint.iacr.org/2019/1159.pdf),[[RRT23]](https://eprint.iacr.org/2023/882). +OLE and Beaver Triples: +* Foleage Binary Beaver Triples and F4 OLE (semi-honest) [[BBCCDS2024]](https://eprint.iacr.org/2024/429.pdf). + +Distributed Point Functions: +* Distributed Point Function (DPF)[[BGI18]](https://eprint.iacr.org/2018/707.pdf) with [Distributed] Key Generation (DKG) [[Ds17]](https://eprint.iacr.org/2017/827.pdf). +* Ternary Distributed Point Function (DPF)[[BBCCDS2024]](https://eprint.iacr.org/2024/429.pdf) with Distributed Key Generation (DKG). +* Sparse Distributed Point Function (DPF) with Distributed Key Generation (DKG). + ## Introduction -This library provides several different classes of OT protocols. First is the +This library provides several different classes of OT, VOLE and Beaver Triple generation protocols. First is the base OT protocol of [CO15, MR19, MRR21]. These protocol bootstraps all the other -OT extension protocols. Within the OT extension protocols, we have 1-out-of-2, -1-out-of-N, and VOLE both in the semi-honest and malicious settings. See The `frontend` or `libOTe_Tests` folder for examples. +protocols. Within the OT extension protocols, we have 1-out-of-2, +1-out-of-N, and VOLE both in the semi-honest and malicious settings. Binary beaver triples can be +generating using the Foleage protocol. The library also includes a distributed point function (DPF) +protocol with distributed key generation (DKG) for secure computation. See The `frontend` or `libOTe_Tests` +folder for examples. All implementations are highly optimized using fast SSE instructions and vectorization to obtain optimal performance both in the single and multi-threaded setting. Networking can be performed using both the sockets provided by the library and -external socket classes. The simplest integration can be achieved via the [message passing interface](https://github.com/osu-crypto/libOTe/blob/master/frontend/ExampleMessagePassing.h) where the user is given the protocol messages that need to be sent/received. Users can also integrate their own socket type for maximum performance. See the [coproto](https://github.com/Visa-Research/coproto/blob/main/frontend/SocketTutorial.cpp) tutorial for examples. +external socket classes. The simplest integration can be achieved via the +[message passing interface](https://github.com/osu-crypto/libOTe/blob/master/frontend/ExampleMessagePassing.h) +where the user is given the protocol messages that need to be sent/received. +Users can also integrate their own socket type for maximum performance. +See the [coproto](https://github.com/Visa-Research/coproto/blob/main/frontend/SocketTutorial.cpp) tutorial for examples. ## Build @@ -47,7 +62,8 @@ There is one mandatory dependency on [coproto](https://github.com/Visa-Research/ and three **optional dependencies** on [libsodium](https://doc.libsodium.org/), [Relic](https://github.com/relic-toolkit/relic), or [SimplestOT](https://github.com/osu-crypto/libOTe/tree/master/SimplestOT) (Unix only) -for Base OTs. [Boost Asio](https://www.boost.org/doc/libs/1_84_0/doc/html/boost_asio.html) tcp networking and [OpenSSL](https://www.openssl.org/) support can optionally be enabled. +for Base OTs. [Boost Asio](https://www.boost.org/doc/libs/1_84_0/doc/html/boost_asio.html) +tcp networking and [OpenSSL](https://www.openssl.org/) support can optionally be enabled. CMake 3.15+ is required and the build script assumes python 3. The library can be built with libsodium, all OT protocols enabled and boost asio TCP networking as @@ -60,10 +76,14 @@ The main executable with examples is ``` ./out/build//frontend/frontend_libOTe ``` -where `` is the build directory, eg `linux`, `x64-Release`, `osx`, etc. **Unit Tests** and **example code** can be run with this excutable. Run the program with no options for a list of available options. +where `` is the build directory, eg `linux`, `x64-Release`, `osx`, etc. +**Unit Tests** and **example code** can be run with this excutable. +Run the program with no options for a list of available options. ### Build Options -LibOTe can be built with various only the selected protocols enabled. `-D ENABLE_ALL_OT=ON` will enable all available protocols depending on platform/dependencies. The `ON`/`OFF` options include +LibOTe can be built with various only the selected protocols enabled. +`-D ENABLE_ALL_OT=ON` will enable all available protocols depending +on platform/dependencies. The `ON`/`OFF` options include **Malicious base OT:** * `ENABLE_SIMPLESTOT` the SimplestOT [[CO15]](https://eprint.iacr.org/2015/267.pdf) protocol (relic or sodium). @@ -81,9 +101,22 @@ LibOTe can be built with various only the selected protocols enabled. `-D ENABLE * `ENABLE_SOFTSPOKEN_OT` the Roy [Roy22](https://eprint.iacr.org/2022/192) semi-honest/malicious protocol. * `ENABLE_SILENTOT` the [[BCGIKRS19]](https://eprint.iacr.org/2019/1159.pdf),[[RRT23]](https://eprint.iacr.org/2023/882) semi-honest/malicious protocol. + **1-out-of-N OT Extension:** + * `ENABLE_KKRT` the Kolesnikov et al [[KKRT16]](https://eprint.iacr.org/2016/799) semi-honest protocol. + * `ENABLE_OOS` the Orrù et al [[OOS16]](http://eprint.iacr.org/2016/933) semi-honest/malicious protocol. + **Vole:** * `ENABLE_SILENT_VOLE` the [[BCGIKRS19]](https://eprint.iacr.org/2019/1159.pdf),[[RRT23]](https://eprint.iacr.org/2023/882) semi-honest/malicious protocol. + ** DPF:** + * `ENABLE_REGULAR_DPF` the Boyle et al [[BGI18]](https://eprint.iacr.org/2018/707.pdf) semi-honest protocol. + * `ENABLE_TERNARY_DPF` the Bombar et al [[BBCCDS2024]](https://eprint.iacr.org/2024/429.pdf) semi-honest protocol. + * `ENABLE_SPARSE_DPF` protocol allowing a sparse set of DPF leaf values. + + **Beaver Triples:** + * `ENABLE_FOLEAGE` the Bombar et al [[BBCCDS2024]](https://eprint.iacr.org/2024/429.pdf) semi-honest protocol. + + Addition options can be set for cryptoTools. See the cmake output. ### Dependencies @@ -188,12 +221,18 @@ find_package(libOTe REQUIRED silent_vole oos kkrt + + foleage + + regular_dpf + ternary_dpf + sparse_dpf ) ``` ## Help -Contact Peter Rindal peterrindal@gmail.com for any assistance on building +Create a github issue or contact Peter Rindal peterrindal@gmail.com for any assistance on building or running the library. ## Citing @@ -226,8 +265,14 @@ or running the library. [ALSZ15] - Gilad Asharov and Yehuda Lindell and Thomas Schneider and Michael Zohner, _More Efficient Oblivious Transfer Extensions with Security for Malicious Adversaries_. [eprint/2015/061](https://eprint.iacr.org/2015/061) +[BGI18] - Elette Boyle, Niv Gilboa, Yuval Ishai, _Function Secret Sharing: Improvements and Extensions_ [eprint/2018/707](https://eprint.iacr.org/2018/707.pdf) + +[Ds17] - Jack Doerner, abhi shelat, _Scaling ORAM for Secure Computation_ [eprint/2017/827](https://eprint.iacr.org/2017/827.pdf) + [CRR21] - Geoffroy Couteau ,Srinivasan Raghuraman and Peter Rindal, _Silver: Silent VOLE and Oblivious Transfer from Hardness of Decoding Structured LDPC Codes_. [Roy22] - Lawrence Roy, SoftSpokenOT: Communication--Computation Tradeoffs in OT Extension. [eprint/2022/192](https://eprint.iacr.org/2022/192) [RRT23] - Srinivasan Raghuraman, Peter Rindal and Titouan Tanguy, _Expand-Convolute Codes for Pseudorandom Correlation Generators from LPN_. [eeprint/2023/882](https://eprint.iacr.org/2023/882) + +[BBCCDS2024] - Maxime Bombar, Dung Bui, Geoffroy Couteau, Alain Couvreur, Clément Ducros, and Sacha Servan-Schreiber, _FOLEAGE: F4 OLE-Based Multi-Party Computation for Boolean Circuits_. [eprint/2024/429](https://eprint.iacr.org/2024/429.pdf) \ No newline at end of file