1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/ssl_hmac_channel_authenticator.h"
6
7#include "base/bind.h"
8#include "base/bind_helpers.h"
9#include "crypto/secure_util.h"
10#include "net/base/host_port_pair.h"
11#include "net/base/io_buffer.h"
12#include "net/base/net_errors.h"
13#include "net/cert/x509_certificate.h"
14#include "net/http/transport_security_state.h"
15#include "net/socket/client_socket_factory.h"
16#include "net/socket/client_socket_handle.h"
17#include "net/socket/ssl_client_socket.h"
18#include "net/socket/ssl_client_socket_openssl.h"
19#include "net/socket/ssl_server_socket.h"
20#include "net/ssl/ssl_config_service.h"
21#include "remoting/base/rsa_key_pair.h"
22#include "remoting/protocol/auth_util.h"
23
24namespace remoting {
25namespace protocol {
26
27// static
28scoped_ptr<SslHmacChannelAuthenticator>
29SslHmacChannelAuthenticator::CreateForClient(
30      const std::string& remote_cert,
31      const std::string& auth_key) {
32  scoped_ptr<SslHmacChannelAuthenticator> result(
33      new SslHmacChannelAuthenticator(auth_key));
34  result->remote_cert_ = remote_cert;
35  return result.Pass();
36}
37
38scoped_ptr<SslHmacChannelAuthenticator>
39SslHmacChannelAuthenticator::CreateForHost(
40    const std::string& local_cert,
41    scoped_refptr<RsaKeyPair> key_pair,
42    const std::string& auth_key) {
43  scoped_ptr<SslHmacChannelAuthenticator> result(
44      new SslHmacChannelAuthenticator(auth_key));
45  result->local_cert_ = local_cert;
46  result->local_key_pair_ = key_pair;
47  return result.Pass();
48}
49
50SslHmacChannelAuthenticator::SslHmacChannelAuthenticator(
51    const std::string& auth_key)
52    : auth_key_(auth_key) {
53}
54
55SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() {
56}
57
58void SslHmacChannelAuthenticator::SecureAndAuthenticate(
59    scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) {
60  DCHECK(CalledOnValidThread());
61  DCHECK(socket->IsConnected());
62
63  done_callback_ = done_callback;
64
65  int result;
66  if (is_ssl_server()) {
67#if defined(OS_NACL)
68    // Client plugin doesn't use server SSL sockets, and so SSLServerSocket
69    // implementation is not compiled for NaCl as part of net_nacl.
70    NOTREACHED();
71    result = net::ERR_FAILED;
72#else
73    scoped_refptr<net::X509Certificate> cert =
74        net::X509Certificate::CreateFromBytes(
75            local_cert_.data(), local_cert_.length());
76    if (!cert.get()) {
77      LOG(ERROR) << "Failed to parse X509Certificate";
78      NotifyError(net::ERR_FAILED);
79      return;
80    }
81
82    net::SSLConfig ssl_config;
83    ssl_config.require_forward_secrecy = true;
84
85    scoped_ptr<net::SSLServerSocket> server_socket =
86        net::CreateSSLServerSocket(socket.Pass(),
87                                   cert.get(),
88                                   local_key_pair_->private_key(),
89                                   ssl_config);
90    net::SSLServerSocket* raw_server_socket = server_socket.get();
91    socket_ = server_socket.Pass();
92    result = raw_server_socket->Handshake(
93        base::Bind(&SslHmacChannelAuthenticator::OnConnected,
94                   base::Unretained(this)));
95#endif
96  } else {
97    transport_security_state_.reset(new net::TransportSecurityState);
98
99    net::SSLConfig::CertAndStatus cert_and_status;
100    cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
101    cert_and_status.der_cert = remote_cert_;
102
103    net::SSLConfig ssl_config;
104    // Certificate verification and revocation checking are not needed
105    // because we use self-signed certs. Disable it so that the SSL
106    // layer doesn't try to initialize OCSP (OCSP works only on the IO
107    // thread).
108    ssl_config.cert_io_enabled = false;
109    ssl_config.rev_checking_enabled = false;
110    ssl_config.allowed_bad_certs.push_back(cert_and_status);
111
112    net::HostPortPair host_and_port(kSslFakeHostName, 0);
113    net::SSLClientSocketContext context;
114    context.transport_security_state = transport_security_state_.get();
115    scoped_ptr<net::ClientSocketHandle> socket_handle(
116        new net::ClientSocketHandle);
117    socket_handle->SetSocket(socket.Pass());
118
119#if defined(OS_NACL)
120    // net_nacl doesn't include ClientSocketFactory.
121    socket_.reset(new net::SSLClientSocketOpenSSL(
122        socket_handle.Pass(), host_and_port, ssl_config, context));
123#else
124    socket_ =
125        net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
126            socket_handle.Pass(), host_and_port, ssl_config, context);
127#endif
128
129    result = socket_->Connect(
130        base::Bind(&SslHmacChannelAuthenticator::OnConnected,
131                   base::Unretained(this)));
132  }
133
134  if (result == net::ERR_IO_PENDING)
135    return;
136
137  OnConnected(result);
138}
139
140bool SslHmacChannelAuthenticator::is_ssl_server() {
141  return local_key_pair_.get() != NULL;
142}
143
144void SslHmacChannelAuthenticator::OnConnected(int result) {
145  if (result != net::OK) {
146    LOG(WARNING) << "Failed to establish SSL connection";
147    NotifyError(result);
148    return;
149  }
150
151  // Generate authentication digest to write to the socket.
152  std::string auth_bytes = GetAuthBytes(
153      socket_.get(), is_ssl_server() ?
154      kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_);
155  if (auth_bytes.empty()) {
156    NotifyError(net::ERR_FAILED);
157    return;
158  }
159
160  // Allocate a buffer to write the digest.
161  auth_write_buf_ = new net::DrainableIOBuffer(
162      new net::StringIOBuffer(auth_bytes), auth_bytes.size());
163
164  // Read an incoming token.
165  auth_read_buf_ = new net::GrowableIOBuffer();
166  auth_read_buf_->SetCapacity(kAuthDigestLength);
167
168  // If WriteAuthenticationBytes() results in |done_callback_| being
169  // called then we must not do anything else because this object may
170  // be destroyed at that point.
171  bool callback_called = false;
172  WriteAuthenticationBytes(&callback_called);
173  if (!callback_called)
174    ReadAuthenticationBytes();
175}
176
177void SslHmacChannelAuthenticator::WriteAuthenticationBytes(
178    bool* callback_called) {
179  while (true) {
180    int result = socket_->Write(
181        auth_write_buf_.get(),
182        auth_write_buf_->BytesRemaining(),
183        base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten,
184                   base::Unretained(this)));
185    if (result == net::ERR_IO_PENDING)
186      break;
187    if (!HandleAuthBytesWritten(result, callback_called))
188      break;
189  }
190}
191
192void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) {
193  DCHECK(CalledOnValidThread());
194
195  if (HandleAuthBytesWritten(result, NULL))
196    WriteAuthenticationBytes(NULL);
197}
198
199bool SslHmacChannelAuthenticator::HandleAuthBytesWritten(
200    int result, bool* callback_called) {
201  if (result <= 0) {
202    LOG(ERROR) << "Error writing authentication: " << result;
203    if (callback_called)
204      *callback_called = false;
205    NotifyError(result);
206    return false;
207  }
208
209  auth_write_buf_->DidConsume(result);
210  if (auth_write_buf_->BytesRemaining() > 0)
211    return true;
212
213  auth_write_buf_ = NULL;
214  CheckDone(callback_called);
215  return false;
216}
217
218void SslHmacChannelAuthenticator::ReadAuthenticationBytes() {
219  while (true) {
220    int result =
221        socket_->Read(auth_read_buf_.get(),
222                      auth_read_buf_->RemainingCapacity(),
223                      base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead,
224                                 base::Unretained(this)));
225    if (result == net::ERR_IO_PENDING)
226      break;
227    if (!HandleAuthBytesRead(result))
228      break;
229  }
230}
231
232void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) {
233  DCHECK(CalledOnValidThread());
234
235  if (HandleAuthBytesRead(result))
236    ReadAuthenticationBytes();
237}
238
239bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) {
240  if (read_result <= 0) {
241    NotifyError(read_result);
242    return false;
243  }
244
245  auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result);
246  if (auth_read_buf_->RemainingCapacity() > 0)
247    return true;
248
249  if (!VerifyAuthBytes(std::string(
250          auth_read_buf_->StartOfBuffer(),
251          auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) {
252    LOG(WARNING) << "Mismatched authentication";
253    NotifyError(net::ERR_FAILED);
254    return false;
255  }
256
257  auth_read_buf_ = NULL;
258  CheckDone(NULL);
259  return false;
260}
261
262bool SslHmacChannelAuthenticator::VerifyAuthBytes(
263    const std::string& received_auth_bytes) {
264  DCHECK(received_auth_bytes.length() == kAuthDigestLength);
265
266  // Compute expected auth bytes.
267  std::string auth_bytes = GetAuthBytes(
268      socket_.get(), is_ssl_server() ?
269      kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_);
270  if (auth_bytes.empty())
271    return false;
272
273  return crypto::SecureMemEqual(received_auth_bytes.data(),
274                                &(auth_bytes[0]), kAuthDigestLength);
275}
276
277void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) {
278  if (auth_write_buf_.get() == NULL && auth_read_buf_.get() == NULL) {
279    DCHECK(socket_.get() != NULL);
280    if (callback_called)
281      *callback_called = true;
282
283    CallDoneCallback(net::OK, socket_.PassAs<net::StreamSocket>());
284  }
285}
286
287void SslHmacChannelAuthenticator::NotifyError(int error) {
288  CallDoneCallback(error, scoped_ptr<net::StreamSocket>());
289}
290
291void SslHmacChannelAuthenticator::CallDoneCallback(
292    int error,
293    scoped_ptr<net::StreamSocket> socket) {
294  DoneCallback callback = done_callback_;
295  done_callback_.Reset();
296  callback.Run(error, socket.Pass());
297}
298
299}  // namespace protocol
300}  // namespace remoting
301