1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.security.AlgorithmParameters;
20import java.security.InvalidAlgorithmParameterException;
21import java.security.InvalidKeyException;
22import java.security.InvalidParameterException;
23import java.security.Key;
24import java.security.KeyFactory;
25import java.security.NoSuchAlgorithmException;
26import java.security.SecureRandom;
27import java.security.SignatureException;
28import java.security.interfaces.RSAPrivateCrtKey;
29import java.security.interfaces.RSAPrivateKey;
30import java.security.interfaces.RSAPublicKey;
31import java.security.spec.AlgorithmParameterSpec;
32import java.security.spec.InvalidKeySpecException;
33import java.security.spec.PKCS8EncodedKeySpec;
34import java.security.spec.X509EncodedKeySpec;
35import java.util.Arrays;
36import java.util.Locale;
37import javax.crypto.BadPaddingException;
38import javax.crypto.Cipher;
39import javax.crypto.CipherSpi;
40import javax.crypto.IllegalBlockSizeException;
41import javax.crypto.NoSuchPaddingException;
42import javax.crypto.ShortBufferException;
43import javax.crypto.spec.SecretKeySpec;
44import org.conscrypt.util.EmptyArray;
45
46public abstract class OpenSSLCipherRSA extends CipherSpi {
47    /**
48     * The current OpenSSL key we're operating on.
49     */
50    private OpenSSLKey key;
51
52    /**
53     * Current key type: private or public.
54     */
55    private boolean usingPrivateKey;
56
57    /**
58     * Current cipher mode: encrypting or decrypting.
59     */
60    private boolean encrypting;
61
62    /**
63     * Buffer for operations
64     */
65    private byte[] buffer;
66
67    /**
68     * Current offset in the buffer.
69     */
70    private int bufferOffset;
71
72    /**
73     * Flag that indicates an exception should be thrown when the input is too
74     * large during doFinal.
75     */
76    private boolean inputTooLarge;
77
78    /**
79     * Current padding mode
80     */
81    private int padding = NativeCrypto.RSA_PKCS1_PADDING;
82
83    protected OpenSSLCipherRSA(int padding) {
84        this.padding = padding;
85    }
86
87    @Override
88    protected void engineSetMode(String mode) throws NoSuchAlgorithmException {
89        final String modeUpper = mode.toUpperCase(Locale.ROOT);
90        if ("NONE".equals(modeUpper) || "ECB".equals(modeUpper)) {
91            return;
92        }
93
94        throw new NoSuchAlgorithmException("mode not supported: " + mode);
95    }
96
97    @Override
98    protected void engineSetPadding(String padding) throws NoSuchPaddingException {
99        final String paddingUpper = padding.toUpperCase(Locale.ROOT);
100        if ("PKCS1PADDING".equals(paddingUpper)) {
101            this.padding = NativeCrypto.RSA_PKCS1_PADDING;
102            return;
103        }
104        if ("NOPADDING".equals(paddingUpper)) {
105            this.padding = NativeCrypto.RSA_NO_PADDING;
106            return;
107        }
108
109        throw new NoSuchPaddingException("padding not supported: " + padding);
110    }
111
112    @Override
113    protected int engineGetBlockSize() {
114        if (encrypting) {
115            return paddedBlockSizeBytes();
116        }
117        return keySizeBytes();
118    }
119
120    @Override
121    protected int engineGetOutputSize(int inputLen) {
122        if (encrypting) {
123            return keySizeBytes();
124        }
125        return paddedBlockSizeBytes();
126    }
127
128    private int paddedBlockSizeBytes() {
129        int paddedBlockSizeBytes = keySizeBytes();
130        if (padding == NativeCrypto.RSA_PKCS1_PADDING) {
131            paddedBlockSizeBytes--;  // for 0 prefix
132            paddedBlockSizeBytes -= 10;  // PKCS1 padding header length
133        }
134        return paddedBlockSizeBytes;
135    }
136
137    private int keySizeBytes() {
138        if (key == null) {
139            throw new IllegalStateException("cipher is not initialized");
140        }
141        return NativeCrypto.RSA_size(this.key.getPkeyContext());
142    }
143
144    @Override
145    protected byte[] engineGetIV() {
146        return null;
147    }
148
149    @Override
150    protected AlgorithmParameters engineGetParameters() {
151        return null;
152    }
153
154    private void engineInitInternal(int opmode, Key key) throws InvalidKeyException {
155        if (opmode == Cipher.ENCRYPT_MODE || opmode == Cipher.WRAP_MODE) {
156            encrypting = true;
157        } else if (opmode == Cipher.DECRYPT_MODE || opmode == Cipher.UNWRAP_MODE) {
158            encrypting = false;
159        } else {
160            throw new InvalidParameterException("Unsupported opmode " + opmode);
161        }
162
163        if (key instanceof OpenSSLRSAPrivateKey) {
164            OpenSSLRSAPrivateKey rsaPrivateKey = (OpenSSLRSAPrivateKey) key;
165            usingPrivateKey = true;
166            this.key = rsaPrivateKey.getOpenSSLKey();
167        } else if (key instanceof RSAPrivateCrtKey) {
168            RSAPrivateCrtKey rsaPrivateKey = (RSAPrivateCrtKey) key;
169            usingPrivateKey = true;
170            this.key = OpenSSLRSAPrivateCrtKey.getInstance(rsaPrivateKey);
171        } else if (key instanceof RSAPrivateKey) {
172            RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) key;
173            usingPrivateKey = true;
174            this.key = OpenSSLRSAPrivateKey.getInstance(rsaPrivateKey);
175        } else if (key instanceof OpenSSLRSAPublicKey) {
176            OpenSSLRSAPublicKey rsaPublicKey = (OpenSSLRSAPublicKey) key;
177            usingPrivateKey = false;
178            this.key = rsaPublicKey.getOpenSSLKey();
179        } else if (key instanceof RSAPublicKey) {
180            RSAPublicKey rsaPublicKey = (RSAPublicKey) key;
181            usingPrivateKey = false;
182            this.key = OpenSSLRSAPublicKey.getInstance(rsaPublicKey);
183        } else {
184            throw new InvalidKeyException("Need RSA private or public key");
185        }
186
187        buffer = new byte[NativeCrypto.RSA_size(this.key.getPkeyContext())];
188        inputTooLarge = false;
189    }
190
191    @Override
192    protected void engineInit(int opmode, Key key, SecureRandom random) throws InvalidKeyException {
193        engineInitInternal(opmode, key);
194    }
195
196    @Override
197    protected void engineInit(int opmode, Key key, AlgorithmParameterSpec params,
198            SecureRandom random) throws InvalidKeyException, InvalidAlgorithmParameterException {
199        if (params != null) {
200            throw new InvalidAlgorithmParameterException("unknown param type: "
201                    + params.getClass().getName());
202        }
203
204        engineInitInternal(opmode, key);
205    }
206
207    @Override
208    protected void engineInit(int opmode, Key key, AlgorithmParameters params, SecureRandom random)
209            throws InvalidKeyException, InvalidAlgorithmParameterException {
210        if (params != null) {
211            throw new InvalidAlgorithmParameterException("unknown param type: "
212                    + params.getClass().getName());
213        }
214
215        engineInitInternal(opmode, key);
216    }
217
218    @Override
219    protected byte[] engineUpdate(byte[] input, int inputOffset, int inputLen) {
220        if (bufferOffset + inputLen > buffer.length) {
221            inputTooLarge = true;
222            return EmptyArray.BYTE;
223        }
224
225        System.arraycopy(input, inputOffset, buffer, bufferOffset, inputLen);
226        bufferOffset += inputLen;
227        return EmptyArray.BYTE;
228    }
229
230    @Override
231    protected int engineUpdate(byte[] input, int inputOffset, int inputLen, byte[] output,
232            int outputOffset) throws ShortBufferException {
233        engineUpdate(input, inputOffset, inputLen);
234        return 0;
235    }
236
237    @Override
238    protected byte[] engineDoFinal(byte[] input, int inputOffset, int inputLen)
239            throws IllegalBlockSizeException, BadPaddingException {
240        if (input != null) {
241            engineUpdate(input, inputOffset, inputLen);
242        }
243
244        if (inputTooLarge) {
245            throw new IllegalBlockSizeException("input must be under " + buffer.length + " bytes");
246        }
247
248        final byte[] tmpBuf;
249        if (bufferOffset != buffer.length) {
250            if (padding == NativeCrypto.RSA_NO_PADDING) {
251                tmpBuf = new byte[buffer.length];
252                System.arraycopy(buffer, 0, tmpBuf, buffer.length - bufferOffset, bufferOffset);
253            } else {
254                tmpBuf = Arrays.copyOf(buffer, bufferOffset);
255            }
256        } else {
257            tmpBuf = buffer;
258        }
259
260        byte[] output = new byte[buffer.length];
261        int resultSize;
262        if (encrypting) {
263            if (usingPrivateKey) {
264                resultSize = NativeCrypto.RSA_private_encrypt(tmpBuf.length, tmpBuf, output,
265                                                              key.getPkeyContext(), padding);
266            } else {
267                resultSize = NativeCrypto.RSA_public_encrypt(tmpBuf.length, tmpBuf, output,
268                                                             key.getPkeyContext(), padding);
269            }
270        } else {
271            try {
272                if (usingPrivateKey) {
273                    resultSize = NativeCrypto.RSA_private_decrypt(tmpBuf.length, tmpBuf, output,
274                                                                  key.getPkeyContext(), padding);
275                } else {
276                    resultSize = NativeCrypto.RSA_public_decrypt(tmpBuf.length, tmpBuf, output,
277                                                                 key.getPkeyContext(), padding);
278                }
279            } catch (SignatureException e) {
280                IllegalBlockSizeException newE = new IllegalBlockSizeException();
281                newE.initCause(e);
282                throw newE;
283            }
284        }
285        if (!encrypting && resultSize != output.length) {
286            output = Arrays.copyOf(output, resultSize);
287        }
288
289        bufferOffset = 0;
290        return output;
291    }
292
293    @Override
294    protected int engineDoFinal(byte[] input, int inputOffset, int inputLen, byte[] output,
295            int outputOffset) throws ShortBufferException, IllegalBlockSizeException,
296            BadPaddingException {
297        byte[] b = engineDoFinal(input, inputOffset, inputLen);
298
299        final int lastOffset = outputOffset + b.length;
300        if (lastOffset > output.length) {
301            throw new ShortBufferException("output buffer is too small " + output.length + " < "
302                    + lastOffset);
303        }
304
305        System.arraycopy(b, 0, output, outputOffset, b.length);
306        return b.length;
307    }
308
309    @Override
310    protected byte[] engineWrap(Key key) throws IllegalBlockSizeException, InvalidKeyException {
311        try {
312            byte[] encoded = key.getEncoded();
313            return engineDoFinal(encoded, 0, encoded.length);
314        } catch (BadPaddingException e) {
315            IllegalBlockSizeException newE = new IllegalBlockSizeException();
316            newE.initCause(e);
317            throw newE;
318        }
319    }
320
321    @Override
322    protected Key engineUnwrap(byte[] wrappedKey, String wrappedKeyAlgorithm,
323            int wrappedKeyType) throws InvalidKeyException, NoSuchAlgorithmException {
324        try {
325            byte[] encoded = engineDoFinal(wrappedKey, 0, wrappedKey.length);
326            if (wrappedKeyType == Cipher.PUBLIC_KEY) {
327                KeyFactory keyFactory = KeyFactory.getInstance(wrappedKeyAlgorithm);
328                return keyFactory.generatePublic(new X509EncodedKeySpec(encoded));
329            } else if (wrappedKeyType == Cipher.PRIVATE_KEY) {
330                KeyFactory keyFactory = KeyFactory.getInstance(wrappedKeyAlgorithm);
331                return keyFactory.generatePrivate(new PKCS8EncodedKeySpec(encoded));
332            } else if (wrappedKeyType == Cipher.SECRET_KEY) {
333                return new SecretKeySpec(encoded, wrappedKeyAlgorithm);
334            } else {
335                throw new UnsupportedOperationException("wrappedKeyType == " + wrappedKeyType);
336            }
337        } catch (IllegalBlockSizeException e) {
338            throw new InvalidKeyException(e);
339        } catch (BadPaddingException e) {
340            throw new InvalidKeyException(e);
341        } catch (InvalidKeySpecException e) {
342            throw new InvalidKeyException(e);
343        }
344    }
345
346    public static class PKCS1 extends OpenSSLCipherRSA {
347        public PKCS1() {
348            super(NativeCrypto.RSA_PKCS1_PADDING);
349        }
350    }
351
352    public static class Raw extends OpenSSLCipherRSA {
353        public Raw() {
354            super(NativeCrypto.RSA_NO_PADDING);
355        }
356    }
357}
358