1/*
2 * Copyright (C) 2006 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "LocalSocketImpl"
18
19#include "JNIHelp.h"
20#include "jni.h"
21#include "utils/Log.h"
22#include "utils/misc.h"
23
24#include <stdio.h>
25#include <string.h>
26#include <sys/types.h>
27#include <sys/socket.h>
28#include <sys/un.h>
29#include <arpa/inet.h>
30#include <netinet/in.h>
31#include <stdlib.h>
32#include <errno.h>
33#include <unistd.h>
34#include <sys/ioctl.h>
35
36#include <cutils/sockets.h>
37#include <netinet/tcp.h>
38#include <ScopedUtfChars.h>
39
40namespace android {
41
42template <typename T>
43void UNUSED(T t) {}
44
45static jfieldID field_inboundFileDescriptors;
46static jfieldID field_outboundFileDescriptors;
47static jclass class_Credentials;
48static jclass class_FileDescriptor;
49static jmethodID method_CredentialsInit;
50
51/* private native void connectLocal(FileDescriptor fd,
52 * String name, int namespace) throws IOException
53 */
54static void
55socket_connect_local(JNIEnv *env, jobject object,
56                        jobject fileDescriptor, jstring name, jint namespaceId)
57{
58    int ret;
59    int fd;
60
61    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
62
63    if (env->ExceptionCheck()) {
64        return;
65    }
66
67    ScopedUtfChars nameUtf8(env, name);
68
69    ret = socket_local_client_connect(
70                fd,
71                nameUtf8.c_str(),
72                namespaceId,
73                SOCK_STREAM);
74
75    if (ret < 0) {
76        jniThrowIOException(env, errno);
77        return;
78    }
79}
80
81#define DEFAULT_BACKLOG 4
82
83/* private native void bindLocal(FileDescriptor fd, String name, namespace)
84 * throws IOException;
85 */
86
87static void
88socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
89                jstring name, jint namespaceId)
90{
91    int ret;
92    int fd;
93
94    if (name == NULL) {
95        jniThrowNullPointerException(env, NULL);
96        return;
97    }
98
99    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
100
101    if (env->ExceptionCheck()) {
102        return;
103    }
104
105    ScopedUtfChars nameUtf8(env, name);
106
107    ret = socket_local_server_bind(fd, nameUtf8.c_str(), namespaceId);
108
109    if (ret < 0) {
110        jniThrowIOException(env, errno);
111        return;
112    }
113}
114
115/**
116 * Processes ancillary data, handling only
117 * SCM_RIGHTS. Creates appropriate objects and sets appropriate
118 * fields in the LocalSocketImpl object. Returns 0 on success
119 * or -1 if an exception was thrown.
120 */
121static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
122{
123    struct cmsghdr *cmsgptr;
124
125    for (cmsgptr = CMSG_FIRSTHDR(pMsg);
126            cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
127
128        if (cmsgptr->cmsg_level != SOL_SOCKET) {
129            continue;
130        }
131
132        if (cmsgptr->cmsg_type == SCM_RIGHTS) {
133            int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
134            jobjectArray fdArray;
135            int count
136                = ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));
137
138            if (count < 0) {
139                jniThrowException(env, "java/io/IOException",
140                    "invalid cmsg length");
141                return -1;
142            }
143
144            fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
145
146            if (fdArray == NULL) {
147                return -1;
148            }
149
150            for (int i = 0; i < count; i++) {
151                jobject fdObject
152                        = jniCreateFileDescriptor(env, pDescriptors[i]);
153
154                if (env->ExceptionCheck()) {
155                    return -1;
156                }
157
158                env->SetObjectArrayElement(fdArray, i, fdObject);
159
160                if (env->ExceptionCheck()) {
161                    return -1;
162                }
163            }
164
165            env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
166
167            if (env->ExceptionCheck()) {
168                return -1;
169            }
170        }
171    }
172
173    return 0;
174}
175
176/**
177 * Reads data from a socket into buf, processing any ancillary data
178 * and adding it to thisJ.
179 *
180 * Returns the length of normal data read, or -1 if an exception has
181 * been thrown in this function.
182 */
183static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
184        void *buffer, size_t len)
185{
186    ssize_t ret;
187    struct msghdr msg;
188    struct iovec iv;
189    unsigned char *buf = (unsigned char *)buffer;
190    // Enough buffer for a pile of fd's. We throw an exception if
191    // this buffer is too small.
192    struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];
193
194    memset(&msg, 0, sizeof(msg));
195    memset(&iv, 0, sizeof(iv));
196
197    iv.iov_base = buf;
198    iv.iov_len = len;
199
200    msg.msg_iov = &iv;
201    msg.msg_iovlen = 1;
202    msg.msg_control = cmsgbuf;
203    msg.msg_controllen = sizeof(cmsgbuf);
204
205    ret = TEMP_FAILURE_RETRY(recvmsg(fd, &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));
206
207    if (ret < 0 && errno == EPIPE) {
208        // Treat this as an end of stream
209        return 0;
210    }
211
212    if (ret < 0) {
213        jniThrowIOException(env, errno);
214        return -1;
215    }
216
217    if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
218        // To us, any of the above flags are a fatal error
219
220        jniThrowException(env, "java/io/IOException",
221                "Unexpected error or truncation during recvmsg()");
222
223        return -1;
224    }
225
226    if (ret >= 0) {
227        socket_process_cmsg(env, thisJ, &msg);
228    }
229
230    return ret;
231}
232
233/**
234 * Writes all the data in the specified buffer to the specified socket.
235 *
236 * Returns 0 on success or -1 if an exception was thrown.
237 */
238static int socket_write_all(JNIEnv *env, jobject object, int fd,
239        void *buf, size_t len)
240{
241    ssize_t ret;
242    struct msghdr msg;
243    unsigned char *buffer = (unsigned char *)buf;
244    memset(&msg, 0, sizeof(msg));
245
246    jobjectArray outboundFds
247            = (jobjectArray)env->GetObjectField(
248                object, field_outboundFileDescriptors);
249
250    if (env->ExceptionCheck()) {
251        return -1;
252    }
253
254    struct cmsghdr *cmsg;
255    int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
256    int fds[countFds];
257    char msgbuf[CMSG_SPACE(countFds)];
258
259    // Add any pending outbound file descriptors to the message
260    if (outboundFds != NULL) {
261
262        if (env->ExceptionCheck()) {
263            return -1;
264        }
265
266        for (int i = 0; i < countFds; i++) {
267            jobject fdObject = env->GetObjectArrayElement(outboundFds, i);
268            if (env->ExceptionCheck()) {
269                return -1;
270            }
271
272            fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
273            if (env->ExceptionCheck()) {
274                return -1;
275            }
276        }
277
278        // See "man cmsg" really
279        msg.msg_control = msgbuf;
280        msg.msg_controllen = sizeof msgbuf;
281        cmsg = CMSG_FIRSTHDR(&msg);
282        cmsg->cmsg_level = SOL_SOCKET;
283        cmsg->cmsg_type = SCM_RIGHTS;
284        cmsg->cmsg_len = CMSG_LEN(sizeof fds);
285        memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
286    }
287
288    // We only write our msg_control during the first write
289    while (len > 0) {
290        struct iovec iv;
291        memset(&iv, 0, sizeof(iv));
292
293        iv.iov_base = buffer;
294        iv.iov_len = len;
295
296        msg.msg_iov = &iv;
297        msg.msg_iovlen = 1;
298
299        do {
300            ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
301        } while (ret < 0 && errno == EINTR);
302
303        if (ret < 0) {
304            jniThrowIOException(env, errno);
305            return -1;
306        }
307
308        buffer += ret;
309        len -= ret;
310
311        // Wipes out any msg_control too
312        memset(&msg, 0, sizeof(msg));
313    }
314
315    return 0;
316}
317
318static jint socket_read (JNIEnv *env, jobject object, jobject fileDescriptor)
319{
320    int fd;
321    int err;
322
323    if (fileDescriptor == NULL) {
324        jniThrowNullPointerException(env, NULL);
325        return (jint)-1;
326    }
327
328    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
329
330    if (env->ExceptionCheck()) {
331        return (jint)0;
332    }
333
334    unsigned char buf;
335
336    err = socket_read_all(env, object, fd, &buf, 1);
337
338    if (err < 0) {
339        jniThrowIOException(env, errno);
340        return (jint)0;
341    }
342
343    if (err == 0) {
344        // end of file
345        return (jint)-1;
346    }
347
348    return (jint)buf;
349}
350
351static jint socket_readba (JNIEnv *env, jobject object,
352        jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
353{
354    int fd;
355    jbyte* byteBuffer;
356    int ret;
357
358    if (fileDescriptor == NULL || buffer == NULL) {
359        jniThrowNullPointerException(env, NULL);
360        return (jint)-1;
361    }
362
363    if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
364        jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
365        return (jint)-1;
366    }
367
368    if (len == 0) {
369        // because socket_read_all returns 0 on EOF
370        return 0;
371    }
372
373    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
374
375    if (env->ExceptionCheck()) {
376        return (jint)-1;
377    }
378
379    byteBuffer = env->GetByteArrayElements(buffer, NULL);
380
381    if (NULL == byteBuffer) {
382        // an exception will have been thrown
383        return (jint)-1;
384    }
385
386    ret = socket_read_all(env, object,
387            fd, byteBuffer + off, len);
388
389    // A return of -1 above means an exception is pending
390
391    env->ReleaseByteArrayElements(buffer, byteBuffer, 0);
392
393    return (jint) ((ret == 0) ? -1 : ret);
394}
395
396static void socket_write (JNIEnv *env, jobject object,
397        jint b, jobject fileDescriptor)
398{
399    int fd;
400    int err;
401
402    if (fileDescriptor == NULL) {
403        jniThrowNullPointerException(env, NULL);
404        return;
405    }
406
407    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
408
409    if (env->ExceptionCheck()) {
410        return;
411    }
412
413    err = socket_write_all(env, object, fd, &b, 1);
414    UNUSED(err);
415    // A return of -1 above means an exception is pending
416}
417
418static void socket_writeba (JNIEnv *env, jobject object,
419        jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
420{
421    int fd;
422    int err;
423    jbyte* byteBuffer;
424
425    if (fileDescriptor == NULL || buffer == NULL) {
426        jniThrowNullPointerException(env, NULL);
427        return;
428    }
429
430    if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
431        jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
432        return;
433    }
434
435    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
436
437    if (env->ExceptionCheck()) {
438        return;
439    }
440
441    byteBuffer = env->GetByteArrayElements(buffer,NULL);
442
443    if (NULL == byteBuffer) {
444        // an exception will have been thrown
445        return;
446    }
447
448    err = socket_write_all(env, object, fd,
449            byteBuffer + off, len);
450    UNUSED(err);
451    // A return of -1 above means an exception is pending
452
453    env->ReleaseByteArrayElements(buffer, byteBuffer, JNI_ABORT);
454}
455
456static jobject socket_get_peer_credentials(JNIEnv *env,
457        jobject object, jobject fileDescriptor)
458{
459    int err;
460    int fd;
461
462    if (fileDescriptor == NULL) {
463        jniThrowNullPointerException(env, NULL);
464        return NULL;
465    }
466
467    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
468
469    if (env->ExceptionCheck()) {
470        return NULL;
471    }
472
473    struct ucred creds;
474
475    memset(&creds, 0, sizeof(creds));
476    socklen_t szCreds = sizeof(creds);
477
478    err = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
479
480    if (err < 0) {
481        jniThrowIOException(env, errno);
482        return NULL;
483    }
484
485    if (szCreds == 0) {
486        return NULL;
487    }
488
489    return env->NewObject(class_Credentials, method_CredentialsInit,
490            creds.pid, creds.uid, creds.gid);
491}
492
493/*
494 * JNI registration.
495 */
496static const JNINativeMethod gMethods[] = {
497     /* name, signature, funcPtr */
498    {"connectLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V",
499                                                (void*)socket_connect_local},
500    {"bindLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V", (void*)socket_bind_local},
501    {"read_native", "(Ljava/io/FileDescriptor;)I", (void*) socket_read},
502    {"readba_native", "([BIILjava/io/FileDescriptor;)I", (void*) socket_readba},
503    {"writeba_native", "([BIILjava/io/FileDescriptor;)V", (void*) socket_writeba},
504    {"write_native", "(ILjava/io/FileDescriptor;)V", (void*) socket_write},
505    {"getPeerCredentials_native",
506            "(Ljava/io/FileDescriptor;)Landroid/net/Credentials;",
507            (void*) socket_get_peer_credentials}
508};
509
510int register_android_net_LocalSocketImpl(JNIEnv *env)
511{
512    jclass clazz;
513
514    clazz = env->FindClass("android/net/LocalSocketImpl");
515
516    if (clazz == NULL) {
517        goto error;
518    }
519
520    field_inboundFileDescriptors = env->GetFieldID(clazz,
521            "inboundFileDescriptors", "[Ljava/io/FileDescriptor;");
522
523    if (field_inboundFileDescriptors == NULL) {
524        goto error;
525    }
526
527    field_outboundFileDescriptors = env->GetFieldID(clazz,
528            "outboundFileDescriptors", "[Ljava/io/FileDescriptor;");
529
530    if (field_outboundFileDescriptors == NULL) {
531        goto error;
532    }
533
534    class_Credentials = env->FindClass("android/net/Credentials");
535
536    if (class_Credentials == NULL) {
537        goto error;
538    }
539
540    class_Credentials = (jclass)env->NewGlobalRef(class_Credentials);
541
542    class_FileDescriptor = env->FindClass("java/io/FileDescriptor");
543
544    if (class_FileDescriptor == NULL) {
545        goto error;
546    }
547
548    class_FileDescriptor = (jclass)env->NewGlobalRef(class_FileDescriptor);
549
550    method_CredentialsInit
551            = env->GetMethodID(class_Credentials, "<init>", "(III)V");
552
553    if (method_CredentialsInit == NULL) {
554        goto error;
555    }
556
557    return jniRegisterNativeMethods(env,
558        "android/net/LocalSocketImpl", gMethods, NELEM(gMethods));
559
560error:
561    ALOGE("Error registering android.net.LocalSocketImpl");
562    return -1;
563}
564
565};
566