1# Authors:
2#   Trevor Perrin
3#   Martin von Loewis - python 3 port
4#
5# See the LICENSE file for legal information regarding use of this file.
6
7"""cryptomath module
8
9This module has basic math/crypto code."""
10from __future__ import print_function
11import os
12import math
13import base64
14import binascii
15
16from .compat import *
17
18
19# **************************************************************************
20# Load Optional Modules
21# **************************************************************************
22
23# Try to load M2Crypto/OpenSSL
24try:
25    from M2Crypto import m2
26    m2cryptoLoaded = True
27
28except ImportError:
29    m2cryptoLoaded = False
30
31#Try to load GMPY
32try:
33    import gmpy
34    gmpyLoaded = True
35except ImportError:
36    gmpyLoaded = False
37
38#Try to load pycrypto
39try:
40    import Crypto.Cipher.AES
41    pycryptoLoaded = True
42except ImportError:
43    pycryptoLoaded = False
44
45
46# **************************************************************************
47# PRNG Functions
48# **************************************************************************
49
50# Check that os.urandom works
51import zlib
52length = len(zlib.compress(os.urandom(1000)))
53assert(length > 900)
54
55def getRandomBytes(howMany):
56    b = bytearray(os.urandom(howMany))
57    assert(len(b) == howMany)
58    return b
59
60prngName = "os.urandom"
61
62# **************************************************************************
63# Simple hash functions
64# **************************************************************************
65
66import hmac
67import hashlib
68
69def MD5(b):
70    return bytearray(hashlib.md5(compat26Str(b)).digest())
71
72def SHA1(b):
73    return bytearray(hashlib.sha1(compat26Str(b)).digest())
74
75def HMAC_MD5(k, b):
76    k = compatHMAC(k)
77    b = compatHMAC(b)
78    return bytearray(hmac.new(k, b, hashlib.md5).digest())
79
80def HMAC_SHA1(k, b):
81    k = compatHMAC(k)
82    b = compatHMAC(b)
83    return bytearray(hmac.new(k, b, hashlib.sha1).digest())
84
85
86# **************************************************************************
87# Converter Functions
88# **************************************************************************
89
90def bytesToNumber(b):
91    total = 0
92    multiplier = 1
93    for count in range(len(b)-1, -1, -1):
94        byte = b[count]
95        total += multiplier * byte
96        multiplier *= 256
97    # Force-cast to long to appease PyCrypto.
98    # https://github.com/trevp/tlslite/issues/15
99    return long(total)
100
101def numberToByteArray(n, howManyBytes=None):
102    """Convert an integer into a bytearray, zero-pad to howManyBytes.
103
104    The returned bytearray may be smaller than howManyBytes, but will
105    not be larger.  The returned bytearray will contain a big-endian
106    encoding of the input integer (n).
107    """
108    if howManyBytes == None:
109        howManyBytes = numBytes(n)
110    b = bytearray(howManyBytes)
111    for count in range(howManyBytes-1, -1, -1):
112        b[count] = int(n % 256)
113        n >>= 8
114    return b
115
116def mpiToNumber(mpi): #mpi is an openssl-format bignum string
117    if (ord(mpi[4]) & 0x80) !=0: #Make sure this is a positive number
118        raise AssertionError()
119    b = bytearray(mpi[4:])
120    return bytesToNumber(b)
121
122def numberToMPI(n):
123    b = numberToByteArray(n)
124    ext = 0
125    #If the high-order bit is going to be set,
126    #add an extra byte of zeros
127    if (numBits(n) & 0x7)==0:
128        ext = 1
129    length = numBytes(n) + ext
130    b = bytearray(4+ext) + b
131    b[0] = (length >> 24) & 0xFF
132    b[1] = (length >> 16) & 0xFF
133    b[2] = (length >> 8) & 0xFF
134    b[3] = length & 0xFF
135    return bytes(b)
136
137
138# **************************************************************************
139# Misc. Utility Functions
140# **************************************************************************
141
142def numBits(n):
143    if n==0:
144        return 0
145    s = "%x" % n
146    return ((len(s)-1)*4) + \
147    {'0':0, '1':1, '2':2, '3':2,
148     '4':3, '5':3, '6':3, '7':3,
149     '8':4, '9':4, 'a':4, 'b':4,
150     'c':4, 'd':4, 'e':4, 'f':4,
151     }[s[0]]
152    return int(math.floor(math.log(n, 2))+1)
153
154def numBytes(n):
155    if n==0:
156        return 0
157    bits = numBits(n)
158    return int(math.ceil(bits / 8.0))
159
160# **************************************************************************
161# Big Number Math
162# **************************************************************************
163
164def getRandomNumber(low, high):
165    if low >= high:
166        raise AssertionError()
167    howManyBits = numBits(high)
168    howManyBytes = numBytes(high)
169    lastBits = howManyBits % 8
170    while 1:
171        bytes = getRandomBytes(howManyBytes)
172        if lastBits:
173            bytes[0] = bytes[0] % (1 << lastBits)
174        n = bytesToNumber(bytes)
175        if n >= low and n < high:
176            return n
177
178def gcd(a,b):
179    a, b = max(a,b), min(a,b)
180    while b:
181        a, b = b, a % b
182    return a
183
184def lcm(a, b):
185    return (a * b) // gcd(a, b)
186
187#Returns inverse of a mod b, zero if none
188#Uses Extended Euclidean Algorithm
189def invMod(a, b):
190    c, d = a, b
191    uc, ud = 1, 0
192    while c != 0:
193        q = d // c
194        c, d = d-(q*c), c
195        uc, ud = ud - (q * uc), uc
196    if d == 1:
197        return ud % b
198    return 0
199
200
201if gmpyLoaded:
202    def powMod(base, power, modulus):
203        base = gmpy.mpz(base)
204        power = gmpy.mpz(power)
205        modulus = gmpy.mpz(modulus)
206        result = pow(base, power, modulus)
207        return long(result)
208
209else:
210    def powMod(base, power, modulus):
211        if power < 0:
212            result = pow(base, power*-1, modulus)
213            result = invMod(result, modulus)
214            return result
215        else:
216            return pow(base, power, modulus)
217
218#Pre-calculate a sieve of the ~100 primes < 1000:
219def makeSieve(n):
220    sieve = list(range(n))
221    for count in range(2, int(math.sqrt(n))):
222        if sieve[count] == 0:
223            continue
224        x = sieve[count] * 2
225        while x < len(sieve):
226            sieve[x] = 0
227            x += sieve[count]
228    sieve = [x for x in sieve[2:] if x]
229    return sieve
230
231sieve = makeSieve(1000)
232
233def isPrime(n, iterations=5, display=False):
234    #Trial division with sieve
235    for x in sieve:
236        if x >= n: return True
237        if n % x == 0: return False
238    #Passed trial division, proceed to Rabin-Miller
239    #Rabin-Miller implemented per Ferguson & Schneier
240    #Compute s, t for Rabin-Miller
241    if display: print("*", end=' ')
242    s, t = n-1, 0
243    while s % 2 == 0:
244        s, t = s//2, t+1
245    #Repeat Rabin-Miller x times
246    a = 2 #Use 2 as a base for first iteration speedup, per HAC
247    for count in range(iterations):
248        v = powMod(a, s, n)
249        if v==1:
250            continue
251        i = 0
252        while v != n-1:
253            if i == t-1:
254                return False
255            else:
256                v, i = powMod(v, 2, n), i+1
257        a = getRandomNumber(2, n)
258    return True
259
260def getRandomPrime(bits, display=False):
261    if bits < 10:
262        raise AssertionError()
263    #The 1.5 ensures the 2 MSBs are set
264    #Thus, when used for p,q in RSA, n will have its MSB set
265    #
266    #Since 30 is lcm(2,3,5), we'll set our test numbers to
267    #29 % 30 and keep them there
268    low = ((2 ** (bits-1)) * 3) // 2
269    high = 2 ** bits - 30
270    p = getRandomNumber(low, high)
271    p += 29 - (p % 30)
272    while 1:
273        if display: print(".", end=' ')
274        p += 30
275        if p >= high:
276            p = getRandomNumber(low, high)
277            p += 29 - (p % 30)
278        if isPrime(p, display=display):
279            return p
280
281#Unused at the moment...
282def getRandomSafePrime(bits, display=False):
283    if bits < 10:
284        raise AssertionError()
285    #The 1.5 ensures the 2 MSBs are set
286    #Thus, when used for p,q in RSA, n will have its MSB set
287    #
288    #Since 30 is lcm(2,3,5), we'll set our test numbers to
289    #29 % 30 and keep them there
290    low = (2 ** (bits-2)) * 3//2
291    high = (2 ** (bits-1)) - 30
292    q = getRandomNumber(low, high)
293    q += 29 - (q % 30)
294    while 1:
295        if display: print(".", end=' ')
296        q += 30
297        if (q >= high):
298            q = getRandomNumber(low, high)
299            q += 29 - (q % 30)
300        #Ideas from Tom Wu's SRP code
301        #Do trial division on p and q before Rabin-Miller
302        if isPrime(q, 0, display=display):
303            p = (2 * q) + 1
304            if isPrime(p, display=display):
305                if isPrime(q, display=display):
306                    return p
307