1/*
2 * libjingle
3 * Copyright 2004--2005, Google Inc.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 *  1. Redistributions of source code must retain the above copyright notice,
9 *     this list of conditions and the following disclaimer.
10 *  2. Redistributions in binary form must reproduce the above copyright notice,
11 *     this list of conditions and the following disclaimer in the documentation
12 *     and/or other materials provided with the distribution.
13 *  3. The name of the author may not be used to endorse or promote products
14 *     derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#include <time.h>
29
30#ifdef WIN32
31#define WIN32_LEAN_AND_MEAN
32#include <windows.h>
33#include <winsock2.h>
34#include <ws2tcpip.h>
35#define SECURITY_WIN32
36#include <security.h>
37#endif
38
39#include "talk/base/httpcommon-inl.h"
40
41#include "talk/base/base64.h"
42#include "talk/base/common.h"
43#include "talk/base/cryptstring.h"
44#include "talk/base/httpcommon.h"
45#include "talk/base/socketaddress.h"
46#include "talk/base/stringdigest.h"
47#include "talk/base/stringencode.h"
48#include "talk/base/stringutils.h"
49
50namespace talk_base {
51
52#ifdef WIN32
53extern const ConstantLabel SECURITY_ERRORS[];
54#endif
55
56//////////////////////////////////////////////////////////////////////
57// Enum - TODO: expose globally later?
58//////////////////////////////////////////////////////////////////////
59
60bool find_string(size_t& index, const std::string& needle,
61                 const char* const haystack[], size_t max_index) {
62  for (index=0; index<max_index; ++index) {
63	if (_stricmp(needle.c_str(), haystack[index]) == 0) {
64	  return true;
65	}
66  }
67  return false;
68}
69
70template<class E>
71struct Enum {
72  static const char** Names;
73  static size_t Size;
74
75  static inline const char* Name(E val) { return Names[val]; }
76  static inline bool Parse(E& val, const std::string& name) {
77	size_t index;
78	if (!find_string(index, name, Names, Size))
79	  return false;
80	val = static_cast<E>(index);
81	return true;
82  }
83
84  E val;
85
86  inline operator E&() { return val; }
87  inline Enum& operator=(E rhs) { val = rhs; return *this; }
88
89  inline const char* name() const { return Name(val); }
90  inline bool assign(const std::string& name) { return Parse(val, name); }
91  inline Enum& operator=(const std::string& rhs) { assign(rhs); return *this; }
92};
93
94#define ENUM(e,n) \
95  template<> const char** Enum<e>::Names = n; \
96  template<> size_t Enum<e>::Size = sizeof(n)/sizeof(n[0])
97
98//////////////////////////////////////////////////////////////////////
99// HttpCommon
100//////////////////////////////////////////////////////////////////////
101
102static const char* kHttpVersions[HVER_LAST+1] = {
103  "1.0", "1.1", "Unknown"
104};
105ENUM(HttpVersion, kHttpVersions);
106
107static const char* kHttpVerbs[HV_LAST+1] = {
108  "GET", "POST", "PUT", "DELETE", "CONNECT", "HEAD"
109};
110ENUM(HttpVerb, kHttpVerbs);
111
112static const char* kHttpHeaders[HH_LAST+1] = {
113  "Age",
114  "Cache-Control",
115  "Connection",
116  "Content-Disposition",
117  "Content-Length",
118  "Content-Range",
119  "Content-Type",
120  "Cookie",
121  "Date",
122  "ETag",
123  "Expires",
124  "Host",
125  "If-Modified-Since",
126  "If-None-Match",
127  "Keep-Alive",
128  "Last-Modified",
129  "Location",
130  "Proxy-Authenticate",
131  "Proxy-Authorization",
132  "Proxy-Connection",
133  "Range",
134  "Set-Cookie",
135  "TE",
136  "Trailers",
137  "Transfer-Encoding",
138  "Upgrade",
139  "User-Agent",
140  "WWW-Authenticate",
141};
142ENUM(HttpHeader, kHttpHeaders);
143
144const char* ToString(HttpVersion version) {
145  return Enum<HttpVersion>::Name(version);
146}
147
148bool FromString(HttpVersion& version, const std::string& str) {
149  return Enum<HttpVersion>::Parse(version, str);
150}
151
152const char* ToString(HttpVerb verb) {
153  return Enum<HttpVerb>::Name(verb);
154}
155
156bool FromString(HttpVerb& verb, const std::string& str) {
157  return Enum<HttpVerb>::Parse(verb, str);
158}
159
160const char* ToString(HttpHeader header) {
161  return Enum<HttpHeader>::Name(header);
162}
163
164bool FromString(HttpHeader& header, const std::string& str) {
165  return Enum<HttpHeader>::Parse(header, str);
166}
167
168bool HttpCodeHasBody(uint32 code) {
169  return !HttpCodeIsInformational(code)
170         && (code != HC_NO_CONTENT) && (code != HC_NOT_MODIFIED);
171}
172
173bool HttpCodeIsCacheable(uint32 code) {
174  switch (code) {
175  case HC_OK:
176  case HC_NON_AUTHORITATIVE:
177  case HC_PARTIAL_CONTENT:
178  case HC_MULTIPLE_CHOICES:
179  case HC_MOVED_PERMANENTLY:
180  case HC_GONE:
181    return true;
182  default:
183    return false;
184  }
185}
186
187bool HttpHeaderIsEndToEnd(HttpHeader header) {
188  switch (header) {
189  case HH_CONNECTION:
190  case HH_KEEP_ALIVE:
191  case HH_PROXY_AUTHENTICATE:
192  case HH_PROXY_AUTHORIZATION:
193  case HH_PROXY_CONNECTION:  // Note part of RFC... this is non-standard header
194  case HH_TE:
195  case HH_TRAILERS:
196  case HH_TRANSFER_ENCODING:
197  case HH_UPGRADE:
198    return false;
199  default:
200    return true;
201  }
202}
203
204bool HttpHeaderIsCollapsible(HttpHeader header) {
205  switch (header) {
206  case HH_SET_COOKIE:
207  case HH_PROXY_AUTHENTICATE:
208  case HH_WWW_AUTHENTICATE:
209    return false;
210  default:
211    return true;
212  }
213}
214
215bool HttpShouldKeepAlive(const HttpData& data) {
216  std::string connection;
217  if ((data.hasHeader(HH_PROXY_CONNECTION, &connection)
218      || data.hasHeader(HH_CONNECTION, &connection))) {
219    return (_stricmp(connection.c_str(), "Keep-Alive") == 0);
220  }
221  return (data.version >= HVER_1_1);
222}
223
224namespace {
225
226inline bool IsEndOfAttributeName(size_t pos, size_t len, const char * data) {
227  if (pos >= len)
228    return true;
229  if (isspace(static_cast<unsigned char>(data[pos])))
230    return true;
231  // The reason for this complexity is that some attributes may contain trailing
232  // equal signs (like base64 tokens in Negotiate auth headers)
233  if ((pos+1 < len) && (data[pos] == '=') &&
234      !isspace(static_cast<unsigned char>(data[pos+1])) &&
235      (data[pos+1] != '=')) {
236    return true;
237  }
238  return false;
239}
240
241// TODO: unittest for EscapeAttribute and HttpComposeAttributes.
242
243std::string EscapeAttribute(const std::string& attribute) {
244  const size_t kMaxLength = attribute.length() * 2 + 1;
245  char* buffer = STACK_ARRAY(char, kMaxLength);
246  size_t len = escape(buffer, kMaxLength, attribute.data(), attribute.length(),
247                      "\"", '\\');
248  return std::string(buffer, len);
249}
250
251}  // anonymous namespace
252
253void HttpComposeAttributes(const HttpAttributeList& attributes, char separator,
254                           std::string* composed) {
255  std::stringstream ss;
256  for (size_t i=0; i<attributes.size(); ++i) {
257    if (i > 0) {
258      ss << separator << " ";
259    }
260    ss << attributes[i].first;
261    if (!attributes[i].second.empty()) {
262      ss << "=\"" << EscapeAttribute(attributes[i].second) << "\"";
263    }
264  }
265  *composed = ss.str();
266}
267
268void HttpParseAttributes(const char * data, size_t len,
269                         HttpAttributeList& attributes) {
270  size_t pos = 0;
271  while (true) {
272    // Skip leading whitespace
273    while ((pos < len) && isspace(static_cast<unsigned char>(data[pos]))) {
274      ++pos;
275    }
276
277    // End of attributes?
278    if (pos >= len)
279      return;
280
281    // Find end of attribute name
282    size_t start = pos;
283    while (!IsEndOfAttributeName(pos, len, data)) {
284      ++pos;
285    }
286
287    HttpAttribute attribute;
288    attribute.first.assign(data + start, data + pos);
289
290    // Attribute has value?
291    if ((pos < len) && (data[pos] == '=')) {
292      ++pos; // Skip '='
293      // Check if quoted value
294      if ((pos < len) && (data[pos] == '"')) {
295        while (++pos < len) {
296          if (data[pos] == '"') {
297            ++pos;
298            break;
299          }
300          if ((data[pos] == '\\') && (pos + 1 < len))
301            ++pos;
302          attribute.second.append(1, data[pos]);
303        }
304      } else {
305        while ((pos < len) &&
306            !isspace(static_cast<unsigned char>(data[pos])) &&
307            (data[pos] != ',')) {
308          attribute.second.append(1, data[pos++]);
309        }
310      }
311    }
312
313    attributes.push_back(attribute);
314    if ((pos < len) && (data[pos] == ',')) ++pos; // Skip ','
315  }
316}
317
318bool HttpHasAttribute(const HttpAttributeList& attributes,
319                      const std::string& name,
320                      std::string* value) {
321  for (HttpAttributeList::const_iterator it = attributes.begin();
322       it != attributes.end(); ++it) {
323    if (it->first == name) {
324      if (value) {
325        *value = it->second;
326      }
327      return true;
328    }
329  }
330  return false;
331}
332
333bool HttpHasNthAttribute(HttpAttributeList& attributes,
334                         size_t index,
335                         std::string* name,
336                         std::string* value) {
337  if (index >= attributes.size())
338    return false;
339
340  if (name)
341    *name = attributes[index].first;
342  if (value)
343    *value = attributes[index].second;
344  return true;
345}
346
347bool HttpDateToSeconds(const std::string& date, unsigned long* seconds) {
348  const char* const kTimeZones[] = {
349    "UT", "GMT", "EST", "EDT", "CST", "CDT", "MST", "MDT", "PST", "PDT",
350    "A", "B", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M",
351    "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y"
352  };
353  const int kTimeZoneOffsets[] = {
354     0,  0, -5, -4, -6, -5, -7, -6, -8, -7,
355    -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12,
356     1,  2,  3,  4,  5,  6,  7,  8,  9,  10,  11,  12
357  };
358
359  ASSERT(NULL != seconds);
360  struct tm tval;
361  memset(&tval, 0, sizeof(tval));
362  char month[4], zone[6];
363  memset(month, 0, sizeof(month));
364  memset(zone, 0, sizeof(zone));
365
366  if (7 != sscanf(date.c_str(), "%*3s, %d %3s %d %d:%d:%d %5c",
367                  &tval.tm_mday, month, &tval.tm_year,
368                  &tval.tm_hour, &tval.tm_min, &tval.tm_sec, zone)) {
369    return false;
370  }
371  switch (toupper(month[2])) {
372  case 'N': tval.tm_mon = (month[1] == 'A') ? 0 : 5; break;
373  case 'B': tval.tm_mon = 1; break;
374  case 'R': tval.tm_mon = (month[0] == 'M') ? 2 : 3; break;
375  case 'Y': tval.tm_mon = 4; break;
376  case 'L': tval.tm_mon = 6; break;
377  case 'G': tval.tm_mon = 7; break;
378  case 'P': tval.tm_mon = 8; break;
379  case 'T': tval.tm_mon = 9; break;
380  case 'V': tval.tm_mon = 10; break;
381  case 'C': tval.tm_mon = 11; break;
382  }
383  tval.tm_year -= 1900;
384  unsigned long gmt, non_gmt = mktime(&tval);
385  if ((zone[0] == '+') || (zone[0] == '-')) {
386    if (!isdigit(zone[1]) || !isdigit(zone[2])
387        || !isdigit(zone[3]) || !isdigit(zone[4])) {
388      return false;
389    }
390    int hours = (zone[1] - '0') * 10 + (zone[2] - '0');
391    int minutes = (zone[3] - '0') * 10 + (zone[4] - '0');
392    int offset = (hours * 60 + minutes) * 60;
393    gmt = non_gmt + (zone[0] == '+') ? offset : -offset;
394  } else {
395    size_t zindex;
396    if (!find_string(zindex, zone, kTimeZones, ARRAY_SIZE(kTimeZones))) {
397      return false;
398    }
399    gmt = non_gmt + kTimeZoneOffsets[zindex] * 60 * 60;
400  }
401  // TODO: Android should support timezone, see b/2441195
402#if defined(OSX) || defined(ANDROID) || defined(BSD)
403  tm *tm_for_timezone = localtime((time_t *)&gmt);
404  *seconds = gmt + tm_for_timezone->tm_gmtoff;
405#else
406  *seconds = gmt - timezone;
407#endif
408  return true;
409}
410
411std::string HttpAddress(const SocketAddress& address, bool secure) {
412  return (address.port() == HttpDefaultPort(secure))
413          ? address.hostname() : address.ToString();
414}
415
416//////////////////////////////////////////////////////////////////////
417// HttpData
418//////////////////////////////////////////////////////////////////////
419
420void
421HttpData::clear(bool release_document) {
422  // Clear headers first, since releasing a document may have far-reaching
423  // effects.
424  headers_.clear();
425  if (release_document) {
426    document.reset();
427  }
428}
429
430void
431HttpData::copy(const HttpData& src) {
432  headers_ = src.headers_;
433}
434
435void
436HttpData::changeHeader(const std::string& name, const std::string& value,
437                       HeaderCombine combine) {
438  if (combine == HC_AUTO) {
439    HttpHeader header;
440    // Unrecognized headers are collapsible
441    combine = !FromString(header, name) || HttpHeaderIsCollapsible(header)
442              ? HC_YES : HC_NO;
443  } else if (combine == HC_REPLACE) {
444    headers_.erase(name);
445    combine = HC_NO;
446  }
447  // At this point, combine is one of (YES, NO, NEW)
448  if (combine != HC_NO) {
449    HeaderMap::iterator it = headers_.find(name);
450    if (it != headers_.end()) {
451      if (combine == HC_YES) {
452        it->second.append(",");
453        it->second.append(value);
454	  }
455      return;
456	}
457  }
458  headers_.insert(HeaderMap::value_type(name, value));
459}
460
461size_t HttpData::clearHeader(const std::string& name) {
462  return headers_.erase(name);
463}
464
465HttpData::iterator HttpData::clearHeader(iterator header) {
466  iterator deprecated = header++;
467  headers_.erase(deprecated);
468  return header;
469}
470
471bool
472HttpData::hasHeader(const std::string& name, std::string* value) const {
473  HeaderMap::const_iterator it = headers_.find(name);
474  if (it == headers_.end()) {
475    return false;
476  } else if (value) {
477    *value = it->second;
478  }
479  return true;
480}
481
482void HttpData::setContent(const std::string& content_type,
483                          StreamInterface* document) {
484  setHeader(HH_CONTENT_TYPE, content_type);
485  setDocumentAndLength(document);
486}
487
488void HttpData::setDocumentAndLength(StreamInterface* document) {
489  // TODO: Consider calling Rewind() here?
490  ASSERT(!hasHeader(HH_CONTENT_LENGTH, NULL));
491  ASSERT(!hasHeader(HH_TRANSFER_ENCODING, NULL));
492  ASSERT(document != NULL);
493  this->document.reset(document);
494  size_t content_length = 0;
495  if (this->document->GetAvailable(&content_length)) {
496    char buffer[32];
497    sprintfn(buffer, sizeof(buffer), "%d", content_length);
498    setHeader(HH_CONTENT_LENGTH, buffer);
499  } else {
500    setHeader(HH_TRANSFER_ENCODING, "chunked");
501  }
502}
503
504//
505// HttpRequestData
506//
507
508void
509HttpRequestData::clear(bool release_document) {
510  verb = HV_GET;
511  path.clear();
512  HttpData::clear(release_document);
513}
514
515void
516HttpRequestData::copy(const HttpRequestData& src) {
517  verb = src.verb;
518  path = src.path;
519  HttpData::copy(src);
520}
521
522size_t
523HttpRequestData::formatLeader(char* buffer, size_t size) const {
524  ASSERT(path.find(' ') == std::string::npos);
525  return sprintfn(buffer, size, "%s %.*s HTTP/%s", ToString(verb), path.size(),
526                  path.data(), ToString(version));
527}
528
529HttpError
530HttpRequestData::parseLeader(const char* line, size_t len) {
531  UNUSED(len);
532  unsigned int vmajor, vminor;
533  int vend, dstart, dend;
534  if ((sscanf(line, "%*s%n %n%*s%n HTTP/%u.%u", &vend, &dstart, &dend,
535              &vmajor, &vminor) != 2)
536      || (vmajor != 1)) {
537    return HE_PROTOCOL;
538  }
539  if (vminor == 0) {
540    version = HVER_1_0;
541  } else if (vminor == 1) {
542    version = HVER_1_1;
543  } else {
544    return HE_PROTOCOL;
545  }
546  std::string sverb(line, vend);
547  if (!FromString(verb, sverb.c_str())) {
548    return HE_PROTOCOL; // !?! HC_METHOD_NOT_SUPPORTED?
549  }
550  path.assign(line + dstart, line + dend);
551  return HE_NONE;
552}
553
554bool HttpRequestData::getAbsoluteUri(std::string* uri) const {
555  if (HV_CONNECT == verb)
556    return false;
557  Url<char> url(path);
558  if (url.valid()) {
559    uri->assign(path);
560    return true;
561  }
562  std::string host;
563  if (!hasHeader(HH_HOST, &host))
564    return false;
565  url.set_address(host);
566  url.set_full_path(path);
567  uri->assign(url.url());
568  return url.valid();
569}
570
571bool HttpRequestData::getRelativeUri(std::string* host,
572                                     std::string* path) const
573{
574  if (HV_CONNECT == verb)
575    return false;
576  Url<char> url(this->path);
577  if (url.valid()) {
578    host->assign(url.address());
579    path->assign(url.full_path());
580    return true;
581  }
582  if (!hasHeader(HH_HOST, host))
583    return false;
584  path->assign(this->path);
585  return true;
586}
587
588//
589// HttpResponseData
590//
591
592void
593HttpResponseData::clear(bool release_document) {
594  scode = HC_INTERNAL_SERVER_ERROR;
595  message.clear();
596  HttpData::clear(release_document);
597}
598
599void
600HttpResponseData::copy(const HttpResponseData& src) {
601  scode = src.scode;
602  message = src.message;
603  HttpData::copy(src);
604}
605
606void
607HttpResponseData::set_success(uint32 scode) {
608  this->scode = scode;
609  message.clear();
610  setHeader(HH_CONTENT_LENGTH, "0", false);
611}
612
613void
614HttpResponseData::set_success(const std::string& content_type,
615                              StreamInterface* document,
616                              uint32 scode) {
617  this->scode = scode;
618  message.erase(message.begin(), message.end());
619  setContent(content_type, document);
620}
621
622void
623HttpResponseData::set_redirect(const std::string& location, uint32 scode) {
624  this->scode = scode;
625  message.clear();
626  setHeader(HH_LOCATION, location);
627  setHeader(HH_CONTENT_LENGTH, "0", false);
628}
629
630void
631HttpResponseData::set_error(uint32 scode) {
632  this->scode = scode;
633  message.clear();
634  setHeader(HH_CONTENT_LENGTH, "0", false);
635}
636
637size_t
638HttpResponseData::formatLeader(char* buffer, size_t size) const {
639  size_t len = sprintfn(buffer, size, "HTTP/%s %lu", ToString(version), scode);
640  if (!message.empty()) {
641    len += sprintfn(buffer + len, size - len, " %.*s",
642                    message.size(), message.data());
643  }
644  return len;
645}
646
647HttpError
648HttpResponseData::parseLeader(const char* line, size_t len) {
649  size_t pos = 0;
650  unsigned int vmajor, vminor, temp_scode;
651  int temp_pos;
652  if (sscanf(line, "HTTP %u%n",
653             &temp_scode, &temp_pos) == 1) {
654    // This server's response has no version. :( NOTE: This happens for every
655    // response to requests made from Chrome plugins, regardless of the server's
656    // behaviour.
657    LOG(LS_VERBOSE) << "HTTP version missing from response";
658    version = HVER_UNKNOWN;
659  } else if ((sscanf(line, "HTTP/%u.%u %u%n",
660                     &vmajor, &vminor, &temp_scode, &temp_pos) == 3)
661             && (vmajor == 1)) {
662    // This server's response does have a version.
663    if (vminor == 0) {
664      version = HVER_1_0;
665    } else if (vminor == 1) {
666      version = HVER_1_1;
667    } else {
668      return HE_PROTOCOL;
669    }
670  } else {
671    return HE_PROTOCOL;
672  }
673  scode = temp_scode;
674  pos = static_cast<size_t>(temp_pos);
675  while ((pos < len) && isspace(static_cast<unsigned char>(line[pos]))) ++pos;
676  message.assign(line + pos, len - pos);
677  return HE_NONE;
678}
679
680//////////////////////////////////////////////////////////////////////
681// Http Authentication
682//////////////////////////////////////////////////////////////////////
683
684#define TEST_DIGEST 0
685#if TEST_DIGEST
686/*
687const char * const DIGEST_CHALLENGE =
688  "Digest realm=\"testrealm@host.com\","
689  " qop=\"auth,auth-int\","
690  " nonce=\"dcd98b7102dd2f0e8b11d0f600bfb0c093\","
691  " opaque=\"5ccc069c403ebaf9f0171e9517f40e41\"";
692const char * const DIGEST_METHOD = "GET";
693const char * const DIGEST_URI =
694  "/dir/index.html";;
695const char * const DIGEST_CNONCE =
696  "0a4f113b";
697const char * const DIGEST_RESPONSE =
698  "6629fae49393a05397450978507c4ef1";
699//user_ = "Mufasa";
700//pass_ = "Circle Of Life";
701*/
702const char * const DIGEST_CHALLENGE =
703  "Digest realm=\"Squid proxy-caching web server\","
704  " nonce=\"Nny4QuC5PwiSDixJ\","
705  " qop=\"auth\","
706  " stale=false";
707const char * const DIGEST_URI =
708  "/";
709const char * const DIGEST_CNONCE =
710  "6501d58e9a21cee1e7b5fec894ded024";
711const char * const DIGEST_RESPONSE =
712  "edffcb0829e755838b073a4a42de06bc";
713#endif
714
715std::string quote(const std::string& str) {
716  std::string result;
717  result.push_back('"');
718  for (size_t i=0; i<str.size(); ++i) {
719    if ((str[i] == '"') || (str[i] == '\\'))
720      result.push_back('\\');
721    result.push_back(str[i]);
722  }
723  result.push_back('"');
724  return result;
725}
726
727#ifdef WIN32
728struct NegotiateAuthContext : public HttpAuthContext {
729  CredHandle cred;
730  CtxtHandle ctx;
731  size_t steps;
732  bool specified_credentials;
733
734  NegotiateAuthContext(const std::string& auth, CredHandle c1, CtxtHandle c2)
735  : HttpAuthContext(auth), cred(c1), ctx(c2), steps(0),
736    specified_credentials(false)
737  { }
738
739  virtual ~NegotiateAuthContext() {
740    DeleteSecurityContext(&ctx);
741    FreeCredentialsHandle(&cred);
742  }
743};
744#endif // WIN32
745
746HttpAuthResult HttpAuthenticate(
747  const char * challenge, size_t len,
748  const SocketAddress& server,
749  const std::string& method, const std::string& uri,
750  const std::string& username, const CryptString& password,
751  HttpAuthContext *& context, std::string& response, std::string& auth_method)
752{
753#if TEST_DIGEST
754  challenge = DIGEST_CHALLENGE;
755  len = strlen(challenge);
756#endif
757
758  HttpAttributeList args;
759  HttpParseAttributes(challenge, len, args);
760  HttpHasNthAttribute(args, 0, &auth_method, NULL);
761
762  if (context && (context->auth_method != auth_method))
763    return HAR_IGNORE;
764
765  // BASIC
766  if (_stricmp(auth_method.c_str(), "basic") == 0) {
767    if (context)
768      return HAR_CREDENTIALS; // Bad credentials
769    if (username.empty())
770      return HAR_CREDENTIALS; // Missing credentials
771
772    context = new HttpAuthContext(auth_method);
773
774    // TODO: convert sensitive to a secure buffer that gets securely deleted
775    //std::string decoded = username + ":" + password;
776    size_t len = username.size() + password.GetLength() + 2;
777    char * sensitive = new char[len];
778    size_t pos = strcpyn(sensitive, len, username.data(), username.size());
779    pos += strcpyn(sensitive + pos, len - pos, ":");
780    password.CopyTo(sensitive + pos, true);
781
782    response = auth_method;
783    response.append(" ");
784    // TODO: create a sensitive-source version of Base64::encode
785    response.append(Base64::Encode(sensitive));
786    memset(sensitive, 0, len);
787    delete [] sensitive;
788    return HAR_RESPONSE;
789  }
790
791  // DIGEST
792  if (_stricmp(auth_method.c_str(), "digest") == 0) {
793    if (context)
794      return HAR_CREDENTIALS; // Bad credentials
795    if (username.empty())
796      return HAR_CREDENTIALS; // Missing credentials
797
798    context = new HttpAuthContext(auth_method);
799
800    std::string cnonce, ncount;
801#if TEST_DIGEST
802    method = DIGEST_METHOD;
803    uri    = DIGEST_URI;
804    cnonce = DIGEST_CNONCE;
805#else
806    char buffer[256];
807    sprintf(buffer, "%d", static_cast<int>(time(0)));
808    cnonce = MD5(buffer);
809#endif
810    ncount = "00000001";
811
812    std::string realm, nonce, qop, opaque;
813    HttpHasAttribute(args, "realm", &realm);
814    HttpHasAttribute(args, "nonce", &nonce);
815    bool has_qop = HttpHasAttribute(args, "qop", &qop);
816    bool has_opaque = HttpHasAttribute(args, "opaque", &opaque);
817
818    // TODO: convert sensitive to be secure buffer
819    //std::string A1 = username + ":" + realm + ":" + password;
820    size_t len = username.size() + realm.size() + password.GetLength() + 3;
821    char * sensitive = new char[len];  // A1
822    size_t pos = strcpyn(sensitive, len, username.data(), username.size());
823    pos += strcpyn(sensitive + pos, len - pos, ":");
824    pos += strcpyn(sensitive + pos, len - pos, realm.c_str());
825    pos += strcpyn(sensitive + pos, len - pos, ":");
826    password.CopyTo(sensitive + pos, true);
827
828    std::string A2 = method + ":" + uri;
829    std::string middle;
830    if (has_qop) {
831      qop = "auth";
832      middle = nonce + ":" + ncount + ":" + cnonce + ":" + qop;
833    } else {
834      middle = nonce;
835    }
836    std::string HA1 = MD5(sensitive);
837    memset(sensitive, 0, len);
838    delete [] sensitive;
839    std::string HA2 = MD5(A2);
840    std::string dig_response = MD5(HA1 + ":" + middle + ":" + HA2);
841
842#if TEST_DIGEST
843    ASSERT(strcmp(dig_response.c_str(), DIGEST_RESPONSE) == 0);
844#endif
845
846    std::stringstream ss;
847    ss << auth_method;
848    ss << " username=" << quote(username);
849    ss << ", realm=" << quote(realm);
850    ss << ", nonce=" << quote(nonce);
851    ss << ", uri=" << quote(uri);
852    if (has_qop) {
853      ss << ", qop=" << qop;
854      ss << ", nc="  << ncount;
855      ss << ", cnonce=" << quote(cnonce);
856    }
857    ss << ", response=\"" << dig_response << "\"";
858    if (has_opaque) {
859      ss << ", opaque=" << quote(opaque);
860    }
861    response = ss.str();
862    return HAR_RESPONSE;
863  }
864
865#ifdef WIN32
866#if 1
867  bool want_negotiate = (_stricmp(auth_method.c_str(), "negotiate") == 0);
868  bool want_ntlm = (_stricmp(auth_method.c_str(), "ntlm") == 0);
869  // SPNEGO & NTLM
870  if (want_negotiate || want_ntlm) {
871    const size_t MAX_MESSAGE = 12000, MAX_SPN = 256;
872    char out_buf[MAX_MESSAGE], spn[MAX_SPN];
873
874#if 0 // Requires funky windows versions
875    DWORD len = MAX_SPN;
876    if (DsMakeSpn("HTTP", server.IPAsString().c_str(), NULL, server.port(),
877                  0, &len, spn) != ERROR_SUCCESS) {
878      LOG_F(WARNING) << "(Negotiate) - DsMakeSpn failed";
879      return HAR_IGNORE;
880    }
881#else
882    sprintfn(spn, MAX_SPN, "HTTP/%s", server.ToString().c_str());
883#endif
884
885    SecBuffer out_sec;
886    out_sec.pvBuffer   = out_buf;
887    out_sec.cbBuffer   = sizeof(out_buf);
888    out_sec.BufferType = SECBUFFER_TOKEN;
889
890    SecBufferDesc out_buf_desc;
891    out_buf_desc.ulVersion = 0;
892    out_buf_desc.cBuffers  = 1;
893    out_buf_desc.pBuffers  = &out_sec;
894
895    const ULONG NEG_FLAGS_DEFAULT =
896      //ISC_REQ_ALLOCATE_MEMORY
897      ISC_REQ_CONFIDENTIALITY
898      //| ISC_REQ_EXTENDED_ERROR
899      //| ISC_REQ_INTEGRITY
900      | ISC_REQ_REPLAY_DETECT
901      | ISC_REQ_SEQUENCE_DETECT
902      //| ISC_REQ_STREAM
903      //| ISC_REQ_USE_SUPPLIED_CREDS
904      ;
905
906    ::TimeStamp lifetime;
907    SECURITY_STATUS ret = S_OK;
908    ULONG ret_flags = 0, flags = NEG_FLAGS_DEFAULT;
909
910    bool specify_credentials = !username.empty();
911    size_t steps = 0;
912
913    //uint32 now = Time();
914
915    NegotiateAuthContext * neg = static_cast<NegotiateAuthContext *>(context);
916    if (neg) {
917      const size_t max_steps = 10;
918      if (++neg->steps >= max_steps) {
919        LOG(WARNING) << "AsyncHttpsProxySocket::Authenticate(Negotiate) too many retries";
920        return HAR_ERROR;
921      }
922      steps = neg->steps;
923
924      std::string challenge, decoded_challenge;
925      if (HttpHasNthAttribute(args, 1, &challenge, NULL)
926          && Base64::Decode(challenge, Base64::DO_STRICT,
927                            &decoded_challenge, NULL)) {
928        SecBuffer in_sec;
929        in_sec.pvBuffer   = const_cast<char *>(decoded_challenge.data());
930        in_sec.cbBuffer   = static_cast<unsigned long>(decoded_challenge.size());
931        in_sec.BufferType = SECBUFFER_TOKEN;
932
933        SecBufferDesc in_buf_desc;
934        in_buf_desc.ulVersion = 0;
935        in_buf_desc.cBuffers  = 1;
936        in_buf_desc.pBuffers  = &in_sec;
937
938        ret = InitializeSecurityContextA(&neg->cred, &neg->ctx, spn, flags, 0, SECURITY_NATIVE_DREP, &in_buf_desc, 0, &neg->ctx, &out_buf_desc, &ret_flags, &lifetime);
939        //LOG(INFO) << "$$$ InitializeSecurityContext @ " << TimeSince(now);
940        if (FAILED(ret)) {
941          LOG(LS_ERROR) << "InitializeSecurityContext returned: "
942                      << ErrorName(ret, SECURITY_ERRORS);
943          return HAR_ERROR;
944        }
945      } else if (neg->specified_credentials) {
946        // Try again with default credentials
947        specify_credentials = false;
948        delete context;
949        context = neg = 0;
950      } else {
951        return HAR_CREDENTIALS;
952      }
953    }
954
955    if (!neg) {
956      unsigned char userbuf[256], passbuf[256], domainbuf[16];
957      SEC_WINNT_AUTH_IDENTITY_A auth_id, * pauth_id = 0;
958      if (specify_credentials) {
959        memset(&auth_id, 0, sizeof(auth_id));
960        size_t len = password.GetLength()+1;
961        char * sensitive = new char[len];
962        password.CopyTo(sensitive, true);
963        std::string::size_type pos = username.find('\\');
964        if (pos == std::string::npos) {
965          auth_id.UserLength = static_cast<unsigned long>(
966            _min(sizeof(userbuf) - 1, username.size()));
967          memcpy(userbuf, username.c_str(), auth_id.UserLength);
968          userbuf[auth_id.UserLength] = 0;
969          auth_id.DomainLength = 0;
970          domainbuf[auth_id.DomainLength] = 0;
971          auth_id.PasswordLength = static_cast<unsigned long>(
972            _min(sizeof(passbuf) - 1, password.GetLength()));
973          memcpy(passbuf, sensitive, auth_id.PasswordLength);
974          passbuf[auth_id.PasswordLength] = 0;
975        } else {
976          auth_id.UserLength = static_cast<unsigned long>(
977            _min(sizeof(userbuf) - 1, username.size() - pos - 1));
978          memcpy(userbuf, username.c_str() + pos + 1, auth_id.UserLength);
979          userbuf[auth_id.UserLength] = 0;
980          auth_id.DomainLength = static_cast<unsigned long>(
981            _min(sizeof(domainbuf) - 1, pos));
982          memcpy(domainbuf, username.c_str(), auth_id.DomainLength);
983          domainbuf[auth_id.DomainLength] = 0;
984          auth_id.PasswordLength = static_cast<unsigned long>(
985            _min(sizeof(passbuf) - 1, password.GetLength()));
986          memcpy(passbuf, sensitive, auth_id.PasswordLength);
987          passbuf[auth_id.PasswordLength] = 0;
988        }
989        memset(sensitive, 0, len);
990        delete [] sensitive;
991        auth_id.User = userbuf;
992        auth_id.Domain = domainbuf;
993        auth_id.Password = passbuf;
994        auth_id.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI;
995        pauth_id = &auth_id;
996        LOG(LS_VERBOSE) << "Negotiate protocol: Using specified credentials";
997      } else {
998        LOG(LS_VERBOSE) << "Negotiate protocol: Using default credentials";
999      }
1000
1001      CredHandle cred;
1002      ret = AcquireCredentialsHandleA(0, want_negotiate ? NEGOSSP_NAME_A : NTLMSP_NAME_A, SECPKG_CRED_OUTBOUND, 0, pauth_id, 0, 0, &cred, &lifetime);
1003      //LOG(INFO) << "$$$ AcquireCredentialsHandle @ " << TimeSince(now);
1004      if (ret != SEC_E_OK) {
1005        LOG(LS_ERROR) << "AcquireCredentialsHandle error: "
1006                    << ErrorName(ret, SECURITY_ERRORS);
1007        return HAR_IGNORE;
1008      }
1009
1010      //CSecBufferBundle<5, CSecBufferBase::FreeSSPI> sb_out;
1011
1012      CtxtHandle ctx;
1013      ret = InitializeSecurityContextA(&cred, 0, spn, flags, 0, SECURITY_NATIVE_DREP, 0, 0, &ctx, &out_buf_desc, &ret_flags, &lifetime);
1014      //LOG(INFO) << "$$$ InitializeSecurityContext @ " << TimeSince(now);
1015      if (FAILED(ret)) {
1016        LOG(LS_ERROR) << "InitializeSecurityContext returned: "
1017                    << ErrorName(ret, SECURITY_ERRORS);
1018        FreeCredentialsHandle(&cred);
1019        return HAR_IGNORE;
1020      }
1021
1022      ASSERT(!context);
1023      context = neg = new NegotiateAuthContext(auth_method, cred, ctx);
1024      neg->specified_credentials = specify_credentials;
1025      neg->steps = steps;
1026    }
1027
1028    if ((ret == SEC_I_COMPLETE_NEEDED) || (ret == SEC_I_COMPLETE_AND_CONTINUE)) {
1029      ret = CompleteAuthToken(&neg->ctx, &out_buf_desc);
1030      //LOG(INFO) << "$$$ CompleteAuthToken @ " << TimeSince(now);
1031      LOG(LS_VERBOSE) << "CompleteAuthToken returned: "
1032                      << ErrorName(ret, SECURITY_ERRORS);
1033      if (FAILED(ret)) {
1034        return HAR_ERROR;
1035      }
1036    }
1037
1038    //LOG(INFO) << "$$$ NEGOTIATE took " << TimeSince(now) << "ms";
1039
1040    std::string decoded(out_buf, out_buf + out_sec.cbBuffer);
1041    response = auth_method;
1042    response.append(" ");
1043    response.append(Base64::Encode(decoded));
1044    return HAR_RESPONSE;
1045  }
1046#endif
1047#endif // WIN32
1048
1049  return HAR_IGNORE;
1050}
1051
1052//////////////////////////////////////////////////////////////////////
1053
1054} // namespace talk_base
1055