1// Copyright 2014 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "extensions/browser/api/socket/tcp_socket.h"
6
7#include "extensions/browser/api/api_resource.h"
8#include "net/base/address_list.h"
9#include "net/base/ip_endpoint.h"
10#include "net/base/net_errors.h"
11#include "net/base/rand_callback.h"
12#include "net/socket/tcp_client_socket.h"
13
14namespace extensions {
15
16const char kTCPSocketTypeInvalidError[] =
17    "Cannot call both connect and listen on the same socket.";
18const char kSocketListenError[] = "Could not listen on the specified port.";
19
20static base::LazyInstance<
21    BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> > >
22    g_factory = LAZY_INSTANCE_INITIALIZER;
23
24// static
25template <>
26BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPSocket> >*
27ApiResourceManager<ResumableTCPSocket>::GetFactoryInstance() {
28  return g_factory.Pointer();
29}
30
31static base::LazyInstance<BrowserContextKeyedAPIFactory<
32    ApiResourceManager<ResumableTCPServerSocket> > > g_server_factory =
33    LAZY_INSTANCE_INITIALIZER;
34
35// static
36template <>
37BrowserContextKeyedAPIFactory<ApiResourceManager<ResumableTCPServerSocket> >*
38ApiResourceManager<ResumableTCPServerSocket>::GetFactoryInstance() {
39  return g_server_factory.Pointer();
40}
41
42TCPSocket::TCPSocket(const std::string& owner_extension_id)
43    : Socket(owner_extension_id), socket_mode_(UNKNOWN) {}
44
45TCPSocket::TCPSocket(net::TCPClientSocket* tcp_client_socket,
46                     const std::string& owner_extension_id,
47                     bool is_connected)
48    : Socket(owner_extension_id),
49      socket_(tcp_client_socket),
50      socket_mode_(CLIENT) {
51  this->is_connected_ = is_connected;
52}
53
54TCPSocket::TCPSocket(net::TCPServerSocket* tcp_server_socket,
55                     const std::string& owner_extension_id)
56    : Socket(owner_extension_id),
57      server_socket_(tcp_server_socket),
58      socket_mode_(SERVER) {}
59
60// static
61TCPSocket* TCPSocket::CreateSocketForTesting(
62    net::TCPClientSocket* tcp_client_socket,
63    const std::string& owner_extension_id,
64    bool is_connected) {
65  return new TCPSocket(tcp_client_socket, owner_extension_id, is_connected);
66}
67
68// static
69TCPSocket* TCPSocket::CreateServerSocketForTesting(
70    net::TCPServerSocket* tcp_server_socket,
71    const std::string& owner_extension_id) {
72  return new TCPSocket(tcp_server_socket, owner_extension_id);
73}
74
75TCPSocket::~TCPSocket() { Disconnect(); }
76
77void TCPSocket::Connect(const std::string& address,
78                        int port,
79                        const CompletionCallback& callback) {
80  DCHECK(!callback.is_null());
81
82  if (socket_mode_ == SERVER || !connect_callback_.is_null()) {
83    callback.Run(net::ERR_CONNECTION_FAILED);
84    return;
85  }
86  DCHECK(!server_socket_.get());
87  socket_mode_ = CLIENT;
88  connect_callback_ = callback;
89
90  int result = net::ERR_CONNECTION_FAILED;
91  do {
92    if (is_connected_)
93      break;
94
95    net::AddressList address_list;
96    if (!StringAndPortToAddressList(address, port, &address_list)) {
97      result = net::ERR_ADDRESS_INVALID;
98      break;
99    }
100
101    socket_.reset(
102        new net::TCPClientSocket(address_list, NULL, net::NetLog::Source()));
103
104    connect_callback_ = callback;
105    result = socket_->Connect(
106        base::Bind(&TCPSocket::OnConnectComplete, base::Unretained(this)));
107  } while (false);
108
109  if (result != net::ERR_IO_PENDING)
110    OnConnectComplete(result);
111}
112
113void TCPSocket::Disconnect() {
114  is_connected_ = false;
115  if (socket_.get())
116    socket_->Disconnect();
117  server_socket_.reset(NULL);
118  connect_callback_.Reset();
119  read_callback_.Reset();
120  accept_callback_.Reset();
121  accept_socket_.reset(NULL);
122}
123
124int TCPSocket::Bind(const std::string& address, int port) {
125  return net::ERR_FAILED;
126}
127
128void TCPSocket::Read(int count, const ReadCompletionCallback& callback) {
129  DCHECK(!callback.is_null());
130
131  if (socket_mode_ != CLIENT) {
132    callback.Run(net::ERR_FAILED, NULL);
133    return;
134  }
135
136  if (!read_callback_.is_null()) {
137    callback.Run(net::ERR_IO_PENDING, NULL);
138    return;
139  }
140
141  if (count < 0) {
142    callback.Run(net::ERR_INVALID_ARGUMENT, NULL);
143    return;
144  }
145
146  if (!socket_.get() || !IsConnected()) {
147    callback.Run(net::ERR_SOCKET_NOT_CONNECTED, NULL);
148    return;
149  }
150
151  read_callback_ = callback;
152  scoped_refptr<net::IOBuffer> io_buffer = new net::IOBuffer(count);
153  int result = socket_->Read(
154      io_buffer.get(),
155      count,
156      base::Bind(
157          &TCPSocket::OnReadComplete, base::Unretained(this), io_buffer));
158
159  if (result != net::ERR_IO_PENDING)
160    OnReadComplete(io_buffer, result);
161}
162
163void TCPSocket::RecvFrom(int count,
164                         const RecvFromCompletionCallback& callback) {
165  callback.Run(net::ERR_FAILED, NULL, NULL, 0);
166}
167
168void TCPSocket::SendTo(scoped_refptr<net::IOBuffer> io_buffer,
169                       int byte_count,
170                       const std::string& address,
171                       int port,
172                       const CompletionCallback& callback) {
173  callback.Run(net::ERR_FAILED);
174}
175
176bool TCPSocket::SetKeepAlive(bool enable, int delay) {
177  if (!socket_.get())
178    return false;
179  return socket_->SetKeepAlive(enable, delay);
180}
181
182bool TCPSocket::SetNoDelay(bool no_delay) {
183  if (!socket_.get())
184    return false;
185  return socket_->SetNoDelay(no_delay);
186}
187
188int TCPSocket::Listen(const std::string& address,
189                      int port,
190                      int backlog,
191                      std::string* error_msg) {
192  if (socket_mode_ == CLIENT) {
193    *error_msg = kTCPSocketTypeInvalidError;
194    return net::ERR_NOT_IMPLEMENTED;
195  }
196  DCHECK(!socket_.get());
197  socket_mode_ = SERVER;
198
199  scoped_ptr<net::IPEndPoint> bind_address(new net::IPEndPoint());
200  if (!StringAndPortToIPEndPoint(address, port, bind_address.get()))
201    return net::ERR_INVALID_ARGUMENT;
202
203  if (!server_socket_.get()) {
204    server_socket_.reset(new net::TCPServerSocket(NULL, net::NetLog::Source()));
205  }
206  int result = server_socket_->Listen(*bind_address, backlog);
207  if (result)
208    *error_msg = kSocketListenError;
209  return result;
210}
211
212void TCPSocket::Accept(const AcceptCompletionCallback& callback) {
213  if (socket_mode_ != SERVER || !server_socket_.get()) {
214    callback.Run(net::ERR_FAILED, NULL);
215    return;
216  }
217
218  // Limits to only 1 blocked accept call.
219  if (!accept_callback_.is_null()) {
220    callback.Run(net::ERR_FAILED, NULL);
221    return;
222  }
223
224  int result = server_socket_->Accept(
225      &accept_socket_,
226      base::Bind(&TCPSocket::OnAccept, base::Unretained(this)));
227  if (result == net::ERR_IO_PENDING) {
228    accept_callback_ = callback;
229  } else if (result == net::OK) {
230    accept_callback_ = callback;
231    this->OnAccept(result);
232  } else {
233    callback.Run(result, NULL);
234  }
235}
236
237bool TCPSocket::IsConnected() {
238  RefreshConnectionStatus();
239  return is_connected_;
240}
241
242bool TCPSocket::GetPeerAddress(net::IPEndPoint* address) {
243  if (!socket_.get())
244    return false;
245  return !socket_->GetPeerAddress(address);
246}
247
248bool TCPSocket::GetLocalAddress(net::IPEndPoint* address) {
249  if (socket_.get()) {
250    return !socket_->GetLocalAddress(address);
251  } else if (server_socket_.get()) {
252    return !server_socket_->GetLocalAddress(address);
253  } else {
254    return false;
255  }
256}
257
258Socket::SocketType TCPSocket::GetSocketType() const { return Socket::TYPE_TCP; }
259
260int TCPSocket::WriteImpl(net::IOBuffer* io_buffer,
261                         int io_buffer_size,
262                         const net::CompletionCallback& callback) {
263  if (socket_mode_ != CLIENT)
264    return net::ERR_FAILED;
265  else if (!socket_.get() || !IsConnected())
266    return net::ERR_SOCKET_NOT_CONNECTED;
267  else
268    return socket_->Write(io_buffer, io_buffer_size, callback);
269}
270
271void TCPSocket::RefreshConnectionStatus() {
272  if (!is_connected_)
273    return;
274  if (server_socket_)
275    return;
276  if (!socket_->IsConnected()) {
277    Disconnect();
278  }
279}
280
281void TCPSocket::OnConnectComplete(int result) {
282  DCHECK(!connect_callback_.is_null());
283  DCHECK(!is_connected_);
284  is_connected_ = result == net::OK;
285  connect_callback_.Run(result);
286  connect_callback_.Reset();
287}
288
289void TCPSocket::OnReadComplete(scoped_refptr<net::IOBuffer> io_buffer,
290                               int result) {
291  DCHECK(!read_callback_.is_null());
292  read_callback_.Run(result, io_buffer);
293  read_callback_.Reset();
294}
295
296void TCPSocket::OnAccept(int result) {
297  DCHECK(!accept_callback_.is_null());
298  if (result == net::OK && accept_socket_.get()) {
299    accept_callback_.Run(
300        result, static_cast<net::TCPClientSocket*>(accept_socket_.release()));
301  } else {
302    accept_callback_.Run(result, NULL);
303  }
304  accept_callback_.Reset();
305}
306
307ResumableTCPSocket::ResumableTCPSocket(const std::string& owner_extension_id)
308    : TCPSocket(owner_extension_id),
309      persistent_(false),
310      buffer_size_(0),
311      paused_(false) {}
312
313ResumableTCPSocket::ResumableTCPSocket(net::TCPClientSocket* tcp_client_socket,
314                                       const std::string& owner_extension_id,
315                                       bool is_connected)
316    : TCPSocket(tcp_client_socket, owner_extension_id, is_connected),
317      persistent_(false),
318      buffer_size_(0),
319      paused_(false) {}
320
321bool ResumableTCPSocket::IsPersistent() const { return persistent(); }
322
323ResumableTCPServerSocket::ResumableTCPServerSocket(
324    const std::string& owner_extension_id)
325    : TCPSocket(owner_extension_id), persistent_(false), paused_(false) {}
326
327bool ResumableTCPServerSocket::IsPersistent() const { return persistent(); }
328
329}  // namespace extensions
330