1/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "chre_host/socket_server.h"
18
19#include <poll.h>
20
21#include <cassert>
22#include <cinttypes>
23#include <csignal>
24#include <cstdlib>
25#include <map>
26#include <mutex>
27
28#include <cutils/sockets.h>
29
30#include "chre_host/log.h"
31
32namespace android {
33namespace chre {
34
35std::atomic<bool> SocketServer::sSignalReceived(false);
36
37namespace {
38
39void maskAllSignals() {
40  sigset_t signalMask;
41  sigfillset(&signalMask);
42  if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
43    LOG_ERROR("Couldn't mask all signals", errno);
44  }
45}
46
47void maskAllSignalsExceptIntAndTerm() {
48  sigset_t signalMask;
49  sigfillset(&signalMask);
50  sigdelset(&signalMask, SIGINT);
51  sigdelset(&signalMask, SIGTERM);
52  if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
53    LOG_ERROR("Couldn't mask all signals except INT/TERM", errno);
54  }
55}
56
57}  // anonymous namespace
58
59SocketServer::SocketServer() {
60  // Initialize the socket fds field for all inactive client slots to -1, so
61  // poll skips over it, and we don't attempt to send on it
62  for (size_t i = 1; i <= kMaxActiveClients; i++) {
63    mPollFds[i].fd = -1;
64    mPollFds[i].events = POLLIN;
65  }
66}
67
68void SocketServer::run(const char *socketName, bool allowSocketCreation,
69                       ClientMessageCallback clientMessageCallback) {
70  mClientMessageCallback = clientMessageCallback;
71
72  mSockFd = android_get_control_socket(socketName);
73  if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
74    LOGI("Didn't inherit socket, creating...");
75    mSockFd = socket_local_server(socketName,
76                                  ANDROID_SOCKET_NAMESPACE_RESERVED,
77                                  SOCK_SEQPACKET);
78  }
79
80  if (mSockFd == INVALID_SOCKET) {
81    LOGE("Couldn't get/create socket");
82  } else {
83    int ret = listen(mSockFd, kMaxPendingConnectionRequests);
84    if (ret < 0) {
85      LOG_ERROR("Couldn't listen on socket", errno);
86    } else {
87      serviceSocket();
88    }
89
90    {
91      std::lock_guard<std::mutex> lock(mClientsMutex);
92      for (const auto& pair : mClients) {
93        int clientSocket = pair.first;
94        if (close(clientSocket) != 0) {
95          LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
96               pair.second.clientId, strerror(errno));
97        }
98      }
99      mClients.clear();
100    }
101    close(mSockFd);
102  }
103}
104
105void SocketServer::sendToAllClients(const void *data, size_t length) {
106  std::lock_guard<std::mutex> lock(mClientsMutex);
107
108  int deliveredCount = 0;
109  for (const auto& pair : mClients) {
110    int clientSocket = pair.first;
111    uint16_t clientId = pair.second.clientId;
112    if (sendToClientSocket(data, length, clientSocket, clientId)) {
113      deliveredCount++;
114    } else if (errno == EINTR) {
115      // Exit early if we were interrupted - we should only get this for
116      // SIGINT/SIGTERM, so we should exit quickly
117      break;
118    }
119  }
120
121  if (deliveredCount == 0) {
122    LOGW("Got message but didn't deliver to any clients");
123  }
124}
125
126bool SocketServer::sendToClientById(const void *data, size_t length,
127                                    uint16_t clientId) {
128  std::lock_guard<std::mutex> lock(mClientsMutex);
129
130  bool sent = false;
131  for (const auto& pair : mClients) {
132    uint16_t thisClientId = pair.second.clientId;
133    if (thisClientId == clientId) {
134      int clientSocket = pair.first;
135      sent = sendToClientSocket(data, length, clientSocket, thisClientId);
136      break;
137    }
138  }
139
140  return sent;
141}
142
143void SocketServer::acceptClientConnection() {
144  int clientSocket = accept(mSockFd, NULL, NULL);
145  if (clientSocket < 0) {
146    LOG_ERROR("Couldn't accept client connection", errno);
147  } else if (mClients.size() >= kMaxActiveClients) {
148    LOGW("Rejecting client request - maximum number of clients reached");
149    close(clientSocket);
150  } else {
151    ClientData clientData;
152    clientData.clientId = mNextClientId++;
153
154    // We currently don't handle wraparound - if we're getting this many
155    // connects/disconnects, then something is wrong.
156    // TODO: can handle this properly by iterating over the existing clients to
157    // avoid a conflict.
158    if (clientData.clientId == 0) {
159      LOGE("Couldn't allocate client ID");
160      std::exit(-1);
161    }
162
163    bool slotFound = false;
164    for (size_t i = 1; i <= kMaxActiveClients; i++) {
165      if (mPollFds[i].fd < 0) {
166        mPollFds[i].fd = clientSocket;
167        slotFound = true;
168        break;
169      }
170    }
171
172    if (!slotFound) {
173      LOGE("Couldn't find slot for client!");
174      assert(slotFound);
175      close(clientSocket);
176    } else {
177      {
178        std::lock_guard<std::mutex> lock(mClientsMutex);
179        mClients[clientSocket] = clientData;
180      }
181      LOGI("Accepted new client connection (count %zu), assigned client ID %"
182           PRIu16, mClients.size(), clientData.clientId);
183    }
184  }
185}
186
187void SocketServer::handleClientData(int clientSocket) {
188  const ClientData& clientData = mClients[clientSocket];
189  uint16_t clientId = clientData.clientId;
190
191  uint8_t buffer[kMaxPacketSize];
192  ssize_t packetSize = recv(clientSocket, buffer, sizeof(buffer), MSG_DONTWAIT);
193  if (packetSize < 0) {
194    LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
195         strerror(errno));
196  } else if (packetSize == 0) {
197    LOGI("Client %" PRIu16 " disconnected", clientId);
198    disconnectClient(clientSocket);
199  } else {
200    LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
201    mClientMessageCallback(clientId, buffer, packetSize);
202  }
203}
204
205void SocketServer::disconnectClient(int clientSocket) {
206  {
207    std::lock_guard<std::mutex> lock(mClientsMutex);
208    mClients.erase(clientSocket);
209  }
210  close(clientSocket);
211
212  bool removed = false;
213  for (size_t i = 1; i <= kMaxActiveClients; i++) {
214    if (mPollFds[i].fd == clientSocket) {
215      mPollFds[i].fd = -1;
216      removed = true;
217      break;
218    }
219  }
220
221  if (!removed) {
222    LOGE("Out of sync");
223    assert(removed);
224  }
225}
226
227bool SocketServer::sendToClientSocket(const void *data, size_t length,
228                                      int clientSocket, uint16_t clientId) {
229  errno = 0;
230  ssize_t bytesSent = send(clientSocket, data, length, 0);
231  if (bytesSent < 0) {
232    LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s",
233         length, clientId, strerror(errno));
234  } else if (bytesSent == 0) {
235    LOGW("Client %" PRIu16 " disconnected before message could be delivered",
236         clientId);
237  } else {
238    LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
239         clientId);
240  }
241
242  return (bytesSent > 0);
243}
244
245void SocketServer::serviceSocket() {
246  constexpr size_t kListenIndex = 0;
247  static_assert(kListenIndex == 0, "Code assumes that the first index is "
248                "always the listen socket");
249
250  mPollFds[kListenIndex].fd = mSockFd;
251  mPollFds[kListenIndex].events = POLLIN;
252
253  // Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
254  // and ignore other signals
255  sigset_t signalMask;
256  sigfillset(&signalMask);
257  sigdelset(&signalMask, SIGINT);
258  sigdelset(&signalMask, SIGTERM);
259
260  // Masking signals here ensure that after this point, we won't handle INT/TERM
261  // until after we call into ppoll()
262  maskAllSignals();
263  std::signal(SIGINT, signalHandler);
264  std::signal(SIGTERM, signalHandler);
265
266  LOGI("Ready to accept connections");
267  while (!sSignalReceived) {
268    int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
269    maskAllSignalsExceptIntAndTerm();
270    if (ret == -1) {
271      LOGI("Exiting poll loop: %s", strerror(errno));
272      break;
273    }
274
275    if (mPollFds[kListenIndex].revents & POLLIN) {
276      acceptClientConnection();
277    }
278
279    for (size_t i = 1; i <= kMaxActiveClients; i++) {
280      if (mPollFds[i].fd < 0) {
281        continue;
282      }
283
284      if (mPollFds[i].revents & POLLIN) {
285        handleClientData(mPollFds[i].fd);
286      }
287    }
288
289    // Mask all signals to ensure that sSignalReceived can't become true between
290    // checking it in the while condition and calling into ppoll()
291    maskAllSignals();
292  }
293}
294
295void SocketServer::signalHandler(int signal) {
296  LOGD("Caught signal %d", signal);
297  sSignalReceived = true;
298}
299
300}  // namespace chre
301}  // namespace android
302