SockDiag.cpp revision f32fc598b01ba8d59873b0a1085716fd84678b54
1/*
2 * Copyright (C) 2016 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <errno.h>
18#include <netdb.h>
19#include <string.h>
20#include <netinet/in.h>
21#include <netinet/tcp.h>
22#include <sys/socket.h>
23#include <sys/uio.h>
24
25#include <linux/netlink.h>
26#include <linux/sock_diag.h>
27#include <linux/inet_diag.h>
28
29#define LOG_TAG "Netd"
30
31#include <cutils/log.h>
32
33#include "NetdConstants.h"
34#include "SockDiag.h"
35
36#include <chrono>
37
38#ifndef SOCK_DESTROY
39#define SOCK_DESTROY 21
40#endif
41
42namespace {
43
44struct AddrinfoDeleter {
45  void operator()(addrinfo *a) { if (a) freeaddrinfo(a); }
46};
47
48typedef std::unique_ptr<addrinfo, AddrinfoDeleter> ScopedAddrinfo;
49
50int checkError(int fd) {
51    struct {
52        nlmsghdr h;
53        nlmsgerr err;
54    } __attribute__((__packed__)) ack;
55    ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
56    if (bytesread == -1) {
57       // Read failed (error), or nothing to read (good).
58       return (errno == EAGAIN) ? 0 : -errno;
59    } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
60        // We got an error. Consume it.
61        recv(fd, &ack, sizeof(ack), 0);
62        return ack.err.error;
63    } else {
64        // The kernel replied with something. Leave it to the caller.
65        return 0;
66    }
67}
68
69}  // namespace
70
71bool SockDiag::open() {
72    if (hasSocks()) {
73        return false;
74    }
75
76    mSock = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_INET_DIAG);
77    mWriteSock = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_INET_DIAG);
78    if (!hasSocks()) {
79        closeSocks();
80        return false;
81    }
82
83    sockaddr_nl nl = { .nl_family = AF_NETLINK };
84    if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
85        (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
86        closeSocks();
87        return false;
88    }
89
90    return true;
91}
92
93int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
94    addrinfo hints = { .ai_flags = AI_NUMERICHOST };
95    addrinfo *res;
96    in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
97    int ret;
98
99    // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
100    // doing string conversions when they're not necessary.
101    if ((ret = getaddrinfo(addrstr, nullptr, &hints, &res)) != 0) {
102        return -EINVAL;
103    }
104
105    // So we don't have to call freeaddrinfo on every failure path.
106    ScopedAddrinfo resP(res);
107
108    void *addr;
109    uint8_t addrlen;
110    if (res->ai_family == AF_INET && family == AF_INET) {
111        in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
112        addr = &ina;
113        addrlen = sizeof(ina);
114    } else if (res->ai_family == AF_INET && family == AF_INET6) {
115        in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
116        mapped.s6_addr32[3] = ina.s_addr;
117        addr = &mapped;
118        addrlen = sizeof(mapped);
119    } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
120        in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
121        addr = &in6a;
122        addrlen = sizeof(in6a);
123    } else {
124        return -EAFNOSUPPORT;
125    }
126
127    uint8_t prefixlen = addrlen * 8;
128    uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
129    uint8_t nojump = yesjump + 4;
130    uint32_t states = ~(1 << TCP_TIME_WAIT);
131
132    struct {
133        nlmsghdr nlh;
134        inet_diag_req_v2 req;
135        nlattr nla;
136        inet_diag_bc_op op;
137        inet_diag_hostcond cond;
138    } __attribute__((__packed__)) request = {
139        .nlh = {
140            .nlmsg_type = SOCK_DIAG_BY_FAMILY,
141            .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
142        },
143        .req = {
144            .sdiag_family = family,
145            .sdiag_protocol = proto,
146            .idiag_states = states,
147        },
148        .nla = {
149            .nla_type = INET_DIAG_REQ_BYTECODE,
150        },
151        .op = {
152            INET_DIAG_BC_S_COND,
153            yesjump,
154            nojump,
155        },
156        .cond = {
157            family,
158            prefixlen,
159            -1,
160            {}
161        },
162    };
163
164    request.nlh.nlmsg_len = sizeof(request) + addrlen;
165    request.nla.nla_len = sizeof(request.nla) + sizeof(request.op) + sizeof(request.cond) + addrlen;
166
167    struct iovec iov[] = {
168        { &request, sizeof(request) },
169        { addr, addrlen },
170    };
171
172    if (writev(mSock, iov, ARRAY_SIZE(iov)) != (int) request.nlh.nlmsg_len) {
173        return -errno;
174    }
175
176    return checkError(mSock);
177}
178
179int SockDiag::readDiagMsg(uint8_t proto, SockDiag::DumpCallback callback) {
180    char buf[kBufferSize];
181
182    ssize_t bytesread;
183    do {
184        bytesread = read(mSock, buf, sizeof(buf));
185
186        if (bytesread < 0) {
187            return -errno;
188        }
189
190        uint32_t len = bytesread;
191        for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf);
192             NLMSG_OK(nlh, len);
193             nlh = NLMSG_NEXT(nlh, len)) {
194            switch (nlh->nlmsg_type) {
195              case NLMSG_DONE:
196                callback(proto, NULL);
197                return 0;
198              case NLMSG_ERROR: {
199                nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
200                return err->error;
201              }
202              default:
203                inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
204                callback(proto, msg);
205            }
206        }
207    } while (bytesread > 0);
208
209    return 0;
210}
211
212int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
213    if (msg == nullptr) {
214       return 0;
215    }
216
217    DestroyRequest request = {
218        .nlh = {
219            .nlmsg_type = SOCK_DESTROY,
220            .nlmsg_flags = NLM_F_REQUEST,
221        },
222        .req = {
223            .sdiag_family = msg->idiag_family,
224            .sdiag_protocol = proto,
225            .idiag_states = (uint32_t) (1 << msg->idiag_state),
226            .id = msg->id,
227        },
228    };
229    request.nlh.nlmsg_len = sizeof(request);
230
231    if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
232        return -errno;
233    }
234
235    int ret = checkError(mWriteSock);
236    if (!ret) mSocketsDestroyed++;
237    return ret;
238}
239
240int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
241    if (!hasSocks()) {
242        return -EBADFD;
243    }
244
245    if (int ret = sendDumpRequest(proto, family, addrstr)) {
246        return ret;
247    }
248
249    auto destroy = [this] (uint8_t proto, const inet_diag_msg *msg) {
250        return this->sockDestroy(proto, msg);
251    };
252
253    return readDiagMsg(proto, destroy);
254}
255
256int SockDiag::destroySockets(const char *addrstr) {
257    using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
258
259    mSocketsDestroyed = 0;
260    const auto start = std::chrono::steady_clock::now();
261    if (!strchr(addrstr, ':')) {
262        if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
263            ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
264            return ret;
265        }
266    }
267    if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
268        ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
269        return ret;
270    }
271    auto elapsed = std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start);
272
273    if (mSocketsDestroyed > 0) {
274        ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, elapsed.count());
275    }
276
277    return mSocketsDestroyed;
278}
279