1// extract-main.h
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Copyright 2005-2010 Google, Inc.
16// Author: riley@google.com (Michael Riley)
17// Modified: jpr@google.com (Jake Ratkiewicz) to use the new arc-dispatch
18
19// \file
20// Extracts component FSTs from an finite-state archive.
21//
22
23#ifndef FST_EXTENSIONS_FAR_EXTRACT_H__
24#define FST_EXTENSIONS_FAR_EXTRACT_H__
25
26#include <string>
27#include <vector>
28using std::vector;
29
30#include <fst/extensions/far/far.h>
31
32namespace fst {
33
34template<class Arc>
35inline void FarWriteFst(const Fst<Arc>* fst, string key,
36                        string* okey, int* nrep,
37                        const int32 &generate_filenames, int i,
38                        const string &filename_prefix,
39                        const string &filename_suffix) {
40  if (key == *okey)
41    ++*nrep;
42  else
43    *nrep = 0;
44
45  *okey = key;
46
47  string ofilename;
48  if (generate_filenames) {
49    ostringstream tmp;
50    tmp.width(generate_filenames);
51    tmp.fill('0');
52    tmp << i;
53    ofilename = tmp.str();
54  } else {
55    if (*nrep > 0) {
56      ostringstream tmp;
57      tmp << '.' << nrep;
58      key.append(tmp.str().data(), tmp.str().size());
59    }
60    ofilename = key;
61  }
62  fst->Write(filename_prefix + ofilename + filename_suffix);
63}
64
65template<class Arc>
66void FarExtract(const vector<string> &ifilenames,
67                const int32 &generate_filenames,
68                const string &keys,
69                const string &key_separator,
70                const string &range_delimiter,
71                const string &filename_prefix,
72                const string &filename_suffix) {
73  FarReader<Arc> *far_reader = FarReader<Arc>::Open(ifilenames);
74  if (!far_reader) return;
75
76  string okey;
77  int nrep = 0;
78
79  vector<char *> key_vector;
80  // User has specified a set of fsts to extract, where some of the "fsts" could
81  // be ranges.
82  if (!keys.empty()) {
83    char *keys_cstr = new char[keys.size()+1];
84    strcpy(keys_cstr, keys.c_str());
85    SplitToVector(keys_cstr, key_separator.c_str(), &key_vector, true);
86    int i = 0;
87    for (int k = 0; k < key_vector.size(); ++k, ++i) {
88      string key = string(key_vector[k]);
89      char *key_cstr = new char[key.size()+1];
90      strcpy(key_cstr, key.c_str());
91      vector<char *> range_vector;
92      SplitToVector(key_cstr, range_delimiter.c_str(), &range_vector, false);
93      if (range_vector.size() == 1) {  // Not a range
94        if (!far_reader->Find(key)) {
95          LOG(ERROR) << "FarExtract: Cannot find key: " << key;
96          return;
97        }
98        const Fst<Arc> &fst = far_reader->GetFst();
99        FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i,
100                    filename_prefix, filename_suffix);
101      } else if (range_vector.size() == 2) {  // A legal range
102        string begin_key = string(range_vector[0]);
103        string end_key = string(range_vector[1]);
104        if (begin_key.empty() || end_key.empty()) {
105          LOG(ERROR) << "FarExtract: Illegal range specification: " << key;
106          return;
107        }
108        if (!far_reader->Find(begin_key)) {
109          LOG(ERROR) << "FarExtract: Cannot find key: " << begin_key;
110          return;
111        }
112        for ( ; !far_reader->Done(); far_reader->Next(), ++i) {
113          string ikey = far_reader->GetKey();
114          if (end_key < ikey) break;
115          const Fst<Arc> &fst = far_reader->GetFst();
116          FarWriteFst(&fst, ikey, &okey, &nrep, generate_filenames, i,
117                      filename_prefix, filename_suffix);
118        }
119      } else {
120        LOG(ERROR) << "FarExtract: Illegal range specification: " << key;
121        return;
122      }
123      delete key_cstr;
124    }
125    delete keys_cstr;
126    return;
127  }
128  // Nothing specified: extract everything.
129  for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) {
130    string key = far_reader->GetKey();
131    const Fst<Arc> &fst = far_reader->GetFst();
132    FarWriteFst(&fst, key, &okey, &nrep, generate_filenames, i,
133                filename_prefix, filename_suffix);
134  }
135  return;
136}
137
138}  // namespace fst
139
140#endif  // FST_EXTENSIONS_FAR_EXTRACT_H__
141