NetdClient.cpp revision d794e580dbe1a8b4192850b0e117654401514af8
1/*
2 * Copyright (C) 2014 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "NetdClient.h"
18
19#include "FwmarkClient.h"
20#include "FwmarkCommand.h"
21#include "resolv_netid.h"
22
23#include <atomic>
24#include <sys/socket.h>
25#include <unistd.h>
26
27namespace {
28
29std::atomic_uint netIdForProcess(NETID_UNSET);
30std::atomic_uint netIdForResolv(NETID_UNSET);
31
32typedef int (*Accept4FunctionType)(int, sockaddr*, socklen_t*, int);
33typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
34typedef int (*SocketFunctionType)(int, int, int);
35typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
36
37// These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
38// it's okay that they are read later at runtime without a lock.
39Accept4FunctionType libcAccept4 = 0;
40ConnectFunctionType libcConnect = 0;
41SocketFunctionType libcSocket = 0;
42
43int closeFdAndRestoreErrno(int fd) {
44    int error = errno;
45    close(fd);
46    errno = error;
47    return -1;
48}
49
50int netdClientAccept4(int sockfd, sockaddr* addr, socklen_t* addrlen, int flags) {
51    int acceptedSocket = libcAccept4(sockfd, addr, addrlen, flags);
52    if (acceptedSocket == -1) {
53        return -1;
54    }
55    int family;
56    if (addr) {
57        family = addr->sa_family;
58    } else {
59        socklen_t familyLen = sizeof(family);
60        if (getsockopt(acceptedSocket, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
61            return closeFdAndRestoreErrno(acceptedSocket);
62        }
63    }
64    if (FwmarkClient::shouldSetFwmark(family)) {
65        FwmarkCommand command = {FwmarkCommand::ON_ACCEPT, 0};
66        if (!FwmarkClient().send(&command, sizeof(command), acceptedSocket)) {
67            return closeFdAndRestoreErrno(acceptedSocket);
68        }
69    }
70    return acceptedSocket;
71}
72
73int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
74    if (sockfd >= 0 && addr && FwmarkClient::shouldSetFwmark(addr->sa_family)) {
75        FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0};
76        if (!FwmarkClient().send(&command, sizeof(command), sockfd)) {
77            return -1;
78        }
79    }
80    return libcConnect(sockfd, addr, addrlen);
81}
82
83int netdClientSocket(int domain, int type, int protocol) {
84    int socketFd = libcSocket(domain, type, protocol);
85    if (socketFd == -1) {
86        return -1;
87    }
88    unsigned netId = netIdForProcess;
89    if (netId != NETID_UNSET && FwmarkClient::shouldSetFwmark(domain)) {
90        if (!setNetworkForSocket(netId, socketFd)) {
91            return closeFdAndRestoreErrno(socketFd);
92        }
93    }
94    return socketFd;
95}
96
97unsigned getNetworkForResolv(unsigned netId) {
98    if (netId != NETID_UNSET) {
99        return netId;
100    }
101    netId = netIdForProcess;
102    if (netId != NETID_UNSET) {
103        return netId;
104    }
105    return netIdForResolv;
106}
107
108bool setNetworkForTarget(unsigned netId, std::atomic_uint* target) {
109    if (netId == NETID_UNSET) {
110        *target = netId;
111        return true;
112    }
113    // Verify that we are allowed to use |netId|, by creating a socket and trying to have it marked
114    // with the netId. Call libcSocket() directly; else the socket creation (via netdClientSocket())
115    // might itself cause another check with the fwmark server, which would be wasteful.
116    int socketFd;
117    if (libcSocket) {
118        socketFd = libcSocket(AF_INET6, SOCK_DGRAM, 0);
119    } else {
120        socketFd = socket(AF_INET6, SOCK_DGRAM, 0);
121    }
122    if (socketFd < 0) {
123        return false;
124    }
125    bool status = setNetworkForSocket(netId, socketFd);
126    closeFdAndRestoreErrno(socketFd);
127    if (status) {
128        *target = netId;
129    }
130    return status;
131}
132
133}  // namespace
134
135// accept() just calls accept4(..., 0), so there's no need to handle accept() separately.
136extern "C" void netdClientInitAccept4(Accept4FunctionType* function) {
137    if (function && *function) {
138        libcAccept4 = *function;
139        *function = netdClientAccept4;
140    }
141}
142
143extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
144    if (function && *function) {
145        libcConnect = *function;
146        *function = netdClientConnect;
147    }
148}
149
150extern "C" void netdClientInitSocket(SocketFunctionType* function) {
151    if (function && *function) {
152        libcSocket = *function;
153        *function = netdClientSocket;
154    }
155}
156
157extern "C" void netdClientInitNetIdForResolv(NetIdForResolvFunctionType* function) {
158    if (function) {
159        *function = getNetworkForResolv;
160    }
161}
162
163extern "C" unsigned getNetworkForProcess() {
164    return netIdForProcess;
165}
166
167extern "C" bool setNetworkForSocket(unsigned netId, int socketFd) {
168    if (socketFd < 0) {
169        errno = EBADF;
170        return false;
171    }
172    FwmarkCommand command = {FwmarkCommand::SELECT_NETWORK, netId};
173    return FwmarkClient().send(&command, sizeof(command), socketFd);
174}
175
176extern "C" bool setNetworkForProcess(unsigned netId) {
177    return setNetworkForTarget(netId, &netIdForProcess);
178}
179
180extern "C" bool setNetworkForResolv(unsigned netId) {
181    return setNetworkForTarget(netId, &netIdForResolv);
182}
183
184extern "C" bool protectFromVpn(int socketFd) {
185    if (socketFd < 0) {
186        errno = EBADF;
187        return false;
188    }
189    FwmarkCommand command = {FwmarkCommand::PROTECT_FROM_VPN, 0};
190    return FwmarkClient().send(&command, sizeof(command), socketFd);
191}
192