1/*
2 * Copyright (C) 2009 Google Inc.  All rights reserved.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *      http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package com.google.polo.pairing;
18
19import com.google.polo.encoding.HexadecimalEncoder;
20import com.google.polo.encoding.SecretEncoder;
21import com.google.polo.exception.BadSecretException;
22import com.google.polo.exception.NoConfigurationException;
23import com.google.polo.exception.PoloException;
24import com.google.polo.exception.ProtocolErrorException;
25import com.google.polo.pairing.PairingListener.LogLevel;
26import com.google.polo.pairing.message.ConfigurationMessage;
27import com.google.polo.pairing.message.EncodingOption;
28import com.google.polo.pairing.message.OptionsMessage;
29import com.google.polo.pairing.message.OptionsMessage.ProtocolRole;
30import com.google.polo.pairing.message.PoloMessage;
31import com.google.polo.pairing.message.PoloMessage.PoloMessageType;
32import com.google.polo.pairing.message.SecretAckMessage;
33import com.google.polo.pairing.message.SecretMessage;
34import com.google.polo.wire.PoloWireInterface;
35
36import java.io.IOException;
37import java.security.NoSuchAlgorithmException;
38import java.security.SecureRandom;
39import java.security.cert.Certificate;
40import java.util.Arrays;
41import java.util.concurrent.BlockingQueue;
42import java.util.concurrent.LinkedBlockingQueue;
43import java.util.concurrent.TimeUnit;
44
45
46/**
47 * Implements the logic of and holds state for a single occurrence of the
48 * pairing protocol.
49 * <p>
50 * This abstract class implements the logic common to both client and server
51 * perspectives of the protocol.  Notably, the 'pairing' phase of the
52 * protocol has the same logic regardless of client/server status
53 * ({link PairingSession#doPairingPhase()}). Other phases of the protocol are
54 * specific to client/server status; see {@link ServerPairingSession} and
55 * {@link ClientPairingSession}.
56 * <p>
57 * The protocol is initiated by called
58 * {@link PairingSession#doPair(PairingListener)}
59 * The listener implementation is responsible for showing the shared secret
60 * to the user
61 * ({@link PairingListener#onPerformOutputDeviceRole(PairingSession, byte[])}),
62 * or in accepting the user input
63 * ({@link PairingListener#onPerformInputDeviceRole(PairingSession)}),
64 * depending on the role negotiated during initialization.
65 * <p>
66 * When operating in the input role, the session will block execution after
67 * calling {@link PairingListener#onPerformInputDeviceRole(PairingSession)} to
68 * wait for the secret.  The listener, or some activity resulting from it, must
69 * publish the input secret to the session via
70 * {@link PairingSession#setSecret(byte[])}.
71 */
72public abstract class PairingSession {
73
74  protected enum ProtocolState {
75      STATE_UNINITIALIZED,
76      STATE_INITIALIZING,
77      STATE_CONFIGURING,
78      STATE_PAIRING,
79      STATE_SUCCESS,
80      STATE_FAILURE,
81  }
82
83  /**
84   * Enable extra verbose debug logging.
85   */
86  private static final boolean DEBUG_VERBOSE = false;
87
88  /**
89   * Controls whether to verify the secret portion of the SecretAck message.
90   * <p>
91   * NOTE(mikey): One implementation does not send the secret back in
92   * the SecretAck.  This should be fixed, but in the meantime it is not
93   * essential that we verify it, since *any* acknowledgment from the
94   * sender is enough to indicate protocol success.
95   */
96  private static final boolean VERIFY_SECRET_ACK = false;
97
98  /**
99   * Timeout, in milliseconds, for polling the secret queue for a response from
100   * the listener.  This timeout is relevant only to periodically check the
101   * mAbort flag to terminate the protocol, which is set by calling teardown().
102   */
103  private static final int SECRET_POLL_TIMEOUT_MS = 500;
104
105  /**
106   * Performs the initialization phase of the protocol.
107   *
108   * @throws PoloException  if a protocol error occurred
109   * @throws IOException    if an error occurred in input/output
110   */
111  protected abstract void doInitializationPhase()
112      throws PoloException, IOException;
113
114  /**
115   * Performs the configuration phase of the protocol.
116   *
117   * @throws PoloException  if a protocol error occurred
118   * @throws IOException    if an error occurred in input/output
119   */
120  protected abstract void doConfigurationPhase()
121      throws PoloException, IOException;
122
123  /**
124   * Internal representation of challenge-response.
125   */
126  protected PoloChallengeResponse mChallenge;
127
128  /**
129   * Implementation of the transport layer.
130   */
131  private final PoloWireInterface mProtocol;
132
133  /**
134   * Context for the pairing session.
135   */
136  protected final PairingContext mPairingContext;
137
138  /**
139   * Local endpoint's supported options.
140   * <p>
141   * If this session is acting as a server, this message will be sent to the
142   * client in the Initialization phase.  If acting as a client, this member is
143   * used to store local options and compute the Configuration message (but
144   * is never transmitted directly).
145   */
146  protected OptionsMessage mLocalOptions;
147
148  /**
149   * Encoding scheme used for the session.
150   */
151  protected SecretEncoder mEncoder;
152
153  /**
154   * Name of the service being paired.
155   */
156  protected String mServiceName;
157
158  /**
159   * Name of the peer.
160   */
161  protected String mPeerName;
162
163  /**
164   * Configuration message for current session.
165   * <p>
166   * This is computed by the client and sent to the server.
167   */
168  protected ConfigurationMessage mSessionConfig;
169
170  /**
171   * Listener that will receive callbacks upon protocol events.
172   */
173  protected PairingListener mListener;
174
175  /**
176   * Internal state of the pairing session.
177   */
178  protected ProtocolState mState;
179
180  /**
181   * Threadsafe queue for receiving the messages sent by peer, user-given secret
182   * from the listener, or exceptions caught by async threads.
183   */
184  protected BlockingQueue<QueueMessage> mMessageQueue;
185
186  /**
187   * Flag set when the session should be aborted.
188   */
189  protected boolean mAbort;
190
191  /**
192   * Reader thread.
193   */
194  private final Thread mThread;
195
196  /**
197   * Constructor.
198   *
199   * @param protocol        the wire interface to operate against
200   * @param pairingContext  a PairingContext for the session
201   */
202  public PairingSession(PoloWireInterface protocol,
203      PairingContext pairingContext) {
204    mProtocol = protocol;
205    mPairingContext = pairingContext;
206    mState = ProtocolState.STATE_UNINITIALIZED;
207    mMessageQueue = new LinkedBlockingQueue<QueueMessage>();
208
209    Certificate clientCert = mPairingContext.getClientCertificate();
210    Certificate serverCert = mPairingContext.getServerCertificate();
211
212    mChallenge = new PoloChallengeResponse(clientCert, serverCert,
213        new PoloChallengeResponse.DebugLogger() {
214          public void debug(String message) {
215            logDebug(message);
216          }
217          public void verbose(String message) {
218            if (DEBUG_VERBOSE) {
219              logDebug(message);
220            }
221          }
222        });
223
224    mLocalOptions = new OptionsMessage();
225
226    if (mPairingContext.isServer()) {
227      mLocalOptions.setProtocolRolePreference(ProtocolRole.DISPLAY_DEVICE);
228    } else {
229      mLocalOptions.setProtocolRolePreference(ProtocolRole.INPUT_DEVICE);
230    }
231
232    mThread = new Thread(new Runnable() {
233      public void run() {
234        logDebug("Starting reader");
235        try {
236          while (!mAbort) {
237            try {
238              PoloMessage message = mProtocol.getNextMessage();
239              logDebug("Received: " + message.getClass());
240              mMessageQueue.put(new QueueMessage(message));
241            } catch (PoloException exception) {
242              logDebug("Exception while getting message: " + exception);
243              mMessageQueue.put(new QueueMessage(exception));
244              break;
245            } catch (IOException exception) {
246              logDebug("Exception while getting message: " + exception);
247              mMessageQueue.put(new QueueMessage(new PoloException(exception)));
248              break;
249            }
250          }
251        } catch (InterruptedException ie) {
252          logDebug("Interrupted: " + ie);
253        } finally {
254          logDebug("Reader is done");
255        }
256      }
257    });
258    mThread.start();
259  }
260
261  public void teardown() {
262    try {
263      // Send any error.
264      mProtocol.sendErrorMessage(new Exception());
265      mPairingContext.getPeerInputStream().close();
266      mPairingContext.getPeerOutputStream().close();
267    } catch (IOException e) {
268      // oh well.
269    }
270
271    // Unblock the blocking wait on the secret queue.
272    mAbort = true;
273    mThread.interrupt();
274  }
275
276  protected void log(LogLevel level, String message) {
277    if (mListener != null) {
278      mListener.onLogMessage(level, message);
279    }
280  }
281
282  /**
283   * Logs a debug message to the active listener.
284   */
285  public void logDebug(String message) {
286    log(LogLevel.LOG_DEBUG, message);
287  }
288
289  /**
290   * Logs an informational message to the active listener.
291   */
292  public void logInfo(String message) {
293    log(LogLevel.LOG_INFO, message);
294  }
295
296  /**
297   * Logs an error message to the active listener.
298   */
299  public void logError(String message) {
300    log(LogLevel.LOG_ERROR, message);
301  }
302
303  /**
304   * Adds an encoding to the supported input role encodings.  This method can
305   * only be called before the session has started.
306   * <p>
307   * If no input encodings have been added, then this endpoint cannot act as
308   * the input device protocol role.
309   *
310   * @param encoding  the {@link EncodingOption} to add
311   */
312  public void addInputEncoding(EncodingOption encoding) {
313    if (mState != ProtocolState.STATE_UNINITIALIZED) {
314      throw new IllegalStateException("Cannot add encodings once session " +
315          "has been started.");
316    }
317    // Legal values of GAMMALEN must be:
318    // - an even number of bytes
319    // - at least 2 bytes
320    if ((encoding.getSymbolLength() < 2) ||
321        ((encoding.getSymbolLength() % 2) != 0)) {
322        throw new IllegalArgumentException("Bad symbol length: " +
323            encoding.getSymbolLength());
324    }
325      mLocalOptions.addInputEncoding(encoding);
326  }
327
328  /**
329   * Adds an encoding to the supported output role encodings.  This method can
330   * only be called before the session has started.
331   * <p>
332   * If no output encodings have been added, then this endpoint cannot act as
333   * the output device protocol role.
334   *
335   * @param encoding  the {@link EncodingOption} to add
336   */
337  public void addOutputEncoding(EncodingOption encoding) {
338    if (mState != ProtocolState.STATE_UNINITIALIZED) {
339      throw new IllegalStateException("Cannot add encodings once session " +
340          "has been started.");
341    }
342    mLocalOptions.addOutputEncoding(encoding);
343  }
344
345  /**
346   * Changes the internal state.
347   *
348   * @param newState  the new state
349   */
350  private void setState(ProtocolState newState) {
351    logInfo("New state: " + newState);
352    mState = newState;
353  }
354
355  /**
356   * Runs the pairing protocol.
357   * <p>
358   * Supported input and output encodings must be specified
359   * first, using
360   * {@link PairingSession#addInputEncoding(EncodingOption)} and
361   * {@link PairingSession#addOutputEncoding(EncodingOption)},
362   * respectively.
363   *
364   * @param listener  the {@link PairingListener} for the session
365   * @return {@code true} if pairing was successful
366   */
367  public boolean doPair(PairingListener listener) {
368    mListener = listener;
369    mListener.onSessionCreated(this);
370
371    if (mPairingContext.isServer()) {
372      logDebug("Protocol started (SERVER mode)");
373    } else {
374      logDebug("Protocol started (CLIENT mode)");
375    }
376
377    logDebug("Local options: " + mLocalOptions.toString());
378
379    Certificate clientCert = mPairingContext.getClientCertificate();
380    if (DEBUG_VERBOSE) {
381      logDebug("Client certificate:");
382      logDebug(clientCert.toString());
383    }
384
385    Certificate serverCert = mPairingContext.getServerCertificate();
386
387    if (DEBUG_VERBOSE) {
388      logDebug("Server certificate:");
389      logDebug(serverCert.toString());
390    }
391
392    boolean success = false;
393
394    try {
395      setState(ProtocolState.STATE_INITIALIZING);
396      doInitializationPhase();
397
398      setState(ProtocolState.STATE_CONFIGURING);
399      doConfigurationPhase();
400
401      setState(ProtocolState.STATE_PAIRING);
402      doPairingPhase();
403
404      success = true;
405    } catch (ProtocolErrorException e) {
406      logDebug("Remote protocol failure: " + e);
407    } catch (PoloException e) {
408      try {
409        logDebug("Local protocol failure, attempting to send error: " + e);
410        mProtocol.sendErrorMessage(e);
411      } catch (IOException e1) {
412        logDebug("Error message send failed");
413      }
414    } catch (IOException e) {
415      logDebug("IOException: " + e);
416    }
417
418    if (success) {
419      setState(ProtocolState.STATE_SUCCESS);
420    } else {
421      setState(ProtocolState.STATE_FAILURE);
422    }
423
424    mListener.onSessionEnded(this);
425    return success;
426  }
427
428  /**
429   * Returns {@code true} if the session is in a terminal state (success or
430   * failure).
431   */
432  public boolean hasCompleted() {
433    switch (mState) {
434      case STATE_SUCCESS:
435      case STATE_FAILURE:
436        return true;
437      default:
438        return false;
439    }
440  }
441
442  public boolean hasSucceeded() {
443    return mState == ProtocolState.STATE_SUCCESS;
444  }
445
446  public String getServiceName() {
447    return mServiceName;
448  }
449
450  /**
451   * Sets the secret, as received from a user.  This method is only meaningful
452   * when the endpoint is acting as the input device role.
453   *
454   * @param secret  the secret, as a byte sequence
455   * @return        {@code true} if the secret was captured
456   */
457  public boolean setSecret(byte[] secret) {
458    if (!isInputDevice()) {
459      throw new IllegalStateException("Secret can only be set for " +
460          "input role session.");
461    } else if (mState != ProtocolState.STATE_PAIRING) {
462      throw new IllegalStateException("Secret can only be set while " +
463          "in pairing state.");
464    }
465    return mMessageQueue.offer(new QueueMessage(secret));
466  }
467
468  /**
469   * Executes the pairing phase of the protocol.
470   *
471   * @throws PoloException  if a protocol error occurred
472   * @throws IOException    if an error in the input/output occurred
473   */
474  protected void doPairingPhase() throws PoloException, IOException {
475    if (isInputDevice()) {
476      new Thread(new Runnable() {
477        public void run() {
478          logDebug("Calling listener for user input...");
479          try {
480            mListener.onPerformInputDeviceRole(PairingSession.this);
481          } catch (PoloException exception) {
482            logDebug("Sending exception: " + exception);
483            mMessageQueue.offer(new QueueMessage(exception));
484          } finally {
485            logDebug("Listener finished.");
486          }
487        }
488      }).start();
489
490      logDebug("Waiting for secret from Listener or ...");
491      QueueMessage message = waitForMessage();
492      if (message == null || !message.hasSecret()) {
493        throw new PoloException(
494            "Illegal state - no secret available: " + message);
495      }
496      byte[] userGamma = message.mSecret;
497      if (userGamma == null) {
498        throw new PoloException("Invalid secret.");
499      }
500
501      boolean match = mChallenge.checkGamma(userGamma);
502      if (match != true) {
503        throw new BadSecretException("Secret failed local check.");
504      }
505
506      byte[] userNonce = mChallenge.extractNonce(userGamma);
507      byte[] genAlpha = mChallenge.getAlpha(userNonce);
508
509      logDebug("Sending Secret reply...");
510      SecretMessage secretMessage = new SecretMessage(genAlpha);
511      mProtocol.sendMessage(secretMessage);
512
513      logDebug("Waiting for SecretAck...");
514      SecretAckMessage secretAck =
515          (SecretAckMessage) getNextMessage(PoloMessageType.SECRET_ACK);
516
517      if (VERIFY_SECRET_ACK) {
518        byte[] inbandAlpha = secretAck.getSecret();
519        if (!Arrays.equals(inbandAlpha, genAlpha)) {
520          throw new BadSecretException("Inband secret did not match. " +
521              "Expected [" + PoloUtil.bytesToHexString(genAlpha) +
522              "], got [" + PoloUtil.bytesToHexString(inbandAlpha) + "]");
523        }
524      }
525    } else {
526      int symbolLength = mSessionConfig.getEncoding().getSymbolLength();
527      int nonceLength = symbolLength / 2;
528      int bytesNeeded = nonceLength / mEncoder.symbolsPerByte();
529
530      byte[] nonce = new byte[bytesNeeded];
531      SecureRandom random;
532      try {
533        random = SecureRandom.getInstance("SHA1PRNG");
534      } catch (NoSuchAlgorithmException e) {
535        throw new PoloException(e);
536      }
537      random.nextBytes(nonce);
538
539      // Display gamma
540      logDebug("Calling listener to display output...");
541      byte[] gamma = mChallenge.getGamma(nonce);
542      mListener.onPerformOutputDeviceRole(this, gamma);
543
544      logDebug("Waiting for Secret...");
545      SecretMessage secretMessage =
546          (SecretMessage) getNextMessage(PoloMessageType.SECRET);
547
548      byte[] localAlpha = mChallenge.getAlpha(nonce);
549      byte[] inbandAlpha = secretMessage.getSecret();
550      boolean matched = Arrays.equals(localAlpha, inbandAlpha);
551
552      if (!matched) {
553        throw new BadSecretException("Inband secret did not match. " +
554            "Expected [" + PoloUtil.bytesToHexString(localAlpha) +
555            "], got [" + PoloUtil.bytesToHexString(inbandAlpha) + "]");
556      }
557
558      logDebug("Sending SecretAck...");
559      byte[] genAlpha = mChallenge.getAlpha(nonce);
560      SecretAckMessage secretAck = new SecretAckMessage(inbandAlpha);
561      mProtocol.sendMessage(secretAck);
562    }
563  }
564
565  public SecretEncoder getEncoder() {
566    return mEncoder;
567  }
568
569  /**
570   * Sets the current session's configuration from a
571   * {@link ConfigurationMessage}.
572   *
573   * @param message         the session's config
574   * @throws PoloException  if the config was not valid for some reason
575   */
576  protected void setConfiguration(ConfigurationMessage message)
577      throws PoloException {
578    if (message == null || message.getEncoding() == null) {
579      throw new NoConfigurationException("No configuration is possible.");
580    }
581    if (message.getEncoding().getSymbolLength() % 2 != 0) {
582      throw new PoloException("Symbol length must be even.");
583    }
584    if (message.getEncoding().getSymbolLength() < 2) {
585      throw new PoloException("Symbol length must be >= 2 symbols.");
586    }
587    switch (message.getEncoding().getType()) {
588      case ENCODING_HEXADECIMAL:
589        mEncoder = new HexadecimalEncoder();
590        break;
591      default:
592        throw new PoloException("Unsupported encoding type.");
593    }
594    mSessionConfig = message;
595  }
596
597  /**
598   * Returns the role of this endpoint in the current session.
599   */
600  protected ProtocolRole getLocalRole() {
601    assert (mSessionConfig != null);
602    if (!mPairingContext.isServer()) {
603      return mSessionConfig.getClientRole();
604    } else {
605      return (mSessionConfig.getClientRole() == ProtocolRole.DISPLAY_DEVICE) ?
606          ProtocolRole.INPUT_DEVICE : ProtocolRole.DISPLAY_DEVICE;
607    }
608  }
609
610  /**
611   * Returns {@code true} if this endpoint will act as the input device.
612   */
613  protected boolean isInputDevice() {
614    return (getLocalRole() == ProtocolRole.INPUT_DEVICE);
615  }
616
617  /**
618   * Returns {@code true} if peer's name is set.
619   */
620  public boolean hasPeerName() {
621    return mPeerName != null;
622  }
623
624  /**
625   * Returns peer's name if set, {@code null} otherwise.
626   */
627  public String getPeerName() {
628    return mPeerName;
629  }
630
631  protected PoloMessage getNextMessage(PoloMessageType type)
632      throws PoloException {
633    QueueMessage message = waitForMessage();
634    if (message != null && message.hasPoloMessage()) {
635      if (!type.equals(message.mPoloMessage.getType())) {
636        throw new PoloException(
637            "Unexpected message type: " + message.mPoloMessage.getType());
638      }
639      return message.mPoloMessage;
640    }
641    throw new PoloException("Invalid state - expected polo message");
642  }
643
644  /**
645   * Returns next queued message. The method blocks until the secret or the
646   * polo message is available.
647   *
648   * @return the queued message, or null on error
649   * @throws PoloException if exception was queued
650   */
651  private QueueMessage waitForMessage() throws PoloException {
652    while (!mAbort) {
653      try {
654        QueueMessage message = mMessageQueue.poll(SECRET_POLL_TIMEOUT_MS,
655            TimeUnit.MILLISECONDS);
656
657        if (message != null) {
658          if (message.hasPoloException()) {
659            throw new PoloException(message.mPoloException);
660          }
661          return message;
662        }
663      } catch (InterruptedException e) {
664        break;
665      }
666    }
667
668    // Aborted or interrupted.
669    return null;
670  }
671
672  /**
673   * Sends message to the peer.
674   *
675   * @param message         the message
676   * @throws PoloException  if a protocol error occurred
677   * @throws IOException    if an error in the input/output occurred
678   */
679  protected void sendMessage(PoloMessage message)
680      throws IOException, PoloException {
681    mProtocol.sendMessage(message);
682  }
683
684  /**
685   * Queued message, that can carry information about secret, next read message,
686   * or exception caught by reader or input threads.
687   */
688  private static final class QueueMessage {
689    final PoloMessage mPoloMessage;
690    final PoloException mPoloException;
691    final byte[] mSecret;
692
693    private QueueMessage(
694        PoloMessage message, byte[] secret, PoloException exception) {
695      int nonNullCount = 0;
696      if (message != null) {
697        ++nonNullCount;
698      }
699      mPoloMessage = message;
700      if (exception != null) {
701        assert(nonNullCount == 0);
702        ++nonNullCount;
703      }
704      mPoloException = exception;
705      if (secret != null) {
706        assert(nonNullCount == 0);
707        ++nonNullCount;
708      }
709      mSecret = secret;
710      assert(nonNullCount == 1);
711    }
712
713    public QueueMessage(PoloMessage message) {
714      this(message, null, null);
715    }
716
717    public QueueMessage(byte[] secret) {
718      this(null, secret, null);
719    }
720
721    public QueueMessage(PoloException exception) {
722      this(null, null, exception);
723    }
724
725    public boolean hasPoloMessage() {
726      return mPoloMessage != null;
727    }
728
729    public boolean hasPoloException() {
730      return mPoloException != null;
731    }
732
733    public boolean hasSecret() {
734      return mSecret != null;
735    }
736
737    @Override
738    public String toString() {
739      StringBuilder builder = new StringBuilder("QueueMessage(");
740      if (hasPoloMessage()) {
741        builder.append("poloMessage = " + mPoloMessage);
742      }
743      if (hasPoloException()) {
744        builder.append("poloException = " + mPoloException);
745      }
746      if (hasSecret()) {
747        builder.append("secret = " + Arrays.toString(mSecret));
748      }
749      return builder.append(")").toString();
750    }
751  }
752
753}
754