Skip to content

Commit

Permalink
trying to fix invcb
Browse files Browse the repository at this point in the history
  • Loading branch information
nindanaoto committed Jul 15, 2024
1 parent 99fb42d commit 72f8475
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 42 deletions.
47 changes: 8 additions & 39 deletions include/mulfft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,6 @@ template <uint32_t N>
inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
const std::array<double, N> &b)
{
double* const res_ptr = std::assume_aligned<32>(res.data());
const double* const a_ptr = std::assume_aligned<32>(a.data());
const double* const b_ptr = std::assume_aligned<32>(b.data());
#ifdef USE_INTERLEAVED_FORMAT
for(int i = 0; i < N / 2; i++){
const std::complex tmp = std::complex(a[2*i], a[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
Expand All @@ -160,48 +157,20 @@ inline void MulInFD(std::array<double, N> &res, const std::array<double, N> &a,
}
#else
for (int i = 0; i < N / 2; i++) {
double aimbim = a_ptr[i + N / 2] * b_ptr[i + N / 2];
double arebim = a_ptr[i] * b_ptr[i + N / 2];
res_ptr[i] = std::fma(a[i], b_ptr[i], -aimbim);
res_ptr[i + N / 2] = std::fma(a_ptr[i + N / 2], b_ptr[i], arebim);
double aimbim = a[i + N / 2] * b[i + N / 2];
double arebim = a[i] * b[i + N / 2];
res[i] = std::fma(a[i], b[i], -aimbim);
res[i + N / 2] = std::fma(a[i + N / 2], b[i], arebim);
}
#endif
}

// template <uint32_t N, uint32_t kpo>
// inline void MulInFD(std::array<std::array<double, N>,kpo> &res, const std::array<double, N> &a,
// const std::array<std::array<double, N>,kpo> &b)
// {
// double* const res_ptr = std::assume_aligned<32>(res[0].data());
// const double* const a_ptr = std::assume_aligned<32>(a.data());
// const double* const b_ptr = std::assume_aligned<32>(b[0].data());
// for(int i = 0; i < N / 2; i++){
// for(int j = 0; j < kpo; j++){
// #ifdef USE_INTERLEAVED_FORMAT
// const std::complex tmp = std::complex(a[2*i], a[2*i+1]) * std::complex(b[2*i+j*N], b[2*i+1+j*N]);
// res[2*i+j*N] = tmp.real();
// res[2*i+1+j*N] = tmp.imag();
// #else
// // double aimbim = a_ptr[i + N / 2] * b_ptr[i + N / 2+j*N];
// res_ptr[i+j*N] = a_ptr[i + N / 2] * b_ptr[i + N / 2+j*N];
// res_ptr[i+j*N] = std::fma(a[i], b_ptr[i+j*N], -res_ptr[i+j*N]);
// // double arebim = a_ptr[i] * b_ptr[i + N / 2+j*N];
// res_ptr[i + N / 2+j*N] = a_ptr[i] * b_ptr[i + N / 2+j*N];
// res_ptr[i + N / 2+j*N] = std::fma(a_ptr[i + N / 2], b_ptr[i+j*N], res_ptr[i + N / 2+j*N]);
// #endif
// }
// }
// }

// Be careful about memory accesss (We assume b has relatively high memory
// access cost)
template <uint32_t N>
inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
const std::array<double, N> &b)
{
double* const res_ptr = std::assume_aligned<32>(res.data());
const double* const a_ptr = std::assume_aligned<32>(a.data());
const double* const b_ptr = std::assume_aligned<32>(b.data());
#ifdef USE_INTERLEAVED_FORMAT
for(int i = 0; i < N / 2; i++){
std::complex tmp = std::complex(a[2*i], a[2*i+1]) * std::complex(b[2*i], b[2*i+1]);
Expand All @@ -210,12 +179,12 @@ inline void FMAInFD(std::array<double, N> &res, const std::array<double, N> &a,
}
#else
for (int i = 0; i < N / 2; i++) {
res_ptr[i] = std::fma(a_ptr[i], b_ptr[i], res_ptr[i]);
res_ptr[i + N / 2] = std::fma(a_ptr[i + N / 2], b_ptr[i], res_ptr[i + N / 2]);
res[i] = std::fma(a[i], b[i], res[i]);
res[i + N / 2] = std::fma(a[i + N / 2], b[i], res[i + N / 2]);
}
for (int i = 0; i < N / 2; i++) {
res_ptr[i + N / 2] = std::fma(a_ptr[i], b_ptr[i + N / 2], res_ptr[i + N / 2]);
res_ptr[i] -= a_ptr[i + N / 2] * b_ptr[i + N / 2];
res[i + N / 2] = std::fma(a[i], b[i + N / 2], res[i + N / 2]);
res[i] -= a[i + N / 2] * b[i + N / 2];
}
// for (int i = 0; i < N / 2; i++) {
// res[i] = std::fma(a[i + N / 2], b[i + N / 2], -res[i]);
Expand Down
6 changes: 3 additions & 3 deletions test/invcircuitbootstrapping.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ int main()
for (int j = 0; j < lvl1param::n; j++)
pmu[i][j] = pa[i][j] ? lvl1param::μ : -lvl1param::μ;
for (int i = 0; i < num_test; i++) pzeros[i] = false;
vector<TRLWE<lvl1param>> ca(num_test);
vector<TRLWE<lvl1param>,TFHEpp::AlignedAllocator<TFHEpp::TRLWE<TFHEpp::lvl1param>,64>> ca(num_test);
vector<TLWE<lvl1param>> czeros(num_test);
vector<TRGSWFFT<lvl1param>> bootedTGSW(num_test);
vector<TRGSWFFT<lvl1param>> invbootedTGSW(num_test);
vector<TRGSWFFT<lvl1param>,TFHEpp::AlignedAllocator<TFHEpp::TRGSWFFT<TFHEpp::lvl1param>,64>> bootedTGSW(num_test);
vector<TRGSWFFT<lvl1param>,TFHEpp::AlignedAllocator<TFHEpp::TRGSWFFT<TFHEpp::lvl1param>,64>> invbootedTGSW(num_test);

for (int i = 0; i < num_test; i++)
ca[i] = trlweSymEncrypt<lvl1param>(pmu[i], sk->key.lvl1);
Expand Down

0 comments on commit 72f8475

Please sign in to comment.