1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Seq2seq loss operations for use in sequence models.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.framework import ops
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import nn_ops
26
27__all__ = ["sequence_loss"]
28
29
30def sequence_loss(logits,
31                  targets,
32                  weights,
33                  average_across_timesteps=True,
34                  average_across_batch=True,
35                  softmax_loss_function=None,
36                  name=None):
37  """Weighted cross-entropy loss for a sequence of logits.
38
39  Depending on the values of `average_across_timesteps` and
40  `average_across_batch`, the return Tensor will have rank 0, 1, or 2 as these
41  arguments reduce the cross-entropy at each target, which has shape
42  `[batch_size, sequence_length]`, over their respective dimensions. For
43  example, if `average_across_timesteps` is `True` and `average_across_batch`
44  is `False`, then the return Tensor will have shape `[batch_size]`.
45
46  Args:
47    logits: A Tensor of shape
48      `[batch_size, sequence_length, num_decoder_symbols]` and dtype float.
49      The logits correspond to the prediction across all classes at each
50      timestep.
51    targets: A Tensor of shape `[batch_size, sequence_length]` and dtype
52      int. The target represents the true class at each timestep.
53    weights: A Tensor of shape `[batch_size, sequence_length]` and dtype
54      float. `weights` constitutes the weighting of each prediction in the
55      sequence. When using `weights` as masking, set all valid timesteps to 1
56      and all padded timesteps to 0, e.g. a mask returned by `tf.sequence_mask`.
57    average_across_timesteps: If set, sum the cost across the sequence
58      dimension and divide the cost by the total label weight across timesteps.
59    average_across_batch: If set, sum the cost across the batch dimension and
60      divide the returned cost by the batch size.
61    softmax_loss_function: Function (labels, logits) -> loss-batch
62      to be used instead of the standard softmax (the default if this is None).
63      **Note that to avoid confusion, it is required for the function to accept
64      named arguments.**
65    name: Optional name for this operation, defaults to "sequence_loss".
66
67  Returns:
68    A float Tensor of rank 0, 1, or 2 depending on the
69    `average_across_timesteps` and `average_across_batch` arguments. By default,
70    it has rank 0 (scalar) and is the weighted average cross-entropy
71    (log-perplexity) per symbol.
72
73  Raises:
74    ValueError: logits does not have 3 dimensions or targets does not have 2
75                dimensions or weights does not have 2 dimensions.
76  """
77  if len(logits.get_shape()) != 3:
78    raise ValueError("Logits must be a "
79                     "[batch_size x sequence_length x logits] tensor")
80  if len(targets.get_shape()) != 2:
81    raise ValueError("Targets must be a [batch_size x sequence_length] "
82                     "tensor")
83  if len(weights.get_shape()) != 2:
84    raise ValueError("Weights must be a [batch_size x sequence_length] "
85                     "tensor")
86  with ops.name_scope(name, "sequence_loss", [logits, targets, weights]):
87    num_classes = array_ops.shape(logits)[2]
88    logits_flat = array_ops.reshape(logits, [-1, num_classes])
89    targets = array_ops.reshape(targets, [-1])
90    if softmax_loss_function is None:
91      crossent = nn_ops.sparse_softmax_cross_entropy_with_logits(
92          labels=targets, logits=logits_flat)
93    else:
94      crossent = softmax_loss_function(labels=targets, logits=logits_flat)
95    crossent *= array_ops.reshape(weights, [-1])
96    if average_across_timesteps and average_across_batch:
97      crossent = math_ops.reduce_sum(crossent)
98      total_size = math_ops.reduce_sum(weights)
99      total_size += 1e-12  # to avoid division by 0 for all-0 weights
100      crossent /= total_size
101    else:
102      batch_size = array_ops.shape(logits)[0]
103      sequence_length = array_ops.shape(logits)[1]
104      crossent = array_ops.reshape(crossent, [batch_size, sequence_length])
105    if average_across_timesteps and not average_across_batch:
106      crossent = math_ops.reduce_sum(crossent, axis=[1])
107      total_size = math_ops.reduce_sum(weights, axis=[1])
108      total_size += 1e-12  # to avoid division by 0 for all-0 weights
109      crossent /= total_size
110    if not average_across_timesteps and average_across_batch:
111      crossent = math_ops.reduce_sum(crossent, axis=[0])
112      total_size = math_ops.reduce_sum(weights, axis=[0])
113      total_size += 1e-12  # to avoid division by 0 for all-0 weights
114      crossent /= total_size
115    return crossent
116