1// Copyright 2013 The Chromium Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style license that can be 3// found in the LICENSE file. 4 5#include "remoting/protocol/negotiating_client_authenticator.h" 6 7#include <algorithm> 8#include <sstream> 9 10#include "base/bind.h" 11#include "base/callback.h" 12#include "base/logging.h" 13#include "base/strings/string_split.h" 14#include "remoting/protocol/channel_authenticator.h" 15#include "remoting/protocol/pairing_client_authenticator.h" 16#include "remoting/protocol/v2_authenticator.h" 17#include "third_party/libjingle/source/talk/xmllite/xmlelement.h" 18 19namespace remoting { 20namespace protocol { 21 22NegotiatingClientAuthenticator::NegotiatingClientAuthenticator( 23 const std::string& client_pairing_id, 24 const std::string& shared_secret, 25 const std::string& authentication_tag, 26 const FetchSecretCallback& fetch_secret_callback, 27 scoped_ptr<ThirdPartyClientAuthenticator::TokenFetcher> token_fetcher, 28 const std::vector<AuthenticationMethod>& methods) 29 : NegotiatingAuthenticatorBase(MESSAGE_READY), 30 client_pairing_id_(client_pairing_id), 31 shared_secret_(shared_secret), 32 authentication_tag_(authentication_tag), 33 fetch_secret_callback_(fetch_secret_callback), 34 token_fetcher_(token_fetcher.Pass()), 35 method_set_by_host_(false), 36 weak_factory_(this) { 37 DCHECK(!methods.empty()); 38 for (std::vector<AuthenticationMethod>::const_iterator it = methods.begin(); 39 it != methods.end(); ++it) { 40 AddMethod(*it); 41 } 42} 43 44NegotiatingClientAuthenticator::~NegotiatingClientAuthenticator() { 45} 46 47void NegotiatingClientAuthenticator::ProcessMessage( 48 const buzz::XmlElement* message, 49 const base::Closure& resume_callback) { 50 DCHECK_EQ(state(), WAITING_MESSAGE); 51 52 std::string method_attr = message->Attr(kMethodAttributeQName); 53 AuthenticationMethod method = AuthenticationMethod::FromString(method_attr); 54 55 // The host picked a method different from the one the client had selected. 56 if (method != current_method_) { 57 // The host must pick a method that is valid and supported by the client, 58 // and it must not change methods after it has picked one. 59 if (method_set_by_host_ || !method.is_valid() || 60 std::find(methods_.begin(), methods_.end(), method) == methods_.end()) { 61 state_ = REJECTED; 62 rejection_reason_ = PROTOCOL_ERROR; 63 resume_callback.Run(); 64 return; 65 } 66 67 current_method_ = method; 68 method_set_by_host_ = true; 69 state_ = PROCESSING_MESSAGE; 70 71 // Copy the message since the authenticator may process it asynchronously. 72 base::Closure callback = base::Bind( 73 &NegotiatingAuthenticatorBase::ProcessMessageInternal, 74 base::Unretained(this), base::Owned(new buzz::XmlElement(*message)), 75 resume_callback); 76 CreateAuthenticatorForCurrentMethod(WAITING_MESSAGE, callback); 77 return; 78 } 79 ProcessMessageInternal(message, resume_callback); 80} 81 82scoped_ptr<buzz::XmlElement> NegotiatingClientAuthenticator::GetNextMessage() { 83 DCHECK_EQ(state(), MESSAGE_READY); 84 85 // This is the first message to the host, send a list of supported methods. 86 if (!current_method_.is_valid()) { 87 // If no authentication method has been chosen, see if we can optimistically 88 // choose one. 89 scoped_ptr<buzz::XmlElement> result; 90 CreatePreferredAuthenticator(); 91 if (current_authenticator_) { 92 DCHECK(current_authenticator_->state() == MESSAGE_READY); 93 result = GetNextMessageInternal(); 94 } else { 95 result = CreateEmptyAuthenticatorMessage(); 96 } 97 98 // Include a list of supported methods. 99 std::stringstream supported_methods(std::stringstream::out); 100 for (std::vector<AuthenticationMethod>::iterator it = methods_.begin(); 101 it != methods_.end(); ++it) { 102 if (it != methods_.begin()) 103 supported_methods << kSupportedMethodsSeparator; 104 supported_methods << it->ToString(); 105 } 106 result->AddAttr(kSupportedMethodsAttributeQName, supported_methods.str()); 107 state_ = WAITING_MESSAGE; 108 return result.Pass(); 109 } 110 return GetNextMessageInternal(); 111} 112 113void NegotiatingClientAuthenticator::CreateAuthenticatorForCurrentMethod( 114 Authenticator::State preferred_initial_state, 115 const base::Closure& resume_callback) { 116 DCHECK(current_method_.is_valid()); 117 if (current_method_.type() == AuthenticationMethod::THIRD_PARTY) { 118 // |ThirdPartyClientAuthenticator| takes ownership of |token_fetcher_|. 119 // The authentication method negotiation logic should guarantee that only 120 // one |ThirdPartyClientAuthenticator| will need to be created per session. 121 DCHECK(token_fetcher_); 122 current_authenticator_.reset(new ThirdPartyClientAuthenticator( 123 token_fetcher_.Pass())); 124 resume_callback.Run(); 125 } else { 126 DCHECK(current_method_.type() == AuthenticationMethod::SPAKE2 || 127 current_method_.type() == AuthenticationMethod::SPAKE2_PAIR); 128 bool pairing_supported = 129 (current_method_.type() == AuthenticationMethod::SPAKE2_PAIR); 130 SecretFetchedCallback callback = base::Bind( 131 &NegotiatingClientAuthenticator::CreateV2AuthenticatorWithSecret, 132 weak_factory_.GetWeakPtr(), preferred_initial_state, resume_callback); 133 fetch_secret_callback_.Run(pairing_supported, callback); 134 } 135} 136 137void NegotiatingClientAuthenticator::CreatePreferredAuthenticator() { 138 if (!client_pairing_id_.empty() && !shared_secret_.empty() && 139 std::find(methods_.begin(), methods_.end(), 140 AuthenticationMethod::Spake2Pair()) != methods_.end()) { 141 // If the client specified a pairing id and shared secret, then create a 142 // PairingAuthenticator. 143 current_authenticator_.reset(new PairingClientAuthenticator( 144 client_pairing_id_, shared_secret_, fetch_secret_callback_, 145 authentication_tag_)); 146 current_method_ = AuthenticationMethod::Spake2Pair(); 147 } 148} 149 150void NegotiatingClientAuthenticator::CreateV2AuthenticatorWithSecret( 151 Authenticator::State initial_state, 152 const base::Closure& resume_callback, 153 const std::string& shared_secret) { 154 current_authenticator_ = V2Authenticator::CreateForClient( 155 AuthenticationMethod::ApplyHashFunction( 156 current_method_.hash_function(), authentication_tag_, shared_secret), 157 initial_state); 158 resume_callback.Run(); 159} 160 161} // namespace protocol 162} // namespace remoting 163