1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.security.AlgorithmParameters;
20import java.security.InvalidAlgorithmParameterException;
21import java.security.InvalidKeyException;
22import java.security.InvalidParameterException;
23import java.security.Key;
24import java.security.KeyFactory;
25import java.security.NoSuchAlgorithmException;
26import java.security.SecureRandom;
27import java.security.SignatureException;
28import java.security.interfaces.RSAPrivateCrtKey;
29import java.security.interfaces.RSAPrivateKey;
30import java.security.interfaces.RSAPublicKey;
31import java.security.spec.AlgorithmParameterSpec;
32import java.security.spec.InvalidKeySpecException;
33import java.security.spec.InvalidParameterSpecException;
34import java.security.spec.MGF1ParameterSpec;
35import java.security.spec.PKCS8EncodedKeySpec;
36import java.security.spec.X509EncodedKeySpec;
37import java.util.Arrays;
38import java.util.Locale;
39import javax.crypto.BadPaddingException;
40import javax.crypto.Cipher;
41import javax.crypto.CipherSpi;
42import javax.crypto.IllegalBlockSizeException;
43import javax.crypto.NoSuchPaddingException;
44import javax.crypto.ShortBufferException;
45import javax.crypto.spec.OAEPParameterSpec;
46import javax.crypto.spec.PSource;
47import javax.crypto.spec.SecretKeySpec;
48
49@Internal
50abstract class OpenSSLCipherRSA extends CipherSpi {
51    /**
52     * The current OpenSSL key we're operating on.
53     */
54    protected OpenSSLKey key;
55
56    /**
57     * Current key type: private or public.
58     */
59    protected boolean usingPrivateKey;
60
61    /**
62     * Current cipher mode: encrypting or decrypting.
63     */
64    protected boolean encrypting;
65
66    /**
67     * Buffer for operations
68     */
69    private byte[] buffer;
70
71    /**
72     * Current offset in the buffer.
73     */
74    private int bufferOffset;
75
76    /**
77     * Flag that indicates an exception should be thrown when the input is too
78     * large during doFinal.
79     */
80    private boolean inputTooLarge;
81
82    /**
83     * Current padding mode
84     */
85    protected int padding = NativeConstants.RSA_PKCS1_PADDING;
86
87    protected OpenSSLCipherRSA(int padding) {
88        this.padding = padding;
89    }
90
91    @Override
92    protected void engineSetMode(String mode) throws NoSuchAlgorithmException {
93        final String modeUpper = mode.toUpperCase(Locale.ROOT);
94        if ("NONE".equals(modeUpper) || "ECB".equals(modeUpper)) {
95            return;
96        }
97
98        throw new NoSuchAlgorithmException("mode not supported: " + mode);
99    }
100
101    @Override
102    protected void engineSetPadding(String padding) throws NoSuchPaddingException {
103        final String paddingUpper = padding.toUpperCase(Locale.ROOT);
104        if ("PKCS1PADDING".equals(paddingUpper)) {
105            this.padding = NativeConstants.RSA_PKCS1_PADDING;
106            return;
107        }
108        if ("NOPADDING".equals(paddingUpper)) {
109            this.padding = NativeConstants.RSA_NO_PADDING;
110            return;
111        }
112
113        throw new NoSuchPaddingException("padding not supported: " + padding);
114    }
115
116    @Override
117    protected int engineGetBlockSize() {
118        if (encrypting) {
119            return paddedBlockSizeBytes();
120        }
121        return keySizeBytes();
122    }
123
124    @Override
125    protected int engineGetOutputSize(int inputLen) {
126        if (encrypting) {
127            return keySizeBytes();
128        }
129        return paddedBlockSizeBytes();
130    }
131
132    protected int paddedBlockSizeBytes() {
133        int paddedBlockSizeBytes = keySizeBytes();
134        if (padding == NativeConstants.RSA_PKCS1_PADDING) {
135            paddedBlockSizeBytes--;  // for 0 prefix
136            paddedBlockSizeBytes -= 10;  // PKCS1 padding header length
137        }
138        return paddedBlockSizeBytes;
139    }
140
141    protected int keySizeBytes() {
142        if (!isInitialized()) {
143            throw new IllegalStateException("cipher is not initialized");
144        }
145        return NativeCrypto.RSA_size(this.key.getNativeRef());
146    }
147
148    /**
149     * Returns {@code true} if the cipher has been initialized.
150     */
151    protected boolean isInitialized() {
152        return key != null;
153    }
154
155    @Override
156    protected byte[] engineGetIV() {
157        return null;
158    }
159
160    @Override
161    protected AlgorithmParameters engineGetParameters() {
162        return null;
163    }
164
165    protected void doCryptoInit(AlgorithmParameterSpec spec)
166            throws InvalidAlgorithmParameterException {}
167
168    protected void engineInitInternal(int opmode, Key key, AlgorithmParameterSpec spec)
169            throws InvalidKeyException, InvalidAlgorithmParameterException {
170        if (opmode == Cipher.ENCRYPT_MODE || opmode == Cipher.WRAP_MODE) {
171            encrypting = true;
172        } else if (opmode == Cipher.DECRYPT_MODE || opmode == Cipher.UNWRAP_MODE) {
173            encrypting = false;
174        } else {
175            throw new InvalidParameterException("Unsupported opmode " + opmode);
176        }
177
178        if (key instanceof OpenSSLRSAPrivateKey) {
179            OpenSSLRSAPrivateKey rsaPrivateKey = (OpenSSLRSAPrivateKey) key;
180            usingPrivateKey = true;
181            this.key = rsaPrivateKey.getOpenSSLKey();
182        } else if (key instanceof RSAPrivateCrtKey) {
183            RSAPrivateCrtKey rsaPrivateKey = (RSAPrivateCrtKey) key;
184            usingPrivateKey = true;
185            this.key = OpenSSLRSAPrivateCrtKey.getInstance(rsaPrivateKey);
186        } else if (key instanceof RSAPrivateKey) {
187            RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) key;
188            usingPrivateKey = true;
189            this.key = OpenSSLRSAPrivateKey.getInstance(rsaPrivateKey);
190        } else if (key instanceof OpenSSLRSAPublicKey) {
191            OpenSSLRSAPublicKey rsaPublicKey = (OpenSSLRSAPublicKey) key;
192            usingPrivateKey = false;
193            this.key = rsaPublicKey.getOpenSSLKey();
194        } else if (key instanceof RSAPublicKey) {
195            RSAPublicKey rsaPublicKey = (RSAPublicKey) key;
196            usingPrivateKey = false;
197            this.key = OpenSSLRSAPublicKey.getInstance(rsaPublicKey);
198        } else {
199            throw new InvalidKeyException("Need RSA private or public key");
200        }
201
202        buffer = new byte[NativeCrypto.RSA_size(this.key.getNativeRef())];
203        bufferOffset = 0;
204        inputTooLarge = false;
205
206        doCryptoInit(spec);
207    }
208
209    @Override
210    protected void engineInit(int opmode, Key key, SecureRandom random) throws InvalidKeyException {
211        try {
212            engineInitInternal(opmode, key, null);
213        } catch (InvalidAlgorithmParameterException e) {
214            throw new InvalidKeyException("Algorithm parameters rejected when none supplied", e);
215        }
216    }
217
218    @Override
219    protected void engineInit(int opmode, Key key, AlgorithmParameterSpec params,
220            SecureRandom random) throws InvalidKeyException, InvalidAlgorithmParameterException {
221        if (params != null) {
222            throw new InvalidAlgorithmParameterException("unknown param type: "
223                    + params.getClass().getName());
224        }
225
226        engineInitInternal(opmode, key, params);
227    }
228
229    @Override
230    protected void engineInit(int opmode, Key key, AlgorithmParameters params, SecureRandom random)
231            throws InvalidKeyException, InvalidAlgorithmParameterException {
232        if (params != null) {
233            throw new InvalidAlgorithmParameterException("unknown param type: "
234                    + params.getClass().getName());
235        }
236
237        engineInitInternal(opmode, key, null);
238    }
239
240    @Override
241    protected byte[] engineUpdate(byte[] input, int inputOffset, int inputLen) {
242        if (bufferOffset + inputLen > buffer.length) {
243            inputTooLarge = true;
244            return EmptyArray.BYTE;
245        }
246
247        System.arraycopy(input, inputOffset, buffer, bufferOffset, inputLen);
248        bufferOffset += inputLen;
249        return EmptyArray.BYTE;
250    }
251
252    @Override
253    protected int engineUpdate(byte[] input, int inputOffset, int inputLen, byte[] output,
254            int outputOffset) throws ShortBufferException {
255        engineUpdate(input, inputOffset, inputLen);
256        return 0;
257    }
258
259    @Override
260    protected byte[] engineDoFinal(byte[] input, int inputOffset, int inputLen)
261            throws IllegalBlockSizeException, BadPaddingException {
262        if (input != null) {
263            engineUpdate(input, inputOffset, inputLen);
264        }
265
266        if (inputTooLarge) {
267            throw new IllegalBlockSizeException("input must be under " + buffer.length + " bytes");
268        }
269
270        final byte[] tmpBuf;
271        if (bufferOffset != buffer.length) {
272            if (padding == NativeConstants.RSA_NO_PADDING) {
273                tmpBuf = new byte[buffer.length];
274                System.arraycopy(buffer, 0, tmpBuf, buffer.length - bufferOffset, bufferOffset);
275            } else {
276                tmpBuf = Arrays.copyOf(buffer, bufferOffset);
277            }
278        } else {
279            tmpBuf = buffer;
280        }
281
282        byte[] output = new byte[buffer.length];
283        int resultSize = doCryptoOperation(tmpBuf, output);
284        if (!encrypting && resultSize != output.length) {
285            output = Arrays.copyOf(output, resultSize);
286        }
287
288        bufferOffset = 0;
289        return output;
290    }
291
292    protected abstract int doCryptoOperation(final byte[] tmpBuf, byte[] output)
293            throws BadPaddingException, IllegalBlockSizeException;
294
295    @Override
296    protected int engineDoFinal(byte[] input, int inputOffset, int inputLen, byte[] output,
297            int outputOffset) throws ShortBufferException, IllegalBlockSizeException,
298            BadPaddingException {
299        byte[] b = engineDoFinal(input, inputOffset, inputLen);
300
301        final int lastOffset = outputOffset + b.length;
302        if (lastOffset > output.length) {
303            throw new ShortBufferException("output buffer is too small " + output.length + " < "
304                    + lastOffset);
305        }
306
307        System.arraycopy(b, 0, output, outputOffset, b.length);
308        return b.length;
309    }
310
311    @Override
312    protected byte[] engineWrap(Key key) throws IllegalBlockSizeException, InvalidKeyException {
313        try {
314            byte[] encoded = key.getEncoded();
315            return engineDoFinal(encoded, 0, encoded.length);
316        } catch (BadPaddingException e) {
317            IllegalBlockSizeException newE = new IllegalBlockSizeException();
318            newE.initCause(e);
319            throw newE;
320        }
321    }
322
323    @Override
324    protected Key engineUnwrap(byte[] wrappedKey, String wrappedKeyAlgorithm,
325            int wrappedKeyType) throws InvalidKeyException, NoSuchAlgorithmException {
326        try {
327            byte[] encoded = engineDoFinal(wrappedKey, 0, wrappedKey.length);
328            if (wrappedKeyType == Cipher.PUBLIC_KEY) {
329                KeyFactory keyFactory = KeyFactory.getInstance(wrappedKeyAlgorithm);
330                return keyFactory.generatePublic(new X509EncodedKeySpec(encoded));
331            } else if (wrappedKeyType == Cipher.PRIVATE_KEY) {
332                KeyFactory keyFactory = KeyFactory.getInstance(wrappedKeyAlgorithm);
333                return keyFactory.generatePrivate(new PKCS8EncodedKeySpec(encoded));
334            } else if (wrappedKeyType == Cipher.SECRET_KEY) {
335                return new SecretKeySpec(encoded, wrappedKeyAlgorithm);
336            } else {
337                throw new UnsupportedOperationException("wrappedKeyType == " + wrappedKeyType);
338            }
339        } catch (IllegalBlockSizeException e) {
340            throw new InvalidKeyException(e);
341        } catch (BadPaddingException e) {
342            throw new InvalidKeyException(e);
343        } catch (InvalidKeySpecException e) {
344            throw new InvalidKeyException(e);
345        }
346    }
347
348    public abstract static class DirectRSA extends OpenSSLCipherRSA {
349        public DirectRSA(int padding) {
350            super(padding);
351        }
352
353        @Override
354        protected int doCryptoOperation(final byte[] tmpBuf, byte[] output)
355                throws BadPaddingException, IllegalBlockSizeException {
356            int resultSize;
357            if (encrypting) {
358                if (usingPrivateKey) {
359                    resultSize = NativeCrypto.RSA_private_encrypt(
360                            tmpBuf.length, tmpBuf, output, key.getNativeRef(), padding);
361                } else {
362                    resultSize = NativeCrypto.RSA_public_encrypt(
363                            tmpBuf.length, tmpBuf, output, key.getNativeRef(), padding);
364                }
365            } else {
366                try {
367                    if (usingPrivateKey) {
368                        resultSize = NativeCrypto.RSA_private_decrypt(
369                                tmpBuf.length, tmpBuf, output, key.getNativeRef(), padding);
370                    } else {
371                        resultSize = NativeCrypto.RSA_public_decrypt(
372                                tmpBuf.length, tmpBuf, output, key.getNativeRef(), padding);
373                    }
374                } catch (SignatureException e) {
375                    IllegalBlockSizeException newE = new IllegalBlockSizeException();
376                    newE.initCause(e);
377                    throw newE;
378                }
379            }
380            return resultSize;
381        }
382    }
383
384    public static final class PKCS1 extends DirectRSA {
385        public PKCS1() {
386            super(NativeConstants.RSA_PKCS1_PADDING);
387        }
388    }
389
390    public static final class Raw extends DirectRSA {
391        public Raw() {
392            super(NativeConstants.RSA_NO_PADDING);
393        }
394    }
395
396    protected static class OAEP extends OpenSSLCipherRSA {
397        private long oaepMd;
398        private int oaepMdSizeBytes;
399
400        private long mgf1Md;
401
402        private byte[] label;
403
404        private NativeRef.EVP_PKEY_CTX pkeyCtx;
405
406        public OAEP(long defaultMd, int defaultMdSizeBytes) {
407            super(NativeConstants.RSA_PKCS1_OAEP_PADDING);
408            oaepMd = mgf1Md = defaultMd;
409            oaepMdSizeBytes = defaultMdSizeBytes;
410        }
411
412        @Override
413        protected AlgorithmParameters engineGetParameters() {
414            if (!isInitialized()) {
415                return null;
416            }
417
418            try {
419                AlgorithmParameters params = AlgorithmParameters.getInstance("OAEP");
420
421                final PSource pSrc;
422                if (label == null) {
423                    pSrc = PSource.PSpecified.DEFAULT;
424                } else {
425                    pSrc = new PSource.PSpecified(label);
426                }
427
428                params.init(new OAEPParameterSpec(
429                        EvpMdRef.getJcaDigestAlgorithmStandardNameFromEVP_MD(oaepMd),
430                        EvpMdRef.MGF1_ALGORITHM_NAME,
431                        new MGF1ParameterSpec(
432                                EvpMdRef.getJcaDigestAlgorithmStandardNameFromEVP_MD(mgf1Md)),
433                        pSrc));
434                return params;
435            } catch (NoSuchAlgorithmException | InvalidParameterSpecException e) {
436                throw new RuntimeException("No providers of AlgorithmParameters.OAEP available");
437            }
438        }
439
440        @Override
441        protected void engineSetPadding(String padding) throws NoSuchPaddingException {
442            String paddingUpper = padding.toUpperCase(Locale.US);
443            if (paddingUpper.equals("OAEPPadding")) {
444                this.padding = NativeConstants.RSA_PKCS1_OAEP_PADDING;
445                return;
446            }
447
448            throw new NoSuchPaddingException("Only OAEP padding is supported");
449        }
450
451        @Override
452        protected void engineInit(
453                int opmode, Key key, AlgorithmParameterSpec spec, SecureRandom random)
454                throws InvalidKeyException, InvalidAlgorithmParameterException {
455            if (spec != null && !(spec instanceof OAEPParameterSpec)) {
456                throw new InvalidAlgorithmParameterException(
457                        "Only OAEPParameterSpec accepted in OAEP mode");
458            }
459
460            engineInitInternal(opmode, key, spec);
461        }
462
463        @Override
464        protected void engineInit(
465                int opmode, Key key, AlgorithmParameters params, SecureRandom random)
466                throws InvalidKeyException, InvalidAlgorithmParameterException {
467            OAEPParameterSpec spec = null;
468            if (params != null) {
469                try {
470                    spec = params.getParameterSpec(OAEPParameterSpec.class);
471                } catch (InvalidParameterSpecException e) {
472                    throw new InvalidAlgorithmParameterException(
473                            "Only OAEP parameters are supported", e);
474                }
475            }
476
477            engineInitInternal(opmode, key, spec);
478        }
479
480        @Override
481        protected void doCryptoInit(AlgorithmParameterSpec spec)
482                throws InvalidAlgorithmParameterException {
483            pkeyCtx = new NativeRef.EVP_PKEY_CTX(encrypting
484                            ? NativeCrypto.EVP_PKEY_encrypt_init(key.getNativeRef())
485                            : NativeCrypto.EVP_PKEY_decrypt_init(key.getNativeRef()));
486
487            if (spec instanceof OAEPParameterSpec) {
488                readOAEPParameters((OAEPParameterSpec) spec);
489            }
490
491            NativeCrypto.EVP_PKEY_CTX_set_rsa_padding(
492                    pkeyCtx.context, NativeConstants.RSA_PKCS1_OAEP_PADDING);
493            NativeCrypto.EVP_PKEY_CTX_set_rsa_oaep_md(pkeyCtx.context, oaepMd);
494            NativeCrypto.EVP_PKEY_CTX_set_rsa_mgf1_md(pkeyCtx.context, mgf1Md);
495            if (label != null && label.length > 0) {
496                NativeCrypto.EVP_PKEY_CTX_set_rsa_oaep_label(pkeyCtx.context, label);
497            }
498        }
499
500        @Override
501        protected int paddedBlockSizeBytes() {
502            int paddedBlockSizeBytes = keySizeBytes();
503            // Size described in step 2 of decoding algorithm, but extra byte
504            // needed to make sure it's smaller than the RSA key modulus size.
505            // https://tools.ietf.org/html/rfc2437#section-9.1.1.2
506            return paddedBlockSizeBytes - (2 * oaepMdSizeBytes + 2);
507        }
508
509        private void readOAEPParameters(OAEPParameterSpec spec)
510                throws InvalidAlgorithmParameterException {
511            String mgfAlgUpper = spec.getMGFAlgorithm().toUpperCase(Locale.US);
512            AlgorithmParameterSpec mgfSpec = spec.getMGFParameters();
513            if ((!EvpMdRef.MGF1_ALGORITHM_NAME.equals(mgfAlgUpper)
514                        && !EvpMdRef.MGF1_OID.equals(mgfAlgUpper))
515                    || !(mgfSpec instanceof MGF1ParameterSpec)) {
516                throw new InvalidAlgorithmParameterException(
517                        "Only MGF1 supported as mask generation function");
518            }
519
520            MGF1ParameterSpec mgf1spec = (MGF1ParameterSpec) mgfSpec;
521            String oaepAlgUpper = spec.getDigestAlgorithm().toUpperCase(Locale.US);
522            try {
523                oaepMd = EvpMdRef.getEVP_MDByJcaDigestAlgorithmStandardName(oaepAlgUpper);
524                oaepMdSizeBytes =
525                        EvpMdRef.getDigestSizeBytesByJcaDigestAlgorithmStandardName(oaepAlgUpper);
526                mgf1Md = EvpMdRef.getEVP_MDByJcaDigestAlgorithmStandardName(
527                        mgf1spec.getDigestAlgorithm());
528            } catch (NoSuchAlgorithmException e) {
529                throw new InvalidAlgorithmParameterException(e);
530            }
531
532            PSource pSource = spec.getPSource();
533            if (!"PSpecified".equals(pSource.getAlgorithm())
534                    || !(pSource instanceof PSource.PSpecified)) {
535                throw new InvalidAlgorithmParameterException(
536                        "Only PSpecified accepted for PSource");
537            }
538            label = ((PSource.PSpecified) pSource).getValue();
539        }
540
541        @Override
542        protected int doCryptoOperation(byte[] tmpBuf, byte[] output)
543                throws BadPaddingException, IllegalBlockSizeException {
544            if (encrypting) {
545                return NativeCrypto.EVP_PKEY_encrypt(pkeyCtx, output, 0, tmpBuf, 0, tmpBuf.length);
546            } else {
547                return NativeCrypto.EVP_PKEY_decrypt(pkeyCtx, output, 0, tmpBuf, 0, tmpBuf.length);
548            }
549        }
550
551        public static final class SHA1 extends OAEP {
552            public SHA1() {
553                super(EvpMdRef.SHA1.EVP_MD, EvpMdRef.SHA1.SIZE_BYTES);
554            }
555        }
556
557        public static final class SHA224 extends OAEP {
558            public SHA224() {
559                super(EvpMdRef.SHA224.EVP_MD, EvpMdRef.SHA224.SIZE_BYTES);
560            }
561        }
562
563        public static final class SHA256 extends OAEP {
564            public SHA256() {
565                super(EvpMdRef.SHA256.EVP_MD, EvpMdRef.SHA256.SIZE_BYTES);
566            }
567        }
568
569        public static final class SHA384 extends OAEP {
570            public SHA384() {
571                super(EvpMdRef.SHA384.EVP_MD, EvpMdRef.SHA384.SIZE_BYTES);
572            }
573        }
574
575        public static final class SHA512 extends OAEP {
576            public SHA512() {
577                super(EvpMdRef.SHA512.EVP_MD, EvpMdRef.SHA512.SIZE_BYTES);
578            }
579        }
580    }
581}
582