1
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6//     http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13//
14// Copyright 2005-2010 Google, Inc.
15// Author: sorenj@google.com (Jeffrey Sorensen)
16
17#include <fst/symbol-table-ops.h>
18
19namespace fst {
20
21SymbolTable *MergeSymbolTable(const SymbolTable &left, const SymbolTable &right,
22                              bool *right_relabel_output) {
23  // MergeSymbolTable detects several special cases.  It will return a reference
24  // copied version of SymbolTable of left or right if either symbol table is
25  // a superset of the other.
26  SymbolTable *merged = new SymbolTable("merge_" + left.Name() + "_" +
27                                        right.Name());
28  // copy everything from the left symbol table
29  bool left_has_all = true, right_has_all = true, relabel = false;
30  SymbolTableIterator liter(left);
31  for (; !liter.Done(); liter.Next()) {
32    merged->AddSymbol(liter.Symbol(), liter.Value());
33    if (right_has_all) {
34      int64 key = right.Find(liter.Symbol());
35      if (key == -1) {
36        right_has_all = false;
37      } else if (!relabel && key != liter.Value()) {
38        relabel = true;
39      }
40    }
41  }
42  if (right_has_all) {
43    delete merged;
44    if (right_relabel_output != NULL) {
45      *right_relabel_output = relabel;
46    }
47    return right.Copy();
48  }
49  // add all symbols we can from right symbol table
50  vector<string> conflicts;
51  SymbolTableIterator riter(right);
52  for (; !riter.Done(); riter.Next()) {
53    int64 key = merged->Find(riter.Symbol());
54    if (key != -1) {
55      // Symbol already exists, maybe with different value
56      if (key != riter.Value()) {
57        relabel = true;
58      }
59      continue;
60    }
61    // Symbol doesn't exist from left
62    left_has_all = false;
63    if (!merged->Find(riter.Value()).empty()) {
64      // we can't add this where we want to, add it later, in order
65      conflicts.push_back(riter.Symbol());
66      continue;
67    }
68    // there is a hole and we can add this symbol with its id
69    merged->AddSymbol(riter.Symbol(), riter.Value());
70  }
71  if (right_relabel_output != NULL) {
72    *right_relabel_output = relabel;
73  }
74  if (left_has_all) {
75    delete merged;
76    return left.Copy();
77  }
78  // Add all symbols that conflicted, in order
79  for (int i= 0; i < conflicts.size(); ++i) {
80    merged->AddSymbol(conflicts[i]);
81  }
82  return merged;
83}
84
85SymbolTable *CompactSymbolTable(const SymbolTable &syms) {
86  map<int, string> sorted;
87  SymbolTableIterator stiter(syms);
88  for (; !stiter.Done(); stiter.Next()) {
89    sorted[stiter.Value()] = stiter.Symbol();
90  }
91  SymbolTable *compact = new SymbolTable(syms.Name() + "_compact");
92  uint64 newkey = 0;
93  for (map<int, string>::const_iterator si = sorted.begin();
94       si != sorted.end(); ++si) {
95    compact->AddSymbol(si->second, newkey++);
96  }
97  return compact;
98}
99
100SymbolTable *FstReadSymbols(const string &filename, bool input_symbols) {
101  ifstream in(filename.c_str(), ifstream::in | ifstream::binary);
102  if (!in) {
103    LOG(ERROR) << "FstReadSymbols: Can't open file " << filename;
104    return NULL;
105  }
106  FstHeader hdr;
107  if (!hdr.Read(in, filename)) {
108    LOG(ERROR) << "FstReadSymbols: Couldn't read header from " << filename;
109    return NULL;
110  }
111  if (hdr.GetFlags() & FstHeader::HAS_ISYMBOLS) {
112    SymbolTable *isymbols = SymbolTable::Read(in, filename);
113    if (isymbols == NULL) {
114      LOG(ERROR) << "FstReadSymbols: Could not read input symbols from "
115                 << filename;
116      return NULL;
117    }
118    if (input_symbols) {
119      return isymbols;
120    }
121    delete isymbols;
122  }
123  if (hdr.GetFlags() & FstHeader::HAS_OSYMBOLS) {
124    SymbolTable *osymbols = SymbolTable::Read(in, filename);
125    if (osymbols == NULL) {
126      LOG(ERROR) << "FstReadSymbols: Could not read output symbols from "
127                 << filename;
128      return NULL;
129    }
130    if (!input_symbols) {
131      return osymbols;
132    }
133    delete osymbols;
134  }
135  LOG(ERROR) << "FstReadSymbols: The file " << filename
136             << " doesn't contain the requested symbols";
137  return NULL;
138}
139
140}  // namespace fst
141