1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/util/ctc/ctc_loss_calculator.h"
17
18namespace tensorflow {
19namespace ctc {
20
21// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3.
22// Starting with t = 0 instead of t = 1 used in the text.
23// Based on Kanishka's CTC.
24void CTCLossCalculator::CalculateForwardVariables(
25    const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
26    Matrix* log_alpha) const {
27  // Number of cols is the number of time steps = number of cols in target
28  // after the output delay.
29  log_alpha->setConstant(kLogZero);
30
31  int U = l_prime.size();
32  int T = log_alpha->cols();
33
34  CHECK_EQ(U, log_alpha->rows());
35
36  // Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6.
37  log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_));
38  // Below, l_prime[1] == labels[0]
39  auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_;
40  log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_));
41
42  for (int t = 1; t < T; ++t) {
43    // If there is not enough time to output the remaining labels or
44    // some labels have been skipped, then let log_alpha(u, t) continue to
45    // be kLogZero.
46    for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
47         ++u) {
48      // Begin (GravesTh) Eq 7.9
49      // Add in the u, t - 1 term.
50      float sum_log_alpha = kLogZero;
51      if (ctc_merge_repeated || l_prime[u] == blank_index_) {
52        sum_log_alpha = log_alpha->coeff(u, t - 1);
53      }
54
55      // Add in the u - 1, t - 1 term.
56      if (u > 0) {
57        sum_log_alpha =
58            LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1));
59      }
60
61      // Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2).
62      if (u > 1) {
63        const bool matching_labels_merge =
64            ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]);
65        if (l_prime[u] != blank_index_ && !matching_labels_merge) {
66          sum_log_alpha =
67              LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1));
68        }
69      }
70      // Multiply the summed alphas with the activation log probability.
71      log_alpha->coeffRef(u, t) =
72          log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha;
73    }  // End (GravesTh) Eq 7.9.
74  }
75}
76
77// Calculates the beta(t, u) as described in (GravesTh) Section 7.3.
78void CTCLossCalculator::CalculateBackwardVariables(
79    const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated,
80    Matrix* log_beta) const {
81  // Number of cols is the number of time steps =  number of cols in target.
82  // Matrix log_beta =
83  //    Matrix::Constant(l_prime.size(), y.cols() - output_delay_,
84  // kLogZero);
85  log_beta->setConstant(kLogZero);
86  int T = log_beta->cols();
87  int U = l_prime.size();
88  CHECK_EQ(U, log_beta->rows());
89
90  // Initial beta values in (GravesTh) Eq 7.13: log of probability 1.
91  for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0;
92
93  for (int t = T - 1 - 1; t >= 0; --t) {
94    // If there is not enough time to output the remaining labels or
95    // some labels have been skipped, then let log_beta(u, t) continue to
96    // be kLogZero.
97    for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1));
98         ++u) {
99      // Begin (GravesTh) Eq 7.15
100      // Add in the u, t + 1 term.
101      if (ctc_merge_repeated || l_prime[u] == blank_index_) {
102        log_beta->coeffRef(u, t) =
103            LogSumExp(log_beta->coeff(u, t),
104                      log_beta->coeff(u, t + 1) +
105                          log(y(l_prime[u], output_delay_ + t + 1)));
106      }
107
108      // Add in the u + 1, t + 1 term.
109      if (u + 1 < U) {
110        log_beta->coeffRef(u, t) =
111            LogSumExp(log_beta->coeff(u, t),
112                      log_beta->coeff(u + 1, t + 1) +
113                          log(y(l_prime[u + 1], output_delay_ + t + 1)));
114      }
115
116      // Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2).
117      if (u + 2 < U) {
118        const bool matching_labels_merge =
119            ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]);
120        if (l_prime[u] != blank_index_ && !matching_labels_merge) {
121          // Add in u + 2 term.
122          log_beta->coeffRef(u, t) =
123              LogSumExp(log_beta->coeff(u, t),
124                        log_beta->coeff(u + 2, t + 1) +
125                            log(y(l_prime[u + 2], output_delay_ + t + 1)));
126        }
127      }  // End (GravesTh) Eq. 7.15
128    }
129  }
130}
131
132// Using (GravesTh) Eq 7.26 & 7.34.
133void CTCLossCalculator::CalculateGradient(const std::vector<int>& l_prime,
134                                          const Matrix& y,
135                                          const Matrix& log_alpha,
136                                          const Matrix& log_beta,
137                                          float log_p_z_x, Matrix* dy) const {
138  // Only working with the leftmost part of dy for this batch element.
139  auto dy_b = dy->leftCols(y.cols());
140
141  // It is possible that no valid path is found if the activations for the
142  // targets are zero.
143  if (log_p_z_x == kLogZero) {
144    LOG(WARNING) << "No valid path found.";
145    dy_b = y;
146    return;
147  }
148
149  int L = y.rows();
150  int T = y.cols();
151  int U = l_prime.size();
152
153  for (int t = 0; t < T - output_delay_; ++t) {
154    Array prob_sum(L);
155    prob_sum.setConstant(kLogZero);
156
157    for (int u = 0; u < U; ++u) {
158      int l = l_prime[u];
159      prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t));
160    }
161
162    for (int l = 0; l < L; ++l) {
163      // Negative term in (GravesTh) Eq 7.28.
164      float negative_term = expf(prob_sum[l] - log_p_z_x);
165
166      dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term;
167    }
168  }
169}
170
171void CTCLossCalculator::GetLPrimeIndices(const std::vector<int>& l,
172                                         std::vector<int>* l_prime) const {
173  // Assumption is that l_prime is empty.
174  l_prime->reserve(2 * l.size() + 1);
175
176  for (auto label : l) {
177    l_prime->push_back(blank_index_);
178    l_prime->push_back(label);
179  }
180  // Add final blank to l'.
181  l_prime->push_back(blank_index_);
182}
183
184}  // namespace ctc
185}  // namespace tensorflow
186