android_net_LocalSocketImpl.cpp revision d2df87eb4424ec25ab5c9a8f47cb09a8f191e87f
1/*
2 * Copyright (C) 2006 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "LocalSocketImpl"
18
19#include "JNIHelp.h"
20#include "jni.h"
21#include "utils/Log.h"
22#include "utils/misc.h"
23
24#include <stdio.h>
25#include <string.h>
26#include <sys/types.h>
27#include <sys/socket.h>
28#include <sys/un.h>
29#include <arpa/inet.h>
30#include <netinet/in.h>
31#include <stdlib.h>
32#include <errno.h>
33#include <unistd.h>
34#include <sys/ioctl.h>
35
36#include <cutils/sockets.h>
37#include <netinet/tcp.h>
38#include <ScopedUtfChars.h>
39
40namespace android {
41
42template <typename T>
43void UNUSED(T t) {}
44
45static jfieldID field_inboundFileDescriptors;
46static jfieldID field_outboundFileDescriptors;
47static jclass class_Credentials;
48static jclass class_FileDescriptor;
49static jmethodID method_CredentialsInit;
50
51/* private native void connectLocal(FileDescriptor fd,
52 * String name, int namespace) throws IOException
53 */
54static void
55socket_connect_local(JNIEnv *env, jobject object,
56                        jobject fileDescriptor, jstring name, jint namespaceId)
57{
58    int ret;
59    int fd;
60
61    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
62
63    if (env->ExceptionCheck()) {
64        return;
65    }
66
67    ScopedUtfChars nameUtf8(env, name);
68
69    ret = socket_local_client_connect(
70                fd,
71                nameUtf8.c_str(),
72                namespaceId,
73                SOCK_STREAM);
74
75    if (ret < 0) {
76        jniThrowIOException(env, errno);
77        return;
78    }
79}
80
81#define DEFAULT_BACKLOG 4
82
83/* private native void bindLocal(FileDescriptor fd, String name, namespace)
84 * throws IOException;
85 */
86
87static void
88socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
89                jstring name, jint namespaceId)
90{
91    int ret;
92    int fd;
93
94    if (name == NULL) {
95        jniThrowNullPointerException(env, NULL);
96        return;
97    }
98
99    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
100
101    if (env->ExceptionCheck()) {
102        return;
103    }
104
105    ScopedUtfChars nameUtf8(env, name);
106
107    ret = socket_local_server_bind(fd, nameUtf8.c_str(), namespaceId);
108
109    if (ret < 0) {
110        jniThrowIOException(env, errno);
111        return;
112    }
113}
114
115/* private native void listen_native(int fd, int backlog) throws IOException; */
116static void
117socket_listen (JNIEnv *env, jobject object, jobject fileDescriptor, jint backlog)
118{
119    int ret;
120    int fd;
121
122    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
123
124    if (env->ExceptionCheck()) {
125        return;
126    }
127
128    ret = listen(fd, backlog);
129
130    if (ret < 0) {
131        jniThrowIOException(env, errno);
132        return;
133    }
134}
135
136/*    private native FileDescriptor
137**    accept (FileDescriptor fd, LocalSocketImpl s)
138**                                   throws IOException;
139*/
140static jobject
141socket_accept (JNIEnv *env, jobject object, jobject fileDescriptor, jobject s)
142{
143    union {
144        struct sockaddr address;
145        struct sockaddr_un un_address;
146    } sa;
147
148    int ret;
149    int retFD;
150    int fd;
151    socklen_t addrlen;
152
153    if (s == NULL) {
154        jniThrowNullPointerException(env, NULL);
155        return NULL;
156    }
157
158    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
159
160    if (env->ExceptionCheck()) {
161        return NULL;
162    }
163
164    do {
165        addrlen = sizeof(sa);
166        ret = accept(fd, &(sa.address), &addrlen);
167    } while (ret < 0 && errno == EINTR);
168
169    if (ret < 0) {
170        jniThrowIOException(env, errno);
171        return NULL;
172    }
173
174    retFD = ret;
175
176    return jniCreateFileDescriptor(env, retFD);
177}
178
179/* private native void shutdown(FileDescriptor fd, boolean shutdownInput) */
180
181static void
182socket_shutdown (JNIEnv *env, jobject object, jobject fileDescriptor,
183                    jboolean shutdownInput)
184{
185    int ret;
186    int fd;
187
188    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
189
190    if (env->ExceptionCheck()) {
191        return;
192    }
193
194    ret = shutdown(fd, shutdownInput ? SHUT_RD : SHUT_WR);
195
196    if (ret < 0) {
197        jniThrowIOException(env, errno);
198        return;
199    }
200}
201
202/**
203 * Processes ancillary data, handling only
204 * SCM_RIGHTS. Creates appropriate objects and sets appropriate
205 * fields in the LocalSocketImpl object. Returns 0 on success
206 * or -1 if an exception was thrown.
207 */
208static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
209{
210    struct cmsghdr *cmsgptr;
211
212    for (cmsgptr = CMSG_FIRSTHDR(pMsg);
213            cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
214
215        if (cmsgptr->cmsg_level != SOL_SOCKET) {
216            continue;
217        }
218
219        if (cmsgptr->cmsg_type == SCM_RIGHTS) {
220            int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
221            jobjectArray fdArray;
222            int count
223                = ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));
224
225            if (count < 0) {
226                jniThrowException(env, "java/io/IOException",
227                    "invalid cmsg length");
228                return -1;
229            }
230
231            fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
232
233            if (fdArray == NULL) {
234                return -1;
235            }
236
237            for (int i = 0; i < count; i++) {
238                jobject fdObject
239                        = jniCreateFileDescriptor(env, pDescriptors[i]);
240
241                if (env->ExceptionCheck()) {
242                    return -1;
243                }
244
245                env->SetObjectArrayElement(fdArray, i, fdObject);
246
247                if (env->ExceptionCheck()) {
248                    return -1;
249                }
250            }
251
252            env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
253
254            if (env->ExceptionCheck()) {
255                return -1;
256            }
257        }
258    }
259
260    return 0;
261}
262
263/**
264 * Reads data from a socket into buf, processing any ancillary data
265 * and adding it to thisJ.
266 *
267 * Returns the length of normal data read, or -1 if an exception has
268 * been thrown in this function.
269 */
270static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
271        void *buffer, size_t len)
272{
273    ssize_t ret;
274    struct msghdr msg;
275    struct iovec iv;
276    unsigned char *buf = (unsigned char *)buffer;
277    // Enough buffer for a pile of fd's. We throw an exception if
278    // this buffer is too small.
279    struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];
280
281    memset(&msg, 0, sizeof(msg));
282    memset(&iv, 0, sizeof(iv));
283
284    iv.iov_base = buf;
285    iv.iov_len = len;
286
287    msg.msg_iov = &iv;
288    msg.msg_iovlen = 1;
289    msg.msg_control = cmsgbuf;
290    msg.msg_controllen = sizeof(cmsgbuf);
291
292    do {
293        ret = recvmsg(fd, &msg, MSG_NOSIGNAL);
294    } while (ret < 0 && errno == EINTR);
295
296    if (ret < 0 && errno == EPIPE) {
297        // Treat this as an end of stream
298        return 0;
299    }
300
301    if (ret < 0) {
302        jniThrowIOException(env, errno);
303        return -1;
304    }
305
306    if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
307        // To us, any of the above flags are a fatal error
308
309        jniThrowException(env, "java/io/IOException",
310                "Unexpected error or truncation during recvmsg()");
311
312        return -1;
313    }
314
315    if (ret >= 0) {
316        socket_process_cmsg(env, thisJ, &msg);
317    }
318
319    return ret;
320}
321
322/**
323 * Writes all the data in the specified buffer to the specified socket.
324 *
325 * Returns 0 on success or -1 if an exception was thrown.
326 */
327static int socket_write_all(JNIEnv *env, jobject object, int fd,
328        void *buf, size_t len)
329{
330    ssize_t ret;
331    struct msghdr msg;
332    unsigned char *buffer = (unsigned char *)buf;
333    memset(&msg, 0, sizeof(msg));
334
335    jobjectArray outboundFds
336            = (jobjectArray)env->GetObjectField(
337                object, field_outboundFileDescriptors);
338
339    if (env->ExceptionCheck()) {
340        return -1;
341    }
342
343    struct cmsghdr *cmsg;
344    int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
345    int fds[countFds];
346    char msgbuf[CMSG_SPACE(countFds)];
347
348    // Add any pending outbound file descriptors to the message
349    if (outboundFds != NULL) {
350
351        if (env->ExceptionCheck()) {
352            return -1;
353        }
354
355        for (int i = 0; i < countFds; i++) {
356            jobject fdObject = env->GetObjectArrayElement(outboundFds, i);
357            if (env->ExceptionCheck()) {
358                return -1;
359            }
360
361            fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
362            if (env->ExceptionCheck()) {
363                return -1;
364            }
365        }
366
367        // See "man cmsg" really
368        msg.msg_control = msgbuf;
369        msg.msg_controllen = sizeof msgbuf;
370        cmsg = CMSG_FIRSTHDR(&msg);
371        cmsg->cmsg_level = SOL_SOCKET;
372        cmsg->cmsg_type = SCM_RIGHTS;
373        cmsg->cmsg_len = CMSG_LEN(sizeof fds);
374        memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
375    }
376
377    // We only write our msg_control during the first write
378    while (len > 0) {
379        struct iovec iv;
380        memset(&iv, 0, sizeof(iv));
381
382        iv.iov_base = buffer;
383        iv.iov_len = len;
384
385        msg.msg_iov = &iv;
386        msg.msg_iovlen = 1;
387
388        do {
389            ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
390        } while (ret < 0 && errno == EINTR);
391
392        if (ret < 0) {
393            jniThrowIOException(env, errno);
394            return -1;
395        }
396
397        buffer += ret;
398        len -= ret;
399
400        // Wipes out any msg_control too
401        memset(&msg, 0, sizeof(msg));
402    }
403
404    return 0;
405}
406
407static jint socket_read (JNIEnv *env, jobject object, jobject fileDescriptor)
408{
409    int fd;
410    int err;
411
412    if (fileDescriptor == NULL) {
413        jniThrowNullPointerException(env, NULL);
414        return (jint)-1;
415    }
416
417    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
418
419    if (env->ExceptionCheck()) {
420        return (jint)0;
421    }
422
423    unsigned char buf;
424
425    err = socket_read_all(env, object, fd, &buf, 1);
426
427    if (err < 0) {
428        jniThrowIOException(env, errno);
429        return (jint)0;
430    }
431
432    if (err == 0) {
433        // end of file
434        return (jint)-1;
435    }
436
437    return (jint)buf;
438}
439
440static jint socket_readba (JNIEnv *env, jobject object,
441        jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
442{
443    int fd;
444    jbyte* byteBuffer;
445    int ret;
446
447    if (fileDescriptor == NULL || buffer == NULL) {
448        jniThrowNullPointerException(env, NULL);
449        return (jint)-1;
450    }
451
452    if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
453        jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
454        return (jint)-1;
455    }
456
457    if (len == 0) {
458        // because socket_read_all returns 0 on EOF
459        return 0;
460    }
461
462    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
463
464    if (env->ExceptionCheck()) {
465        return (jint)-1;
466    }
467
468    byteBuffer = env->GetByteArrayElements(buffer, NULL);
469
470    if (NULL == byteBuffer) {
471        // an exception will have been thrown
472        return (jint)-1;
473    }
474
475    ret = socket_read_all(env, object,
476            fd, byteBuffer + off, len);
477
478    // A return of -1 above means an exception is pending
479
480    env->ReleaseByteArrayElements(buffer, byteBuffer, 0);
481
482    return (jint) ((ret == 0) ? -1 : ret);
483}
484
485static void socket_write (JNIEnv *env, jobject object,
486        jint b, jobject fileDescriptor)
487{
488    int fd;
489    int err;
490
491    if (fileDescriptor == NULL) {
492        jniThrowNullPointerException(env, NULL);
493        return;
494    }
495
496    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
497
498    if (env->ExceptionCheck()) {
499        return;
500    }
501
502    err = socket_write_all(env, object, fd, &b, 1);
503    UNUSED(err);
504    // A return of -1 above means an exception is pending
505}
506
507static void socket_writeba (JNIEnv *env, jobject object,
508        jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
509{
510    int fd;
511    int err;
512    jbyte* byteBuffer;
513
514    if (fileDescriptor == NULL || buffer == NULL) {
515        jniThrowNullPointerException(env, NULL);
516        return;
517    }
518
519    if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
520        jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
521        return;
522    }
523
524    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
525
526    if (env->ExceptionCheck()) {
527        return;
528    }
529
530    byteBuffer = env->GetByteArrayElements(buffer,NULL);
531
532    if (NULL == byteBuffer) {
533        // an exception will have been thrown
534        return;
535    }
536
537    err = socket_write_all(env, object, fd,
538            byteBuffer + off, len);
539    UNUSED(err);
540    // A return of -1 above means an exception is pending
541
542    env->ReleaseByteArrayElements(buffer, byteBuffer, JNI_ABORT);
543}
544
545static jobject socket_get_peer_credentials(JNIEnv *env,
546        jobject object, jobject fileDescriptor)
547{
548    int err;
549    int fd;
550
551    if (fileDescriptor == NULL) {
552        jniThrowNullPointerException(env, NULL);
553        return NULL;
554    }
555
556    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
557
558    if (env->ExceptionCheck()) {
559        return NULL;
560    }
561
562    struct ucred creds;
563
564    memset(&creds, 0, sizeof(creds));
565    socklen_t szCreds = sizeof(creds);
566
567    err = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
568
569    if (err < 0) {
570        jniThrowIOException(env, errno);
571        return NULL;
572    }
573
574    if (szCreds == 0) {
575        return NULL;
576    }
577
578    return env->NewObject(class_Credentials, method_CredentialsInit,
579            creds.pid, creds.uid, creds.gid);
580}
581
582/*
583 * JNI registration.
584 */
585static JNINativeMethod gMethods[] = {
586     /* name, signature, funcPtr */
587    {"connectLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V",
588                                                (void*)socket_connect_local},
589    {"bindLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V", (void*)socket_bind_local},
590    {"listen_native", "(Ljava/io/FileDescriptor;I)V", (void*)socket_listen},
591    {"accept", "(Ljava/io/FileDescriptor;Landroid/net/LocalSocketImpl;)Ljava/io/FileDescriptor;", (void*)socket_accept},
592    {"shutdown", "(Ljava/io/FileDescriptor;Z)V", (void*)socket_shutdown},
593    {"read_native", "(Ljava/io/FileDescriptor;)I", (void*) socket_read},
594    {"readba_native", "([BIILjava/io/FileDescriptor;)I", (void*) socket_readba},
595    {"writeba_native", "([BIILjava/io/FileDescriptor;)V", (void*) socket_writeba},
596    {"write_native", "(ILjava/io/FileDescriptor;)V", (void*) socket_write},
597    {"getPeerCredentials_native",
598            "(Ljava/io/FileDescriptor;)Landroid/net/Credentials;",
599            (void*) socket_get_peer_credentials}
600};
601
602int register_android_net_LocalSocketImpl(JNIEnv *env)
603{
604    jclass clazz;
605
606    clazz = env->FindClass("android/net/LocalSocketImpl");
607
608    if (clazz == NULL) {
609        goto error;
610    }
611
612    field_inboundFileDescriptors = env->GetFieldID(clazz,
613            "inboundFileDescriptors", "[Ljava/io/FileDescriptor;");
614
615    if (field_inboundFileDescriptors == NULL) {
616        goto error;
617    }
618
619    field_outboundFileDescriptors = env->GetFieldID(clazz,
620            "outboundFileDescriptors", "[Ljava/io/FileDescriptor;");
621
622    if (field_outboundFileDescriptors == NULL) {
623        goto error;
624    }
625
626    class_Credentials = env->FindClass("android/net/Credentials");
627
628    if (class_Credentials == NULL) {
629        goto error;
630    }
631
632    class_Credentials = (jclass)env->NewGlobalRef(class_Credentials);
633
634    class_FileDescriptor = env->FindClass("java/io/FileDescriptor");
635
636    if (class_FileDescriptor == NULL) {
637        goto error;
638    }
639
640    class_FileDescriptor = (jclass)env->NewGlobalRef(class_FileDescriptor);
641
642    method_CredentialsInit
643            = env->GetMethodID(class_Credentials, "<init>", "(III)V");
644
645    if (method_CredentialsInit == NULL) {
646        goto error;
647    }
648
649    return jniRegisterNativeMethods(env,
650        "android/net/LocalSocketImpl", gMethods, NELEM(gMethods));
651
652error:
653    ALOGE("Error registering android.net.LocalSocketImpl");
654    return -1;
655}
656
657};
658