1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome_frame/http_negotiate.h"
6
7#include <atlbase.h>
8#include <atlcom.h>
9#include <htiframe.h>
10
11#include "base/logging.h"
12#include "base/memory/scoped_ptr.h"
13#include "base/strings/string_util.h"
14#include "base/strings/stringprintf.h"
15#include "base/strings/utf_string_conversions.h"
16#include "chrome_frame/bho.h"
17#include "chrome_frame/exception_barrier.h"
18#include "chrome_frame/html_utils.h"
19#include "chrome_frame/urlmon_moniker.h"
20#include "chrome_frame/urlmon_url_request.h"
21#include "chrome_frame/utils.h"
22#include "chrome_frame/vtable_patch_manager.h"
23#include "net/http/http_response_headers.h"
24#include "net/http/http_util.h"
25
26bool HttpNegotiatePatch::modify_user_agent_ = true;
27const char kUACompatibleHttpHeader[] = "x-ua-compatible";
28const char kLowerCaseUserAgent[] = "user-agent";
29
30// From the latest urlmon.h. Symbol name prepended with LOCAL_ to
31// avoid conflict (and therefore build errors) for those building with
32// a newer Windows SDK.
33// TODO(robertshield): Remove this once we update our SDK version.
34const int LOCAL_BINDSTATUS_SERVER_MIMETYPEAVAILABLE = 54;
35
36static const int kHttpNegotiateBeginningTransactionIndex = 3;
37
38BEGIN_VTABLE_PATCHES(IHttpNegotiate)
39  VTABLE_PATCH_ENTRY(kHttpNegotiateBeginningTransactionIndex,
40                     HttpNegotiatePatch::BeginningTransaction)
41END_VTABLE_PATCHES()
42
43namespace {
44
45class SimpleBindStatusCallback : public CComObjectRootEx<CComSingleThreadModel>,
46                                 public IBindStatusCallback {
47 public:
48  BEGIN_COM_MAP(SimpleBindStatusCallback)
49    COM_INTERFACE_ENTRY(IBindStatusCallback)
50  END_COM_MAP()
51
52  // IBindStatusCallback implementation
53  STDMETHOD(OnStartBinding)(DWORD reserved, IBinding* binding) {
54    return E_NOTIMPL;
55  }
56
57  STDMETHOD(GetPriority)(LONG* priority) {
58    return E_NOTIMPL;
59  }
60  STDMETHOD(OnLowResource)(DWORD reserved) {
61    return E_NOTIMPL;
62  }
63
64  STDMETHOD(OnProgress)(ULONG progress, ULONG max_progress,
65                        ULONG status_code, LPCWSTR status_text) {
66    return E_NOTIMPL;
67  }
68  STDMETHOD(OnStopBinding)(HRESULT result, LPCWSTR error) {
69    return E_NOTIMPL;
70  }
71
72  STDMETHOD(GetBindInfo)(DWORD* bind_flags, BINDINFO* bind_info) {
73    return E_NOTIMPL;
74  }
75
76  STDMETHOD(OnDataAvailable)(DWORD flags, DWORD size, FORMATETC* formatetc,
77    STGMEDIUM* storage) {
78    return E_NOTIMPL;
79  }
80  STDMETHOD(OnObjectAvailable)(REFIID iid, IUnknown* object) {
81    return E_NOTIMPL;
82  }
83};
84
85// Returns the full user agent header from the HTTP header strings passed to
86// IHttpNegotiate::BeginningTransaction. Looks first in |additional_headers|
87// and if it can't be found there looks in |headers|.
88std::string GetUserAgentFromHeaders(LPCWSTR headers,
89                                    LPCWSTR additional_headers) {
90  using net::HttpUtil;
91
92  std::string ascii_headers;
93  if (additional_headers) {
94    ascii_headers = WideToASCII(additional_headers);
95  }
96
97  // Extract "User-Agent" from |additional_headers| or |headers|.
98  HttpUtil::HeadersIterator headers_iterator(ascii_headers.begin(),
99                                             ascii_headers.end(), "\r\n");
100  std::string user_agent_value;
101  if (headers_iterator.AdvanceTo(kLowerCaseUserAgent)) {
102    user_agent_value = headers_iterator.values();
103  } else if (headers != NULL) {
104    // See if there's a user-agent header specified in the original headers.
105    std::string original_headers(WideToASCII(headers));
106    HttpUtil::HeadersIterator original_it(original_headers.begin(),
107        original_headers.end(), "\r\n");
108    if (original_it.AdvanceTo(kLowerCaseUserAgent))
109      user_agent_value = original_it.values();
110  }
111
112  return user_agent_value;
113}
114
115// Removes the named header |field| from a set of headers. |field| must be
116// lower-case.
117std::string ExcludeFieldFromHeaders(const std::string& old_headers,
118                                    const char* field) {
119  using net::HttpUtil;
120  std::string new_headers;
121  new_headers.reserve(old_headers.size());
122  HttpUtil::HeadersIterator headers_iterator(old_headers.begin(),
123                                             old_headers.end(), "\r\n");
124  while (headers_iterator.GetNext()) {
125    if (!LowerCaseEqualsASCII(headers_iterator.name_begin(),
126                              headers_iterator.name_end(),
127                              field)) {
128      new_headers.append(headers_iterator.name_begin(),
129                         headers_iterator.name_end());
130      new_headers += ": ";
131      new_headers.append(headers_iterator.values_begin(),
132                         headers_iterator.values_end());
133      new_headers += "\r\n";
134    }
135  }
136
137  return new_headers;
138}
139
140std::string MutateCFUserAgentString(LPCWSTR headers,
141                                    LPCWSTR additional_headers,
142                                    bool add_user_agent) {
143  std::string user_agent_value(GetUserAgentFromHeaders(headers,
144                                                       additional_headers));
145
146  // Use the default "User-Agent" if none was provided.
147  if (user_agent_value.empty())
148    user_agent_value = http_utils::GetDefaultUserAgent();
149
150  // Now add chromeframe to it.
151  user_agent_value = add_user_agent ?
152      http_utils::AddChromeFrameToUserAgentValue(user_agent_value) :
153      http_utils::RemoveChromeFrameFromUserAgentValue(user_agent_value);
154
155  // Build a new set of additional headers, skipping the existing user agent
156  // value if present.
157  return ReplaceOrAddUserAgent(additional_headers, user_agent_value);
158}
159
160}  // end namespace
161
162
163std::string AppendCFUserAgentString(LPCWSTR headers,
164                                    LPCWSTR additional_headers) {
165  return MutateCFUserAgentString(headers, additional_headers, true);
166}
167
168
169// Looks for a user agent header found in |headers| or |additional_headers|
170// then returns |additional_headers| with a modified user agent header that does
171// not include the chromeframe token.
172std::string RemoveCFUserAgentString(LPCWSTR headers,
173                                    LPCWSTR additional_headers) {
174  return MutateCFUserAgentString(headers, additional_headers, false);
175}
176
177
178// Unconditionally adds the specified |user_agent_value| to the given set of
179// |headers|, removing any that were already there.
180std::string ReplaceOrAddUserAgent(LPCWSTR headers,
181                                  const std::string& user_agent_value) {
182  std::string new_headers;
183  if (headers) {
184    std::string ascii_headers(WideToASCII(headers));
185    // Build new headers, skip the existing user agent value from
186    // existing headers.
187    new_headers = ExcludeFieldFromHeaders(ascii_headers, kLowerCaseUserAgent);
188  }
189  new_headers += "User-Agent: ";
190  new_headers += user_agent_value;
191  new_headers += "\r\n";
192  return new_headers;
193}
194
195HttpNegotiatePatch::HttpNegotiatePatch() {
196}
197
198HttpNegotiatePatch::~HttpNegotiatePatch() {
199}
200
201// static
202bool HttpNegotiatePatch::Initialize() {
203  if (IS_PATCHED(IHttpNegotiate)) {
204    DLOG(WARNING) << __FUNCTION__ << " called more than once.";
205    return true;
206  }
207  // Use our SimpleBindStatusCallback class as we need a temporary object that
208  // implements IBindStatusCallback.
209  CComObjectStackEx<SimpleBindStatusCallback> request;
210  base::win::ScopedComPtr<IBindCtx> bind_ctx;
211  HRESULT hr = CreateAsyncBindCtx(0, &request, NULL, bind_ctx.Receive());
212  DCHECK(SUCCEEDED(hr)) << "CreateAsyncBindCtx";
213  if (bind_ctx) {
214    base::win::ScopedComPtr<IUnknown> bscb_holder;
215    bind_ctx->GetObjectParam(L"_BSCB_Holder_", bscb_holder.Receive());
216    if (bscb_holder) {
217      hr = PatchHttpNegotiate(bscb_holder);
218    } else {
219      NOTREACHED() << "Failed to get _BSCB_Holder_";
220      hr = E_UNEXPECTED;
221    }
222    bind_ctx.Release();
223  }
224
225  return SUCCEEDED(hr);
226}
227
228// static
229void HttpNegotiatePatch::Uninitialize() {
230  vtable_patch::UnpatchInterfaceMethods(IHttpNegotiate_PatchInfo);
231}
232
233// static
234HRESULT HttpNegotiatePatch::PatchHttpNegotiate(IUnknown* to_patch) {
235  DCHECK(to_patch);
236  DCHECK_IS_NOT_PATCHED(IHttpNegotiate);
237
238  base::win::ScopedComPtr<IHttpNegotiate> http;
239  HRESULT hr = http.QueryFrom(to_patch);
240  if (FAILED(hr)) {
241    hr = DoQueryService(IID_IHttpNegotiate, to_patch, http.Receive());
242  }
243
244  if (http) {
245    hr = vtable_patch::PatchInterfaceMethods(http, IHttpNegotiate_PatchInfo);
246    DLOG_IF(ERROR, FAILED(hr))
247        << base::StringPrintf("HttpNegotiate patch failed 0x%08X", hr);
248  } else {
249    DLOG(WARNING)
250        << base::StringPrintf("IHttpNegotiate not supported 0x%08X", hr);
251  }
252  return hr;
253}
254
255// static
256HRESULT HttpNegotiatePatch::BeginningTransaction(
257    IHttpNegotiate_BeginningTransaction_Fn original, IHttpNegotiate* me,
258    LPCWSTR url, LPCWSTR headers, DWORD reserved, LPWSTR* additional_headers) {
259  DVLOG(1) << __FUNCTION__ << " " << url << " headers:\n" << headers;
260
261  HRESULT hr = original(me, url, headers, reserved, additional_headers);
262
263  if (FAILED(hr)) {
264    DLOG(WARNING) << __FUNCTION__ << " Delegate returned an error";
265    return hr;
266  }
267  if (modify_user_agent_) {
268    std::string updated_headers;
269
270    if (IsGcfDefaultRenderer() &&
271        RendererTypeForUrl(url) == RENDERER_TYPE_CHROME_DEFAULT_RENDERER) {
272      // Replace the user-agent header with Chrome's.
273      updated_headers = ReplaceOrAddUserAgent(*additional_headers,
274                                              http_utils::GetChromeUserAgent());
275    } else if (ShouldRemoveUAForUrl(url)) {
276      updated_headers = RemoveCFUserAgentString(headers, *additional_headers);
277    } else {
278      updated_headers = AppendCFUserAgentString(headers, *additional_headers);
279    }
280
281    *additional_headers = reinterpret_cast<wchar_t*>(::CoTaskMemRealloc(
282        *additional_headers,
283        (updated_headers.length() + 1) * sizeof(wchar_t)));
284    lstrcpyW(*additional_headers, ASCIIToWide(updated_headers).c_str());
285  } else {
286    // TODO(erikwright): Remove the user agent if it is present (i.e., because
287    // of PostPlatform setting in the registry).
288  }
289  return S_OK;
290}
291