1/*
2 * Copyright (C) 2014 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "NetdClient.h"
18
19#include <errno.h>
20#include <sys/socket.h>
21#include <unistd.h>
22
23#include <atomic>
24
25#include "Fwmark.h"
26#include "FwmarkClient.h"
27#include "FwmarkCommand.h"
28#include "resolv_netid.h"
29
30namespace {
31
32std::atomic_uint netIdForProcess(NETID_UNSET);
33std::atomic_uint netIdForResolv(NETID_UNSET);
34
35typedef int (*Accept4FunctionType)(int, sockaddr*, socklen_t*, int);
36typedef int (*ConnectFunctionType)(int, const sockaddr*, socklen_t);
37typedef int (*SocketFunctionType)(int, int, int);
38typedef unsigned (*NetIdForResolvFunctionType)(unsigned);
39
40// These variables are only modified at startup (when libc.so is loaded) and never afterwards, so
41// it's okay that they are read later at runtime without a lock.
42Accept4FunctionType libcAccept4 = 0;
43ConnectFunctionType libcConnect = 0;
44SocketFunctionType libcSocket = 0;
45
46int closeFdAndSetErrno(int fd, int error) {
47    close(fd);
48    errno = -error;
49    return -1;
50}
51
52int netdClientAccept4(int sockfd, sockaddr* addr, socklen_t* addrlen, int flags) {
53    int acceptedSocket = libcAccept4(sockfd, addr, addrlen, flags);
54    if (acceptedSocket == -1) {
55        return -1;
56    }
57    int family;
58    if (addr) {
59        family = addr->sa_family;
60    } else {
61        socklen_t familyLen = sizeof(family);
62        if (getsockopt(acceptedSocket, SOL_SOCKET, SO_DOMAIN, &family, &familyLen) == -1) {
63            return closeFdAndSetErrno(acceptedSocket, -errno);
64        }
65    }
66    if (FwmarkClient::shouldSetFwmark(family)) {
67        FwmarkCommand command = {FwmarkCommand::ON_ACCEPT, 0, 0};
68        if (int error = FwmarkClient().send(&command, acceptedSocket)) {
69            return closeFdAndSetErrno(acceptedSocket, error);
70        }
71    }
72    return acceptedSocket;
73}
74
75int netdClientConnect(int sockfd, const sockaddr* addr, socklen_t addrlen) {
76    if (sockfd >= 0 && addr && FwmarkClient::shouldSetFwmark(addr->sa_family)) {
77        FwmarkCommand command = {FwmarkCommand::ON_CONNECT, 0, 0};
78        if (int error = FwmarkClient().send(&command, sockfd)) {
79            errno = -error;
80            return -1;
81        }
82    }
83    return libcConnect(sockfd, addr, addrlen);
84}
85
86int netdClientSocket(int domain, int type, int protocol) {
87    int socketFd = libcSocket(domain, type, protocol);
88    if (socketFd == -1) {
89        return -1;
90    }
91    unsigned netId = netIdForProcess;
92    if (netId != NETID_UNSET && FwmarkClient::shouldSetFwmark(domain)) {
93        if (int error = setNetworkForSocket(netId, socketFd)) {
94            return closeFdAndSetErrno(socketFd, error);
95        }
96    }
97    return socketFd;
98}
99
100unsigned getNetworkForResolv(unsigned netId) {
101    if (netId != NETID_UNSET) {
102        return netId;
103    }
104    netId = netIdForProcess;
105    if (netId != NETID_UNSET) {
106        return netId;
107    }
108    return netIdForResolv;
109}
110
111int setNetworkForTarget(unsigned netId, std::atomic_uint* target) {
112    if (netId == NETID_UNSET) {
113        *target = netId;
114        return 0;
115    }
116    // Verify that we are allowed to use |netId|, by creating a socket and trying to have it marked
117    // with the netId. Call libcSocket() directly; else the socket creation (via netdClientSocket())
118    // might itself cause another check with the fwmark server, which would be wasteful.
119    int socketFd;
120    if (libcSocket) {
121        socketFd = libcSocket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
122    } else {
123        socketFd = socket(AF_INET6, SOCK_DGRAM | SOCK_CLOEXEC, 0);
124    }
125    if (socketFd < 0) {
126        return -errno;
127    }
128    int error = setNetworkForSocket(netId, socketFd);
129    if (!error) {
130        *target = netId;
131    }
132    close(socketFd);
133    return error;
134}
135
136}  // namespace
137
138// accept() just calls accept4(..., 0), so there's no need to handle accept() separately.
139extern "C" void netdClientInitAccept4(Accept4FunctionType* function) {
140    if (function && *function) {
141        libcAccept4 = *function;
142        *function = netdClientAccept4;
143    }
144}
145
146extern "C" void netdClientInitConnect(ConnectFunctionType* function) {
147    if (function && *function) {
148        libcConnect = *function;
149        *function = netdClientConnect;
150    }
151}
152
153extern "C" void netdClientInitSocket(SocketFunctionType* function) {
154    if (function && *function) {
155        libcSocket = *function;
156        *function = netdClientSocket;
157    }
158}
159
160extern "C" void netdClientInitNetIdForResolv(NetIdForResolvFunctionType* function) {
161    if (function) {
162        *function = getNetworkForResolv;
163    }
164}
165
166extern "C" int getNetworkForSocket(unsigned* netId, int socketFd) {
167    if (!netId || socketFd < 0) {
168        return -EBADF;
169    }
170    Fwmark fwmark;
171    socklen_t fwmarkLen = sizeof(fwmark.intValue);
172    if (getsockopt(socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
173        return -errno;
174    }
175    *netId = fwmark.netId;
176    return 0;
177}
178
179extern "C" unsigned getNetworkForProcess() {
180    return netIdForProcess;
181}
182
183extern "C" int setNetworkForSocket(unsigned netId, int socketFd) {
184    if (socketFd < 0) {
185        return -EBADF;
186    }
187    FwmarkCommand command = {FwmarkCommand::SELECT_NETWORK, netId, 0};
188    return FwmarkClient().send(&command, socketFd);
189}
190
191extern "C" int setNetworkForProcess(unsigned netId) {
192    return setNetworkForTarget(netId, &netIdForProcess);
193}
194
195extern "C" int setNetworkForResolv(unsigned netId) {
196    return setNetworkForTarget(netId, &netIdForResolv);
197}
198
199extern "C" int protectFromVpn(int socketFd) {
200    if (socketFd < 0) {
201        return -EBADF;
202    }
203    FwmarkCommand command = {FwmarkCommand::PROTECT_FROM_VPN, 0, 0};
204    return FwmarkClient().send(&command, socketFd);
205}
206
207extern "C" int setNetworkForUser(uid_t uid, int socketFd) {
208    if (socketFd < 0) {
209        return -EBADF;
210    }
211    FwmarkCommand command = {FwmarkCommand::SELECT_FOR_USER, 0, uid};
212    return FwmarkClient().send(&command, socketFd);
213}
214
215extern "C" int queryUserAccess(uid_t uid, unsigned netId) {
216    FwmarkCommand command = {FwmarkCommand::QUERY_USER_ACCESS, netId, uid};
217    return FwmarkClient().send(&command, -1);
218}
219