Skip to content

Commit

Permalink
PoC that AEGIS-X(p) can be as fast as AEGIS-X(p-1)
Browse files Browse the repository at this point in the history
Right now, without 512-bit registers, AEGIS-X4 is generally
slower tan AEGIS-X2. AEGIS-X2 may also be slower than AEGIS-X1
on architectures with limited registers and AES pipelines.

The reason for that is register spills. We simulate large
vector registers, so actual registers constantly need to be spilled
and restored to/from the stack

A different strategy is to evaluate the AEGIS instances sequentially,
instead of in parallel.

By doing so, and ignoring initialization/finalization, an intuition
is that AEGIS-X4 has the same cost as the sum of 4 AEGIS-X1 runs.
That is, AEGIS-X4 is not slower than AEGIS-X1 on large messages.

If we need multiple passes over the entire message, memory accesses
would defeat this.

Unless the message is split into small chunks, and AEGIS instances
are sequentially run on individual chunks. Stack spills happens way
less frequently than when emulating large registers. But also,
once loaded during the first pass, the chunk is likely to be
available in the L1 or L2 caches, ready to be immediately processed
by the next AEGIS instances.

Using that trick, negotiating X2 or X4 would be acceptable most of
the time: if an endpoint has registers/pipelines large enough to
take advantage of them, they will. But if it doesn't, it wouldn't
be significantly slower than using a variant with a lower
parallelism degree.

The downside is a bit of implementation complexity, but also the
fact that the optimal chunk size depends on the architecture and
on the use cases.

We may pick that chunk size to look great on benchmarks, but
AEGIS is about real-world usage, not synthetic benchmarks. So,
the benefits of this approach needs to be properly measured.
  • Loading branch information
jedisct1 committed Nov 4, 2024
1 parent 146c9d9 commit 690859b
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 2 deletions.
81 changes: 80 additions & 1 deletion src/aegis128x4/aegis128x4_armcrypto.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
# pragma GCC target("+simd+crypto")
# endif

# define AES_BLOCK_LENGTH 64
# define AES_BLOCK_LENGTH 64
# define AES_BLOCK1_LENGTH 16
# define CHUNK_SIZE 512

