1/*
2 * Copyright (C) 2018 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include <cstdint>
18#include <cstdlib>
19#include <iostream>
20#include <memory>
21#include <mutex>
22#include <sstream>
23#include <string>
24#include <thread>
25#include <vector>
26#include <glog/logging.h>
27#include <gflags/gflags.h>
28
29#include <unistd.h>
30
31#include "common/libs/fs/shared_fd.h"
32#include "common/libs/strings/str_split.h"
33#include "common/vsoc/lib/socket_forward_region_view.h"
34
35#ifdef CUTTLEFISH_HOST
36#include "host/libs/config/cuttlefish_config.h"
37#include "host/libs/adb_connection_maintainer/adb_connection_maintainer.h"
38#endif
39
40using vsoc::socket_forward::Packet;
41using vsoc::socket_forward::SocketForwardRegionView;
42
43#ifdef CUTTLEFISH_HOST
44DEFINE_string(guest_ports, "",
45              "Comma-separated list of ports on which to forward TCP "
46              "connections to the guest.");
47DEFINE_string(host_ports, "",
48              "Comma-separated list of ports on which to run TCP servers on "
49              "the host.");
50#endif
51
52namespace {
53// Sends packets, Shutdown(SHUT_WR) on destruction
54class SocketSender {
55 public:
56  explicit SocketSender(cvd::SharedFD socket) : socket_{std::move(socket)} {}
57
58  SocketSender(SocketSender&&) = default;
59  SocketSender& operator=(SocketSender&&) = default;
60
61  SocketSender(const SocketSender&&) = delete;
62  SocketSender& operator=(const SocketSender&) = delete;
63
64  ~SocketSender() {
65    if (socket_.operator->()) {  // check that socket_ was not moved-from
66      socket_->Shutdown(SHUT_WR);
67    }
68  }
69
70  ssize_t SendAll(const Packet& packet) {
71    ssize_t written{};
72    while (written < static_cast<ssize_t>(packet.payload_length())) {
73      if (!socket_->IsOpen()) {
74        return -1;
75      }
76      auto just_written =
77          socket_->Send(packet.payload() + written,
78                        packet.payload_length() - written, MSG_NOSIGNAL);
79      if (just_written <= 0) {
80        LOG(INFO) << "Couldn't write to client: "
81                  << strerror(socket_->GetErrno());
82        return just_written;
83      }
84      written += just_written;
85    }
86    return written;
87  }
88
89 private:
90  cvd::SharedFD socket_;
91};
92
93class SocketReceiver {
94 public:
95  explicit SocketReceiver(cvd::SharedFD socket) : socket_{std::move(socket)} {}
96
97  SocketReceiver(SocketReceiver&&) = default;
98  SocketReceiver& operator=(SocketReceiver&&) = default;
99
100  SocketReceiver(const SocketReceiver&&) = delete;
101  SocketReceiver& operator=(const SocketReceiver&) = delete;
102
103  // *packet will be empty if Read returns 0 or error
104  void Recv(Packet* packet) {
105    auto size = socket_->Read(packet->payload(), sizeof packet->payload());
106    if (size < 0) {
107      size = 0;
108    }
109    packet->set_payload_length(size);
110  }
111
112 private:
113  cvd::SharedFD socket_;
114};
115
116void SocketToShm(SocketReceiver socket_receiver,
117                 SocketForwardRegionView::Sender shm_sender) {
118  auto packet = Packet::MakeData();
119  while (true) {
120    socket_receiver.Recv(&packet);
121    if (packet.empty()) {
122      break;
123    }
124    if (!shm_sender.Send(packet)) {
125      break;
126    }
127  }
128  LOG(INFO) << "Socket to shm exiting";
129}
130
131void ShmToSocket(SocketSender socket_sender,
132                 SocketForwardRegionView::Receiver shm_receiver) {
133  Packet packet{};
134  while (true) {
135    shm_receiver.Recv(&packet);
136    if (packet.IsEnd()) {
137      break;
138    }
139    if (socket_sender.SendAll(packet) < 0) {
140      break;
141    }
142  }
143  LOG(INFO) << "Shm to socket exiting";
144}
145
146// One thread for reading from shm and writing into a socket.
147// One thread for reading from a socket and writing into shm.
148void LaunchWorkers(std::pair<SocketForwardRegionView::Sender,
149                             SocketForwardRegionView::Receiver>
150                       conn,
151                   cvd::SharedFD socket) {
152  // TODO create the SocketSender/Receiver in their respective threads?
153  std::thread(
154      SocketToShm, SocketReceiver{socket}, std::move(conn.first)).detach();
155  std::thread(
156      ShmToSocket, SocketSender{socket}, std::move(conn.second)).detach();
157}
158
159#ifdef CUTTLEFISH_HOST
160struct PortPair {
161  int guest_port;
162  int host_port;
163};
164
165void LaunchConnectionMaintainer(int port) {
166  std::thread(cvd::EstablishAndMaintainConnection, port).detach();
167}
168
169
170[[noreturn]] void host_impl(SocketForwardRegionView* shm,
171                            std::vector<PortPair> ports, std::size_t index) {
172  // launch a worker for the following port before handling the current port.
173  // recursion (instead of a loop) removes the need fore any join() or having
174  // the main thread do no work.
175  if (index + 1 < ports.size()) {
176    std::thread(host_impl, shm, ports, index + 1).detach();
177  }
178  auto guest_port = ports[index].guest_port;
179  auto host_port = ports[index].host_port;
180  LOG(INFO) << "starting server on " << host_port
181            << " for guest port " << guest_port;
182  auto server = cvd::SharedFD::SocketLocalServer(host_port, SOCK_STREAM);
183  CHECK(server->IsOpen()) << "Could not start server on port " << host_port;
184  LaunchConnectionMaintainer(host_port);
185  while (true) {
186    auto client_socket = cvd::SharedFD::Accept(*server);
187    CHECK(client_socket->IsOpen()) << "error creating client socket";
188    LOG(INFO) << "client socket accepted";
189    auto conn = shm->OpenConnection(guest_port);
190    LOG(INFO) << "shm connection opened";
191    LaunchWorkers(std::move(conn), std::move(client_socket));
192  }
193}
194
195[[noreturn]] void host(SocketForwardRegionView* shm,
196                       std::vector<PortPair> ports) {
197  CHECK(!ports.empty());
198  host_impl(shm, ports, 0);
199}
200
201std::vector<PortPair> ParsePortsList(const std::string& guest_ports_str,
202                                const std::string& host_ports_str) {
203  std::vector<PortPair> ports{};
204  auto guest_ports = cvd::StrSplit(guest_ports_str, ',');
205  auto host_ports = cvd::StrSplit(host_ports_str, ',');
206  CHECK(guest_ports.size() == host_ports.size());
207  for (std::size_t i = 0; i < guest_ports.size(); ++i) {
208    ports.push_back({std::stoi(guest_ports[i]), std::stoi(host_ports[i])});
209  }
210  return ports;
211
212}
213
214#else
215cvd::SharedFD OpenSocketConnection(int port) {
216  while (true) {
217    auto sock = cvd::SharedFD::SocketLocalClient(port, SOCK_STREAM);
218    if (sock->IsOpen()) {
219      return sock;
220    }
221    LOG(WARNING) << "could not connect on port " << port
222                 << ". sleeping for 1 second";
223    sleep(1);
224  }
225}
226
227[[noreturn]] void guest(SocketForwardRegionView* shm) {
228  LOG(INFO) << "Starting guest mainloop";
229  while (true) {
230    auto conn = shm->AcceptConnection();
231    LOG(INFO) << "shm connection accepted";
232    auto sock = OpenSocketConnection(conn.first.port());
233    CHECK(sock->IsOpen());
234    LOG(INFO) << "socket opened to " << conn.first.port();
235    LaunchWorkers(std::move(conn), std::move(sock));
236  }
237}
238#endif
239
240SocketForwardRegionView* GetShm() {
241  auto shm = SocketForwardRegionView::GetInstance(
242#ifdef CUTTLEFISH_HOST
243      vsoc::GetDomain().c_str()
244#endif
245  );
246  if (!shm) {
247    LOG(FATAL) << "Could not open SHM. Aborting.";
248  }
249  shm->CleanUpPreviousConnections();
250  return shm;
251}
252
253// makes sure we're running as root on the guest, no-op on the host
254void assert_correct_user() {
255#ifndef CUTTLEFISH_HOST
256  CHECK_EQ(getuid(), 0u) << "must run as root!";
257#endif
258}
259
260}  // namespace
261
262int main(int argc, char* argv[]) {
263  gflags::ParseCommandLineFlags(&argc, &argv, true);
264  assert_correct_user();
265
266  auto shm = GetShm();
267  auto worker = shm->StartWorker();
268
269#ifdef CUTTLEFISH_HOST
270  CHECK(!FLAGS_guest_ports.empty()) << "Must specify --guest_ports flag";
271  CHECK(!FLAGS_host_ports.empty()) << "Must specify --host_ports flag";
272  host(shm, ParsePortsList(FLAGS_guest_ports, FLAGS_host_ports));
273#else
274  guest(shm);
275#endif
276}
277