1/*---------------------------------------------------------------------------*
2 *  vocab.cpp                                                                *
3 *                                                                           *
4 *  Copyright 2007, 2008 Nuance Communciations, Inc.                               *
5 *                                                                           *
6 *  Licensed under the Apache License, Version 2.0 (the 'License');          *
7 *  you may not use this file except in compliance with the License.         *
8 *                                                                           *
9 *  You may obtain a copy of the License at                                  *
10 *      http://www.apache.org/licenses/LICENSE-2.0                           *
11 *                                                                           *
12 *  Unless required by applicable law or agreed to in writing, software      *
13 *  distributed under the License is distributed on an 'AS IS' BASIS,        *
14 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
15 *  See the License for the specific language governing permissions and      *
16 *  limitations under the License.                                           *
17 *                                                                           *
18 *---------------------------------------------------------------------------*/
19
20#include <string>
21#include <iostream>
22#include <stdexcept>
23#include "ESR_Locale.h"
24#include "LCHAR.h"
25#include "pstdio.h"
26#include "ESR_Session.h"
27#include "SR_Vocabulary.h"
28
29#include "vocab.h"
30
31#define MAX_LINE_LENGTH     256
32#define MAX_PRONS_LENGTH 1024
33
34#define DEBUG	0
35
36#define GENERIC CONTEXT "#"
37
38Vocabulary::Vocabulary( std::string const & vocFileName )
39{
40    ESR_ReturnCode rc;
41    rc = SR_VocabularyLoad(vocFileName.c_str(), &m_hVocab);
42    if (rc != ESR_SUCCESS)
43    {
44        std::cout << "Error: " << ESR_rc2str(rc) <<std::endl;
45        exit (-1);
46    }
47}
48
49Vocabulary::~Vocabulary()
50{
51    SR_VocabularyDestroy(m_hVocab);
52}
53
54Pronunciation::Pronunciation()
55{
56}
57
58Pronunciation::~Pronunciation()
59{
60}
61
62void Pronunciation::clear()
63{
64    m_Prons.clear();
65    for (unsigned int ii=0;ii<m_ModelIDs.size();ii++ )
66    {
67        m_ModelIDs[ii].clear();
68    }
69    m_ModelIDs.clear();
70}
71
72int Pronunciation::lookup(  Vocabulary & vocab, std::string  & phrase )
73{
74    ESR_ReturnCode rc;
75    LCHAR prons[MAX_PRONS_LENGTH];
76    LCHAR* c_phrase;
77    size_t len;
78
79    LCHAR s[MAX_LINE_LENGTH];
80    strcpy (s, phrase.c_str() ); // No conversion for std::string to wchar
81    //clear();
82
83    memset (prons, 0x00, sizeof(LCHAR));
84
85    c_phrase = s;
86    SR_Vocabulary *p_SRVocab = vocab.getSRVocabularyHandle();
87#if DEBUG
88    std::cout << "DEBUG: " << phrase <<" to be looked up" << std::endl;
89#endif
90    rc = SR_VocabularyGetPronunciation( p_SRVocab, c_phrase, prons, &len );
91    if (rc != ESR_SUCCESS)
92        //  std::cout <<"ERORORORORROOR!" <<std::endl;
93        std::cout <<"ERROR: " << ESR_rc2str(rc) << std::endl;
94    else {
95#if DEBUG
96        std::cout <<"OUTPUT: " << prons << " num " << len << std::endl;
97#endif
98        size_t len_used;
99        LCHAR *pron = 0;
100        for(len_used=0; len_used <len; ) {
101            pron = &prons[0]+len_used;
102            len_used += LSTRLEN(pron)+1;
103#if DEBUG
104            std::cout << "DEBUG: used " << len_used << " now " << LSTRLEN(pron) << std::endl;
105#endif
106            std::string pronString( pron ); // wstring conversion if needed
107            addPron( pronString );
108#if DEBUG
109            std::cout << "DEBUG: " << phrase << " " << pron << std::endl;
110#endif
111        }
112    }
113    return getPronCount();
114}
115
116
117int Pronunciation::addPron( std::string & s )
118{
119    m_Prons.push_back( s );
120    return m_Prons.size();
121}
122
123int Pronunciation::getPronCount()
124{  // returns number of prons
125    return m_Prons.size();
126}
127
128bool Pronunciation::getPron( int index, std::string &s )
129{
130 // returns string length used
131    try {
132      s = m_Prons.at(index);
133    }
134    catch(std::out_of_range& err) {
135      std::cerr << "out_of_range: " << err.what() << std::endl;
136    }
137    return true;
138}
139
140void Pronunciation::print()
141{
142  std::string s;
143  for (int ii=0; ii< getPronCount(); ii++) {
144    getPron(ii, s);
145#if DEBUG
146    std::cout << "Pron #" << ii << ": " << s << std::endl;
147#endif
148  }
149}
150
151void Pronunciation::printModelIDs()
152{
153  std::string s;
154  for (int ii=0; ii< getPronCount(); ii++) {
155    getPron(ii, s);
156#if DEBUG
157    std::cout << "  Pron #" << ii << ": " << s << std::endl;
158    std::cout << "    Model IDs: ";
159#endif
160    for (int jj=0;jj<getModelCount(ii);jj++) {
161      std::cout << " " << getModelID(ii,jj);
162    }
163#if DEBUG
164    std::cout <<  std::endl;
165#endif
166  }
167}
168
169int Pronunciation::getPhonemeCount( int pronIndex )
170{
171  std::string s;
172  getPron(pronIndex, s);
173  return s.size();
174}
175
176bool Pronunciation::getPhoneme( int pronIndex, int picIndex , std::string &phoneme )
177{
178  std::string s;
179  getPron(pronIndex, s);
180  phoneme= s.at(picIndex);
181  return true;
182}
183
184
185bool Pronunciation::getPIC( int pronIndex, int picIndex, std::string &pic )
186{
187  std::string pron;
188  char lphon;
189  char cphon;
190  char rphon;
191
192  getPron( pronIndex, pron );
193  int numPhonemes = pron.size();
194  if ( 1==numPhonemes ) {
195    lphon=GENERIC_CONTEXT;
196    rphon=GENERIC_CONTEXT;
197    cphon = pron.at(0);
198  }
199  else
200    {
201      if ( 0==picIndex ) {
202	lphon=GENERIC_CONTEXT;
203	rphon=GENERIC_CONTEXT;
204      }
205      else if( numPhonemes-1==picIndex ) {
206	lphon = pron.at(picIndex-1);
207	rphon=GENERIC_CONTEXT;
208      }
209      else {
210	lphon = pron.at(picIndex-1);
211	rphon = pron.at(picIndex+1);
212      }
213      cphon = pron.at(picIndex);
214      pic = lphon + cphon + rphon;
215    }
216  return true;
217}
218
219int Pronunciation::lookupModelIDs( AcousticModel &acoustic )
220{
221  // Looks up all hmms for all prons
222  std::string pron;
223  char lphon;
224  char cphon;
225  char rphon;
226
227  int numProns = getPronCount();
228  int totalCount=0;
229  for (int ii=0;ii < numProns; ii++ )
230    {
231      getPron( ii, pron );
232      std::vector<int> idList; // Create storage
233      int numPhonemes = getPhonemeCount(ii);
234      if (1==numPhonemes) {
235	lphon=GENERIC_CONTEXT;
236	rphon=GENERIC_CONTEXT;
237	cphon = pron.at(0);
238      }
239      else
240      for ( int jj=0;jj<numPhonemes;jj++ )
241	{
242	  std::string pic;
243	  getPIC(ii, jj, pic);
244	  lphon = pron.at(0);
245	  cphon = pron.at(1);
246	  rphon = pron.at(2);
247	  int id = CA_ArbdataGetModelIdsForPIC( acoustic.getCAModelHandle(), lphon, cphon,  rphon );
248#if DEBUG
249	  std::cout <<"DEBUG model id: " << lphon <<cphon << rphon << "  "<< id << std::endl;
250#endif
251
252	  idList.push_back(id);
253	}
254      m_ModelIDs.push_back(idList);
255      totalCount+=numPhonemes;
256    }
257  return totalCount;
258}
259
260int Pronunciation::getModelCount( int pronIndex )
261{
262  return m_ModelIDs[pronIndex].size();
263}
264
265int Pronunciation::getModelID( int pronIndex, int modelPos )
266{
267  return m_ModelIDs[pronIndex][modelPos];
268}
269
270AcousticModel::AcousticModel( std::string & arbFileName )
271{
272  m_CA_Arbdata = CA_LoadArbdata( arbFileName.c_str() );
273  if (!m_CA_Arbdata)
274    {
275      std::cout << "Error: while trying to load " << arbFileName.c_str() << std::endl;
276      exit (-1);
277    }
278
279}
280
281AcousticModel::~AcousticModel()
282{
283  CA_FreeArbdata( m_CA_Arbdata);
284}
285
286int AcousticModel::getStateIndices(int id, std::vector<int> & stateIDs)
287{
288  srec_arbdata *allotree = (srec_arbdata*) m_CA_Arbdata;
289  int numStates = allotree->hmm_infos[id].num_states;
290#if DEBUG
291  std::cout << "getStateIndices: count = " << numStates <<std::endl;
292#endif
293  for (int ii=0; ii <numStates; ii++ ) {
294    stateIDs.push_back( allotree->hmm_infos[id].state_indices[ii] );
295#if DEBUG
296    std::cout <<  allotree->hmm_infos[id].state_indices[ii] ;
297#endif
298  }
299#if DEBUG
300  std::cout << std::endl;
301#endif
302    return stateIDs.size();
303}
304
305