1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/ssl_hmac_channel_authenticator.h"
6
7#include "base/bind.h"
8#include "base/bind_helpers.h"
9#include "crypto/secure_util.h"
10#include "net/base/host_port_pair.h"
11#include "net/base/io_buffer.h"
12#include "net/base/net_errors.h"
13#include "net/cert/cert_verifier.h"
14#include "net/cert/x509_certificate.h"
15#include "net/http/transport_security_state.h"
16#include "net/socket/client_socket_factory.h"
17#include "net/socket/client_socket_handle.h"
18#include "net/socket/ssl_client_socket.h"
19#include "net/socket/ssl_server_socket.h"
20#include "net/ssl/ssl_config_service.h"
21#include "remoting/base/rsa_key_pair.h"
22#include "remoting/protocol/auth_util.h"
23
24namespace remoting {
25namespace protocol {
26
27// static
28scoped_ptr<SslHmacChannelAuthenticator>
29SslHmacChannelAuthenticator::CreateForClient(
30      const std::string& remote_cert,
31      const std::string& auth_key) {
32  scoped_ptr<SslHmacChannelAuthenticator> result(
33      new SslHmacChannelAuthenticator(auth_key));
34  result->remote_cert_ = remote_cert;
35  return result.Pass();
36}
37
38scoped_ptr<SslHmacChannelAuthenticator>
39SslHmacChannelAuthenticator::CreateForHost(
40    const std::string& local_cert,
41    scoped_refptr<RsaKeyPair> key_pair,
42    const std::string& auth_key) {
43  scoped_ptr<SslHmacChannelAuthenticator> result(
44      new SslHmacChannelAuthenticator(auth_key));
45  result->local_cert_ = local_cert;
46  result->local_key_pair_ = key_pair;
47  return result.Pass();
48}
49
50SslHmacChannelAuthenticator::SslHmacChannelAuthenticator(
51    const std::string& auth_key)
52    : auth_key_(auth_key) {
53}
54
55SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() {
56}
57
58void SslHmacChannelAuthenticator::SecureAndAuthenticate(
59    scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) {
60  DCHECK(CalledOnValidThread());
61  DCHECK(socket->IsConnected());
62
63  done_callback_ = done_callback;
64
65  int result;
66  if (is_ssl_server()) {
67    scoped_refptr<net::X509Certificate> cert =
68        net::X509Certificate::CreateFromBytes(
69            local_cert_.data(), local_cert_.length());
70    if (!cert.get()) {
71      LOG(ERROR) << "Failed to parse X509Certificate";
72      NotifyError(net::ERR_FAILED);
73      return;
74    }
75
76    net::SSLConfig ssl_config;
77    ssl_config.require_forward_secrecy = true;
78
79    scoped_ptr<net::SSLServerSocket> server_socket =
80        net::CreateSSLServerSocket(socket.Pass(),
81                                   cert.get(),
82                                   local_key_pair_->private_key(),
83                                   ssl_config);
84    net::SSLServerSocket* raw_server_socket = server_socket.get();
85    socket_ = server_socket.Pass();
86    result = raw_server_socket->Handshake(
87        base::Bind(&SslHmacChannelAuthenticator::OnConnected,
88                   base::Unretained(this)));
89  } else {
90    cert_verifier_.reset(net::CertVerifier::CreateDefault());
91    transport_security_state_.reset(new net::TransportSecurityState);
92
93    net::SSLConfig::CertAndStatus cert_and_status;
94    cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID;
95    cert_and_status.der_cert = remote_cert_;
96
97    net::SSLConfig ssl_config;
98    // Certificate verification and revocation checking are not needed
99    // because we use self-signed certs. Disable it so that the SSL
100    // layer doesn't try to initialize OCSP (OCSP works only on the IO
101    // thread).
102    ssl_config.cert_io_enabled = false;
103    ssl_config.rev_checking_enabled = false;
104    ssl_config.allowed_bad_certs.push_back(cert_and_status);
105
106    net::HostPortPair host_and_port(kSslFakeHostName, 0);
107    net::SSLClientSocketContext context;
108    context.cert_verifier = cert_verifier_.get();
109    context.transport_security_state = transport_security_state_.get();
110    scoped_ptr<net::ClientSocketHandle> connection(new net::ClientSocketHandle);
111    connection->SetSocket(socket.Pass());
112    socket_ =
113        net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket(
114            connection.Pass(), host_and_port, ssl_config, context);
115
116    result = socket_->Connect(
117        base::Bind(&SslHmacChannelAuthenticator::OnConnected,
118                   base::Unretained(this)));
119  }
120
121  if (result == net::ERR_IO_PENDING)
122    return;
123
124  OnConnected(result);
125}
126
127bool SslHmacChannelAuthenticator::is_ssl_server() {
128  return local_key_pair_.get() != NULL;
129}
130
131void SslHmacChannelAuthenticator::OnConnected(int result) {
132  if (result != net::OK) {
133    LOG(WARNING) << "Failed to establish SSL connection";
134    NotifyError(result);
135    return;
136  }
137
138  // Generate authentication digest to write to the socket.
139  std::string auth_bytes = GetAuthBytes(
140      socket_.get(), is_ssl_server() ?
141      kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_);
142  if (auth_bytes.empty()) {
143    NotifyError(net::ERR_FAILED);
144    return;
145  }
146
147  // Allocate a buffer to write the digest.
148  auth_write_buf_ = new net::DrainableIOBuffer(
149      new net::StringIOBuffer(auth_bytes), auth_bytes.size());
150
151  // Read an incoming token.
152  auth_read_buf_ = new net::GrowableIOBuffer();
153  auth_read_buf_->SetCapacity(kAuthDigestLength);
154
155  // If WriteAuthenticationBytes() results in |done_callback_| being
156  // called then we must not do anything else because this object may
157  // be destroyed at that point.
158  bool callback_called = false;
159  WriteAuthenticationBytes(&callback_called);
160  if (!callback_called)
161    ReadAuthenticationBytes();
162}
163
164void SslHmacChannelAuthenticator::WriteAuthenticationBytes(
165    bool* callback_called) {
166  while (true) {
167    int result = socket_->Write(
168        auth_write_buf_.get(),
169        auth_write_buf_->BytesRemaining(),
170        base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten,
171                   base::Unretained(this)));
172    if (result == net::ERR_IO_PENDING)
173      break;
174    if (!HandleAuthBytesWritten(result, callback_called))
175      break;
176  }
177}
178
179void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) {
180  DCHECK(CalledOnValidThread());
181
182  if (HandleAuthBytesWritten(result, NULL))
183    WriteAuthenticationBytes(NULL);
184}
185
186bool SslHmacChannelAuthenticator::HandleAuthBytesWritten(
187    int result, bool* callback_called) {
188  if (result <= 0) {
189    LOG(ERROR) << "Error writing authentication: " << result;
190    if (callback_called)
191      *callback_called = false;
192    NotifyError(result);
193    return false;
194  }
195
196  auth_write_buf_->DidConsume(result);
197  if (auth_write_buf_->BytesRemaining() > 0)
198    return true;
199
200  auth_write_buf_ = NULL;
201  CheckDone(callback_called);
202  return false;
203}
204
205void SslHmacChannelAuthenticator::ReadAuthenticationBytes() {
206  while (true) {
207    int result =
208        socket_->Read(auth_read_buf_.get(),
209                      auth_read_buf_->RemainingCapacity(),
210                      base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead,
211                                 base::Unretained(this)));
212    if (result == net::ERR_IO_PENDING)
213      break;
214    if (!HandleAuthBytesRead(result))
215      break;
216  }
217}
218
219void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) {
220  DCHECK(CalledOnValidThread());
221
222  if (HandleAuthBytesRead(result))
223    ReadAuthenticationBytes();
224}
225
226bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) {
227  if (read_result <= 0) {
228    NotifyError(read_result);
229    return false;
230  }
231
232  auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result);
233  if (auth_read_buf_->RemainingCapacity() > 0)
234    return true;
235
236  if (!VerifyAuthBytes(std::string(
237          auth_read_buf_->StartOfBuffer(),
238          auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) {
239    LOG(WARNING) << "Mismatched authentication";
240    NotifyError(net::ERR_FAILED);
241    return false;
242  }
243
244  auth_read_buf_ = NULL;
245  CheckDone(NULL);
246  return false;
247}
248
249bool SslHmacChannelAuthenticator::VerifyAuthBytes(
250    const std::string& received_auth_bytes) {
251  DCHECK(received_auth_bytes.length() == kAuthDigestLength);
252
253  // Compute expected auth bytes.
254  std::string auth_bytes = GetAuthBytes(
255      socket_.get(), is_ssl_server() ?
256      kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_);
257  if (auth_bytes.empty())
258    return false;
259
260  return crypto::SecureMemEqual(received_auth_bytes.data(),
261                                &(auth_bytes[0]), kAuthDigestLength);
262}
263
264void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) {
265  if (auth_write_buf_.get() == NULL && auth_read_buf_.get() == NULL) {
266    DCHECK(socket_.get() != NULL);
267    if (callback_called)
268      *callback_called = true;
269    done_callback_.Run(net::OK, socket_.PassAs<net::StreamSocket>());
270  }
271}
272
273void SslHmacChannelAuthenticator::NotifyError(int error) {
274  done_callback_.Run(static_cast<net::Error>(error),
275                     scoped_ptr<net::StreamSocket>());
276}
277
278}  // namespace protocol
279}  // namespace remoting
280