1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A powerful dynamic attention wrapper object."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import math
24
25import numpy as np
26
27from tensorflow.contrib.framework.python.framework import tensor_util
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_shape
31from tensorflow.python.layers import base as layers_base
32from tensorflow.python.layers import core as layers_core
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import check_ops
35from tensorflow.python.ops import clip_ops
36from tensorflow.python.ops import functional_ops
37from tensorflow.python.ops import init_ops
38from tensorflow.python.ops import math_ops
39from tensorflow.python.ops import nn_ops
40from tensorflow.python.ops import random_ops
41from tensorflow.python.ops import rnn_cell_impl
42from tensorflow.python.ops import tensor_array_ops
43from tensorflow.python.ops import variable_scope
44from tensorflow.python.util import nest
45
46
47__all__ = [
48    "AttentionMechanism",
49    "AttentionWrapper",
50    "AttentionWrapperState",
51    "LuongAttention",
52    "BahdanauAttention",
53    "hardmax",
54    "safe_cumprod",
55    "monotonic_attention",
56    "BahdanauMonotonicAttention",
57    "LuongMonotonicAttention",
58]
59
60
61_zero_state_tensors = rnn_cell_impl._zero_state_tensors  # pylint: disable=protected-access
62
63
64class AttentionMechanism(object):
65
66  @property
67  def alignments_size(self):
68    raise NotImplementedError
69
70  @property
71  def state_size(self):
72    raise NotImplementedError
73
74
75def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined):
76  """Convert to tensor and possibly mask `memory`.
77
78  Args:
79    memory: `Tensor`, shaped `[batch_size, max_time, ...]`.
80    memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`.
81    check_inner_dims_defined: Python boolean.  If `True`, the `memory`
82      argument's shape is checked to ensure all but the two outermost
83      dimensions are fully defined.
84
85  Returns:
86    A (possibly masked), checked, new `memory`.
87
88  Raises:
89    ValueError: If `check_inner_dims_defined` is `True` and not
90      `memory.shape[2:].is_fully_defined()`.
91  """
92  memory = nest.map_structure(
93      lambda m: ops.convert_to_tensor(m, name="memory"), memory)
94  if memory_sequence_length is not None:
95    memory_sequence_length = ops.convert_to_tensor(
96        memory_sequence_length, name="memory_sequence_length")
97  if check_inner_dims_defined:
98    def _check_dims(m):
99      if not m.get_shape()[2:].is_fully_defined():
100        raise ValueError("Expected memory %s to have fully defined inner dims, "
101                         "but saw shape: %s" % (m.name, m.get_shape()))
102    nest.map_structure(_check_dims, memory)
103  if memory_sequence_length is None:
104    seq_len_mask = None
105  else:
106    seq_len_mask = array_ops.sequence_mask(
107        memory_sequence_length,
108        maxlen=array_ops.shape(nest.flatten(memory)[0])[1],
109        dtype=nest.flatten(memory)[0].dtype)
110    seq_len_batch_size = (
111        memory_sequence_length.shape[0].value
112        or array_ops.shape(memory_sequence_length)[0])
113  def _maybe_mask(m, seq_len_mask):
114    rank = m.get_shape().ndims
115    rank = rank if rank is not None else array_ops.rank(m)
116    extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32)
117    m_batch_size = m.shape[0].value or array_ops.shape(m)[0]
118    if memory_sequence_length is not None:
119      message = ("memory_sequence_length and memory tensor batch sizes do not "
120                 "match.")
121      with ops.control_dependencies([
122          check_ops.assert_equal(
123              seq_len_batch_size, m_batch_size, message=message)]):
124        seq_len_mask = array_ops.reshape(
125            seq_len_mask,
126            array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0))
127        return m * seq_len_mask
128    else:
129      return m
130  return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)
131
132
133def _maybe_mask_score(score, memory_sequence_length, score_mask_value):
134  if memory_sequence_length is None:
135    return score
136  message = ("All values in memory_sequence_length must greater than zero.")
137  with ops.control_dependencies(
138      [check_ops.assert_positive(memory_sequence_length, message=message)]):
139    score_mask = array_ops.sequence_mask(
140        memory_sequence_length, maxlen=array_ops.shape(score)[1])
141    score_mask_values = score_mask_value * array_ops.ones_like(score)
142    return array_ops.where(score_mask, score, score_mask_values)
143
144
145class _BaseAttentionMechanism(AttentionMechanism):
146  """A base AttentionMechanism class providing common functionality.
147
148  Common functionality includes:
149    1. Storing the query and memory layers.
150    2. Preprocessing and storing the memory.
151  """
152
153  def __init__(self,
154               query_layer,
155               memory,
156               probability_fn,
157               memory_sequence_length=None,
158               memory_layer=None,
159               check_inner_dims_defined=True,
160               score_mask_value=None,
161               name=None):
162    """Construct base AttentionMechanism class.
163
164    Args:
165      query_layer: Callable.  Instance of `tf.layers.Layer`.  The layer's depth
166        must match the depth of `memory_layer`.  If `query_layer` is not
167        provided, the shape of `query` must match that of `memory_layer`.
168      memory: The memory to query; usually the output of an RNN encoder.  This
169        tensor should be shaped `[batch_size, max_time, ...]`.
170      probability_fn: A `callable`.  Converts the score and previous alignments
171        to probabilities. Its signature should be:
172        `probabilities = probability_fn(score, state)`.
173      memory_sequence_length (optional): Sequence lengths for the batch entries
174        in memory.  If provided, the memory tensor rows are masked with zeros
175        for values past the respective sequence lengths.
176      memory_layer: Instance of `tf.layers.Layer` (may be None).  The layer's
177        depth must match the depth of `query_layer`.
178        If `memory_layer` is not provided, the shape of `memory` must match
179        that of `query_layer`.
180      check_inner_dims_defined: Python boolean.  If `True`, the `memory`
181        argument's shape is checked to ensure all but the two outermost
182        dimensions are fully defined.
183      score_mask_value: (optional): The mask value for score before passing into
184        `probability_fn`. The default is -inf. Only used if
185        `memory_sequence_length` is not None.
186      name: Name to use when creating ops.
187    """
188    if (query_layer is not None
189        and not isinstance(query_layer, layers_base.Layer)):
190      raise TypeError(
191          "query_layer is not a Layer: %s" % type(query_layer).__name__)
192    if (memory_layer is not None
193        and not isinstance(memory_layer, layers_base.Layer)):
194      raise TypeError(
195          "memory_layer is not a Layer: %s" % type(memory_layer).__name__)
196    self._query_layer = query_layer
197    self._memory_layer = memory_layer
198    self.dtype = memory_layer.dtype
199    if not callable(probability_fn):
200      raise TypeError("probability_fn must be callable, saw type: %s" %
201                      type(probability_fn).__name__)
202    if score_mask_value is None:
203      score_mask_value = dtypes.as_dtype(
204          self._memory_layer.dtype).as_numpy_dtype(-np.inf)
205    self._probability_fn = lambda score, prev: (  # pylint:disable=g-long-lambda
206        probability_fn(
207            _maybe_mask_score(score, memory_sequence_length, score_mask_value),
208            prev))
209    with ops.name_scope(
210        name, "BaseAttentionMechanismInit", nest.flatten(memory)):
211      self._values = _prepare_memory(
212          memory, memory_sequence_length,
213          check_inner_dims_defined=check_inner_dims_defined)
214      self._keys = (
215          self.memory_layer(self._values) if self.memory_layer  # pylint: disable=not-callable
216          else self._values)
217      self._batch_size = (
218          self._keys.shape[0].value or array_ops.shape(self._keys)[0])
219      self._alignments_size = (self._keys.shape[1].value or
220                               array_ops.shape(self._keys)[1])
221
222  @property
223  def memory_layer(self):
224    return self._memory_layer
225
226  @property
227  def query_layer(self):
228    return self._query_layer
229
230  @property
231  def values(self):
232    return self._values
233
234  @property
235  def keys(self):
236    return self._keys
237
238  @property
239  def batch_size(self):
240    return self._batch_size
241
242  @property
243  def alignments_size(self):
244    return self._alignments_size
245
246  @property
247  def state_size(self):
248    return self._alignments_size
249
250  def initial_alignments(self, batch_size, dtype):
251    """Creates the initial alignment values for the `AttentionWrapper` class.
252
253    This is important for AttentionMechanisms that use the previous alignment
254    to calculate the alignment at the next time step (e.g. monotonic attention).
255
256    The default behavior is to return a tensor of all zeros.
257
258    Args:
259      batch_size: `int32` scalar, the batch_size.
260      dtype: The `dtype`.
261
262    Returns:
263      A `dtype` tensor shaped `[batch_size, alignments_size]`
264      (`alignments_size` is the values' `max_time`).
265    """
266    max_time = self._alignments_size
267    return _zero_state_tensors(max_time, batch_size, dtype)
268
269  def initial_state(self, batch_size, dtype):
270    """Creates the initial state values for the `AttentionWrapper` class.
271
272    This is important for AttentionMechanisms that use the previous alignment
273    to calculate the alignment at the next time step (e.g. monotonic attention).
274
275    The default behavior is to return the same output as initial_alignments.
276
277    Args:
278      batch_size: `int32` scalar, the batch_size.
279      dtype: The `dtype`.
280
281    Returns:
282      A structure of all-zero tensors with shapes as described by `state_size`.
283    """
284    return self.initial_alignments(batch_size, dtype)
285
286
287def _luong_score(query, keys, scale):
288  """Implements Luong-style (multiplicative) scoring function.
289
290  This attention has two forms.  The first is standard Luong attention,
291  as described in:
292
293  Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
294  "Effective Approaches to Attention-based Neural Machine Translation."
295  EMNLP 2015.  https://arxiv.org/abs/1508.04025
296
297  The second is the scaled form inspired partly by the normalized form of
298  Bahdanau attention.
299
300  To enable the second form, call this function with `scale=True`.
301
302  Args:
303    query: Tensor, shape `[batch_size, num_units]` to compare to keys.
304    keys: Processed memory, shape `[batch_size, max_time, num_units]`.
305    scale: Whether to apply a scale to the score function.
306
307  Returns:
308    A `[batch_size, max_time]` tensor of unnormalized score values.
309
310  Raises:
311    ValueError: If `key` and `query` depths do not match.
312  """
313  depth = query.get_shape()[-1]
314  key_units = keys.get_shape()[-1]
315  if depth != key_units:
316    raise ValueError(
317        "Incompatible or unknown inner dimensions between query and keys.  "
318        "Query (%s) has units: %s.  Keys (%s) have units: %s.  "
319        "Perhaps you need to set num_units to the keys' dimension (%s)?"
320        % (query, depth, keys, key_units, key_units))
321  dtype = query.dtype
322
323  # Reshape from [batch_size, depth] to [batch_size, 1, depth]
324  # for matmul.
325  query = array_ops.expand_dims(query, 1)
326
327  # Inner product along the query units dimension.
328  # matmul shapes: query is [batch_size, 1, depth] and
329  #                keys is [batch_size, max_time, depth].
330  # the inner product is asked to **transpose keys' inner shape** to get a
331  # batched matmul on:
332  #   [batch_size, 1, depth] . [batch_size, depth, max_time]
333  # resulting in an output shape of:
334  #   [batch_size, 1, max_time].
335  # we then squeeze out the center singleton dimension.
336  score = math_ops.matmul(query, keys, transpose_b=True)
337  score = array_ops.squeeze(score, [1])
338
339  if scale:
340    # Scalar used in weight scaling
341    g = variable_scope.get_variable(
342        "attention_g", dtype=dtype, initializer=1.)
343    score = g * score
344  return score
345
346
347class LuongAttention(_BaseAttentionMechanism):
348  """Implements Luong-style (multiplicative) attention scoring.
349
350  This attention has two forms.  The first is standard Luong attention,
351  as described in:
352
353  Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
354  "Effective Approaches to Attention-based Neural Machine Translation."
355  EMNLP 2015.  https://arxiv.org/abs/1508.04025
356
357  The second is the scaled form inspired partly by the normalized form of
358  Bahdanau attention.
359
360  To enable the second form, construct the object with parameter
361  `scale=True`.
362  """
363
364  def __init__(self,
365               num_units,
366               memory,
367               memory_sequence_length=None,
368               scale=False,
369               probability_fn=None,
370               score_mask_value=None,
371               dtype=None,
372               name="LuongAttention"):
373    """Construct the AttentionMechanism mechanism.
374
375    Args:
376      num_units: The depth of the attention mechanism.
377      memory: The memory to query; usually the output of an RNN encoder.  This
378        tensor should be shaped `[batch_size, max_time, ...]`.
379      memory_sequence_length: (optional) Sequence lengths for the batch entries
380        in memory.  If provided, the memory tensor rows are masked with zeros
381        for values past the respective sequence lengths.
382      scale: Python boolean.  Whether to scale the energy term.
383      probability_fn: (optional) A `callable`.  Converts the score to
384        probabilities.  The default is @{tf.nn.softmax}. Other options include
385        @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
386        Its signature should be: `probabilities = probability_fn(score)`.
387      score_mask_value: (optional) The mask value for score before passing into
388        `probability_fn`. The default is -inf. Only used if
389        `memory_sequence_length` is not None.
390      dtype: The data type for the memory layer of the attention mechanism.
391      name: Name to use when creating ops.
392    """
393    # For LuongAttention, we only transform the memory layer; thus
394    # num_units **must** match expected the query depth.
395    if probability_fn is None:
396      probability_fn = nn_ops.softmax
397    if dtype is None:
398      dtype = dtypes.float32
399    wrapped_probability_fn = lambda score, _: probability_fn(score)
400    super(LuongAttention, self).__init__(
401        query_layer=None,
402        memory_layer=layers_core.Dense(
403            num_units, name="memory_layer", use_bias=False, dtype=dtype),
404        memory=memory,
405        probability_fn=wrapped_probability_fn,
406        memory_sequence_length=memory_sequence_length,
407        score_mask_value=score_mask_value,
408        name=name)
409    self._num_units = num_units
410    self._scale = scale
411    self._name = name
412
413  def __call__(self, query, state):
414    """Score the query based on the keys and values.
415
416    Args:
417      query: Tensor of dtype matching `self.values` and shape
418        `[batch_size, query_depth]`.
419      state: Tensor of dtype matching `self.values` and shape
420        `[batch_size, alignments_size]`
421        (`alignments_size` is memory's `max_time`).
422
423    Returns:
424      alignments: Tensor of dtype matching `self.values` and shape
425        `[batch_size, alignments_size]` (`alignments_size` is memory's
426        `max_time`).
427    """
428    with variable_scope.variable_scope(None, "luong_attention", [query]):
429      score = _luong_score(query, self._keys, self._scale)
430    alignments = self._probability_fn(score, state)
431    next_state = alignments
432    return alignments, next_state
433
434
435def _bahdanau_score(processed_query, keys, normalize):
436  """Implements Bahdanau-style (additive) scoring function.
437
438  This attention has two forms.  The first is Bhandanau attention,
439  as described in:
440
441  Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
442  "Neural Machine Translation by Jointly Learning to Align and Translate."
443  ICLR 2015. https://arxiv.org/abs/1409.0473
444
445  The second is the normalized form.  This form is inspired by the
446  weight normalization article:
447
448  Tim Salimans, Diederik P. Kingma.
449  "Weight Normalization: A Simple Reparameterization to Accelerate
450   Training of Deep Neural Networks."
451  https://arxiv.org/abs/1602.07868
452
453  To enable the second form, set `normalize=True`.
454
455  Args:
456    processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys.
457    keys: Processed memory, shape `[batch_size, max_time, num_units]`.
458    normalize: Whether to normalize the score function.
459
460  Returns:
461    A `[batch_size, max_time]` tensor of unnormalized score values.
462  """
463  dtype = processed_query.dtype
464  # Get the number of hidden units from the trailing dimension of keys
465  num_units = keys.shape[2].value or array_ops.shape(keys)[2]
466  # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
467  processed_query = array_ops.expand_dims(processed_query, 1)
468  v = variable_scope.get_variable(
469      "attention_v", [num_units], dtype=dtype)
470  if normalize:
471    # Scalar used in weight normalization
472    g = variable_scope.get_variable(
473        "attention_g", dtype=dtype,
474        initializer=math.sqrt((1. / num_units)))
475    # Bias added prior to the nonlinearity
476    b = variable_scope.get_variable(
477        "attention_b", [num_units], dtype=dtype,
478        initializer=init_ops.zeros_initializer())
479    # normed_v = g * v / ||v||
480    normed_v = g * v * math_ops.rsqrt(
481        math_ops.reduce_sum(math_ops.square(v)))
482    return math_ops.reduce_sum(
483        normed_v * math_ops.tanh(keys + processed_query + b), [2])
484  else:
485    return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2])
486
487
488class BahdanauAttention(_BaseAttentionMechanism):
489  """Implements Bahdanau-style (additive) attention.
490
491  This attention has two forms.  The first is Bahdanau attention,
492  as described in:
493
494  Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
495  "Neural Machine Translation by Jointly Learning to Align and Translate."
496  ICLR 2015. https://arxiv.org/abs/1409.0473
497
498  The second is the normalized form.  This form is inspired by the
499  weight normalization article:
500
501  Tim Salimans, Diederik P. Kingma.
502  "Weight Normalization: A Simple Reparameterization to Accelerate
503   Training of Deep Neural Networks."
504  https://arxiv.org/abs/1602.07868
505
506  To enable the second form, construct the object with parameter
507  `normalize=True`.
508  """
509
510  def __init__(self,
511               num_units,
512               memory,
513               memory_sequence_length=None,
514               normalize=False,
515               probability_fn=None,
516               score_mask_value=None,
517               dtype=None,
518               name="BahdanauAttention"):
519    """Construct the Attention mechanism.
520
521    Args:
522      num_units: The depth of the query mechanism.
523      memory: The memory to query; usually the output of an RNN encoder.  This
524        tensor should be shaped `[batch_size, max_time, ...]`.
525      memory_sequence_length (optional): Sequence lengths for the batch entries
526        in memory.  If provided, the memory tensor rows are masked with zeros
527        for values past the respective sequence lengths.
528      normalize: Python boolean.  Whether to normalize the energy term.
529      probability_fn: (optional) A `callable`.  Converts the score to
530        probabilities.  The default is @{tf.nn.softmax}. Other options include
531        @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
532        Its signature should be: `probabilities = probability_fn(score)`.
533      score_mask_value: (optional): The mask value for score before passing into
534        `probability_fn`. The default is -inf. Only used if
535        `memory_sequence_length` is not None.
536      dtype: The data type for the query and memory layers of the attention
537        mechanism.
538      name: Name to use when creating ops.
539    """
540    if probability_fn is None:
541      probability_fn = nn_ops.softmax
542    if dtype is None:
543      dtype = dtypes.float32
544    wrapped_probability_fn = lambda score, _: probability_fn(score)
545    super(BahdanauAttention, self).__init__(
546        query_layer=layers_core.Dense(
547            num_units, name="query_layer", use_bias=False, dtype=dtype),
548        memory_layer=layers_core.Dense(
549            num_units, name="memory_layer", use_bias=False, dtype=dtype),
550        memory=memory,
551        probability_fn=wrapped_probability_fn,
552        memory_sequence_length=memory_sequence_length,
553        score_mask_value=score_mask_value,
554        name=name)
555    self._num_units = num_units
556    self._normalize = normalize
557    self._name = name
558
559  def __call__(self, query, state):
560    """Score the query based on the keys and values.
561
562    Args:
563      query: Tensor of dtype matching `self.values` and shape
564        `[batch_size, query_depth]`.
565      state: Tensor of dtype matching `self.values` and shape
566        `[batch_size, alignments_size]`
567        (`alignments_size` is memory's `max_time`).
568
569    Returns:
570      alignments: Tensor of dtype matching `self.values` and shape
571        `[batch_size, alignments_size]` (`alignments_size` is memory's
572        `max_time`).
573    """
574    with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
575      processed_query = self.query_layer(query) if self.query_layer else query
576      score = _bahdanau_score(processed_query, self._keys, self._normalize)
577    alignments = self._probability_fn(score, state)
578    next_state = alignments
579    return alignments, next_state
580
581
582def safe_cumprod(x, *args, **kwargs):
583  """Computes cumprod of x in logspace using cumsum to avoid underflow.
584
585  The cumprod function and its gradient can result in numerical instabilities
586  when its argument has very small and/or zero values.  As long as the argument
587  is all positive, we can instead compute the cumulative product as
588  exp(cumsum(log(x))).  This function can be called identically to tf.cumprod.
589
590  Args:
591    x: Tensor to take the cumulative product of.
592    *args: Passed on to cumsum; these are identical to those in cumprod.
593    **kwargs: Passed on to cumsum; these are identical to those in cumprod.
594  Returns:
595    Cumulative product of x.
596  """
597  with ops.name_scope(None, "SafeCumprod", [x]):
598    x = ops.convert_to_tensor(x, name="x")
599    tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
600    return math_ops.exp(math_ops.cumsum(
601        math_ops.log(clip_ops.clip_by_value(x, tiny, 1)), *args, **kwargs))
602
603
604def monotonic_attention(p_choose_i, previous_attention, mode):
605  """Compute monotonic attention distribution from choosing probabilities.
606
607  Monotonic attention implies that the input sequence is processed in an
608  explicitly left-to-right manner when generating the output sequence.  In
609  addition, once an input sequence element is attended to at a given output
610  timestep, elements occurring before it cannot be attended to at subsequent
611  output timesteps.  This function generates attention distributions according
612  to these assumptions.  For more information, see ``Online and Linear-Time
613  Attention by Enforcing Monotonic Alignments''.
614
615  Args:
616    p_choose_i: Probability of choosing input sequence/memory element i.  Should
617      be of shape (batch_size, input_sequence_length), and should all be in the
618      range [0, 1].
619    previous_attention: The attention distribution from the previous output
620      timestep.  Should be of shape (batch_size, input_sequence_length).  For
621      the first output timestep, preevious_attention[n] should be [1, 0, 0, ...,
622      0] for all n in [0, ... batch_size - 1].
623    mode: How to compute the attention distribution.  Must be one of
624      'recursive', 'parallel', or 'hard'.
625        * 'recursive' uses tf.scan to recursively compute the distribution.
626          This is slowest but is exact, general, and does not suffer from
627          numerical instabilities.
628        * 'parallel' uses parallelized cumulative-sum and cumulative-product
629          operations to compute a closed-form solution to the recurrence
630          relation defining the attention distribution.  This makes it more
631          efficient than 'recursive', but it requires numerical checks which
632          make the distribution non-exact.  This can be a problem in particular
633          when input_sequence_length is long and/or p_choose_i has entries very
634          close to 0 or 1.
635        * 'hard' requires that the probabilities in p_choose_i are all either 0
636          or 1, and subsequently uses a more efficient and exact solution.
637
638  Returns:
639    A tensor of shape (batch_size, input_sequence_length) representing the
640    attention distributions for each sequence in the batch.
641
642  Raises:
643    ValueError: mode is not one of 'recursive', 'parallel', 'hard'.
644  """
645  # Force things to be tensors
646  p_choose_i = ops.convert_to_tensor(p_choose_i, name="p_choose_i")
647  previous_attention = ops.convert_to_tensor(
648      previous_attention, name="previous_attention")
649  if mode == "recursive":
650    # Use .shape[0].value when it's not None, or fall back on symbolic shape
651    batch_size = p_choose_i.shape[0].value or array_ops.shape(p_choose_i)[0]
652    # Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_i[-2]]
653    shifted_1mp_choose_i = array_ops.concat(
654        [array_ops.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1)
655    # Compute attention distribution recursively as
656    # q[i] = (1 - p_choose_i[i])*q[i - 1] + previous_attention[i]
657    # attention[i] = p_choose_i[i]*q[i]
658    attention = p_choose_i*array_ops.transpose(functional_ops.scan(
659        # Need to use reshape to remind TF of the shape between loop iterations
660        lambda x, yz: array_ops.reshape(yz[0]*x + yz[1], (batch_size,)),
661        # Loop variables yz[0] and yz[1]
662        [array_ops.transpose(shifted_1mp_choose_i),
663         array_ops.transpose(previous_attention)],
664        # Initial value of x is just zeros
665        array_ops.zeros((batch_size,))))
666  elif mode == "parallel":
667    # safe_cumprod computes cumprod in logspace with numeric checks
668    cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, axis=1, exclusive=True)
669    # Compute recurrence relation solution
670    attention = p_choose_i*cumprod_1mp_choose_i*math_ops.cumsum(
671        previous_attention /
672        # Clip cumprod_1mp to avoid divide-by-zero
673        clip_ops.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.), axis=1)
674  elif mode == "hard":
675    # Remove any probabilities before the index chosen last time step
676    p_choose_i *= math_ops.cumsum(previous_attention, axis=1)
677    # Now, use exclusive cumprod to remove probabilities after the first
678    # chosen index, like so:
679    # p_choose_i = [0, 0, 0, 1, 1, 0, 1, 1]
680    # cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0]
681    # Product of above: [0, 0, 0, 1, 0, 0, 0, 0]
682    attention = p_choose_i*math_ops.cumprod(
683        1 - p_choose_i, axis=1, exclusive=True)
684  else:
685    raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.")
686  return attention
687
688
689def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode,
690                              seed=None):
691  """Attention probability function for monotonic attention.
692
693  Takes in unnormalized attention scores, adds pre-sigmoid noise to encourage
694  the model to make discrete attention decisions, passes them through a sigmoid
695  to obtain "choosing" probabilities, and then calls monotonic_attention to
696  obtain the attention distribution.  For more information, see
697
698  Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
699  "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
700  ICML 2017.  https://arxiv.org/abs/1704.00784
701
702  Args:
703    score: Unnormalized attention scores, shape `[batch_size, alignments_size]`
704    previous_alignments: Previous attention distribution, shape
705      `[batch_size, alignments_size]`
706    sigmoid_noise: Standard deviation of pre-sigmoid noise.  Setting this larger
707      than 0 will encourage the model to produce large attention scores,
708      effectively making the choosing probabilities discrete and the resulting
709      attention distribution one-hot.  It should be set to 0 at test-time, and
710      when hard attention is not desired.
711    mode: How to compute the attention distribution.  Must be one of
712      'recursive', 'parallel', or 'hard'.  See the docstring for
713      `tf.contrib.seq2seq.monotonic_attention` for more information.
714    seed: (optional) Random seed for pre-sigmoid noise.
715
716  Returns:
717    A `[batch_size, alignments_size]`-shape tensor corresponding to the
718    resulting attention distribution.
719  """
720  # Optionally add pre-sigmoid noise to the scores
721  if sigmoid_noise > 0:
722    noise = random_ops.random_normal(array_ops.shape(score), dtype=score.dtype,
723                                     seed=seed)
724    score += sigmoid_noise*noise
725  # Compute "choosing" probabilities from the attention scores
726  if mode == "hard":
727    # When mode is hard, use a hard sigmoid
728    p_choose_i = math_ops.cast(score > 0, score.dtype)
729  else:
730    p_choose_i = math_ops.sigmoid(score)
731  # Convert from choosing probabilities to attention distribution
732  return monotonic_attention(p_choose_i, previous_alignments, mode)
733
734
735class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism):
736  """Base attention mechanism for monotonic attention.
737
738  Simply overrides the initial_alignments function to provide a dirac
739  distribution,which is needed in order for the monotonic attention
740  distributions to have the correct behavior.
741  """
742
743  def initial_alignments(self, batch_size, dtype):
744    """Creates the initial alignment values for the monotonic attentions.
745
746    Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0]
747    for all entries in the batch.
748
749    Args:
750      batch_size: `int32` scalar, the batch_size.
751      dtype: The `dtype`.
752
753    Returns:
754      A `dtype` tensor shaped `[batch_size, alignments_size]`
755      (`alignments_size` is the values' `max_time`).
756    """
757    max_time = self._alignments_size
758    return array_ops.one_hot(
759        array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time,
760        dtype=dtype)
761
762
763class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
764  """Monotonic attention mechanism with Bahadanau-style energy function.
765
766  This type of attention encorces a monotonic constraint on the attention
767  distributions; that is once the model attends to a given point in the memory
768  it can't attend to any prior points at subsequence output timesteps.  It
769  achieves this by using the _monotonic_probability_fn instead of softmax to
770  construct its attention distributions.  Since the attention scores are passed
771  through a sigmoid, a learnable scalar bias parameter is applied after the
772  score function and before the sigmoid.  Otherwise, it is equivalent to
773  BahdanauAttention.  This approach is proposed in
774
775  Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
776  "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
777  ICML 2017.  https://arxiv.org/abs/1704.00784
778  """
779
780  def __init__(self,
781               num_units,
782               memory,
783               memory_sequence_length=None,
784               normalize=False,
785               score_mask_value=None,
786               sigmoid_noise=0.,
787               sigmoid_noise_seed=None,
788               score_bias_init=0.,
789               mode="parallel",
790               dtype=None,
791               name="BahdanauMonotonicAttention"):
792    """Construct the Attention mechanism.
793
794    Args:
795      num_units: The depth of the query mechanism.
796      memory: The memory to query; usually the output of an RNN encoder.  This
797        tensor should be shaped `[batch_size, max_time, ...]`.
798      memory_sequence_length (optional): Sequence lengths for the batch entries
799        in memory.  If provided, the memory tensor rows are masked with zeros
800        for values past the respective sequence lengths.
801      normalize: Python boolean.  Whether to normalize the energy term.
802      score_mask_value: (optional): The mask value for score before passing into
803        `probability_fn`. The default is -inf. Only used if
804        `memory_sequence_length` is not None.
805      sigmoid_noise: Standard deviation of pre-sigmoid noise.  See the docstring
806        for `_monotonic_probability_fn` for more information.
807      sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
808      score_bias_init: Initial value for score bias scalar.  It's recommended to
809        initialize this to a negative value when the length of the memory is
810        large.
811      mode: How to compute the attention distribution.  Must be one of
812        'recursive', 'parallel', or 'hard'.  See the docstring for
813        `tf.contrib.seq2seq.monotonic_attention` for more information.
814      dtype: The data type for the query and memory layers of the attention
815        mechanism.
816      name: Name to use when creating ops.
817    """
818    # Set up the monotonic probability fn with supplied parameters
819    if dtype is None:
820      dtype = dtypes.float32
821    wrapped_probability_fn = functools.partial(
822        _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
823        seed=sigmoid_noise_seed)
824    super(BahdanauMonotonicAttention, self).__init__(
825        query_layer=layers_core.Dense(
826            num_units, name="query_layer", use_bias=False, dtype=dtype),
827        memory_layer=layers_core.Dense(
828            num_units, name="memory_layer", use_bias=False, dtype=dtype),
829        memory=memory,
830        probability_fn=wrapped_probability_fn,
831        memory_sequence_length=memory_sequence_length,
832        score_mask_value=score_mask_value,
833        name=name)
834    self._num_units = num_units
835    self._normalize = normalize
836    self._name = name
837    self._score_bias_init = score_bias_init
838
839  def __call__(self, query, state):
840    """Score the query based on the keys and values.
841
842    Args:
843      query: Tensor of dtype matching `self.values` and shape
844        `[batch_size, query_depth]`.
845      state: Tensor of dtype matching `self.values` and shape
846        `[batch_size, alignments_size]`
847        (`alignments_size` is memory's `max_time`).
848
849    Returns:
850      alignments: Tensor of dtype matching `self.values` and shape
851        `[batch_size, alignments_size]` (`alignments_size` is memory's
852        `max_time`).
853    """
854    with variable_scope.variable_scope(
855        None, "bahdanau_monotonic_attention", [query]):
856      processed_query = self.query_layer(query) if self.query_layer else query
857      score = _bahdanau_score(processed_query, self._keys, self._normalize)
858      score_bias = variable_scope.get_variable(
859          "attention_score_bias", dtype=processed_query.dtype,
860          initializer=self._score_bias_init)
861      score += score_bias
862    alignments = self._probability_fn(score, state)
863    next_state = alignments
864    return alignments, next_state
865
866
867class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
868  """Monotonic attention mechanism with Luong-style energy function.
869
870  This type of attention encorces a monotonic constraint on the attention
871  distributions; that is once the model attends to a given point in the memory
872  it can't attend to any prior points at subsequence output timesteps.  It
873  achieves this by using the _monotonic_probability_fn instead of softmax to
874  construct its attention distributions.  Otherwise, it is equivalent to
875  LuongAttention.  This approach is proposed in
876
877  Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
878  "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
879  ICML 2017.  https://arxiv.org/abs/1704.00784
880  """
881
882  def __init__(self,
883               num_units,
884               memory,
885               memory_sequence_length=None,
886               scale=False,
887               score_mask_value=None,
888               sigmoid_noise=0.,
889               sigmoid_noise_seed=None,
890               score_bias_init=0.,
891               mode="parallel",
892               dtype=None,
893               name="LuongMonotonicAttention"):
894    """Construct the Attention mechanism.
895
896    Args:
897      num_units: The depth of the query mechanism.
898      memory: The memory to query; usually the output of an RNN encoder.  This
899        tensor should be shaped `[batch_size, max_time, ...]`.
900      memory_sequence_length (optional): Sequence lengths for the batch entries
901        in memory.  If provided, the memory tensor rows are masked with zeros
902        for values past the respective sequence lengths.
903      scale: Python boolean.  Whether to scale the energy term.
904      score_mask_value: (optional): The mask value for score before passing into
905        `probability_fn`. The default is -inf. Only used if
906        `memory_sequence_length` is not None.
907      sigmoid_noise: Standard deviation of pre-sigmoid noise.  See the docstring
908        for `_monotonic_probability_fn` for more information.
909      sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
910      score_bias_init: Initial value for score bias scalar.  It's recommended to
911        initialize this to a negative value when the length of the memory is
912        large.
913      mode: How to compute the attention distribution.  Must be one of
914        'recursive', 'parallel', or 'hard'.  See the docstring for
915        `tf.contrib.seq2seq.monotonic_attention` for more information.
916      dtype: The data type for the query and memory layers of the attention
917        mechanism.
918      name: Name to use when creating ops.
919    """
920    # Set up the monotonic probability fn with supplied parameters
921    if dtype is None:
922      dtype = dtypes.float32
923    wrapped_probability_fn = functools.partial(
924        _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
925        seed=sigmoid_noise_seed)
926    super(LuongMonotonicAttention, self).__init__(
927        query_layer=None,
928        memory_layer=layers_core.Dense(
929            num_units, name="memory_layer", use_bias=False, dtype=dtype),
930        memory=memory,
931        probability_fn=wrapped_probability_fn,
932        memory_sequence_length=memory_sequence_length,
933        score_mask_value=score_mask_value,
934        name=name)
935    self._num_units = num_units
936    self._scale = scale
937    self._score_bias_init = score_bias_init
938    self._name = name
939
940  def __call__(self, query, state):
941    """Score the query based on the keys and values.
942
943    Args:
944      query: Tensor of dtype matching `self.values` and shape
945        `[batch_size, query_depth]`.
946      state: Tensor of dtype matching `self.values` and shape
947        `[batch_size, alignments_size]`
948        (`alignments_size` is memory's `max_time`).
949
950    Returns:
951      alignments: Tensor of dtype matching `self.values` and shape
952        `[batch_size, alignments_size]` (`alignments_size` is memory's
953        `max_time`).
954    """
955    with variable_scope.variable_scope(None, "luong_monotonic_attention",
956                                       [query]):
957      score = _luong_score(query, self._keys, self._scale)
958      score_bias = variable_scope.get_variable(
959          "attention_score_bias", dtype=query.dtype,
960          initializer=self._score_bias_init)
961      score += score_bias
962    alignments = self._probability_fn(score, state)
963    next_state = alignments
964    return alignments, next_state
965
966
967class AttentionWrapperState(
968    collections.namedtuple("AttentionWrapperState",
969                           ("cell_state", "attention", "time", "alignments",
970                            "alignment_history", "attention_state"))):
971  """`namedtuple` storing the state of a `AttentionWrapper`.
972
973  Contains:
974
975    - `cell_state`: The state of the wrapped `RNNCell` at the previous time
976      step.
977    - `attention`: The attention emitted at the previous time step.
978    - `time`: int32 scalar containing the current time step.
979    - `alignments`: A single or tuple of `Tensor`(s) containing the alignments
980       emitted at the previous time step for each attention mechanism.
981    - `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s)
982       containing alignment matrices from all time steps for each attention
983       mechanism. Call `stack()` on each to convert to a `Tensor`.
984    - `attention_state`: A single or tuple of nested objects
985       containing attention mechanism state for each attention mechanism.
986       The objects may contain Tensors or TensorArrays.
987  """
988
989  def clone(self, **kwargs):
990    """Clone this object, overriding components provided by kwargs.
991
992    The new state fields' shape must match original state fields' shape. This
993    will be validated, and original fields' shape will be propagated to new
994    fields.
995
996    Example:
997
998    ```python
999    initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...)
1000    initial_state = initial_state.clone(cell_state=encoder_state)
1001    ```
1002
1003    Args:
1004      **kwargs: Any properties of the state object to replace in the returned
1005        `AttentionWrapperState`.
1006
1007    Returns:
1008      A new `AttentionWrapperState` whose properties are the same as
1009      this one, except any overridden properties as provided in `kwargs`.
1010    """
1011    def with_same_shape(old, new):
1012      """Check and set new tensor's shape."""
1013      if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
1014        return tensor_util.with_same_shape(old, new)
1015      return new
1016
1017    return nest.map_structure(
1018        with_same_shape,
1019        self,
1020        super(AttentionWrapperState, self)._replace(**kwargs))
1021
1022
1023def hardmax(logits, name=None):
1024  """Returns batched one-hot vectors.
1025
1026  The depth index containing the `1` is that of the maximum logit value.
1027
1028  Args:
1029    logits: A batch tensor of logit values.
1030    name: Name to use when creating ops.
1031  Returns:
1032    A batched one-hot tensor.
1033  """
1034  with ops.name_scope(name, "Hardmax", [logits]):
1035    logits = ops.convert_to_tensor(logits, name="logits")
1036    if logits.get_shape()[-1].value is not None:
1037      depth = logits.get_shape()[-1].value
1038    else:
1039      depth = array_ops.shape(logits)[-1]
1040    return array_ops.one_hot(
1041        math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
1042
1043
1044def _compute_attention(attention_mechanism, cell_output, attention_state,
1045                       attention_layer):
1046  """Computes the attention and alignments for a given attention_mechanism."""
1047  alignments, next_attention_state = attention_mechanism(
1048      cell_output, state=attention_state)
1049
1050  # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
1051  expanded_alignments = array_ops.expand_dims(alignments, 1)
1052  # Context is the inner product of alignments and values along the
1053  # memory time dimension.
1054  # alignments shape is
1055  #   [batch_size, 1, memory_time]
1056  # attention_mechanism.values shape is
1057  #   [batch_size, memory_time, memory_size]
1058  # the batched matmul is over memory_time, so the output shape is
1059  #   [batch_size, 1, memory_size].
1060  # we then squeeze out the singleton dim.
1061  context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
1062  context = array_ops.squeeze(context, [1])
1063
1064  if attention_layer is not None:
1065    attention = attention_layer(array_ops.concat([cell_output, context], 1))
1066  else:
1067    attention = context
1068
1069  return attention, alignments, next_attention_state
1070
1071
1072class AttentionWrapper(rnn_cell_impl.RNNCell):
1073  """Wraps another `RNNCell` with attention.
1074  """
1075
1076  def __init__(self,
1077               cell,
1078               attention_mechanism,
1079               attention_layer_size=None,
1080               alignment_history=False,
1081               cell_input_fn=None,
1082               output_attention=True,
1083               initial_cell_state=None,
1084               name=None):
1085    """Construct the `AttentionWrapper`.
1086
1087    **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
1088    `AttentionWrapper`, then you must ensure that:
1089
1090    - The encoder output has been tiled to `beam_width` via
1091      @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`).
1092    - The `batch_size` argument passed to the `zero_state` method of this
1093      wrapper is equal to `true_batch_size * beam_width`.
1094    - The initial state created with `zero_state` above contains a
1095      `cell_state` value containing properly tiled final state from the
1096      encoder.
1097
1098    An example:
1099
1100    ```
1101    tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
1102        encoder_outputs, multiplier=beam_width)
1103    tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
1104        encoder_final_state, multiplier=beam_width)
1105    tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
1106        sequence_length, multiplier=beam_width)
1107    attention_mechanism = MyFavoriteAttentionMechanism(
1108        num_units=attention_depth,
1109        memory=tiled_inputs,
1110        memory_sequence_length=tiled_sequence_length)
1111    attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
1112    decoder_initial_state = attention_cell.zero_state(
1113        dtype, batch_size=true_batch_size * beam_width)
1114    decoder_initial_state = decoder_initial_state.clone(
1115        cell_state=tiled_encoder_final_state)
1116    ```
1117
1118    Args:
1119      cell: An instance of `RNNCell`.
1120      attention_mechanism: A list of `AttentionMechanism` instances or a single
1121        instance.
1122      attention_layer_size: A list of Python integers or a single Python
1123        integer, the depth of the attention (output) layer(s). If None
1124        (default), use the context as attention at each time step. Otherwise,
1125        feed the context and cell output into the attention layer to generate
1126        attention at each time step. If attention_mechanism is a list,
1127        attention_layer_size must be a list of the same length.
1128      alignment_history: Python boolean, whether to store alignment history
1129        from all time steps in the final output state (currently stored as a
1130        time major `TensorArray` on which you must call `stack()`).
1131      cell_input_fn: (optional) A `callable`.  The default is:
1132        `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
1133      output_attention: Python bool.  If `True` (default), the output at each
1134        time step is the attention value.  This is the behavior of Luong-style
1135        attention mechanisms.  If `False`, the output at each time step is
1136        the output of `cell`.  This is the beahvior of Bhadanau-style
1137        attention mechanisms.  In both cases, the `attention` tensor is
1138        propagated to the next time step via the state and is used there.
1139        This flag only controls whether the attention mechanism is propagated
1140        up to the next cell in an RNN stack or to the top RNN output.
1141      initial_cell_state: The initial state value to use for the cell when
1142        the user calls `zero_state()`.  Note that if this value is provided
1143        now, and the user uses a `batch_size` argument of `zero_state` which
1144        does not match the batch size of `initial_cell_state`, proper
1145        behavior is not guaranteed.
1146      name: Name to use when creating ops.
1147
1148    Raises:
1149      TypeError: `attention_layer_size` is not None and (`attention_mechanism`
1150        is a list but `attention_layer_size` is not; or vice versa).
1151      ValueError: if `attention_layer_size` is not None, `attention_mechanism`
1152        is a list, and its length does not match that of `attention_layer_size`.
1153    """
1154    super(AttentionWrapper, self).__init__(name=name)
1155    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
1156      raise TypeError(
1157          "cell must be an RNNCell, saw type: %s" % type(cell).__name__)
1158    if isinstance(attention_mechanism, (list, tuple)):
1159      self._is_multi = True
1160      attention_mechanisms = attention_mechanism
1161      for attention_mechanism in attention_mechanisms:
1162        if not isinstance(attention_mechanism, AttentionMechanism):
1163          raise TypeError(
1164              "attention_mechanism must contain only instances of "
1165              "AttentionMechanism, saw type: %s"
1166              % type(attention_mechanism).__name__)
1167    else:
1168      self._is_multi = False
1169      if not isinstance(attention_mechanism, AttentionMechanism):
1170        raise TypeError(
1171            "attention_mechanism must be an AttentionMechanism or list of "
1172            "multiple AttentionMechanism instances, saw type: %s"
1173            % type(attention_mechanism).__name__)
1174      attention_mechanisms = (attention_mechanism,)
1175
1176    if cell_input_fn is None:
1177      cell_input_fn = (
1178          lambda inputs, attention: array_ops.concat([inputs, attention], -1))
1179    else:
1180      if not callable(cell_input_fn):
1181        raise TypeError(
1182            "cell_input_fn must be callable, saw type: %s"
1183            % type(cell_input_fn).__name__)
1184
1185    if attention_layer_size is not None:
1186      attention_layer_sizes = tuple(
1187          attention_layer_size
1188          if isinstance(attention_layer_size, (list, tuple))
1189          else (attention_layer_size,))
1190      if len(attention_layer_sizes) != len(attention_mechanisms):
1191        raise ValueError(
1192            "If provided, attention_layer_size must contain exactly one "
1193            "integer per attention_mechanism, saw: %d vs %d"
1194            % (len(attention_layer_sizes), len(attention_mechanisms)))
1195      self._attention_layers = tuple(
1196          layers_core.Dense(
1197              attention_layer_size,
1198              name="attention_layer",
1199              use_bias=False,
1200              dtype=attention_mechanisms[i].dtype)
1201          for i, attention_layer_size in enumerate(attention_layer_sizes))
1202      self._attention_layer_size = sum(attention_layer_sizes)
1203    else:
1204      self._attention_layers = None
1205      self._attention_layer_size = sum(
1206          attention_mechanism.values.get_shape()[-1].value
1207          for attention_mechanism in attention_mechanisms)
1208
1209    self._cell = cell
1210    self._attention_mechanisms = attention_mechanisms
1211    self._cell_input_fn = cell_input_fn
1212    self._output_attention = output_attention
1213    self._alignment_history = alignment_history
1214    with ops.name_scope(name, "AttentionWrapperInit"):
1215      if initial_cell_state is None:
1216        self._initial_cell_state = None
1217      else:
1218        final_state_tensor = nest.flatten(initial_cell_state)[-1]
1219        state_batch_size = (
1220            final_state_tensor.shape[0].value
1221            or array_ops.shape(final_state_tensor)[0])
1222        error_message = (
1223            "When constructing AttentionWrapper %s: " % self._base_name +
1224            "Non-matching batch sizes between the memory "
1225            "(encoder output) and initial_cell_state.  Are you using "
1226            "the BeamSearchDecoder?  You may need to tile your initial state "
1227            "via the tf.contrib.seq2seq.tile_batch function with argument "
1228            "multiple=beam_width.")
1229        with ops.control_dependencies(
1230            self._batch_size_checks(state_batch_size, error_message)):
1231          self._initial_cell_state = nest.map_structure(
1232              lambda s: array_ops.identity(s, name="check_initial_cell_state"),
1233              initial_cell_state)
1234
1235  def _batch_size_checks(self, batch_size, error_message):
1236    return [check_ops.assert_equal(batch_size,
1237                                   attention_mechanism.batch_size,
1238                                   message=error_message)
1239            for attention_mechanism in self._attention_mechanisms]
1240
1241  def _item_or_tuple(self, seq):
1242    """Returns `seq` as tuple or the singular element.
1243
1244    Which is returned is determined by how the AttentionMechanism(s) were passed
1245    to the constructor.
1246
1247    Args:
1248      seq: A non-empty sequence of items or generator.
1249
1250    Returns:
1251       Either the values in the sequence as a tuple if AttentionMechanism(s)
1252       were passed to the constructor as a sequence or the singular element.
1253    """
1254    t = tuple(seq)
1255    if self._is_multi:
1256      return t
1257    else:
1258      return t[0]
1259
1260  @property
1261  def output_size(self):
1262    if self._output_attention:
1263      return self._attention_layer_size
1264    else:
1265      return self._cell.output_size
1266
1267  @property
1268  def state_size(self):
1269    """The `state_size` property of `AttentionWrapper`.
1270
1271    Returns:
1272      An `AttentionWrapperState` tuple containing shapes used by this object.
1273    """
1274    return AttentionWrapperState(
1275        cell_state=self._cell.state_size,
1276        time=tensor_shape.TensorShape([]),
1277        attention=self._attention_layer_size,
1278        alignments=self._item_or_tuple(
1279            a.alignments_size for a in self._attention_mechanisms),
1280        attention_state=self._item_or_tuple(
1281            a.state_size for a in self._attention_mechanisms),
1282        alignment_history=self._item_or_tuple(
1283            () for _ in self._attention_mechanisms))  # sometimes a TensorArray
1284
1285  def zero_state(self, batch_size, dtype):
1286    """Return an initial (zero) state tuple for this `AttentionWrapper`.
1287
1288    **NOTE** Please see the initializer documentation for details of how
1289    to call `zero_state` if using an `AttentionWrapper` with a
1290    `BeamSearchDecoder`.
1291
1292    Args:
1293      batch_size: `0D` integer tensor: the batch size.
1294      dtype: The internal state data type.
1295
1296    Returns:
1297      An `AttentionWrapperState` tuple containing zeroed out tensors and,
1298      possibly, empty `TensorArray` objects.
1299
1300    Raises:
1301      ValueError: (or, possibly at runtime, InvalidArgument), if
1302        `batch_size` does not match the output size of the encoder passed
1303        to the wrapper object at initialization time.
1304    """
1305    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
1306      if self._initial_cell_state is not None:
1307        cell_state = self._initial_cell_state
1308      else:
1309        cell_state = self._cell.zero_state(batch_size, dtype)
1310      error_message = (
1311          "When calling zero_state of AttentionWrapper %s: " % self._base_name +
1312          "Non-matching batch sizes between the memory "
1313          "(encoder output) and the requested batch size.  Are you using "
1314          "the BeamSearchDecoder?  If so, make sure your encoder output has "
1315          "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
1316          "the batch_size= argument passed to zero_state is "
1317          "batch_size * beam_width.")
1318      with ops.control_dependencies(
1319          self._batch_size_checks(batch_size, error_message)):
1320        cell_state = nest.map_structure(
1321            lambda s: array_ops.identity(s, name="checked_cell_state"),
1322            cell_state)
1323      return AttentionWrapperState(
1324          cell_state=cell_state,
1325          time=array_ops.zeros([], dtype=dtypes.int32),
1326          attention=_zero_state_tensors(self._attention_layer_size, batch_size,
1327                                        dtype),
1328          alignments=self._item_or_tuple(
1329              attention_mechanism.initial_alignments(batch_size, dtype)
1330              for attention_mechanism in self._attention_mechanisms),
1331          attention_state=self._item_or_tuple(
1332              attention_mechanism.initial_state(batch_size, dtype)
1333              for attention_mechanism in self._attention_mechanisms),
1334          alignment_history=self._item_or_tuple(
1335              tensor_array_ops.TensorArray(dtype=dtype, size=0,
1336                                           dynamic_size=True)
1337              if self._alignment_history else ()
1338              for _ in self._attention_mechanisms))
1339
1340  def call(self, inputs, state):
1341    """Perform a step of attention-wrapped RNN.
1342
1343    - Step 1: Mix the `inputs` and previous step's `attention` output via
1344      `cell_input_fn`.
1345    - Step 2: Call the wrapped `cell` with this input and its previous state.
1346    - Step 3: Score the cell's output with `attention_mechanism`.
1347    - Step 4: Calculate the alignments by passing the score through the
1348      `normalizer`.
1349    - Step 5: Calculate the context vector as the inner product between the
1350      alignments and the attention_mechanism's values (memory).
1351    - Step 6: Calculate the attention output by concatenating the cell output
1352      and context through the attention layer (a linear layer with
1353      `attention_layer_size` outputs).
1354
1355    Args:
1356      inputs: (Possibly nested tuple of) Tensor, the input at this time step.
1357      state: An instance of `AttentionWrapperState` containing
1358        tensors from the previous time step.
1359
1360    Returns:
1361      A tuple `(attention_or_cell_output, next_state)`, where:
1362
1363      - `attention_or_cell_output` depending on `output_attention`.
1364      - `next_state` is an instance of `AttentionWrapperState`
1365         containing the state calculated at this time step.
1366
1367    Raises:
1368      TypeError: If `state` is not an instance of `AttentionWrapperState`.
1369    """
1370    if not isinstance(state, AttentionWrapperState):
1371      raise TypeError("Expected state to be instance of AttentionWrapperState. "
1372                      "Received type %s instead."  % type(state))
1373
1374    # Step 1: Calculate the true inputs to the cell based on the
1375    # previous attention value.
1376    cell_inputs = self._cell_input_fn(inputs, state.attention)
1377    cell_state = state.cell_state
1378    cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
1379
1380    cell_batch_size = (
1381        cell_output.shape[0].value or array_ops.shape(cell_output)[0])
1382    error_message = (
1383        "When applying AttentionWrapper %s: " % self.name +
1384        "Non-matching batch sizes between the memory "
1385        "(encoder output) and the query (decoder output).  Are you using "
1386        "the BeamSearchDecoder?  You may need to tile your memory input via "
1387        "the tf.contrib.seq2seq.tile_batch function with argument "
1388        "multiple=beam_width.")
1389    with ops.control_dependencies(
1390        self._batch_size_checks(cell_batch_size, error_message)):
1391      cell_output = array_ops.identity(
1392          cell_output, name="checked_cell_output")
1393
1394    if self._is_multi:
1395      previous_attention_state = state.attention_state
1396      previous_alignment_history = state.alignment_history
1397    else:
1398      previous_attention_state = [state.attention_state]
1399      previous_alignment_history = [state.alignment_history]
1400
1401    all_alignments = []
1402    all_attentions = []
1403    all_attention_states = []
1404    maybe_all_histories = []
1405    for i, attention_mechanism in enumerate(self._attention_mechanisms):
1406      attention, alignments, next_attention_state = _compute_attention(
1407          attention_mechanism, cell_output, previous_attention_state[i],
1408          self._attention_layers[i] if self._attention_layers else None)
1409      alignment_history = previous_alignment_history[i].write(
1410          state.time, alignments) if self._alignment_history else ()
1411
1412      all_attention_states.append(next_attention_state)
1413      all_alignments.append(alignments)
1414      all_attentions.append(attention)
1415      maybe_all_histories.append(alignment_history)
1416
1417    attention = array_ops.concat(all_attentions, 1)
1418    next_state = AttentionWrapperState(
1419        time=state.time + 1,
1420        cell_state=next_cell_state,
1421        attention=attention,
1422        attention_state=self._item_or_tuple(all_attention_states),
1423        alignments=self._item_or_tuple(all_alignments),
1424        alignment_history=self._item_or_tuple(maybe_all_histories))
1425
1426    if self._output_attention:
1427      return attention, next_state
1428    else:
1429      return cell_output, next_state
1430