socket_stream.cc revision 94f8eba2f71f0d73e15bfd5b1c28bb0ed682706d
1// Copyright 2015 The Chromium OS Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include <arpa/inet.h> 6#include <map> 7#include <netdb.h> 8#include <string> 9#include <sys/socket.h> 10#include <sys/types.h> 11#include <unistd.h> 12 13#include <base/bind.h> 14#include <base/bind_helpers.h> 15#include <base/files/file_util.h> 16#include <base/message_loop/message_loop.h> 17#include <base/strings/stringprintf.h> 18#include <brillo/bind_lambda.h> 19#include <brillo/streams/file_stream.h> 20#include <brillo/streams/tls_stream.h> 21 22#include "buffet/socket_stream.h" 23#include "buffet/weave_error_conversion.h" 24 25namespace buffet { 26 27using weave::provider::Network; 28 29namespace { 30 31std::string GetIPAddress(const sockaddr* sa) { 32 std::string addr; 33 char str[INET6_ADDRSTRLEN] = {}; 34 switch (sa->sa_family) { 35 case AF_INET: 36 if (inet_ntop(AF_INET, 37 &(reinterpret_cast<const sockaddr_in*>(sa)->sin_addr), str, 38 sizeof(str))) { 39 addr = str; 40 } 41 break; 42 43 case AF_INET6: 44 if (inet_ntop(AF_INET6, 45 &(reinterpret_cast<const sockaddr_in6*>(sa)->sin6_addr), 46 str, sizeof(str))) { 47 addr = str; 48 } 49 break; 50 } 51 if (addr.empty()) 52 addr = base::StringPrintf("<Unknown address family: %d>", sa->sa_family); 53 return addr; 54} 55 56int ConnectSocket(const std::string& host, uint16_t port) { 57 std::string service = std::to_string(port); 58 addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM}; 59 addrinfo* result = nullptr; 60 if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) { 61 PLOG(WARNING) << "Failed to resolve host name: " << host; 62 return -1; 63 } 64 65 int socket_fd = -1; 66 for (const addrinfo* info = result; info != nullptr; info = info->ai_next) { 67 socket_fd = socket(info->ai_family, info->ai_socktype, info->ai_protocol); 68 if (socket_fd < 0) 69 continue; 70 71 std::string addr = GetIPAddress(info->ai_addr); 72 LOG(INFO) << "Connecting to address: " << addr; 73 if (connect(socket_fd, info->ai_addr, info->ai_addrlen) == 0) 74 break; // Success. 75 76 PLOG(WARNING) << "Failed to connect to address: " << addr; 77 close(socket_fd); 78 socket_fd = -1; 79 } 80 81 freeaddrinfo(result); 82 return socket_fd; 83} 84 85void OnSuccess(const Network::OpenSslSocketCallback& callback, 86 brillo::StreamPtr tls_stream) { 87 callback.Run( 88 std::unique_ptr<weave::Stream>{new SocketStream{std::move(tls_stream)}}, 89 nullptr); 90} 91 92void OnError(const weave::DoneCallback& callback, 93 const brillo::Error* brillo_error) { 94 weave::ErrorPtr error; 95 ConvertError(*brillo_error, &error); 96 callback.Run(std::move(error)); 97} 98 99} // namespace 100 101void SocketStream::Read(void* buffer, 102 size_t size_to_read, 103 const ReadCallback& callback) { 104 brillo::ErrorPtr brillo_error; 105 if (!ptr_->ReadAsync( 106 buffer, size_to_read, 107 base::Bind([](const ReadCallback& callback, 108 size_t size) { callback.Run(size, nullptr); }, 109 callback), 110 base::Bind(&OnError, base::Bind(callback, 0)), &brillo_error)) { 111 weave::ErrorPtr error; 112 ConvertError(*brillo_error, &error); 113 base::MessageLoop::current()->PostTask( 114 FROM_HERE, base::Bind(callback, 0, base::Passed(&error))); 115 } 116} 117 118void SocketStream::Write(const void* buffer, 119 size_t size_to_write, 120 const WriteCallback& callback) { 121 brillo::ErrorPtr brillo_error; 122 if (!ptr_->WriteAllAsync(buffer, size_to_write, base::Bind(callback, nullptr), 123 base::Bind(&OnError, callback), &brillo_error)) { 124 weave::ErrorPtr error; 125 ConvertError(*brillo_error, &error); 126 base::MessageLoop::current()->PostTask( 127 FROM_HERE, base::Bind(callback, base::Passed(&error))); 128 } 129} 130 131void SocketStream::CancelPendingOperations() { 132 ptr_->CancelPendingAsyncOperations(); 133} 134 135std::unique_ptr<weave::Stream> SocketStream::ConnectBlocking( 136 const std::string& host, 137 uint16_t port) { 138 int socket_fd = ConnectSocket(host, port); 139 if (socket_fd <= 0) 140 return nullptr; 141 142 auto ptr_ = brillo::FileStream::FromFileDescriptor(socket_fd, true, nullptr); 143 if (ptr_) 144 return std::unique_ptr<Stream>{new SocketStream{std::move(ptr_)}}; 145 146 close(socket_fd); 147 return nullptr; 148} 149 150void SocketStream::TlsConnect(std::unique_ptr<Stream> socket, 151 const std::string& host, 152 const Network::OpenSslSocketCallback& callback) { 153 SocketStream* stream = static_cast<SocketStream*>(socket.get()); 154 brillo::TlsStream::Connect( 155 std::move(stream->ptr_), host, base::Bind(&OnSuccess, callback), 156 base::Bind(&OnError, base::Bind(callback, nullptr))); 157} 158 159} // namespace buffet 160