1/*------------------------------------------------------------------------
2/ OCB Version 3 Reference Code (Optimized C)     Last modified 12-JUN-2013
3/-------------------------------------------------------------------------
4/ Copyright (c) 2013 Ted Krovetz.
5/
6/ Permission to use, copy, modify, and/or distribute this software for any
7/ purpose with or without fee is hereby granted, provided that the above
8/ copyright notice and this permission notice appear in all copies.
9/
10/ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11/ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12/ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13/ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14/ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15/ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16/ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17/
18/ Phillip Rogaway holds patents relevant to OCB. See the following for
19/ his patent grant: http://www.cs.ucdavis.edu/~rogaway/ocb/grant.htm
20/
21/ Special thanks to Keegan McAllister for suggesting several good improvements
22/
23/ Comments are welcome: Ted Krovetz <ted@krovetz.net> - Dedicated to Laurel K
24/------------------------------------------------------------------------- */
25
26/* ----------------------------------------------------------------------- */
27/* Usage notes                                                             */
28/* ----------------------------------------------------------------------- */
29
30/* - When AE_PENDING is passed as the 'final' parameter of any function,
31/    the length parameters must be a multiple of (BPI*16).
32/  - When available, SSE or AltiVec registers are used to manipulate data.
33/    So, when on machines with these facilities, all pointers passed to
34/    any function should be 16-byte aligned.
35/  - Plaintext and ciphertext pointers may be equal (ie, plaintext gets
36/    encrypted in-place), but no other pair of pointers may be equal.
37/  - This code assumes all x86 processors have SSE2 and SSSE3 instructions
38/    when compiling under MSVC. If untrue, alter the #define.
39/  - This code is tested for C99 and recent versions of GCC and MSVC.      */
40
41/* ----------------------------------------------------------------------- */
42/* User configuration options                                              */
43/* ----------------------------------------------------------------------- */
44
45/* Set the AES key length to use and length of authentication tag to produce.
46/  Setting either to 0 requires the value be set at runtime via ae_init().
47/  Some optimizations occur for each when set to a fixed value.            */
48#define OCB_KEY_LEN 16 /* 0, 16, 24 or 32. 0 means set in ae_init */
49#define OCB_TAG_LEN 16 /* 0 to 16. 0 means set in ae_init         */
50
51/* This implementation has built-in support for multiple AES APIs. Set any
52/  one of the following to non-zero to specify which to use.               */
53#define USE_OPENSSL_AES 1   /* http://openssl.org                      */
54#define USE_REFERENCE_AES 0 /* Internet search: rijndael-alg-fst.c     */
55#define USE_AES_NI 0        /* Uses compiler's intrinsics              */
56
57/* During encryption and decryption, various "L values" are required.
58/  The L values can be precomputed during initialization (requiring extra
59/  space in ae_ctx), generated as needed (slightly slowing encryption and
60/  decryption), or some combination of the two. L_TABLE_SZ specifies how many
61/  L values to precompute. L_TABLE_SZ must be at least 3. L_TABLE_SZ*16 bytes
62/  are used for L values in ae_ctx. Plaintext and ciphertexts shorter than
63/  2^L_TABLE_SZ blocks need no L values calculated dynamically.            */
64#define L_TABLE_SZ 16
65
66/* Set L_TABLE_SZ_IS_ENOUGH non-zero iff you know that all plaintexts
67/  will be shorter than 2^(L_TABLE_SZ+4) bytes in length. This results
68/  in better performance.                                                  */
69#define L_TABLE_SZ_IS_ENOUGH 1
70
71/* ----------------------------------------------------------------------- */
72/* Includes and compiler specific definitions                              */
73/* ----------------------------------------------------------------------- */
74
75#include "ae.h"
76#include <stdlib.h>
77#include <string.h>
78
79/* Define standard sized integers                                          */
80#if defined(_MSC_VER) && (_MSC_VER < 1600)
81typedef unsigned __int8 uint8_t;
82typedef unsigned __int32 uint32_t;
83typedef unsigned __int64 uint64_t;
84typedef __int64 int64_t;
85#else
86#include <stdint.h>
87#endif
88
89/* Compiler-specific intrinsics and fixes: bswap64, ntz                    */
90#if _MSC_VER
91#define inline __inline                           /* MSVC doesn't recognize "inline" in C */
92#define restrict __restrict                       /* MSVC doesn't recognize "restrict" in C */
93#define __SSE2__ (_M_IX86 || _M_AMD64 || _M_X64)  /* Assume SSE2  */
94#define __SSSE3__ (_M_IX86 || _M_AMD64 || _M_X64) /* Assume SSSE3 */
95#include <intrin.h>
96#pragma intrinsic(_byteswap_uint64, _BitScanForward, memcpy)
97#define bswap64(x) _byteswap_uint64(x)
98static inline unsigned ntz(unsigned x) {
99    _BitScanForward(&x, x);
100    return x;
101}
102#elif __GNUC__
103#define inline __inline__                   /* No "inline" in GCC ansi C mode */
104#define restrict __restrict__               /* No "restrict" in GCC ansi C mode */
105#define bswap64(x) __builtin_bswap64(x)     /* Assuming GCC 4.3+ */
106#define ntz(x) __builtin_ctz((unsigned)(x)) /* Assuming GCC 3.4+ */
107#else /* Assume some C99 features: stdint.h, inline, restrict */
108#define bswap32(x)                                                                                 \
109    ((((x)&0xff000000u) >> 24) | (((x)&0x00ff0000u) >> 8) | (((x)&0x0000ff00u) << 8) |             \
110     (((x)&0x000000ffu) << 24))
111
112static inline uint64_t bswap64(uint64_t x) {
113    union {
114        uint64_t u64;
115        uint32_t u32[2];
116    } in, out;
117    in.u64 = x;
118    out.u32[0] = bswap32(in.u32[1]);
119    out.u32[1] = bswap32(in.u32[0]);
120    return out.u64;
121}
122
123#if (L_TABLE_SZ <= 9) && (L_TABLE_SZ_IS_ENOUGH) /* < 2^13 byte texts */
124static inline unsigned ntz(unsigned x) {
125    static const unsigned char tz_table[] = {
126        0, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2,
127        3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2,
128        4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 8, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2,
129        3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2,
130        5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2};
131    return tz_table[x / 4];
132}
133#else                                           /* From http://supertech.csail.mit.edu/papers/debruijn.pdf */
134static inline unsigned ntz(unsigned x) {
135    static const unsigned char tz_table[32] = {0,  1,  28, 2,  29, 14, 24, 3,  30, 22, 20,
136                                               15, 25, 17, 4,  8,  31, 27, 13, 23, 21, 19,
137                                               16, 7,  26, 12, 18, 6,  11, 5,  10, 9};
138    return tz_table[((uint32_t)((x & -x) * 0x077CB531u)) >> 27];
139}
140#endif
141#endif
142
143/* ----------------------------------------------------------------------- */
144/* Define blocks and operations -- Patch if incorrect on your compiler.    */
145/* ----------------------------------------------------------------------- */
146
147#if __SSE2__ && !KEYMASTER_CLANG_TEST_BUILD
148#include <xmmintrin.h> /* SSE instructions and _mm_malloc */
149#include <emmintrin.h> /* SSE2 instructions               */
150typedef __m128i block;
151#define xor_block(x, y) _mm_xor_si128(x, y)
152#define zero_block() _mm_setzero_si128()
153#define unequal_blocks(x, y) (_mm_movemask_epi8(_mm_cmpeq_epi8(x, y)) != 0xffff)
154#if __SSSE3__ || USE_AES_NI
155#include <tmmintrin.h> /* SSSE3 instructions              */
156#define swap_if_le(b)                                                                              \
157    _mm_shuffle_epi8(b, _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
158#else
159static inline block swap_if_le(block b) {
160    block a = _mm_shuffle_epi32(b, _MM_SHUFFLE(0, 1, 2, 3));
161    a = _mm_shufflehi_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
162    a = _mm_shufflelo_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
163    return _mm_xor_si128(_mm_srli_epi16(a, 8), _mm_slli_epi16(a, 8));
164}
165#endif
166static inline block gen_offset(uint64_t KtopStr[3], unsigned bot) {
167    block hi = _mm_load_si128((__m128i*)(KtopStr + 0));  /* hi = B A */
168    block lo = _mm_loadu_si128((__m128i*)(KtopStr + 1)); /* lo = C B */
169    __m128i lshift = _mm_cvtsi32_si128(bot);
170    __m128i rshift = _mm_cvtsi32_si128(64 - bot);
171    lo = _mm_xor_si128(_mm_sll_epi64(hi, lshift), _mm_srl_epi64(lo, rshift));
172#if __SSSE3__ || USE_AES_NI
173    return _mm_shuffle_epi8(lo, _mm_set_epi8(8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7));
174#else
175    return swap_if_le(_mm_shuffle_epi32(lo, _MM_SHUFFLE(1, 0, 3, 2)));
176#endif
177}
178static inline block double_block(block bl) {
179    const __m128i mask = _mm_set_epi32(135, 1, 1, 1);
180    __m128i tmp = _mm_srai_epi32(bl, 31);
181    tmp = _mm_and_si128(tmp, mask);
182    tmp = _mm_shuffle_epi32(tmp, _MM_SHUFFLE(2, 1, 0, 3));
183    bl = _mm_slli_epi32(bl, 1);
184    return _mm_xor_si128(bl, tmp);
185}
186#elif __ALTIVEC__
187#include <altivec.h>
188typedef vector unsigned block;
189#define xor_block(x, y) vec_xor(x, y)
190#define zero_block() vec_splat_u32(0)
191#define unequal_blocks(x, y) vec_any_ne(x, y)
192#define swap_if_le(b) (b)
193#if __PPC64__
194block gen_offset(uint64_t KtopStr[3], unsigned bot) {
195    union {
196        uint64_t u64[2];
197        block bl;
198    } rval;
199    rval.u64[0] = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
200    rval.u64[1] = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
201    return rval.bl;
202}
203#else
204/* Special handling: Shifts are mod 32, and no 64-bit types */
205block gen_offset(uint64_t KtopStr[3], unsigned bot) {
206    const vector unsigned k32 = {32, 32, 32, 32};
207    vector unsigned hi = *(vector unsigned*)(KtopStr + 0);
208    vector unsigned lo = *(vector unsigned*)(KtopStr + 2);
209    vector unsigned bot_vec;
210    if (bot < 32) {
211        lo = vec_sld(hi, lo, 4);
212    } else {
213        vector unsigned t = vec_sld(hi, lo, 4);
214        lo = vec_sld(hi, lo, 8);
215        hi = t;
216        bot = bot - 32;
217    }
218    if (bot == 0)
219        return hi;
220    *(unsigned*)&bot_vec = bot;
221    vector unsigned lshift = vec_splat(bot_vec, 0);
222    vector unsigned rshift = vec_sub(k32, lshift);
223    hi = vec_sl(hi, lshift);
224    lo = vec_sr(lo, rshift);
225    return vec_xor(hi, lo);
226}
227#endif
228static inline block double_block(block b) {
229    const vector unsigned char mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
230    const vector unsigned char perm = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0};
231    const vector unsigned char shift7 = vec_splat_u8(7);
232    const vector unsigned char shift1 = vec_splat_u8(1);
233    vector unsigned char c = (vector unsigned char)b;
234    vector unsigned char t = vec_sra(c, shift7);
235    t = vec_and(t, mask);
236    t = vec_perm(t, t, perm);
237    c = vec_sl(c, shift1);
238    return (block)vec_xor(c, t);
239}
240#elif __ARM_NEON__
241#include <arm_neon.h>
242typedef int8x16_t block; /* Yay! Endian-neutral reads! */
243#define xor_block(x, y) veorq_s8(x, y)
244#define zero_block() vdupq_n_s8(0)
245static inline int unequal_blocks(block a, block b) {
246    int64x2_t t = veorq_s64((int64x2_t)a, (int64x2_t)b);
247    return (vgetq_lane_s64(t, 0) | vgetq_lane_s64(t, 1)) != 0;
248}
249#define swap_if_le(b) (b) /* Using endian-neutral int8x16_t */
250/* KtopStr is reg correct by 64 bits, return mem correct */
251block gen_offset(uint64_t KtopStr[3], unsigned bot) {
252    const union {
253        unsigned x;
254        unsigned char endian;
255    } little = {1};
256    const int64x2_t k64 = {-64, -64};
257    /* Copy hi and lo into local variables to ensure proper alignment */
258    uint64x2_t hi = vld1q_u64(KtopStr + 0); /* hi = A B */
259    uint64x2_t lo = vld1q_u64(KtopStr + 1); /* lo = B C */
260    int64x2_t ls = vdupq_n_s64(bot);
261    int64x2_t rs = vqaddq_s64(k64, ls);
262    block rval = (block)veorq_u64(vshlq_u64(hi, ls), vshlq_u64(lo, rs));
263    if (little.endian)
264        rval = vrev64q_s8(rval);
265    return rval;
266}
267static inline block double_block(block b) {
268    const block mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
269    block tmp = vshrq_n_s8(b, 7);
270    tmp = vandq_s8(tmp, mask);
271    tmp = vextq_s8(tmp, tmp, 1); /* Rotate high byte to end */
272    b = vshlq_n_s8(b, 1);
273    return veorq_s8(tmp, b);
274}
275#else
276typedef struct { uint64_t l, r; } block;
277static inline block xor_block(block x, block y) {
278    x.l ^= y.l;
279    x.r ^= y.r;
280    return x;
281}
282static inline block zero_block(void) {
283    const block t = {0, 0};
284    return t;
285}
286#define unequal_blocks(x, y) ((((x).l ^ (y).l) | ((x).r ^ (y).r)) != 0)
287static inline block swap_if_le(block b) {
288    const union {
289        unsigned x;
290        unsigned char endian;
291    } little = {1};
292    if (little.endian) {
293        block r;
294        r.l = bswap64(b.l);
295        r.r = bswap64(b.r);
296        return r;
297    } else
298        return b;
299}
300
301/* KtopStr is reg correct by 64 bits, return mem correct */
302block gen_offset(uint64_t KtopStr[3], unsigned bot) {
303    block rval;
304    if (bot != 0) {
305        rval.l = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
306        rval.r = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
307    } else {
308        rval.l = KtopStr[0];
309        rval.r = KtopStr[1];
310    }
311    return swap_if_le(rval);
312}
313
314#if __GNUC__ && __arm__
315static inline block double_block(block b) {
316    __asm__("adds %1,%1,%1\n\t"
317            "adcs %H1,%H1,%H1\n\t"
318            "adcs %0,%0,%0\n\t"
319            "adcs %H0,%H0,%H0\n\t"
320            "it cs\n\t"
321            "eorcs %1,%1,#135"
322            : "+r"(b.l), "+r"(b.r)
323            :
324            : "cc");
325    return b;
326}
327#else
328static inline block double_block(block b) {
329    uint64_t t = (uint64_t)((int64_t)b.l >> 63);
330    b.l = (b.l + b.l) ^ (b.r >> 63);
331    b.r = (b.r + b.r) ^ (t & 135);
332    return b;
333}
334#endif
335
336#endif
337
338/* ----------------------------------------------------------------------- */
339/* AES - Code uses OpenSSL API. Other implementations get mapped to it.    */
340/* ----------------------------------------------------------------------- */
341
342/*---------------*/
343#if USE_OPENSSL_AES
344/*---------------*/
345
346#include <openssl/aes.h> /* http://openssl.org/ */
347
348/* How to ECB encrypt an array of blocks, in place                         */
349static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
350    while (nblks) {
351        --nblks;
352        AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
353    }
354}
355
356static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
357    while (nblks) {
358        --nblks;
359        AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
360    }
361}
362
363#define BPI 4 /* Number of blocks in buffer per ECB call */
364
365/*-------------------*/
366#elif USE_REFERENCE_AES
367/*-------------------*/
368
369#include "rijndael-alg-fst.h" /* Barreto's Public-Domain Code */
370#if (OCB_KEY_LEN == 0)
371typedef struct {
372    uint32_t rd_key[60];
373    int rounds;
374} AES_KEY;
375#define ROUNDS(ctx) ((ctx)->rounds)
376#define AES_set_encrypt_key(x, y, z)                                                               \
377    do {                                                                                           \
378        rijndaelKeySetupEnc((z)->rd_key, x, y);                                                    \
379        (z)->rounds = y / 32 + 6;                                                                  \
380    } while (0)
381#define AES_set_decrypt_key(x, y, z)                                                               \
382    do {                                                                                           \
383        rijndaelKeySetupDec((z)->rd_key, x, y);                                                    \
384        (z)->rounds = y / 32 + 6;                                                                  \
385    } while (0)
386#else
387typedef struct { uint32_t rd_key[OCB_KEY_LEN + 28]; } AES_KEY;
388#define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
389#define AES_set_encrypt_key(x, y, z) rijndaelKeySetupEnc((z)->rd_key, x, y)
390#define AES_set_decrypt_key(x, y, z) rijndaelKeySetupDec((z)->rd_key, x, y)
391#endif
392#define AES_encrypt(x, y, z) rijndaelEncrypt((z)->rd_key, ROUNDS(z), x, y)
393#define AES_decrypt(x, y, z) rijndaelDecrypt((z)->rd_key, ROUNDS(z), x, y)
394
395static void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
396    while (nblks) {
397        --nblks;
398        AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
399    }
400}
401
402void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
403    while (nblks) {
404        --nblks;
405        AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
406    }
407}
408
409#define BPI 4 /* Number of blocks in buffer per ECB call */
410
411/*----------*/
412#elif USE_AES_NI
413/*----------*/
414
415#include <wmmintrin.h>
416
417#if (OCB_KEY_LEN == 0)
418typedef struct {
419    __m128i rd_key[15];
420    int rounds;
421} AES_KEY;
422#define ROUNDS(ctx) ((ctx)->rounds)
423#else
424typedef struct { __m128i rd_key[7 + OCB_KEY_LEN / 4]; } AES_KEY;
425#define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
426#endif
427
428#define EXPAND_ASSIST(v1, v2, v3, v4, shuff_const, aes_const)                                      \
429    v2 = _mm_aeskeygenassist_si128(v4, aes_const);                                                 \
430    v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 16));         \
431    v1 = _mm_xor_si128(v1, v3);                                                                    \
432    v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 140));        \
433    v1 = _mm_xor_si128(v1, v3);                                                                    \
434    v2 = _mm_shuffle_epi32(v2, shuff_const);                                                       \
435    v1 = _mm_xor_si128(v1, v2)
436
437#define EXPAND192_STEP(idx, aes_const)                                                             \
438    EXPAND_ASSIST(x0, x1, x2, x3, 85, aes_const);                                                  \
439    x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4));                                                 \
440    x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255));                                            \
441    kp[idx] = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(tmp), _mm_castsi128_ps(x0), 68));   \
442    kp[idx + 1] =                                                                                  \
443        _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(x0), _mm_castsi128_ps(x3), 78));          \
444    EXPAND_ASSIST(x0, x1, x2, x3, 85, (aes_const * 2));                                            \
445    x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4));                                                 \
446    x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255));                                            \
447    kp[idx + 2] = x0;                                                                              \
448    tmp = x3
449
450static void AES_128_Key_Expansion(const unsigned char* userkey, void* key) {
451    __m128i x0, x1, x2;
452    __m128i* kp = (__m128i*)key;
453    kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
454    x2 = _mm_setzero_si128();
455    EXPAND_ASSIST(x0, x1, x2, x0, 255, 1);
456    kp[1] = x0;
457    EXPAND_ASSIST(x0, x1, x2, x0, 255, 2);
458    kp[2] = x0;
459    EXPAND_ASSIST(x0, x1, x2, x0, 255, 4);
460    kp[3] = x0;
461    EXPAND_ASSIST(x0, x1, x2, x0, 255, 8);
462    kp[4] = x0;
463    EXPAND_ASSIST(x0, x1, x2, x0, 255, 16);
464    kp[5] = x0;
465    EXPAND_ASSIST(x0, x1, x2, x0, 255, 32);
466    kp[6] = x0;
467    EXPAND_ASSIST(x0, x1, x2, x0, 255, 64);
468    kp[7] = x0;
469    EXPAND_ASSIST(x0, x1, x2, x0, 255, 128);
470    kp[8] = x0;
471    EXPAND_ASSIST(x0, x1, x2, x0, 255, 27);
472    kp[9] = x0;
473    EXPAND_ASSIST(x0, x1, x2, x0, 255, 54);
474    kp[10] = x0;
475}
476
477static void AES_192_Key_Expansion(const unsigned char* userkey, void* key) {
478    __m128i x0, x1, x2, x3, tmp, *kp = (__m128i*)key;
479    kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
480    tmp = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
481    x2 = _mm_setzero_si128();
482    EXPAND192_STEP(1, 1);
483    EXPAND192_STEP(4, 4);
484    EXPAND192_STEP(7, 16);
485    EXPAND192_STEP(10, 64);
486}
487
488static void AES_256_Key_Expansion(const unsigned char* userkey, void* key) {
489    __m128i x0, x1, x2, x3, *kp = (__m128i*)key;
490    kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
491    kp[1] = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
492    x2 = _mm_setzero_si128();
493    EXPAND_ASSIST(x0, x1, x2, x3, 255, 1);
494    kp[2] = x0;
495    EXPAND_ASSIST(x3, x1, x2, x0, 170, 1);
496    kp[3] = x3;
497    EXPAND_ASSIST(x0, x1, x2, x3, 255, 2);
498    kp[4] = x0;
499    EXPAND_ASSIST(x3, x1, x2, x0, 170, 2);
500    kp[5] = x3;
501    EXPAND_ASSIST(x0, x1, x2, x3, 255, 4);
502    kp[6] = x0;
503    EXPAND_ASSIST(x3, x1, x2, x0, 170, 4);
504    kp[7] = x3;
505    EXPAND_ASSIST(x0, x1, x2, x3, 255, 8);
506    kp[8] = x0;
507    EXPAND_ASSIST(x3, x1, x2, x0, 170, 8);
508    kp[9] = x3;
509    EXPAND_ASSIST(x0, x1, x2, x3, 255, 16);
510    kp[10] = x0;
511    EXPAND_ASSIST(x3, x1, x2, x0, 170, 16);
512    kp[11] = x3;
513    EXPAND_ASSIST(x0, x1, x2, x3, 255, 32);
514    kp[12] = x0;
515    EXPAND_ASSIST(x3, x1, x2, x0, 170, 32);
516    kp[13] = x3;
517    EXPAND_ASSIST(x0, x1, x2, x3, 255, 64);
518    kp[14] = x0;
519}
520
521static int AES_set_encrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
522    if (bits == 128) {
523        AES_128_Key_Expansion(userKey, key);
524    } else if (bits == 192) {
525        AES_192_Key_Expansion(userKey, key);
526    } else if (bits == 256) {
527        AES_256_Key_Expansion(userKey, key);
528    }
529#if (OCB_KEY_LEN == 0)
530    key->rounds = 6 + bits / 32;
531#endif
532    return 0;
533}
534
535static void AES_set_decrypt_key_fast(AES_KEY* dkey, const AES_KEY* ekey) {
536    int j = 0;
537    int i = ROUNDS(ekey);
538#if (OCB_KEY_LEN == 0)
539    dkey->rounds = i;
540#endif
541    dkey->rd_key[i--] = ekey->rd_key[j++];
542    while (i)
543        dkey->rd_key[i--] = _mm_aesimc_si128(ekey->rd_key[j++]);
544    dkey->rd_key[i] = ekey->rd_key[j];
545}
546
547static int AES_set_decrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
548    AES_KEY temp_key;
549    AES_set_encrypt_key(userKey, bits, &temp_key);
550    AES_set_decrypt_key_fast(key, &temp_key);
551    return 0;
552}
553
554static inline void AES_encrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
555    int j, rnds = ROUNDS(key);
556    const __m128i* sched = ((__m128i*)(key->rd_key));
557    __m128i tmp = _mm_load_si128((__m128i*)in);
558    tmp = _mm_xor_si128(tmp, sched[0]);
559    for (j = 1; j < rnds; j++)
560        tmp = _mm_aesenc_si128(tmp, sched[j]);
561    tmp = _mm_aesenclast_si128(tmp, sched[j]);
562    _mm_store_si128((__m128i*)out, tmp);
563}
564
565static inline void AES_decrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
566    int j, rnds = ROUNDS(key);
567    const __m128i* sched = ((__m128i*)(key->rd_key));
568    __m128i tmp = _mm_load_si128((__m128i*)in);
569    tmp = _mm_xor_si128(tmp, sched[0]);
570    for (j = 1; j < rnds; j++)
571        tmp = _mm_aesdec_si128(tmp, sched[j]);
572    tmp = _mm_aesdeclast_si128(tmp, sched[j]);
573    _mm_store_si128((__m128i*)out, tmp);
574}
575
576static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
577    unsigned i, j, rnds = ROUNDS(key);
578    const __m128i* sched = ((__m128i*)(key->rd_key));
579    for (i = 0; i < nblks; ++i)
580        blks[i] = _mm_xor_si128(blks[i], sched[0]);
581    for (j = 1; j < rnds; ++j)
582        for (i = 0; i < nblks; ++i)
583            blks[i] = _mm_aesenc_si128(blks[i], sched[j]);
584    for (i = 0; i < nblks; ++i)
585        blks[i] = _mm_aesenclast_si128(blks[i], sched[j]);
586}
587
588static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
589    unsigned i, j, rnds = ROUNDS(key);
590    const __m128i* sched = ((__m128i*)(key->rd_key));
591    for (i = 0; i < nblks; ++i)
592        blks[i] = _mm_xor_si128(blks[i], sched[0]);
593    for (j = 1; j < rnds; ++j)
594        for (i = 0; i < nblks; ++i)
595            blks[i] = _mm_aesdec_si128(blks[i], sched[j]);
596    for (i = 0; i < nblks; ++i)
597        blks[i] = _mm_aesdeclast_si128(blks[i], sched[j]);
598}
599
600#define BPI 8 /* Number of blocks in buffer per ECB call   */
601/* Set to 4 for Westmere, 8 for Sandy Bridge */
602
603#endif
604
605/* ----------------------------------------------------------------------- */
606/* Define OCB context structure.                                           */
607/* ----------------------------------------------------------------------- */
608
609/*------------------------------------------------------------------------
610/ Each item in the OCB context is stored either "memory correct" or
611/ "register correct". On big-endian machines, this is identical. On
612/ little-endian machines, one must choose whether the byte-string
613/ is in the correct order when it resides in memory or in registers.
614/ It must be register correct whenever it is to be manipulated
615/ arithmetically, but must be memory correct whenever it interacts
616/ with the plaintext or ciphertext.
617/------------------------------------------------------------------------- */
618
619struct _ae_ctx {
620    block offset;        /* Memory correct               */
621    block checksum;      /* Memory correct               */
622    block Lstar;         /* Memory correct               */
623    block Ldollar;       /* Memory correct               */
624    block L[L_TABLE_SZ]; /* Memory correct               */
625    block ad_checksum;   /* Memory correct               */
626    block ad_offset;     /* Memory correct               */
627    block cached_Top;    /* Memory correct               */
628    uint64_t KtopStr[3]; /* Register correct, each item  */
629    uint32_t ad_blocks_processed;
630    uint32_t blocks_processed;
631    AES_KEY decrypt_key;
632    AES_KEY encrypt_key;
633#if (OCB_TAG_LEN == 0)
634    unsigned tag_len;
635#endif
636};
637
638/* ----------------------------------------------------------------------- */
639/* L table lookup (or on-the-fly generation)                               */
640/* ----------------------------------------------------------------------- */
641
642#if L_TABLE_SZ_IS_ENOUGH
643#define getL(_ctx, _tz) ((_ctx)->L[_tz])
644#else
645static block getL(const ae_ctx* ctx, unsigned tz) {
646    if (tz < L_TABLE_SZ)
647        return ctx->L[tz];
648    else {
649        unsigned i;
650        /* Bring L[MAX] into registers, make it register correct */
651        block rval = swap_if_le(ctx->L[L_TABLE_SZ - 1]);
652        rval = double_block(rval);
653        for (i = L_TABLE_SZ; i < tz; i++)
654            rval = double_block(rval);
655        return swap_if_le(rval); /* To memory correct */
656    }
657}
658#endif
659
660/* ----------------------------------------------------------------------- */
661/* Public functions                                                        */
662/* ----------------------------------------------------------------------- */
663
664/* 32-bit SSE2 and Altivec systems need to be forced to allocate memory
665   on 16-byte alignments. (I believe all major 64-bit systems do already.) */
666
667ae_ctx* ae_allocate(void* misc) {
668    void* p;
669    (void)misc; /* misc unused in this implementation */
670#if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
671    p = _mm_malloc(sizeof(ae_ctx), 16);
672#elif(__ALTIVEC__ && !__PPC64__)
673    if (posix_memalign(&p, 16, sizeof(ae_ctx)) != 0)
674        p = NULL;
675#else
676    p = malloc(sizeof(ae_ctx));
677#endif
678    return (ae_ctx*)p;
679}
680
681void ae_free(ae_ctx* ctx) {
682#if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
683    _mm_free(ctx);
684#else
685    free(ctx);
686#endif
687}
688
689/* ----------------------------------------------------------------------- */
690
691int ae_clear(ae_ctx* ctx) /* Zero ae_ctx and undo initialization          */
692{
693    memset(ctx, 0, sizeof(ae_ctx));
694    return AE_SUCCESS;
695}
696
697int ae_ctx_sizeof(void) {
698    return (int)sizeof(ae_ctx);
699}
700
701/* ----------------------------------------------------------------------- */
702
703int ae_init(ae_ctx* ctx, const void* key, int key_len, int nonce_len, int tag_len) {
704    unsigned i;
705    block tmp_blk;
706
707    if (nonce_len != 12)
708        return AE_NOT_SUPPORTED;
709
710/* Initialize encryption & decryption keys */
711#if (OCB_KEY_LEN > 0)
712    key_len = OCB_KEY_LEN;
713#endif
714    AES_set_encrypt_key((unsigned char*)key, key_len * 8, &ctx->encrypt_key);
715#if USE_AES_NI
716    AES_set_decrypt_key_fast(&ctx->decrypt_key, &ctx->encrypt_key);
717#else
718    AES_set_decrypt_key((unsigned char*)key, (int)(key_len * 8), &ctx->decrypt_key);
719#endif
720
721    /* Zero things that need zeroing */
722    ctx->cached_Top = ctx->ad_checksum = zero_block();
723    ctx->ad_blocks_processed = 0;
724
725    /* Compute key-dependent values */
726    AES_encrypt((unsigned char*)&ctx->cached_Top, (unsigned char*)&ctx->Lstar, &ctx->encrypt_key);
727    tmp_blk = swap_if_le(ctx->Lstar);
728    tmp_blk = double_block(tmp_blk);
729    ctx->Ldollar = swap_if_le(tmp_blk);
730    tmp_blk = double_block(tmp_blk);
731    ctx->L[0] = swap_if_le(tmp_blk);
732    for (i = 1; i < L_TABLE_SZ; i++) {
733        tmp_blk = double_block(tmp_blk);
734        ctx->L[i] = swap_if_le(tmp_blk);
735    }
736
737#if (OCB_TAG_LEN == 0)
738    ctx->tag_len = tag_len;
739#else
740    (void)tag_len; /* Suppress var not used error */
741#endif
742
743    return AE_SUCCESS;
744}
745
746/* ----------------------------------------------------------------------- */
747
748static block gen_offset_from_nonce(ae_ctx* ctx, const void* nonce) {
749    const union {
750        unsigned x;
751        unsigned char endian;
752    } little = {1};
753    union {
754        uint32_t u32[4];
755        uint8_t u8[16];
756        block bl;
757    } tmp;
758    unsigned idx;
759
760/* Replace cached nonce Top if needed */
761#if (OCB_TAG_LEN > 0)
762    if (little.endian)
763        tmp.u32[0] = 0x01000000 + ((OCB_TAG_LEN * 8 % 128) << 1);
764    else
765        tmp.u32[0] = 0x00000001 + ((OCB_TAG_LEN * 8 % 128) << 25);
766#else
767    if (little.endian)
768        tmp.u32[0] = 0x01000000 + ((ctx->tag_len * 8 % 128) << 1);
769    else
770        tmp.u32[0] = 0x00000001 + ((ctx->tag_len * 8 % 128) << 25);
771#endif
772    tmp.u32[1] = ((uint32_t*)nonce)[0];
773    tmp.u32[2] = ((uint32_t*)nonce)[1];
774    tmp.u32[3] = ((uint32_t*)nonce)[2];
775    idx = (unsigned)(tmp.u8[15] & 0x3f);           /* Get low 6 bits of nonce  */
776    tmp.u8[15] = tmp.u8[15] & 0xc0;                /* Zero low 6 bits of nonce */
777    if (unequal_blocks(tmp.bl, ctx->cached_Top)) { /* Cached?       */
778        ctx->cached_Top = tmp.bl;                  /* Update cache, KtopStr    */
779        AES_encrypt(tmp.u8, (unsigned char*)&ctx->KtopStr, &ctx->encrypt_key);
780        if (little.endian) { /* Make Register Correct    */
781            ctx->KtopStr[0] = bswap64(ctx->KtopStr[0]);
782            ctx->KtopStr[1] = bswap64(ctx->KtopStr[1]);
783        }
784        ctx->KtopStr[2] = ctx->KtopStr[0] ^ (ctx->KtopStr[0] << 8) ^ (ctx->KtopStr[1] >> 56);
785    }
786    return gen_offset(ctx->KtopStr, idx);
787}
788
789static void process_ad(ae_ctx* ctx, const void* ad, int ad_len, int final) {
790    union {
791        uint32_t u32[4];
792        uint8_t u8[16];
793        block bl;
794    } tmp;
795    block ad_offset, ad_checksum;
796    const block* adp = (block*)ad;
797    unsigned i, k, tz, remaining;
798
799    ad_offset = ctx->ad_offset;
800    ad_checksum = ctx->ad_checksum;
801    i = ad_len / (BPI * 16);
802    if (i) {
803        unsigned ad_block_num = ctx->ad_blocks_processed;
804        do {
805            block ta[BPI], oa[BPI];
806            ad_block_num += BPI;
807            tz = ntz(ad_block_num);
808            oa[0] = xor_block(ad_offset, ctx->L[0]);
809            ta[0] = xor_block(oa[0], adp[0]);
810            oa[1] = xor_block(oa[0], ctx->L[1]);
811            ta[1] = xor_block(oa[1], adp[1]);
812            oa[2] = xor_block(ad_offset, ctx->L[1]);
813            ta[2] = xor_block(oa[2], adp[2]);
814#if BPI == 4
815            ad_offset = xor_block(oa[2], getL(ctx, tz));
816            ta[3] = xor_block(ad_offset, adp[3]);
817#elif BPI == 8
818            oa[3] = xor_block(oa[2], ctx->L[2]);
819            ta[3] = xor_block(oa[3], adp[3]);
820            oa[4] = xor_block(oa[1], ctx->L[2]);
821            ta[4] = xor_block(oa[4], adp[4]);
822            oa[5] = xor_block(oa[0], ctx->L[2]);
823            ta[5] = xor_block(oa[5], adp[5]);
824            oa[6] = xor_block(ad_offset, ctx->L[2]);
825            ta[6] = xor_block(oa[6], adp[6]);
826            ad_offset = xor_block(oa[6], getL(ctx, tz));
827            ta[7] = xor_block(ad_offset, adp[7]);
828#endif
829            AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
830            ad_checksum = xor_block(ad_checksum, ta[0]);
831            ad_checksum = xor_block(ad_checksum, ta[1]);
832            ad_checksum = xor_block(ad_checksum, ta[2]);
833            ad_checksum = xor_block(ad_checksum, ta[3]);
834#if (BPI == 8)
835            ad_checksum = xor_block(ad_checksum, ta[4]);
836            ad_checksum = xor_block(ad_checksum, ta[5]);
837            ad_checksum = xor_block(ad_checksum, ta[6]);
838            ad_checksum = xor_block(ad_checksum, ta[7]);
839#endif
840            adp += BPI;
841        } while (--i);
842        ctx->ad_blocks_processed = ad_block_num;
843        ctx->ad_offset = ad_offset;
844        ctx->ad_checksum = ad_checksum;
845    }
846
847    if (final) {
848        block ta[BPI];
849
850        /* Process remaining associated data, compute its tag contribution */
851        remaining = ((unsigned)ad_len) % (BPI * 16);
852        if (remaining) {
853            k = 0;
854#if (BPI == 8)
855            if (remaining >= 64) {
856                tmp.bl = xor_block(ad_offset, ctx->L[0]);
857                ta[0] = xor_block(tmp.bl, adp[0]);
858                tmp.bl = xor_block(tmp.bl, ctx->L[1]);
859                ta[1] = xor_block(tmp.bl, adp[1]);
860                ad_offset = xor_block(ad_offset, ctx->L[1]);
861                ta[2] = xor_block(ad_offset, adp[2]);
862                ad_offset = xor_block(ad_offset, ctx->L[2]);
863                ta[3] = xor_block(ad_offset, adp[3]);
864                remaining -= 64;
865                k = 4;
866            }
867#endif
868            if (remaining >= 32) {
869                ad_offset = xor_block(ad_offset, ctx->L[0]);
870                ta[k] = xor_block(ad_offset, adp[k]);
871                ad_offset = xor_block(ad_offset, getL(ctx, ntz(k + 2)));
872                ta[k + 1] = xor_block(ad_offset, adp[k + 1]);
873                remaining -= 32;
874                k += 2;
875            }
876            if (remaining >= 16) {
877                ad_offset = xor_block(ad_offset, ctx->L[0]);
878                ta[k] = xor_block(ad_offset, adp[k]);
879                remaining = remaining - 16;
880                ++k;
881            }
882            if (remaining) {
883                ad_offset = xor_block(ad_offset, ctx->Lstar);
884                tmp.bl = zero_block();
885                memcpy(tmp.u8, adp + k, remaining);
886                tmp.u8[remaining] = (unsigned char)0x80u;
887                ta[k] = xor_block(ad_offset, tmp.bl);
888                ++k;
889            }
890            AES_ecb_encrypt_blks(ta, k, &ctx->encrypt_key);
891            switch (k) {
892#if (BPI == 8)
893            case 8:
894                ad_checksum = xor_block(ad_checksum, ta[7]);
895            case 7:
896                ad_checksum = xor_block(ad_checksum, ta[6]);
897            case 6:
898                ad_checksum = xor_block(ad_checksum, ta[5]);
899            case 5:
900                ad_checksum = xor_block(ad_checksum, ta[4]);
901#endif
902            case 4:
903                ad_checksum = xor_block(ad_checksum, ta[3]);
904            case 3:
905                ad_checksum = xor_block(ad_checksum, ta[2]);
906            case 2:
907                ad_checksum = xor_block(ad_checksum, ta[1]);
908            case 1:
909                ad_checksum = xor_block(ad_checksum, ta[0]);
910            }
911            ctx->ad_checksum = ad_checksum;
912        }
913    }
914}
915
916/* ----------------------------------------------------------------------- */
917
918int ae_encrypt(ae_ctx* ctx, const void* nonce, const void* pt, int pt_len, const void* ad,
919               int ad_len, void* ct, void* tag, int final) {
920    union {
921        uint32_t u32[4];
922        uint8_t u8[16];
923        block bl;
924    } tmp;
925    block offset, checksum;
926    unsigned i, k;
927    block* ctp = (block*)ct;
928    const block* ptp = (block*)pt;
929
930    /* Non-null nonce means start of new message, init per-message values */
931    if (nonce) {
932        ctx->offset = gen_offset_from_nonce(ctx, nonce);
933        ctx->ad_offset = ctx->checksum = zero_block();
934        ctx->ad_blocks_processed = ctx->blocks_processed = 0;
935        if (ad_len >= 0)
936            ctx->ad_checksum = zero_block();
937    }
938
939    /* Process associated data */
940    if (ad_len > 0)
941        process_ad(ctx, ad, ad_len, final);
942
943    /* Encrypt plaintext data BPI blocks at a time */
944    offset = ctx->offset;
945    checksum = ctx->checksum;
946    i = pt_len / (BPI * 16);
947    if (i) {
948        block oa[BPI];
949        unsigned block_num = ctx->blocks_processed;
950        oa[BPI - 1] = offset;
951        do {
952            block ta[BPI];
953            block_num += BPI;
954            oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
955            ta[0] = xor_block(oa[0], ptp[0]);
956            checksum = xor_block(checksum, ptp[0]);
957            oa[1] = xor_block(oa[0], ctx->L[1]);
958            ta[1] = xor_block(oa[1], ptp[1]);
959            checksum = xor_block(checksum, ptp[1]);
960            oa[2] = xor_block(oa[1], ctx->L[0]);
961            ta[2] = xor_block(oa[2], ptp[2]);
962            checksum = xor_block(checksum, ptp[2]);
963#if BPI == 4
964            oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
965            ta[3] = xor_block(oa[3], ptp[3]);
966            checksum = xor_block(checksum, ptp[3]);
967#elif BPI == 8
968            oa[3] = xor_block(oa[2], ctx->L[2]);
969            ta[3] = xor_block(oa[3], ptp[3]);
970            checksum = xor_block(checksum, ptp[3]);
971            oa[4] = xor_block(oa[1], ctx->L[2]);
972            ta[4] = xor_block(oa[4], ptp[4]);
973            checksum = xor_block(checksum, ptp[4]);
974            oa[5] = xor_block(oa[0], ctx->L[2]);
975            ta[5] = xor_block(oa[5], ptp[5]);
976            checksum = xor_block(checksum, ptp[5]);
977            oa[6] = xor_block(oa[7], ctx->L[2]);
978            ta[6] = xor_block(oa[6], ptp[6]);
979            checksum = xor_block(checksum, ptp[6]);
980            oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
981            ta[7] = xor_block(oa[7], ptp[7]);
982            checksum = xor_block(checksum, ptp[7]);
983#endif
984            AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
985            ctp[0] = xor_block(ta[0], oa[0]);
986            ctp[1] = xor_block(ta[1], oa[1]);
987            ctp[2] = xor_block(ta[2], oa[2]);
988            ctp[3] = xor_block(ta[3], oa[3]);
989#if (BPI == 8)
990            ctp[4] = xor_block(ta[4], oa[4]);
991            ctp[5] = xor_block(ta[5], oa[5]);
992            ctp[6] = xor_block(ta[6], oa[6]);
993            ctp[7] = xor_block(ta[7], oa[7]);
994#endif
995            ptp += BPI;
996            ctp += BPI;
997        } while (--i);
998        ctx->offset = offset = oa[BPI - 1];
999        ctx->blocks_processed = block_num;
1000        ctx->checksum = checksum;
1001    }
1002
1003    if (final) {
1004        block ta[BPI + 1], oa[BPI];
1005
1006        /* Process remaining plaintext and compute its tag contribution    */
1007        unsigned remaining = ((unsigned)pt_len) % (BPI * 16);
1008        k = 0; /* How many blocks in ta[] need ECBing */
1009        if (remaining) {
1010#if (BPI == 8)
1011            if (remaining >= 64) {
1012                oa[0] = xor_block(offset, ctx->L[0]);
1013                ta[0] = xor_block(oa[0], ptp[0]);
1014                checksum = xor_block(checksum, ptp[0]);
1015                oa[1] = xor_block(oa[0], ctx->L[1]);
1016                ta[1] = xor_block(oa[1], ptp[1]);
1017                checksum = xor_block(checksum, ptp[1]);
1018                oa[2] = xor_block(oa[1], ctx->L[0]);
1019                ta[2] = xor_block(oa[2], ptp[2]);
1020                checksum = xor_block(checksum, ptp[2]);
1021                offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1022                ta[3] = xor_block(offset, ptp[3]);
1023                checksum = xor_block(checksum, ptp[3]);
1024                remaining -= 64;
1025                k = 4;
1026            }
1027#endif
1028            if (remaining >= 32) {
1029                oa[k] = xor_block(offset, ctx->L[0]);
1030                ta[k] = xor_block(oa[k], ptp[k]);
1031                checksum = xor_block(checksum, ptp[k]);
1032                offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1033                ta[k + 1] = xor_block(offset, ptp[k + 1]);
1034                checksum = xor_block(checksum, ptp[k + 1]);
1035                remaining -= 32;
1036                k += 2;
1037            }
1038            if (remaining >= 16) {
1039                offset = oa[k] = xor_block(offset, ctx->L[0]);
1040                ta[k] = xor_block(offset, ptp[k]);
1041                checksum = xor_block(checksum, ptp[k]);
1042                remaining -= 16;
1043                ++k;
1044            }
1045            if (remaining) {
1046                tmp.bl = zero_block();
1047                memcpy(tmp.u8, ptp + k, remaining);
1048                tmp.u8[remaining] = (unsigned char)0x80u;
1049                checksum = xor_block(checksum, tmp.bl);
1050                ta[k] = offset = xor_block(offset, ctx->Lstar);
1051                ++k;
1052            }
1053        }
1054        offset = xor_block(offset, ctx->Ldollar); /* Part of tag gen */
1055        ta[k] = xor_block(offset, checksum);      /* Part of tag gen */
1056        AES_ecb_encrypt_blks(ta, k + 1, &ctx->encrypt_key);
1057        offset = xor_block(ta[k], ctx->ad_checksum); /* Part of tag gen */
1058        if (remaining) {
1059            --k;
1060            tmp.bl = xor_block(tmp.bl, ta[k]);
1061            memcpy(ctp + k, tmp.u8, remaining);
1062        }
1063        switch (k) {
1064#if (BPI == 8)
1065        case 7:
1066            ctp[6] = xor_block(ta[6], oa[6]);
1067        case 6:
1068            ctp[5] = xor_block(ta[5], oa[5]);
1069        case 5:
1070            ctp[4] = xor_block(ta[4], oa[4]);
1071        case 4:
1072            ctp[3] = xor_block(ta[3], oa[3]);
1073#endif
1074        case 3:
1075            ctp[2] = xor_block(ta[2], oa[2]);
1076        case 2:
1077            ctp[1] = xor_block(ta[1], oa[1]);
1078        case 1:
1079            ctp[0] = xor_block(ta[0], oa[0]);
1080        }
1081
1082        /* Tag is placed at the correct location
1083         */
1084        if (tag) {
1085#if (OCB_TAG_LEN == 16)
1086            *(block*)tag = offset;
1087#elif(OCB_TAG_LEN > 0)
1088            memcpy((char*)tag, &offset, OCB_TAG_LEN);
1089#else
1090            memcpy((char*)tag, &offset, ctx->tag_len);
1091#endif
1092        } else {
1093#if (OCB_TAG_LEN > 0)
1094            memcpy((char*)ct + pt_len, &offset, OCB_TAG_LEN);
1095            pt_len += OCB_TAG_LEN;
1096#else
1097            memcpy((char*)ct + pt_len, &offset, ctx->tag_len);
1098            pt_len += ctx->tag_len;
1099#endif
1100        }
1101    }
1102    return (int)pt_len;
1103}
1104
1105/* ----------------------------------------------------------------------- */
1106
1107/* Compare two regions of memory, taking a constant amount of time for a
1108   given buffer size -- under certain assumptions about the compiler
1109   and machine, of course.
1110
1111   Use this to avoid timing side-channel attacks.
1112
1113   Returns 0 for memory regions with equal contents; non-zero otherwise. */
1114static int constant_time_memcmp(const void* av, const void* bv, size_t n) {
1115    const uint8_t* a = (const uint8_t*)av;
1116    const uint8_t* b = (const uint8_t*)bv;
1117    uint8_t result = 0;
1118    size_t i;
1119
1120    for (i = 0; i < n; i++) {
1121        result |= *a ^ *b;
1122        a++;
1123        b++;
1124    }
1125
1126    return (int)result;
1127}
1128
1129int ae_decrypt(ae_ctx* ctx, const void* nonce, const void* ct, int ct_len, const void* ad,
1130               int ad_len, void* pt, const void* tag, int final) {
1131    union {
1132        uint32_t u32[4];
1133        uint8_t u8[16];
1134        block bl;
1135    } tmp;
1136    block offset, checksum;
1137    unsigned i, k;
1138    block* ctp = (block*)ct;
1139    block* ptp = (block*)pt;
1140
1141    /* Reduce ct_len tag bundled in ct */
1142    if ((final) && (!tag))
1143#if (OCB_TAG_LEN > 0)
1144        ct_len -= OCB_TAG_LEN;
1145#else
1146        ct_len -= ctx->tag_len;
1147#endif
1148
1149    /* Non-null nonce means start of new message, init per-message values */
1150    if (nonce) {
1151        ctx->offset = gen_offset_from_nonce(ctx, nonce);
1152        ctx->ad_offset = ctx->checksum = zero_block();
1153        ctx->ad_blocks_processed = ctx->blocks_processed = 0;
1154        if (ad_len >= 0)
1155            ctx->ad_checksum = zero_block();
1156    }
1157
1158    /* Process associated data */
1159    if (ad_len > 0)
1160        process_ad(ctx, ad, ad_len, final);
1161
1162    /* Encrypt plaintext data BPI blocks at a time */
1163    offset = ctx->offset;
1164    checksum = ctx->checksum;
1165    i = ct_len / (BPI * 16);
1166    if (i) {
1167        block oa[BPI];
1168        unsigned block_num = ctx->blocks_processed;
1169        oa[BPI - 1] = offset;
1170        do {
1171            block ta[BPI];
1172            block_num += BPI;
1173            oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
1174            ta[0] = xor_block(oa[0], ctp[0]);
1175            oa[1] = xor_block(oa[0], ctx->L[1]);
1176            ta[1] = xor_block(oa[1], ctp[1]);
1177            oa[2] = xor_block(oa[1], ctx->L[0]);
1178            ta[2] = xor_block(oa[2], ctp[2]);
1179#if BPI == 4
1180            oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
1181            ta[3] = xor_block(oa[3], ctp[3]);
1182#elif BPI == 8
1183            oa[3] = xor_block(oa[2], ctx->L[2]);
1184            ta[3] = xor_block(oa[3], ctp[3]);
1185            oa[4] = xor_block(oa[1], ctx->L[2]);
1186            ta[4] = xor_block(oa[4], ctp[4]);
1187            oa[5] = xor_block(oa[0], ctx->L[2]);
1188            ta[5] = xor_block(oa[5], ctp[5]);
1189            oa[6] = xor_block(oa[7], ctx->L[2]);
1190            ta[6] = xor_block(oa[6], ctp[6]);
1191            oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
1192            ta[7] = xor_block(oa[7], ctp[7]);
1193#endif
1194            AES_ecb_decrypt_blks(ta, BPI, &ctx->decrypt_key);
1195            ptp[0] = xor_block(ta[0], oa[0]);
1196            checksum = xor_block(checksum, ptp[0]);
1197            ptp[1] = xor_block(ta[1], oa[1]);
1198            checksum = xor_block(checksum, ptp[1]);
1199            ptp[2] = xor_block(ta[2], oa[2]);
1200            checksum = xor_block(checksum, ptp[2]);
1201            ptp[3] = xor_block(ta[3], oa[3]);
1202            checksum = xor_block(checksum, ptp[3]);
1203#if (BPI == 8)
1204            ptp[4] = xor_block(ta[4], oa[4]);
1205            checksum = xor_block(checksum, ptp[4]);
1206            ptp[5] = xor_block(ta[5], oa[5]);
1207            checksum = xor_block(checksum, ptp[5]);
1208            ptp[6] = xor_block(ta[6], oa[6]);
1209            checksum = xor_block(checksum, ptp[6]);
1210            ptp[7] = xor_block(ta[7], oa[7]);
1211            checksum = xor_block(checksum, ptp[7]);
1212#endif
1213            ptp += BPI;
1214            ctp += BPI;
1215        } while (--i);
1216        ctx->offset = offset = oa[BPI - 1];
1217        ctx->blocks_processed = block_num;
1218        ctx->checksum = checksum;
1219    }
1220
1221    if (final) {
1222        block ta[BPI + 1], oa[BPI];
1223
1224        /* Process remaining plaintext and compute its tag contribution    */
1225        unsigned remaining = ((unsigned)ct_len) % (BPI * 16);
1226        k = 0; /* How many blocks in ta[] need ECBing */
1227        if (remaining) {
1228#if (BPI == 8)
1229            if (remaining >= 64) {
1230                oa[0] = xor_block(offset, ctx->L[0]);
1231                ta[0] = xor_block(oa[0], ctp[0]);
1232                oa[1] = xor_block(oa[0], ctx->L[1]);
1233                ta[1] = xor_block(oa[1], ctp[1]);
1234                oa[2] = xor_block(oa[1], ctx->L[0]);
1235                ta[2] = xor_block(oa[2], ctp[2]);
1236                offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1237                ta[3] = xor_block(offset, ctp[3]);
1238                remaining -= 64;
1239                k = 4;
1240            }
1241#endif
1242            if (remaining >= 32) {
1243                oa[k] = xor_block(offset, ctx->L[0]);
1244                ta[k] = xor_block(oa[k], ctp[k]);
1245                offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1246                ta[k + 1] = xor_block(offset, ctp[k + 1]);
1247                remaining -= 32;
1248                k += 2;
1249            }
1250            if (remaining >= 16) {
1251                offset = oa[k] = xor_block(offset, ctx->L[0]);
1252                ta[k] = xor_block(offset, ctp[k]);
1253                remaining -= 16;
1254                ++k;
1255            }
1256            if (remaining) {
1257                block pad;
1258                offset = xor_block(offset, ctx->Lstar);
1259                AES_encrypt((unsigned char*)&offset, tmp.u8, &ctx->encrypt_key);
1260                pad = tmp.bl;
1261                memcpy(tmp.u8, ctp + k, remaining);
1262                tmp.bl = xor_block(tmp.bl, pad);
1263                tmp.u8[remaining] = (unsigned char)0x80u;
1264                memcpy(ptp + k, tmp.u8, remaining);
1265                checksum = xor_block(checksum, tmp.bl);
1266            }
1267        }
1268        AES_ecb_decrypt_blks(ta, k, &ctx->decrypt_key);
1269        switch (k) {
1270#if (BPI == 8)
1271        case 7:
1272            ptp[6] = xor_block(ta[6], oa[6]);
1273            checksum = xor_block(checksum, ptp[6]);
1274        case 6:
1275            ptp[5] = xor_block(ta[5], oa[5]);
1276            checksum = xor_block(checksum, ptp[5]);
1277        case 5:
1278            ptp[4] = xor_block(ta[4], oa[4]);
1279            checksum = xor_block(checksum, ptp[4]);
1280        case 4:
1281            ptp[3] = xor_block(ta[3], oa[3]);
1282            checksum = xor_block(checksum, ptp[3]);
1283#endif
1284        case 3:
1285            ptp[2] = xor_block(ta[2], oa[2]);
1286            checksum = xor_block(checksum, ptp[2]);
1287        case 2:
1288            ptp[1] = xor_block(ta[1], oa[1]);
1289            checksum = xor_block(checksum, ptp[1]);
1290        case 1:
1291            ptp[0] = xor_block(ta[0], oa[0]);
1292            checksum = xor_block(checksum, ptp[0]);
1293        }
1294
1295        /* Calculate expected tag */
1296        offset = xor_block(offset, ctx->Ldollar);
1297        tmp.bl = xor_block(offset, checksum);
1298        AES_encrypt(tmp.u8, tmp.u8, &ctx->encrypt_key);
1299        tmp.bl = xor_block(tmp.bl, ctx->ad_checksum); /* Full tag */
1300
1301        /* Compare with proposed tag, change ct_len if invalid */
1302        if ((OCB_TAG_LEN == 16) && tag) {
1303            if (unequal_blocks(tmp.bl, *(block*)tag))
1304                ct_len = AE_INVALID;
1305        } else {
1306#if (OCB_TAG_LEN > 0)
1307            int len = OCB_TAG_LEN;
1308#else
1309            int len = ctx->tag_len;
1310#endif
1311            if (tag) {
1312                if (constant_time_memcmp(tag, tmp.u8, len) != 0)
1313                    ct_len = AE_INVALID;
1314            } else {
1315                if (constant_time_memcmp((char*)ct + ct_len, tmp.u8, len) != 0)
1316                    ct_len = AE_INVALID;
1317            }
1318        }
1319    }
1320    return ct_len;
1321}
1322
1323/* ----------------------------------------------------------------------- */
1324/* Simple test program                                                     */
1325/* ----------------------------------------------------------------------- */
1326
1327#if 0
1328
1329#include <stdio.h>
1330#include <time.h>
1331
1332#if __GNUC__
1333#define ALIGN(n) __attribute__((aligned(n)))
1334#elif _MSC_VER
1335#define ALIGN(n) __declspec(align(n))
1336#else /* Not GNU/Microsoft: delete alignment uses.     */
1337#define ALIGN(n)
1338#endif
1339
1340static void pbuf(void *p, unsigned len, const void *s)
1341{
1342    unsigned i;
1343    if (s)
1344        printf("%s", (char *)s);
1345    for (i = 0; i < len; i++)
1346        printf("%02X", (unsigned)(((unsigned char *)p)[i]));
1347    printf("\n");
1348}
1349
1350static void vectors(ae_ctx *ctx, int len)
1351{
1352    ALIGN(16) char pt[128];
1353    ALIGN(16) char ct[144];
1354    ALIGN(16) char nonce[] = {0,1,2,3,4,5,6,7,8,9,10,11};
1355    int i;
1356    for (i=0; i < 128; i++) pt[i] = i;
1357    i = ae_encrypt(ctx,nonce,pt,len,pt,len,ct,NULL,AE_FINALIZE);
1358    printf("P=%d,A=%d: ",len,len); pbuf(ct, i, NULL);
1359    i = ae_encrypt(ctx,nonce,pt,0,pt,len,ct,NULL,AE_FINALIZE);
1360    printf("P=%d,A=%d: ",0,len); pbuf(ct, i, NULL);
1361    i = ae_encrypt(ctx,nonce,pt,len,pt,0,ct,NULL,AE_FINALIZE);
1362    printf("P=%d,A=%d: ",len,0); pbuf(ct, i, NULL);
1363}
1364
1365void validate()
1366{
1367    ALIGN(16) char pt[1024];
1368    ALIGN(16) char ct[1024];
1369    ALIGN(16) char tag[16];
1370    ALIGN(16) char nonce[12] = {0,};
1371    ALIGN(16) char key[32] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31};
1372    ae_ctx ctx;
1373    char *val_buf, *next;
1374    int i, len;
1375
1376    val_buf = (char *)malloc(22400 + 16);
1377    next = val_buf = (char *)(((size_t)val_buf + 16) & ~((size_t)15));
1378
1379    if (0) {
1380		ae_init(&ctx, key, 16, 12, 16);
1381		/* pbuf(&ctx, sizeof(ctx), "CTX: "); */
1382		vectors(&ctx,0);
1383		vectors(&ctx,8);
1384		vectors(&ctx,16);
1385		vectors(&ctx,24);
1386		vectors(&ctx,32);
1387		vectors(&ctx,40);
1388    }
1389
1390    memset(key,0,32);
1391    memset(pt,0,128);
1392    ae_init(&ctx, key, OCB_KEY_LEN, 12, OCB_TAG_LEN);
1393
1394    /* RFC Vector test */
1395    for (i = 0; i < 128; i++) {
1396        int first = ((i/3)/(BPI*16))*(BPI*16);
1397        int second = first;
1398        int third = i - (first + second);
1399
1400        nonce[11] = i;
1401
1402        if (0) {
1403            ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,NULL,AE_FINALIZE);
1404            memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1405            next = next+i+OCB_TAG_LEN;
1406
1407            ae_encrypt(&ctx,nonce,pt,i,pt,0,ct,NULL,AE_FINALIZE);
1408            memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1409            next = next+i+OCB_TAG_LEN;
1410
1411            ae_encrypt(&ctx,nonce,pt,0,pt,i,ct,NULL,AE_FINALIZE);
1412            memcpy(next,ct,OCB_TAG_LEN);
1413            next = next+OCB_TAG_LEN;
1414        } else {
1415            ae_encrypt(&ctx,nonce,pt,first,pt,first,ct,NULL,AE_PENDING);
1416            ae_encrypt(&ctx,NULL,pt+first,second,pt+first,second,ct+first,NULL,AE_PENDING);
1417            ae_encrypt(&ctx,NULL,pt+first+second,third,pt+first+second,third,ct+first+second,NULL,AE_FINALIZE);
1418            memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1419            next = next+i+OCB_TAG_LEN;
1420
1421            ae_encrypt(&ctx,nonce,pt,first,pt,0,ct,NULL,AE_PENDING);
1422            ae_encrypt(&ctx,NULL,pt+first,second,pt,0,ct+first,NULL,AE_PENDING);
1423            ae_encrypt(&ctx,NULL,pt+first+second,third,pt,0,ct+first+second,NULL,AE_FINALIZE);
1424            memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1425            next = next+i+OCB_TAG_LEN;
1426
1427            ae_encrypt(&ctx,nonce,pt,0,pt,first,ct,NULL,AE_PENDING);
1428            ae_encrypt(&ctx,NULL,pt,0,pt+first,second,ct,NULL,AE_PENDING);
1429            ae_encrypt(&ctx,NULL,pt,0,pt+first+second,third,ct,NULL,AE_FINALIZE);
1430            memcpy(next,ct,OCB_TAG_LEN);
1431            next = next+OCB_TAG_LEN;
1432        }
1433
1434    }
1435    nonce[11] = 0;
1436    ae_encrypt(&ctx,nonce,NULL,0,val_buf,next-val_buf,ct,tag,AE_FINALIZE);
1437    pbuf(tag,OCB_TAG_LEN,0);
1438
1439
1440    /* Encrypt/Decrypt test */
1441    for (i = 0; i < 128; i++) {
1442        int first = ((i/3)/(BPI*16))*(BPI*16);
1443        int second = first;
1444        int third = i - (first + second);
1445
1446        nonce[11] = i%128;
1447
1448        if (1) {
1449            len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,tag,AE_FINALIZE);
1450            len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,-1,ct,tag,AE_FINALIZE);
1451            len = ae_decrypt(&ctx,nonce,ct,len,val_buf,-1,pt,tag,AE_FINALIZE);
1452            if (len == -1) { printf("Authentication error: %d\n", i); return; }
1453            if (len != i) { printf("Length error: %d\n", i); return; }
1454            if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1455        } else {
1456            len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,NULL,AE_FINALIZE);
1457            ae_decrypt(&ctx,nonce,ct,first,val_buf,first,pt,NULL,AE_PENDING);
1458            ae_decrypt(&ctx,NULL,ct+first,second,val_buf+first,second,pt+first,NULL,AE_PENDING);
1459            len = ae_decrypt(&ctx,NULL,ct+first+second,len-(first+second),val_buf+first+second,third,pt+first+second,NULL,AE_FINALIZE);
1460            if (len == -1) { printf("Authentication error: %d\n", i); return; }
1461            if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1462        }
1463
1464    }
1465    printf("Decrypt: PASS\n");
1466}
1467
1468int main()
1469{
1470    validate();
1471    return 0;
1472}
1473#endif
1474
1475#if USE_AES_NI
1476char infoString[] = "OCB3 (AES-NI)";
1477#elif USE_REFERENCE_AES
1478char infoString[] = "OCB3 (Reference)";
1479#elif USE_OPENSSL_AES
1480char infoString[] = "OCB3 (OpenSSL)";
1481#endif
1482