1#include <alloca.h>
2#include <errno.h>
3#include <malloc.h>
4#include <pthread.h>
5#include <signal.h>
6#include <string.h>
7#include <arpa/inet.h>
8#include <sys/socket.h>
9#include <sys/types.h>
10
11#define LOG_TAG "SocketClient"
12#include <cutils/log.h>
13
14#include <sysutils/SocketClient.h>
15
16SocketClient::SocketClient(int socket, bool owned) {
17    init(socket, owned, false);
18}
19
20SocketClient::SocketClient(int socket, bool owned, bool useCmdNum) {
21    init(socket, owned, useCmdNum);
22}
23
24void SocketClient::init(int socket, bool owned, bool useCmdNum) {
25    mSocket = socket;
26    mSocketOwned = owned;
27    mUseCmdNum = useCmdNum;
28    pthread_mutex_init(&mWriteMutex, NULL);
29    pthread_mutex_init(&mRefCountMutex, NULL);
30    mPid = -1;
31    mUid = -1;
32    mGid = -1;
33    mRefCount = 1;
34    mCmdNum = 0;
35
36    struct ucred creds;
37    socklen_t szCreds = sizeof(creds);
38    memset(&creds, 0, szCreds);
39
40    int err = getsockopt(socket, SOL_SOCKET, SO_PEERCRED, &creds, &szCreds);
41    if (err == 0) {
42        mPid = creds.pid;
43        mUid = creds.uid;
44        mGid = creds.gid;
45    }
46}
47
48SocketClient::~SocketClient() {
49    if (mSocketOwned) {
50        close(mSocket);
51    }
52}
53
54int SocketClient::sendMsg(int code, const char *msg, bool addErrno) {
55    return sendMsg(code, msg, addErrno, mUseCmdNum);
56}
57
58int SocketClient::sendMsg(int code, const char *msg, bool addErrno, bool useCmdNum) {
59    char *buf;
60    int ret = 0;
61
62    if (addErrno) {
63        if (useCmdNum) {
64            ret = asprintf(&buf, "%d %d %s (%s)", code, getCmdNum(), msg, strerror(errno));
65        } else {
66            ret = asprintf(&buf, "%d %s (%s)", code, msg, strerror(errno));
67        }
68    } else {
69        if (useCmdNum) {
70            ret = asprintf(&buf, "%d %d %s", code, getCmdNum(), msg);
71        } else {
72            ret = asprintf(&buf, "%d %s", code, msg);
73        }
74    }
75    // Send the zero-terminated message
76    if (ret != -1) {
77        ret = sendMsg(buf);
78        free(buf);
79    }
80    return ret;
81}
82
83// send 3-digit code, null, binary-length, binary data
84int SocketClient::sendBinaryMsg(int code, const void *data, int len) {
85
86    // 4 bytes for the code & null + 4 bytes for the len
87    char buf[8];
88    // Write the code
89    snprintf(buf, 4, "%.3d", code);
90    // Write the len
91    uint32_t tmp = htonl(len);
92    memcpy(buf + 4, &tmp, sizeof(uint32_t));
93
94    struct iovec vec[2];
95    vec[0].iov_base = (void *) buf;
96    vec[0].iov_len = sizeof(buf);
97    vec[1].iov_base = (void *) data;
98    vec[1].iov_len = len;
99
100    pthread_mutex_lock(&mWriteMutex);
101    int result = sendDataLockedv(vec, (len > 0) ? 2 : 1);
102    pthread_mutex_unlock(&mWriteMutex);
103
104    return result;
105}
106
107// Sends the code (c-string null-terminated).
108int SocketClient::sendCode(int code) {
109    char buf[4];
110    snprintf(buf, sizeof(buf), "%.3d", code);
111    return sendData(buf, sizeof(buf));
112}
113
114char *SocketClient::quoteArg(const char *arg) {
115    int len = strlen(arg);
116    char *result = (char *)malloc(len * 2 + 3);
117    char *current = result;
118    const char *end = arg + len;
119    char *oldresult;
120
121    if(result == NULL) {
122        SLOGW("malloc error (%s)", strerror(errno));
123        return NULL;
124    }
125
126    *(current++) = '"';
127    while (arg < end) {
128        switch (*arg) {
129        case '\\':
130        case '"':
131            *(current++) = '\\'; // fallthrough
132        default:
133            *(current++) = *(arg++);
134        }
135    }
136    *(current++) = '"';
137    *(current++) = '\0';
138    oldresult = result; // save pointer in case realloc fails
139    result = (char *)realloc(result, current-result);
140    return result ? result : oldresult;
141}
142
143
144int SocketClient::sendMsg(const char *msg) {
145    // Send the message including null character
146    if (sendData(msg, strlen(msg) + 1) != 0) {
147        SLOGW("Unable to send msg '%s'", msg);
148        return -1;
149    }
150    return 0;
151}
152
153int SocketClient::sendData(const void *data, int len) {
154    struct iovec vec[1];
155    vec[0].iov_base = (void *) data;
156    vec[0].iov_len = len;
157
158    pthread_mutex_lock(&mWriteMutex);
159    int rc = sendDataLockedv(vec, 1);
160    pthread_mutex_unlock(&mWriteMutex);
161
162    return rc;
163}
164
165int SocketClient::sendDatav(struct iovec *iov, int iovcnt) {
166    pthread_mutex_lock(&mWriteMutex);
167    int rc = sendDataLockedv(iov, iovcnt);
168    pthread_mutex_unlock(&mWriteMutex);
169
170    return rc;
171}
172
173int SocketClient::sendDataLockedv(struct iovec *iov, int iovcnt) {
174
175    if (mSocket < 0) {
176        errno = EHOSTUNREACH;
177        return -1;
178    }
179
180    if (iovcnt <= 0) {
181        return 0;
182    }
183
184    int ret = 0;
185    int e = 0; // SLOGW and sigaction are not inert regarding errno
186    int current = 0;
187
188    struct sigaction new_action, old_action;
189    memset(&new_action, 0, sizeof(new_action));
190    new_action.sa_handler = SIG_IGN;
191    sigaction(SIGPIPE, &new_action, &old_action);
192
193    for (;;) {
194        ssize_t rc = TEMP_FAILURE_RETRY(
195            writev(mSocket, iov + current, iovcnt - current));
196
197        if (rc > 0) {
198            size_t written = rc;
199            while ((current < iovcnt) && (written >= iov[current].iov_len)) {
200                written -= iov[current].iov_len;
201                current++;
202            }
203            if (current == iovcnt) {
204                break;
205            }
206            iov[current].iov_base = (char *)iov[current].iov_base + written;
207            iov[current].iov_len -= written;
208            continue;
209        }
210
211        if (rc == 0) {
212            e = EIO;
213            SLOGW("0 length write :(");
214        } else {
215            e = errno;
216            SLOGW("write error (%s)", strerror(e));
217        }
218        ret = -1;
219        break;
220    }
221
222    sigaction(SIGPIPE, &old_action, &new_action);
223
224    if (e != 0) {
225        errno = e;
226    }
227    return ret;
228}
229
230void SocketClient::incRef() {
231    pthread_mutex_lock(&mRefCountMutex);
232    mRefCount++;
233    pthread_mutex_unlock(&mRefCountMutex);
234}
235
236bool SocketClient::decRef() {
237    bool deleteSelf = false;
238    pthread_mutex_lock(&mRefCountMutex);
239    mRefCount--;
240    if (mRefCount == 0) {
241        deleteSelf = true;
242    } else if (mRefCount < 0) {
243        SLOGE("SocketClient refcount went negative!");
244    }
245    pthread_mutex_unlock(&mRefCountMutex);
246    if (deleteSelf) {
247        delete this;
248    }
249    return deleteSelf;
250}
251