Skip to content

Commit

Permalink
rm copy in ifft
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed Jul 13, 2024
1 parent fde8352 commit 05e07f7
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 56 deletions.
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ if(NOT USE_TERNARY)
endif()


if(USE_AVX512)
string(APPEND CMAKE_CXX_FLAGS " -mprefer-vector-width=512")
endif()
# if(USE_AVX512)
# string(APPEND CMAKE_CXX_FLAGS " -mprefer-vector-width=512")
# endif()

if(USE_FFTW3)
set(TFHEpp_DEFINITIONS
Expand Down
12 changes: 4 additions & 8 deletions include/cloudkey.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,23 +416,19 @@ struct EvalKey {
void emplacebkfft(const SecretKey& sk)
{
if constexpr (std::is_same_v<P, lvl01param>) {
bkfftlvl01 = std::make_unique_for_overwrite<
BootstrappingKeyFFT<lvl01param>>();
bkfftlvl01 = std::unique_ptr<BootstrappingKeyFFT<lvl01param>>(new (std::align_val_t(64)) BootstrappingKeyFFT<lvl01param>());
bkfftgen<lvl01param>(*bkfftlvl01, sk);
}
else if constexpr (std::is_same_v<P, lvlh1param>) {
bkfftlvlh1 = std::make_unique_for_overwrite<
BootstrappingKeyFFT<lvlh1param>>();
bkfftlvlh1 = std::unique_ptr<BootstrappingKeyFFT<lvlh1param>>(new (std::align_val_t(64)) BootstrappingKeyFFT<lvlh1param>());
bkfftgen<lvlh1param>(*bkfftlvlh1, sk);
}
else if constexpr (std::is_same_v<P, lvl02param>) {
bkfftlvl02 = std::make_unique_for_overwrite<
BootstrappingKeyFFT<lvl02param>>();
bkfftlvl02 = std::unique_ptr<BootstrappingKeyFFT<lvl02param>>(new (std::align_val_t(64)) BootstrappingKeyFFT<lvl02param>());
bkfftgen<lvl02param>(*bkfftlvl02, sk);
}
else if constexpr (std::is_same_v<P, lvlh2param>) {
bkfftlvlh2 = std::make_unique_for_overwrite<
BootstrappingKeyFFT<lvlh2param>>();
bkfftlvlh2 = std::unique_ptr<BootstrappingKeyFFT<lvlh2param>>(new (std::align_val_t(64)) BootstrappingKeyFFT<lvlh2param>());
bkfftgen<lvlh2param>(*bkfftlvlh2, sk);
}
else
Expand Down
57 changes: 41 additions & 16 deletions include/mulfft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ template <uint32_t N>
inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
const std::array<double, N> &b)
{
std::assume_aligned<16>(res.data());
std::assume_aligned<16>(a.data());
std::assume_aligned<16>(b.data());
double* const res_ptr = std::assume_aligned<32>(res.data());
const double* const a_ptr = std::assume_aligned<32>(a.data());
const double* const b_ptr = std::assume_aligned<32>(b.data());
#ifdef USE_INTERLEAVED_FORMAT
for(int i = 0; i < N / 2; i++){
const std::complex tmp = std::complex(a[2*i], a[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
Expand All @@ -160,23 +160,48 @@ inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
}
#else
for (int i = 0; i < N / 2; i++) {
double aimbim = a[i + N / 2] * b[i + N / 2];
double arebim = a[i] * b[i + N / 2];
res[i] = std::fma(a[i], b[i], -aimbim);
res[i + N / 2] = std::fma(a[i + N / 2], b[i], arebim);
double aimbim = a_ptr[i + N / 2] * b_ptr[i + N / 2];
double arebim = a_ptr[i] * b_ptr[i + N / 2];
res_ptr[i] = std::fma(a[i], b_ptr[i], -aimbim);
res_ptr[i + N / 2] = std::fma(a_ptr[i + N / 2], b_ptr[i], arebim);
}
#endif
}

// template <uint32_t N, uint32_t kpo>
// inline void MulInFD(std::array<std::array<double, N>,kpo> &res, const std::array<double, N> &a,
// const std::array<std::array<double, N>,kpo> &b)
// {
// double* const res_ptr = std::assume_aligned<32>(res[0].data());
// const double* const a_ptr = std::assume_aligned<32>(a.data());
// const double* const b_ptr = std::assume_aligned<32>(b[0].data());
// for(int i = 0; i < N / 2; i++){
// for(int j = 0; j < kpo; j++){
// #ifdef USE_INTERLEAVED_FORMAT
// const std::complex tmp = std::complex(a[2*i], a[2*i+1]) * std::complex(b[2*i+j*N], b[2*i+1+j*N]);
// res[2*i+j*N] = tmp.real();
// res[2*i+1+j*N] = tmp.imag();
// #else
// // double aimbim = a_ptr[i + N / 2] * b_ptr[i + N / 2+j*N];
// res_ptr[i+j*N] = a_ptr[i + N / 2] * b_ptr[i + N / 2+j*N];
// res_ptr[i+j*N] = std::fma(a[i], b_ptr[i+j*N], -res_ptr[i+j*N]);
// // double arebim = a_ptr[i] * b_ptr[i + N / 2+j*N];
// res_ptr[i + N / 2+j*N] = a_ptr[i] * b_ptr[i + N / 2+j*N];
// res_ptr[i + N / 2+j*N] = std::fma(a_ptr[i + N / 2], b_ptr[i+j*N], res_ptr[i + N / 2+j*N]);
// #endif
// }
// }
// }

// Be careful about memory accesss (We assume b has relatively high memory
// access cost)
template <uint32_t N>
inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
const std::array<double, N> &b)
{
std::assume_aligned<16>(res.data());
std::assume_aligned<16>(a.data());
std::assume_aligned<16>(b.data());
double* const res_ptr = std::assume_aligned<32>(res.data());
const double* const a_ptr = std::assume_aligned<32>(a.data());
const double* const b_ptr = std::assume_aligned<32>(b.data());
#ifdef USE_INTERLEAVED_FORMAT
for(int i = 0; i < N / 2; i++){
std::complex tmp = std::complex(a[2*i], a[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
Expand All @@ -185,12 +210,12 @@ inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
}
#else
for (int i = 0; i < N / 2; i++) {
res[i] = std::fma(a[i], b[i], res[i]);
res[i + N / 2] = std::fma(a[i + N / 2], b[i], res[i + N / 2]);
res_ptr[i] = std::fma(a_ptr[i], b_ptr[i], res_ptr[i]);
res_ptr[i + N / 2] = std::fma(a_ptr[i + N / 2], b_ptr[i], res_ptr[i + N / 2]);
}
for (int i = 0; i < N / 2; i++) {
res[i + N / 2] = std::fma(a[i], b[i + N / 2], res[i + N / 2]);
res[i] -= a[i + N / 2] * b[i + N / 2];
res_ptr[i + N / 2] = std::fma(a_ptr[i], b_ptr[i + N / 2], res_ptr[i + N / 2]);
res_ptr[i] -= a_ptr[i + N / 2] * b_ptr[i + N / 2];
}
// for (int i = 0; i < N / 2; i++) {
// res[i] = std::fma(a[i + N / 2], b[i + N / 2], -res[i]);
Expand All @@ -206,9 +231,9 @@ inline void PolyMul(Polynomial<P> &res, const Polynomial<P> &a,
const Polynomial<P> &b)
{
if constexpr (std::is_same_v<typename P::T, uint32_t>) {
PolynomialInFD<P> ffta;
alignas(64) PolynomialInFD<P> ffta;
TwistIFFT<P>(ffta, a);
PolynomialInFD<P> fftb;
alignas(64) PolynomialInFD<P> fftb;
TwistIFFT<P>(fftb, b);
MulInFD<P::n>(ffta, ffta, fftb);
TwistFFT<P>(res, ffta);
Expand Down
8 changes: 4 additions & 4 deletions include/trgsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ template <class P>
void trgswfftExternalProduct(TRLWE<P> &res, const TRLWE<P> &trlwe,
const TRGSWFFT<P> &trgswfft)
{
DecomposedPolynomial<P> decpoly;
alignas(64) DecomposedPolynomial<P> decpoly;
Decomposition<P>(decpoly, trlwe[0]);
PolynomialInFD<P> decpolyfft;
alignas(64) PolynomialInFD<P> decpolyfft;
// __builtin_prefetch(trgswfft[0].data());
TwistIFFT<P>(decpolyfft, decpoly[0]);
TRLWEInFD<P> restrlwefft;
alignas(64) TRLWEInFD<P> restrlwefft;
for (int m = 0; m < P::k + 1; m++)
MulInFD<P::n>(restrlwefft[m], decpolyfft, trgswfft[0][m]);
for (int i = 1; i < P::l; i++) {
Expand Down Expand Up @@ -246,7 +246,7 @@ constexpr std::array<typename P::T, P::l> hgen()
template <class P>
TRGSWFFT<P> ApplyFFT2trgsw(const TRGSW<P> &trgsw)
{
TRGSWFFT<P> trgswfft;
alignas(64) TRGSWFFT<P> trgswfft;
for (int i = 0; i < (P::k + 1) * P::l; i++)
for (int j = 0; j < (P::k + 1); j++)
TwistIFFT<P>(trgswfft[i][j], trgsw[i][j]);
Expand Down
6 changes: 3 additions & 3 deletions include/trlwe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ TRLWE<P> trlweSymEncryptZero(const uint η, const Key<P> &key)
{
std::uniform_int_distribution<typename P::T> Torusdist(
0, std::numeric_limits<typename P::T>::max());
TRLWE<P> c;
alignas(64) TRLWE<P> c;
for (typename P::T &i : c[P::k])
i = (CenteredBinomial<P>(η) << std::numeric_limits<P>::digits) / P::q;
for (int k = 0; k < P::k; k++) {
for (typename P::T &i : c[k]) i = Torusdist(generator);
std::array<typename P::T, P::n> partkey;
alignas(64) std::array<typename P::T, P::n> partkey;
for (int i = 0; i < P::n; i++) partkey[i] = key[k * P::n + i];
Polynomial<P> temp;
alignas(64) Polynomial<P> temp;
PolyMul<P>(temp, c[k], partkey);
for (int i = 0; i < P::n; i++) c[P::k][i] += temp[i];
}
Expand Down
2 changes: 1 addition & 1 deletion thirdparties/spqlios/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ else()
spqlios-fft-impl.cpp)
endif()

set(SPQLIOS_HEADERS fft_processor_spqlios.h)
set(SPQLIOS_HEADERS fft_processor_spqlios.h x86.h)

if(ENABLE_SHARED)
add_library(spqlios SHARED ${SRCS_FMA} ${SPQLIOS_HEADERS})
Expand Down
46 changes: 25 additions & 21 deletions thirdparties/spqlios/fft_processor_spqlios.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include<params.hpp>

#include "x86.h"
#include "fft_processor_spqlios.h"

using namespace std;
Expand Down Expand Up @@ -41,7 +42,8 @@ FFT_Processor_Spqlios::FFT_Processor_Spqlios(const int32_t N) : _2N(2 * N), N(N)
void FFT_Processor_Spqlios::execute_reverse_int(double *res, const int32_t *a) {
//for (int32_t i=0; i<N; i++) real_inout_rev[i]=(double)a[i];
{
double *dst = real_inout_rev;
double *dst = res;
// double *dst = real_inout_rev;
const int32_t *ait = a;
const int32_t *aend = a + N;
__asm__ __volatile__ (
Expand All @@ -58,26 +60,27 @@ void FFT_Processor_Spqlios::execute_reverse_int(double *res, const int32_t *a) {
: "%xmm0", "%ymm1", "memory"
);
}
ifft(tables_reverse, real_inout_rev);
ifft(tables_reverse, res);
// ifft(tables_reverse, real_inout_rev);
//for (int32_t i=0; i<N; i++) res[i]=real_inout_rev[i];
{
double *dst = res;
double *sit = real_inout_rev;
double *send = real_inout_rev + N;
__asm__ __volatile__ (
"1:\n"
"vmovapd (%1),%%ymm0\n"
"vmovupd %%ymm0,(%0)\n"
"addq $32,%1\n"
"addq $32,%0\n"
"cmpq %2,%1\n"
"jb 1b\n"
"vzeroall\n"
: "=r"(dst), "=r"(sit), "=r"(send)
: "0"(dst), "1"(sit), "2"(send)
: "%ymm0", "memory"
);
}
// {
// double *dst = res;
// double *sit = real_inout_rev;
// double *send = real_inout_rev + N;
// __asm__ __volatile__ (
// "1:\n"
// "vmovapd (%1),%%ymm0\n"
// "vmovupd %%ymm0,(%0)\n"
// "addq $32,%1\n"
// "addq $32,%0\n"
// "cmpq %2,%1\n"
// "jb 1b\n"
// "vzeroall\n"
// : "=r"(dst), "=r"(sit), "=r"(send)
// : "0"(dst), "1"(sit), "2"(send)
// : "%ymm0", "memory"
// );
// }
}

void FFT_Processor_Spqlios::execute_reverse_torus32(double *res, const uint32_t *a) {
Expand Down Expand Up @@ -124,7 +127,8 @@ void FFT_Processor_Spqlios::execute_direct_torus32(uint32_t *res, const double *
);
}
fft(tables_direct, real_inout_direct);
for (int32_t i = 0; i < N; i++) res[i] = uint32_t(int64_t(real_inout_direct[i]));
// for (int32_t i = 0; i < N; i++) res[i] = uint32_t(int64_t(real_inout_direct[i]));
SPQLIOS::convert_f64_to_u32(res,real_inout_direct,N);
}

void FFT_Processor_Spqlios::execute_direct_torus32_q(uint32_t *res, const double *a, const uint32_t q) {
Expand Down
Loading

0 comments on commit 05e07f7

Please sign in to comment.