1/*
2 * Copyright (C) 2006 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define LOG_TAG "LocalSocketImpl"
18
19#include "JNIHelp.h"
20#include "jni.h"
21#include "utils/Log.h"
22#include "utils/misc.h"
23
24#include <stdio.h>
25#include <string.h>
26#include <sys/types.h>
27#include <sys/socket.h>
28#include <sys/un.h>
29#include <arpa/inet.h>
30#include <netinet/in.h>
31#include <stdlib.h>
32#include <errno.h>
33#include <unistd.h>
34#include <sys/ioctl.h>
35
36#include <cutils/sockets.h>
37#include <netinet/tcp.h>
38#include <ScopedUtfChars.h>
39
40namespace android {
41
42template <typename T>
43void UNUSED(T t) {}
44
45static jfieldID field_inboundFileDescriptors;
46static jfieldID field_outboundFileDescriptors;
47static jclass class_Credentials;
48static jclass class_FileDescriptor;
49static jmethodID method_CredentialsInit;
50
51/* private native void connectLocal(FileDescriptor fd,
52 * String name, int namespace) throws IOException
53 */
54static void
55socket_connect_local(JNIEnv *env, jobject object,
56                        jobject fileDescriptor, jstring name, jint namespaceId)
57{
58    int ret;
59    int fd;
60
61    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
62
63    if (env->ExceptionCheck()) {
64        return;
65    }
66
67    ScopedUtfChars nameUtf8(env, name);
68
69    ret = socket_local_client_connect(
70                fd,
71                nameUtf8.c_str(),
72                namespaceId,
73                SOCK_STREAM);
74
75    if (ret < 0) {
76        jniThrowIOException(env, errno);
77        return;
78    }
79}
80
81#define DEFAULT_BACKLOG 4
82
83/* private native void bindLocal(FileDescriptor fd, String name, namespace)
84 * throws IOException;
85 */
86
87static void
88socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
89                jstring name, jint namespaceId)
90{
91    int ret;
92    int fd;
93
94    if (name == NULL) {
95        jniThrowNullPointerException(env, NULL);
96        return;
97    }
98
99    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
100
101    if (env->ExceptionCheck()) {
102        return;
103    }
104
105    ScopedUtfChars nameUtf8(env, name);
106
107    ret = socket_local_server_bind(fd, nameUtf8.c_str(), namespaceId);
108
109    if (ret < 0) {
110        jniThrowIOException(env, errno);
111        return;
112    }
113}
114
115/**
116 * Processes ancillary data, handling only
117 * SCM_RIGHTS. Creates appropriate objects and sets appropriate
118 * fields in the LocalSocketImpl object. Returns 0 on success
119 * or -1 if an exception was thrown.
120 */
121static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
122{
123    struct cmsghdr *cmsgptr;
124
125    for (cmsgptr = CMSG_FIRSTHDR(pMsg);
126            cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
127
128        if (cmsgptr->cmsg_level != SOL_SOCKET) {
129            continue;
130        }
131
132        if (cmsgptr->cmsg_type == SCM_RIGHTS) {
133            int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
134            jobjectArray fdArray;
135            int count
136                = ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));
137
138            if (count < 0) {
139                jniThrowException(env, "java/io/IOException",
140                    "invalid cmsg length");
141                return -1;
142            }
143
144            fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
145
146            if (fdArray == NULL) {
147                return -1;
148            }
149
150            for (int i = 0; i < count; i++) {
151                jobject fdObject
152                        = jniCreateFileDescriptor(env, pDescriptors[i]);
153
154                if (env->ExceptionCheck()) {
155                    return -1;
156                }
157
158                env->SetObjectArrayElement(fdArray, i, fdObject);
159
160                if (env->ExceptionCheck()) {
161                    return -1;
162                }
163            }
164
165            env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
166
167            if (env->ExceptionCheck()) {
168                return -1;
169            }
170        }
171    }
172
173    return 0;
174}
175
176/**
177 * Reads data from a socket into buf, processing any ancillary data
178 * and adding it to thisJ.
179 *
180 * Returns the length of normal data read, or -1 if an exception has
181 * been thrown in this function.
182 */
183static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
184        void *buffer, size_t len)
185{
186    ssize_t ret;
187    struct msghdr msg;
188    struct iovec iv;
189    unsigned char *buf = (unsigned char *)buffer;
190    // Enough buffer for a pile of fd's. We throw an exception if
191    // this buffer is too small.
192    struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];
193
194    memset(&msg, 0, sizeof(msg));
195    memset(&iv, 0, sizeof(iv));
196
197    iv.iov_base = buf;
198    iv.iov_len = len;
199
200    msg.msg_iov = &iv;
201    msg.msg_iovlen = 1;
202    msg.msg_control = cmsgbuf;
203    msg.msg_controllen = sizeof(cmsgbuf);
204
205    do {
206        ret = recvmsg(fd, &msg, MSG_NOSIGNAL);
207    } while (ret < 0 && errno == EINTR);
208
209    if (ret < 0 && errno == EPIPE) {
210        // Treat this as an end of stream
211        return 0;
212    }
213
214    if (ret < 0) {
215        jniThrowIOException(env, errno);
216        return -1;
217    }
218
219    if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
220        // To us, any of the above flags are a fatal error
221
222        jniThrowException(env, "java/io/IOException",
223                "Unexpected error or truncation during recvmsg()");
224
225        return -1;
226    }
227
228    if (ret >= 0) {
229        socket_process_cmsg(env, thisJ, &msg);
230    }
231
232    return ret;
233}
234
235/**
236 * Writes all the data in the specified buffer to the specified socket.
237 *
238 * Returns 0 on success or -1 if an exception was thrown.
239 */
240static int socket_write_all(JNIEnv *env, jobject object, int fd,
241        void *buf, size_t len)
242{
243    ssize_t ret;
244    struct msghdr msg;
245    unsigned char *buffer = (unsigned char *)buf;
246    memset(&msg, 0, sizeof(msg));
247
248    jobjectArray outboundFds
249            = (jobjectArray)env->GetObjectField(
250                object, field_outboundFileDescriptors);
251
252    if (env->ExceptionCheck()) {
253        return -1;
254    }
255
256    struct cmsghdr *cmsg;
257    int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
258    int fds[countFds];
259    char msgbuf[CMSG_SPACE(countFds)];
260
261    // Add any pending outbound file descriptors to the message
262    if (outboundFds != NULL) {
263
264        if (env->ExceptionCheck()) {
265            return -1;
266        }
267
268        for (int i = 0; i < countFds; i++) {
269            jobject fdObject = env->GetObjectArrayElement(outboundFds, i);
270            if (env->ExceptionCheck()) {
271                return -1;
272            }
273
274            fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
275            if (env->ExceptionCheck()) {
276                return -1;
277            }
278        }
279
280        // See "man cmsg" really
281        msg.msg_control = msgbuf;
282        msg.msg_controllen = sizeof msgbuf;
283        cmsg = CMSG_FIRSTHDR(&msg);
284        cmsg->cmsg_level = SOL_SOCKET;
285        cmsg->cmsg_type = SCM_RIGHTS;
286        cmsg->cmsg_len = CMSG_LEN(sizeof fds);
287        memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
288    }
289
290    // We only write our msg_control during the first write
291    while (len > 0) {
292        struct iovec iv;
293        memset(&iv, 0, sizeof(iv));
294
295        iv.iov_base = buffer;
296        iv.iov_len = len;
297
298        msg.msg_iov = &iv;
299        msg.msg_iovlen = 1;
300
301        do {
302            ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
303        } while (ret < 0 && errno == EINTR);
304
305        if (ret < 0) {
306            jniThrowIOException(env, errno);
307            return -1;
308        }
309
310        buffer += ret;
311        len -= ret;
312
313        // Wipes out any msg_control too
314        memset(&msg, 0, sizeof(msg));
315    }
316
317    return 0;
318}
319
320static jint socket_read (JNIEnv *env, jobject object, jobject fileDescriptor)
321{
322    int fd;
323    int err;
324
325    if (fileDescriptor == NULL) {
326        jniThrowNullPointerException(env, NULL);
327        return (jint)-1;
328    }
329
330    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
331
332    if (env->ExceptionCheck()) {
333        return (jint)0;
334    }
335
336    unsigned char buf;
337
338    err = socket_read_all(env, object, fd, &buf, 1);
339
340    if (err < 0) {
341        jniThrowIOException(env, errno);
342        return (jint)0;
343    }
344
345    if (err == 0) {
346        // end of file
347        return (jint)-1;
348    }
349
350    return (jint)buf;
351}
352
353static jint socket_readba (JNIEnv *env, jobject object,
354        jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
355{
356    int fd;
357    jbyte* byteBuffer;
358    int ret;
359
360    if (fileDescriptor == NULL || buffer == NULL) {
361        jniThrowNullPointerException(env, NULL);
362        return (jint)-1;
363    }
364
365    if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
366        jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
367        return (jint)-1;
368    }
369
370    if (len == 0) {
371        // because socket_read_all returns 0 on EOF
372        return 0;
373    }
374
375    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
376
377    if (env->ExceptionCheck()) {
378        return (jint)-1;
379    }
380
381    byteBuffer = env->GetByteArrayElements(buffer, NULL);
382
383    if (NULL == byteBuffer) {
384        // an exception will have been thrown
385        return (jint)-1;
386    }
387
388    ret = socket_read_all(env, object,
389            fd, byteBuffer + off, len);
390
391    // A return of -1 above means an exception is pending
392
393    env->ReleaseByteArrayElements(buffer, byteBuffer, 0);
394
395    return (jint) ((ret == 0) ? -1 : ret);
396}
397
398static void socket_write (JNIEnv *env, jobject object,
399        jint b, jobject fileDescriptor)
400{
401    int fd;
402    int err;
403
404    if (fileDescriptor == NULL) {
405        jniThrowNullPointerException(env, NULL);
406        return;
407    }
408
409    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
410
411    if (env->ExceptionCheck()) {
412        return;
413    }
414
415    err = socket_write_all(env, object, fd, &b, 1);
416    UNUSED(err);
417    // A return of -1 above means an exception is pending
418}
419
420static void socket_writeba (JNIEnv *env, jobject object,
421        jbyteArray buffer, jint off, jint len, jobject fileDescriptor)
422{
423    int fd;
424    int err;
425    jbyte* byteBuffer;
426
427    if (fileDescriptor == NULL || buffer == NULL) {
428        jniThrowNullPointerException(env, NULL);
429        return;
430    }
431
432    if (off < 0 || len < 0 || (off + len) > env->GetArrayLength(buffer)) {
433        jniThrowException(env, "java/lang/ArrayIndexOutOfBoundsException", NULL);
434        return;
435    }
436
437    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
438
439    if (env->ExceptionCheck()) {
440        return;
441    }
442
443    byteBuffer = env->GetByteArrayElements(buffer,NULL);
444
445    if (NULL == byteBuffer) {
446        // an exception will have been thrown
447        return;
448    }
449
450    err = socket_write_all(env, object, fd,
451            byteBuffer + off, len);
452    UNUSED(err);
453    // A return of -1 above means an exception is pending
454
455    env->ReleaseByteArrayElements(buffer, byteBuffer, JNI_ABORT);
456}
457
458static jobject socket_get_peer_credentials(JNIEnv *env,
459        jobject object, jobject fileDescriptor)
460{
461    int err;
462    int fd;
463
464    if (fileDescriptor == NULL) {
465        jniThrowNullPointerException(env, NULL);
466        return NULL;
467    }
468
469    fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
470
471    if (env->ExceptionCheck()) {
472        return NULL;
473    }
474
475    struct ucred creds;
476
477    memset(&creds, 0, sizeof(creds));
478    socklen_t szCreds = sizeof(creds);
479
480    err = getsockopt(fd, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
481
482    if (err < 0) {
483        jniThrowIOException(env, errno);
484        return NULL;
485    }
486
487    if (szCreds == 0) {
488        return NULL;
489    }
490
491    return env->NewObject(class_Credentials, method_CredentialsInit,
492            creds.pid, creds.uid, creds.gid);
493}
494
495/*
496 * JNI registration.
497 */
498static const JNINativeMethod gMethods[] = {
499     /* name, signature, funcPtr */
500    {"connectLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V",
501                                                (void*)socket_connect_local},
502    {"bindLocal", "(Ljava/io/FileDescriptor;Ljava/lang/String;I)V", (void*)socket_bind_local},
503    {"read_native", "(Ljava/io/FileDescriptor;)I", (void*) socket_read},
504    {"readba_native", "([BIILjava/io/FileDescriptor;)I", (void*) socket_readba},
505    {"writeba_native", "([BIILjava/io/FileDescriptor;)V", (void*) socket_writeba},
506    {"write_native", "(ILjava/io/FileDescriptor;)V", (void*) socket_write},
507    {"getPeerCredentials_native",
508            "(Ljava/io/FileDescriptor;)Landroid/net/Credentials;",
509            (void*) socket_get_peer_credentials}
510};
511
512int register_android_net_LocalSocketImpl(JNIEnv *env)
513{
514    jclass clazz;
515
516    clazz = env->FindClass("android/net/LocalSocketImpl");
517
518    if (clazz == NULL) {
519        goto error;
520    }
521
522    field_inboundFileDescriptors = env->GetFieldID(clazz,
523            "inboundFileDescriptors", "[Ljava/io/FileDescriptor;");
524
525    if (field_inboundFileDescriptors == NULL) {
526        goto error;
527    }
528
529    field_outboundFileDescriptors = env->GetFieldID(clazz,
530            "outboundFileDescriptors", "[Ljava/io/FileDescriptor;");
531
532    if (field_outboundFileDescriptors == NULL) {
533        goto error;
534    }
535
536    class_Credentials = env->FindClass("android/net/Credentials");
537
538    if (class_Credentials == NULL) {
539        goto error;
540    }
541
542    class_Credentials = (jclass)env->NewGlobalRef(class_Credentials);
543
544    class_FileDescriptor = env->FindClass("java/io/FileDescriptor");
545
546    if (class_FileDescriptor == NULL) {
547        goto error;
548    }
549
550    class_FileDescriptor = (jclass)env->NewGlobalRef(class_FileDescriptor);
551
552    method_CredentialsInit
553            = env->GetMethodID(class_Credentials, "<init>", "(III)V");
554
555    if (method_CredentialsInit == NULL) {
556        goto error;
557    }
558
559    return jniRegisterNativeMethods(env,
560        "android/net/LocalSocketImpl", gMethods, NELEM(gMethods));
561
562error:
563    ALOGE("Error registering android.net.LocalSocketImpl");
564    return -1;
565}
566
567};
568