1# Authors:
2#   Bram Cohen
3#   Trevor Perrin - various changes
4#
5# See the LICENSE file for legal information regarding use of this file.
6# Also see Bram Cohen's statement below
7
8"""
9A pure python (slow) implementation of rijndael with a decent interface
10
11To include -
12
13from rijndael import rijndael
14
15To do a key setup -
16
17r = rijndael(key, block_size = 16)
18
19key must be a string of length 16, 24, or 32
20blocksize must be 16, 24, or 32. Default is 16
21
22To use -
23
24ciphertext = r.encrypt(plaintext)
25plaintext = r.decrypt(ciphertext)
26
27If any strings are of the wrong length a ValueError is thrown
28"""
29
30# ported from the Java reference code by Bram Cohen, bram@gawth.com, April 2001
31# this code is public domain, unless someone makes
32# an intellectual property claim against the reference
33# code, in which case it can be made public domain by
34# deleting all the comments and renaming all the variables
35
36import copy
37import string
38
39shifts = [[[0, 0], [1, 3], [2, 2], [3, 1]],
40          [[0, 0], [1, 5], [2, 4], [3, 3]],
41          [[0, 0], [1, 7], [3, 5], [4, 4]]]
42
43# [keysize][block_size]
44num_rounds = {16: {16: 10, 24: 12, 32: 14}, 24: {16: 12, 24: 12, 32: 14}, 32: {16: 14, 24: 14, 32: 14}}
45
46A = [[1, 1, 1, 1, 1, 0, 0, 0],
47     [0, 1, 1, 1, 1, 1, 0, 0],
48     [0, 0, 1, 1, 1, 1, 1, 0],
49     [0, 0, 0, 1, 1, 1, 1, 1],
50     [1, 0, 0, 0, 1, 1, 1, 1],
51     [1, 1, 0, 0, 0, 1, 1, 1],
52     [1, 1, 1, 0, 0, 0, 1, 1],
53     [1, 1, 1, 1, 0, 0, 0, 1]]
54
55# produce log and alog tables, needed for multiplying in the
56# field GF(2^m) (generator = 3)
57alog = [1]
58for i in range(255):
59    j = (alog[-1] << 1) ^ alog[-1]
60    if j & 0x100 != 0:
61        j ^= 0x11B
62    alog.append(j)
63
64log = [0] * 256
65for i in range(1, 255):
66    log[alog[i]] = i
67
68# multiply two elements of GF(2^m)
69def mul(a, b):
70    if a == 0 or b == 0:
71        return 0
72    return alog[(log[a & 0xFF] + log[b & 0xFF]) % 255]
73
74# substitution box based on F^{-1}(x)
75box = [[0] * 8 for i in range(256)]
76box[1][7] = 1
77for i in range(2, 256):
78    j = alog[255 - log[i]]
79    for t in range(8):
80        box[i][t] = (j >> (7 - t)) & 0x01
81
82B = [0, 1, 1, 0, 0, 0, 1, 1]
83
84# affine transform:  box[i] <- B + A*box[i]
85cox = [[0] * 8 for i in range(256)]
86for i in range(256):
87    for t in range(8):
88        cox[i][t] = B[t]
89        for j in range(8):
90            cox[i][t] ^= A[t][j] * box[i][j]
91
92# S-boxes and inverse S-boxes
93S =  [0] * 256
94Si = [0] * 256
95for i in range(256):
96    S[i] = cox[i][0] << 7
97    for t in range(1, 8):
98        S[i] ^= cox[i][t] << (7-t)
99    Si[S[i] & 0xFF] = i
100
101# T-boxes
102G = [[2, 1, 1, 3],
103    [3, 2, 1, 1],
104    [1, 3, 2, 1],
105    [1, 1, 3, 2]]
106
107AA = [[0] * 8 for i in range(4)]
108
109for i in range(4):
110    for j in range(4):
111        AA[i][j] = G[i][j]
112        AA[i][i+4] = 1
113
114for i in range(4):
115    pivot = AA[i][i]
116    if pivot == 0:
117        t = i + 1
118        while AA[t][i] == 0 and t < 4:
119            t += 1
120            assert t != 4, 'G matrix must be invertible'
121            for j in range(8):
122                AA[i][j], AA[t][j] = AA[t][j], AA[i][j]
123            pivot = AA[i][i]
124    for j in range(8):
125        if AA[i][j] != 0:
126            AA[i][j] = alog[(255 + log[AA[i][j] & 0xFF] - log[pivot & 0xFF]) % 255]
127    for t in range(4):
128        if i != t:
129            for j in range(i+1, 8):
130                AA[t][j] ^= mul(AA[i][j], AA[t][i])
131            AA[t][i] = 0
132
133iG = [[0] * 4 for i in range(4)]
134
135for i in range(4):
136    for j in range(4):
137        iG[i][j] = AA[i][j + 4]
138
139def mul4(a, bs):
140    if a == 0:
141        return 0
142    r = 0
143    for b in bs:
144        r <<= 8
145        if b != 0:
146            r = r | mul(a, b)
147    return r
148
149T1 = []
150T2 = []
151T3 = []
152T4 = []
153T5 = []
154T6 = []
155T7 = []
156T8 = []
157U1 = []
158U2 = []
159U3 = []
160U4 = []
161
162for t in range(256):
163    s = S[t]
164    T1.append(mul4(s, G[0]))
165    T2.append(mul4(s, G[1]))
166    T3.append(mul4(s, G[2]))
167    T4.append(mul4(s, G[3]))
168
169    s = Si[t]
170    T5.append(mul4(s, iG[0]))
171    T6.append(mul4(s, iG[1]))
172    T7.append(mul4(s, iG[2]))
173    T8.append(mul4(s, iG[3]))
174
175    U1.append(mul4(t, iG[0]))
176    U2.append(mul4(t, iG[1]))
177    U3.append(mul4(t, iG[2]))
178    U4.append(mul4(t, iG[3]))
179
180# round constants
181rcon = [1]
182r = 1
183for t in range(1, 30):
184    r = mul(2, r)
185    rcon.append(r)
186
187del A
188del AA
189del pivot
190del B
191del G
192del box
193del log
194del alog
195del i
196del j
197del r
198del s
199del t
200del mul
201del mul4
202del cox
203del iG
204
205class rijndael:
206    def __init__(self, key, block_size = 16):
207        if block_size != 16 and block_size != 24 and block_size != 32:
208            raise ValueError('Invalid block size: ' + str(block_size))
209        if len(key) != 16 and len(key) != 24 and len(key) != 32:
210            raise ValueError('Invalid key size: ' + str(len(key)))
211        self.block_size = block_size
212
213        ROUNDS = num_rounds[len(key)][block_size]
214        BC = block_size // 4
215        # encryption round keys
216        Ke = [[0] * BC for i in range(ROUNDS + 1)]
217        # decryption round keys
218        Kd = [[0] * BC for i in range(ROUNDS + 1)]
219        ROUND_KEY_COUNT = (ROUNDS + 1) * BC
220        KC = len(key) // 4
221
222        # copy user material bytes into temporary ints
223        tk = []
224        for i in range(0, KC):
225            tk.append((key[i * 4] << 24) | (key[i * 4 + 1] << 16) |
226                (key[i * 4 + 2] << 8) | key[i * 4 + 3])
227
228        # copy values into round key arrays
229        t = 0
230        j = 0
231        while j < KC and t < ROUND_KEY_COUNT:
232            Ke[t // BC][t % BC] = tk[j]
233            Kd[ROUNDS - (t // BC)][t % BC] = tk[j]
234            j += 1
235            t += 1
236        tt = 0
237        rconpointer = 0
238        while t < ROUND_KEY_COUNT:
239            # extrapolate using phi (the round key evolution function)
240            tt = tk[KC - 1]
241            tk[0] ^= (S[(tt >> 16) & 0xFF] & 0xFF) << 24 ^  \
242                     (S[(tt >>  8) & 0xFF] & 0xFF) << 16 ^  \
243                     (S[ tt        & 0xFF] & 0xFF) <<  8 ^  \
244                     (S[(tt >> 24) & 0xFF] & 0xFF)       ^  \
245                     (rcon[rconpointer]    & 0xFF) << 24
246            rconpointer += 1
247            if KC != 8:
248                for i in range(1, KC):
249                    tk[i] ^= tk[i-1]
250            else:
251                for i in range(1, KC // 2):
252                    tk[i] ^= tk[i-1]
253                tt = tk[KC // 2 - 1]
254                tk[KC // 2] ^= (S[ tt        & 0xFF] & 0xFF)       ^ \
255                              (S[(tt >>  8) & 0xFF] & 0xFF) <<  8 ^ \
256                              (S[(tt >> 16) & 0xFF] & 0xFF) << 16 ^ \
257                              (S[(tt >> 24) & 0xFF] & 0xFF) << 24
258                for i in range(KC // 2 + 1, KC):
259                    tk[i] ^= tk[i-1]
260            # copy values into round key arrays
261            j = 0
262            while j < KC and t < ROUND_KEY_COUNT:
263                Ke[t // BC][t % BC] = tk[j]
264                Kd[ROUNDS - (t // BC)][t % BC] = tk[j]
265                j += 1
266                t += 1
267        # inverse MixColumn where needed
268        for r in range(1, ROUNDS):
269            for j in range(BC):
270                tt = Kd[r][j]
271                Kd[r][j] = U1[(tt >> 24) & 0xFF] ^ \
272                           U2[(tt >> 16) & 0xFF] ^ \
273                           U3[(tt >>  8) & 0xFF] ^ \
274                           U4[ tt        & 0xFF]
275        self.Ke = Ke
276        self.Kd = Kd
277
278    def encrypt(self, plaintext):
279        if len(plaintext) != self.block_size:
280            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(plaintext)))
281        Ke = self.Ke
282
283        BC = self.block_size // 4
284        ROUNDS = len(Ke) - 1
285        if BC == 4:
286            SC = 0
287        elif BC == 6:
288            SC = 1
289        else:
290            SC = 2
291        s1 = shifts[SC][1][0]
292        s2 = shifts[SC][2][0]
293        s3 = shifts[SC][3][0]
294        a = [0] * BC
295        # temporary work array
296        t = []
297        # plaintext to ints + key
298        for i in range(BC):
299            t.append((plaintext[i * 4    ] << 24 |
300                      plaintext[i * 4 + 1] << 16 |
301                      plaintext[i * 4 + 2] <<  8 |
302                      plaintext[i * 4 + 3]        ) ^ Ke[0][i])
303        # apply round transforms
304        for r in range(1, ROUNDS):
305            for i in range(BC):
306                a[i] = (T1[(t[ i           ] >> 24) & 0xFF] ^
307                        T2[(t[(i + s1) % BC] >> 16) & 0xFF] ^
308                        T3[(t[(i + s2) % BC] >>  8) & 0xFF] ^
309                        T4[ t[(i + s3) % BC]        & 0xFF]  ) ^ Ke[r][i]
310            t = copy.copy(a)
311        # last round is special
312        result = []
313        for i in range(BC):
314            tt = Ke[ROUNDS][i]
315            result.append((S[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
316            result.append((S[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
317            result.append((S[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
318            result.append((S[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
319        return bytearray(result)
320
321    def decrypt(self, ciphertext):
322        if len(ciphertext) != self.block_size:
323            raise ValueError('wrong block length, expected ' + str(self.block_size) + ' got ' + str(len(plaintext)))
324        Kd = self.Kd
325
326        BC = self.block_size // 4
327        ROUNDS = len(Kd) - 1
328        if BC == 4:
329            SC = 0
330        elif BC == 6:
331            SC = 1
332        else:
333            SC = 2
334        s1 = shifts[SC][1][1]
335        s2 = shifts[SC][2][1]
336        s3 = shifts[SC][3][1]
337        a = [0] * BC
338        # temporary work array
339        t = [0] * BC
340        # ciphertext to ints + key
341        for i in range(BC):
342            t[i] = (ciphertext[i * 4    ] << 24 |
343                    ciphertext[i * 4 + 1] << 16 |
344                    ciphertext[i * 4 + 2] <<  8 |
345                    ciphertext[i * 4 + 3]        ) ^ Kd[0][i]
346        # apply round transforms
347        for r in range(1, ROUNDS):
348            for i in range(BC):
349                a[i] = (T5[(t[ i           ] >> 24) & 0xFF] ^
350                        T6[(t[(i + s1) % BC] >> 16) & 0xFF] ^
351                        T7[(t[(i + s2) % BC] >>  8) & 0xFF] ^
352                        T8[ t[(i + s3) % BC]        & 0xFF]  ) ^ Kd[r][i]
353            t = copy.copy(a)
354        # last round is special
355        result = []
356        for i in range(BC):
357            tt = Kd[ROUNDS][i]
358            result.append((Si[(t[ i           ] >> 24) & 0xFF] ^ (tt >> 24)) & 0xFF)
359            result.append((Si[(t[(i + s1) % BC] >> 16) & 0xFF] ^ (tt >> 16)) & 0xFF)
360            result.append((Si[(t[(i + s2) % BC] >>  8) & 0xFF] ^ (tt >>  8)) & 0xFF)
361            result.append((Si[ t[(i + s3) % BC]        & 0xFF] ^  tt       ) & 0xFF)
362        return bytearray(result)
363
364def encrypt(key, block):
365    return rijndael(key, len(block)).encrypt(block)
366
367def decrypt(key, block):
368    return rijndael(key, len(block)).decrypt(block)
369
370def test():
371    def t(kl, bl):
372        b = 'b' * bl
373        r = rijndael('a' * kl, bl)
374        assert r.decrypt(r.encrypt(b)) == b
375    t(16, 16)
376    t(16, 24)
377    t(16, 32)
378    t(24, 16)
379    t(24, 24)
380    t(24, 32)
381    t(32, 16)
382    t(32, 24)
383    t(32, 32)
384
385