typedef struct {
uint8x16_t b0;
Expand All @@ -35,6 +37,83 @@ typedef struct {
uint8x16_t b3;
} aes_block_t;

/* -- */

typedef uint8x16_t aes_block1_t;

# define AES_BLOCK1_XOR(A, B) veorq_u8((A), (B))
# define AES_BLOCK1_AND(A, B) vandq_u8((A), (B))
# define AES_BLOCK1_LOAD(A) vld1q_u8(A)
# define AES_BLOCK1_LOAD_64x2(A, B) vreinterpretq_u8_u64(vsetq_lane_u64((A), vmovq_n_u64(B), 1))
# define AES_BLOCK1_STORE(A, B) vst1q_u8((A), (B))
# define AES_ENC1(A, B) veorq_u8(vaesmcq_u8(vaeseq_u8((A), vmovq_n_u8(0))), (B))

static inline void
aegis128x4_update_b0(aes_block_t *const state, const aes_block1_t d1, const aes_block1_t d2)
{
aes_block1_t tmp;

tmp = state[7].b0;
state[7].b0 = AES_ENC1(state[6].b0, state[7].b0);
state[6].b0 = AES_ENC1(state[5].b0, state[6].b0);
state[5].b0 = AES_ENC1(state[4].b0, state[5].b0);
state[4].b0 = AES_BLOCK1_XOR(AES_ENC1(state[3].b0, state[4].b0), d2);
state[3].b0 = AES_ENC1(state[2].b0, state[3].b0);
state[2].b0 = AES_ENC1(state[1].b0, state[2].b0);
state[1].b0 = AES_ENC1(state[0].b0, state[1].b0);
state[0].b0 = AES_BLOCK1_XOR(AES_ENC1(tmp, state[0].b0), d1);
}

static inline void
aegis128x4_update_b1(aes_block_t *const state, const aes_block1_t d1, const aes_block1_t d2)
{
aes_block1_t tmp;

tmp = state[7].b1;
state[7].b1 = AES_ENC1(state[6].b1, state[7].b1);
state[6].b1 = AES_ENC1(state[5].b1, state[6].b1);
state[5].b1 = AES_ENC1(state[4].b1, state[5].b1);
state[4].b1 = AES_BLOCK1_XOR(AES_ENC1(state[3].b1, state[4].b1), d2);
state[3].b1 = AES_ENC1(state[2].b1, state[3].b1);
state[2].b1 = AES_ENC1(state[1].b1, state[2].b1);
state[1].b1 = AES_ENC1(state[0].b1, state[1].b1);
state[0].b1 = AES_BLOCK1_XOR(AES_ENC1(tmp, state[0].b1), d1);
}

static inline void
aegis128x4_update_b2(aes_block_t *const state, const aes_block1_t d1, const aes_block1_t d2)
{
aes_block1_t tmp;

tmp = state[7].b2;
state[7].b2 = AES_ENC1(state[6].b2, state[7].b2);
state[6].b2 = AES_ENC1(state[5].b2, state[6].b2);
state[5].b2 = AES_ENC1(state[4].b2, state[5].b2);
state[4].b2 = AES_BLOCK1_XOR(AES_ENC1(state[3].b2, state[4].b2), d2);
state[3].b2 = AES_ENC1(state[2].b2, state[3].b2);
state[2].b2 = AES_ENC1(state[1].b2, state[2].b2);
state[1].b2 = AES_ENC1(state[0].b2, state[1].b2);
state[0].b2 = AES_BLOCK1_XOR(AES_ENC1(tmp, state[0].b2), d1);
}

static inline void
aegis128x4_update_b3(aes_block_t *const state, const aes_block1_t d1, const aes_block1_t d2)
{
aes_block1_t tmp;

tmp = state[7].b3;
state[7].b3 = AES_ENC1(state[6].b3, state[7].b3);
state[6].b3 = AES_ENC1(state[5].b3, state[6].b3);
state[5].b3 = AES_ENC1(state[4].b3, state[5].b3);
state[4].b3 = AES_BLOCK1_XOR(AES_ENC1(state[3].b3, state[4].b3), d2);
state[3].b3 = AES_ENC1(state[2].b3, state[3].b3);
state[2].b3 = AES_ENC1(state[1].b3, state[2].b3);
state[1].b3 = AES_ENC1(state[0].b3, state[1].b3);
state[0].b3 = AES_BLOCK1_XOR(AES_ENC1(tmp, state[0].b3), d1);
}

/* -- */

static inline aes_block_t
AES_BLOCK_XOR(const aes_block_t a, const aes_block_t b)
{
Expand Down
108 changes: 107 additions & 1 deletion src/aegis128x4/aegis128x4_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,88 @@ aegis128x4_absorb(const uint8_t *const src, aes_block_t *const state)
aegis128x4_update(state, msg0, msg1);
}

#ifdef CHUNK_SIZE
static void
aegis128x4_enc_b0(uint8_t *const dst, const uint8_t *const src, aes_block_t *const state)
{
aes_block1_t msg0, msg1;
aes_block1_t tmp0, tmp1;

msg0 = AES_BLOCK1_LOAD(src);
msg1 = AES_BLOCK1_LOAD(src + AES_BLOCK_LENGTH);
tmp0 = AES_BLOCK1_XOR(msg0, state[6].b0);
tmp0 = AES_BLOCK1_XOR(tmp0, state[1].b0);
tmp1 = AES_BLOCK1_XOR(msg1, state[5].b0);
tmp1 = AES_BLOCK1_XOR(tmp1, state[2].b0);
tmp0 = AES_BLOCK1_XOR(tmp0, AES_BLOCK1_AND(state[2].b0, state[3].b0));
tmp1 = AES_BLOCK1_XOR(tmp1, AES_BLOCK1_AND(state[6].b0, state[7].b0));
AES_BLOCK1_STORE(dst, tmp0);
AES_BLOCK1_STORE(dst + AES_BLOCK_LENGTH, tmp1);

aegis128x4_update_b0(state, msg0, msg1);
}

static void
aegis128x4_enc_b1(uint8_t *const dst, const uint8_t *const src, aes_block_t *const state)
{
aes_block1_t msg0, msg1;
aes_block1_t tmp0, tmp1;

msg0 = AES_BLOCK1_LOAD(src);
msg1 = AES_BLOCK1_LOAD(src + AES_BLOCK_LENGTH);
tmp0 = AES_BLOCK1_XOR(msg0, state[6].b1);
tmp0 = AES_BLOCK1_XOR(tmp0, state[1].b1);
tmp1 = AES_BLOCK1_XOR(msg1, state[5].b1);
tmp1 = AES_BLOCK1_XOR(tmp1, state[2].b1);
tmp0 = AES_BLOCK1_XOR(tmp0, AES_BLOCK1_AND(state[2].b1, state[3].b1));
tmp1 = AES_BLOCK1_XOR(tmp1, AES_BLOCK1_AND(state[6].b1, state[7].b1));
AES_BLOCK1_STORE(dst, tmp0);
AES_BLOCK1_STORE(dst + AES_BLOCK_LENGTH, tmp1);

aegis128x4_update_b1(state, msg0, msg1);
}

static void
aegis128x4_enc_b2(uint8_t *const dst, const uint8_t *const src, aes_block_t *const state)
{
aes_block1_t msg0, msg1;
aes_block1_t tmp0, tmp1;

msg0 = AES_BLOCK1_LOAD(src);
msg1 = AES_BLOCK1_LOAD(src + AES_BLOCK_LENGTH);
tmp0 = AES_BLOCK1_XOR(msg0, state[6].b2);
tmp0 = AES_BLOCK1_XOR(tmp0, state[1].b2);
tmp1 = AES_BLOCK1_XOR(msg1, state[5].b2);
tmp1 = AES_BLOCK1_XOR(tmp1, state[2].b2);
tmp0 = AES_BLOCK1_XOR(tmp0, AES_BLOCK1_AND(state[2].b2, state[3].b2));
tmp1 = AES_BLOCK1_XOR(tmp1, AES_BLOCK1_AND(state[6].b2, state[7].b2));
AES_BLOCK1_STORE(dst, tmp0);
AES_BLOCK1_STORE(dst + AES_BLOCK_LENGTH, tmp1);

aegis128x4_update_b2(state, msg0, msg1);
}

static void
aegis128x4_enc_b3(uint8_t *const dst, const uint8_t *const src, aes_block_t *const state)
{
aes_block1_t msg0, msg1;
aes_block1_t tmp0, tmp1;

msg0 = AES_BLOCK1_LOAD(src);
msg1 = AES_BLOCK1_LOAD(src + AES_BLOCK_LENGTH);
tmp0 = AES_BLOCK1_XOR(msg0, state[6].b3);
tmp0 = AES_BLOCK1_XOR(tmp0, state[1].b3);
tmp1 = AES_BLOCK1_XOR(msg1, state[5].b3);
tmp1 = AES_BLOCK1_XOR(tmp1, state[2].b3);
tmp0 = AES_BLOCK1_XOR(tmp0, AES_BLOCK1_AND(state[2].b3, state[3].b3));
tmp1 = AES_BLOCK1_XOR(tmp1, AES_BLOCK1_AND(state[6].b3, state[7].b3));
AES_BLOCK1_STORE(dst, tmp0);
AES_BLOCK1_STORE(dst + AES_BLOCK_LENGTH, tmp1);

aegis128x4_update_b3(state, msg0, msg1);
}
#endif

static void
aegis128x4_enc(uint8_t *const dst, const uint8_t *const src, aes_block_t *const state)
{
Expand Down Expand Up @@ -209,7 +291,31 @@ encrypt_detached(uint8_t *c, uint8_t *mac, size_t maclen, const uint8_t *m, size
memcpy(src, ad + i, adlen % RATE);
aegis128x4_absorb(src, state);
}
for (i = 0; i + RATE <= mlen; i += RATE) {
i = 0;

#ifdef CHUNK_SIZE
{
const size_t mlenx = mlen - mlen % CHUNK_SIZE;
size_t j;

for (; i < mlenx; i += CHUNK_SIZE) {
for (j = 0; j < CHUNK_SIZE; j += RATE) {
aegis128x4_enc_b0(c + i + j, m + i + j, state);
}
for (j = AES_BLOCK1_LENGTH; j < CHUNK_SIZE; j += RATE) {
aegis128x4_enc_b1(c + i + j, m + i + j, state);
}
for (j = AES_BLOCK1_LENGTH * 2; j < CHUNK_SIZE; j += RATE) {
aegis128x4_enc_b2(c + i + j, m + i + j, state);
}
for (j = AES_BLOCK1_LENGTH * 3; j < CHUNK_SIZE; j += RATE) {
aegis128x4_enc_b3(c + i + j, m + i + j, state);
}
}
}
#endif

for (; i + RATE <= mlen; i += RATE) {
aegis128x4_enc(c + i, m + i, state);
}
if (mlen % RATE) {
Expand Down

0 comments on commit 690859b

Please sign in to comment.