1// Copyright 2012 Google Inc. All Rights Reserved. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15#include "polo/pairing/pairingsession.h" 16 17#include <glog/logging.h> 18#include "polo/encoding/hexadecimalencoder.h" 19#include "polo/util/poloutil.h" 20 21namespace polo { 22namespace pairing { 23 24PairingSession::PairingSession(wire::PoloWireAdapter* wire, 25 PairingContext* context, 26 PoloChallengeResponse* challenge) 27 : state_(kUninitialized), 28 wire_(wire), 29 context_(context), 30 challenge_(challenge), 31 configuration_(NULL), 32 encoder_(NULL), 33 nonce_(NULL), 34 secret_(NULL) { 35 wire_->set_listener(this); 36 37 local_options_.set_protocol_role_preference(context->is_server() ? 38 message::OptionsMessage::kDisplayDevice 39 : message::OptionsMessage::kInputDevice); 40} 41 42PairingSession::~PairingSession() { 43 if (configuration_) { 44 delete configuration_; 45 } 46 47 if (encoder_) { 48 delete encoder_; 49 } 50 51 if (nonce_) { 52 delete nonce_; 53 } 54 55 if (secret_) { 56 delete secret_; 57 } 58} 59 60void PairingSession::AddInputEncoding( 61 const encoding::EncodingOption& encoding) { 62 if (state_ != kUninitialized) { 63 LOG(ERROR) << "Attempt to add input encoding to active session"; 64 return; 65 } 66 67 if (!IsValidEncodingOption(encoding)) { 68 LOG(ERROR) << "Invalid input encoding: " << encoding.ToString(); 69 return; 70 } 71 72 local_options_.AddInputEncoding(encoding); 73} 74 75void PairingSession::AddOutputEncoding( 76 const encoding::EncodingOption& encoding) { 77 if (state_ != kUninitialized) { 78 LOG(ERROR) << "Attempt to add output encoding to active session"; 79 return; 80 } 81 82 if (!IsValidEncodingOption(encoding)) { 83 LOG(ERROR) << "Invalid output encoding: " << encoding.ToString(); 84 return; 85 } 86 87 local_options_.AddOutputEncoding(encoding); 88} 89 90bool PairingSession::SetSecret(const Gamma& secret) { 91 secret_ = new Gamma(secret); 92 93 if (!IsInputDevice() || state_ != kWaitingForSecret) { 94 LOG(ERROR) << "Invalid state: unexpected secret"; 95 return false; 96 } 97 98 if (!challenge().CheckGamma(secret)) { 99 LOG(ERROR) << "Secret failed local check"; 100 return false; 101 } 102 103 nonce_ = challenge().ExtractNonce(secret); 104 if (!nonce_) { 105 LOG(ERROR) << "Failed to extract nonce"; 106 return false; 107 } 108 109 const Alpha* gen_alpha = challenge().GetAlpha(*nonce_); 110 if (!gen_alpha) { 111 LOG(ERROR) << "Failed to get alpha"; 112 return false; 113 } 114 115 message::SecretMessage secret_message(*gen_alpha); 116 delete gen_alpha; 117 118 wire_->SendSecretMessage(secret_message); 119 120 LOG(INFO) << "Waiting for SecretAck..."; 121 wire_->GetNextMessage(); 122 123 return true; 124} 125 126void PairingSession::DoPair(PairingListener *listener) { 127 listener_ = listener; 128 listener_->OnSessionCreated(); 129 130 if (context_->is_server()) { 131 LOG(INFO) << "Pairing started (SERVER mode)"; 132 } else { 133 LOG(INFO) << "Pairing started (CLIENT mode)"; 134 } 135 LOG(INFO) << "Local options: " << local_options_.ToString(); 136 137 set_state(kInitializing); 138 DoInitializationPhase(); 139} 140 141void PairingSession::DoPairingPhase() { 142 if (IsInputDevice()) { 143 DoInputPairing(); 144 } else { 145 DoOutputPairing(); 146 } 147} 148 149void PairingSession::DoInputPairing() { 150 set_state(kWaitingForSecret); 151 listener_->OnPerformInputDeviceRole(); 152} 153 154void PairingSession::DoOutputPairing() { 155 size_t nonce_length = configuration_->encoding().symbol_length() / 2; 156 size_t bytes_needed = nonce_length / encoder_->symbols_per_byte(); 157 158 uint8_t* random = util::PoloUtil::GenerateRandomBytes(bytes_needed); 159 nonce_ = new Nonce(random, random + bytes_needed); 160 delete[] random; 161 162 const Gamma* gamma = challenge().GetGamma(*nonce_); 163 if (!gamma) { 164 LOG(ERROR) << "Failed to get gamma"; 165 wire()->SendErrorMessage(kErrorProtocol); 166 listener()->OnError(kErrorProtocol); 167 return; 168 } 169 170 listener_->OnPerformOutputDeviceRole(*gamma); 171 delete gamma; 172 173 set_state(kWaitingForSecret); 174 175 LOG(INFO) << "Waiting for Secret..."; 176 wire_->GetNextMessage(); 177} 178 179void PairingSession::set_state(ProtocolState state) { 180 LOG(INFO) << "New state: " << state; 181 state_ = state; 182} 183 184bool PairingSession::SetConfiguration( 185 const message::ConfigurationMessage& message) { 186 const encoding::EncodingOption& encoding = message.encoding(); 187 188 if (!IsValidEncodingOption(encoding)) { 189 LOG(ERROR) << "Invalid configuration: " << encoding.ToString(); 190 return false; 191 } 192 193 if (encoder_) { 194 delete encoder_; 195 encoder_ = NULL; 196 } 197 198 switch (encoding.encoding_type()) { 199 case encoding::EncodingOption::kHexadecimal: 200 encoder_ = new encoding::HexadecimalEncoder(); 201 break; 202 default: 203 LOG(ERROR) << "Unsupported encoding type: " 204 << encoding.encoding_type(); 205 return false; 206 } 207 208 if (configuration_) { 209 delete configuration_; 210 } 211 configuration_ = new message::ConfigurationMessage(message.encoding(), 212 message.client_role()); 213 return true; 214} 215 216void PairingSession::OnSecretMessage(const message::SecretMessage& message) { 217 if (state() != kWaitingForSecret) { 218 LOG(ERROR) << "Invalid state: unexpected secret message"; 219 wire()->SendErrorMessage(kErrorProtocol); 220 listener()->OnError(kErrorProtocol); 221 return; 222 } 223 224 if (!VerifySecret(message.secret())) { 225 wire()->SendErrorMessage(kErrorInvalidChallengeResponse); 226 listener_->OnError(kErrorInvalidChallengeResponse); 227 return; 228 } 229 230 const Alpha* alpha = challenge().GetAlpha(*nonce_); 231 if (!alpha) { 232 LOG(ERROR) << "Failed to get alpha"; 233 wire()->SendErrorMessage(kErrorProtocol); 234 listener()->OnError(kErrorProtocol); 235 return; 236 } 237 238 message::SecretAckMessage ack(*alpha); 239 delete alpha; 240 241 wire_->SendSecretAckMessage(ack); 242 243 listener_->OnPairingSuccess(); 244} 245 246void PairingSession::OnSecretAckMessage( 247 const message::SecretAckMessage& message) { 248 if (kVerifySecretAck && !VerifySecret(message.secret())) { 249 wire()->SendErrorMessage(kErrorInvalidChallengeResponse); 250 listener_->OnError(kErrorInvalidChallengeResponse); 251 return; 252 } 253 254 listener_->OnPairingSuccess(); 255} 256 257void PairingSession::OnError(pairing::PoloError error) { 258 listener_->OnError(error); 259} 260 261bool PairingSession::VerifySecret(const Alpha& secret) const { 262 if (!nonce_) { 263 LOG(ERROR) << "Nonce not set"; 264 return false; 265 } 266 267 const Alpha* gen_alpha = challenge().GetAlpha(*nonce_); 268 if (!gen_alpha) { 269 LOG(ERROR) << "Failed to get alpha"; 270 return false; 271 } 272 273 bool valid = (secret == *gen_alpha); 274 275 if (!valid) { 276 LOG(ERROR) << "Inband secret did not match. Expected [" 277 << util::PoloUtil::BytesToHexString(&(*gen_alpha)[0], gen_alpha->size()) 278 << "], got [" 279 << util::PoloUtil::BytesToHexString(&secret[0], secret.size()) 280 << "]"; 281 } 282 283 delete gen_alpha; 284 return valid; 285} 286 287message::OptionsMessage::ProtocolRole PairingSession::GetLocalRole() const { 288 if (!configuration_) { 289 return message::OptionsMessage::kUnknown; 290 } 291 292 if (context_->is_client()) { 293 return configuration_->client_role(); 294 } else { 295 return configuration_->client_role() == 296 message::OptionsMessage::kDisplayDevice ? 297 message::OptionsMessage::kInputDevice 298 : message::OptionsMessage::kDisplayDevice; 299 } 300} 301 302bool PairingSession::IsInputDevice() const { 303 return GetLocalRole() == message::OptionsMessage::kInputDevice; 304} 305 306bool PairingSession::IsValidEncodingOption( 307 const encoding::EncodingOption& option) const { 308 // Legal values of GAMMALEN must be an even number of at least 2 bytes. 309 return option.encoding_type() != encoding::EncodingOption::kUnknown 310 && (option.symbol_length() % 2 == 0) 311 && (option.symbol_length() >= 2); 312} 313 314} // namespace pairing 315} // namespace polo 316