1/*------------------------------------------------------------------------
2/ OCB Version 3 Reference Code (Optimized C)     Last modified 12-JUN-2013
3/-------------------------------------------------------------------------
4/ Copyright (c) 2013 Ted Krovetz.
5/
6/ Permission to use, copy, modify, and/or distribute this software for any
7/ purpose with or without fee is hereby granted, provided that the above
8/ copyright notice and this permission notice appear in all copies.
9/
10/ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11/ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12/ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
13/ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14/ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
15/ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
16/ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17/
18/ Phillip Rogaway holds patents relevant to OCB. See the following for
19/ his patent grant: http://www.cs.ucdavis.edu/~rogaway/ocb/grant.htm
20/
21/ Special thanks to Keegan McAllister for suggesting several good improvements
22/
23/ Comments are welcome: Ted Krovetz <ted@krovetz.net> - Dedicated to Laurel K
24/------------------------------------------------------------------------- */
25
26/* ----------------------------------------------------------------------- */
27/* Usage notes                                                             */
28/* ----------------------------------------------------------------------- */
29
30/* - When AE_PENDING is passed as the 'final' parameter of any function,
31/    the length parameters must be a multiple of (BPI*16).
32/  - When available, SSE or AltiVec registers are used to manipulate data.
33/    So, when on machines with these facilities, all pointers passed to
34/    any function should be 16-byte aligned.
35/  - Plaintext and ciphertext pointers may be equal (ie, plaintext gets
36/    encrypted in-place), but no other pair of pointers may be equal.
37/  - This code assumes all x86 processors have SSE2 and SSSE3 instructions
38/    when compiling under MSVC. If untrue, alter the #define.
39/  - This code is tested for C99 and recent versions of GCC and MSVC.      */
40
41/* ----------------------------------------------------------------------- */
42/* User configuration options                                              */
43/* ----------------------------------------------------------------------- */
44
45/* Set the AES key length to use and length of authentication tag to produce.
46/  Setting either to 0 requires the value be set at runtime via ae_init().
47/  Some optimizations occur for each when set to a fixed value.            */
48#define OCB_KEY_LEN 16 /* 0, 16, 24 or 32. 0 means set in ae_init */
49#define OCB_TAG_LEN 16 /* 0 to 16. 0 means set in ae_init         */
50
51/* This implementation has built-in support for multiple AES APIs. Set any
52/  one of the following to non-zero to specify which to use.               */
53#define USE_OPENSSL_AES 1   /* http://openssl.org                      */
54#define USE_REFERENCE_AES 0 /* Internet search: rijndael-alg-fst.c     */
55#define USE_AES_NI 0        /* Uses compiler's intrinsics              */
56
57/* During encryption and decryption, various "L values" are required.
58/  The L values can be precomputed during initialization (requiring extra
59/  space in ae_ctx), generated as needed (slightly slowing encryption and
60/  decryption), or some combination of the two. L_TABLE_SZ specifies how many
61/  L values to precompute. L_TABLE_SZ must be at least 3. L_TABLE_SZ*16 bytes
62/  are used for L values in ae_ctx. Plaintext and ciphertexts shorter than
63/  2^L_TABLE_SZ blocks need no L values calculated dynamically.            */
64#define L_TABLE_SZ 16
65
66/* Set L_TABLE_SZ_IS_ENOUGH non-zero iff you know that all plaintexts
67/  will be shorter than 2^(L_TABLE_SZ+4) bytes in length. This results
68/  in better performance.                                                  */
69#define L_TABLE_SZ_IS_ENOUGH 1
70
71/* ----------------------------------------------------------------------- */
72/* Includes and compiler specific definitions                              */
73/* ----------------------------------------------------------------------- */
74
75#include "ae.h"
76#include <stdlib.h>
77#include <string.h>
78
79/* Define standard sized integers                                          */
80#if defined(_MSC_VER) && (_MSC_VER < 1600)
81typedef unsigned __int8 uint8_t;
82typedef unsigned __int32 uint32_t;
83typedef unsigned __int64 uint64_t;
84typedef __int64 int64_t;
85#else
86#include <stdint.h>
87#endif
88
89/* Compiler-specific intrinsics and fixes: bswap64, ntz                    */
90#if _MSC_VER
91#define inline __inline                           /* MSVC doesn't recognize "inline" in C */
92#define restrict __restrict                       /* MSVC doesn't recognize "restrict" in C */
93#define __SSE2__ (_M_IX86 || _M_AMD64 || _M_X64)  /* Assume SSE2  */
94#define __SSSE3__ (_M_IX86 || _M_AMD64 || _M_X64) /* Assume SSSE3 */
95#include <intrin.h>
96#pragma intrinsic(_byteswap_uint64, _BitScanForward, memcpy)
97#define bswap64(x) _byteswap_uint64(x)
98static inline unsigned ntz(unsigned x) {
99    _BitScanForward(&x, x);
100    return x;
101}
102#elif __GNUC__
103#define inline __inline__                   /* No "inline" in GCC ansi C mode */
104#define restrict __restrict__               /* No "restrict" in GCC ansi C mode */
105#define bswap64(x) __builtin_bswap64(x)     /* Assuming GCC 4.3+ */
106#define ntz(x) __builtin_ctz((unsigned)(x)) /* Assuming GCC 3.4+ */
107#else /* Assume some C99 features: stdint.h, inline, restrict */
108#define bswap32(x)                                                                                 \
109    ((((x)&0xff000000u) >> 24) | (((x)&0x00ff0000u) >> 8) | (((x)&0x0000ff00u) << 8) |             \
110     (((x)&0x000000ffu) << 24))
111
112static inline uint64_t bswap64(uint64_t x) {
113    union {
114        uint64_t u64;
115        uint32_t u32[2];
116    } in, out;
117    in.u64 = x;
118    out.u32[0] = bswap32(in.u32[1]);
119    out.u32[1] = bswap32(in.u32[0]);
120    return out.u64;
121}
122
123#if (L_TABLE_SZ <= 9) && (L_TABLE_SZ_IS_ENOUGH) /* < 2^13 byte texts */
124static inline unsigned ntz(unsigned x) {
125    static const unsigned char tz_table[] = {
126        0, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2,
127        3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2,
128        4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 8, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2,
129        3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2, 7, 2, 3, 2, 4, 2, 3, 2,
130        5, 2, 3, 2, 4, 2, 3, 2, 6, 2, 3, 2, 4, 2, 3, 2, 5, 2, 3, 2, 4, 2, 3, 2};
131    return tz_table[x / 4];
132}
133#else                                           /* From http://supertech.csail.mit.edu/papers/debruijn.pdf */
134static inline unsigned ntz(unsigned x) {
135    static const unsigned char tz_table[32] = {0,  1,  28, 2,  29, 14, 24, 3,  30, 22, 20,
136                                               15, 25, 17, 4,  8,  31, 27, 13, 23, 21, 19,
137                                               16, 7,  26, 12, 18, 6,  11, 5,  10, 9};
138    return tz_table[((uint32_t)((x & -x) * 0x077CB531u)) >> 27];
139}
140#endif
141#endif
142
143/* ----------------------------------------------------------------------- */
144/* Define blocks and operations -- Patch if incorrect on your compiler.    */
145/* ----------------------------------------------------------------------- */
146
147#if __SSE2__ && !KEYMASTER_CLANG_TEST_BUILD
148#include <xmmintrin.h> /* SSE instructions and _mm_malloc */
149#include <emmintrin.h> /* SSE2 instructions               */
150typedef __m128i block;
151#define xor_block(x, y) _mm_xor_si128(x, y)
152#define zero_block() _mm_setzero_si128()
153#define unequal_blocks(x, y) (_mm_movemask_epi8(_mm_cmpeq_epi8(x, y)) != 0xffff)
154#if __SSSE3__ || USE_AES_NI
155#include <tmmintrin.h> /* SSSE3 instructions              */
156#define swap_if_le(b)                                                                              \
157    _mm_shuffle_epi8(b, _mm_set_epi8(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15))
158#else
159static inline block swap_if_le(block b) {
160    block a = _mm_shuffle_epi32(b, _MM_SHUFFLE(0, 1, 2, 3));
161    a = _mm_shufflehi_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
162    a = _mm_shufflelo_epi16(a, _MM_SHUFFLE(2, 3, 0, 1));
163    return _mm_xor_si128(_mm_srli_epi16(a, 8), _mm_slli_epi16(a, 8));
164}
165#endif
166static inline block gen_offset(uint64_t KtopStr[3], unsigned bot) {
167    block hi = _mm_load_si128((__m128i*)(KtopStr + 0));  /* hi = B A */
168    block lo = _mm_loadu_si128((__m128i*)(KtopStr + 1)); /* lo = C B */
169    __m128i lshift = _mm_cvtsi32_si128(bot);
170    __m128i rshift = _mm_cvtsi32_si128(64 - bot);
171    lo = _mm_xor_si128(_mm_sll_epi64(hi, lshift), _mm_srl_epi64(lo, rshift));
172#if __SSSE3__ || USE_AES_NI
173    return _mm_shuffle_epi8(lo, _mm_set_epi8(8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7));
174#else
175    return swap_if_le(_mm_shuffle_epi32(lo, _MM_SHUFFLE(1, 0, 3, 2)));
176#endif
177}
178static inline block double_block(block bl) {
179    const __m128i mask = _mm_set_epi32(135, 1, 1, 1);
180    __m128i tmp = _mm_srai_epi32(bl, 31);
181    tmp = _mm_and_si128(tmp, mask);
182    tmp = _mm_shuffle_epi32(tmp, _MM_SHUFFLE(2, 1, 0, 3));
183    bl = _mm_slli_epi32(bl, 1);
184    return _mm_xor_si128(bl, tmp);
185}
186#elif __ALTIVEC__
187#include <altivec.h>
188typedef vector unsigned block;
189#define xor_block(x, y) vec_xor(x, y)
190#define zero_block() vec_splat_u32(0)
191#define unequal_blocks(x, y) vec_any_ne(x, y)
192#define swap_if_le(b) (b)
193#if __PPC64__
194block gen_offset(uint64_t KtopStr[3], unsigned bot) {
195    union {
196        uint64_t u64[2];
197        block bl;
198    } rval;
199    rval.u64[0] = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
200    rval.u64[1] = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
201    return rval.bl;
202}
203#else
204/* Special handling: Shifts are mod 32, and no 64-bit types */
205block gen_offset(uint64_t KtopStr[3], unsigned bot) {
206    const vector unsigned k32 = {32, 32, 32, 32};
207    vector unsigned hi = *(vector unsigned*)(KtopStr + 0);
208    vector unsigned lo = *(vector unsigned*)(KtopStr + 2);
209    vector unsigned bot_vec;
210    if (bot < 32) {
211        lo = vec_sld(hi, lo, 4);
212    } else {
213        vector unsigned t = vec_sld(hi, lo, 4);
214        lo = vec_sld(hi, lo, 8);
215        hi = t;
216        bot = bot - 32;
217    }
218    if (bot == 0)
219        return hi;
220    *(unsigned*)&bot_vec = bot;
221    vector unsigned lshift = vec_splat(bot_vec, 0);
222    vector unsigned rshift = vec_sub(k32, lshift);
223    hi = vec_sl(hi, lshift);
224    lo = vec_sr(lo, rshift);
225    return vec_xor(hi, lo);
226}
227#endif
228static inline block double_block(block b) {
229    const vector unsigned char mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
230    const vector unsigned char perm = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0};
231    const vector unsigned char shift7 = vec_splat_u8(7);
232    const vector unsigned char shift1 = vec_splat_u8(1);
233    vector unsigned char c = (vector unsigned char)b;
234    vector unsigned char t = vec_sra(c, shift7);
235    t = vec_and(t, mask);
236    t = vec_perm(t, t, perm);
237    c = vec_sl(c, shift1);
238    return (block)vec_xor(c, t);
239}
240#elif __ARM_NEON__
241#include <arm_neon.h>
242typedef int8x16_t block; /* Yay! Endian-neutral reads! */
243#define xor_block(x, y) veorq_s8(x, y)
244#define zero_block() vdupq_n_s8(0)
245static inline int unequal_blocks(block a, block b) {
246    int64x2_t t = veorq_s64((int64x2_t)a, (int64x2_t)b);
247    return (vgetq_lane_s64(t, 0) | vgetq_lane_s64(t, 1)) != 0;
248}
249#define swap_if_le(b) (b) /* Using endian-neutral int8x16_t */
250/* KtopStr is reg correct by 64 bits, return mem correct */
251block gen_offset(uint64_t KtopStr[3], unsigned bot) {
252    const union {
253        unsigned x;
254        unsigned char endian;
255    } little = {1};
256    const int64x2_t k64 = {-64, -64};
257    uint64x2_t hi = *(uint64x2_t*)(KtopStr + 0); /* hi = A B */
258    uint64x2_t lo = *(uint64x2_t*)(KtopStr + 1); /* hi = B C */
259    int64x2_t ls = vdupq_n_s64(bot);
260    int64x2_t rs = vqaddq_s64(k64, ls);
261    block rval = (block)veorq_u64(vshlq_u64(hi, ls), vshlq_u64(lo, rs));
262    if (little.endian)
263        rval = vrev64q_s8(rval);
264    return rval;
265}
266static inline block double_block(block b) {
267    const block mask = {135, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
268    block tmp = vshrq_n_s8(b, 7);
269    tmp = vandq_s8(tmp, mask);
270    tmp = vextq_s8(tmp, tmp, 1); /* Rotate high byte to end */
271    b = vshlq_n_s8(b, 1);
272    return veorq_s8(tmp, b);
273}
274#else
275typedef struct { uint64_t l, r; } block;
276static inline block xor_block(block x, block y) {
277    x.l ^= y.l;
278    x.r ^= y.r;
279    return x;
280}
281static inline block zero_block(void) {
282    const block t = {0, 0};
283    return t;
284}
285#define unequal_blocks(x, y) ((((x).l ^ (y).l) | ((x).r ^ (y).r)) != 0)
286static inline block swap_if_le(block b) {
287    const union {
288        unsigned x;
289        unsigned char endian;
290    } little = {1};
291    if (little.endian) {
292        block r;
293        r.l = bswap64(b.l);
294        r.r = bswap64(b.r);
295        return r;
296    } else
297        return b;
298}
299
300/* KtopStr is reg correct by 64 bits, return mem correct */
301block gen_offset(uint64_t KtopStr[3], unsigned bot) {
302    block rval;
303    if (bot != 0) {
304        rval.l = (KtopStr[0] << bot) | (KtopStr[1] >> (64 - bot));
305        rval.r = (KtopStr[1] << bot) | (KtopStr[2] >> (64 - bot));
306    } else {
307        rval.l = KtopStr[0];
308        rval.r = KtopStr[1];
309    }
310    return swap_if_le(rval);
311}
312
313#if __GNUC__ && __arm__
314static inline block double_block(block b) {
315    __asm__("adds %1,%1,%1\n\t"
316            "adcs %H1,%H1,%H1\n\t"
317            "adcs %0,%0,%0\n\t"
318            "adcs %H0,%H0,%H0\n\t"
319            "it cs\n\t"
320            "eorcs %1,%1,#135"
321            : "+r"(b.l), "+r"(b.r)
322            :
323            : "cc");
324    return b;
325}
326#else
327static inline block double_block(block b) {
328    uint64_t t = (uint64_t)((int64_t)b.l >> 63);
329    b.l = (b.l + b.l) ^ (b.r >> 63);
330    b.r = (b.r + b.r) ^ (t & 135);
331    return b;
332}
333#endif
334
335#endif
336
337/* ----------------------------------------------------------------------- */
338/* AES - Code uses OpenSSL API. Other implementations get mapped to it.    */
339/* ----------------------------------------------------------------------- */
340
341/*---------------*/
342#if USE_OPENSSL_AES
343/*---------------*/
344
345#include <openssl/aes.h> /* http://openssl.org/ */
346
347/* How to ECB encrypt an array of blocks, in place                         */
348static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
349    while (nblks) {
350        --nblks;
351        AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
352    }
353}
354
355static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
356    while (nblks) {
357        --nblks;
358        AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
359    }
360}
361
362#define BPI 4 /* Number of blocks in buffer per ECB call */
363
364/*-------------------*/
365#elif USE_REFERENCE_AES
366/*-------------------*/
367
368#include "rijndael-alg-fst.h" /* Barreto's Public-Domain Code */
369#if (OCB_KEY_LEN == 0)
370typedef struct {
371    uint32_t rd_key[60];
372    int rounds;
373} AES_KEY;
374#define ROUNDS(ctx) ((ctx)->rounds)
375#define AES_set_encrypt_key(x, y, z)                                                               \
376    do {                                                                                           \
377        rijndaelKeySetupEnc((z)->rd_key, x, y);                                                    \
378        (z)->rounds = y / 32 + 6;                                                                  \
379    } while (0)
380#define AES_set_decrypt_key(x, y, z)                                                               \
381    do {                                                                                           \
382        rijndaelKeySetupDec((z)->rd_key, x, y);                                                    \
383        (z)->rounds = y / 32 + 6;                                                                  \
384    } while (0)
385#else
386typedef struct { uint32_t rd_key[OCB_KEY_LEN + 28]; } AES_KEY;
387#define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
388#define AES_set_encrypt_key(x, y, z) rijndaelKeySetupEnc((z)->rd_key, x, y)
389#define AES_set_decrypt_key(x, y, z) rijndaelKeySetupDec((z)->rd_key, x, y)
390#endif
391#define AES_encrypt(x, y, z) rijndaelEncrypt((z)->rd_key, ROUNDS(z), x, y)
392#define AES_decrypt(x, y, z) rijndaelDecrypt((z)->rd_key, ROUNDS(z), x, y)
393
394static void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
395    while (nblks) {
396        --nblks;
397        AES_encrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
398    }
399}
400
401void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
402    while (nblks) {
403        --nblks;
404        AES_decrypt((unsigned char*)(blks + nblks), (unsigned char*)(blks + nblks), key);
405    }
406}
407
408#define BPI 4 /* Number of blocks in buffer per ECB call */
409
410/*----------*/
411#elif USE_AES_NI
412/*----------*/
413
414#include <wmmintrin.h>
415
416#if (OCB_KEY_LEN == 0)
417typedef struct {
418    __m128i rd_key[15];
419    int rounds;
420} AES_KEY;
421#define ROUNDS(ctx) ((ctx)->rounds)
422#else
423typedef struct { __m128i rd_key[7 + OCB_KEY_LEN / 4]; } AES_KEY;
424#define ROUNDS(ctx) (6 + OCB_KEY_LEN / 4)
425#endif
426
427#define EXPAND_ASSIST(v1, v2, v3, v4, shuff_const, aes_const)                                      \
428    v2 = _mm_aeskeygenassist_si128(v4, aes_const);                                                 \
429    v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 16));         \
430    v1 = _mm_xor_si128(v1, v3);                                                                    \
431    v3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(v3), _mm_castsi128_ps(v1), 140));        \
432    v1 = _mm_xor_si128(v1, v3);                                                                    \
433    v2 = _mm_shuffle_epi32(v2, shuff_const);                                                       \
434    v1 = _mm_xor_si128(v1, v2)
435
436#define EXPAND192_STEP(idx, aes_const)                                                             \
437    EXPAND_ASSIST(x0, x1, x2, x3, 85, aes_const);                                                  \
438    x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4));                                                 \
439    x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255));                                            \
440    kp[idx] = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(tmp), _mm_castsi128_ps(x0), 68));   \
441    kp[idx + 1] =                                                                                  \
442        _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(x0), _mm_castsi128_ps(x3), 78));          \
443    EXPAND_ASSIST(x0, x1, x2, x3, 85, (aes_const * 2));                                            \
444    x3 = _mm_xor_si128(x3, _mm_slli_si128(x3, 4));                                                 \
445    x3 = _mm_xor_si128(x3, _mm_shuffle_epi32(x0, 255));                                            \
446    kp[idx + 2] = x0;                                                                              \
447    tmp = x3
448
449static void AES_128_Key_Expansion(const unsigned char* userkey, void* key) {
450    __m128i x0, x1, x2;
451    __m128i* kp = (__m128i*)key;
452    kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
453    x2 = _mm_setzero_si128();
454    EXPAND_ASSIST(x0, x1, x2, x0, 255, 1);
455    kp[1] = x0;
456    EXPAND_ASSIST(x0, x1, x2, x0, 255, 2);
457    kp[2] = x0;
458    EXPAND_ASSIST(x0, x1, x2, x0, 255, 4);
459    kp[3] = x0;
460    EXPAND_ASSIST(x0, x1, x2, x0, 255, 8);
461    kp[4] = x0;
462    EXPAND_ASSIST(x0, x1, x2, x0, 255, 16);
463    kp[5] = x0;
464    EXPAND_ASSIST(x0, x1, x2, x0, 255, 32);
465    kp[6] = x0;
466    EXPAND_ASSIST(x0, x1, x2, x0, 255, 64);
467    kp[7] = x0;
468    EXPAND_ASSIST(x0, x1, x2, x0, 255, 128);
469    kp[8] = x0;
470    EXPAND_ASSIST(x0, x1, x2, x0, 255, 27);
471    kp[9] = x0;
472    EXPAND_ASSIST(x0, x1, x2, x0, 255, 54);
473    kp[10] = x0;
474}
475
476static void AES_192_Key_Expansion(const unsigned char* userkey, void* key) {
477    __m128i x0, x1, x2, x3, tmp, *kp = (__m128i*)key;
478    kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
479    tmp = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
480    x2 = _mm_setzero_si128();
481    EXPAND192_STEP(1, 1);
482    EXPAND192_STEP(4, 4);
483    EXPAND192_STEP(7, 16);
484    EXPAND192_STEP(10, 64);
485}
486
487static void AES_256_Key_Expansion(const unsigned char* userkey, void* key) {
488    __m128i x0, x1, x2, x3, *kp = (__m128i*)key;
489    kp[0] = x0 = _mm_loadu_si128((__m128i*)userkey);
490    kp[1] = x3 = _mm_loadu_si128((__m128i*)(userkey + 16));
491    x2 = _mm_setzero_si128();
492    EXPAND_ASSIST(x0, x1, x2, x3, 255, 1);
493    kp[2] = x0;
494    EXPAND_ASSIST(x3, x1, x2, x0, 170, 1);
495    kp[3] = x3;
496    EXPAND_ASSIST(x0, x1, x2, x3, 255, 2);
497    kp[4] = x0;
498    EXPAND_ASSIST(x3, x1, x2, x0, 170, 2);
499    kp[5] = x3;
500    EXPAND_ASSIST(x0, x1, x2, x3, 255, 4);
501    kp[6] = x0;
502    EXPAND_ASSIST(x3, x1, x2, x0, 170, 4);
503    kp[7] = x3;
504    EXPAND_ASSIST(x0, x1, x2, x3, 255, 8);
505    kp[8] = x0;
506    EXPAND_ASSIST(x3, x1, x2, x0, 170, 8);
507    kp[9] = x3;
508    EXPAND_ASSIST(x0, x1, x2, x3, 255, 16);
509    kp[10] = x0;
510    EXPAND_ASSIST(x3, x1, x2, x0, 170, 16);
511    kp[11] = x3;
512    EXPAND_ASSIST(x0, x1, x2, x3, 255, 32);
513    kp[12] = x0;
514    EXPAND_ASSIST(x3, x1, x2, x0, 170, 32);
515    kp[13] = x3;
516    EXPAND_ASSIST(x0, x1, x2, x3, 255, 64);
517    kp[14] = x0;
518}
519
520static int AES_set_encrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
521    if (bits == 128) {
522        AES_128_Key_Expansion(userKey, key);
523    } else if (bits == 192) {
524        AES_192_Key_Expansion(userKey, key);
525    } else if (bits == 256) {
526        AES_256_Key_Expansion(userKey, key);
527    }
528#if (OCB_KEY_LEN == 0)
529    key->rounds = 6 + bits / 32;
530#endif
531    return 0;
532}
533
534static void AES_set_decrypt_key_fast(AES_KEY* dkey, const AES_KEY* ekey) {
535    int j = 0;
536    int i = ROUNDS(ekey);
537#if (OCB_KEY_LEN == 0)
538    dkey->rounds = i;
539#endif
540    dkey->rd_key[i--] = ekey->rd_key[j++];
541    while (i)
542        dkey->rd_key[i--] = _mm_aesimc_si128(ekey->rd_key[j++]);
543    dkey->rd_key[i] = ekey->rd_key[j];
544}
545
546static int AES_set_decrypt_key(const unsigned char* userKey, const int bits, AES_KEY* key) {
547    AES_KEY temp_key;
548    AES_set_encrypt_key(userKey, bits, &temp_key);
549    AES_set_decrypt_key_fast(key, &temp_key);
550    return 0;
551}
552
553static inline void AES_encrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
554    int j, rnds = ROUNDS(key);
555    const __m128i* sched = ((__m128i*)(key->rd_key));
556    __m128i tmp = _mm_load_si128((__m128i*)in);
557    tmp = _mm_xor_si128(tmp, sched[0]);
558    for (j = 1; j < rnds; j++)
559        tmp = _mm_aesenc_si128(tmp, sched[j]);
560    tmp = _mm_aesenclast_si128(tmp, sched[j]);
561    _mm_store_si128((__m128i*)out, tmp);
562}
563
564static inline void AES_decrypt(const unsigned char* in, unsigned char* out, const AES_KEY* key) {
565    int j, rnds = ROUNDS(key);
566    const __m128i* sched = ((__m128i*)(key->rd_key));
567    __m128i tmp = _mm_load_si128((__m128i*)in);
568    tmp = _mm_xor_si128(tmp, sched[0]);
569    for (j = 1; j < rnds; j++)
570        tmp = _mm_aesdec_si128(tmp, sched[j]);
571    tmp = _mm_aesdeclast_si128(tmp, sched[j]);
572    _mm_store_si128((__m128i*)out, tmp);
573}
574
575static inline void AES_ecb_encrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
576    unsigned i, j, rnds = ROUNDS(key);
577    const __m128i* sched = ((__m128i*)(key->rd_key));
578    for (i = 0; i < nblks; ++i)
579        blks[i] = _mm_xor_si128(blks[i], sched[0]);
580    for (j = 1; j < rnds; ++j)
581        for (i = 0; i < nblks; ++i)
582            blks[i] = _mm_aesenc_si128(blks[i], sched[j]);
583    for (i = 0; i < nblks; ++i)
584        blks[i] = _mm_aesenclast_si128(blks[i], sched[j]);
585}
586
587static inline void AES_ecb_decrypt_blks(block* blks, unsigned nblks, AES_KEY* key) {
588    unsigned i, j, rnds = ROUNDS(key);
589    const __m128i* sched = ((__m128i*)(key->rd_key));
590    for (i = 0; i < nblks; ++i)
591        blks[i] = _mm_xor_si128(blks[i], sched[0]);
592    for (j = 1; j < rnds; ++j)
593        for (i = 0; i < nblks; ++i)
594            blks[i] = _mm_aesdec_si128(blks[i], sched[j]);
595    for (i = 0; i < nblks; ++i)
596        blks[i] = _mm_aesdeclast_si128(blks[i], sched[j]);
597}
598
599#define BPI 8 /* Number of blocks in buffer per ECB call   */
600/* Set to 4 for Westmere, 8 for Sandy Bridge */
601
602#endif
603
604/* ----------------------------------------------------------------------- */
605/* Define OCB context structure.                                           */
606/* ----------------------------------------------------------------------- */
607
608/*------------------------------------------------------------------------
609/ Each item in the OCB context is stored either "memory correct" or
610/ "register correct". On big-endian machines, this is identical. On
611/ little-endian machines, one must choose whether the byte-string
612/ is in the correct order when it resides in memory or in registers.
613/ It must be register correct whenever it is to be manipulated
614/ arithmetically, but must be memory correct whenever it interacts
615/ with the plaintext or ciphertext.
616/------------------------------------------------------------------------- */
617
618struct _ae_ctx {
619    block offset;        /* Memory correct               */
620    block checksum;      /* Memory correct               */
621    block Lstar;         /* Memory correct               */
622    block Ldollar;       /* Memory correct               */
623    block L[L_TABLE_SZ]; /* Memory correct               */
624    block ad_checksum;   /* Memory correct               */
625    block ad_offset;     /* Memory correct               */
626    block cached_Top;    /* Memory correct               */
627    uint64_t KtopStr[3]; /* Register correct, each item  */
628    uint32_t ad_blocks_processed;
629    uint32_t blocks_processed;
630    AES_KEY decrypt_key;
631    AES_KEY encrypt_key;
632#if (OCB_TAG_LEN == 0)
633    unsigned tag_len;
634#endif
635};
636
637/* ----------------------------------------------------------------------- */
638/* L table lookup (or on-the-fly generation)                               */
639/* ----------------------------------------------------------------------- */
640
641#if L_TABLE_SZ_IS_ENOUGH
642#define getL(_ctx, _tz) ((_ctx)->L[_tz])
643#else
644static block getL(const ae_ctx* ctx, unsigned tz) {
645    if (tz < L_TABLE_SZ)
646        return ctx->L[tz];
647    else {
648        unsigned i;
649        /* Bring L[MAX] into registers, make it register correct */
650        block rval = swap_if_le(ctx->L[L_TABLE_SZ - 1]);
651        rval = double_block(rval);
652        for (i = L_TABLE_SZ; i < tz; i++)
653            rval = double_block(rval);
654        return swap_if_le(rval); /* To memory correct */
655    }
656}
657#endif
658
659/* ----------------------------------------------------------------------- */
660/* Public functions                                                        */
661/* ----------------------------------------------------------------------- */
662
663/* 32-bit SSE2 and Altivec systems need to be forced to allocate memory
664   on 16-byte alignments. (I believe all major 64-bit systems do already.) */
665
666ae_ctx* ae_allocate(void* misc) {
667    void* p;
668    (void)misc; /* misc unused in this implementation */
669#if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
670    p = _mm_malloc(sizeof(ae_ctx), 16);
671#elif(__ALTIVEC__ && !__PPC64__)
672    if (posix_memalign(&p, 16, sizeof(ae_ctx)) != 0)
673        p = NULL;
674#else
675    p = malloc(sizeof(ae_ctx));
676#endif
677    return (ae_ctx*)p;
678}
679
680void ae_free(ae_ctx* ctx) {
681#if (__SSE2__ && !_M_X64 && !_M_AMD64 && !__amd64__)
682    _mm_free(ctx);
683#else
684    free(ctx);
685#endif
686}
687
688/* ----------------------------------------------------------------------- */
689
690int ae_clear(ae_ctx* ctx) /* Zero ae_ctx and undo initialization          */
691{
692    memset(ctx, 0, sizeof(ae_ctx));
693    return AE_SUCCESS;
694}
695
696int ae_ctx_sizeof(void) {
697    return (int)sizeof(ae_ctx);
698}
699
700/* ----------------------------------------------------------------------- */
701
702int ae_init(ae_ctx* ctx, const void* key, int key_len, int nonce_len, int tag_len) {
703    unsigned i;
704    block tmp_blk;
705
706    if (nonce_len != 12)
707        return AE_NOT_SUPPORTED;
708
709/* Initialize encryption & decryption keys */
710#if (OCB_KEY_LEN > 0)
711    key_len = OCB_KEY_LEN;
712#endif
713    AES_set_encrypt_key((unsigned char*)key, key_len * 8, &ctx->encrypt_key);
714#if USE_AES_NI
715    AES_set_decrypt_key_fast(&ctx->decrypt_key, &ctx->encrypt_key);
716#else
717    AES_set_decrypt_key((unsigned char*)key, (int)(key_len * 8), &ctx->decrypt_key);
718#endif
719
720    /* Zero things that need zeroing */
721    ctx->cached_Top = ctx->ad_checksum = zero_block();
722    ctx->ad_blocks_processed = 0;
723
724    /* Compute key-dependent values */
725    AES_encrypt((unsigned char*)&ctx->cached_Top, (unsigned char*)&ctx->Lstar, &ctx->encrypt_key);
726    tmp_blk = swap_if_le(ctx->Lstar);
727    tmp_blk = double_block(tmp_blk);
728    ctx->Ldollar = swap_if_le(tmp_blk);
729    tmp_blk = double_block(tmp_blk);
730    ctx->L[0] = swap_if_le(tmp_blk);
731    for (i = 1; i < L_TABLE_SZ; i++) {
732        tmp_blk = double_block(tmp_blk);
733        ctx->L[i] = swap_if_le(tmp_blk);
734    }
735
736#if (OCB_TAG_LEN == 0)
737    ctx->tag_len = tag_len;
738#else
739    (void)tag_len; /* Suppress var not used error */
740#endif
741
742    return AE_SUCCESS;
743}
744
745/* ----------------------------------------------------------------------- */
746
747static block gen_offset_from_nonce(ae_ctx* ctx, const void* nonce) {
748    const union {
749        unsigned x;
750        unsigned char endian;
751    } little = {1};
752    union {
753        uint32_t u32[4];
754        uint8_t u8[16];
755        block bl;
756    } tmp;
757    unsigned idx;
758
759/* Replace cached nonce Top if needed */
760#if (OCB_TAG_LEN > 0)
761    if (little.endian)
762        tmp.u32[0] = 0x01000000 + ((OCB_TAG_LEN * 8 % 128) << 1);
763    else
764        tmp.u32[0] = 0x00000001 + ((OCB_TAG_LEN * 8 % 128) << 25);
765#else
766    if (little.endian)
767        tmp.u32[0] = 0x01000000 + ((ctx->tag_len * 8 % 128) << 1);
768    else
769        tmp.u32[0] = 0x00000001 + ((ctx->tag_len * 8 % 128) << 25);
770#endif
771    tmp.u32[1] = ((uint32_t*)nonce)[0];
772    tmp.u32[2] = ((uint32_t*)nonce)[1];
773    tmp.u32[3] = ((uint32_t*)nonce)[2];
774    idx = (unsigned)(tmp.u8[15] & 0x3f);           /* Get low 6 bits of nonce  */
775    tmp.u8[15] = tmp.u8[15] & 0xc0;                /* Zero low 6 bits of nonce */
776    if (unequal_blocks(tmp.bl, ctx->cached_Top)) { /* Cached?       */
777        ctx->cached_Top = tmp.bl;                  /* Update cache, KtopStr    */
778        AES_encrypt(tmp.u8, (unsigned char*)&ctx->KtopStr, &ctx->encrypt_key);
779        if (little.endian) { /* Make Register Correct    */
780            ctx->KtopStr[0] = bswap64(ctx->KtopStr[0]);
781            ctx->KtopStr[1] = bswap64(ctx->KtopStr[1]);
782        }
783        ctx->KtopStr[2] = ctx->KtopStr[0] ^ (ctx->KtopStr[0] << 8) ^ (ctx->KtopStr[1] >> 56);
784    }
785    return gen_offset(ctx->KtopStr, idx);
786}
787
788static void process_ad(ae_ctx* ctx, const void* ad, int ad_len, int final) {
789    union {
790        uint32_t u32[4];
791        uint8_t u8[16];
792        block bl;
793    } tmp;
794    block ad_offset, ad_checksum;
795    const block* adp = (block*)ad;
796    unsigned i, k, tz, remaining;
797
798    ad_offset = ctx->ad_offset;
799    ad_checksum = ctx->ad_checksum;
800    i = ad_len / (BPI * 16);
801    if (i) {
802        unsigned ad_block_num = ctx->ad_blocks_processed;
803        do {
804            block ta[BPI], oa[BPI];
805            ad_block_num += BPI;
806            tz = ntz(ad_block_num);
807            oa[0] = xor_block(ad_offset, ctx->L[0]);
808            ta[0] = xor_block(oa[0], adp[0]);
809            oa[1] = xor_block(oa[0], ctx->L[1]);
810            ta[1] = xor_block(oa[1], adp[1]);
811            oa[2] = xor_block(ad_offset, ctx->L[1]);
812            ta[2] = xor_block(oa[2], adp[2]);
813#if BPI == 4
814            ad_offset = xor_block(oa[2], getL(ctx, tz));
815            ta[3] = xor_block(ad_offset, adp[3]);
816#elif BPI == 8
817            oa[3] = xor_block(oa[2], ctx->L[2]);
818            ta[3] = xor_block(oa[3], adp[3]);
819            oa[4] = xor_block(oa[1], ctx->L[2]);
820            ta[4] = xor_block(oa[4], adp[4]);
821            oa[5] = xor_block(oa[0], ctx->L[2]);
822            ta[5] = xor_block(oa[5], adp[5]);
823            oa[6] = xor_block(ad_offset, ctx->L[2]);
824            ta[6] = xor_block(oa[6], adp[6]);
825            ad_offset = xor_block(oa[6], getL(ctx, tz));
826            ta[7] = xor_block(ad_offset, adp[7]);
827#endif
828            AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
829            ad_checksum = xor_block(ad_checksum, ta[0]);
830            ad_checksum = xor_block(ad_checksum, ta[1]);
831            ad_checksum = xor_block(ad_checksum, ta[2]);
832            ad_checksum = xor_block(ad_checksum, ta[3]);
833#if (BPI == 8)
834            ad_checksum = xor_block(ad_checksum, ta[4]);
835            ad_checksum = xor_block(ad_checksum, ta[5]);
836            ad_checksum = xor_block(ad_checksum, ta[6]);
837            ad_checksum = xor_block(ad_checksum, ta[7]);
838#endif
839            adp += BPI;
840        } while (--i);
841        ctx->ad_blocks_processed = ad_block_num;
842        ctx->ad_offset = ad_offset;
843        ctx->ad_checksum = ad_checksum;
844    }
845
846    if (final) {
847        block ta[BPI];
848
849        /* Process remaining associated data, compute its tag contribution */
850        remaining = ((unsigned)ad_len) % (BPI * 16);
851        if (remaining) {
852            k = 0;
853#if (BPI == 8)
854            if (remaining >= 64) {
855                tmp.bl = xor_block(ad_offset, ctx->L[0]);
856                ta[0] = xor_block(tmp.bl, adp[0]);
857                tmp.bl = xor_block(tmp.bl, ctx->L[1]);
858                ta[1] = xor_block(tmp.bl, adp[1]);
859                ad_offset = xor_block(ad_offset, ctx->L[1]);
860                ta[2] = xor_block(ad_offset, adp[2]);
861                ad_offset = xor_block(ad_offset, ctx->L[2]);
862                ta[3] = xor_block(ad_offset, adp[3]);
863                remaining -= 64;
864                k = 4;
865            }
866#endif
867            if (remaining >= 32) {
868                ad_offset = xor_block(ad_offset, ctx->L[0]);
869                ta[k] = xor_block(ad_offset, adp[k]);
870                ad_offset = xor_block(ad_offset, getL(ctx, ntz(k + 2)));
871                ta[k + 1] = xor_block(ad_offset, adp[k + 1]);
872                remaining -= 32;
873                k += 2;
874            }
875            if (remaining >= 16) {
876                ad_offset = xor_block(ad_offset, ctx->L[0]);
877                ta[k] = xor_block(ad_offset, adp[k]);
878                remaining = remaining - 16;
879                ++k;
880            }
881            if (remaining) {
882                ad_offset = xor_block(ad_offset, ctx->Lstar);
883                tmp.bl = zero_block();
884                memcpy(tmp.u8, adp + k, remaining);
885                tmp.u8[remaining] = (unsigned char)0x80u;
886                ta[k] = xor_block(ad_offset, tmp.bl);
887                ++k;
888            }
889            AES_ecb_encrypt_blks(ta, k, &ctx->encrypt_key);
890            switch (k) {
891#if (BPI == 8)
892            case 8:
893                ad_checksum = xor_block(ad_checksum, ta[7]);
894            case 7:
895                ad_checksum = xor_block(ad_checksum, ta[6]);
896            case 6:
897                ad_checksum = xor_block(ad_checksum, ta[5]);
898            case 5:
899                ad_checksum = xor_block(ad_checksum, ta[4]);
900#endif
901            case 4:
902                ad_checksum = xor_block(ad_checksum, ta[3]);
903            case 3:
904                ad_checksum = xor_block(ad_checksum, ta[2]);
905            case 2:
906                ad_checksum = xor_block(ad_checksum, ta[1]);
907            case 1:
908                ad_checksum = xor_block(ad_checksum, ta[0]);
909            }
910            ctx->ad_checksum = ad_checksum;
911        }
912    }
913}
914
915/* ----------------------------------------------------------------------- */
916
917int ae_encrypt(ae_ctx* ctx, const void* nonce, const void* pt, int pt_len, const void* ad,
918               int ad_len, void* ct, void* tag, int final) {
919    union {
920        uint32_t u32[4];
921        uint8_t u8[16];
922        block bl;
923    } tmp;
924    block offset, checksum;
925    unsigned i, k;
926    block* ctp = (block*)ct;
927    const block* ptp = (block*)pt;
928
929    /* Non-null nonce means start of new message, init per-message values */
930    if (nonce) {
931        ctx->offset = gen_offset_from_nonce(ctx, nonce);
932        ctx->ad_offset = ctx->checksum = zero_block();
933        ctx->ad_blocks_processed = ctx->blocks_processed = 0;
934        if (ad_len >= 0)
935            ctx->ad_checksum = zero_block();
936    }
937
938    /* Process associated data */
939    if (ad_len > 0)
940        process_ad(ctx, ad, ad_len, final);
941
942    /* Encrypt plaintext data BPI blocks at a time */
943    offset = ctx->offset;
944    checksum = ctx->checksum;
945    i = pt_len / (BPI * 16);
946    if (i) {
947        block oa[BPI];
948        unsigned block_num = ctx->blocks_processed;
949        oa[BPI - 1] = offset;
950        do {
951            block ta[BPI];
952            block_num += BPI;
953            oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
954            ta[0] = xor_block(oa[0], ptp[0]);
955            checksum = xor_block(checksum, ptp[0]);
956            oa[1] = xor_block(oa[0], ctx->L[1]);
957            ta[1] = xor_block(oa[1], ptp[1]);
958            checksum = xor_block(checksum, ptp[1]);
959            oa[2] = xor_block(oa[1], ctx->L[0]);
960            ta[2] = xor_block(oa[2], ptp[2]);
961            checksum = xor_block(checksum, ptp[2]);
962#if BPI == 4
963            oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
964            ta[3] = xor_block(oa[3], ptp[3]);
965            checksum = xor_block(checksum, ptp[3]);
966#elif BPI == 8
967            oa[3] = xor_block(oa[2], ctx->L[2]);
968            ta[3] = xor_block(oa[3], ptp[3]);
969            checksum = xor_block(checksum, ptp[3]);
970            oa[4] = xor_block(oa[1], ctx->L[2]);
971            ta[4] = xor_block(oa[4], ptp[4]);
972            checksum = xor_block(checksum, ptp[4]);
973            oa[5] = xor_block(oa[0], ctx->L[2]);
974            ta[5] = xor_block(oa[5], ptp[5]);
975            checksum = xor_block(checksum, ptp[5]);
976            oa[6] = xor_block(oa[7], ctx->L[2]);
977            ta[6] = xor_block(oa[6], ptp[6]);
978            checksum = xor_block(checksum, ptp[6]);
979            oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
980            ta[7] = xor_block(oa[7], ptp[7]);
981            checksum = xor_block(checksum, ptp[7]);
982#endif
983            AES_ecb_encrypt_blks(ta, BPI, &ctx->encrypt_key);
984            ctp[0] = xor_block(ta[0], oa[0]);
985            ctp[1] = xor_block(ta[1], oa[1]);
986            ctp[2] = xor_block(ta[2], oa[2]);
987            ctp[3] = xor_block(ta[3], oa[3]);
988#if (BPI == 8)
989            ctp[4] = xor_block(ta[4], oa[4]);
990            ctp[5] = xor_block(ta[5], oa[5]);
991            ctp[6] = xor_block(ta[6], oa[6]);
992            ctp[7] = xor_block(ta[7], oa[7]);
993#endif
994            ptp += BPI;
995            ctp += BPI;
996        } while (--i);
997        ctx->offset = offset = oa[BPI - 1];
998        ctx->blocks_processed = block_num;
999        ctx->checksum = checksum;
1000    }
1001
1002    if (final) {
1003        block ta[BPI + 1], oa[BPI];
1004
1005        /* Process remaining plaintext and compute its tag contribution    */
1006        unsigned remaining = ((unsigned)pt_len) % (BPI * 16);
1007        k = 0; /* How many blocks in ta[] need ECBing */
1008        if (remaining) {
1009#if (BPI == 8)
1010            if (remaining >= 64) {
1011                oa[0] = xor_block(offset, ctx->L[0]);
1012                ta[0] = xor_block(oa[0], ptp[0]);
1013                checksum = xor_block(checksum, ptp[0]);
1014                oa[1] = xor_block(oa[0], ctx->L[1]);
1015                ta[1] = xor_block(oa[1], ptp[1]);
1016                checksum = xor_block(checksum, ptp[1]);
1017                oa[2] = xor_block(oa[1], ctx->L[0]);
1018                ta[2] = xor_block(oa[2], ptp[2]);
1019                checksum = xor_block(checksum, ptp[2]);
1020                offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1021                ta[3] = xor_block(offset, ptp[3]);
1022                checksum = xor_block(checksum, ptp[3]);
1023                remaining -= 64;
1024                k = 4;
1025            }
1026#endif
1027            if (remaining >= 32) {
1028                oa[k] = xor_block(offset, ctx->L[0]);
1029                ta[k] = xor_block(oa[k], ptp[k]);
1030                checksum = xor_block(checksum, ptp[k]);
1031                offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1032                ta[k + 1] = xor_block(offset, ptp[k + 1]);
1033                checksum = xor_block(checksum, ptp[k + 1]);
1034                remaining -= 32;
1035                k += 2;
1036            }
1037            if (remaining >= 16) {
1038                offset = oa[k] = xor_block(offset, ctx->L[0]);
1039                ta[k] = xor_block(offset, ptp[k]);
1040                checksum = xor_block(checksum, ptp[k]);
1041                remaining -= 16;
1042                ++k;
1043            }
1044            if (remaining) {
1045                tmp.bl = zero_block();
1046                memcpy(tmp.u8, ptp + k, remaining);
1047                tmp.u8[remaining] = (unsigned char)0x80u;
1048                checksum = xor_block(checksum, tmp.bl);
1049                ta[k] = offset = xor_block(offset, ctx->Lstar);
1050                ++k;
1051            }
1052        }
1053        offset = xor_block(offset, ctx->Ldollar); /* Part of tag gen */
1054        ta[k] = xor_block(offset, checksum);      /* Part of tag gen */
1055        AES_ecb_encrypt_blks(ta, k + 1, &ctx->encrypt_key);
1056        offset = xor_block(ta[k], ctx->ad_checksum); /* Part of tag gen */
1057        if (remaining) {
1058            --k;
1059            tmp.bl = xor_block(tmp.bl, ta[k]);
1060            memcpy(ctp + k, tmp.u8, remaining);
1061        }
1062        switch (k) {
1063#if (BPI == 8)
1064        case 7:
1065            ctp[6] = xor_block(ta[6], oa[6]);
1066        case 6:
1067            ctp[5] = xor_block(ta[5], oa[5]);
1068        case 5:
1069            ctp[4] = xor_block(ta[4], oa[4]);
1070        case 4:
1071            ctp[3] = xor_block(ta[3], oa[3]);
1072#endif
1073        case 3:
1074            ctp[2] = xor_block(ta[2], oa[2]);
1075        case 2:
1076            ctp[1] = xor_block(ta[1], oa[1]);
1077        case 1:
1078            ctp[0] = xor_block(ta[0], oa[0]);
1079        }
1080
1081        /* Tag is placed at the correct location
1082         */
1083        if (tag) {
1084#if (OCB_TAG_LEN == 16)
1085            *(block*)tag = offset;
1086#elif(OCB_TAG_LEN > 0)
1087            memcpy((char*)tag, &offset, OCB_TAG_LEN);
1088#else
1089            memcpy((char*)tag, &offset, ctx->tag_len);
1090#endif
1091        } else {
1092#if (OCB_TAG_LEN > 0)
1093            memcpy((char*)ct + pt_len, &offset, OCB_TAG_LEN);
1094            pt_len += OCB_TAG_LEN;
1095#else
1096            memcpy((char*)ct + pt_len, &offset, ctx->tag_len);
1097            pt_len += ctx->tag_len;
1098#endif
1099        }
1100    }
1101    return (int)pt_len;
1102}
1103
1104/* ----------------------------------------------------------------------- */
1105
1106/* Compare two regions of memory, taking a constant amount of time for a
1107   given buffer size -- under certain assumptions about the compiler
1108   and machine, of course.
1109
1110   Use this to avoid timing side-channel attacks.
1111
1112   Returns 0 for memory regions with equal contents; non-zero otherwise. */
1113static int constant_time_memcmp(const void* av, const void* bv, size_t n) {
1114    const uint8_t* a = (const uint8_t*)av;
1115    const uint8_t* b = (const uint8_t*)bv;
1116    uint8_t result = 0;
1117    size_t i;
1118
1119    for (i = 0; i < n; i++) {
1120        result |= *a ^ *b;
1121        a++;
1122        b++;
1123    }
1124
1125    return (int)result;
1126}
1127
1128int ae_decrypt(ae_ctx* ctx, const void* nonce, const void* ct, int ct_len, const void* ad,
1129               int ad_len, void* pt, const void* tag, int final) {
1130    union {
1131        uint32_t u32[4];
1132        uint8_t u8[16];
1133        block bl;
1134    } tmp;
1135    block offset, checksum;
1136    unsigned i, k;
1137    block* ctp = (block*)ct;
1138    block* ptp = (block*)pt;
1139
1140    /* Reduce ct_len tag bundled in ct */
1141    if ((final) && (!tag))
1142#if (OCB_TAG_LEN > 0)
1143        ct_len -= OCB_TAG_LEN;
1144#else
1145        ct_len -= ctx->tag_len;
1146#endif
1147
1148    /* Non-null nonce means start of new message, init per-message values */
1149    if (nonce) {
1150        ctx->offset = gen_offset_from_nonce(ctx, nonce);
1151        ctx->ad_offset = ctx->checksum = zero_block();
1152        ctx->ad_blocks_processed = ctx->blocks_processed = 0;
1153        if (ad_len >= 0)
1154            ctx->ad_checksum = zero_block();
1155    }
1156
1157    /* Process associated data */
1158    if (ad_len > 0)
1159        process_ad(ctx, ad, ad_len, final);
1160
1161    /* Encrypt plaintext data BPI blocks at a time */
1162    offset = ctx->offset;
1163    checksum = ctx->checksum;
1164    i = ct_len / (BPI * 16);
1165    if (i) {
1166        block oa[BPI];
1167        unsigned block_num = ctx->blocks_processed;
1168        oa[BPI - 1] = offset;
1169        do {
1170            block ta[BPI];
1171            block_num += BPI;
1172            oa[0] = xor_block(oa[BPI - 1], ctx->L[0]);
1173            ta[0] = xor_block(oa[0], ctp[0]);
1174            oa[1] = xor_block(oa[0], ctx->L[1]);
1175            ta[1] = xor_block(oa[1], ctp[1]);
1176            oa[2] = xor_block(oa[1], ctx->L[0]);
1177            ta[2] = xor_block(oa[2], ctp[2]);
1178#if BPI == 4
1179            oa[3] = xor_block(oa[2], getL(ctx, ntz(block_num)));
1180            ta[3] = xor_block(oa[3], ctp[3]);
1181#elif BPI == 8
1182            oa[3] = xor_block(oa[2], ctx->L[2]);
1183            ta[3] = xor_block(oa[3], ctp[3]);
1184            oa[4] = xor_block(oa[1], ctx->L[2]);
1185            ta[4] = xor_block(oa[4], ctp[4]);
1186            oa[5] = xor_block(oa[0], ctx->L[2]);
1187            ta[5] = xor_block(oa[5], ctp[5]);
1188            oa[6] = xor_block(oa[7], ctx->L[2]);
1189            ta[6] = xor_block(oa[6], ctp[6]);
1190            oa[7] = xor_block(oa[6], getL(ctx, ntz(block_num)));
1191            ta[7] = xor_block(oa[7], ctp[7]);
1192#endif
1193            AES_ecb_decrypt_blks(ta, BPI, &ctx->decrypt_key);
1194            ptp[0] = xor_block(ta[0], oa[0]);
1195            checksum = xor_block(checksum, ptp[0]);
1196            ptp[1] = xor_block(ta[1], oa[1]);
1197            checksum = xor_block(checksum, ptp[1]);
1198            ptp[2] = xor_block(ta[2], oa[2]);
1199            checksum = xor_block(checksum, ptp[2]);
1200            ptp[3] = xor_block(ta[3], oa[3]);
1201            checksum = xor_block(checksum, ptp[3]);
1202#if (BPI == 8)
1203            ptp[4] = xor_block(ta[4], oa[4]);
1204            checksum = xor_block(checksum, ptp[4]);
1205            ptp[5] = xor_block(ta[5], oa[5]);
1206            checksum = xor_block(checksum, ptp[5]);
1207            ptp[6] = xor_block(ta[6], oa[6]);
1208            checksum = xor_block(checksum, ptp[6]);
1209            ptp[7] = xor_block(ta[7], oa[7]);
1210            checksum = xor_block(checksum, ptp[7]);
1211#endif
1212            ptp += BPI;
1213            ctp += BPI;
1214        } while (--i);
1215        ctx->offset = offset = oa[BPI - 1];
1216        ctx->blocks_processed = block_num;
1217        ctx->checksum = checksum;
1218    }
1219
1220    if (final) {
1221        block ta[BPI + 1], oa[BPI];
1222
1223        /* Process remaining plaintext and compute its tag contribution    */
1224        unsigned remaining = ((unsigned)ct_len) % (BPI * 16);
1225        k = 0; /* How many blocks in ta[] need ECBing */
1226        if (remaining) {
1227#if (BPI == 8)
1228            if (remaining >= 64) {
1229                oa[0] = xor_block(offset, ctx->L[0]);
1230                ta[0] = xor_block(oa[0], ctp[0]);
1231                oa[1] = xor_block(oa[0], ctx->L[1]);
1232                ta[1] = xor_block(oa[1], ctp[1]);
1233                oa[2] = xor_block(oa[1], ctx->L[0]);
1234                ta[2] = xor_block(oa[2], ctp[2]);
1235                offset = oa[3] = xor_block(oa[2], ctx->L[2]);
1236                ta[3] = xor_block(offset, ctp[3]);
1237                remaining -= 64;
1238                k = 4;
1239            }
1240#endif
1241            if (remaining >= 32) {
1242                oa[k] = xor_block(offset, ctx->L[0]);
1243                ta[k] = xor_block(oa[k], ctp[k]);
1244                offset = oa[k + 1] = xor_block(oa[k], ctx->L[1]);
1245                ta[k + 1] = xor_block(offset, ctp[k + 1]);
1246                remaining -= 32;
1247                k += 2;
1248            }
1249            if (remaining >= 16) {
1250                offset = oa[k] = xor_block(offset, ctx->L[0]);
1251                ta[k] = xor_block(offset, ctp[k]);
1252                remaining -= 16;
1253                ++k;
1254            }
1255            if (remaining) {
1256                block pad;
1257                offset = xor_block(offset, ctx->Lstar);
1258                AES_encrypt((unsigned char*)&offset, tmp.u8, &ctx->encrypt_key);
1259                pad = tmp.bl;
1260                memcpy(tmp.u8, ctp + k, remaining);
1261                tmp.bl = xor_block(tmp.bl, pad);
1262                tmp.u8[remaining] = (unsigned char)0x80u;
1263                memcpy(ptp + k, tmp.u8, remaining);
1264                checksum = xor_block(checksum, tmp.bl);
1265            }
1266        }
1267        AES_ecb_decrypt_blks(ta, k, &ctx->decrypt_key);
1268        switch (k) {
1269#if (BPI == 8)
1270        case 7:
1271            ptp[6] = xor_block(ta[6], oa[6]);
1272            checksum = xor_block(checksum, ptp[6]);
1273        case 6:
1274            ptp[5] = xor_block(ta[5], oa[5]);
1275            checksum = xor_block(checksum, ptp[5]);
1276        case 5:
1277            ptp[4] = xor_block(ta[4], oa[4]);
1278            checksum = xor_block(checksum, ptp[4]);
1279        case 4:
1280            ptp[3] = xor_block(ta[3], oa[3]);
1281            checksum = xor_block(checksum, ptp[3]);
1282#endif
1283        case 3:
1284            ptp[2] = xor_block(ta[2], oa[2]);
1285            checksum = xor_block(checksum, ptp[2]);
1286        case 2:
1287            ptp[1] = xor_block(ta[1], oa[1]);
1288            checksum = xor_block(checksum, ptp[1]);
1289        case 1:
1290            ptp[0] = xor_block(ta[0], oa[0]);
1291            checksum = xor_block(checksum, ptp[0]);
1292        }
1293
1294        /* Calculate expected tag */
1295        offset = xor_block(offset, ctx->Ldollar);
1296        tmp.bl = xor_block(offset, checksum);
1297        AES_encrypt(tmp.u8, tmp.u8, &ctx->encrypt_key);
1298        tmp.bl = xor_block(tmp.bl, ctx->ad_checksum); /* Full tag */
1299
1300        /* Compare with proposed tag, change ct_len if invalid */
1301        if ((OCB_TAG_LEN == 16) && tag) {
1302            if (unequal_blocks(tmp.bl, *(block*)tag))
1303                ct_len = AE_INVALID;
1304        } else {
1305#if (OCB_TAG_LEN > 0)
1306            int len = OCB_TAG_LEN;
1307#else
1308            int len = ctx->tag_len;
1309#endif
1310            if (tag) {
1311                if (constant_time_memcmp(tag, tmp.u8, len) != 0)
1312                    ct_len = AE_INVALID;
1313            } else {
1314                if (constant_time_memcmp((char*)ct + ct_len, tmp.u8, len) != 0)
1315                    ct_len = AE_INVALID;
1316            }
1317        }
1318    }
1319    return ct_len;
1320}
1321
1322/* ----------------------------------------------------------------------- */
1323/* Simple test program                                                     */
1324/* ----------------------------------------------------------------------- */
1325
1326#if 0
1327
1328#include <stdio.h>
1329#include <time.h>
1330
1331#if __GNUC__
1332#define ALIGN(n) __attribute__((aligned(n)))
1333#elif _MSC_VER
1334#define ALIGN(n) __declspec(align(n))
1335#else /* Not GNU/Microsoft: delete alignment uses.     */
1336#define ALIGN(n)
1337#endif
1338
1339static void pbuf(void *p, unsigned len, const void *s)
1340{
1341    unsigned i;
1342    if (s)
1343        printf("%s", (char *)s);
1344    for (i = 0; i < len; i++)
1345        printf("%02X", (unsigned)(((unsigned char *)p)[i]));
1346    printf("\n");
1347}
1348
1349static void vectors(ae_ctx *ctx, int len)
1350{
1351    ALIGN(16) char pt[128];
1352    ALIGN(16) char ct[144];
1353    ALIGN(16) char nonce[] = {0,1,2,3,4,5,6,7,8,9,10,11};
1354    int i;
1355    for (i=0; i < 128; i++) pt[i] = i;
1356    i = ae_encrypt(ctx,nonce,pt,len,pt,len,ct,NULL,AE_FINALIZE);
1357    printf("P=%d,A=%d: ",len,len); pbuf(ct, i, NULL);
1358    i = ae_encrypt(ctx,nonce,pt,0,pt,len,ct,NULL,AE_FINALIZE);
1359    printf("P=%d,A=%d: ",0,len); pbuf(ct, i, NULL);
1360    i = ae_encrypt(ctx,nonce,pt,len,pt,0,ct,NULL,AE_FINALIZE);
1361    printf("P=%d,A=%d: ",len,0); pbuf(ct, i, NULL);
1362}
1363
1364void validate()
1365{
1366    ALIGN(16) char pt[1024];
1367    ALIGN(16) char ct[1024];
1368    ALIGN(16) char tag[16];
1369    ALIGN(16) char nonce[12] = {0,};
1370    ALIGN(16) char key[32] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31};
1371    ae_ctx ctx;
1372    char *val_buf, *next;
1373    int i, len;
1374
1375    val_buf = (char *)malloc(22400 + 16);
1376    next = val_buf = (char *)(((size_t)val_buf + 16) & ~((size_t)15));
1377
1378    if (0) {
1379		ae_init(&ctx, key, 16, 12, 16);
1380		/* pbuf(&ctx, sizeof(ctx), "CTX: "); */
1381		vectors(&ctx,0);
1382		vectors(&ctx,8);
1383		vectors(&ctx,16);
1384		vectors(&ctx,24);
1385		vectors(&ctx,32);
1386		vectors(&ctx,40);
1387    }
1388
1389    memset(key,0,32);
1390    memset(pt,0,128);
1391    ae_init(&ctx, key, OCB_KEY_LEN, 12, OCB_TAG_LEN);
1392
1393    /* RFC Vector test */
1394    for (i = 0; i < 128; i++) {
1395        int first = ((i/3)/(BPI*16))*(BPI*16);
1396        int second = first;
1397        int third = i - (first + second);
1398
1399        nonce[11] = i;
1400
1401        if (0) {
1402            ae_encrypt(&ctx,nonce,pt,i,pt,i,ct,NULL,AE_FINALIZE);
1403            memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1404            next = next+i+OCB_TAG_LEN;
1405
1406            ae_encrypt(&ctx,nonce,pt,i,pt,0,ct,NULL,AE_FINALIZE);
1407            memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1408            next = next+i+OCB_TAG_LEN;
1409
1410            ae_encrypt(&ctx,nonce,pt,0,pt,i,ct,NULL,AE_FINALIZE);
1411            memcpy(next,ct,OCB_TAG_LEN);
1412            next = next+OCB_TAG_LEN;
1413        } else {
1414            ae_encrypt(&ctx,nonce,pt,first,pt,first,ct,NULL,AE_PENDING);
1415            ae_encrypt(&ctx,NULL,pt+first,second,pt+first,second,ct+first,NULL,AE_PENDING);
1416            ae_encrypt(&ctx,NULL,pt+first+second,third,pt+first+second,third,ct+first+second,NULL,AE_FINALIZE);
1417            memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1418            next = next+i+OCB_TAG_LEN;
1419
1420            ae_encrypt(&ctx,nonce,pt,first,pt,0,ct,NULL,AE_PENDING);
1421            ae_encrypt(&ctx,NULL,pt+first,second,pt,0,ct+first,NULL,AE_PENDING);
1422            ae_encrypt(&ctx,NULL,pt+first+second,third,pt,0,ct+first+second,NULL,AE_FINALIZE);
1423            memcpy(next,ct,(size_t)i+OCB_TAG_LEN);
1424            next = next+i+OCB_TAG_LEN;
1425
1426            ae_encrypt(&ctx,nonce,pt,0,pt,first,ct,NULL,AE_PENDING);
1427            ae_encrypt(&ctx,NULL,pt,0,pt+first,second,ct,NULL,AE_PENDING);
1428            ae_encrypt(&ctx,NULL,pt,0,pt+first+second,third,ct,NULL,AE_FINALIZE);
1429            memcpy(next,ct,OCB_TAG_LEN);
1430            next = next+OCB_TAG_LEN;
1431        }
1432
1433    }
1434    nonce[11] = 0;
1435    ae_encrypt(&ctx,nonce,NULL,0,val_buf,next-val_buf,ct,tag,AE_FINALIZE);
1436    pbuf(tag,OCB_TAG_LEN,0);
1437
1438
1439    /* Encrypt/Decrypt test */
1440    for (i = 0; i < 128; i++) {
1441        int first = ((i/3)/(BPI*16))*(BPI*16);
1442        int second = first;
1443        int third = i - (first + second);
1444
1445        nonce[11] = i%128;
1446
1447        if (1) {
1448            len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,tag,AE_FINALIZE);
1449            len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,-1,ct,tag,AE_FINALIZE);
1450            len = ae_decrypt(&ctx,nonce,ct,len,val_buf,-1,pt,tag,AE_FINALIZE);
1451            if (len == -1) { printf("Authentication error: %d\n", i); return; }
1452            if (len != i) { printf("Length error: %d\n", i); return; }
1453            if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1454        } else {
1455            len = ae_encrypt(&ctx,nonce,val_buf,i,val_buf,i,ct,NULL,AE_FINALIZE);
1456            ae_decrypt(&ctx,nonce,ct,first,val_buf,first,pt,NULL,AE_PENDING);
1457            ae_decrypt(&ctx,NULL,ct+first,second,val_buf+first,second,pt+first,NULL,AE_PENDING);
1458            len = ae_decrypt(&ctx,NULL,ct+first+second,len-(first+second),val_buf+first+second,third,pt+first+second,NULL,AE_FINALIZE);
1459            if (len == -1) { printf("Authentication error: %d\n", i); return; }
1460            if (memcmp(val_buf,pt,i)) { printf("Decrypt error: %d\n", i); return; }
1461        }
1462
1463    }
1464    printf("Decrypt: PASS\n");
1465}
1466
1467int main()
1468{
1469    validate();
1470    return 0;
1471}
1472#endif
1473
1474#if USE_AES_NI
1475char infoString[] = "OCB3 (AES-NI)";
1476#elif USE_REFERENCE_AES
1477char infoString[] = "OCB3 (Reference)";
1478#elif USE_OPENSSL_AES
1479char infoString[] = "OCB3 (OpenSSL)";
1480#endif
1481