1/*
2 * Copyright (C) 2014 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "NetdClient.h"
18
19#include <arpa/inet.h>
20#include <errno.h>
21#include <math.h>
22#include <sys/socket.h>
23#include <unistd.h>
24
25#include <atomic>
26
27#include "Fwmark.h"
28#include "FwmarkClient.h"
29#include "FwmarkCommand.h"
30#include "resolv_netid.h"
31#include "Stopwatch.h"
32
33namespace {
34
35std::atomic_uint netIdForProcess(NETID_UNSET);
36std::atomic_uint netIdForResolv(NETID_UNSET);
37
38typedef int (*Accept4FunctionType)(int, sockaddr*, socklen_t*, int);
39typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
40typedef int (*SocketFunctionType)(int, int, int);
41typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
42
43// These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
44// it's okay that they are read later at runtime without a lock.
45Accept4FunctionType libcAccept4 = 0;
46ConnectFunctionType libcConnect = 0;
47SocketFunctionType libcSocket = 0;
48
49int closeFdAndSetErrno(int fd, int error) {
50    close(fd);
51    errno = -error;
52    return -1;
53}
54
55int netdClientAccept4(int sockfd, sockaddr* addr, socklen_t* addrlen, int flags) {
56    int acceptedSocket = libcAccept4(sockfd, addr, addrlen, flags);
57    if (acceptedSocket == -1) {
58        return -1;
59    }
60    int family;
61    if (addr) {
62        family = addr->sa_family;
63    } else {
64        socklen_t familyLen = sizeof(family);
65        if (getsockopt(acceptedSocket, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
66            return closeFdAndSetErrno(acceptedSocket, -errno);
67        }
68    }
69    if (FwmarkClient::shouldSetFwmark(family)) {
70        FwmarkCommand command = {FwmarkCommand::ON_ACCEPT, 0, 0};
71        if (int error = FwmarkClient().send(&command, acceptedSocket, nullptr)) {
72            return closeFdAndSetErrno(acceptedSocket, error);
73        }
74    }
75    return acceptedSocket;
76}
77
78int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
79    const bool shouldSetFwmark = (sockfd >= 0) && addr
80            && FwmarkClient::shouldSetFwmark(addr->sa_family);
81    if (shouldSetFwmark) {
82        FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0, 0};
83        if (int error = FwmarkClient().send(&command, sockfd, nullptr)) {
84            errno = -error;
85            return -1;
86        }
87    }
88    // Latency measurement does not include time of sending commands to Fwmark
89    Stopwatch s;
90    const int ret = libcConnect(sockfd, addr, addrlen);
91    // Save errno so it isn't clobbered by sending ON_CONNECT_COMPLETE
92    const int connectErrno = errno;
93    const unsigned latencyMs = lround(s.timeTaken());
94    // Send an ON_CONNECT_COMPLETE command that includes sockaddr and connect latency for reporting
95    if (shouldSetFwmark && FwmarkClient::shouldReportConnectComplete(addr->sa_family)) {
96        FwmarkConnectInfo connectInfo(ret == 0 ? 0 : connectErrno, latencyMs, addr);
97        // TODO: get the netId from the socket mark once we have continuous benchmark runs
98        FwmarkCommand command = {FwmarkCommand::ON_CONNECT_COMPLETE, /* netId (ignored) */ 0,
99                /* uid (filled in by the server) */ 0};
100        // Ignore return value since it's only used for logging
101        FwmarkClient().send(&command, sockfd, &connectInfo);
102    }
103    errno = connectErrno;
104    return ret;
105}
106
107int netdClientSocket(int domain, int type, int protocol) {
108    int socketFd = libcSocket(domain, type, protocol);
109    if (socketFd == -1) {
110        return -1;
111    }
112    unsigned netId = netIdForProcess;
113    if (netId != NETID_UNSET && FwmarkClient::shouldSetFwmark(domain)) {
114        if (int error = setNetworkForSocket(netId, socketFd)) {
115            return closeFdAndSetErrno(socketFd, error);
116        }
117    }
118    return socketFd;
119}
120
121unsigned getNetworkForResolv(unsigned netId) {
122    if (netId != NETID_UNSET) {
123        return netId;
124    }
125    netId = netIdForProcess;
126    if (netId != NETID_UNSET) {
127        return netId;
128    }
129    return netIdForResolv;
130}
131
132int setNetworkForTarget(unsigned netId, std::atomic_uint* target) {
133    if (netId == NETID_UNSET) {
134        *target = netId;
135        return 0;
136    }
137    // Verify that we are allowed to use |netId|, by creating a socket and trying to have it marked
138    // with the netId. Call libcSocket() directly; else the socket creation (via netdClientSocket())
139    // might itself cause another check with the fwmark server, which would be wasteful.
140    int socketFd;
141    if (libcSocket) {
142        socketFd = libcSocket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
143    } else {
144        socketFd = socket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
145    }
146    if (socketFd < 0) {
147        return -errno;
148    }
149    int error = setNetworkForSocket(netId, socketFd);
150    if (!error) {
151        *target = netId;
152    }
153    close(socketFd);
154    return error;
155}
156
157}  // namespace
158
159// accept() just calls accept4(..., 0), so there's no need to handle accept() separately.
160extern "C" void netdClientInitAccept4(Accept4FunctionType* function) {
161    if (function && *function) {
162        libcAccept4 = *function;
163        *function = netdClientAccept4;
164    }
165}
166
167extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
168    if (function && *function) {
169        libcConnect = *function;
170        *function = netdClientConnect;
171    }
172}
173
174extern "C" void netdClientInitSocket(SocketFunctionType* function) {
175    if (function && *function) {
176        libcSocket = *function;
177        *function = netdClientSocket;
178    }
179}
180
181extern "C" void netdClientInitNetIdForResolv(NetIdForResolvFunctionType* function) {
182    if (function) {
183        *function = getNetworkForResolv;
184    }
185}
186
187extern "C" int getNetworkForSocket(unsigned* netId, int socketFd) {
188    if (!netId || socketFd < 0) {
189        return -EBADF;
190    }
191    Fwmark fwmark;
192    socklen_t fwmarkLen = sizeof(fwmark.intValue);
193    if (getsockopt(socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
194        return -errno;
195    }
196    *netId = fwmark.netId;
197    return 0;
198}
199
200extern "C" unsigned getNetworkForProcess() {
201    return netIdForProcess;
202}
203
204extern "C" int setNetworkForSocket(unsigned netId, int socketFd) {
205    if (socketFd < 0) {
206        return -EBADF;
207    }
208    FwmarkCommand command = {FwmarkCommand::SELECT_NETWORK, netId, 0};
209    return FwmarkClient().send(&command, socketFd, nullptr);
210}
211
212extern "C" int setNetworkForProcess(unsigned netId) {
213    return setNetworkForTarget(netId, &netIdForProcess);
214}
215
216extern "C" int setNetworkForResolv(unsigned netId) {
217    return setNetworkForTarget(netId, &netIdForResolv);
218}
219
220extern "C" int protectFromVpn(int socketFd) {
221    if (socketFd < 0) {
222        return -EBADF;
223    }
224    FwmarkCommand command = {FwmarkCommand::PROTECT_FROM_VPN, 0, 0};
225    return FwmarkClient().send(&command, socketFd, nullptr);
226}
227
228extern "C" int setNetworkForUser(uid_t uid, int socketFd) {
229    if (socketFd < 0) {
230        return -EBADF;
231    }
232    FwmarkCommand command = {FwmarkCommand::SELECT_FOR_USER, 0, uid};
233    return FwmarkClient().send(&command, socketFd, nullptr);
234}
235
236extern "C" int queryUserAccess(uid_t uid, unsigned netId) {
237    FwmarkCommand command = {FwmarkCommand::QUERY_USER_ACCESS, netId, uid};
238    return FwmarkClient().send(&command, -1, nullptr);
239}
240