1/* FILE:		sub_phon.cpp
2 *  DATE MODIFIED:	31-Aug-07
3 *  DESCRIPTION:	Part of the  SREC graph compiler project source files.
4 *
5 *  Copyright 2007, 2008 Nuance Communciations, Inc.                               *
6 *                                                                           *
7 *  Licensed under the Apache License, Version 2.0 (the 'License');          *
8 *  you may not use this file except in compliance with the License.         *
9 *                                                                           *
10 *  You may obtain a copy of the License at                                  *
11 *      http://www.apache.org/licenses/LICENSE-2.0                           *
12 *                                                                           *
13 *  Unless required by applicable law or agreed to in writing, software      *
14 *  distributed under the License is distributed on an 'AS IS' BASIS,        *
15 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
16 *  See the License for the specific language governing permissions and      *
17 *  limitations under the License.                                           *
18 *                                                                           *
19 *---------------------------------------------------------------------------*/
20
21#include <iostream>
22#include <sstream>
23#include <string>
24#include <assert.h>
25
26#define DEBUG           0
27
28#include "sub_grph.h"
29#include "grxmldoc.h"
30
31void SubGraph::ExpandPhonemes ( GRXMLDoc &doc )
32{
33    int ii, wordId, phoneId, currId, newId, nextId, arcCount;
34    Pronunciation pron;
35    int pronCount;
36    NUANArc *arcOne;
37    std::string modelLabel, word;
38
39    {
40        std::stringstream ss;
41        ss << SILENCE_CONTEXT;
42        modelLabel= ss.str();
43        silenceId = doc.addPhonemeToList(modelLabel);
44    }
45    {
46        std::stringstream ss;
47        ss << INTRA_SILENCE_CONTEXT;
48        modelLabel= ss.str();
49        intraId = doc.addPhonemeToList(modelLabel);
50    }
51    UpdateVertexCount (0);
52    arcCount= numArc;
53    for (ii= 0; ii < arcCount; ii++) {
54        wordId= arc[ii]->GetInput();
55        if (wordId >= 0) {
56	    doc.findLabel(wordId, word );
57            if (IsSlot (word)) {
58	        // std::cout << "Found slot "<< word <<std::endl;
59	        newId= NewVertexId();
60	        arcOne= CreateArc (NONE_LABEL, NONE_LABEL, arc[ii]->GetFromId(), newId);
61		arcOne->AssignCentre (NONE_LABEL);
62	        nextId= NewVertexId();
63	        //  special case
64	        arcOne= CreateArc (-wordId, wordId, newId, nextId);
65		arcOne->AssignCentre (NONE_LABEL);
66	        // (void) CreateArc (-wordId, NONE_LABEL, arc[ii]->GetFromId(), newId);
67	        arcOne= CreateArc (WB_LABEL, NONE_LABEL, nextId, arc[ii]->GetToId());
68		arcOne->AssignCentre (NONE_LABEL);
69	        // (void) CreateArc (WB_LABEL, wordId, newId, arc[ii]->GetToId());
70            }
71            else {
72	        pron.clear();
73	        pron.lookup( *(doc.getVocabulary()), word );
74	        pronCount = pron.getPronCount();
75	        for (int jj= 0; jj < pronCount; jj++) {
76	            currId= arc[ii]->GetFromId();
77	            int modelCount = pron.getPhonemeCount(jj);
78	            for (int kk= 0; kk < modelCount; kk++) {
79	                newId= NewVertexId();
80	                pron.getPhoneme(jj, kk, modelLabel);
81	                //std::cout << "ExpandPhonemes adding "<< modelLabel <<std::endl;
82	                phoneId = doc.addPhonemeToList( modelLabel );
83	                arcOne= CreateArc (phoneId, NONE_LABEL, currId, newId);
84                        if (phoneId == intraId)
85		            arcOne->AssignCentre (silenceId);
86                        else
87		            arcOne->AssignCentre (phoneId);
88	                currId= newId;
89	            }
90	            arcOne= CreateArc (WB_LABEL, wordId, currId, arc[ii]->GetToId());
91		    arcOne->AssignCentre (NONE_LABEL);
92	        }
93	        //  End of loop
94            }
95	    arc[ii]->AssignInput (DISCARD_LABEL);   //  Delete original arc
96        }
97    }
98    RemoveDiscardedArcs ();
99
100    for (ii= 0; ii < numArc; ii++) {
101	arc[ii]->AssignLeft (NONE_LABEL);
102	arc[ii]->AssignRight (NONE_LABEL);
103    }
104
105    SortLanguage ();
106
107    return;
108}
109
110void SubGraph::AddLeftContexts ()
111{
112    int ii, rix, currId, leftC;
113
114    SortLanguage();
115    SortLanguageReverse();
116    for (ii= 0; ii < numArc; ii++) {
117        if (arc[ii]->GetInput() >= 0) {
118            currId= arc[ii]->GetFromId();
119            rix= FindToIndex (currId);
120            if (rix >= 0) {
121                leftC= arc[backwardList[rix]]->GetCentre();
122                arc[ii]->AssignLeft(leftC);
123            }
124            else if (currId != startId)
125                printf ("Shouldn't get here (L) %d\n", currId);
126        }
127        else
128            arc[ii]->AssignLeft (NONE_LABEL);
129    }
130    return;
131}
132
133void SubGraph::AddRightContexts ()
134{
135    int ii, rix, currId, rightC;
136
137    SortLanguage();
138    SortLanguageReverse();
139    for (ii= 0; ii < numArc; ii++) {
140        if (arc[ii]->GetInput() >= 0) {
141            currId= arc[ii]->GetToId();
142            rix= FindFromIndex (currId);
143            if (rix >= 0) {
144                rightC= arc[forwardList[rix]]->GetCentre();
145                arc[ii]->AssignRight (rightC);
146            }
147            else
148                printf ("Shouldn't get here (R) %d\n", currId);
149        }
150        else
151            arc[ii]->AssignRight (NONE_LABEL);
152    }
153    return;
154}
155
156void SubGraph::ExpandToHMMs ( GRXMLDoc &doc )
157{
158    int ii, currId, newId, arcCount, left, right, centre;
159    int modelCount;
160    NUANArc *arcOne;
161
162    UpdateVertexCount (0);
163    arcCount= numArc;
164    for (ii= 0; ii < arcCount; ii++) {
165        std::vector<int> modelSequence;
166	if (arc[ii]->GetInput() >= 0) {      //  i.e. proper phoneme
167            centre= arc[ii]->GetCentre();
168	    left= arc[ii]->GetLeft();
169	    right= arc[ii]->GetRight();
170#if DEBUG
171            std::cout << "HMM PIC:" << left <<" " << centre <<" " << right << std::endl;
172#endif
173	    doc.getHMMSequence (centre, left, right, modelSequence);
174	    modelCount = modelSequence.size();
175#if DEBUG
176            std::cout << "HMM: " << centre << " number of HMMs = " << modelCount <<std::endl;
177#endif
178	    if (modelCount >= 0) {
179		currId= arc[ii]->GetFromId();
180		for (int jj= 0; jj < modelCount; jj++) {
181                    if (jj == (modelCount - 1))
182                        newId= arc[ii]->GetToId();
183                    else
184		        newId= NewVertexId();
185		    arcOne= CreateArc (modelSequence[jj], NONE_LABEL, currId, newId);
186		    arcOne->AssignCentre (arc[ii]->GetInput());
187		    arcOne->AssignLeft (arc[ii]->GetLeft());
188		    arcOne->AssignRight (arc[ii]->GetRight());
189#if DEBUG
190                    std::cout << "HMM phoneme: " << modelSequence[jj] << " ";
191#endif
192		    currId= newId;
193		}
194#if DEBUG
195                std::cout << " centre " << arc[ii]->GetInput() << std::endl;
196#endif
197                arc[ii]->AssignInput (DISCARD_LABEL);       //  Delete original arc
198            }
199        }
200    }
201    RemoveDiscardedArcs ();
202
203    SortLanguage ();
204
205    return;
206}
207
208void SubGraph::ExpandIntraWordSilence ( GRXMLDoc &doc )
209{
210    int ii, fix, bix, firstId, newId, modelCount, followCount, currId, count;
211    int left, centre, right;
212    NUANArc *arcOne;
213
214    SortLanguage();
215    SortLanguageReverse();
216
217#if DEBUG
218    std::cout << "Intra sil search " << intraId << std::endl;
219#endif
220    count= numArc;
221    for (ii= 0; ii < count; ii++) {
222        if (arc[ii]->GetCentre() == intraId) {
223#if DEBUG
224            std::cout << "Intra sil: " << arc[ii]->GetFromId() << " " << arc[ii]->GetToId() << std::endl;
225#endif
226
227            fix= FindToIndex (arc[ii]->GetFromId());
228            if (fix < 0)
229                return;
230            while (fix < sortRevNum
231             && arc[backwardList[fix]]->GetToId() == arc[ii]->GetFromId()) {
232		//  left triphone
233                newId= NewVertexId();
234		left= arc[backwardList[fix]]->GetLeft();
235		centre= arc[ii]->GetLeft();
236		right= arc[ii]->GetRight();
237#if DEBUG
238                std::cout << "HMM PIC:" << left <<" " << centre <<" " << right << std::endl;
239#endif
240                std::vector<int> modelSequence;
241	        doc.getHMMSequence (centre, left, right, modelSequence);
242	        modelCount = modelSequence.size();
243#if DEBUG
244                std::cout << "HMM: " << centre << " number of HMMs = " << modelCount <<std::endl;
245#endif
246	        if (modelCount >= 0) {
247		    currId= arc[backwardList[fix]]->GetFromId();
248		    for (int jj= 0; jj < modelCount; jj++) {
249		        newId= NewVertexId();
250		        arcOne= CreateArc (modelSequence[jj],
251			    arc[backwardList[fix]]->GetOutput(), currId, newId);
252		        arcOne->AssignCentre (centre);
253#if DEBUG
254                        std::cout << "HMM phoneme: " << modelSequence[jj] << " ";
255#endif
256		        currId= newId;
257		    }
258#if DEBUG
259                    std::cout << " " << centre << std::endl;
260#endif
261		}
262		firstId= newId;
263
264		//  right block
265                bix= FindFromIndex (arc[ii]->GetToId());
266                if (bix < 0)
267                    return;
268                while (bix < sortNum
269                 && arc[forwardList[bix]]->GetFromId() == arc[ii]->GetToId()) {
270                    fix++;
271		    //  right triphone
272		    left= arc[ii]->GetLeft();
273		    centre= arc[ii]->GetRight();
274		    right= arc[forwardList[bix]]->GetRight();
275
276#if DEBUG
277                    std::cout << "HMM PIC:" << left <<" " << centre <<" " << right << std::endl;
278#endif
279                    std::vector<int> followSequence;
280	            doc.getHMMSequence (centre, left, right, followSequence);
281	            followCount = followSequence.size();
282#if DEBUG
283                    std::cout << "HMM: " << centre << " number of HMMs = " << followCount <<std::endl;
284#endif
285
286	            if (followCount >= 0) {
287		        currId= firstId;
288		        for (int jj= 0; jj < followCount; jj++) {
289                            if (jj == (followCount - 1))
290                                newId= arc[forwardList[bix]]->GetToId();
291                            else
292		                newId= NewVertexId();
293		            arcOne= CreateArc (followSequence[jj],
294				arc[forwardList[bix]]->GetOutput(), currId, newId);
295		            arcOne->AssignCentre (centre);
296#if DEBUG
297                            std::cout << "HMM phoneme: " << followSequence[jj] << " ";
298#endif
299		            currId= newId;
300		        }
301#if DEBUG
302                        std::cout << " " << centre << std::endl;
303#endif
304		    }
305                    bix++;
306                }
307		fix++;
308            }
309            // arc[ii]->AssignInput (silenceId);
310        }
311    }
312    return;
313}
314
315void SubGraph::ShiftOutputsToLeft ()
316{
317    UpdateVertexCount (0);
318    SortLanguage();
319    SortLanguageReverse();
320    ReverseMarkArcs();
321    MarkNodesByOutputAndClearArcs();
322    return;
323}
324
325void SubGraph::ReverseMarkArcs ()
326{
327    int ii;
328
329    for (ii= 0; ii < numArc; ii++)
330        if (arc[ii]->GetInput() == WB_LABEL)
331            ReverseMarkOutput (arc[ii]->GetFromId(), startId, arc[ii]->GetOutput());
332    return;
333}
334
335void SubGraph::ReverseMarkOutput (int currId, int initialId, int outId)
336{
337    int rix;
338
339    rix= FindToIndex (currId);
340    if (rix < 0)
341        return;
342    while (rix < sortRevNum && arc[backwardList[rix]]->GetToId() == currId) {
343        if (arc[backwardList[rix]]->GetOutput() != DISCARD_LABEL    //  not resolved yet
344         && arc[backwardList[rix]]->GetInput() >= 0) { // excludes word boundary
345            if (arc[backwardList[rix]]->GetOutput() == NONE_LABEL)
346                arc[backwardList[rix]]->AssignOutput (outId);
347            else if (outId != arc[backwardList[rix]]->GetOutput())
348                arc[backwardList[rix]]->AssignOutput(DISCARD_LABEL);
349            ReverseMarkOutput (arc[backwardList[rix]]->GetFromId(), initialId, outId);
350        }
351        rix++;
352    }
353    return;
354}
355
356void SubGraph::MarkNodesByOutputAndClearArcs ()
357{
358    int ii, currId, rix;
359
360    int *nodeList= new int [numVertex];
361    for (ii= 0; ii < numVertex; ii++)
362        nodeList[ii]= NONE_LABEL;
363
364    //  Associate outputs with destination node
365    for (ii= 0; ii < numArc; ii++) {
366        currId= arc[ii]->GetToId();
367        if (currId >= 0) {
368	    if (arc[ii]->GetInput() == WB_LABEL)
369                nodeList[currId]= DISCARD_LABEL;
370            else if (nodeList[currId] != DISCARD_LABEL) {
371                if (nodeList[currId] == NONE_LABEL)
372                    nodeList[currId]= arc[ii]->GetOutput();
373                else if (nodeList[currId] != arc[ii]->GetOutput())
374                    nodeList[currId]= DISCARD_LABEL;
375            }
376        }
377    }
378
379    //  Now discard all arcs other than those emanating from unique assignments
380    for (ii= 0; ii < numArc; ii++) {
381        currId= arc[ii]->GetFromId();
382        if (nodeList[currId] >= 0 && arc[ii]->GetOutput() >= 0) // unique ones
383            arc[ii]->AssignOutput(DISCARD_LABEL);
384    }
385
386    //  Finally, special case for intra-word silence
387    for (ii= 0; ii < numArc; ii++) {
388	if (arc[ii]->GetOutput() >= 0 && arc[ii]->GetCentre() == intraId) {
389            currId= arc[ii]->GetToId();
390#if DEBUG
391            std::cout << "Intra silence: " << currId << " " << arc[ii]->GetFromId() << std::endl;
392#endif
393            rix= FindFromIndex (currId);
394            if (rix < 0)
395                continue;
396            while (rix < sortNum && arc[forwardList[rix]]->GetFromId() == currId) {
397                assert (arc[forwardList[rix]]->GetOutput() == DISCARD_LABEL);
398                arc[forwardList[rix]]->AssignOutput(arc[ii]->GetOutput());
399		rix++;
400	    }
401            arc[ii]->AssignOutput(DISCARD_LABEL);
402        }
403    }
404
405    delete [] nodeList;
406    return;
407}
408
409void SubGraph::FinalProcessing (GRXMLDoc &doc)
410{
411    ExpandWordBoundaries (doc);
412    AddInitialFinalSilences (doc);
413    return;
414}
415
416void SubGraph::ExpandWordBoundaries (GRXMLDoc &doc)
417{
418    int ii, newId, count;
419    NUANArc  *arcOne;
420
421    count= numArc;
422    for (ii= 0; ii < count; ii++) {
423        std::vector<int> modelSequence;
424        doc.getHMMSequence (silenceId, -1, -1, modelSequence);
425        if (arc[ii]->GetInput() == WB_LABEL) {
426            newId= NewVertexId();
427            // (void) CreateArc (NONE_LABEL, NONE_LABEL, arc[ii]->GetFromId(), newId);
428            // arcOne= CreateArc (modelSequence[0], NONE_LABEL, arc[ii]->GetFromId(), newId);
429            arcOne= CreateArc (modelSequence[0], NONE_LABEL, arc[ii]->GetFromId(), newId);
430	    arcOne->AssignCentre (silenceId);
431            (void) CreateArc (WB_LABEL, arc[ii]->GetOutput(), newId, arc[ii]->GetToId());
432            // arc[ii]->AssignInput (DISCARD_LABEL);
433        }
434    }
435    return;
436}
437
438void SubGraph::AddInitialFinalSilences (GRXMLDoc &doc)
439{
440    int ii, rix, newId, intId, count;
441    NUANArc *arcOne;
442
443    SortLanguage();
444    newId= NewVertexId();
445    rix= FindFromIndex (startId);
446    if (rix < 0)
447        return;
448    while (rix < sortNum && arc[forwardList[rix]]->GetFromId() == startId) {
449        arc[forwardList[rix]]->AssignFromId (newId);
450        rix++;
451    }
452    std::vector<int> modelSequence;
453    doc.getHMMSequence (silenceId, -1, -1, modelSequence);
454    intId= NewVertexId();
455    arcOne= CreateArc (modelSequence[0], INITIAL_LABEL, startId, intId);
456    arcOne->AssignCentre (silenceId);
457    (void) CreateArc (WB_LABEL, NONE_LABEL, intId, newId);
458
459    count= numArc;
460    newId= NewVertexId();
461    for (ii= 0; ii < count; ii++) {
462        if (arc[ii]->GetInput() == TERMINAL_LABEL) {
463            arc[ii]->AssignInput (modelSequence[0]);
464            arc[ii]->AssignCentre (silenceId);
465            arc[ii]->AssignOutput (FINAL_LABEL);
466            arc[ii]->AssignToId (newId);
467        }
468    }
469    (void) CreateArc (TERMINAL_LABEL, TERMINAL_LABEL, newId, newId);
470
471    return;
472}
473