1/*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.conscrypt;
18
19import java.security.InvalidKeyException;
20import java.security.InvalidParameterException;
21import java.security.PrivateKey;
22import java.security.PublicKey;
23import java.security.SignatureException;
24import java.security.SignatureSpi;
25import java.security.interfaces.RSAPrivateCrtKey;
26import java.security.interfaces.RSAPrivateKey;
27import java.security.interfaces.RSAPublicKey;
28
29/**
30 * Implements the JDK Signature interface needed for RAW RSA signature
31 * generation and verification using OpenSSL.
32 */
33public class OpenSSLSignatureRawRSA extends SignatureSpi {
34    /**
35     * The current OpenSSL key we're operating on.
36     */
37    private OpenSSLKey key;
38
39    /**
40     * Buffer to hold value to be signed or verified.
41     */
42    private byte[] inputBuffer;
43
44    /**
45     * Current offset in input buffer.
46     */
47    private int inputOffset;
48
49    /**
50     * Provides a flag to specify when the input is too long.
51     */
52    private boolean inputIsTooLong;
53
54    @Override
55    protected void engineUpdate(byte input) {
56        final int oldOffset = inputOffset++;
57
58        if (inputOffset > inputBuffer.length) {
59            inputIsTooLong = true;
60            return;
61        }
62
63        inputBuffer[oldOffset] = input;
64    }
65
66    @Override
67    protected void engineUpdate(byte[] input, int offset, int len) {
68        final int oldOffset = inputOffset;
69        inputOffset += len;
70
71        if (inputOffset > inputBuffer.length) {
72            inputIsTooLong = true;
73            return;
74        }
75
76        System.arraycopy(input, offset, inputBuffer, oldOffset, len);
77    }
78
79    @Override
80    protected Object engineGetParameter(String param) throws InvalidParameterException {
81        return null;
82    }
83
84    @Override
85    protected void engineInitSign(PrivateKey privateKey) throws InvalidKeyException {
86        if (privateKey instanceof OpenSSLRSAPrivateKey) {
87            OpenSSLRSAPrivateKey rsaPrivateKey = (OpenSSLRSAPrivateKey) privateKey;
88            key = rsaPrivateKey.getOpenSSLKey();
89        } else if (privateKey instanceof RSAPrivateCrtKey) {
90            RSAPrivateCrtKey rsaPrivateKey = (RSAPrivateCrtKey) privateKey;
91            key = OpenSSLRSAPrivateCrtKey.getInstance(rsaPrivateKey);
92        } else if (privateKey instanceof RSAPrivateKey) {
93            RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) privateKey;
94            key = OpenSSLRSAPrivateKey.getInstance(rsaPrivateKey);
95        } else {
96            throw new InvalidKeyException("Need RSA private key");
97        }
98
99        // Allocate buffer according to RSA modulus size.
100        int maxSize = NativeCrypto.RSA_size(key.getPkeyContext());
101        inputBuffer = new byte[maxSize];
102        inputOffset = 0;
103    }
104
105    @Override
106    protected void engineInitVerify(PublicKey publicKey) throws InvalidKeyException {
107        if (publicKey instanceof OpenSSLRSAPublicKey) {
108            OpenSSLRSAPublicKey rsaPublicKey = (OpenSSLRSAPublicKey) publicKey;
109            key = rsaPublicKey.getOpenSSLKey();
110        } else if (publicKey instanceof RSAPublicKey) {
111            RSAPublicKey rsaPublicKey = (RSAPublicKey) publicKey;
112            key = OpenSSLRSAPublicKey.getInstance(rsaPublicKey);
113        } else {
114            throw new InvalidKeyException("Need RSA public key");
115        }
116
117        // Allocate buffer according to RSA modulus size.
118        int maxSize = NativeCrypto.RSA_size(key.getPkeyContext());
119        inputBuffer = new byte[maxSize];
120        inputOffset = 0;
121    }
122
123    @Override
124    protected void engineSetParameter(String param, Object value) throws InvalidParameterException {
125    }
126
127    @Override
128    protected byte[] engineSign() throws SignatureException {
129        if (key == null) {
130            // This can't actually happen, but you never know...
131            throw new SignatureException("Need RSA private key");
132        }
133
134        if (inputIsTooLong) {
135            throw new SignatureException("input length " + inputOffset + " != "
136                    + inputBuffer.length + " (modulus size)");
137        }
138
139        byte[] outputBuffer = new byte[inputBuffer.length];
140        try {
141            NativeCrypto.RSA_private_encrypt(inputOffset, inputBuffer, outputBuffer,
142                    key.getPkeyContext(), NativeCrypto.RSA_PKCS1_PADDING);
143            return outputBuffer;
144        } catch (Exception ex) {
145            throw new SignatureException(ex);
146        } finally {
147            inputOffset = 0;
148        }
149    }
150
151    @Override
152    protected boolean engineVerify(byte[] sigBytes) throws SignatureException {
153        if (key == null) {
154            // This can't actually happen, but you never know...
155            throw new SignatureException("Need RSA public key");
156        }
157
158        if (inputIsTooLong) {
159            return false;
160        }
161
162        byte[] outputBuffer = new byte[inputBuffer.length];
163        try {
164            final int resultSize;
165            try {
166                resultSize = NativeCrypto.RSA_public_decrypt(sigBytes.length, sigBytes,
167                        outputBuffer, key.getPkeyContext(), NativeCrypto.RSA_PKCS1_PADDING);
168            } catch (SignatureException e) {
169                throw e;
170            } catch (Exception e) {
171                return false;
172            }
173            /* Make this constant time by comparing every byte. */
174            boolean matches = (resultSize == inputOffset);
175            for (int i = 0; i < resultSize; i++) {
176                if (inputBuffer[i] != outputBuffer[i]) {
177                    matches = false;
178                }
179            }
180            return matches;
181        } catch (Exception ex) {
182            throw new SignatureException(ex);
183        } finally {
184            inputOffset = 0;
185        }
186    }
187}
188