1// Copyright 2012 Google Inc. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "polo/pairing/pairingsession.h"
16
17#include <glog/logging.h>
18#include "polo/encoding/hexadecimalencoder.h"
19#include "polo/util/poloutil.h"
20
21namespace polo {
22namespace pairing {
23
24PairingSession::PairingSession(wire::PoloWireAdapter* wire,
25                               PairingContext* context,
26                               PoloChallengeResponse* challenge)
27    : state_(kUninitialized),
28      wire_(wire),
29      context_(context),
30      challenge_(challenge),
31      configuration_(NULL),
32      encoder_(NULL),
33      nonce_(NULL),
34      secret_(NULL) {
35  wire_->set_listener(this);
36
37  local_options_.set_protocol_role_preference(context->is_server() ?
38      message::OptionsMessage::kDisplayDevice
39      : message::OptionsMessage::kInputDevice);
40}
41
42PairingSession::~PairingSession() {
43  if (configuration_) {
44    delete configuration_;
45  }
46
47  if (encoder_) {
48    delete encoder_;
49  }
50
51  if (nonce_) {
52    delete nonce_;
53  }
54
55  if (secret_) {
56    delete secret_;
57  }
58}
59
60void PairingSession::AddInputEncoding(
61    const encoding::EncodingOption& encoding) {
62  if (state_ != kUninitialized) {
63    LOG(ERROR) << "Attempt to add input encoding to active session";
64    return;
65  }
66
67  if (!IsValidEncodingOption(encoding)) {
68    LOG(ERROR) << "Invalid input encoding: " << encoding.ToString();
69    return;
70  }
71
72  local_options_.AddInputEncoding(encoding);
73}
74
75void PairingSession::AddOutputEncoding(
76    const encoding::EncodingOption& encoding) {
77  if (state_ != kUninitialized) {
78    LOG(ERROR) << "Attempt to add output encoding to active session";
79    return;
80  }
81
82  if (!IsValidEncodingOption(encoding)) {
83    LOG(ERROR) << "Invalid output encoding: " << encoding.ToString();
84    return;
85  }
86
87  local_options_.AddOutputEncoding(encoding);
88}
89
90bool PairingSession::SetSecret(const Gamma& secret) {
91  secret_ = new Gamma(secret);
92
93  if (!IsInputDevice() || state_ != kWaitingForSecret) {
94    LOG(ERROR) << "Invalid state: unexpected secret";
95    return false;
96  }
97
98  if (!challenge().CheckGamma(secret)) {
99    LOG(ERROR) << "Secret failed local check";
100    return false;
101  }
102
103  nonce_ = challenge().ExtractNonce(secret);
104  if (!nonce_) {
105    LOG(ERROR) << "Failed to extract nonce";
106    return false;
107  }
108
109  const Alpha* gen_alpha = challenge().GetAlpha(*nonce_);
110  if (!gen_alpha) {
111    LOG(ERROR) << "Failed to get alpha";
112    return false;
113  }
114
115  message::SecretMessage secret_message(*gen_alpha);
116  delete gen_alpha;
117
118  wire_->SendSecretMessage(secret_message);
119
120  LOG(INFO) << "Waiting for SecretAck...";
121  wire_->GetNextMessage();
122
123  return true;
124}
125
126void PairingSession::DoPair(PairingListener *listener) {
127  listener_ = listener;
128  listener_->OnSessionCreated();
129
130  if (context_->is_server()) {
131    LOG(INFO) << "Pairing started (SERVER mode)";
132  } else {
133    LOG(INFO) << "Pairing started (CLIENT mode)";
134  }
135  LOG(INFO) << "Local options: " << local_options_.ToString();
136
137  set_state(kInitializing);
138  DoInitializationPhase();
139}
140
141void PairingSession::DoPairingPhase() {
142  if (IsInputDevice()) {
143    DoInputPairing();
144  } else {
145    DoOutputPairing();
146  }
147}
148
149void PairingSession::DoInputPairing() {
150  set_state(kWaitingForSecret);
151  listener_->OnPerformInputDeviceRole();
152}
153
154void PairingSession::DoOutputPairing() {
155  size_t nonce_length = configuration_->encoding().symbol_length() / 2;
156  size_t bytes_needed = nonce_length / encoder_->symbols_per_byte();
157
158  uint8_t* random = util::PoloUtil::GenerateRandomBytes(bytes_needed);
159  nonce_ = new Nonce(random, random + bytes_needed);
160  delete[] random;
161
162  const Gamma* gamma = challenge().GetGamma(*nonce_);
163  if (!gamma) {
164    LOG(ERROR) << "Failed to get gamma";
165    wire()->SendErrorMessage(kErrorProtocol);
166    listener()->OnError(kErrorProtocol);
167    return;
168  }
169
170  listener_->OnPerformOutputDeviceRole(*gamma);
171  delete gamma;
172
173  set_state(kWaitingForSecret);
174
175  LOG(INFO) << "Waiting for Secret...";
176  wire_->GetNextMessage();
177}
178
179void PairingSession::set_state(ProtocolState state) {
180  LOG(INFO) << "New state: " << state;
181  state_ = state;
182}
183
184bool PairingSession::SetConfiguration(
185    const message::ConfigurationMessage& message) {
186  const encoding::EncodingOption& encoding = message.encoding();
187
188  if (!IsValidEncodingOption(encoding)) {
189    LOG(ERROR) << "Invalid configuration: " << encoding.ToString();
190    return false;
191  }
192
193  if (encoder_) {
194    delete encoder_;
195    encoder_ = NULL;
196  }
197
198  switch (encoding.encoding_type()) {
199    case encoding::EncodingOption::kHexadecimal:
200      encoder_ = new encoding::HexadecimalEncoder();
201      break;
202    default:
203      LOG(ERROR) << "Unsupported encoding type: "
204          << encoding.encoding_type();
205      return false;
206  }
207
208  if (configuration_) {
209    delete configuration_;
210  }
211  configuration_ = new message::ConfigurationMessage(message.encoding(),
212                                                     message.client_role());
213  return true;
214}
215
216void PairingSession::OnSecretMessage(const message::SecretMessage& message) {
217  if (state() != kWaitingForSecret) {
218    LOG(ERROR) << "Invalid state: unexpected secret message";
219    wire()->SendErrorMessage(kErrorProtocol);
220    listener()->OnError(kErrorProtocol);
221    return;
222  }
223
224  if (!VerifySecret(message.secret())) {
225    wire()->SendErrorMessage(kErrorInvalidChallengeResponse);
226    listener_->OnError(kErrorInvalidChallengeResponse);
227    return;
228  }
229
230  const Alpha* alpha = challenge().GetAlpha(*nonce_);
231  if (!alpha) {
232    LOG(ERROR) << "Failed to get alpha";
233    wire()->SendErrorMessage(kErrorProtocol);
234    listener()->OnError(kErrorProtocol);
235    return;
236  }
237
238  message::SecretAckMessage ack(*alpha);
239  delete alpha;
240
241  wire_->SendSecretAckMessage(ack);
242
243  listener_->OnPairingSuccess();
244}
245
246void PairingSession::OnSecretAckMessage(
247    const message::SecretAckMessage& message) {
248  if (kVerifySecretAck && !VerifySecret(message.secret())) {
249    wire()->SendErrorMessage(kErrorInvalidChallengeResponse);
250    listener_->OnError(kErrorInvalidChallengeResponse);
251    return;
252  }
253
254  listener_->OnPairingSuccess();
255}
256
257void PairingSession::OnError(pairing::PoloError error) {
258  listener_->OnError(error);
259}
260
261bool PairingSession::VerifySecret(const Alpha& secret) const {
262  if (!nonce_) {
263    LOG(ERROR) << "Nonce not set";
264    return false;
265  }
266
267  const Alpha* gen_alpha = challenge().GetAlpha(*nonce_);
268  if (!gen_alpha) {
269    LOG(ERROR) << "Failed to get alpha";
270    return false;
271  }
272
273  bool valid = (secret == *gen_alpha);
274
275  if (!valid) {
276    LOG(ERROR) << "Inband secret did not match. Expected ["
277        << util::PoloUtil::BytesToHexString(&(*gen_alpha)[0], gen_alpha->size())
278        << "], got ["
279        << util::PoloUtil::BytesToHexString(&secret[0], secret.size())
280        << "]";
281  }
282
283  delete gen_alpha;
284  return valid;
285}
286
287message::OptionsMessage::ProtocolRole PairingSession::GetLocalRole() const {
288  if (!configuration_) {
289    return message::OptionsMessage::kUnknown;
290  }
291
292  if (context_->is_client()) {
293    return configuration_->client_role();
294  } else {
295    return configuration_->client_role() ==
296        message::OptionsMessage::kDisplayDevice ?
297            message::OptionsMessage::kInputDevice
298            : message::OptionsMessage::kDisplayDevice;
299  }
300}
301
302bool PairingSession::IsInputDevice() const {
303  return GetLocalRole() == message::OptionsMessage::kInputDevice;
304}
305
306bool PairingSession::IsValidEncodingOption(
307    const encoding::EncodingOption& option) const {
308  // Legal values of GAMMALEN must be an even number of at least 2 bytes.
309  return option.encoding_type() != encoding::EncodingOption::kUnknown
310      && (option.symbol_length() % 2 == 0)
311      && (option.symbol_length() >= 2);
312}
313
314}  // namespace pairing
315}  // namespace polo
316