1/*
2 * libjingle
3 * Copyright 2004--2005, Google Inc.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 *
8 *  1. Redistributions of source code must retain the above copyright notice,
9 *     this list of conditions and the following disclaimer.
10 *  2. Redistributions in binary form must reproduce the above copyright notice,
11 *     this list of conditions and the following disclaimer in the documentation
12 *     and/or other materials provided with the distribution.
13 *  3. The name of the author may not be used to endorse or promote products
14 *     derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17 * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18 * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19 * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21 * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22 * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24 * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25 * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 */
27
28#include <time.h>
29
30#ifdef WIN32
31#define WIN32_LEAN_AND_MEAN
32#include <windows.h>
33#include <winsock2.h>
34#include <ws2tcpip.h>
35#define SECURITY_WIN32
36#include <security.h>
37#endif
38
39#include "talk/base/httpcommon-inl.h"
40
41#include "talk/base/base64.h"
42#include "talk/base/common.h"
43#include "talk/base/cryptstring.h"
44#include "talk/base/httpcommon.h"
45#include "talk/base/socketaddress.h"
46#include "talk/base/stringdigest.h"
47#include "talk/base/stringencode.h"
48#include "talk/base/stringutils.h"
49
50namespace talk_base {
51
52#ifdef WIN32
53extern const ConstantLabel SECURITY_ERRORS[];
54#endif
55
56//////////////////////////////////////////////////////////////////////
57// Enum - TODO: expose globally later?
58//////////////////////////////////////////////////////////////////////
59
60bool find_string(size_t& index, const std::string& needle,
61                 const char* const haystack[], size_t max_index) {
62  for (index=0; index<max_index; ++index) {
63	if (_stricmp(needle.c_str(), haystack[index]) == 0) {
64	  return true;
65	}
66  }
67  return false;
68}
69
70template<class E>
71struct Enum {
72  static const char** Names;
73  static size_t Size;
74
75  static inline const char* Name(E val) { return Names[val]; }
76  static inline bool Parse(E& val, const std::string& name) {
77	size_t index;
78	if (!find_string(index, name, Names, Size))
79	  return false;
80	val = static_cast<E>(index);
81	return true;
82  }
83
84  E val;
85
86  inline operator E&() { return val; }
87  inline Enum& operator=(E rhs) { val = rhs; return *this; }
88
89  inline const char* name() const { return Name(val); }
90  inline bool assign(const std::string& name) { return Parse(val, name); }
91  inline Enum& operator=(const std::string& rhs) { assign(rhs); return *this; }
92};
93
94#define ENUM(e,n) \
95  template<> const char** Enum<e>::Names = n; \
96  template<> size_t Enum<e>::Size = sizeof(n)/sizeof(n[0])
97
98//////////////////////////////////////////////////////////////////////
99// HttpCommon
100//////////////////////////////////////////////////////////////////////
101
102static const char* kHttpVersions[HVER_LAST+1] = {
103  "1.0", "1.1", "Unknown"
104};
105ENUM(HttpVersion, kHttpVersions);
106
107static const char* kHttpVerbs[HV_LAST+1] = {
108  "GET", "POST", "PUT", "DELETE", "CONNECT", "HEAD"
109};
110ENUM(HttpVerb, kHttpVerbs);
111
112static const char* kHttpHeaders[HH_LAST+1] = {
113  "Age",
114  "Cache-Control",
115  "Connection",
116  "Content-Disposition",
117  "Content-Length",
118  "Content-Range",
119  "Content-Type",
120  "Cookie",
121  "Date",
122  "ETag",
123  "Expires",
124  "Host",
125  "If-Modified-Since",
126  "If-None-Match",
127  "Keep-Alive",
128  "Last-Modified",
129  "Location",
130  "Proxy-Authenticate",
131  "Proxy-Authorization",
132  "Proxy-Connection",
133  "Range",
134  "Set-Cookie",
135  "TE",
136  "Trailers",
137  "Transfer-Encoding",
138  "Upgrade",
139  "User-Agent",
140  "WWW-Authenticate",
141};
142ENUM(HttpHeader, kHttpHeaders);
143
144const char* ToString(HttpVersion version) {
145  return Enum<HttpVersion>::Name(version);
146}
147
148bool FromString(HttpVersion& version, const std::string& str) {
149  return Enum<HttpVersion>::Parse(version, str);
150}
151
152const char* ToString(HttpVerb verb) {
153  return Enum<HttpVerb>::Name(verb);
154}
155
156bool FromString(HttpVerb& verb, const std::string& str) {
157  return Enum<HttpVerb>::Parse(verb, str);
158}
159
160const char* ToString(HttpHeader header) {
161  return Enum<HttpHeader>::Name(header);
162}
163
164bool FromString(HttpHeader& header, const std::string& str) {
165  return Enum<HttpHeader>::Parse(header, str);
166}
167
168bool HttpCodeHasBody(uint32 code) {
169  return !HttpCodeIsInformational(code)
170         && (code != HC_NO_CONTENT) && (code != HC_NOT_MODIFIED);
171}
172
173bool HttpCodeIsCacheable(uint32 code) {
174  switch (code) {
175  case HC_OK:
176  case HC_NON_AUTHORITATIVE:
177  case HC_PARTIAL_CONTENT:
178  case HC_MULTIPLE_CHOICES:
179  case HC_MOVED_PERMANENTLY:
180  case HC_GONE:
181    return true;
182  default:
183    return false;
184  }
185}
186
187bool HttpHeaderIsEndToEnd(HttpHeader header) {
188  switch (header) {
189  case HH_CONNECTION:
190  case HH_KEEP_ALIVE:
191  case HH_PROXY_AUTHENTICATE:
192  case HH_PROXY_AUTHORIZATION:
193  case HH_PROXY_CONNECTION:  // Note part of RFC... this is non-standard header
194  case HH_TE:
195  case HH_TRAILERS:
196  case HH_TRANSFER_ENCODING:
197  case HH_UPGRADE:
198    return false;
199  default:
200    return true;
201  }
202}
203
204bool HttpHeaderIsCollapsible(HttpHeader header) {
205  switch (header) {
206  case HH_SET_COOKIE:
207  case HH_PROXY_AUTHENTICATE:
208  case HH_WWW_AUTHENTICATE:
209    return false;
210  default:
211    return true;
212  }
213}
214
215bool HttpShouldKeepAlive(const HttpData& data) {
216  std::string connection;
217  if ((data.hasHeader(HH_PROXY_CONNECTION, &connection)
218      || data.hasHeader(HH_CONNECTION, &connection))) {
219    return (_stricmp(connection.c_str(), "Keep-Alive") == 0);
220  }
221  return (data.version >= HVER_1_1);
222}
223
224namespace {
225
226inline bool IsEndOfAttributeName(size_t pos, size_t len, const char * data) {
227  if (pos >= len)
228    return true;
229  if (isspace(static_cast<unsigned char>(data[pos])))
230    return true;
231  // The reason for this complexity is that some attributes may contain trailing
232  // equal signs (like base64 tokens in Negotiate auth headers)
233  if ((pos+1 < len) && (data[pos] == '=') &&
234      !isspace(static_cast<unsigned char>(data[pos+1])) &&
235      (data[pos+1] != '=')) {
236    return true;
237  }
238  return false;
239}
240
241// TODO: unittest for EscapeAttribute and HttpComposeAttributes.
242
243std::string EscapeAttribute(const std::string& attribute) {
244  const size_t kMaxLength = attribute.length() * 2 + 1;
245  char* buffer = STACK_ARRAY(char, kMaxLength);
246  size_t len = escape(buffer, kMaxLength, attribute.data(), attribute.length(),
247                      "\"", '\\');
248  return std::string(buffer, len);
249}
250
251}  // anonymous namespace
252
253void HttpComposeAttributes(const HttpAttributeList& attributes, char separator,
254                           std::string* composed) {
255  std::stringstream ss;
256  for (size_t i=0; i<attributes.size(); ++i) {
257    if (i > 0) {
258      ss << separator << " ";
259    }
260    ss << attributes[i].first;
261    if (!attributes[i].second.empty()) {
262      ss << "=\"" << EscapeAttribute(attributes[i].second) << "\"";
263    }
264  }
265  *composed = ss.str();
266}
267
268void HttpParseAttributes(const char * data, size_t len,
269                         HttpAttributeList& attributes) {
270  size_t pos = 0;
271  while (true) {
272    // Skip leading whitespace
273    while ((pos < len) && isspace(static_cast<unsigned char>(data[pos]))) {
274      ++pos;
275    }
276
277    // End of attributes?
278    if (pos >= len)
279      return;
280
281    // Find end of attribute name
282    size_t start = pos;
283    while (!IsEndOfAttributeName(pos, len, data)) {
284      ++pos;
285    }
286
287    HttpAttribute attribute;
288    attribute.first.assign(data + start, data + pos);
289
290    // Attribute has value?
291    if ((pos < len) && (data[pos] == '=')) {
292      ++pos; // Skip '='
293      // Check if quoted value
294      if ((pos < len) && (data[pos] == '"')) {
295        while (++pos < len) {
296          if (data[pos] == '"') {
297            ++pos;
298            break;
299          }
300          if ((data[pos] == '\\') && (pos + 1 < len))
301            ++pos;
302          attribute.second.append(1, data[pos]);
303        }
304      } else {
305        while ((pos < len) &&
306            !isspace(static_cast<unsigned char>(data[pos])) &&
307            (data[pos] != ',')) {
308          attribute.second.append(1, data[pos++]);
309        }
310      }
311    }
312
313    attributes.push_back(attribute);
314    if ((pos < len) && (data[pos] == ',')) ++pos; // Skip ','
315  }
316}
317
318bool HttpHasAttribute(const HttpAttributeList& attributes,
319                      const std::string& name,
320                      std::string* value) {
321  for (HttpAttributeList::const_iterator it = attributes.begin();
322       it != attributes.end(); ++it) {
323    if (it->first == name) {
324      if (value) {
325        *value = it->second;
326      }
327      return true;
328    }
329  }
330  return false;
331}
332
333bool HttpHasNthAttribute(HttpAttributeList& attributes,
334                         size_t index,
335                         std::string* name,
336                         std::string* value) {
337  if (index >= attributes.size())
338    return false;
339
340  if (name)
341    *name = attributes[index].first;
342  if (value)
343    *value = attributes[index].second;
344  return true;
345}
346
347bool HttpDateToSeconds(const std::string& date, time_t* seconds) {
348  const char* const kTimeZones[] = {
349    "UT", "GMT", "EST", "EDT", "CST", "CDT", "MST", "MDT", "PST", "PDT",
350    "A", "B", "C", "D", "E", "F", "G", "H", "I", "K", "L", "M",
351    "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y"
352  };
353  const int kTimeZoneOffsets[] = {
354     0,  0, -5, -4, -6, -5, -7, -6, -8, -7,
355    -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12,
356     1,  2,  3,  4,  5,  6,  7,  8,  9,  10,  11,  12
357  };
358
359  ASSERT(NULL != seconds);
360  struct tm tval;
361  memset(&tval, 0, sizeof(tval));
362  char month[4], zone[6];
363  memset(month, 0, sizeof(month));
364  memset(zone, 0, sizeof(zone));
365
366  if (7 != sscanf(date.c_str(), "%*3s, %d %3s %d %d:%d:%d %5c",
367                  &tval.tm_mday, month, &tval.tm_year,
368                  &tval.tm_hour, &tval.tm_min, &tval.tm_sec, zone)) {
369    return false;
370  }
371  switch (toupper(month[2])) {
372  case 'N': tval.tm_mon = (month[1] == 'A') ? 0 : 5; break;
373  case 'B': tval.tm_mon = 1; break;
374  case 'R': tval.tm_mon = (month[0] == 'M') ? 2 : 3; break;
375  case 'Y': tval.tm_mon = 4; break;
376  case 'L': tval.tm_mon = 6; break;
377  case 'G': tval.tm_mon = 7; break;
378  case 'P': tval.tm_mon = 8; break;
379  case 'T': tval.tm_mon = 9; break;
380  case 'V': tval.tm_mon = 10; break;
381  case 'C': tval.tm_mon = 11; break;
382  }
383  tval.tm_year -= 1900;
384  size_t gmt, non_gmt = mktime(&tval);
385  if ((zone[0] == '+') || (zone[0] == '-')) {
386    if (!isdigit(zone[1]) || !isdigit(zone[2])
387        || !isdigit(zone[3]) || !isdigit(zone[4])) {
388      return false;
389    }
390    int hours = (zone[1] - '0') * 10 + (zone[2] - '0');
391    int minutes = (zone[3] - '0') * 10 + (zone[4] - '0');
392    int offset = (hours * 60 + minutes) * 60;
393    gmt = non_gmt + ((zone[0] == '+') ? offset : -offset);
394  } else {
395    size_t zindex;
396    if (!find_string(zindex, zone, kTimeZones, ARRAY_SIZE(kTimeZones))) {
397      return false;
398    }
399    gmt = non_gmt + kTimeZoneOffsets[zindex] * 60 * 60;
400  }
401  // TODO: Android should support timezone, see b/2441195
402#if defined(OSX) || defined(ANDROID) || defined(BSD)
403  tm *tm_for_timezone = localtime((time_t *)&gmt);
404  *seconds = gmt + tm_for_timezone->tm_gmtoff;
405#else
406  *seconds = gmt - timezone;
407#endif
408  return true;
409}
410
411std::string HttpAddress(const SocketAddress& address, bool secure) {
412  return (address.port() == HttpDefaultPort(secure))
413          ? address.hostname() : address.ToString();
414}
415
416//////////////////////////////////////////////////////////////////////
417// HttpData
418//////////////////////////////////////////////////////////////////////
419
420void
421HttpData::clear(bool release_document) {
422  // Clear headers first, since releasing a document may have far-reaching
423  // effects.
424  headers_.clear();
425  if (release_document) {
426    document.reset();
427  }
428}
429
430void
431HttpData::copy(const HttpData& src) {
432  headers_ = src.headers_;
433}
434
435void
436HttpData::changeHeader(const std::string& name, const std::string& value,
437                       HeaderCombine combine) {
438  if (combine == HC_AUTO) {
439    HttpHeader header;
440    // Unrecognized headers are collapsible
441    combine = !FromString(header, name) || HttpHeaderIsCollapsible(header)
442              ? HC_YES : HC_NO;
443  } else if (combine == HC_REPLACE) {
444    headers_.erase(name);
445    combine = HC_NO;
446  }
447  // At this point, combine is one of (YES, NO, NEW)
448  if (combine != HC_NO) {
449    HeaderMap::iterator it = headers_.find(name);
450    if (it != headers_.end()) {
451      if (combine == HC_YES) {
452        it->second.append(",");
453        it->second.append(value);
454	  }
455      return;
456	}
457  }
458  headers_.insert(HeaderMap::value_type(name, value));
459}
460
461size_t HttpData::clearHeader(const std::string& name) {
462  return headers_.erase(name);
463}
464
465HttpData::iterator HttpData::clearHeader(iterator header) {
466  iterator deprecated = header++;
467  headers_.erase(deprecated);
468  return header;
469}
470
471bool
472HttpData::hasHeader(const std::string& name, std::string* value) const {
473  HeaderMap::const_iterator it = headers_.find(name);
474  if (it == headers_.end()) {
475    return false;
476  } else if (value) {
477    *value = it->second;
478  }
479  return true;
480}
481
482void HttpData::setContent(const std::string& content_type,
483                          StreamInterface* document) {
484  setHeader(HH_CONTENT_TYPE, content_type);
485  setDocumentAndLength(document);
486}
487
488void HttpData::setDocumentAndLength(StreamInterface* document) {
489  // TODO: Consider calling Rewind() here?
490  ASSERT(!hasHeader(HH_CONTENT_LENGTH, NULL));
491  ASSERT(!hasHeader(HH_TRANSFER_ENCODING, NULL));
492  ASSERT(document != NULL);
493  this->document.reset(document);
494  size_t content_length = 0;
495  if (this->document->GetAvailable(&content_length)) {
496    char buffer[32];
497    sprintfn(buffer, sizeof(buffer), "%d", content_length);
498    setHeader(HH_CONTENT_LENGTH, buffer);
499  } else {
500    setHeader(HH_TRANSFER_ENCODING, "chunked");
501  }
502}
503
504//
505// HttpRequestData
506//
507
508void
509HttpRequestData::clear(bool release_document) {
510  verb = HV_GET;
511  path.clear();
512  HttpData::clear(release_document);
513}
514
515void
516HttpRequestData::copy(const HttpRequestData& src) {
517  verb = src.verb;
518  path = src.path;
519  HttpData::copy(src);
520}
521
522size_t
523HttpRequestData::formatLeader(char* buffer, size_t size) const {
524  ASSERT(path.find(' ') == std::string::npos);
525  return sprintfn(buffer, size, "%s %.*s HTTP/%s", ToString(verb), path.size(),
526                  path.data(), ToString(version));
527}
528
529HttpError
530HttpRequestData::parseLeader(const char* line, size_t len) {
531  unsigned int vmajor, vminor;
532  int vend, dstart, dend;
533  // sscanf isn't safe with strings that aren't null-terminated, and there is
534  // no guarantee that |line| is. Create a local copy that is null-terminated.
535  std::string line_str(line, len);
536  line = line_str.c_str();
537  if ((sscanf(line, "%*s%n %n%*s%n HTTP/%u.%u",
538              &vend, &dstart, &dend, &vmajor, &vminor) != 2)
539      || (vmajor != 1)) {
540    return HE_PROTOCOL;
541  }
542  if (vminor == 0) {
543    version = HVER_1_0;
544  } else if (vminor == 1) {
545    version = HVER_1_1;
546  } else {
547    return HE_PROTOCOL;
548  }
549  std::string sverb(line, vend);
550  if (!FromString(verb, sverb.c_str())) {
551    return HE_PROTOCOL; // !?! HC_METHOD_NOT_SUPPORTED?
552  }
553  path.assign(line + dstart, line + dend);
554  return HE_NONE;
555}
556
557bool HttpRequestData::getAbsoluteUri(std::string* uri) const {
558  if (HV_CONNECT == verb)
559    return false;
560  Url<char> url(path);
561  if (url.valid()) {
562    uri->assign(path);
563    return true;
564  }
565  std::string host;
566  if (!hasHeader(HH_HOST, &host))
567    return false;
568  url.set_address(host);
569  url.set_full_path(path);
570  uri->assign(url.url());
571  return url.valid();
572}
573
574bool HttpRequestData::getRelativeUri(std::string* host,
575                                     std::string* path) const
576{
577  if (HV_CONNECT == verb)
578    return false;
579  Url<char> url(this->path);
580  if (url.valid()) {
581    host->assign(url.address());
582    path->assign(url.full_path());
583    return true;
584  }
585  if (!hasHeader(HH_HOST, host))
586    return false;
587  path->assign(this->path);
588  return true;
589}
590
591//
592// HttpResponseData
593//
594
595void
596HttpResponseData::clear(bool release_document) {
597  scode = HC_INTERNAL_SERVER_ERROR;
598  message.clear();
599  HttpData::clear(release_document);
600}
601
602void
603HttpResponseData::copy(const HttpResponseData& src) {
604  scode = src.scode;
605  message = src.message;
606  HttpData::copy(src);
607}
608
609void
610HttpResponseData::set_success(uint32 scode) {
611  this->scode = scode;
612  message.clear();
613  setHeader(HH_CONTENT_LENGTH, "0", false);
614}
615
616void
617HttpResponseData::set_success(const std::string& content_type,
618                              StreamInterface* document,
619                              uint32 scode) {
620  this->scode = scode;
621  message.erase(message.begin(), message.end());
622  setContent(content_type, document);
623}
624
625void
626HttpResponseData::set_redirect(const std::string& location, uint32 scode) {
627  this->scode = scode;
628  message.clear();
629  setHeader(HH_LOCATION, location);
630  setHeader(HH_CONTENT_LENGTH, "0", false);
631}
632
633void
634HttpResponseData::set_error(uint32 scode) {
635  this->scode = scode;
636  message.clear();
637  setHeader(HH_CONTENT_LENGTH, "0", false);
638}
639
640size_t
641HttpResponseData::formatLeader(char* buffer, size_t size) const {
642  size_t len = sprintfn(buffer, size, "HTTP/%s %lu", ToString(version), scode);
643  if (!message.empty()) {
644    len += sprintfn(buffer + len, size - len, " %.*s",
645                    message.size(), message.data());
646  }
647  return len;
648}
649
650HttpError
651HttpResponseData::parseLeader(const char* line, size_t len) {
652  size_t pos = 0;
653  unsigned int vmajor, vminor, temp_scode;
654  int temp_pos;
655  // sscanf isn't safe with strings that aren't null-terminated, and there is
656  // no guarantee that |line| is. Create a local copy that is null-terminated.
657  std::string line_str(line, len);
658  line = line_str.c_str();
659  if (sscanf(line, "HTTP %u%n",
660             &temp_scode, &temp_pos) == 1) {
661    // This server's response has no version. :( NOTE: This happens for every
662    // response to requests made from Chrome plugins, regardless of the server's
663    // behaviour.
664    LOG(LS_VERBOSE) << "HTTP version missing from response";
665    version = HVER_UNKNOWN;
666  } else if ((sscanf(line, "HTTP/%u.%u %u%n",
667                     &vmajor, &vminor, &temp_scode, &temp_pos) == 3)
668             && (vmajor == 1)) {
669    // This server's response does have a version.
670    if (vminor == 0) {
671      version = HVER_1_0;
672    } else if (vminor == 1) {
673      version = HVER_1_1;
674    } else {
675      return HE_PROTOCOL;
676    }
677  } else {
678    return HE_PROTOCOL;
679  }
680  scode = temp_scode;
681  pos = static_cast<size_t>(temp_pos);
682  while ((pos < len) && isspace(static_cast<unsigned char>(line[pos]))) ++pos;
683  message.assign(line + pos, len - pos);
684  return HE_NONE;
685}
686
687//////////////////////////////////////////////////////////////////////
688// Http Authentication
689//////////////////////////////////////////////////////////////////////
690
691#define TEST_DIGEST 0
692#if TEST_DIGEST
693/*
694const char * const DIGEST_CHALLENGE =
695  "Digest realm=\"testrealm@host.com\","
696  " qop=\"auth,auth-int\","
697  " nonce=\"dcd98b7102dd2f0e8b11d0f600bfb0c093\","
698  " opaque=\"5ccc069c403ebaf9f0171e9517f40e41\"";
699const char * const DIGEST_METHOD = "GET";
700const char * const DIGEST_URI =
701  "/dir/index.html";;
702const char * const DIGEST_CNONCE =
703  "0a4f113b";
704const char * const DIGEST_RESPONSE =
705  "6629fae49393a05397450978507c4ef1";
706//user_ = "Mufasa";
707//pass_ = "Circle Of Life";
708*/
709const char * const DIGEST_CHALLENGE =
710  "Digest realm=\"Squid proxy-caching web server\","
711  " nonce=\"Nny4QuC5PwiSDixJ\","
712  " qop=\"auth\","
713  " stale=false";
714const char * const DIGEST_URI =
715  "/";
716const char * const DIGEST_CNONCE =
717  "6501d58e9a21cee1e7b5fec894ded024";
718const char * const DIGEST_RESPONSE =
719  "edffcb0829e755838b073a4a42de06bc";
720#endif
721
722std::string quote(const std::string& str) {
723  std::string result;
724  result.push_back('"');
725  for (size_t i=0; i<str.size(); ++i) {
726    if ((str[i] == '"') || (str[i] == '\\'))
727      result.push_back('\\');
728    result.push_back(str[i]);
729  }
730  result.push_back('"');
731  return result;
732}
733
734#ifdef WIN32
735struct NegotiateAuthContext : public HttpAuthContext {
736  CredHandle cred;
737  CtxtHandle ctx;
738  size_t steps;
739  bool specified_credentials;
740
741  NegotiateAuthContext(const std::string& auth, CredHandle c1, CtxtHandle c2)
742  : HttpAuthContext(auth), cred(c1), ctx(c2), steps(0),
743    specified_credentials(false)
744  { }
745
746  virtual ~NegotiateAuthContext() {
747    DeleteSecurityContext(&ctx);
748    FreeCredentialsHandle(&cred);
749  }
750};
751#endif // WIN32
752
753HttpAuthResult HttpAuthenticate(
754  const char * challenge, size_t len,
755  const SocketAddress& server,
756  const std::string& method, const std::string& uri,
757  const std::string& username, const CryptString& password,
758  HttpAuthContext *& context, std::string& response, std::string& auth_method)
759{
760#if TEST_DIGEST
761  challenge = DIGEST_CHALLENGE;
762  len = strlen(challenge);
763#endif
764
765  HttpAttributeList args;
766  HttpParseAttributes(challenge, len, args);
767  HttpHasNthAttribute(args, 0, &auth_method, NULL);
768
769  if (context && (context->auth_method != auth_method))
770    return HAR_IGNORE;
771
772  // BASIC
773  if (_stricmp(auth_method.c_str(), "basic") == 0) {
774    if (context)
775      return HAR_CREDENTIALS; // Bad credentials
776    if (username.empty())
777      return HAR_CREDENTIALS; // Missing credentials
778
779    context = new HttpAuthContext(auth_method);
780
781    // TODO: convert sensitive to a secure buffer that gets securely deleted
782    //std::string decoded = username + ":" + password;
783    size_t len = username.size() + password.GetLength() + 2;
784    char * sensitive = new char[len];
785    size_t pos = strcpyn(sensitive, len, username.data(), username.size());
786    pos += strcpyn(sensitive + pos, len - pos, ":");
787    password.CopyTo(sensitive + pos, true);
788
789    response = auth_method;
790    response.append(" ");
791    // TODO: create a sensitive-source version of Base64::encode
792    response.append(Base64::Encode(sensitive));
793    memset(sensitive, 0, len);
794    delete [] sensitive;
795    return HAR_RESPONSE;
796  }
797
798  // DIGEST
799  if (_stricmp(auth_method.c_str(), "digest") == 0) {
800    if (context)
801      return HAR_CREDENTIALS; // Bad credentials
802    if (username.empty())
803      return HAR_CREDENTIALS; // Missing credentials
804
805    context = new HttpAuthContext(auth_method);
806
807    std::string cnonce, ncount;
808#if TEST_DIGEST
809    method = DIGEST_METHOD;
810    uri    = DIGEST_URI;
811    cnonce = DIGEST_CNONCE;
812#else
813    char buffer[256];
814    sprintf(buffer, "%d", static_cast<int>(time(0)));
815    cnonce = MD5(buffer);
816#endif
817    ncount = "00000001";
818
819    std::string realm, nonce, qop, opaque;
820    HttpHasAttribute(args, "realm", &realm);
821    HttpHasAttribute(args, "nonce", &nonce);
822    bool has_qop = HttpHasAttribute(args, "qop", &qop);
823    bool has_opaque = HttpHasAttribute(args, "opaque", &opaque);
824
825    // TODO: convert sensitive to be secure buffer
826    //std::string A1 = username + ":" + realm + ":" + password;
827    size_t len = username.size() + realm.size() + password.GetLength() + 3;
828    char * sensitive = new char[len];  // A1
829    size_t pos = strcpyn(sensitive, len, username.data(), username.size());
830    pos += strcpyn(sensitive + pos, len - pos, ":");
831    pos += strcpyn(sensitive + pos, len - pos, realm.c_str());
832    pos += strcpyn(sensitive + pos, len - pos, ":");
833    password.CopyTo(sensitive + pos, true);
834
835    std::string A2 = method + ":" + uri;
836    std::string middle;
837    if (has_qop) {
838      qop = "auth";
839      middle = nonce + ":" + ncount + ":" + cnonce + ":" + qop;
840    } else {
841      middle = nonce;
842    }
843    std::string HA1 = MD5(sensitive);
844    memset(sensitive, 0, len);
845    delete [] sensitive;
846    std::string HA2 = MD5(A2);
847    std::string dig_response = MD5(HA1 + ":" + middle + ":" + HA2);
848
849#if TEST_DIGEST
850    ASSERT(strcmp(dig_response.c_str(), DIGEST_RESPONSE) == 0);
851#endif
852
853    std::stringstream ss;
854    ss << auth_method;
855    ss << " username=" << quote(username);
856    ss << ", realm=" << quote(realm);
857    ss << ", nonce=" << quote(nonce);
858    ss << ", uri=" << quote(uri);
859    if (has_qop) {
860      ss << ", qop=" << qop;
861      ss << ", nc="  << ncount;
862      ss << ", cnonce=" << quote(cnonce);
863    }
864    ss << ", response=\"" << dig_response << "\"";
865    if (has_opaque) {
866      ss << ", opaque=" << quote(opaque);
867    }
868    response = ss.str();
869    return HAR_RESPONSE;
870  }
871
872#ifdef WIN32
873#if 1
874  bool want_negotiate = (_stricmp(auth_method.c_str(), "negotiate") == 0);
875  bool want_ntlm = (_stricmp(auth_method.c_str(), "ntlm") == 0);
876  // SPNEGO & NTLM
877  if (want_negotiate || want_ntlm) {
878    const size_t MAX_MESSAGE = 12000, MAX_SPN = 256;
879    char out_buf[MAX_MESSAGE], spn[MAX_SPN];
880
881#if 0 // Requires funky windows versions
882    DWORD len = MAX_SPN;
883    if (DsMakeSpn("HTTP", server.HostAsURIString().c_str(), NULL,
884                  server.port(),
885                  0, &len, spn) != ERROR_SUCCESS) {
886      LOG_F(WARNING) << "(Negotiate) - DsMakeSpn failed";
887      return HAR_IGNORE;
888    }
889#else
890    sprintfn(spn, MAX_SPN, "HTTP/%s", server.ToString().c_str());
891#endif
892
893    SecBuffer out_sec;
894    out_sec.pvBuffer   = out_buf;
895    out_sec.cbBuffer   = sizeof(out_buf);
896    out_sec.BufferType = SECBUFFER_TOKEN;
897
898    SecBufferDesc out_buf_desc;
899    out_buf_desc.ulVersion = 0;
900    out_buf_desc.cBuffers  = 1;
901    out_buf_desc.pBuffers  = &out_sec;
902
903    const ULONG NEG_FLAGS_DEFAULT =
904      //ISC_REQ_ALLOCATE_MEMORY
905      ISC_REQ_CONFIDENTIALITY
906      //| ISC_REQ_EXTENDED_ERROR
907      //| ISC_REQ_INTEGRITY
908      | ISC_REQ_REPLAY_DETECT
909      | ISC_REQ_SEQUENCE_DETECT
910      //| ISC_REQ_STREAM
911      //| ISC_REQ_USE_SUPPLIED_CREDS
912      ;
913
914    ::TimeStamp lifetime;
915    SECURITY_STATUS ret = S_OK;
916    ULONG ret_flags = 0, flags = NEG_FLAGS_DEFAULT;
917
918    bool specify_credentials = !username.empty();
919    size_t steps = 0;
920
921    //uint32 now = Time();
922
923    NegotiateAuthContext * neg = static_cast<NegotiateAuthContext *>(context);
924    if (neg) {
925      const size_t max_steps = 10;
926      if (++neg->steps >= max_steps) {
927        LOG(WARNING) << "AsyncHttpsProxySocket::Authenticate(Negotiate) too many retries";
928        return HAR_ERROR;
929      }
930      steps = neg->steps;
931
932      std::string challenge, decoded_challenge;
933      if (HttpHasNthAttribute(args, 1, &challenge, NULL)
934          && Base64::Decode(challenge, Base64::DO_STRICT,
935                            &decoded_challenge, NULL)) {
936        SecBuffer in_sec;
937        in_sec.pvBuffer   = const_cast<char *>(decoded_challenge.data());
938        in_sec.cbBuffer   = static_cast<unsigned long>(decoded_challenge.size());
939        in_sec.BufferType = SECBUFFER_TOKEN;
940
941        SecBufferDesc in_buf_desc;
942        in_buf_desc.ulVersion = 0;
943        in_buf_desc.cBuffers  = 1;
944        in_buf_desc.pBuffers  = &in_sec;
945
946        ret = InitializeSecurityContextA(&neg->cred, &neg->ctx, spn, flags, 0, SECURITY_NATIVE_DREP, &in_buf_desc, 0, &neg->ctx, &out_buf_desc, &ret_flags, &lifetime);
947        //LOG(INFO) << "$$$ InitializeSecurityContext @ " << TimeSince(now);
948        if (FAILED(ret)) {
949          LOG(LS_ERROR) << "InitializeSecurityContext returned: "
950                      << ErrorName(ret, SECURITY_ERRORS);
951          return HAR_ERROR;
952        }
953      } else if (neg->specified_credentials) {
954        // Try again with default credentials
955        specify_credentials = false;
956        delete context;
957        context = neg = 0;
958      } else {
959        return HAR_CREDENTIALS;
960      }
961    }
962
963    if (!neg) {
964      unsigned char userbuf[256], passbuf[256], domainbuf[16];
965      SEC_WINNT_AUTH_IDENTITY_A auth_id, * pauth_id = 0;
966      if (specify_credentials) {
967        memset(&auth_id, 0, sizeof(auth_id));
968        size_t len = password.GetLength()+1;
969        char * sensitive = new char[len];
970        password.CopyTo(sensitive, true);
971        std::string::size_type pos = username.find('\\');
972        if (pos == std::string::npos) {
973          auth_id.UserLength = static_cast<unsigned long>(
974            _min(sizeof(userbuf) - 1, username.size()));
975          memcpy(userbuf, username.c_str(), auth_id.UserLength);
976          userbuf[auth_id.UserLength] = 0;
977          auth_id.DomainLength = 0;
978          domainbuf[auth_id.DomainLength] = 0;
979          auth_id.PasswordLength = static_cast<unsigned long>(
980            _min(sizeof(passbuf) - 1, password.GetLength()));
981          memcpy(passbuf, sensitive, auth_id.PasswordLength);
982          passbuf[auth_id.PasswordLength] = 0;
983        } else {
984          auth_id.UserLength = static_cast<unsigned long>(
985            _min(sizeof(userbuf) - 1, username.size() - pos - 1));
986          memcpy(userbuf, username.c_str() + pos + 1, auth_id.UserLength);
987          userbuf[auth_id.UserLength] = 0;
988          auth_id.DomainLength = static_cast<unsigned long>(
989            _min(sizeof(domainbuf) - 1, pos));
990          memcpy(domainbuf, username.c_str(), auth_id.DomainLength);
991          domainbuf[auth_id.DomainLength] = 0;
992          auth_id.PasswordLength = static_cast<unsigned long>(
993            _min(sizeof(passbuf) - 1, password.GetLength()));
994          memcpy(passbuf, sensitive, auth_id.PasswordLength);
995          passbuf[auth_id.PasswordLength] = 0;
996        }
997        memset(sensitive, 0, len);
998        delete [] sensitive;
999        auth_id.User = userbuf;
1000        auth_id.Domain = domainbuf;
1001        auth_id.Password = passbuf;
1002        auth_id.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI;
1003        pauth_id = &auth_id;
1004        LOG(LS_VERBOSE) << "Negotiate protocol: Using specified credentials";
1005      } else {
1006        LOG(LS_VERBOSE) << "Negotiate protocol: Using default credentials";
1007      }
1008
1009      CredHandle cred;
1010      ret = AcquireCredentialsHandleA(0, want_negotiate ? NEGOSSP_NAME_A : NTLMSP_NAME_A, SECPKG_CRED_OUTBOUND, 0, pauth_id, 0, 0, &cred, &lifetime);
1011      //LOG(INFO) << "$$$ AcquireCredentialsHandle @ " << TimeSince(now);
1012      if (ret != SEC_E_OK) {
1013        LOG(LS_ERROR) << "AcquireCredentialsHandle error: "
1014                    << ErrorName(ret, SECURITY_ERRORS);
1015        return HAR_IGNORE;
1016      }
1017
1018      //CSecBufferBundle<5, CSecBufferBase::FreeSSPI> sb_out;
1019
1020      CtxtHandle ctx;
1021      ret = InitializeSecurityContextA(&cred, 0, spn, flags, 0, SECURITY_NATIVE_DREP, 0, 0, &ctx, &out_buf_desc, &ret_flags, &lifetime);
1022      //LOG(INFO) << "$$$ InitializeSecurityContext @ " << TimeSince(now);
1023      if (FAILED(ret)) {
1024        LOG(LS_ERROR) << "InitializeSecurityContext returned: "
1025                    << ErrorName(ret, SECURITY_ERRORS);
1026        FreeCredentialsHandle(&cred);
1027        return HAR_IGNORE;
1028      }
1029
1030      ASSERT(!context);
1031      context = neg = new NegotiateAuthContext(auth_method, cred, ctx);
1032      neg->specified_credentials = specify_credentials;
1033      neg->steps = steps;
1034    }
1035
1036    if ((ret == SEC_I_COMPLETE_NEEDED) || (ret == SEC_I_COMPLETE_AND_CONTINUE)) {
1037      ret = CompleteAuthToken(&neg->ctx, &out_buf_desc);
1038      //LOG(INFO) << "$$$ CompleteAuthToken @ " << TimeSince(now);
1039      LOG(LS_VERBOSE) << "CompleteAuthToken returned: "
1040                      << ErrorName(ret, SECURITY_ERRORS);
1041      if (FAILED(ret)) {
1042        return HAR_ERROR;
1043      }
1044    }
1045
1046    //LOG(INFO) << "$$$ NEGOTIATE took " << TimeSince(now) << "ms";
1047
1048    std::string decoded(out_buf, out_buf + out_sec.cbBuffer);
1049    response = auth_method;
1050    response.append(" ");
1051    response.append(Base64::Encode(decoded));
1052    return HAR_RESPONSE;
1053  }
1054#endif
1055#endif // WIN32
1056
1057  return HAR_IGNORE;
1058}
1059
1060//////////////////////////////////////////////////////////////////////
1061
1062} // namespace talk_base
1063