1/*
2 * libjingle
3 * Copyright 2004--2005, Google Inc.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 *  1. Redistributions of source code must retain the above copyright notice,
9 *     this list of conditions and the following disclaimer.
10 *  2. Redistributions in binary form must reproduce the above copyright notice,
11 *     this list of conditions and the following disclaimer in the documentation
12 *     and/or other materials provided with the distribution.
13 *  3. The name of the author may not be used to endorse or promote products
14 *     derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#include "talk/base/win32.h"
29#define SECURITY_WIN32
30#include <security.h>
31#include <schannel.h>
32
33#include <iomanip>
34#include <vector>
35
36#include "talk/base/common.h"
37#include "talk/base/logging.h"
38#include "talk/base/schanneladapter.h"
39#include "talk/base/sec_buffer.h"
40#include "talk/base/thread.h"
41
42namespace talk_base {
43
44/////////////////////////////////////////////////////////////////////////////
45// SChannelAdapter
46/////////////////////////////////////////////////////////////////////////////
47
48extern const ConstantLabel SECURITY_ERRORS[];
49
50const ConstantLabel SCHANNEL_BUFFER_TYPES[] = {
51  KLABEL(SECBUFFER_EMPTY),              //  0
52  KLABEL(SECBUFFER_DATA),               //  1
53  KLABEL(SECBUFFER_TOKEN),              //  2
54  KLABEL(SECBUFFER_PKG_PARAMS),         //  3
55  KLABEL(SECBUFFER_MISSING),            //  4
56  KLABEL(SECBUFFER_EXTRA),              //  5
57  KLABEL(SECBUFFER_STREAM_TRAILER),     //  6
58  KLABEL(SECBUFFER_STREAM_HEADER),      //  7
59  KLABEL(SECBUFFER_MECHLIST),           // 11
60  KLABEL(SECBUFFER_MECHLIST_SIGNATURE), // 12
61  KLABEL(SECBUFFER_TARGET),             // 13
62  KLABEL(SECBUFFER_CHANNEL_BINDINGS),   // 14
63  LASTLABEL
64};
65
66void DescribeBuffer(LoggingSeverity severity, const char* prefix,
67                    const SecBuffer& sb) {
68  LOG_V(severity)
69    << prefix
70    << "(" << sb.cbBuffer
71    << ", " << FindLabel(sb.BufferType & ~SECBUFFER_ATTRMASK,
72                          SCHANNEL_BUFFER_TYPES)
73    << ", " << sb.pvBuffer << ")";
74}
75
76void DescribeBuffers(LoggingSeverity severity, const char* prefix,
77                     const SecBufferDesc* sbd) {
78  if (!LOG_CHECK_LEVEL_V(severity))
79    return;
80  LOG_V(severity) << prefix << "(";
81  for (size_t i=0; i<sbd->cBuffers; ++i) {
82    DescribeBuffer(severity, "  ", sbd->pBuffers[i]);
83  }
84  LOG_V(severity) << ")";
85}
86
87const ULONG SSL_FLAGS_DEFAULT = ISC_REQ_ALLOCATE_MEMORY
88                              | ISC_REQ_CONFIDENTIALITY
89                              | ISC_REQ_EXTENDED_ERROR
90                              | ISC_REQ_INTEGRITY
91                              | ISC_REQ_REPLAY_DETECT
92                              | ISC_REQ_SEQUENCE_DETECT
93                              | ISC_REQ_STREAM;
94                              //| ISC_REQ_USE_SUPPLIED_CREDS;
95
96typedef std::vector<char> SChannelBuffer;
97
98struct SChannelAdapter::SSLImpl {
99  CredHandle cred;
100  CtxtHandle ctx;
101  bool cred_init, ctx_init;
102  SChannelBuffer inbuf, outbuf, readable;
103  SecPkgContext_StreamSizes sizes;
104
105  SSLImpl() : cred_init(false), ctx_init(false) { }
106};
107
108SChannelAdapter::SChannelAdapter(AsyncSocket* socket)
109  : SSLAdapter(socket), state_(SSL_NONE),
110    restartable_(false), signal_close_(false), message_pending_(false),
111    impl_(new SSLImpl) {
112}
113
114SChannelAdapter::~SChannelAdapter() {
115  Cleanup();
116}
117
118int
119SChannelAdapter::StartSSL(const char* hostname, bool restartable) {
120  if (state_ != SSL_NONE)
121    return ERROR_ALREADY_INITIALIZED;
122
123  ssl_host_name_ = hostname;
124  restartable_ = restartable;
125
126  if (socket_->GetState() != Socket::CS_CONNECTED) {
127    state_ = SSL_WAIT;
128    return 0;
129  }
130
131  state_ = SSL_CONNECTING;
132  if (int err = BeginSSL()) {
133    Error("BeginSSL", err, false);
134    return err;
135  }
136
137  return 0;
138}
139
140int
141SChannelAdapter::BeginSSL() {
142  LOG(LS_VERBOSE) << "BeginSSL: " << ssl_host_name_;
143  ASSERT(state_ == SSL_CONNECTING);
144
145  SECURITY_STATUS ret;
146
147  SCHANNEL_CRED sc_cred = { 0 };
148  sc_cred.dwVersion = SCHANNEL_CRED_VERSION;
149  //sc_cred.dwMinimumCipherStrength = 128; // Note: use system default
150  sc_cred.dwFlags = SCH_CRED_NO_DEFAULT_CREDS | SCH_CRED_AUTO_CRED_VALIDATION;
151
152  ret = AcquireCredentialsHandle(NULL, UNISP_NAME, SECPKG_CRED_OUTBOUND, NULL,
153                                 &sc_cred, NULL, NULL, &impl_->cred, NULL);
154  if (ret != SEC_E_OK) {
155    LOG(LS_ERROR) << "AcquireCredentialsHandle error: "
156                  << ErrorName(ret, SECURITY_ERRORS);
157    return ret;
158  }
159  impl_->cred_init = true;
160
161  if (LOG_CHECK_LEVEL(LS_VERBOSE)) {
162    SecPkgCred_CipherStrengths cipher_strengths = { 0 };
163    ret = QueryCredentialsAttributes(&impl_->cred,
164                                     SECPKG_ATTR_CIPHER_STRENGTHS,
165                                     &cipher_strengths);
166    if (SUCCEEDED(ret)) {
167      LOG(LS_VERBOSE) << "SChannel cipher strength: "
168                  << cipher_strengths.dwMinimumCipherStrength << " - "
169                  << cipher_strengths.dwMaximumCipherStrength;
170    }
171
172    SecPkgCred_SupportedAlgs supported_algs = { 0 };
173    ret = QueryCredentialsAttributes(&impl_->cred,
174                                     SECPKG_ATTR_SUPPORTED_ALGS,
175                                     &supported_algs);
176    if (SUCCEEDED(ret)) {
177      LOG(LS_VERBOSE) << "SChannel supported algorithms:";
178      for (DWORD i=0; i<supported_algs.cSupportedAlgs; ++i) {
179        ALG_ID alg_id = supported_algs.palgSupportedAlgs[i];
180        PCCRYPT_OID_INFO oinfo = CryptFindOIDInfo(CRYPT_OID_INFO_ALGID_KEY,
181                                                  &alg_id, 0);
182        LPCWSTR alg_name = (NULL != oinfo) ? oinfo->pwszName : L"Unknown";
183        LOG(LS_VERBOSE) << "  " << ToUtf8(alg_name) << " (" << alg_id << ")";
184      }
185      CSecBufferBase::FreeSSPI(supported_algs.palgSupportedAlgs);
186    }
187  }
188
189  ULONG flags = SSL_FLAGS_DEFAULT, ret_flags = 0;
190  if (ignore_bad_cert())
191    flags |= ISC_REQ_MANUAL_CRED_VALIDATION;
192
193  CSecBufferBundle<2, CSecBufferBase::FreeSSPI> sb_out;
194  ret = InitializeSecurityContextA(&impl_->cred, NULL,
195                                   const_cast<char*>(ssl_host_name_.c_str()),
196                                   flags, 0, 0, NULL, 0,
197                                   &impl_->ctx, sb_out.desc(),
198                                   &ret_flags, NULL);
199  if (SUCCEEDED(ret))
200    impl_->ctx_init = true;
201  return ProcessContext(ret, NULL, sb_out.desc());
202}
203
204int
205SChannelAdapter::ContinueSSL() {
206  LOG(LS_VERBOSE) << "ContinueSSL";
207  ASSERT(state_ == SSL_CONNECTING);
208
209  SECURITY_STATUS ret;
210
211  CSecBufferBundle<2> sb_in;
212  sb_in[0].BufferType = SECBUFFER_TOKEN;
213  sb_in[0].cbBuffer = static_cast<unsigned long>(impl_->inbuf.size());
214  sb_in[0].pvBuffer = &impl_->inbuf[0];
215  //DescribeBuffers(LS_VERBOSE, "Input Buffer ", sb_in.desc());
216
217  ULONG flags = SSL_FLAGS_DEFAULT, ret_flags = 0;
218  if (ignore_bad_cert())
219    flags |= ISC_REQ_MANUAL_CRED_VALIDATION;
220
221  CSecBufferBundle<2, CSecBufferBase::FreeSSPI> sb_out;
222  ret = InitializeSecurityContextA(&impl_->cred, &impl_->ctx,
223                                   const_cast<char*>(ssl_host_name_.c_str()),
224                                   flags, 0, 0, sb_in.desc(), 0,
225                                   NULL, sb_out.desc(),
226                                   &ret_flags, NULL);
227  return ProcessContext(ret, sb_in.desc(), sb_out.desc());
228}
229
230int
231SChannelAdapter::ProcessContext(long int status, _SecBufferDesc* sbd_in,
232                                _SecBufferDesc* sbd_out) {
233  LoggingSeverity level = LS_ERROR;
234  if ((status == SEC_E_OK)
235      || (status != SEC_I_CONTINUE_NEEDED)
236      || (status != SEC_E_INCOMPLETE_MESSAGE)) {
237    level = LS_VERBOSE;  // Expected messages
238  }
239  LOG_V(level)
240    << "InitializeSecurityContext error: "
241    << ErrorName(status, SECURITY_ERRORS);
242  //if (sbd_in)
243  //  DescribeBuffers(LS_VERBOSE, "Input Buffer ", sbd_in);
244  //if (sbd_out)
245  //  DescribeBuffers(LS_VERBOSE, "Output Buffer ", sbd_out);
246
247  if (status == SEC_E_INCOMPLETE_MESSAGE) {
248    // Wait for more input from server.
249    return Flush();
250  }
251
252  if (FAILED(status)) {
253    // We can't continue.  Common errors:
254    // SEC_E_CERT_EXPIRED - Typically, this means the computer clock is wrong.
255    return status;
256  }
257
258  // Note: we check both input and output buffers for SECBUFFER_EXTRA.
259  // Experience shows it appearing in the input, but the documentation claims
260  // it should appear in the output.
261  size_t extra = 0;
262  if (sbd_in) {
263    for (size_t i=0; i<sbd_in->cBuffers; ++i) {
264      SecBuffer& buffer = sbd_in->pBuffers[i];
265      if (buffer.BufferType == SECBUFFER_EXTRA) {
266        extra += buffer.cbBuffer;
267      }
268    }
269  }
270  if (sbd_out) {
271    for (size_t i=0; i<sbd_out->cBuffers; ++i) {
272      SecBuffer& buffer = sbd_out->pBuffers[i];
273      if (buffer.BufferType == SECBUFFER_EXTRA) {
274        extra += buffer.cbBuffer;
275      } else if (buffer.BufferType == SECBUFFER_TOKEN) {
276        impl_->outbuf.insert(impl_->outbuf.end(),
277          reinterpret_cast<char*>(buffer.pvBuffer),
278          reinterpret_cast<char*>(buffer.pvBuffer) + buffer.cbBuffer);
279      }
280    }
281  }
282
283  if (extra) {
284    ASSERT(extra <= impl_->inbuf.size());
285    size_t consumed = impl_->inbuf.size() - extra;
286    memmove(&impl_->inbuf[0], &impl_->inbuf[consumed], extra);
287    impl_->inbuf.resize(extra);
288  } else {
289    impl_->inbuf.clear();
290  }
291
292  if (SEC_I_CONTINUE_NEEDED == status) {
293    // Send data to server and wait for response.
294    // Note: ContinueSSL will result in a Flush, anyway.
295    return impl_->inbuf.empty() ? Flush() : ContinueSSL();
296  }
297
298  if (SEC_E_OK == status) {
299    LOG(LS_VERBOSE) << "QueryContextAttributes";
300    status = QueryContextAttributes(&impl_->ctx, SECPKG_ATTR_STREAM_SIZES,
301                                    &impl_->sizes);
302    if (FAILED(status)) {
303      LOG(LS_ERROR) << "QueryContextAttributes error: "
304                    << ErrorName(status, SECURITY_ERRORS);
305      return status;
306    }
307
308    state_ = SSL_CONNECTED;
309
310    if (int err = DecryptData()) {
311      return err;
312    } else if (int err = Flush()) {
313      return err;
314    } else {
315      // If we decrypted any data, queue up a notification here
316      PostEvent();
317      // Signal our connectedness
318      AsyncSocketAdapter::OnConnectEvent(this);
319    }
320    return 0;
321  }
322
323  if (SEC_I_INCOMPLETE_CREDENTIALS == status) {
324    // We don't support client authentication in schannel.
325    return status;
326  }
327
328  // We don't expect any other codes
329  ASSERT(false);
330  return status;
331}
332
333int
334SChannelAdapter::DecryptData() {
335  SChannelBuffer& inbuf = impl_->inbuf;
336  SChannelBuffer& readable = impl_->readable;
337
338  while (!inbuf.empty()) {
339    CSecBufferBundle<4> in_buf;
340    in_buf[0].BufferType = SECBUFFER_DATA;
341    in_buf[0].cbBuffer = static_cast<unsigned long>(inbuf.size());
342    in_buf[0].pvBuffer = &inbuf[0];
343
344    //DescribeBuffers(LS_VERBOSE, "Decrypt In ", in_buf.desc());
345    SECURITY_STATUS status = DecryptMessage(&impl_->ctx, in_buf.desc(), 0, 0);
346    //DescribeBuffers(LS_VERBOSE, "Decrypt Out ", in_buf.desc());
347
348    // Note: We are explicitly treating SEC_E_OK, SEC_I_CONTEXT_EXPIRED, and
349    // any other successful results as continue.
350    if (SUCCEEDED(status)) {
351      size_t data_len = 0, extra_len = 0;
352      for (size_t i=0; i<in_buf.desc()->cBuffers; ++i) {
353        if (in_buf[i].BufferType == SECBUFFER_DATA) {
354          data_len += in_buf[i].cbBuffer;
355          readable.insert(readable.end(),
356            reinterpret_cast<char*>(in_buf[i].pvBuffer),
357            reinterpret_cast<char*>(in_buf[i].pvBuffer) + in_buf[i].cbBuffer);
358        } else if (in_buf[i].BufferType == SECBUFFER_EXTRA) {
359          extra_len += in_buf[i].cbBuffer;
360        }
361      }
362      // There is a bug on Win2K where SEC_I_CONTEXT_EXPIRED is misclassified.
363      if ((data_len == 0) && (inbuf[0] == 0x15)) {
364        status = SEC_I_CONTEXT_EXPIRED;
365      }
366      if (extra_len) {
367        size_t consumed = inbuf.size() - extra_len;
368        memmove(&inbuf[0], &inbuf[consumed], extra_len);
369        inbuf.resize(extra_len);
370      } else {
371        inbuf.clear();
372      }
373      // TODO: Handle SEC_I_CONTEXT_EXPIRED to do clean shutdown
374      if (status != SEC_E_OK) {
375        LOG(LS_INFO) << "DecryptMessage returned continuation code: "
376                      << ErrorName(status, SECURITY_ERRORS);
377      }
378      continue;
379    }
380
381    if (status == SEC_E_INCOMPLETE_MESSAGE) {
382      break;
383    } else {
384      return status;
385    }
386  }
387
388  return 0;
389}
390
391void
392SChannelAdapter::Cleanup() {
393  if (impl_->ctx_init)
394    DeleteSecurityContext(&impl_->ctx);
395  if (impl_->cred_init)
396    FreeCredentialsHandle(&impl_->cred);
397  delete impl_;
398}
399
400void
401SChannelAdapter::PostEvent() {
402  // Check if there's anything notable to signal
403  if (impl_->readable.empty() && !signal_close_)
404    return;
405
406  // Only one post in the queue at a time
407  if (message_pending_)
408    return;
409
410  if (Thread* thread = Thread::Current()) {
411    message_pending_ = true;
412    thread->Post(this);
413  } else {
414    LOG(LS_ERROR) << "No thread context available for SChannelAdapter";
415    ASSERT(false);
416  }
417}
418
419void
420SChannelAdapter::Error(const char* context, int err, bool signal) {
421  LOG(LS_WARNING) << "SChannelAdapter::Error("
422                  << context << ", "
423                  << ErrorName(err, SECURITY_ERRORS) << ")";
424  state_ = SSL_ERROR;
425  SetError(err);
426  if (signal)
427    AsyncSocketAdapter::OnCloseEvent(this, err);
428}
429
430int
431SChannelAdapter::Read() {
432  char buffer[4096];
433  SChannelBuffer& inbuf = impl_->inbuf;
434  while (true) {
435    int ret = AsyncSocketAdapter::Recv(buffer, sizeof(buffer));
436    if (ret > 0) {
437      inbuf.insert(inbuf.end(), buffer, buffer + ret);
438    } else if (GetError() == EWOULDBLOCK) {
439      return 0;  // Blocking
440    } else {
441      return GetError();
442    }
443  }
444}
445
446int
447SChannelAdapter::Flush() {
448  int result = 0;
449  size_t pos = 0;
450  SChannelBuffer& outbuf = impl_->outbuf;
451  while (pos < outbuf.size()) {
452    int sent = AsyncSocketAdapter::Send(&outbuf[pos], outbuf.size() - pos);
453    if (sent > 0) {
454      pos += sent;
455    } else if (GetError() == EWOULDBLOCK) {
456      break;  // Blocking
457    } else {
458      result = GetError();
459      break;
460    }
461  }
462  if (int remainder = outbuf.size() - pos) {
463    memmove(&outbuf[0], &outbuf[pos], remainder);
464    outbuf.resize(remainder);
465  } else {
466    outbuf.clear();
467  }
468  return result;
469}
470
471//
472// AsyncSocket Implementation
473//
474
475int
476SChannelAdapter::Send(const void* pv, size_t cb) {
477  switch (state_) {
478  case SSL_NONE:
479    return AsyncSocketAdapter::Send(pv, cb);
480
481  case SSL_WAIT:
482  case SSL_CONNECTING:
483    SetError(EWOULDBLOCK);
484    return SOCKET_ERROR;
485
486  case SSL_CONNECTED:
487    break;
488
489  case SSL_ERROR:
490  default:
491    return SOCKET_ERROR;
492  }
493
494  size_t written = 0;
495  SChannelBuffer& outbuf = impl_->outbuf;
496  while (written < cb) {
497    const size_t encrypt_len = std::min<size_t>(cb - written,
498                                                impl_->sizes.cbMaximumMessage);
499
500    CSecBufferBundle<4> out_buf;
501    out_buf[0].BufferType = SECBUFFER_STREAM_HEADER;
502    out_buf[0].cbBuffer = impl_->sizes.cbHeader;
503    out_buf[1].BufferType = SECBUFFER_DATA;
504    out_buf[1].cbBuffer = static_cast<unsigned long>(encrypt_len);
505    out_buf[2].BufferType = SECBUFFER_STREAM_TRAILER;
506    out_buf[2].cbBuffer = impl_->sizes.cbTrailer;
507
508    size_t packet_len = out_buf[0].cbBuffer
509                      + out_buf[1].cbBuffer
510                      + out_buf[2].cbBuffer;
511
512    SChannelBuffer message;
513    message.resize(packet_len);
514    out_buf[0].pvBuffer = &message[0];
515    out_buf[1].pvBuffer = &message[out_buf[0].cbBuffer];
516    out_buf[2].pvBuffer = &message[out_buf[0].cbBuffer + out_buf[1].cbBuffer];
517
518    memcpy(out_buf[1].pvBuffer,
519           static_cast<const char*>(pv) + written,
520           encrypt_len);
521
522    //DescribeBuffers(LS_VERBOSE, "Encrypt In ", out_buf.desc());
523    SECURITY_STATUS res = EncryptMessage(&impl_->ctx, 0, out_buf.desc(), 0);
524    //DescribeBuffers(LS_VERBOSE, "Encrypt Out ", out_buf.desc());
525
526    if (FAILED(res)) {
527      Error("EncryptMessage", res, false);
528      return SOCKET_ERROR;
529    }
530
531    // We assume that the header and data segments do not change length,
532    // or else encrypting the concatenated packet in-place is wrong.
533    ASSERT(out_buf[0].cbBuffer == impl_->sizes.cbHeader);
534    ASSERT(out_buf[1].cbBuffer == static_cast<unsigned long>(encrypt_len));
535
536    // However, the length of the trailer may change due to padding.
537    ASSERT(out_buf[2].cbBuffer <= impl_->sizes.cbTrailer);
538
539    packet_len = out_buf[0].cbBuffer
540               + out_buf[1].cbBuffer
541               + out_buf[2].cbBuffer;
542
543    written += encrypt_len;
544    outbuf.insert(outbuf.end(), &message[0], &message[packet_len-1]+1);
545  }
546
547  if (int err = Flush()) {
548    state_ = SSL_ERROR;
549    SetError(err);
550    return SOCKET_ERROR;
551  }
552
553  return static_cast<int>(written);
554}
555
556int
557SChannelAdapter::Recv(void* pv, size_t cb) {
558  switch (state_) {
559  case SSL_NONE:
560    return AsyncSocketAdapter::Recv(pv, cb);
561
562  case SSL_WAIT:
563  case SSL_CONNECTING:
564    SetError(EWOULDBLOCK);
565    return SOCKET_ERROR;
566
567  case SSL_CONNECTED:
568    break;
569
570  case SSL_ERROR:
571  default:
572    return SOCKET_ERROR;
573  }
574
575  SChannelBuffer& readable = impl_->readable;
576  if (readable.empty()) {
577    SetError(EWOULDBLOCK);
578    return SOCKET_ERROR;
579  }
580  size_t read = _min(cb, readable.size());
581  memcpy(pv, &readable[0], read);
582  if (size_t remaining = readable.size() - read) {
583    memmove(&readable[0], &readable[read], remaining);
584    readable.resize(remaining);
585  } else {
586    readable.clear();
587  }
588
589  PostEvent();
590  return static_cast<int>(read);
591}
592
593int
594SChannelAdapter::Close() {
595  if (!impl_->readable.empty()) {
596    LOG(WARNING) << "SChannelAdapter::Close with readable data";
597    // Note: this isn't strictly an error, but we're using it temporarily to
598    // track bugs.
599    //ASSERT(false);
600  }
601  if (state_ == SSL_CONNECTED) {
602    DWORD token = SCHANNEL_SHUTDOWN;
603    CSecBufferBundle<1> sb_in;
604    sb_in[0].BufferType = SECBUFFER_TOKEN;
605    sb_in[0].cbBuffer = sizeof(token);
606    sb_in[0].pvBuffer = &token;
607    ApplyControlToken(&impl_->ctx, sb_in.desc());
608    // TODO: In theory, to do a nice shutdown, we need to begin shutdown
609    // negotiation with more calls to InitializeSecurityContext.  Since the
610    // socket api doesn't support nice shutdown at this point, we don't bother.
611  }
612  Cleanup();
613  impl_ = new SSLImpl;
614  state_ = restartable_ ? SSL_WAIT : SSL_NONE;
615  signal_close_ = false;
616  message_pending_ = false;
617  return AsyncSocketAdapter::Close();
618}
619
620Socket::ConnState
621SChannelAdapter::GetState() const {
622  if (signal_close_)
623    return CS_CONNECTED;
624  ConnState state = socket_->GetState();
625  if ((state == CS_CONNECTED)
626      && ((state_ == SSL_WAIT) || (state_ == SSL_CONNECTING)))
627    state = CS_CONNECTING;
628  return state;
629}
630
631void
632SChannelAdapter::OnConnectEvent(AsyncSocket* socket) {
633  LOG(LS_VERBOSE) << "SChannelAdapter::OnConnectEvent";
634  if (state_ != SSL_WAIT) {
635    ASSERT(state_ == SSL_NONE);
636    AsyncSocketAdapter::OnConnectEvent(socket);
637    return;
638  }
639
640  state_ = SSL_CONNECTING;
641  if (int err = BeginSSL()) {
642    Error("BeginSSL", err);
643  }
644}
645
646void
647SChannelAdapter::OnReadEvent(AsyncSocket* socket) {
648  if (state_ == SSL_NONE) {
649    AsyncSocketAdapter::OnReadEvent(socket);
650    return;
651  }
652
653  if (int err = Read()) {
654    Error("Read", err);
655    return;
656  }
657
658  if (impl_->inbuf.empty())
659    return;
660
661  if (state_ == SSL_CONNECTED) {
662    if (int err = DecryptData()) {
663      Error("DecryptData", err);
664    } else if (!impl_->readable.empty()) {
665      AsyncSocketAdapter::OnReadEvent(this);
666    }
667  } else if (state_ == SSL_CONNECTING) {
668    if (int err = ContinueSSL()) {
669      Error("ContinueSSL", err);
670    }
671  }
672}
673
674void
675SChannelAdapter::OnWriteEvent(AsyncSocket* socket) {
676  if (state_ == SSL_NONE) {
677    AsyncSocketAdapter::OnWriteEvent(socket);
678    return;
679  }
680
681  if (int err = Flush()) {
682    Error("Flush", err);
683    return;
684  }
685
686  // See if we have more data to write
687  if (!impl_->outbuf.empty())
688    return;
689
690  // Buffer is empty, submit notification
691  if (state_ == SSL_CONNECTED) {
692    AsyncSocketAdapter::OnWriteEvent(socket);
693  }
694}
695
696void
697SChannelAdapter::OnCloseEvent(AsyncSocket* socket, int err) {
698  if ((state_ == SSL_NONE) || impl_->readable.empty()) {
699    AsyncSocketAdapter::OnCloseEvent(socket, err);
700    return;
701  }
702
703  // If readable is non-empty, then we have a pending Message
704  // that will allow us to signal close (eventually).
705  signal_close_ = true;
706}
707
708void
709SChannelAdapter::OnMessage(Message* pmsg) {
710  if (!message_pending_)
711    return;  // This occurs when socket is closed
712
713  message_pending_ = false;
714  if (!impl_->readable.empty()) {
715    AsyncSocketAdapter::OnReadEvent(this);
716  } else if (signal_close_) {
717    signal_close_ = false;
718    AsyncSocketAdapter::OnCloseEvent(this, 0); // TODO: cache this error?
719  }
720}
721
722} // namespace talk_base
723