socket_stream.cc revision 94f8eba2f71f0d73e15bfd5b1c28bb0ed682706d
1// Copyright 2015 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <arpa/inet.h>
6#include <map>
7#include <netdb.h>
8#include <string>
9#include <sys/socket.h>
10#include <sys/types.h>
11#include <unistd.h>
12
13#include <base/bind.h>
14#include <base/bind_helpers.h>
15#include <base/files/file_util.h>
16#include <base/message_loop/message_loop.h>
17#include <base/strings/stringprintf.h>
18#include <brillo/bind_lambda.h>
19#include <brillo/streams/file_stream.h>
20#include <brillo/streams/tls_stream.h>
21
22#include "buffet/socket_stream.h"
23#include "buffet/weave_error_conversion.h"
24
25namespace buffet {
26
27using weave::provider::Network;
28
29namespace {
30
31std::string GetIPAddress(const sockaddr* sa) {
32  std::string addr;
33  char str[INET6_ADDRSTRLEN] = {};
34  switch (sa->sa_family) {
35    case AF_INET:
36      if (inet_ntop(AF_INET,
37                    &(reinterpret_cast<const sockaddr_in*>(sa)->sin_addr), str,
38                    sizeof(str))) {
39        addr = str;
40      }
41      break;
42
43    case AF_INET6:
44      if (inet_ntop(AF_INET6,
45                    &(reinterpret_cast<const sockaddr_in6*>(sa)->sin6_addr),
46                    str, sizeof(str))) {
47        addr = str;
48      }
49      break;
50  }
51  if (addr.empty())
52    addr = base::StringPrintf("<Unknown address family: %d>", sa->sa_family);
53  return addr;
54}
55
56int ConnectSocket(const std::string& host, uint16_t port) {
57  std::string service = std::to_string(port);
58  addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM};
59  addrinfo* result = nullptr;
60  if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) {
61    PLOG(WARNING) << "Failed to resolve host name: " << host;
62    return -1;
63  }
64
65  int socket_fd = -1;
66  for (const addrinfo* info = result; info != nullptr; info = info->ai_next) {
67    socket_fd = socket(info->ai_family, info->ai_socktype, info->ai_protocol);
68    if (socket_fd < 0)
69      continue;
70
71    std::string addr = GetIPAddress(info->ai_addr);
72    LOG(INFO) << "Connecting to address: " << addr;
73    if (connect(socket_fd, info->ai_addr, info->ai_addrlen) == 0)
74      break;  // Success.
75
76    PLOG(WARNING) << "Failed to connect to address: " << addr;
77    close(socket_fd);
78    socket_fd = -1;
79  }
80
81  freeaddrinfo(result);
82  return socket_fd;
83}
84
85void OnSuccess(const Network::OpenSslSocketCallback& callback,
86               brillo::StreamPtr tls_stream) {
87  callback.Run(
88      std::unique_ptr<weave::Stream>{new SocketStream{std::move(tls_stream)}},
89      nullptr);
90}
91
92void OnError(const weave::DoneCallback& callback,
93             const brillo::Error* brillo_error) {
94  weave::ErrorPtr error;
95  ConvertError(*brillo_error, &error);
96  callback.Run(std::move(error));
97}
98
99}  // namespace
100
101void SocketStream::Read(void* buffer,
102                        size_t size_to_read,
103                        const ReadCallback& callback) {
104  brillo::ErrorPtr brillo_error;
105  if (!ptr_->ReadAsync(
106          buffer, size_to_read,
107          base::Bind([](const ReadCallback& callback,
108                        size_t size) { callback.Run(size, nullptr); },
109                     callback),
110          base::Bind(&OnError, base::Bind(callback, 0)), &brillo_error)) {
111    weave::ErrorPtr error;
112    ConvertError(*brillo_error, &error);
113    base::MessageLoop::current()->PostTask(
114        FROM_HERE, base::Bind(callback, 0, base::Passed(&error)));
115  }
116}
117
118void SocketStream::Write(const void* buffer,
119                         size_t size_to_write,
120                         const WriteCallback& callback) {
121  brillo::ErrorPtr brillo_error;
122  if (!ptr_->WriteAllAsync(buffer, size_to_write, base::Bind(callback, nullptr),
123                           base::Bind(&OnError, callback), &brillo_error)) {
124    weave::ErrorPtr error;
125    ConvertError(*brillo_error, &error);
126    base::MessageLoop::current()->PostTask(
127        FROM_HERE, base::Bind(callback, base::Passed(&error)));
128  }
129}
130
131void SocketStream::CancelPendingOperations() {
132  ptr_->CancelPendingAsyncOperations();
133}
134
135std::unique_ptr<weave::Stream> SocketStream::ConnectBlocking(
136    const std::string& host,
137    uint16_t port) {
138  int socket_fd = ConnectSocket(host, port);
139  if (socket_fd <= 0)
140    return nullptr;
141
142  auto ptr_ = brillo::FileStream::FromFileDescriptor(socket_fd, true, nullptr);
143  if (ptr_)
144    return std::unique_ptr<Stream>{new SocketStream{std::move(ptr_)}};
145
146  close(socket_fd);
147  return nullptr;
148}
149
150void SocketStream::TlsConnect(std::unique_ptr<Stream> socket,
151                              const std::string& host,
152                              const Network::OpenSslSocketCallback& callback) {
153  SocketStream* stream = static_cast<SocketStream*>(socket.get());
154  brillo::TlsStream::Connect(
155      std::move(stream->ptr_), host, base::Bind(&OnSuccess, callback),
156      base::Bind(&OnError, base::Bind(callback, nullptr)));
157}
158
159}  // namespace buffet
160