1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <conscrypt/jniutil.h>
18
19#include <conscrypt/compat.h>
20#include <conscrypt/trace.h>
21#include <cstdlib>
22#include <errno.h>
23
24namespace conscrypt {
25namespace jniutil {
26
27JavaVM *gJavaVM;
28jclass cryptoUpcallsClass;
29jclass openSslInputStreamClass;
30jclass nativeRefClass;
31
32jclass byteArrayClass;
33jclass calendarClass;
34jclass objectClass;
35jclass objectArrayClass;
36jclass integerClass;
37jclass inputStreamClass;
38jclass outputStreamClass;
39jclass stringClass;
40
41jfieldID nativeRef_context;
42
43jmethodID calendar_setMethod;
44jmethodID inputStream_readMethod;
45jmethodID integer_valueOfMethod;
46jmethodID openSslInputStream_readLineMethod;
47jmethodID outputStream_writeMethod;
48jmethodID outputStream_flushMethod;
49
50void init(JavaVM* vm, JNIEnv* env) {
51    gJavaVM = vm;
52
53    byteArrayClass = findClass(env, "[B");
54    calendarClass = findClass(env, "java/util/Calendar");
55    inputStreamClass = findClass(env, "java/io/InputStream");
56    integerClass = findClass(env, "java/lang/Integer");
57    objectClass = findClass(env, "java/lang/Object");
58    objectArrayClass = findClass(env, "[Ljava/lang/Object;");
59    outputStreamClass = findClass(env, "java/io/OutputStream");
60    stringClass = findClass(env, "java/lang/String");
61
62    cryptoUpcallsClass = getGlobalRefToClass(
63            env, TO_STRING(JNI_JARJAR_PREFIX) "org/conscrypt/CryptoUpcalls");
64    nativeRefClass = getGlobalRefToClass(
65            env, TO_STRING(JNI_JARJAR_PREFIX) "org/conscrypt/NativeRef");
66    openSslInputStreamClass = getGlobalRefToClass(
67            env, TO_STRING(JNI_JARJAR_PREFIX) "org/conscrypt/OpenSSLBIOInputStream");
68
69    nativeRef_context = getFieldRef(env, nativeRefClass, "context", "J");
70
71    calendar_setMethod = getMethodRef(env, calendarClass, "set", "(IIIIII)V");
72    inputStream_readMethod = getMethodRef(env, inputStreamClass, "read", "([B)I");
73    integer_valueOfMethod =
74            env->GetStaticMethodID(integerClass, "valueOf", "(I)Ljava/lang/Integer;");
75    openSslInputStream_readLineMethod =
76            getMethodRef(env, openSslInputStreamClass, "gets", "([B)I");
77    outputStream_writeMethod = getMethodRef(env, outputStreamClass, "write", "([B)V");
78    outputStream_flushMethod = getMethodRef(env, outputStreamClass, "flush", "()V");
79}
80
81void jniRegisterNativeMethods(JNIEnv* env, const char* className, const JNINativeMethod* gMethods,
82                              int numMethods) {
83    ALOGV("Registering %s's %d native methods...", className, numMethods);
84
85    ScopedLocalRef<jclass> c(env, env->FindClass(className));
86    if (c.get() == nullptr) {
87        char* msg;
88        (void)asprintf(&msg, "Native registration unable to find class '%s'; aborting...",
89                       className);
90        env->FatalError(msg);
91    }
92
93    if (env->RegisterNatives(c.get(), gMethods, numMethods) < 0) {
94        char* msg;
95        (void)asprintf(&msg, "RegisterNatives failed for '%s'; aborting...", className);
96        env->FatalError(msg);
97    }
98}
99
100int jniGetFDFromFileDescriptor(JNIEnv* env, jobject fileDescriptor) {
101    ScopedLocalRef<jclass> localClass(env, env->FindClass("java/io/FileDescriptor"));
102#if defined(ANDROID) && !defined(CONSCRYPT_OPENJDK)
103    static jfieldID fid = env->GetFieldID(localClass.get(), "descriptor", "I");
104#else /* !ANDROID || CONSCRYPT_OPENJDK */
105    static jfieldID fid = env->GetFieldID(localClass.get(), "fd", "I");
106#endif
107    if (fileDescriptor != nullptr) {
108        return env->GetIntField(fileDescriptor, fid);
109    } else {
110        return -1;
111    }
112}
113
114bool isGetByteArrayElementsLikelyToReturnACopy(size_t size) {
115#if defined(ANDROID) && !defined(CONSCRYPT_OPENJDK)
116    // ART's GetByteArrayElements creates copies only for arrays smaller than 12 kB.
117    return size <= 12 * 1024;
118#else
119    (void)size;
120    // On OpenJDK based VMs GetByteArrayElements appears to always create a copy.
121    return true;
122#endif
123}
124
125int throwException(JNIEnv* env, const char* className, const char* msg) {
126    jclass exceptionClass = env->FindClass(className);
127
128    if (exceptionClass == nullptr) {
129        ALOGD("Unable to find exception class %s", className);
130        /* ClassNotFoundException now pending */
131        return -1;
132    }
133
134    if (env->ThrowNew(exceptionClass, msg) != JNI_OK) {
135        ALOGD("Failed throwing '%s' '%s'", className, msg);
136        /* an exception, most likely OOM, will now be pending */
137        return -1;
138    }
139
140    env->DeleteLocalRef(exceptionClass);
141    return 0;
142}
143
144int throwRuntimeException(JNIEnv* env, const char* msg) {
145    return conscrypt::jniutil::throwException(env, "java/lang/RuntimeException", msg);
146}
147
148int throwAssertionError(JNIEnv* env, const char* msg) {
149    return conscrypt::jniutil::throwException(env, "java/lang/AssertionError", msg);
150}
151
152int throwNullPointerException(JNIEnv* env, const char* msg) {
153    return conscrypt::jniutil::throwException(env, "java/lang/NullPointerException", msg);
154}
155
156int throwOutOfMemory(JNIEnv* env, const char* message) {
157    return conscrypt::jniutil::throwException(env, "java/lang/OutOfMemoryError", message);
158}
159
160int throwBadPaddingException(JNIEnv* env, const char* message) {
161    JNI_TRACE("throwBadPaddingException %s", message);
162    return conscrypt::jniutil::throwException(env, "javax/crypto/BadPaddingException", message);
163}
164
165int throwSignatureException(JNIEnv* env, const char* message) {
166    JNI_TRACE("throwSignatureException %s", message);
167    return conscrypt::jniutil::throwException(env, "java/security/SignatureException", message);
168}
169
170int throwInvalidKeyException(JNIEnv* env, const char* message) {
171    JNI_TRACE("throwInvalidKeyException %s", message);
172    return conscrypt::jniutil::throwException(env, "java/security/InvalidKeyException", message);
173}
174
175int throwIllegalBlockSizeException(JNIEnv* env, const char* message) {
176    JNI_TRACE("throwIllegalBlockSizeException %s", message);
177    return conscrypt::jniutil::throwException(
178            env, "javax/crypto/IllegalBlockSizeException", message);
179}
180
181int throwNoSuchAlgorithmException(JNIEnv* env, const char* message) {
182    JNI_TRACE("throwUnknownAlgorithmException %s", message);
183    return conscrypt::jniutil::throwException(
184            env, "java/security/NoSuchAlgorithmException", message);
185}
186
187int throwIOException(JNIEnv* env, const char* message) {
188    JNI_TRACE("throwIOException %s", message);
189    return conscrypt::jniutil::throwException(env, "java/io/IOException", message);
190}
191
192int throwParsingException(JNIEnv* env, const char* message) {
193    return conscrypt::jniutil::throwException(env, TO_STRING(JNI_JARJAR_PREFIX)
194                            "org/conscrypt/OpenSSLX509CertificateFactory$ParsingException",
195                            message);
196}
197
198int throwInvalidAlgorithmParameterException(JNIEnv* env, const char* message) {
199    JNI_TRACE("throwInvalidAlgorithmParameterException %s", message);
200    return conscrypt::jniutil::throwException(
201            env, "java/security/InvalidAlgorithmParameterException", message);
202}
203
204int throwForAsn1Error(JNIEnv* env, int reason, const char* message,
205                      int (*defaultThrow)(JNIEnv*, const char*)) {
206    switch (reason) {
207        case ASN1_R_UNSUPPORTED_PUBLIC_KEY_TYPE:
208#if defined(ASN1_R_UNABLE_TO_DECODE_RSA_KEY)
209        case ASN1_R_UNABLE_TO_DECODE_RSA_KEY:
210#endif
211#if defined(ASN1_R_WRONG_PUBLIC_KEY_TYPE)
212        case ASN1_R_WRONG_PUBLIC_KEY_TYPE:
213#endif
214#if defined(ASN1_R_UNABLE_TO_DECODE_RSA_PRIVATE_KEY)
215        case ASN1_R_UNABLE_TO_DECODE_RSA_PRIVATE_KEY:
216#endif
217#if defined(ASN1_R_UNKNOWN_PUBLIC_KEY_TYPE)
218        case ASN1_R_UNKNOWN_PUBLIC_KEY_TYPE:
219#endif
220            return throwInvalidKeyException(env, message);
221            break;
222        case ASN1_R_UNKNOWN_SIGNATURE_ALGORITHM:
223        case ASN1_R_UNKNOWN_MESSAGE_DIGEST_ALGORITHM:
224            return throwNoSuchAlgorithmException(env, message);
225            break;
226    }
227    return defaultThrow(env, message);
228}
229
230int throwForCipherError(JNIEnv* env, int reason, const char* message,
231                        int (*defaultThrow)(JNIEnv*, const char*)) {
232    switch (reason) {
233        case CIPHER_R_BAD_DECRYPT:
234            return throwBadPaddingException(env, message);
235            break;
236        case CIPHER_R_DATA_NOT_MULTIPLE_OF_BLOCK_LENGTH:
237        case CIPHER_R_WRONG_FINAL_BLOCK_LENGTH:
238            return throwIllegalBlockSizeException(env, message);
239            break;
240        case CIPHER_R_AES_KEY_SETUP_FAILED:
241        case CIPHER_R_BAD_KEY_LENGTH:
242        case CIPHER_R_UNSUPPORTED_KEY_SIZE:
243            return throwInvalidKeyException(env, message);
244            break;
245    }
246    return defaultThrow(env, message);
247}
248
249int throwForEvpError(JNIEnv* env, int reason, const char* message,
250                     int (*defaultThrow)(JNIEnv*, const char*)) {
251    switch (reason) {
252        case EVP_R_MISSING_PARAMETERS:
253            return throwInvalidKeyException(env, message);
254            break;
255        case EVP_R_UNSUPPORTED_ALGORITHM:
256#if defined(EVP_R_X931_UNSUPPORTED)
257        case EVP_R_X931_UNSUPPORTED:
258#endif
259            return throwNoSuchAlgorithmException(env, message);
260            break;
261#if defined(EVP_R_WRONG_PUBLIC_KEY_TYPE)
262        case EVP_R_WRONG_PUBLIC_KEY_TYPE:
263            return throwInvalidKeyException(env, message);
264            break;
265#endif
266#if defined(EVP_R_UNKNOWN_MESSAGE_DIGEST_ALGORITHM)
267        case EVP_R_UNKNOWN_MESSAGE_DIGEST_ALGORITHM:
268            return throwNoSuchAlgorithmException(env, message);
269            break;
270#endif
271        default:
272            return defaultThrow(env, message);
273            break;
274    }
275}
276
277int throwForRsaError(JNIEnv* env, int reason, const char* message,
278                     int (*defaultThrow)(JNIEnv*, const char*)) {
279    switch (reason) {
280        case RSA_R_BLOCK_TYPE_IS_NOT_01:
281        case RSA_R_PKCS_DECODING_ERROR:
282#if defined(RSA_R_BLOCK_TYPE_IS_NOT_02)
283        case RSA_R_BLOCK_TYPE_IS_NOT_02:
284#endif
285            return throwBadPaddingException(env, message);
286            break;
287        case RSA_R_BAD_SIGNATURE:
288        case RSA_R_INVALID_MESSAGE_LENGTH:
289        case RSA_R_WRONG_SIGNATURE_LENGTH:
290            return throwSignatureException(env, message);
291            break;
292        case RSA_R_UNKNOWN_ALGORITHM_TYPE:
293            return throwNoSuchAlgorithmException(env, message);
294            break;
295        case RSA_R_MODULUS_TOO_LARGE:
296        case RSA_R_NO_PUBLIC_EXPONENT:
297            return throwInvalidKeyException(env, message);
298            break;
299        case RSA_R_DATA_TOO_LARGE:
300        case RSA_R_DATA_TOO_LARGE_FOR_MODULUS:
301        case RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE:
302            return throwIllegalBlockSizeException(env, message);
303            break;
304    }
305    return defaultThrow(env, message);
306}
307
308int throwForX509Error(JNIEnv* env, int reason, const char* message,
309                      int (*defaultThrow)(JNIEnv*, const char*)) {
310    switch (reason) {
311        case X509_R_UNSUPPORTED_ALGORITHM:
312            return throwNoSuchAlgorithmException(env, message);
313            break;
314        default:
315            return defaultThrow(env, message);
316            break;
317    }
318}
319
320void throwExceptionFromBoringSSLError(JNIEnv* env, CONSCRYPT_UNUSED const char* location,
321                                      int (*defaultThrow)(JNIEnv*, const char*)) {
322    const char* file;
323    int line;
324    const char* data;
325    int flags;
326    // NOLINTNEXTLINE(runtime/int)
327    unsigned long error = ERR_get_error_line_data(&file, &line, &data, &flags);
328
329    if (error == 0) {
330        throwAssertionError(env, "throwExceptionFromBoringSSLError called with no error");
331        return;
332    }
333
334    // If there's an error from BoringSSL it may have been caused by an exception in Java code, so
335    // ensure there isn't a pending exception before we throw a new one.
336    if (!env->ExceptionCheck()) {
337        char message[256];
338        ERR_error_string_n(error, message, sizeof(message));
339        int library = ERR_GET_LIB(error);
340        int reason = ERR_GET_REASON(error);
341        JNI_TRACE("OpenSSL error in %s error=%lx library=%x reason=%x (%s:%d): %s %s", location,
342                  error, library, reason, file, line, message,
343                  (flags & ERR_TXT_STRING) ? data : "(no data)");
344        switch (library) {
345            case ERR_LIB_RSA:
346                throwForRsaError(env, reason, message, defaultThrow);
347                break;
348            case ERR_LIB_ASN1:
349                throwForAsn1Error(env, reason, message, defaultThrow);
350                break;
351            case ERR_LIB_CIPHER:
352                throwForCipherError(env, reason, message, defaultThrow);
353                break;
354            case ERR_LIB_EVP:
355                throwForEvpError(env, reason, message, defaultThrow);
356                break;
357            case ERR_LIB_X509:
358                throwForX509Error(env, reason, message, defaultThrow);
359                break;
360            case ERR_LIB_DSA:
361                throwInvalidKeyException(env, message);
362                break;
363            default:
364                defaultThrow(env, message);
365                break;
366        }
367    }
368
369    ERR_clear_error();
370}
371
372int throwSocketTimeoutException(JNIEnv* env, const char* message) {
373    JNI_TRACE("throwSocketTimeoutException %s", message);
374    return conscrypt::jniutil::throwException(env, "java/net/SocketTimeoutException", message);
375}
376
377int throwSSLHandshakeExceptionStr(JNIEnv* env, const char* message) {
378    JNI_TRACE("throwSSLExceptionStr %s", message);
379    return conscrypt::jniutil::throwException(
380            env, "javax/net/ssl/SSLHandshakeException", message);
381}
382
383int throwSSLExceptionStr(JNIEnv* env, const char* message) {
384    JNI_TRACE("throwSSLExceptionStr %s", message);
385    return conscrypt::jniutil::throwException(env, "javax/net/ssl/SSLException", message);
386}
387
388int throwSSLProtocolExceptionStr(JNIEnv* env, const char* message) {
389    JNI_TRACE("throwSSLProtocolExceptionStr %s", message);
390    return conscrypt::jniutil::throwException(
391            env, "javax/net/ssl/SSLProtocolException", message);
392}
393
394int throwSSLExceptionWithSslErrors(JNIEnv* env, SSL* ssl, int sslErrorCode, const char* message,
395                                   int (*actualThrow)(JNIEnv*, const char*)) {
396    if (message == nullptr) {
397        message = "SSL error";
398    }
399
400    // First consult the SSL error code for the general message.
401    const char* sslErrorStr = nullptr;
402    switch (sslErrorCode) {
403        case SSL_ERROR_NONE:
404            if (ERR_peek_error() == 0) {
405                sslErrorStr = "OK";
406            } else {
407                sslErrorStr = "";
408            }
409            break;
410        case SSL_ERROR_SSL:
411            sslErrorStr = "Failure in SSL library, usually a protocol error";
412            break;
413        case SSL_ERROR_WANT_READ:
414            sslErrorStr = "SSL_ERROR_WANT_READ occurred. You should never see this.";
415            break;
416        case SSL_ERROR_WANT_WRITE:
417            sslErrorStr = "SSL_ERROR_WANT_WRITE occurred. You should never see this.";
418            break;
419        case SSL_ERROR_WANT_X509_LOOKUP:
420            sslErrorStr = "SSL_ERROR_WANT_X509_LOOKUP occurred. You should never see this.";
421            break;
422        case SSL_ERROR_SYSCALL:
423            sslErrorStr = "I/O error during system call";
424            break;
425        case SSL_ERROR_ZERO_RETURN:
426            sslErrorStr = "SSL_ERROR_ZERO_RETURN occurred. You should never see this.";
427            break;
428        case SSL_ERROR_WANT_CONNECT:
429            sslErrorStr = "SSL_ERROR_WANT_CONNECT occurred. You should never see this.";
430            break;
431        case SSL_ERROR_WANT_ACCEPT:
432            sslErrorStr = "SSL_ERROR_WANT_ACCEPT occurred. You should never see this.";
433            break;
434        default:
435            sslErrorStr = "Unknown SSL error";
436    }
437
438    // Prepend either our explicit message or a default one.
439    char* str;
440    if (asprintf(&str, "%s: ssl=%p: %s", message, ssl, sslErrorStr) <= 0) {
441        // problem with asprintf, just throw argument message, log everything
442        int ret = actualThrow(env, message);
443        ALOGV("%s: ssl=%p: %s", message, ssl, sslErrorStr);
444        ERR_clear_error();
445        return ret;
446    }
447
448    char* allocStr = str;
449
450    // For protocol errors, SSL might have more information.
451    if (sslErrorCode == SSL_ERROR_NONE || sslErrorCode == SSL_ERROR_SSL) {
452        // Append each error as an additional line to the message.
453        for (;;) {
454            char errStr[256];
455            const char* file;
456            int line;
457            const char* data;
458            int flags;
459            // NOLINTNEXTLINE(runtime/int)
460            unsigned long err = ERR_get_error_line_data(&file, &line, &data, &flags);
461            if (err == 0) {
462                break;
463            }
464
465            ERR_error_string_n(err, errStr, sizeof(errStr));
466
467            int ret = asprintf(&str, "%s\n%s (%s:%d %p:0x%08x)",
468                               (allocStr == nullptr) ? "" : allocStr, errStr, file, line,
469                               (flags & ERR_TXT_STRING) ? data : "(no data)", flags);
470
471            if (ret < 0) {
472                break;
473            }
474
475            free(allocStr);
476            allocStr = str;
477        }
478        // For errors during system calls, errno might be our friend.
479    } else if (sslErrorCode == SSL_ERROR_SYSCALL) {
480        if (asprintf(&str, "%s, %s", allocStr, strerror(errno)) >= 0) {
481            free(allocStr);
482            allocStr = str;
483        }
484        // If the error code is invalid, print it.
485    } else if (sslErrorCode > SSL_ERROR_WANT_ACCEPT) {
486        if (asprintf(&str, ", error code is %d", sslErrorCode) >= 0) {
487            free(allocStr);
488            allocStr = str;
489        }
490    }
491
492    int ret;
493    if (sslErrorCode == SSL_ERROR_SSL) {
494        ret = throwSSLProtocolExceptionStr(env, allocStr);
495    } else {
496        ret = actualThrow(env, allocStr);
497    }
498
499    ALOGV("%s", allocStr);
500    free(allocStr);
501    ERR_clear_error();
502    return ret;
503}
504
505}  // namespace jniutil
506}  // namespace conscrypt
507