socket_stream.cc revision 4170585fe75d99036883229081420f2972dd4ec1
1// Copyright 2015 The Chromium OS Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include <arpa/inet.h>
6#include <map>
7#include <netdb.h>
8#include <string>
9#include <sys/socket.h>
10#include <sys/types.h>
11#include <unistd.h>
12
13#include <base/bind.h>
14#include <base/files/file_util.h>
15#include <base/message_loop/message_loop.h>
16#include <base/strings/stringprintf.h>
17#include <brillo/streams/file_stream.h>
18#include <brillo/streams/tls_stream.h>
19
20#include "buffet/socket_stream.h"
21#include "buffet/weave_error_conversion.h"
22
23namespace buffet {
24
25namespace {
26
27std::string GetIPAddress(const sockaddr* sa) {
28  std::string addr;
29  char str[INET6_ADDRSTRLEN] = {};
30  switch (sa->sa_family) {
31    case AF_INET:
32      if (inet_ntop(AF_INET,
33                    &(reinterpret_cast<const sockaddr_in*>(sa)->sin_addr), str,
34                    sizeof(str))) {
35        addr = str;
36      }
37      break;
38
39    case AF_INET6:
40      if (inet_ntop(AF_INET6,
41                    &(reinterpret_cast<const sockaddr_in6*>(sa)->sin6_addr),
42                    str, sizeof(str))) {
43        addr = str;
44      }
45      break;
46  }
47  if (addr.empty())
48    addr = base::StringPrintf("<Unknown address family: %d>", sa->sa_family);
49  return addr;
50}
51
52int ConnectSocket(const std::string& host, uint16_t port) {
53  std::string service = std::to_string(port);
54  addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM};
55  addrinfo* result = nullptr;
56  if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) {
57    PLOG(WARNING) << "Failed to resolve host name: " << host;
58    return -1;
59  }
60
61  int socket_fd = -1;
62  for (const addrinfo* info = result; info != nullptr; info = info->ai_next) {
63    socket_fd = socket(info->ai_family, info->ai_socktype, info->ai_protocol);
64    if (socket_fd < 0)
65      continue;
66
67    std::string addr = GetIPAddress(info->ai_addr);
68    LOG(INFO) << "Connecting to address: " << addr;
69    if (connect(socket_fd, info->ai_addr, info->ai_addrlen) == 0)
70      break;  // Success.
71
72    PLOG(WARNING) << "Failed to connect to address: " << addr;
73    close(socket_fd);
74    socket_fd = -1;
75  }
76
77  freeaddrinfo(result);
78  return socket_fd;
79}
80
81void OnSuccess(const base::Callback<void(std::unique_ptr<weave::Stream>)>&
82                   success_callback,
83               brillo::StreamPtr tls_stream) {
84  success_callback.Run(
85      std::unique_ptr<weave::Stream>{new SocketStream{std::move(tls_stream)}});
86}
87
88void OnError(const base::Callback<void(weave::ErrorPtr)>& error_callback,
89             const brillo::Error* chromeos_error) {
90  weave::ErrorPtr error;
91  ConvertError(*chromeos_error, &error);
92  error_callback.Run(std::move(error));
93}
94
95}  // namespace
96
97void SocketStream::Read(void* buffer,
98                        size_t size_to_read,
99                        const ReadSuccessCallback& success_callback,
100                        const weave::ErrorCallback& error_callback) {
101  brillo::ErrorPtr chromeos_error;
102  if (!ptr_->ReadAsync(buffer, size_to_read, success_callback,
103                       base::Bind(&OnError, error_callback), &chromeos_error)) {
104    weave::ErrorPtr error;
105    ConvertError(*chromeos_error, &error);
106    base::MessageLoop::current()->PostTask(
107        FROM_HERE, base::Bind(error_callback, base::Passed(&error)));
108  }
109}
110
111void SocketStream::Write(const void* buffer,
112                         size_t size_to_write,
113                         const weave::SuccessCallback& success_callback,
114                         const weave::ErrorCallback& error_callback) {
115  brillo::ErrorPtr chromeos_error;
116  if (!ptr_->WriteAllAsync(buffer, size_to_write, success_callback,
117                           base::Bind(&OnError, error_callback),
118                           &chromeos_error)) {
119    weave::ErrorPtr error;
120    ConvertError(*chromeos_error, &error);
121    base::MessageLoop::current()->PostTask(
122        FROM_HERE, base::Bind(error_callback, base::Passed(&error)));
123  }
124}
125
126void SocketStream::CancelPendingOperations() {
127  ptr_->CancelPendingAsyncOperations();
128}
129
130std::unique_ptr<weave::Stream> SocketStream::ConnectBlocking(
131    const std::string& host,
132    uint16_t port) {
133  int socket_fd = ConnectSocket(host, port);
134  if (socket_fd <= 0)
135    return nullptr;
136
137  auto ptr_ =
138      brillo::FileStream::FromFileDescriptor(socket_fd, true, nullptr);
139  if (ptr_)
140    return std::unique_ptr<Stream>{new SocketStream{std::move(ptr_)}};
141
142  close(socket_fd);
143  return nullptr;
144}
145
146void SocketStream::TlsConnect(
147    std::unique_ptr<Stream> socket,
148    const std::string& host,
149    const base::Callback<void(std::unique_ptr<Stream>)>& success_callback,
150    const weave::ErrorCallback& error_callback) {
151  SocketStream* stream = static_cast<SocketStream*>(socket.get());
152  brillo::TlsStream::Connect(std::move(stream->ptr_), host,
153                             base::Bind(&OnSuccess, success_callback),
154                             base::Bind(&OnError, error_callback));
155}
156
157}  // namespace buffet
158