android_net_LocalSocketImpl.cpp revision c1eaeb93379fc8940f595cf21844fa24b8cd1734
1/*
2 * Copyright (C) 2006 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "LocalSocketImpl"
18
19#include "JNIHelp.h"
20#include "jni.h"
21#include "utils/Log.h"
22#include "utils/misc.h"
23
24#include <stdio.h>
25#include <string.h>
26#include <sys/types.h>
27#include <sys/socket.h>
28#include <sys/un.h>
29#include <arpa/inet.h>
30#include <netinet/in.h>
31#include <stdlib.h>
32#include <errno.h>
33#include <unistd.h>
34#include <sys/ioctl.h>
35
36#include <cutils/sockets.h>
37#include <netinet/tcp.h>
38#include <ScopedUtfChars.h>
39
40namespace android {
41
42template <typename T>
43void UNUSED(T t) {}
44
45static jfieldID field_inboundFileDescriptors;
46static jfieldID field_outboundFileDescriptors;
47static jclass class_Credentials;
48static jclass class_FileDescriptor;
49static jmethodID method_CredentialsInit;
50
51/* private native void connectLocal(FileDescriptor fd,
52 * String name, int namespace) throws IOException
53 */
54static void
55socket_connect_local(JNIEnv *env, jobject object,
56                        jobject fileDescriptor, jstring name, jint namespaceId)
57{
58    int ret;
59    int fd;
60
61    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
62
63    if (env->ExceptionCheck()) {
64        return;
65    }
66
67    ScopedUtfChars nameUtf8(env, name);
68
69    ret = socket_local_client_connect(
70                fd,
71                nameUtf8.c_str(),
72                namespaceId,
73                SOCK_STREAM);
74
75    if (ret < 0) {
76        jniThrowIOException(env, errno);
77        return;
78    }
79}
80
81#define DEFAULT_BACKLOG 4
82
83/* private native void bindLocal(FileDescriptor fd, String name, namespace)
84 * throws IOException;
85 */
86
87static void
88socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
89                jstring name, jint namespaceId)
90{
91    int ret;
92    int fd;
93
94    if (name == NULL) {
95        jniThrowNullPointerException(env, NULL);
96        return;
97    }
98
99    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
100
101    if (env->ExceptionCheck()) {
102        return;
103    }
104
105    ScopedUtfChars nameUtf8(env, name);
106
107    ret = socket_local_server_bind(fd, nameUtf8.c_str(), namespaceId);
108
109    if (ret < 0) {
110        jniThrowIOException(env, errno);
111        return;
112    }
113}
114
115/* private native void shutdown(FileDescriptor fd, boolean shutdownInput) */
116
117static void
118socket_shutdown (JNIEnv *env, jobject object, jobject fileDescriptor,
119                    jboolean shutdownInput)
120{
121    int ret;
122    int fd;
123
124    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
125
126    if (env->ExceptionCheck()) {
127        return;
128    }
129
130    ret = shutdown(fd, shutdownInput ? SHUT_RD : SHUT_WR);
131
132    if (ret < 0) {
133        jniThrowIOException(env, errno);
134        return;
135    }
136}
137
138/**
139 * Processes ancillary data, handling only
140 * SCM_RIGHTS. Creates appropriate objects and sets appropriate
141 * fields in the LocalSocketImpl object. Returns 0 on success
142 * or -1 if an exception was thrown.
143 */
144static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
145{
146    struct cmsghdr *cmsgptr;
147
148    for (cmsgptr = CMSG_FIRSTHDR(pMsg);
149            cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
150
151        if (cmsgptr->cmsg_level != SOL_SOCKET) {
152            continue;
153        }
154
155        if (cmsgptr->cmsg_type == SCM_RIGHTS) {
156            int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
157            jobjectArray fdArray;
158            int count
159                = ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));
160
161            if (count < 0) {
162                jniThrowException(env, "java/io/IOException",
163                    "invalid cmsg length");
164                return -1;
165            }
166
167            fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
168
169            if (fdArray == NULL) {
170                return -1;
171            }
172
173            for (int i = 0; i < count; i++) {
174                jobject fdObject
175                        = jniCreateFileDescriptor(env, pDescriptors[i]);
176
177                if (env->ExceptionCheck()) {
178                    return -1;
179                }
180
181                env->SetObjectArrayElement(fdArray, i, fdObject);
182
183                if (env->ExceptionCheck()) {
184                    return -1;
185                }
186            }
187
188            env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
189
190            if (env->ExceptionCheck()) {
191                return -1;
192            }
193        }
194    }
195
196    return 0;
197}
198
199/**
200 * Reads data from a socket into buf, processing any ancillary data
201 * and adding it to thisJ.
202 *
203 * Returns the length of normal data read, or -1 if an exception has
204 * been thrown in this function.
205 */
206static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
207        void *buffer, size_t len)
208{
209    ssize_t ret;
210    struct msghdr msg;
211    struct iovec iv;
212    unsigned char *buf = (unsigned char *)buffer;
213    // Enough buffer for a pile of fd's. We throw an exception if
214    // this buffer is too small.
215    struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];
216
217    memset(&msg, 0, sizeof(msg));
218    memset(&iv, 0, sizeof(iv));
219
220    iv.iov_base = buf;
221    iv.iov_len = len;
222
223    msg.msg_iov = &iv;
224    msg.msg_iovlen = 1;
225    msg.msg_control = cmsgbuf;
226    msg.msg_controllen = sizeof(cmsgbuf);
227
228    do {
229        ret = recvmsg(fd, &msg, MSG_NOSIGNAL);
230    } while (ret < 0 && errno == EINTR);
231
232    if (ret < 0 && errno == EPIPE) {
233        // Treat this as an end of stream
234        return 0;
235    }
236
237    if (ret < 0) {
238        jniThrowIOException(env, errno);
239        return -1;
240    }
241
242    if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
243        // To us, any of the above flags are a fatal error
244
245        jniThrowException(env, "java/io/IOException",
246                "Unexpected error or truncation during recvmsg()");
247
248        return -1;
249    }
250
251    if (ret >= 0) {
252        socket_process_cmsg(env, thisJ, &msg);
253    }
254
255    return ret;
256}
257
258/**
259 * Writes all the data in the specified buffer to the specified socket.
260 *
261 * Returns 0 on success or -1 if an exception was thrown.
262 */
263static int socket_write_all(JNIEnv *env, jobject object, int fd,
264        void *buf, size_t len)
265{
266    ssize_t ret;
267    struct msghdr msg;
268    unsigned char *buffer = (unsigned char *)buf;
269    memset(&msg, 0, sizeof(msg));
270
271    jobjectArray outboundFds
272            = (jobjectArray)env->GetObjectField(
273                object, field_outboundFileDescriptors);
274
275    if (env->ExceptionCheck()) {
276        return -1;
277    }
278
279    struct cmsghdr *cmsg;
280    int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
281    int fds[countFds];
282    char msgbuf[CMSG_SPACE(countFds)];
283
284    // Add any pending outbound file descriptors to the message
285    if (outboundFds != NULL) {
286
287        if (env->ExceptionCheck()) {
288            return -1;
289        }
290
291        for (int i = 0; i < countFds; i++) {
292            jobject fdObject = env->GetObjectArrayElement(outboundFds, i);
293            if (env->ExceptionCheck()) {
294                return -1;
295            }
296
297            fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
298            if (env->ExceptionCheck()) {
299                return -1;
300            }
301        }
302
303        // See "man cmsg" really
304        msg.msg_control = msgbuf;
305        msg.msg_controllen = sizeof msgbuf;
306        cmsg = CMSG_FIRSTHDR(&msg);
307        cmsg->cmsg_level = SOL_SOCKET;
308        cmsg->cmsg_type = SCM_RIGHTS;
309        cmsg->cmsg_len = CMSG_LEN(sizeof fds);
310        memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
311    }
312
313    // We only write our msg_control during the first write
314    while (len > 0) {
315        struct iovec iv;
316        memset(&iv, 0, sizeof(iv));
317
318        iv.iov_base = buffer;
319        iv.iov_len = len;
320
321        msg.msg_iov = &iv;
322        msg.msg_iovlen = 1;
323
324        do {
325            ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
326        } while (ret < 0 && errno == EINTR);
327
328        if (ret < 0) {
329            jniThrowIOException(env, errno);
330            return -1;
331        }
332
333        buffer += ret;
334        len -= ret;
335
336        // Wipes out any msg_control too
337        memset(&msg, 0, sizeof(msg));
338    }
339
340    return 0;
341}
342
343static jint socket_read (JNIEnv *env, jobject object, jobject fileDescriptor)
344{
345    int fd;
346    int err;
347
348    if (fileDescriptor == NULL) {
349        jniThrowNullPointerException(env, NULL);
350        return (jint)-1;
351    }
352
353    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
354
355    if (env->ExceptionCheck()) {
356        return (jint)0;
357    }
358
359    unsigned char buf;
360
361    err = socket_read_all(env, object, fd, &buf, 1);
362
363    if (err < 0) {
364        jniThrowIOException(env, errno);
365        return (jint)0;
366    }
367
368    if (err == 0) {
369        // end of file
370        return (jint)-1;
371    }
372
373    return (jint)buf;
374}
375
376static jint socket_readba (JNIEnv *env, jobject object,
377        jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
378{
379    int fd;
380    jbyte* byteBuffer;
381    int ret;
382
383    if (fileDescriptor == NULL || buffer == NULL) {
384        jniThrowNullPointerException(env, NULL);
385        return (jint)-1;
386    }
387
388    if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
389        jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
390        return (jint)-1;
391    }
392
393    if (len == 0) {
394        // because socket_read_all returns 0 on EOF
395        return 0;
396    }
397
398    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
399
400    if (env->ExceptionCheck()) {
401        return (jint)-1;
402    }
403
404    byteBuffer = env->GetByteArrayElements(buffer, NULL);
405
406    if (NULL == byteBuffer) {
407        // an exception will have been thrown
408        return (jint)-1;
409    }
410
411    ret = socket_read_all(env, object,
412            fd, byteBuffer + off, len);
413
414    // A return of -1 above means an exception is pending
415
416    env->ReleaseByteArrayElements(buffer, byteBuffer, 0);
417
418    return (jint) ((ret == 0) ? -1 : ret);
419}
420
421static void socket_write (JNIEnv *env, jobject object,
422        jint b, jobject fileDescriptor)
423{
424    int fd;
425    int err;
426
427    if (fileDescriptor == NULL) {
428        jniThrowNullPointerException(env, NULL);
429        return;
430    }
431
432    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
433
434    if (env->ExceptionCheck()) {
435        return;
436    }
437
438    err = socket_write_all(env, object, fd, &b, 1);
439    UNUSED(err);
440    // A return of -1 above means an exception is pending
441}
442
443static void socket_writeba (JNIEnv *env, jobject object,
444        jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
445{
446    int fd;
447    int err;
448    jbyte* byteBuffer;
449
450    if (fileDescriptor == NULL || buffer == NULL) {
451        jniThrowNullPointerException(env, NULL);
452        return;
453    }
454
455    if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
456        jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
457        return;
458    }
459
460    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
461
462    if (env->ExceptionCheck()) {
463        return;
464    }
465
466    byteBuffer = env->GetByteArrayElements(buffer,NULL);
467
468    if (NULL == byteBuffer) {
469        // an exception will have been thrown
470        return;
471    }
472
473    err = socket_write_all(env, object, fd,
474            byteBuffer + off, len);
475    UNUSED(err);
476    // A return of -1 above means an exception is pending
477
478    env->ReleaseByteArrayElements(buffer, byteBuffer, JNI_ABORT);
479}
480
481static jobject socket_get_peer_credentials(JNIEnv *env,
482        jobject object, jobject fileDescriptor)
483{
484    int err;
485    int fd;
486
487    if (fileDescriptor == NULL) {
488        jniThrowNullPointerException(env, NULL);
489        return NULL;
490    }
491
492    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
493
494    if (env->ExceptionCheck()) {
495        return NULL;
496    }
497
498    struct ucred creds;
499
500    memset(&creds, 0, sizeof(creds));
501    socklen_t szCreds = sizeof(creds);
502
503    err = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
504
505    if (err < 0) {
506        jniThrowIOException(env, errno);
507        return NULL;
508    }
509
510    if (szCreds == 0) {
511        return NULL;
512    }
513
514    return env->NewObject(class_Credentials, method_CredentialsInit,
515            creds.pid, creds.uid, creds.gid);
516}
517
518/*
519 * JNI registration.
520 */
521static JNINativeMethod gMethods[] = {
522     /* name, signature, funcPtr */
523    {"connectLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V",
524                                                (void*)socket_connect_local},
525    {"bindLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V", (void*)socket_bind_local},
526    {"shutdown", "(Ljava/io/FileDescriptor;Z)V", (void*)socket_shutdown},
527    {"read_native", "(Ljava/io/FileDescriptor;)I", (void*) socket_read},
528    {"readba_native", "([BIILjava/io/FileDescriptor;)I", (void*) socket_readba},
529    {"writeba_native", "([BIILjava/io/FileDescriptor;)V", (void*) socket_writeba},
530    {"write_native", "(ILjava/io/FileDescriptor;)V", (void*) socket_write},
531    {"getPeerCredentials_native",
532            "(Ljava/io/FileDescriptor;)Landroid/net/Credentials;",
533            (void*) socket_get_peer_credentials}
534};
535
536int register_android_net_LocalSocketImpl(JNIEnv *env)
537{
538    jclass clazz;
539
540    clazz = env->FindClass("android/net/LocalSocketImpl");
541
542    if (clazz == NULL) {
543        goto error;
544    }
545
546    field_inboundFileDescriptors = env->GetFieldID(clazz,
547            "inboundFileDescriptors", "[Ljava/io/FileDescriptor;");
548
549    if (field_inboundFileDescriptors == NULL) {
550        goto error;
551    }
552
553    field_outboundFileDescriptors = env->GetFieldID(clazz,
554            "outboundFileDescriptors", "[Ljava/io/FileDescriptor;");
555
556    if (field_outboundFileDescriptors == NULL) {
557        goto error;
558    }
559
560    class_Credentials = env->FindClass("android/net/Credentials");
561
562    if (class_Credentials == NULL) {
563        goto error;
564    }
565
566    class_Credentials = (jclass)env->NewGlobalRef(class_Credentials);
567
568    class_FileDescriptor = env->FindClass("java/io/FileDescriptor");
569
570    if (class_FileDescriptor == NULL) {
571        goto error;
572    }
573
574    class_FileDescriptor = (jclass)env->NewGlobalRef(class_FileDescriptor);
575
576    method_CredentialsInit
577            = env->GetMethodID(class_Credentials, "<init>", "(III)V");
578
579    if (method_CredentialsInit == NULL) {
580        goto error;
581    }
582
583    return jniRegisterNativeMethods(env,
584        "android/net/LocalSocketImpl", gMethods, NELEM(gMethods));
585
586error:
587    ALOGE("Error registering android.net.LocalSocketImpl");
588    return -1;
589}
590
591};
592