1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/util/ctc/ctc_loss_calculator.h" 17 18namespace tensorflow { 19namespace ctc { 20 21// Calculates the alpha(t, u) as described in (GravesTh) Section 7.3. 22// Starting with t = 0 instead of t = 1 used in the text. 23// Based on Kanishka's CTC. 24void CTCLossCalculator::CalculateForwardVariables( 25 const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated, 26 Matrix* log_alpha) const { 27 // Number of cols is the number of time steps = number of cols in target 28 // after the output delay. 29 log_alpha->setConstant(kLogZero); 30 31 int U = l_prime.size(); 32 int T = log_alpha->cols(); 33 34 CHECK_EQ(U, log_alpha->rows()); 35 36 // Initial alpha values in (GravesTh) Eq 7.5 and Eq 7.6. 37 log_alpha->coeffRef(0, 0) = log(y(blank_index_, output_delay_)); 38 // Below, l_prime[1] == labels[0] 39 auto label_0 = (l_prime.size() > 1) ? l_prime[1] : blank_index_; 40 log_alpha->coeffRef(1, 0) = log(y(label_0, output_delay_)); 41 42 for (int t = 1; t < T; ++t) { 43 // If there is not enough time to output the remaining labels or 44 // some labels have been skipped, then let log_alpha(u, t) continue to 45 // be kLogZero. 46 for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1)); 47 ++u) { 48 // Begin (GravesTh) Eq 7.9 49 // Add in the u, t - 1 term. 50 float sum_log_alpha = kLogZero; 51 if (ctc_merge_repeated || l_prime[u] == blank_index_) { 52 sum_log_alpha = log_alpha->coeff(u, t - 1); 53 } 54 55 // Add in the u - 1, t - 1 term. 56 if (u > 0) { 57 sum_log_alpha = 58 LogSumExp(sum_log_alpha, log_alpha->coeff(u - 1, t - 1)); 59 } 60 61 // Add in the u - 2, t - 1 term if l_prime(u) != blank or l_prime(u-2). 62 if (u > 1) { 63 const bool matching_labels_merge = 64 ctc_merge_repeated && (l_prime[u] == l_prime[u - 2]); 65 if (l_prime[u] != blank_index_ && !matching_labels_merge) { 66 sum_log_alpha = 67 LogSumExp(sum_log_alpha, log_alpha->coeff(u - 2, t - 1)); 68 } 69 } 70 // Multiply the summed alphas with the activation log probability. 71 log_alpha->coeffRef(u, t) = 72 log(y(l_prime[u], output_delay_ + t)) + sum_log_alpha; 73 } // End (GravesTh) Eq 7.9. 74 } 75} 76 77// Calculates the beta(t, u) as described in (GravesTh) Section 7.3. 78void CTCLossCalculator::CalculateBackwardVariables( 79 const std::vector<int>& l_prime, const Matrix& y, bool ctc_merge_repeated, 80 Matrix* log_beta) const { 81 // Number of cols is the number of time steps = number of cols in target. 82 // Matrix log_beta = 83 // Matrix::Constant(l_prime.size(), y.cols() - output_delay_, 84 // kLogZero); 85 log_beta->setConstant(kLogZero); 86 int T = log_beta->cols(); 87 int U = l_prime.size(); 88 CHECK_EQ(U, log_beta->rows()); 89 90 // Initial beta values in (GravesTh) Eq 7.13: log of probability 1. 91 for (int u = U - 2; u < U; ++u) log_beta->coeffRef(u, T - 1) = 0; 92 93 for (int t = T - 1 - 1; t >= 0; --t) { 94 // If there is not enough time to output the remaining labels or 95 // some labels have been skipped, then let log_beta(u, t) continue to 96 // be kLogZero. 97 for (int u = std::max(0, U - (2 * (T - t))); u < std::min(U, 2 * (t + 1)); 98 ++u) { 99 // Begin (GravesTh) Eq 7.15 100 // Add in the u, t + 1 term. 101 if (ctc_merge_repeated || l_prime[u] == blank_index_) { 102 log_beta->coeffRef(u, t) = 103 LogSumExp(log_beta->coeff(u, t), 104 log_beta->coeff(u, t + 1) + 105 log(y(l_prime[u], output_delay_ + t + 1))); 106 } 107 108 // Add in the u + 1, t + 1 term. 109 if (u + 1 < U) { 110 log_beta->coeffRef(u, t) = 111 LogSumExp(log_beta->coeff(u, t), 112 log_beta->coeff(u + 1, t + 1) + 113 log(y(l_prime[u + 1], output_delay_ + t + 1))); 114 } 115 116 // Add in the u + 2, t + 1 term if l_prime(u) != blank or l_prime(u+2). 117 if (u + 2 < U) { 118 const bool matching_labels_merge = 119 ctc_merge_repeated && (l_prime[u] == l_prime[u + 2]); 120 if (l_prime[u] != blank_index_ && !matching_labels_merge) { 121 // Add in u + 2 term. 122 log_beta->coeffRef(u, t) = 123 LogSumExp(log_beta->coeff(u, t), 124 log_beta->coeff(u + 2, t + 1) + 125 log(y(l_prime[u + 2], output_delay_ + t + 1))); 126 } 127 } // End (GravesTh) Eq. 7.15 128 } 129 } 130} 131 132// Using (GravesTh) Eq 7.26 & 7.34. 133void CTCLossCalculator::CalculateGradient(const std::vector<int>& l_prime, 134 const Matrix& y, 135 const Matrix& log_alpha, 136 const Matrix& log_beta, 137 float log_p_z_x, Matrix* dy) const { 138 // Only working with the leftmost part of dy for this batch element. 139 auto dy_b = dy->leftCols(y.cols()); 140 141 // It is possible that no valid path is found if the activations for the 142 // targets are zero. 143 if (log_p_z_x == kLogZero) { 144 LOG(WARNING) << "No valid path found."; 145 dy_b = y; 146 return; 147 } 148 149 int L = y.rows(); 150 int T = y.cols(); 151 int U = l_prime.size(); 152 153 for (int t = 0; t < T - output_delay_; ++t) { 154 Array prob_sum(L); 155 prob_sum.setConstant(kLogZero); 156 157 for (int u = 0; u < U; ++u) { 158 int l = l_prime[u]; 159 prob_sum[l] = LogSumExp(prob_sum[l], log_alpha(u, t) + log_beta(u, t)); 160 } 161 162 for (int l = 0; l < L; ++l) { 163 // Negative term in (GravesTh) Eq 7.28. 164 float negative_term = expf(prob_sum[l] - log_p_z_x); 165 166 dy_b(l, output_delay_ + t) = y(l, output_delay_ + t) - negative_term; 167 } 168 } 169} 170 171void CTCLossCalculator::GetLPrimeIndices(const std::vector<int>& l, 172 std::vector<int>* l_prime) const { 173 // Assumption is that l_prime is empty. 174 l_prime->reserve(2 * l.size() + 1); 175 176 for (auto label : l) { 177 l_prime->push_back(blank_index_); 178 l_prime->push_back(label); 179 } 180 // Add final blank to l'. 181 l_prime->push_back(blank_index_); 182} 183 184} // namespace ctc 185} // namespace tensorflow 186