1/* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17#include <cstdint> 18#include <cstdlib> 19#include <iostream> 20#include <memory> 21#include <mutex> 22#include <sstream> 23#include <string> 24#include <thread> 25#include <vector> 26#include <glog/logging.h> 27#include <gflags/gflags.h> 28 29#include <unistd.h> 30 31#include "common/libs/fs/shared_fd.h" 32#include "common/libs/strings/str_split.h" 33#include "common/vsoc/lib/socket_forward_region_view.h" 34 35#ifdef CUTTLEFISH_HOST 36#include "host/libs/config/cuttlefish_config.h" 37#include "host/libs/adb_connection_maintainer/adb_connection_maintainer.h" 38#endif 39 40using vsoc::socket_forward::Packet; 41using vsoc::socket_forward::SocketForwardRegionView; 42 43#ifdef CUTTLEFISH_HOST 44DEFINE_string(guest_ports, "", 45 "Comma-separated list of ports on which to forward TCP " 46 "connections to the guest."); 47DEFINE_string(host_ports, "", 48 "Comma-separated list of ports on which to run TCP servers on " 49 "the host."); 50#endif 51 52namespace { 53// Sends packets, Shutdown(SHUT_WR) on destruction 54class SocketSender { 55 public: 56 explicit SocketSender(cvd::SharedFD socket) : socket_{std::move(socket)} {} 57 58 SocketSender(SocketSender&&) = default; 59 SocketSender& operator=(SocketSender&&) = default; 60 61 SocketSender(const SocketSender&&) = delete; 62 SocketSender& operator=(const SocketSender&) = delete; 63 64 ~SocketSender() { 65 if (socket_.operator->()) { // check that socket_ was not moved-from 66 socket_->Shutdown(SHUT_WR); 67 } 68 } 69 70 ssize_t SendAll(const Packet& packet) { 71 ssize_t written{}; 72 while (written < static_cast<ssize_t>(packet.payload_length())) { 73 if (!socket_->IsOpen()) { 74 return -1; 75 } 76 auto just_written = 77 socket_->Send(packet.payload() + written, 78 packet.payload_length() - written, MSG_NOSIGNAL); 79 if (just_written <= 0) { 80 LOG(INFO) << "Couldn't write to client: " 81 << strerror(socket_->GetErrno()); 82 return just_written; 83 } 84 written += just_written; 85 } 86 return written; 87 } 88 89 private: 90 cvd::SharedFD socket_; 91}; 92 93class SocketReceiver { 94 public: 95 explicit SocketReceiver(cvd::SharedFD socket) : socket_{std::move(socket)} {} 96 97 SocketReceiver(SocketReceiver&&) = default; 98 SocketReceiver& operator=(SocketReceiver&&) = default; 99 100 SocketReceiver(const SocketReceiver&&) = delete; 101 SocketReceiver& operator=(const SocketReceiver&) = delete; 102 103 // *packet will be empty if Read returns 0 or error 104 void Recv(Packet* packet) { 105 auto size = socket_->Read(packet->payload(), sizeof packet->payload()); 106 if (size < 0) { 107 size = 0; 108 } 109 packet->set_payload_length(size); 110 } 111 112 private: 113 cvd::SharedFD socket_; 114}; 115 116void SocketToShm(SocketReceiver socket_receiver, 117 SocketForwardRegionView::Sender shm_sender) { 118 auto packet = Packet::MakeData(); 119 while (true) { 120 socket_receiver.Recv(&packet); 121 if (packet.empty()) { 122 break; 123 } 124 if (!shm_sender.Send(packet)) { 125 break; 126 } 127 } 128 LOG(INFO) << "Socket to shm exiting"; 129} 130 131void ShmToSocket(SocketSender socket_sender, 132 SocketForwardRegionView::Receiver shm_receiver) { 133 Packet packet{}; 134 while (true) { 135 shm_receiver.Recv(&packet); 136 if (packet.IsEnd()) { 137 break; 138 } 139 if (socket_sender.SendAll(packet) < 0) { 140 break; 141 } 142 } 143 LOG(INFO) << "Shm to socket exiting"; 144} 145 146// One thread for reading from shm and writing into a socket. 147// One thread for reading from a socket and writing into shm. 148void LaunchWorkers(std::pair<SocketForwardRegionView::Sender, 149 SocketForwardRegionView::Receiver> 150 conn, 151 cvd::SharedFD socket) { 152 // TODO create the SocketSender/Receiver in their respective threads? 153 std::thread( 154 SocketToShm, SocketReceiver{socket}, std::move(conn.first)).detach(); 155 std::thread( 156 ShmToSocket, SocketSender{socket}, std::move(conn.second)).detach(); 157} 158 159#ifdef CUTTLEFISH_HOST 160struct PortPair { 161 int guest_port; 162 int host_port; 163}; 164 165void LaunchConnectionMaintainer(int port) { 166 std::thread(cvd::EstablishAndMaintainConnection, port).detach(); 167} 168 169 170[[noreturn]] void host_impl(SocketForwardRegionView* shm, 171 std::vector<PortPair> ports, std::size_t index) { 172 // launch a worker for the following port before handling the current port. 173 // recursion (instead of a loop) removes the need fore any join() or having 174 // the main thread do no work. 175 if (index + 1 < ports.size()) { 176 std::thread(host_impl, shm, ports, index + 1).detach(); 177 } 178 auto guest_port = ports[index].guest_port; 179 auto host_port = ports[index].host_port; 180 LOG(INFO) << "starting server on " << host_port 181 << " for guest port " << guest_port; 182 auto server = cvd::SharedFD::SocketLocalServer(host_port, SOCK_STREAM); 183 CHECK(server->IsOpen()) << "Could not start server on port " << host_port; 184 LaunchConnectionMaintainer(host_port); 185 while (true) { 186 auto client_socket = cvd::SharedFD::Accept(*server); 187 CHECK(client_socket->IsOpen()) << "error creating client socket"; 188 LOG(INFO) << "client socket accepted"; 189 auto conn = shm->OpenConnection(guest_port); 190 LOG(INFO) << "shm connection opened"; 191 LaunchWorkers(std::move(conn), std::move(client_socket)); 192 } 193} 194 195[[noreturn]] void host(SocketForwardRegionView* shm, 196 std::vector<PortPair> ports) { 197 CHECK(!ports.empty()); 198 host_impl(shm, ports, 0); 199} 200 201std::vector<PortPair> ParsePortsList(const std::string& guest_ports_str, 202 const std::string& host_ports_str) { 203 std::vector<PortPair> ports{}; 204 auto guest_ports = cvd::StrSplit(guest_ports_str, ','); 205 auto host_ports = cvd::StrSplit(host_ports_str, ','); 206 CHECK(guest_ports.size() == host_ports.size()); 207 for (std::size_t i = 0; i < guest_ports.size(); ++i) { 208 ports.push_back({std::stoi(guest_ports[i]), std::stoi(host_ports[i])}); 209 } 210 return ports; 211 212} 213 214#else 215cvd::SharedFD OpenSocketConnection(int port) { 216 while (true) { 217 auto sock = cvd::SharedFD::SocketLocalClient(port, SOCK_STREAM); 218 if (sock->IsOpen()) { 219 return sock; 220 } 221 LOG(WARNING) << "could not connect on port " << port 222 << ". sleeping for 1 second"; 223 sleep(1); 224 } 225} 226 227[[noreturn]] void guest(SocketForwardRegionView* shm) { 228 LOG(INFO) << "Starting guest mainloop"; 229 while (true) { 230 auto conn = shm->AcceptConnection(); 231 LOG(INFO) << "shm connection accepted"; 232 auto sock = OpenSocketConnection(conn.first.port()); 233 CHECK(sock->IsOpen()); 234 LOG(INFO) << "socket opened to " << conn.first.port(); 235 LaunchWorkers(std::move(conn), std::move(sock)); 236 } 237} 238#endif 239 240SocketForwardRegionView* GetShm() { 241 auto shm = SocketForwardRegionView::GetInstance( 242#ifdef CUTTLEFISH_HOST 243 vsoc::GetDomain().c_str() 244#endif 245 ); 246 if (!shm) { 247 LOG(FATAL) << "Could not open SHM. Aborting."; 248 } 249 shm->CleanUpPreviousConnections(); 250 return shm; 251} 252 253// makes sure we're running as root on the guest, no-op on the host 254void assert_correct_user() { 255#ifndef CUTTLEFISH_HOST 256 CHECK_EQ(getuid(), 0u) << "must run as root!"; 257#endif 258} 259 260} // namespace 261 262int main(int argc, char* argv[]) { 263 gflags::ParseCommandLineFlags(&argc, &argv, true); 264 assert_correct_user(); 265 266 auto shm = GetShm(); 267 auto worker = shm->StartWorker(); 268 269#ifdef CUTTLEFISH_HOST 270 CHECK(!FLAGS_guest_ports.empty()) << "Must specify --guest_ports flag"; 271 CHECK(!FLAGS_host_ports.empty()) << "Must specify --host_ports flag"; 272 host(shm, ParsePortsList(FLAGS_guest_ports, FLAGS_host_ports)); 273#else 274 guest(shm); 275#endif 276} 277