1#include <cstdlib> 2#include <ctime> 3#include <fstream> 4#include <iostream> 5#include <limits> 6#include <string> 7#include <utility> 8#include <vector> 9 10#include <marisa.h> 11 12#include "./cmdopt.h" 13 14namespace { 15 16typedef std::pair<std::string, double> Key; 17 18int param_min_num_tries = 1; 19int param_max_num_tries = 10; 20int param_trie = MARISA_DEFAULT_TRIE; 21int param_tail = MARISA_DEFAULT_TAIL; 22int param_order = MARISA_DEFAULT_ORDER; 23bool predict_strs_flag = false; 24bool speed_flag = true; 25 26class Clock { 27 public: 28 Clock() : cl_(std::clock()) {} 29 30 void reset() { 31 cl_ = std::clock(); 32 } 33 34 double elasped() const { 35 std::clock_t cur = std::clock(); 36 return (cur == cl_) ? 0.01 : (1.0 * (cur - cl_) / CLOCKS_PER_SEC); 37 } 38 39 private: 40 std::clock_t cl_; 41}; 42 43void print_help(const char *cmd) { 44 std::cerr << "Usage: " << cmd << " [OPTION]... [FILE]...\n\n" 45 "Options:\n" 46 " -N, --min-num-tries=[N] limits the number of tries to N" 47 " (default: 1)\n" 48 " -n, --max-num-tries=[N] limits the number of tries to N" 49 " (default: 10)\n" 50 " -P, --patricia-trie build patricia tries (default)\n" 51 " -p, --prefix-trie build prefix tries\n" 52 " -T, --text-tail build a dictionary with text TAIL (default)\n" 53 " -b, --binary-tail build a dictionary with binary TAIL\n" 54 " -t, --without-tail build a dictionary without TAIL\n" 55 " -w, --weight-order arrange siblings in weight order (default)\n" 56 " -l, --label-order arrange siblings in label order\n" 57 " -I, --predict-ids get key IDs in predictive searches (default)\n" 58 " -i, --predict-strs restore key strings in predictive searches\n" 59 " -S, --print-speed print speed [1000 keys/s] (default)\n" 60 " -s, --print-time print time [us/key]\n" 61 " -h, --help print this help\n" 62 << std::endl; 63} 64 65void print_config() { 66 std::cout << "#tries: " << param_min_num_tries 67 << " - " << param_max_num_tries << std::endl; 68 69 switch (param_trie) { 70 case MARISA_PATRICIA_TRIE: { 71 std::cout << "trie: patricia" << std::endl; 72 break; 73 } 74 case MARISA_PREFIX_TRIE: { 75 std::cout << "trie: prefix" << std::endl; 76 break; 77 } 78 } 79 80 switch (param_tail) { 81 case MARISA_WITHOUT_TAIL: { 82 std::cout << "tail: no" << std::endl; 83 break; 84 } 85 case MARISA_BINARY_TAIL: { 86 std::cout << "tail: binary" << std::endl; 87 break; 88 } 89 case MARISA_TEXT_TAIL: { 90 std::cout << "tail: text" << std::endl; 91 break; 92 } 93 } 94 95 switch (param_order) { 96 case MARISA_LABEL_ORDER: { 97 std::cout << "order: label" << std::endl; 98 break; 99 } 100 case MARISA_WEIGHT_ORDER: { 101 std::cout << "order: weight" << std::endl; 102 break; 103 } 104 } 105 106 if (predict_strs_flag) { 107 std::cout << "predict: both IDs and strings" << std::endl; 108 } else { 109 std::cout << "predict: only IDs" << std::endl; 110 } 111} 112 113void print_time_info(std::size_t num_keys, double elasped) { 114 if (speed_flag) { 115 if (elasped == 0.0) { 116 std::printf(" %7s", "-"); 117 } else { 118 std::printf(" %7.2f", num_keys / elasped / 1000.0); 119 } 120 } else { 121 if (num_keys == 0) { 122 std::printf(" %7s", "-"); 123 } else { 124 std::printf(" %7.3f", 1000000.0 * elasped / num_keys); 125 } 126 } 127} 128 129void read_keys(std::istream *input, std::vector<Key> *keys) { 130 Key key; 131 std::string line; 132 while (std::getline(*input, line)) { 133 const std::string::size_type delim_pos = line.find_last_of('\t'); 134 if (delim_pos != line.npos) { 135 char *end_of_value; 136 key.second = std::strtod(&line[delim_pos + 1], &end_of_value); 137 if (*end_of_value == '\0') { 138 line.resize(delim_pos); 139 } else { 140 key.second = 1.0; 141 } 142 } else { 143 key.second = 1.0; 144 } 145 key.first = line; 146 keys->push_back(key); 147 } 148} 149 150int read_keys(const char * const *args, std::size_t num_args, 151 std::vector<Key> *keys) { 152 if (num_args == 0) { 153 read_keys(&std::cin, keys); 154 } 155 for (std::size_t i = 0; i < num_args; ++i) { 156 std::ifstream input_file(args[i], std::ios::binary); 157 if (!input_file) { 158 std::cerr << "error: failed to open a keyset file: " 159 << args[i] << std::endl; 160 return 10; 161 } 162 read_keys(&input_file, keys); 163 } 164 std::cout << "#keys: " << keys->size() << std::endl; 165 std::size_t total_length = 0; 166 for (std::size_t i = 0; i < keys->size(); ++i) { 167 total_length += (*keys)[i].first.length(); 168 } 169 std::cout << "total length: " << total_length << std::endl; 170 return 0; 171} 172 173void benchmark_build(const std::vector<Key> &keys, int num_tries, 174 marisa::Trie *trie, std::vector<marisa::UInt32> *key_ids) { 175 Clock cl; 176 trie->build(keys, key_ids, num_tries 177 | param_trie | param_tail | param_order); 178 std::printf(" %9lu", (unsigned long)trie->num_nodes()); 179 std::printf(" %9lu", (unsigned long)trie->total_size()); 180 print_time_info(keys.size(), cl.elasped()); 181} 182 183void benchmark_restore(const marisa::Trie &trie, 184 const std::vector<Key> &keys, 185 const std::vector<marisa::UInt32> &key_ids) { 186 Clock cl; 187 std::string key; 188 for (std::size_t i = 0; i < key_ids.size(); ++i) { 189 key.clear(); 190 trie.restore(key_ids[i], &key); 191 if (key != keys[i].first) { 192 std::cerr << "error: restore() failed" << std::endl; 193 return; 194 } 195 } 196 print_time_info(key_ids.size(), cl.elasped()); 197} 198 199void benchmark_lookup(const marisa::Trie &trie, 200 const std::vector<Key> &keys, 201 const std::vector<marisa::UInt32> &key_ids) { 202 Clock cl; 203 for (std::size_t i = 0; i < keys.size(); ++i) { 204 const marisa::UInt32 key_id = trie.lookup(keys[i].first); 205 if (key_id != key_ids[i]) { 206 std::cerr << "error: lookup() failed" << std::endl; 207 return; 208 } 209 } 210 print_time_info(keys.size(), cl.elasped()); 211} 212 213void benchmark_find(const marisa::Trie &trie, 214 const std::vector<Key> &keys, 215 const std::vector<marisa::UInt32> &key_ids) { 216 Clock cl; 217 std::vector<marisa::UInt32> found_key_ids; 218 for (std::size_t i = 0; i < keys.size(); ++i) { 219 found_key_ids.clear(); 220 const std::size_t num_keys = trie.find(keys[i].first, &found_key_ids); 221 if ((num_keys == 0) || (found_key_ids.back() != key_ids[i])) { 222 std::cerr << "error: find() failed" << std::endl; 223 return; 224 } 225 } 226 print_time_info(keys.size(), cl.elasped()); 227} 228 229void benchmark_predict_breadth_first(const marisa::Trie &trie, 230 const std::vector<Key> &keys, 231 const std::vector<marisa::UInt32> &key_ids) { 232 Clock cl; 233 std::vector<marisa::UInt32> found_key_ids; 234 std::vector<std::string> found_keys; 235 std::vector<std::string> *found_keys_ref = 236 predict_strs_flag ? &found_keys : NULL; 237 for (std::size_t i = 0; i < keys.size(); ++i) { 238 found_key_ids.clear(); 239 found_keys.clear(); 240 const std::size_t num_keys = trie.predict_breadth_first( 241 keys[i].first, &found_key_ids, found_keys_ref); 242 if ((num_keys == 0) || (found_key_ids.front() != key_ids[i])) { 243 std::cerr << "error: predict() failed" << std::endl; 244 return; 245 } 246 } 247 print_time_info(keys.size(), cl.elasped()); 248} 249 250void benchmark_predict_depth_first(const marisa::Trie &trie, 251 const std::vector<Key> &keys, 252 const std::vector<marisa::UInt32> &key_ids) { 253 Clock cl; 254 std::vector<marisa::UInt32> found_key_ids; 255 std::vector<std::string> found_keys; 256 std::vector<std::string> *found_keys_ref = 257 predict_strs_flag ? &found_keys : NULL; 258 for (std::size_t i = 0; i < keys.size(); ++i) { 259 found_key_ids.clear(); 260 found_keys.clear(); 261 const std::size_t num_keys = trie.predict_depth_first( 262 keys[i].first, &found_key_ids, found_keys_ref); 263 if ((num_keys == 0) || (found_key_ids.front() != key_ids[i])) { 264 std::cerr << "error: predict() failed" << std::endl; 265 return; 266 } 267 } 268 print_time_info(keys.size(), cl.elasped()); 269} 270 271void benchmark(const std::vector<Key> &keys, int num_tries) { 272 std::printf("%6d", num_tries); 273 marisa::Trie trie; 274 std::vector<marisa::UInt32> key_ids; 275 benchmark_build(keys, num_tries, &trie, &key_ids); 276 if (!trie.empty()) { 277 benchmark_restore(trie, keys, key_ids); 278 benchmark_lookup(trie, keys, key_ids); 279 benchmark_find(trie, keys, key_ids); 280 benchmark_predict_breadth_first(trie, keys, key_ids); 281 benchmark_predict_depth_first(trie, keys, key_ids); 282 } 283 std::printf("\n"); 284} 285 286int benchmark(const char * const *args, std::size_t num_args) try { 287 std::vector<Key> keys; 288 const int ret = read_keys(args, num_args, &keys); 289 if (ret != 0) { 290 return ret; 291 } 292 std::printf("------+---------+---------+-------+" 293 "-------+-------+-------+-------+-------\n"); 294 std::printf("%6s %9s %9s %7s %7s %7s %7s %7s %7s\n", 295 "#tries", "#nodes", "size", 296 "build", "restore", "lookup", "find", "predict", "predict"); 297 std::printf("%6s %9s %9s %7s %7s %7s %7s %7s %7s\n", 298 "", "", "", "", "", "", "", "breadth", "depth"); 299 if (speed_flag) { 300 std::printf("%6s %9s %9s %7s %7s %7s %7s %7s %7s\n", 301 "", "", "[bytes]", 302 "[K/s]", "[K/s]", "[K/s]", "[K/s]", "[K/s]", "[K/s]"); 303 } else { 304 std::printf("%6s %9s %9s %7s %7s %7s %7s %7s %7s\n", 305 "", "", "[bytes]", "[us]", "[us]", "[us]", "[us]", "[us]", "[us]"); 306 } 307 std::printf("------+---------+---------+-------+" 308 "-------+-------+-------+-------+-------\n"); 309 for (int i = param_min_num_tries; i <= param_max_num_tries; ++i) { 310 benchmark(keys, i); 311 } 312 std::printf("------+---------+---------+-------+" 313 "-------+-------+-------+-------+-------\n"); 314 return 0; 315} catch (const marisa::Exception &ex) { 316 std::cerr << ex.filename() << ':' << ex.line() 317 << ": " << ex.what() << std::endl; 318 return -1; 319} 320 321} // namespace 322 323int main(int argc, char *argv[]) { 324 std::ios::sync_with_stdio(false); 325 326 ::cmdopt_option long_options[] = { 327 { "min-num-tries", 1, NULL, 'N' }, 328 { "max-num-tries", 1, NULL, 'n' }, 329 { "patricia-trie", 0, NULL, 'P' }, 330 { "prefix-trie", 0, NULL, 'p' }, 331 { "text-tail", 0, NULL, 'T' }, 332 { "binary-tail", 0, NULL, 'b' }, 333 { "without-tail", 0, NULL, 't' }, 334 { "weight-order", 0, NULL, 'w' }, 335 { "label-order", 0, NULL, 'l' }, 336 { "predict-ids", 0, NULL, 'I' }, 337 { "predict-strs", 0, NULL, 'i' }, 338 { "print-speed", 0, NULL, 'S' }, 339 { "print-time", 0, NULL, 's' }, 340 { "help", 0, NULL, 'h' }, 341 { NULL, 0, NULL, 0 } 342 }; 343 ::cmdopt_t cmdopt; 344 ::cmdopt_init(&cmdopt, argc, argv, "N:n:PpTbtwlIiSsh", long_options); 345 int label; 346 while ((label = ::cmdopt_get(&cmdopt)) != -1) { 347 switch (label) { 348 case 'N': { 349 char *end_of_value; 350 const long value = std::strtol(cmdopt.optarg, &end_of_value, 10); 351 if ((*end_of_value != '\0') || (value <= 0) || 352 (value > MARISA_MAX_NUM_TRIES)) { 353 std::cerr << "error: option `-n' with an invalid argument: " 354 << cmdopt.optarg << std::endl; 355 } 356 param_min_num_tries = (int)value; 357 break; 358 } 359 case 'n': { 360 char *end_of_value; 361 const long value = std::strtol(cmdopt.optarg, &end_of_value, 10); 362 if ((*end_of_value != '\0') || (value <= 0) || 363 (value > MARISA_MAX_NUM_TRIES)) { 364 std::cerr << "error: option `-n' with an invalid argument: " 365 << cmdopt.optarg << std::endl; 366 } 367 param_max_num_tries = (int)value; 368 break; 369 } 370 case 'P': { 371 param_trie = MARISA_PATRICIA_TRIE; 372 break; 373 } 374 case 'p': { 375 param_trie = MARISA_PREFIX_TRIE; 376 break; 377 } 378 case 'T': { 379 param_tail = MARISA_TEXT_TAIL; 380 break; 381 } 382 case 'b': { 383 param_tail = MARISA_BINARY_TAIL; 384 break; 385 } 386 case 't': { 387 param_tail = MARISA_WITHOUT_TAIL; 388 break; 389 } 390 case 'w': { 391 param_order = MARISA_WEIGHT_ORDER; 392 break; 393 } 394 case 'l': { 395 param_order = MARISA_LABEL_ORDER; 396 break; 397 } 398 case 'I': { 399 predict_strs_flag = false; 400 break; 401 } 402 case 'i': { 403 predict_strs_flag = true; 404 break; 405 } 406 case 'S': { 407 speed_flag = true; 408 break; 409 } 410 case 's': { 411 speed_flag = false; 412 break; 413 } 414 case 'h': { 415 print_help(argv[0]); 416 return 0; 417 } 418 default: { 419 return 1; 420 } 421 } 422 } 423 print_config(); 424 return benchmark(cmdopt.argv + cmdopt.optind, cmdopt.argc - cmdopt.optind); 425} 426