1// Copyright 2013 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "remoting/protocol/negotiating_client_authenticator.h"
6
7#include <algorithm>
8#include <sstream>
9
10#include "base/bind.h"
11#include "base/callback.h"
12#include "base/logging.h"
13#include "base/strings/string_split.h"
14#include "remoting/protocol/channel_authenticator.h"
15#include "remoting/protocol/pairing_client_authenticator.h"
16#include "remoting/protocol/v2_authenticator.h"
17#include "third_party/webrtc/libjingle/xmllite/xmlelement.h"
18
19namespace remoting {
20namespace protocol {
21
22NegotiatingClientAuthenticator::NegotiatingClientAuthenticator(
23    const std::string& client_pairing_id,
24    const std::string& shared_secret,
25    const std::string& authentication_tag,
26    const FetchSecretCallback& fetch_secret_callback,
27    scoped_ptr<ThirdPartyClientAuthenticator::TokenFetcher> token_fetcher,
28    const std::vector<AuthenticationMethod>& methods)
29    : NegotiatingAuthenticatorBase(MESSAGE_READY),
30      client_pairing_id_(client_pairing_id),
31      shared_secret_(shared_secret),
32      authentication_tag_(authentication_tag),
33      fetch_secret_callback_(fetch_secret_callback),
34      token_fetcher_(token_fetcher.Pass()),
35      method_set_by_host_(false),
36      weak_factory_(this) {
37  DCHECK(!methods.empty());
38  for (std::vector<AuthenticationMethod>::const_iterator it = methods.begin();
39       it != methods.end(); ++it) {
40    AddMethod(*it);
41  }
42}
43
44NegotiatingClientAuthenticator::~NegotiatingClientAuthenticator() {
45}
46
47void NegotiatingClientAuthenticator::ProcessMessage(
48    const buzz::XmlElement* message,
49    const base::Closure& resume_callback) {
50  DCHECK_EQ(state(), WAITING_MESSAGE);
51
52  std::string method_attr = message->Attr(kMethodAttributeQName);
53  AuthenticationMethod method = AuthenticationMethod::FromString(method_attr);
54
55  // The host picked a method different from the one the client had selected.
56  if (method != current_method_) {
57    // The host must pick a method that is valid and supported by the client,
58    // and it must not change methods after it has picked one.
59    if (method_set_by_host_ || !method.is_valid() ||
60        std::find(methods_.begin(), methods_.end(), method) == methods_.end()) {
61      state_ = REJECTED;
62      rejection_reason_ = PROTOCOL_ERROR;
63      resume_callback.Run();
64      return;
65    }
66
67    current_method_ = method;
68    method_set_by_host_ = true;
69    state_ = PROCESSING_MESSAGE;
70
71    // Copy the message since the authenticator may process it asynchronously.
72    base::Closure callback = base::Bind(
73        &NegotiatingAuthenticatorBase::ProcessMessageInternal,
74        base::Unretained(this), base::Owned(new buzz::XmlElement(*message)),
75        resume_callback);
76    CreateAuthenticatorForCurrentMethod(WAITING_MESSAGE, callback);
77    return;
78  }
79  ProcessMessageInternal(message, resume_callback);
80}
81
82scoped_ptr<buzz::XmlElement> NegotiatingClientAuthenticator::GetNextMessage() {
83  DCHECK_EQ(state(), MESSAGE_READY);
84
85  // This is the first message to the host, send a list of supported methods.
86  if (!current_method_.is_valid()) {
87    // If no authentication method has been chosen, see if we can optimistically
88    // choose one.
89    scoped_ptr<buzz::XmlElement> result;
90    CreatePreferredAuthenticator();
91    if (current_authenticator_) {
92      DCHECK(current_authenticator_->state() == MESSAGE_READY);
93      result = GetNextMessageInternal();
94    } else {
95      result = CreateEmptyAuthenticatorMessage();
96    }
97
98    // Include a list of supported methods.
99    std::stringstream supported_methods(std::stringstream::out);
100    for (std::vector<AuthenticationMethod>::iterator it = methods_.begin();
101         it != methods_.end(); ++it) {
102      if (it != methods_.begin())
103        supported_methods << kSupportedMethodsSeparator;
104      supported_methods << it->ToString();
105    }
106    result->AddAttr(kSupportedMethodsAttributeQName, supported_methods.str());
107    state_ = WAITING_MESSAGE;
108    return result.Pass();
109  }
110  return GetNextMessageInternal();
111}
112
113void NegotiatingClientAuthenticator::CreateAuthenticatorForCurrentMethod(
114    Authenticator::State preferred_initial_state,
115    const base::Closure& resume_callback) {
116  DCHECK(current_method_.is_valid());
117  if (current_method_.type() == AuthenticationMethod::THIRD_PARTY) {
118    // |ThirdPartyClientAuthenticator| takes ownership of |token_fetcher_|.
119    // The authentication method negotiation logic should guarantee that only
120    // one |ThirdPartyClientAuthenticator| will need to be created per session.
121    DCHECK(token_fetcher_);
122    current_authenticator_.reset(new ThirdPartyClientAuthenticator(
123        token_fetcher_.Pass()));
124    resume_callback.Run();
125  } else {
126    DCHECK(current_method_.type() == AuthenticationMethod::SPAKE2 ||
127           current_method_.type() == AuthenticationMethod::SPAKE2_PAIR);
128    bool pairing_supported =
129        (current_method_.type() == AuthenticationMethod::SPAKE2_PAIR);
130    SecretFetchedCallback callback = base::Bind(
131        &NegotiatingClientAuthenticator::CreateV2AuthenticatorWithSecret,
132        weak_factory_.GetWeakPtr(), preferred_initial_state, resume_callback);
133    fetch_secret_callback_.Run(pairing_supported, callback);
134  }
135}
136
137void NegotiatingClientAuthenticator::CreatePreferredAuthenticator() {
138  if (!client_pairing_id_.empty() && !shared_secret_.empty() &&
139      std::find(methods_.begin(), methods_.end(),
140                AuthenticationMethod::Spake2Pair()) != methods_.end()) {
141    // If the client specified a pairing id and shared secret, then create a
142    // PairingAuthenticator.
143    current_authenticator_.reset(new PairingClientAuthenticator(
144        client_pairing_id_, shared_secret_, fetch_secret_callback_,
145        authentication_tag_));
146    current_method_ = AuthenticationMethod::Spake2Pair();
147  }
148}
149
150void NegotiatingClientAuthenticator::CreateV2AuthenticatorWithSecret(
151    Authenticator::State initial_state,
152    const base::Closure& resume_callback,
153    const std::string& shared_secret) {
154  current_authenticator_ = V2Authenticator::CreateForClient(
155      AuthenticationMethod::ApplyHashFunction(
156          current_method_.hash_function(), authentication_tag_, shared_secret),
157      initial_state);
158  resume_callback.Run();
159}
160
161}  // namespace protocol
162}  // namespace remoting
163