1# Author: Trevor Perrin
2# See the LICENSE file for legal information regarding use of this file.
3
4"""OpenSSL/M2Crypto RSA implementation."""
5
6from .cryptomath import *
7
8from .rsakey import *
9from .python_rsakey import Python_RSAKey
10
11#copied from M2Crypto.util.py, so when we load the local copy of m2
12#we can still use it
13def password_callback(v, prompt1='Enter private key passphrase:',
14                           prompt2='Verify passphrase:'):
15    from getpass import getpass
16    while 1:
17        try:
18            p1=getpass(prompt1)
19            if v:
20                p2=getpass(prompt2)
21                if p1==p2:
22                    break
23            else:
24                break
25        except KeyboardInterrupt:
26            return None
27    return p1
28
29
30if m2cryptoLoaded:
31    class OpenSSL_RSAKey(RSAKey):
32        def __init__(self, n=0, e=0):
33            self.rsa = None
34            self._hasPrivateKey = False
35            if (n and not e) or (e and not n):
36                raise AssertionError()
37            if n and e:
38                self.rsa = m2.rsa_new()
39                m2.rsa_set_n(self.rsa, numberToMPI(n))
40                m2.rsa_set_e(self.rsa, numberToMPI(e))
41
42        def __del__(self):
43            if self.rsa:
44                m2.rsa_free(self.rsa)
45
46        def __getattr__(self, name):
47            if name == 'e':
48                if not self.rsa:
49                    return 0
50                return mpiToNumber(m2.rsa_get_e(self.rsa))
51            elif name == 'n':
52                if not self.rsa:
53                    return 0
54                return mpiToNumber(m2.rsa_get_n(self.rsa))
55            else:
56                raise AttributeError
57
58        def hasPrivateKey(self):
59            return self._hasPrivateKey
60
61        def _rawPrivateKeyOp(self, m):
62            b = numberToByteArray(m, numBytes(self.n))
63            s = m2.rsa_private_encrypt(self.rsa, bytes(b), m2.no_padding)
64            c = bytesToNumber(bytearray(s))
65            return c
66
67        def _rawPublicKeyOp(self, c):
68            b = numberToByteArray(c, numBytes(self.n))
69            s = m2.rsa_public_decrypt(self.rsa, bytes(b), m2.no_padding)
70            m = bytesToNumber(bytearray(s))
71            return m
72
73        def acceptsPassword(self): return True
74
75        def write(self, password=None):
76            bio = m2.bio_new(m2.bio_s_mem())
77            if self._hasPrivateKey:
78                if password:
79                    def f(v): return password
80                    m2.rsa_write_key(self.rsa, bio, m2.des_ede_cbc(), f)
81                else:
82                    def f(): pass
83                    m2.rsa_write_key_no_cipher(self.rsa, bio, f)
84            else:
85                if password:
86                    raise AssertionError()
87                m2.rsa_write_pub_key(self.rsa, bio)
88            s = m2.bio_read(bio, m2.bio_ctrl_pending(bio))
89            m2.bio_free(bio)
90            return s
91
92        def generate(bits):
93            key = OpenSSL_RSAKey()
94            def f():pass
95            key.rsa = m2.rsa_generate_key(bits, 3, f)
96            key._hasPrivateKey = True
97            return key
98        generate = staticmethod(generate)
99
100        def parse(s, passwordCallback=None):
101            # Skip forward to the first PEM header
102            start = s.find("-----BEGIN ")
103            if start == -1:
104                raise SyntaxError()
105            s = s[start:]
106            if s.startswith("-----BEGIN "):
107                if passwordCallback==None:
108                    callback = password_callback
109                else:
110                    def f(v, prompt1=None, prompt2=None):
111                        return passwordCallback()
112                    callback = f
113                bio = m2.bio_new(m2.bio_s_mem())
114                try:
115                    m2.bio_write(bio, s)
116                    key = OpenSSL_RSAKey()
117                    if s.startswith("-----BEGIN RSA PRIVATE KEY-----"):
118                        def f():pass
119                        key.rsa = m2.rsa_read_key(bio, callback)
120                        if key.rsa == None:
121                            raise SyntaxError()
122                        key._hasPrivateKey = True
123                    elif s.startswith("-----BEGIN PUBLIC KEY-----"):
124                        key.rsa = m2.rsa_read_pub_key(bio)
125                        if key.rsa == None:
126                            raise SyntaxError()
127                        key._hasPrivateKey = False
128                    else:
129                        raise SyntaxError()
130                    return key
131                finally:
132                    m2.bio_free(bio)
133            else:
134                raise SyntaxError()
135
136        parse = staticmethod(parse)
137