1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Seq2seq loss operations for use in sequence models. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.framework import ops 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.ops import nn_ops 26 27__all__ = ["sequence_loss"] 28 29 30def sequence_loss(logits, 31 targets, 32 weights, 33 average_across_timesteps=True, 34 average_across_batch=True, 35 softmax_loss_function=None, 36 name=None): 37 """Weighted cross-entropy loss for a sequence of logits. 38 39 Depending on the values of `average_across_timesteps` and 40 `average_across_batch`, the return Tensor will have rank 0, 1, or 2 as these 41 arguments reduce the cross-entropy at each target, which has shape 42 `[batch_size, sequence_length]`, over their respective dimensions. For 43 example, if `average_across_timesteps` is `True` and `average_across_batch` 44 is `False`, then the return Tensor will have shape `[batch_size]`. 45 46 Args: 47 logits: A Tensor of shape 48 `[batch_size, sequence_length, num_decoder_symbols]` and dtype float. 49 The logits correspond to the prediction across all classes at each 50 timestep. 51 targets: A Tensor of shape `[batch_size, sequence_length]` and dtype 52 int. The target represents the true class at each timestep. 53 weights: A Tensor of shape `[batch_size, sequence_length]` and dtype 54 float. `weights` constitutes the weighting of each prediction in the 55 sequence. When using `weights` as masking, set all valid timesteps to 1 56 and all padded timesteps to 0, e.g. a mask returned by `tf.sequence_mask`. 57 average_across_timesteps: If set, sum the cost across the sequence 58 dimension and divide the cost by the total label weight across timesteps. 59 average_across_batch: If set, sum the cost across the batch dimension and 60 divide the returned cost by the batch size. 61 softmax_loss_function: Function (labels, logits) -> loss-batch 62 to be used instead of the standard softmax (the default if this is None). 63 **Note that to avoid confusion, it is required for the function to accept 64 named arguments.** 65 name: Optional name for this operation, defaults to "sequence_loss". 66 67 Returns: 68 A float Tensor of rank 0, 1, or 2 depending on the 69 `average_across_timesteps` and `average_across_batch` arguments. By default, 70 it has rank 0 (scalar) and is the weighted average cross-entropy 71 (log-perplexity) per symbol. 72 73 Raises: 74 ValueError: logits does not have 3 dimensions or targets does not have 2 75 dimensions or weights does not have 2 dimensions. 76 """ 77 if len(logits.get_shape()) != 3: 78 raise ValueError("Logits must be a " 79 "[batch_size x sequence_length x logits] tensor") 80 if len(targets.get_shape()) != 2: 81 raise ValueError("Targets must be a [batch_size x sequence_length] " 82 "tensor") 83 if len(weights.get_shape()) != 2: 84 raise ValueError("Weights must be a [batch_size x sequence_length] " 85 "tensor") 86 with ops.name_scope(name, "sequence_loss", [logits, targets, weights]): 87 num_classes = array_ops.shape(logits)[2] 88 logits_flat = array_ops.reshape(logits, [-1, num_classes]) 89 targets = array_ops.reshape(targets, [-1]) 90 if softmax_loss_function is None: 91 crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 92 labels=targets, logits=logits_flat) 93 else: 94 crossent = softmax_loss_function(labels=targets, logits=logits_flat) 95 crossent *= array_ops.reshape(weights, [-1]) 96 if average_across_timesteps and average_across_batch: 97 crossent = math_ops.reduce_sum(crossent) 98 total_size = math_ops.reduce_sum(weights) 99 total_size += 1e-12 # to avoid division by 0 for all-0 weights 100 crossent /= total_size 101 else: 102 batch_size = array_ops.shape(logits)[0] 103 sequence_length = array_ops.shape(logits)[1] 104 crossent = array_ops.reshape(crossent, [batch_size, sequence_length]) 105 if average_across_timesteps and not average_across_batch: 106 crossent = math_ops.reduce_sum(crossent, axis=[1]) 107 total_size = math_ops.reduce_sum(weights, axis=[1]) 108 total_size += 1e-12 # to avoid division by 0 for all-0 weights 109 crossent /= total_size 110 if not average_across_timesteps and average_across_batch: 111 crossent = math_ops.reduce_sum(crossent, axis=[0]) 112 total_size = math_ops.reduce_sum(weights, axis=[0]) 113 total_size += 1e-12 # to avoid division by 0 for all-0 weights 114 crossent /= total_size 115 return crossent 116