1// Copyright (c) 2012 The Chromium Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style license that can be
3// found in the LICENSE file.
4
5#include "chrome/browser/download/download_query.h"
6
7#include <algorithm>
8#include <string>
9#include <vector>
10
11#include "base/bind.h"
12#include "base/callback.h"
13#include "base/files/file_path.h"
14#include "base/i18n/case_conversion.h"
15#include "base/i18n/string_search.h"
16#include "base/logging.h"
17#include "base/memory/scoped_ptr.h"
18#include "base/prefs/pref_service.h"
19#include "base/stl_util.h"
20#include "base/strings/string16.h"
21#include "base/strings/string_split.h"
22#include "base/strings/stringprintf.h"
23#include "base/strings/utf_string_conversions.h"
24#include "base/time/time.h"
25#include "base/values.h"
26#include "chrome/browser/profiles/profile.h"
27#include "chrome/common/pref_names.h"
28#include "content/public/browser/content_browser_client.h"
29#include "content/public/browser/download_item.h"
30#include "net/base/net_util.h"
31#include "third_party/re2/re2/re2.h"
32#include "url/gurl.h"
33
34using content::DownloadDangerType;
35using content::DownloadItem;
36
37namespace {
38
39// Templatized base::Value::GetAs*().
40template <typename T> bool GetAs(const base::Value& in, T* out);
41template<> bool GetAs(const base::Value& in, bool* out) {
42  return in.GetAsBoolean(out);
43}
44template<> bool GetAs(const base::Value& in, int* out) {
45  return in.GetAsInteger(out);
46}
47template<> bool GetAs(const base::Value& in, std::string* out) {
48  return in.GetAsString(out);
49}
50template<> bool GetAs(const base::Value& in, base::string16* out) {
51  return in.GetAsString(out);
52}
53template<> bool GetAs(const base::Value& in, std::vector<base::string16>* out) {
54  out->clear();
55  const base::ListValue* list = NULL;
56  if (!in.GetAsList(&list))
57    return false;
58  for (size_t i = 0; i < list->GetSize(); ++i) {
59    base::string16 element;
60    if (!list->GetString(i, &element)) {
61      out->clear();
62      return false;
63    }
64    out->push_back(element);
65  }
66  return true;
67}
68
69// The next several functions are helpers for making Callbacks that access
70// DownloadItem fields.
71
72static bool MatchesQuery(
73    const std::vector<base::string16>& query_terms,
74    const DownloadItem& item) {
75  DCHECK(!query_terms.empty());
76  base::string16 url_raw(UTF8ToUTF16(item.GetOriginalUrl().spec()));
77  base::string16 url_formatted = url_raw;
78  if (item.GetBrowserContext()) {
79    Profile* profile = Profile::FromBrowserContext(item.GetBrowserContext());
80    url_formatted = net::FormatUrl(
81        item.GetOriginalUrl(),
82        profile->GetPrefs()->GetString(prefs::kAcceptLanguages));
83  }
84  base::string16 path(item.GetTargetFilePath().LossyDisplayName());
85
86  for (std::vector<base::string16>::const_iterator it = query_terms.begin();
87       it != query_terms.end(); ++it) {
88    base::string16 term = base::i18n::ToLower(*it);
89    if (!base::i18n::StringSearchIgnoringCaseAndAccents(
90            term, url_raw, NULL, NULL) &&
91        !base::i18n::StringSearchIgnoringCaseAndAccents(
92            term, url_formatted, NULL, NULL) &&
93        !base::i18n::StringSearchIgnoringCaseAndAccents(
94            term, path, NULL, NULL)) {
95      return false;
96    }
97  }
98  return true;
99}
100
101static int64 GetStartTimeMsEpoch(const DownloadItem& item) {
102  return (item.GetStartTime() - base::Time::UnixEpoch()).InMilliseconds();
103}
104
105static int64 GetEndTimeMsEpoch(const DownloadItem& item) {
106  return (item.GetEndTime() - base::Time::UnixEpoch()).InMilliseconds();
107}
108
109std::string TimeToISO8601(const base::Time& t) {
110  base::Time::Exploded exploded;
111  t.UTCExplode(&exploded);
112  return base::StringPrintf(
113      "%04d-%02d-%02dT%02d:%02d:%02d.%03dZ", exploded.year, exploded.month,
114      exploded.day_of_month, exploded.hour, exploded.minute, exploded.second,
115      exploded.millisecond);
116}
117
118static std::string GetStartTime(const DownloadItem& item) {
119  return TimeToISO8601(item.GetStartTime());
120}
121
122static std::string GetEndTime(const DownloadItem& item) {
123  return TimeToISO8601(item.GetEndTime());
124}
125
126static bool GetDangerAccepted(const DownloadItem& item) {
127  return (item.GetDangerType() ==
128          content::DOWNLOAD_DANGER_TYPE_USER_VALIDATED);
129}
130
131static bool GetExists(const DownloadItem& item) {
132  return !item.GetFileExternallyRemoved();
133}
134
135static base::string16 GetFilename(const DownloadItem& item) {
136  // This filename will be compared with strings that could be passed in by the
137  // user, who only sees LossyDisplayNames.
138  return item.GetTargetFilePath().LossyDisplayName();
139}
140
141static std::string GetFilenameUTF8(const DownloadItem& item) {
142  return UTF16ToUTF8(GetFilename(item));
143}
144
145static std::string GetUrl(const DownloadItem& item) {
146  return item.GetOriginalUrl().spec();
147}
148
149static DownloadItem::DownloadState GetState(const DownloadItem& item) {
150  return item.GetState();
151}
152
153static DownloadDangerType GetDangerType(const DownloadItem& item) {
154  return item.GetDangerType();
155}
156
157static int GetReceivedBytes(const DownloadItem& item) {
158  return item.GetReceivedBytes();
159}
160
161static int GetTotalBytes(const DownloadItem& item) {
162  return item.GetTotalBytes();
163}
164
165static std::string GetMimeType(const DownloadItem& item) {
166  return item.GetMimeType();
167}
168
169static bool IsPaused(const DownloadItem& item) {
170  return item.IsPaused();
171}
172
173enum ComparisonType {LT, EQ, GT};
174
175// Returns true if |item| matches the filter specified by |value|, |cmptype|,
176// and |accessor|. |accessor| is conceptually a function that takes a
177// DownloadItem and returns one of its fields, which is then compared to
178// |value|.
179template<typename ValueType>
180static bool FieldMatches(
181    const ValueType& value,
182    ComparisonType cmptype,
183    const base::Callback<ValueType(const DownloadItem&)>& accessor,
184    const DownloadItem& item) {
185  switch (cmptype) {
186    case LT: return accessor.Run(item) < value;
187    case EQ: return accessor.Run(item) == value;
188    case GT: return accessor.Run(item) > value;
189  }
190  NOTREACHED();
191  return false;
192}
193
194// Helper for building a Callback to FieldMatches<>().
195template <typename ValueType> DownloadQuery::FilterCallback BuildFilter(
196    const base::Value& value, ComparisonType cmptype,
197    ValueType (*accessor)(const DownloadItem&)) {
198  ValueType cpp_value;
199  if (!GetAs(value, &cpp_value)) return DownloadQuery::FilterCallback();
200  return base::Bind(&FieldMatches<ValueType>, cpp_value, cmptype,
201                    base::Bind(accessor));
202}
203
204// Returns true if |accessor.Run(item)| matches |pattern|.
205static bool FindRegex(
206    RE2* pattern,
207    const base::Callback<std::string(const DownloadItem&)>& accessor,
208    const DownloadItem& item) {
209  return RE2::PartialMatch(accessor.Run(item), *pattern);
210}
211
212// Helper for building a Callback to FindRegex().
213DownloadQuery::FilterCallback BuildRegexFilter(
214    const base::Value& regex_value,
215    std::string (*accessor)(const DownloadItem&)) {
216  std::string regex_str;
217  if (!GetAs(regex_value, &regex_str)) return DownloadQuery::FilterCallback();
218  scoped_ptr<RE2> pattern(new RE2(regex_str));
219  if (!pattern->ok()) return DownloadQuery::FilterCallback();
220  return base::Bind(&FindRegex, base::Owned(pattern.release()),
221                    base::Bind(accessor));
222}
223
224// Returns a ComparisonType to indicate whether a field in |left| is less than,
225// greater than or equal to the same field in |right|.
226template<typename ValueType>
227static ComparisonType Compare(
228    const base::Callback<ValueType(const DownloadItem&)>& accessor,
229    const DownloadItem& left, const DownloadItem& right) {
230  ValueType left_value = accessor.Run(left);
231  ValueType right_value = accessor.Run(right);
232  if (left_value > right_value) return GT;
233  if (left_value < right_value) return LT;
234  DCHECK_EQ(left_value, right_value);
235  return EQ;
236}
237
238}  // anonymous namespace
239
240DownloadQuery::DownloadQuery()
241  : limit_(kuint32max) {
242}
243
244DownloadQuery::~DownloadQuery() {
245}
246
247// AddFilter() pushes a new FilterCallback to filters_. Most FilterCallbacks are
248// Callbacks to FieldMatches<>(). Search() iterates over given DownloadItems,
249// discarding items for which any filter returns false. A DownloadQuery may have
250// zero or more FilterCallbacks.
251
252bool DownloadQuery::AddFilter(const DownloadQuery::FilterCallback& value) {
253  if (value.is_null()) return false;
254  filters_.push_back(value);
255  return true;
256}
257
258void DownloadQuery::AddFilter(DownloadItem::DownloadState state) {
259  AddFilter(base::Bind(&FieldMatches<DownloadItem::DownloadState>, state, EQ,
260      base::Bind(&GetState)));
261}
262
263void DownloadQuery::AddFilter(DownloadDangerType danger) {
264  AddFilter(base::Bind(&FieldMatches<DownloadDangerType>, danger, EQ,
265      base::Bind(&GetDangerType)));
266}
267
268bool DownloadQuery::AddFilter(DownloadQuery::FilterType type,
269                              const base::Value& value) {
270  switch (type) {
271    case FILTER_BYTES_RECEIVED:
272      return AddFilter(BuildFilter<int>(value, EQ, &GetReceivedBytes));
273    case FILTER_DANGER_ACCEPTED:
274      return AddFilter(BuildFilter<bool>(value, EQ, &GetDangerAccepted));
275    case FILTER_EXISTS:
276      return AddFilter(BuildFilter<bool>(value, EQ, &GetExists));
277    case FILTER_FILENAME:
278      return AddFilter(BuildFilter<base::string16>(value, EQ, &GetFilename));
279    case FILTER_FILENAME_REGEX:
280      return AddFilter(BuildRegexFilter(value, &GetFilenameUTF8));
281    case FILTER_MIME:
282      return AddFilter(BuildFilter<std::string>(value, EQ, &GetMimeType));
283    case FILTER_PAUSED:
284      return AddFilter(BuildFilter<bool>(value, EQ, &IsPaused));
285    case FILTER_QUERY: {
286      std::vector<base::string16> query_terms;
287      return GetAs(value, &query_terms) &&
288             (query_terms.empty() ||
289              AddFilter(base::Bind(&MatchesQuery, query_terms)));
290    }
291    case FILTER_ENDED_AFTER:
292      return AddFilter(BuildFilter<std::string>(value, GT, &GetEndTime));
293    case FILTER_ENDED_BEFORE:
294      return AddFilter(BuildFilter<std::string>(value, LT, &GetEndTime));
295    case FILTER_END_TIME:
296      return AddFilter(BuildFilter<std::string>(value, EQ, &GetEndTime));
297    case FILTER_STARTED_AFTER:
298      return AddFilter(BuildFilter<std::string>(value, GT, &GetStartTime));
299    case FILTER_STARTED_BEFORE:
300      return AddFilter(BuildFilter<std::string>(value, LT, &GetStartTime));
301    case FILTER_START_TIME:
302      return AddFilter(BuildFilter<std::string>(value, EQ, &GetStartTime));
303    case FILTER_TOTAL_BYTES:
304      return AddFilter(BuildFilter<int>(value, EQ, &GetTotalBytes));
305    case FILTER_TOTAL_BYTES_GREATER:
306      return AddFilter(BuildFilter<int>(value, GT, &GetTotalBytes));
307    case FILTER_TOTAL_BYTES_LESS:
308      return AddFilter(BuildFilter<int>(value, LT, &GetTotalBytes));
309    case FILTER_URL:
310      return AddFilter(BuildFilter<std::string>(value, EQ, &GetUrl));
311    case FILTER_URL_REGEX:
312      return AddFilter(BuildRegexFilter(value, &GetUrl));
313  }
314  return false;
315}
316
317bool DownloadQuery::Matches(const DownloadItem& item) const {
318  for (FilterCallbackVector::const_iterator filter = filters_.begin();
319        filter != filters_.end(); ++filter) {
320    if (!filter->Run(item))
321      return false;
322  }
323  return true;
324}
325
326// AddSorter() creates a Sorter and pushes it onto sorters_. A Sorter is a
327// direction and a Callback to Compare<>(). After filtering, Search() makes a
328// DownloadComparator functor from the sorters_ and passes the
329// DownloadComparator to std::partial_sort. std::partial_sort calls the
330// DownloadComparator with different pairs of DownloadItems.  DownloadComparator
331// iterates over the sorters until a callback returns ComparisonType LT or GT.
332// DownloadComparator returns true or false depending on that ComparisonType and
333// the sorter's direction in order to indicate to std::partial_sort whether the
334// left item is after or before the right item. If all sorters return EQ, then
335// DownloadComparator compares GetId. A DownloadQuery may have zero or more
336// Sorters, but there is one DownloadComparator per call to Search().
337
338struct DownloadQuery::Sorter {
339  typedef base::Callback<ComparisonType(
340      const DownloadItem&, const DownloadItem&)> SortType;
341
342  template<typename ValueType>
343  static Sorter Build(DownloadQuery::SortDirection adirection,
344                         ValueType (*accessor)(const DownloadItem&)) {
345    return Sorter(adirection, base::Bind(&Compare<ValueType>,
346        base::Bind(accessor)));
347  }
348
349  Sorter(DownloadQuery::SortDirection adirection,
350            const SortType& asorter)
351    : direction(adirection),
352      sorter(asorter) {
353  }
354  ~Sorter() {}
355
356  DownloadQuery::SortDirection direction;
357  SortType sorter;
358};
359
360class DownloadQuery::DownloadComparator {
361 public:
362  explicit DownloadComparator(const DownloadQuery::SorterVector& terms)
363    : terms_(terms) {
364  }
365
366  // Returns true if |left| sorts before |right|.
367  bool operator() (const DownloadItem* left, const DownloadItem* right);
368
369 private:
370  const DownloadQuery::SorterVector& terms_;
371
372  // std::sort requires this class to be copyable.
373};
374
375bool DownloadQuery::DownloadComparator::operator() (
376    const DownloadItem* left, const DownloadItem* right) {
377  for (DownloadQuery::SorterVector::const_iterator term = terms_.begin();
378       term != terms_.end(); ++term) {
379    switch (term->sorter.Run(*left, *right)) {
380      case LT: return term->direction == DownloadQuery::ASCENDING;
381      case GT: return term->direction == DownloadQuery::DESCENDING;
382      case EQ: break;  // break the switch but not the loop
383    }
384  }
385  CHECK_NE(left->GetId(), right->GetId());
386  return left->GetId() < right->GetId();
387}
388
389void DownloadQuery::AddSorter(DownloadQuery::SortType type,
390                              DownloadQuery::SortDirection direction) {
391  switch (type) {
392    case SORT_END_TIME:
393      sorters_.push_back(Sorter::Build<int64>(direction, &GetEndTimeMsEpoch));
394      break;
395    case SORT_START_TIME:
396      sorters_.push_back(Sorter::Build<int64>(direction, &GetStartTimeMsEpoch));
397      break;
398    case SORT_URL:
399      sorters_.push_back(Sorter::Build<std::string>(direction, &GetUrl));
400      break;
401    case SORT_FILENAME:
402      sorters_.push_back(
403          Sorter::Build<base::string16>(direction, &GetFilename));
404      break;
405    case SORT_DANGER:
406      sorters_.push_back(Sorter::Build<DownloadDangerType>(
407          direction, &GetDangerType));
408      break;
409    case SORT_DANGER_ACCEPTED:
410      sorters_.push_back(Sorter::Build<bool>(direction, &GetDangerAccepted));
411      break;
412    case SORT_EXISTS:
413      sorters_.push_back(Sorter::Build<bool>(direction, &GetExists));
414      break;
415    case SORT_STATE:
416      sorters_.push_back(Sorter::Build<DownloadItem::DownloadState>(
417          direction, &GetState));
418      break;
419    case SORT_PAUSED:
420      sorters_.push_back(Sorter::Build<bool>(direction, &IsPaused));
421      break;
422    case SORT_MIME:
423      sorters_.push_back(Sorter::Build<std::string>(direction, &GetMimeType));
424      break;
425    case SORT_BYTES_RECEIVED:
426      sorters_.push_back(Sorter::Build<int>(direction, &GetReceivedBytes));
427      break;
428    case SORT_TOTAL_BYTES:
429      sorters_.push_back(Sorter::Build<int>(direction, &GetTotalBytes));
430      break;
431  }
432}
433
434void DownloadQuery::FinishSearch(DownloadQuery::DownloadVector* results) const {
435  if (!sorters_.empty())
436    std::partial_sort(results->begin(),
437                      results->begin() + std::min(limit_, results->size()),
438                      results->end(),
439                      DownloadComparator(sorters_));
440  if (results->size() > limit_)
441    results->resize(limit_);
442}
443