1// Copyright (c) 2012 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "remoting/protocol/ssl_hmac_channel_authenticator.h" 6 7#include "base/bind.h" 8#include "base/bind_helpers.h" 9#include "crypto/secure_util.h" 10#include "net/base/host_port_pair.h" 11#include "net/base/io_buffer.h" 12#include "net/base/net_errors.h" 13#include "net/cert/cert_verifier.h" 14#include "net/cert/x509_certificate.h" 15#include "net/http/transport_security_state.h" 16#include "net/socket/client_socket_factory.h" 17#include "net/socket/client_socket_handle.h" 18#include "net/socket/ssl_client_socket.h" 19#include "net/socket/ssl_server_socket.h" 20#include "net/ssl/ssl_config_service.h" 21#include "remoting/base/rsa_key_pair.h" 22#include "remoting/protocol/auth_util.h" 23 24namespace remoting { 25namespace protocol { 26 27// static 28scoped_ptr<SslHmacChannelAuthenticator> 29SslHmacChannelAuthenticator::CreateForClient( 30 const std::string& remote_cert, 31 const std::string& auth_key) { 32 scoped_ptr<SslHmacChannelAuthenticator> result( 33 new SslHmacChannelAuthenticator(auth_key)); 34 result->remote_cert_ = remote_cert; 35 return result.Pass(); 36} 37 38scoped_ptr<SslHmacChannelAuthenticator> 39SslHmacChannelAuthenticator::CreateForHost( 40 const std::string& local_cert, 41 scoped_refptr<RsaKeyPair> key_pair, 42 const std::string& auth_key) { 43 scoped_ptr<SslHmacChannelAuthenticator> result( 44 new SslHmacChannelAuthenticator(auth_key)); 45 result->local_cert_ = local_cert; 46 result->local_key_pair_ = key_pair; 47 return result.Pass(); 48} 49 50SslHmacChannelAuthenticator::SslHmacChannelAuthenticator( 51 const std::string& auth_key) 52 : auth_key_(auth_key) { 53} 54 55SslHmacChannelAuthenticator::~SslHmacChannelAuthenticator() { 56} 57 58void SslHmacChannelAuthenticator::SecureAndAuthenticate( 59 scoped_ptr<net::StreamSocket> socket, const DoneCallback& done_callback) { 60 DCHECK(CalledOnValidThread()); 61 DCHECK(socket->IsConnected()); 62 63 done_callback_ = done_callback; 64 65 int result; 66 if (is_ssl_server()) { 67 scoped_refptr<net::X509Certificate> cert = 68 net::X509Certificate::CreateFromBytes( 69 local_cert_.data(), local_cert_.length()); 70 if (!cert.get()) { 71 LOG(ERROR) << "Failed to parse X509Certificate"; 72 NotifyError(net::ERR_FAILED); 73 return; 74 } 75 76 net::SSLConfig ssl_config; 77 ssl_config.require_forward_secrecy = true; 78 79 scoped_ptr<net::SSLServerSocket> server_socket = 80 net::CreateSSLServerSocket(socket.Pass(), 81 cert.get(), 82 local_key_pair_->private_key(), 83 ssl_config); 84 net::SSLServerSocket* raw_server_socket = server_socket.get(); 85 socket_ = server_socket.Pass(); 86 result = raw_server_socket->Handshake( 87 base::Bind(&SslHmacChannelAuthenticator::OnConnected, 88 base::Unretained(this))); 89 } else { 90 cert_verifier_.reset(net::CertVerifier::CreateDefault()); 91 transport_security_state_.reset(new net::TransportSecurityState); 92 93 net::SSLConfig::CertAndStatus cert_and_status; 94 cert_and_status.cert_status = net::CERT_STATUS_AUTHORITY_INVALID; 95 cert_and_status.der_cert = remote_cert_; 96 97 net::SSLConfig ssl_config; 98 // Certificate verification and revocation checking are not needed 99 // because we use self-signed certs. Disable it so that the SSL 100 // layer doesn't try to initialize OCSP (OCSP works only on the IO 101 // thread). 102 ssl_config.cert_io_enabled = false; 103 ssl_config.rev_checking_enabled = false; 104 ssl_config.allowed_bad_certs.push_back(cert_and_status); 105 106 net::HostPortPair host_and_port(kSslFakeHostName, 0); 107 net::SSLClientSocketContext context; 108 context.cert_verifier = cert_verifier_.get(); 109 context.transport_security_state = transport_security_state_.get(); 110 scoped_ptr<net::ClientSocketHandle> connection(new net::ClientSocketHandle); 111 connection->SetSocket(socket.Pass()); 112 socket_ = 113 net::ClientSocketFactory::GetDefaultFactory()->CreateSSLClientSocket( 114 connection.Pass(), host_and_port, ssl_config, context); 115 116 result = socket_->Connect( 117 base::Bind(&SslHmacChannelAuthenticator::OnConnected, 118 base::Unretained(this))); 119 } 120 121 if (result == net::ERR_IO_PENDING) 122 return; 123 124 OnConnected(result); 125} 126 127bool SslHmacChannelAuthenticator::is_ssl_server() { 128 return local_key_pair_.get() != NULL; 129} 130 131void SslHmacChannelAuthenticator::OnConnected(int result) { 132 if (result != net::OK) { 133 LOG(WARNING) << "Failed to establish SSL connection"; 134 NotifyError(result); 135 return; 136 } 137 138 // Generate authentication digest to write to the socket. 139 std::string auth_bytes = GetAuthBytes( 140 socket_.get(), is_ssl_server() ? 141 kHostAuthSslExporterLabel : kClientAuthSslExporterLabel, auth_key_); 142 if (auth_bytes.empty()) { 143 NotifyError(net::ERR_FAILED); 144 return; 145 } 146 147 // Allocate a buffer to write the digest. 148 auth_write_buf_ = new net::DrainableIOBuffer( 149 new net::StringIOBuffer(auth_bytes), auth_bytes.size()); 150 151 // Read an incoming token. 152 auth_read_buf_ = new net::GrowableIOBuffer(); 153 auth_read_buf_->SetCapacity(kAuthDigestLength); 154 155 // If WriteAuthenticationBytes() results in |done_callback_| being 156 // called then we must not do anything else because this object may 157 // be destroyed at that point. 158 bool callback_called = false; 159 WriteAuthenticationBytes(&callback_called); 160 if (!callback_called) 161 ReadAuthenticationBytes(); 162} 163 164void SslHmacChannelAuthenticator::WriteAuthenticationBytes( 165 bool* callback_called) { 166 while (true) { 167 int result = socket_->Write( 168 auth_write_buf_.get(), 169 auth_write_buf_->BytesRemaining(), 170 base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesWritten, 171 base::Unretained(this))); 172 if (result == net::ERR_IO_PENDING) 173 break; 174 if (!HandleAuthBytesWritten(result, callback_called)) 175 break; 176 } 177} 178 179void SslHmacChannelAuthenticator::OnAuthBytesWritten(int result) { 180 DCHECK(CalledOnValidThread()); 181 182 if (HandleAuthBytesWritten(result, NULL)) 183 WriteAuthenticationBytes(NULL); 184} 185 186bool SslHmacChannelAuthenticator::HandleAuthBytesWritten( 187 int result, bool* callback_called) { 188 if (result <= 0) { 189 LOG(ERROR) << "Error writing authentication: " << result; 190 if (callback_called) 191 *callback_called = false; 192 NotifyError(result); 193 return false; 194 } 195 196 auth_write_buf_->DidConsume(result); 197 if (auth_write_buf_->BytesRemaining() > 0) 198 return true; 199 200 auth_write_buf_ = NULL; 201 CheckDone(callback_called); 202 return false; 203} 204 205void SslHmacChannelAuthenticator::ReadAuthenticationBytes() { 206 while (true) { 207 int result = 208 socket_->Read(auth_read_buf_.get(), 209 auth_read_buf_->RemainingCapacity(), 210 base::Bind(&SslHmacChannelAuthenticator::OnAuthBytesRead, 211 base::Unretained(this))); 212 if (result == net::ERR_IO_PENDING) 213 break; 214 if (!HandleAuthBytesRead(result)) 215 break; 216 } 217} 218 219void SslHmacChannelAuthenticator::OnAuthBytesRead(int result) { 220 DCHECK(CalledOnValidThread()); 221 222 if (HandleAuthBytesRead(result)) 223 ReadAuthenticationBytes(); 224} 225 226bool SslHmacChannelAuthenticator::HandleAuthBytesRead(int read_result) { 227 if (read_result <= 0) { 228 NotifyError(read_result); 229 return false; 230 } 231 232 auth_read_buf_->set_offset(auth_read_buf_->offset() + read_result); 233 if (auth_read_buf_->RemainingCapacity() > 0) 234 return true; 235 236 if (!VerifyAuthBytes(std::string( 237 auth_read_buf_->StartOfBuffer(), 238 auth_read_buf_->StartOfBuffer() + kAuthDigestLength))) { 239 LOG(WARNING) << "Mismatched authentication"; 240 NotifyError(net::ERR_FAILED); 241 return false; 242 } 243 244 auth_read_buf_ = NULL; 245 CheckDone(NULL); 246 return false; 247} 248 249bool SslHmacChannelAuthenticator::VerifyAuthBytes( 250 const std::string& received_auth_bytes) { 251 DCHECK(received_auth_bytes.length() == kAuthDigestLength); 252 253 // Compute expected auth bytes. 254 std::string auth_bytes = GetAuthBytes( 255 socket_.get(), is_ssl_server() ? 256 kClientAuthSslExporterLabel : kHostAuthSslExporterLabel, auth_key_); 257 if (auth_bytes.empty()) 258 return false; 259 260 return crypto::SecureMemEqual(received_auth_bytes.data(), 261 &(auth_bytes[0]), kAuthDigestLength); 262} 263 264void SslHmacChannelAuthenticator::CheckDone(bool* callback_called) { 265 if (auth_write_buf_.get() == NULL && auth_read_buf_.get() == NULL) { 266 DCHECK(socket_.get() != NULL); 267 if (callback_called) 268 *callback_called = true; 269 done_callback_.Run(net::OK, socket_.PassAs<net::StreamSocket>()); 270 } 271} 272 273void SslHmacChannelAuthenticator::NotifyError(int error) { 274 done_callback_.Run(static_cast<net::Error>(error), 275 scoped_ptr<net::StreamSocket>()); 276} 277 278} // namespace protocol 279} // namespace remoting 280