1#include <algorithm>
2#include <functional>
3#include <queue>
4#include <stdexcept>
5
6#include "range.h"
7#include "trie.h"
8
9namespace marisa_alpha {
10
11void Trie::build(const char * const *keys, std::size_t num_keys,
12    const std::size_t *key_lengths, const double *key_weights,
13    UInt32 *key_ids, int flags) {
14  MARISA_ALPHA_THROW_IF((keys == NULL) && (num_keys != 0),
15      MARISA_ALPHA_PARAM_ERROR);
16  Vector<Key<String> > temp_keys;
17  temp_keys.resize(num_keys);
18  for (std::size_t i = 0; i < temp_keys.size(); ++i) {
19    MARISA_ALPHA_THROW_IF(keys[i] == NULL, MARISA_ALPHA_PARAM_ERROR);
20    std::size_t length = 0;
21    if (key_lengths == NULL) {
22      while (keys[i][length] != '\0') {
23        ++length;
24      }
25    } else {
26      length = key_lengths[i];
27    }
28    MARISA_ALPHA_THROW_IF(length > MARISA_ALPHA_MAX_LENGTH,
29        MARISA_ALPHA_SIZE_ERROR);
30    temp_keys[i].set_str(String(keys[i], length));
31    temp_keys[i].set_weight((key_weights != NULL) ? key_weights[i] : 1.0);
32  }
33  build_trie(temp_keys, key_ids, flags);
34}
35
36void Trie::build(const std::vector<std::string> &keys,
37    std::vector<UInt32> *key_ids, int flags) {
38  Vector<Key<String> > temp_keys;
39  temp_keys.resize(keys.size());
40  for (std::size_t i = 0; i < temp_keys.size(); ++i) {
41    MARISA_ALPHA_THROW_IF(keys[i].length() > MARISA_ALPHA_MAX_LENGTH,
42        MARISA_ALPHA_SIZE_ERROR);
43    temp_keys[i].set_str(String(keys[i].c_str(), keys[i].length()));
44    temp_keys[i].set_weight(1.0);
45  }
46  build_trie(temp_keys, key_ids, flags);
47}
48
49void Trie::build(const std::vector<std::pair<std::string, double> > &keys,
50    std::vector<UInt32> *key_ids, int flags) {
51  Vector<Key<String> > temp_keys;
52  temp_keys.resize(keys.size());
53  for (std::size_t i = 0; i < temp_keys.size(); ++i) {
54    MARISA_ALPHA_THROW_IF(keys[i].first.length() > MARISA_ALPHA_MAX_LENGTH,
55        MARISA_ALPHA_SIZE_ERROR);
56    temp_keys[i].set_str(String(
57        keys[i].first.c_str(), keys[i].first.length()));
58    temp_keys[i].set_weight(keys[i].second);
59  }
60  build_trie(temp_keys, key_ids, flags);
61}
62
63void Trie::build_trie(Vector<Key<String> > &keys,
64    std::vector<UInt32> *key_ids, int flags) {
65  if (key_ids == NULL) {
66    build_trie(keys, static_cast<UInt32 *>(NULL), flags);
67    return;
68  }
69  try {
70    std::vector<UInt32> temp_key_ids(keys.size());
71    build_trie(keys, temp_key_ids.empty() ? NULL : &temp_key_ids[0], flags);
72    key_ids->swap(temp_key_ids);
73  } catch (const std::bad_alloc &) {
74    MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
75  } catch (const std::length_error &) {
76    MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
77  }
78}
79
80void Trie::build_trie(Vector<Key<String> > &keys,
81    UInt32 *key_ids, int flags) {
82  Trie temp;
83  Vector<UInt32> terminals;
84  Progress progress(flags);
85  MARISA_ALPHA_THROW_IF(!progress.is_valid(), MARISA_ALPHA_PARAM_ERROR);
86  temp.build_trie(keys, &terminals, progress);
87
88  typedef std::pair<UInt32, UInt32> TerminalIdPair;
89  Vector<TerminalIdPair> pairs;
90  pairs.resize(terminals.size());
91  for (UInt32 i = 0; i < pairs.size(); ++i) {
92    pairs[i].first = terminals[i];
93    pairs[i].second = i;
94  }
95  terminals.clear();
96  std::sort(pairs.begin(), pairs.end());
97
98  UInt32 node = 0;
99  for (UInt32 i = 0; i < pairs.size(); ++i) {
100    while (node < pairs[i].first) {
101      temp.terminal_flags_.push_back(false);
102      ++node;
103    }
104    if (node == pairs[i].first) {
105      temp.terminal_flags_.push_back(true);
106      ++node;
107    }
108  }
109  while (node < temp.labels_.size()) {
110    temp.terminal_flags_.push_back(false);
111    ++node;
112  }
113  terminal_flags_.push_back(false);
114  temp.terminal_flags_.build();
115  temp.terminal_flags_.clear_select0s();
116  progress.test_total_size(temp.terminal_flags_.total_size());
117
118  if (key_ids != NULL) {
119    for (UInt32 i = 0; i < pairs.size(); ++i) {
120      key_ids[pairs[i].second] = temp.node_to_key_id(pairs[i].first);
121    }
122  }
123  MARISA_ALPHA_THROW_IF(progress.total_size() != temp.total_size(),
124      MARISA_ALPHA_UNEXPECTED_ERROR);
125  temp.swap(this);
126}
127
128template <typename T>
129void Trie::build_trie(Vector<Key<T> > &keys,
130    Vector<UInt32> *terminals, Progress &progress) {
131  build_cur(keys, terminals, progress);
132  progress.test_total_size(louds_.total_size());
133  progress.test_total_size(sizeof(num_first_branches_));
134  progress.test_total_size(sizeof(num_keys_));
135  if (link_flags_.empty()) {
136    labels_.shrink();
137    progress.test_total_size(labels_.total_size());
138    progress.test_total_size(link_flags_.total_size());
139    progress.test_total_size(links_.total_size());
140    progress.test_total_size(tail_.total_size());
141    return;
142  }
143
144  Vector<UInt32> next_terminals;
145  build_next(keys, &next_terminals, progress);
146
147  if (has_trie()) {
148    progress.test_total_size(trie_->terminal_flags_.total_size());
149  } else if (tail_.mode() == MARISA_ALPHA_BINARY_TAIL) {
150    labels_.push_back('\0');
151    link_flags_.push_back(true);
152  }
153  link_flags_.build();
154
155  for (UInt32 i = 0; i < next_terminals.size(); ++i) {
156    labels_[link_flags_.select1(i)] = (UInt8)(next_terminals[i] % 256);
157    next_terminals[i] /= 256;
158  }
159  link_flags_.clear_select0s();
160  if (has_trie() || (tail_.mode() == MARISA_ALPHA_TEXT_TAIL)) {
161    link_flags_.clear_select1s();
162  }
163
164  links_.build(next_terminals);
165  labels_.shrink();
166  progress.test_total_size(labels_.total_size());
167  progress.test_total_size(link_flags_.total_size());
168  progress.test_total_size(links_.total_size());
169  progress.test_total_size(tail_.total_size());
170}
171
172template <typename T>
173void Trie::build_cur(Vector<Key<T> > &keys,
174    Vector<UInt32> *terminals, Progress &progress) try {
175  num_keys_ = sort_keys(keys);
176  louds_.push_back(true);
177  louds_.push_back(false);
178  labels_.push_back('\0');
179  link_flags_.push_back(false);
180
181  Vector<Key<T> > rest_keys;
182  std::queue<Range> queue;
183  Vector<WRange> wranges;
184  queue.push(Range(0, (UInt32)keys.size(), 0));
185  while (!queue.empty()) {
186    const UInt32 node = (UInt32)(link_flags_.size() - queue.size());
187    Range range = queue.front();
188    queue.pop();
189
190    while ((range.begin() < range.end()) &&
191        (keys[range.begin()].str().length() == range.pos())) {
192      keys[range.begin()].set_terminal(node);
193      range.set_begin(range.begin() + 1);
194    }
195    if (range.begin() == range.end()) {
196      louds_.push_back(false);
197      continue;
198    }
199
200    wranges.clear();
201    double weight = keys[range.begin()].weight();
202    for (UInt32 i = range.begin() + 1; i < range.end(); ++i) {
203      if (keys[i - 1].str()[range.pos()] != keys[i].str()[range.pos()]) {
204        wranges.push_back(WRange(range.begin(), i, range.pos(), weight));
205        range.set_begin(i);
206        weight = 0.0;
207      }
208      weight += keys[i].weight();
209    }
210    wranges.push_back(WRange(range, weight));
211    if (progress.order() == MARISA_ALPHA_WEIGHT_ORDER) {
212      std::stable_sort(wranges.begin(), wranges.end(), std::greater<WRange>());
213    }
214    if (node == 0) {
215      num_first_branches_ = wranges.size();
216    }
217    for (UInt32 i = 0; i < wranges.size(); ++i) {
218      const WRange &wrange = wranges[i];
219      UInt32 pos = wrange.pos() + 1;
220      if ((progress.tail() != MARISA_ALPHA_WITHOUT_TAIL) ||
221          !progress.is_last()) {
222        while (pos < keys[wrange.begin()].str().length()) {
223          UInt32 j;
224          for (j = wrange.begin() + 1; j < wrange.end(); ++j) {
225            if (keys[j - 1].str()[pos] != keys[j].str()[pos]) {
226              break;
227            }
228          }
229          if (j < wrange.end()) {
230            break;
231          }
232          ++pos;
233        }
234      }
235      if ((progress.trie() != MARISA_ALPHA_PATRICIA_TRIE) &&
236          (pos != keys[wrange.end() - 1].str().length())) {
237        pos = wrange.pos() + 1;
238      }
239      louds_.push_back(true);
240      if (pos == wrange.pos() + 1) {
241        labels_.push_back(keys[wrange.begin()].str()[wrange.pos()]);
242        link_flags_.push_back(false);
243      } else {
244        labels_.push_back('\0');
245        link_flags_.push_back(true);
246        Key<T> rest_key;
247        rest_key.set_str(keys[wrange.begin()].str().substr(
248            wrange.pos(), pos - wrange.pos()));
249        rest_key.set_weight(wrange.weight());
250        rest_keys.push_back(rest_key);
251      }
252      wranges[i].set_pos(pos);
253      queue.push(wranges[i].range());
254    }
255    louds_.push_back(false);
256  }
257  louds_.push_back(false);
258  louds_.build();
259  if (progress.trie_id() != 0) {
260    louds_.clear_select0s();
261  }
262  if (rest_keys.empty()) {
263    link_flags_.clear();
264  }
265
266  build_terminals(keys, terminals);
267  keys.swap(&rest_keys);
268} catch (const std::bad_alloc &) {
269  MARISA_ALPHA_THROW(MARISA_ALPHA_MEMORY_ERROR);
270} catch (const std::length_error &) {
271  MARISA_ALPHA_THROW(MARISA_ALPHA_SIZE_ERROR);
272}
273
274void Trie::build_next(Vector<Key<String> > &keys,
275    Vector<UInt32> *terminals, Progress &progress) {
276  if (progress.is_last()) {
277    Vector<String> strs;
278    strs.resize(keys.size());
279    for (UInt32 i = 0; i < strs.size(); ++i) {
280      strs[i] = keys[i].str();
281    }
282    tail_.build(strs, terminals, progress.tail());
283    return;
284  }
285  Vector<Key<RString> > rkeys;
286  rkeys.resize(keys.size());
287  for (UInt32 i = 0; i < rkeys.size(); ++i) {
288    rkeys[i].set_str(RString(keys[i].str()));
289    rkeys[i].set_weight(keys[i].weight());
290  }
291  keys.clear();
292  trie_.reset(new (std::nothrow) Trie);
293  MARISA_ALPHA_THROW_IF(!has_trie(), MARISA_ALPHA_MEMORY_ERROR);
294  trie_->build_trie(rkeys, terminals, ++progress);
295}
296
297void Trie::build_next(Vector<Key<RString> > &rkeys,
298    Vector<UInt32> *terminals, Progress &progress) {
299  if (progress.is_last()) {
300    Vector<String> strs;
301    strs.resize(rkeys.size());
302    for (UInt32 i = 0; i < strs.size(); ++i) {
303      strs[i] = String(rkeys[i].str().ptr(), rkeys[i].str().length());
304    }
305    tail_.build(strs, terminals, progress.tail());
306    return;
307  }
308  trie_.reset(new (std::nothrow) Trie);
309  MARISA_ALPHA_THROW_IF(!has_trie(), MARISA_ALPHA_MEMORY_ERROR);
310  trie_->build_trie(rkeys, terminals, ++progress);
311}
312
313template <typename T>
314UInt32 Trie::sort_keys(Vector<Key<T> > &keys) const {
315  if (keys.empty()) {
316    return 0;
317  }
318  for (UInt32 i = 0; i < keys.size(); ++i) {
319    keys[i].set_id(i);
320  }
321  std::sort(keys.begin(), keys.end());
322  UInt32 count = 1;
323  for (UInt32 i = 1; i < keys.size(); ++i) {
324    if (keys[i - 1].str() != keys[i].str()) {
325      ++count;
326    }
327  }
328  return count;
329}
330
331template <typename T>
332void Trie::build_terminals(const Vector<Key<T> > &keys,
333    Vector<UInt32> *terminals) const {
334  Vector<UInt32> temp_terminals;
335  temp_terminals.resize(keys.size());
336  for (UInt32 i = 0; i < keys.size(); ++i) {
337    temp_terminals[keys[i].id()] = keys[i].terminal();
338  }
339  temp_terminals.swap(terminals);
340}
341
342}  // namespace marisa_alpha
343