rnn_cell.py revision bab22b9f25741e172bb70ff1f82dc803ced0f579
1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#
3d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
4d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# you may not use this file except in compliance with the License.
5d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# You may obtain a copy of the License at
6d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#
7d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
8d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#
9d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
10d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
11d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# See the License for the specific language governing permissions and
13d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# limitations under the License.
14d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# ==============================================================================
15d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
16d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower"""Module for constructing RNN Cells."""
17d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import absolute_import
18d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import division
19d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import print_function
20d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
218fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlowerimport collections
22d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerimport math
23d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
24bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.contrib.compiler import jit
2534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.contrib.layers.python.layers import layers
2637fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xiefrom tensorflow.contrib.rnn.python.ops import core_rnn_cell
2737fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xiefrom tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerfrom tensorflow.python.framework import dtypes
29bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.python.framework import op_def_registry
30d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.framework import ops
31d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import array_ops
32d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import clip_ops
3334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.python.ops import init_ops
34d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import math_ops
35d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import nn_ops
36d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import variable_scope as vs
375cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlowerfrom tensorflow.python.platform import tf_logging as logging
384c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenechfrom tensorflow.python.util import nest
39d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
4037fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie
41d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_concat_variable(name, shape, dtype, num_shards):
42d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Get a sharded variable concatenated into one tensor."""
43d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
44d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  if len(sharded_variable) == 1:
45d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return sharded_variable[0]
46d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
47d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  concat_name = name + "/concat"
48d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
49d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
50d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if value.name == concat_full_name:
51d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      return value
52d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
530e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
54d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
55d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                        concat_variable)
56d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  return concat_variable
57d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
58d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
59d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_sharded_variable(name, shape, dtype, num_shards):
60d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Get a list of sharded variables with the given dtype."""
61d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  if num_shards > shape[0]:
62d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    raise ValueError("Too many shards: shape=%s, num_shards=%d" %
63d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                     (shape, num_shards))
64d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  unit_shard_size = int(math.floor(shape[0] / num_shards))
65d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  remaining_rows = shape[0] - unit_shard_size * num_shards
66d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
67d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  shards = []
68d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  for i in range(num_shards):
69d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    current_size = unit_shard_size
70d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if i < remaining_rows:
71d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      current_size += 1
72d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
73d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                                  dtype=dtype))
74d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  return shards
75d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
76d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
7737fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
781032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  """Long short-term memory unit (LSTM) recurrent network cell.
791032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
801032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The default non-peephole implementation is based on:
811032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
821032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
831032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
841032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  S. Hochreiter and J. Schmidhuber.
851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
861032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
871032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The peephole implementation is based on:
881032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
891032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    https://research.google.com/pubs/archive/43905.pdf
901032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
911032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  Hasim Sak, Andrew Senior, and Francoise Beaufays.
921032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  "Long short-term memory recurrent neural network architectures for
931032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower   large scale acoustic modeling." INTERSPEECH, 2014.
941032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The coupling of input and forget gate is based on:
961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
971032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    http://arxiv.org/pdf/1503.04069.pdf
981032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
991032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  Greff et al. "LSTM: A Search Space Odyssey"
1001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1011032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The class uses optional peep-hole connections, and an optional projection
1021032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  layer.
1031032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  """
1041032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  def __init__(self, num_units, use_peepholes=False,
1061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower               initializer=None, num_proj=None, proj_clip=None,
1071032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower               num_unit_shards=1, num_proj_shards=1,
1081032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower               forget_bias=1.0, state_is_tuple=False,
109e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo               activation=math_ops.tanh):
1101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
1111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1121032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Args:
1131032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
1141032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      use_peepholes: bool, set True to enable diagonal/peephole connections.
1151032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
1161032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        projection matrices.
1171032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_proj: (optional) int, The output dimensionality for the projection
1181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        matrices.  If None, no projection is performed.
1191032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
1201032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      provided, then the projected values are clipped elementwise to within
1211032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      `[-proj_clip, proj_clip]`.
1221032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_unit_shards: How to split the weight matrix.  If >1, the weight
1231032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        matrix is stored across num_unit_shards.
1241032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_proj_shards: How to split the projection matrix.  If >1, the
1251032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        projection matrix is stored across num_proj_shards.
1261032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      forget_bias: Biases of the forget gate are initialized by default to 1
1271032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        in order to reduce the scale of forgetting at the beginning of
1281032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        the training.
1291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      state_is_tuple: If True, accepted and returned states are 2-tuples of
1301032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        the `c_state` and `m_state`.  By default (False), they are concatenated
1311032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        along the column axis.  This default behavior will soon be deprecated.
1321032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      activation: Activation function of the inner states.
1331032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """
1341032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if not state_is_tuple:
1351032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      logging.warn(
1361032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          "%s: Using a concatenated state is slower and will soon be "
137d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower          "deprecated.  Use state_is_tuple=True.", self)
1381032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_units = num_units
1391032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._use_peepholes = use_peepholes
1401032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._initializer = initializer
1411032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_proj = num_proj
1421032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._proj_clip = proj_clip
1431032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_unit_shards = num_unit_shards
1441032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_proj_shards = num_proj_shards
1451032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._forget_bias = forget_bias
1461032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._state_is_tuple = state_is_tuple
1471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._activation = activation
1481032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1491032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if num_proj:
1501032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      self._state_size = (
15137fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie          core_rnn_cell.LSTMStateTuple(num_units, num_proj)
1521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          if state_is_tuple else num_units + num_proj)
1531032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      self._output_size = num_proj
1541032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    else:
1551032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      self._state_size = (
15637fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie          core_rnn_cell.LSTMStateTuple(num_units, num_units)
1571032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          if state_is_tuple else 2 * num_units)
1581032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      self._output_size = num_units
1591032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1601032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  @property
1611032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  def state_size(self):
1621032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    return self._state_size
1631032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1641032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  @property
1651032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  def output_size(self):
1661032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    return self._output_size
1671032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1681032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  def __call__(self, inputs, state, scope=None):
1691032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """Run one step of LSTM.
1701032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1711032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Args:
1721032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      inputs: input Tensor, 2D, batch x num_units.
1731032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      state: if `state_is_tuple` is False, this must be a state Tensor,
1741032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
1751032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
1761032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        `m_state`.
1771032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      scope: VariableScope for the created subgraph; defaults to "LSTMCell".
1781032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1791032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Returns:
1801032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      A tuple containing:
1811032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
1821032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        LSTM after reading `inputs` when previous state was `state`.
1831032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        Here output_dim is:
1841032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower           num_proj if num_proj was set,
1851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower           num_units otherwise.
1861032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      - Tensor(s) representing the new state of LSTM after reading `inputs` when
1871032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        the previous state was `state`.  Same type and shape(s) as `state`.
1881032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1891032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Raises:
1901032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      ValueError: If input size cannot be inferred from inputs via
1911032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        static shape inference.
1921032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """
193e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    sigmoid = math_ops.sigmoid
194e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo
1951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    num_proj = self._num_units if self._num_proj is None else self._num_proj
1961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1971032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if self._state_is_tuple:
1981032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      (c_prev, m_prev) = state
1991032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    else:
2001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
2011032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
2021032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2031032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    dtype = inputs.dtype
2041032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    input_size = inputs.get_shape().with_rank(2)[1]
2051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if input_size.value is None:
2061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
20792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope(scope or "coupled_input_forget_gate_lstm_cell",
20892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo                           initializer=self._initializer):
2091032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      concat_w = _get_concat_variable(
2101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          "W", [input_size.value + num_proj, 3 * self._num_units],
2111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          dtype, self._num_unit_shards)
2121032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2131032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      b = vs.get_variable(
2144ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          "B",
2154ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          shape=[3 * self._num_units],
2164ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          initializer=init_ops.zeros_initializer(),
2174ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          dtype=dtype)
2181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2191032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      # j = new_input, f = forget_gate, o = output_gate
2200e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower      cell_inputs = array_ops.concat([inputs, m_prev], 1)
2211032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
222a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower      j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
2231032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2241032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      # Diagonal connections
2251032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      if self._use_peepholes:
2261032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        w_f_diag = vs.get_variable(
2271032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower            "W_F_diag", shape=[self._num_units], dtype=dtype)
2281032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        w_o_diag = vs.get_variable(
2291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower            "W_O_diag", shape=[self._num_units], dtype=dtype)
2301032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2311032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      if self._use_peepholes:
2321032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
2331032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      else:
2341032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        f_act = sigmoid(f + self._forget_bias)
2351032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      c = (f_act * c_prev + (1 - f_act) * self._activation(j))
2361032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2371032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      if self._use_peepholes:
2381032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        m = sigmoid(o + w_o_diag * c) * self._activation(c)
2391032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      else:
2401032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        m = sigmoid(o) * self._activation(c)
2411032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2421032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      if self._num_proj is not None:
2431032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        concat_w_proj = _get_concat_variable(
2441032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower            "W_P", [self._num_units, self._num_proj],
2451032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower            dtype, self._num_proj_shards)
2461032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        m = math_ops.matmul(m, concat_w_proj)
2481032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        if self._proj_clip is not None:
2491032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          # pylint: disable=invalid-unary-operand-type
2501032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
2511032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          # pylint: enable=invalid-unary-operand-type
2521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2530e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower    new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else
2540e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower                 array_ops.concat([c, m], 1))
2551032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    return m, new_state
2561032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2571032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
25837fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass TimeFreqLSTMCell(core_rnn_cell.RNNCell):
259d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
260d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
261d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  This implementation is based on:
262d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
263d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Tara N. Sainath and Bo Li
264d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
265d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    for LVCSR Tasks." submitted to INTERSPEECH, 2016.
266d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
267d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  It uses peep-hole connections and optional cell clipping.
268d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """
269d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
270d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def __init__(self, num_units, use_peepholes=False,
271d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               cell_clip=None, initializer=None,
272d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               num_unit_shards=1, forget_bias=1.0,
273d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               feature_size=None, frequency_skip=None):
274d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
275d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
276d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
277d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
278d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      use_peepholes: bool, set True to enable diagonal/peephole connections.
279d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      cell_clip: (optional) A float value, if provided the cell state is clipped
280d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        by this value prior to the cell output activation.
281d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
282d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        projection matrices.
283d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      num_unit_shards: int, How to split the weight matrix.  If >1, the weight
284d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        matrix is stored across num_unit_shards.
285d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      forget_bias: float, Biases of the forget gate are initialized by default
286d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        to 1 in order to reduce the scale of forgetting at the beginning
287d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        of the training.
288d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      feature_size: int, The size of the input feature the LSTM spans over.
289d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      frequency_skip: int, The amount the LSTM filter is shifted by in
290d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        frequency.
291d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
292d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_units = num_units
293d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._use_peepholes = use_peepholes
294d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._cell_clip = cell_clip
295d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._initializer = initializer
296d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_unit_shards = num_unit_shards
297d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._forget_bias = forget_bias
298d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._feature_size = feature_size
299d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._frequency_skip = frequency_skip
300d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._state_size = 2 * num_units
301d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._output_size = num_units
302d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
303d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
304d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def output_size(self):
305d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._output_size
306d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
307d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
308d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def state_size(self):
309d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._state_size
310d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
311d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def __call__(self, inputs, state, scope=None):
312d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Run one step of LSTM.
313d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
314d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
315d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      inputs: input Tensor, 2D, batch x num_units.
316d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      state: state Tensor, 2D, batch x state_size.
317d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      scope: VariableScope for the created subgraph; defaults to
318d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        "TimeFreqLSTMCell".
319d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
320d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
321d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A tuple containing:
322d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      - A 2D, batch x output_dim, Tensor representing the output of the LSTM
323d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
324d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        Here output_dim is num_units.
325d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      - A 2D, batch x state_size, Tensor representing the new state of LSTM
326d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
327d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
328d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if an input_size was specified and the provided inputs have
329d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        a different dimension.
330d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
331e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    sigmoid = math_ops.sigmoid
332e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    tanh = math_ops.tanh
333d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
334d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    freq_inputs = self._make_tf_features(inputs)
335d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    dtype = inputs.dtype
336d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    actual_input_size = freq_inputs[0].get_shape().as_list()[1]
33792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope(scope or "time_freq_lstm_cell",
338d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                           initializer=self._initializer):  # "TimeFreqLSTMCell"
339d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      concat_w = _get_concat_variable(
340d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          "W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
341d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          dtype, self._num_unit_shards)
342d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      b = vs.get_variable(
3434ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          "B",
3444ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          shape=[4 * self._num_units],
3454ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          initializer=init_ops.zeros_initializer(),
3464ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          dtype=dtype)
347d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
348763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower      # Diagonal connections
349763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower      if self._use_peepholes:
350763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower        w_f_diag = vs.get_variable(
351763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower            "W_F_diag", shape=[self._num_units], dtype=dtype)
352763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower        w_i_diag = vs.get_variable(
353763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower            "W_I_diag", shape=[self._num_units], dtype=dtype)
354763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower        w_o_diag = vs.get_variable(
355763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower            "W_O_diag", shape=[self._num_units], dtype=dtype)
356763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower
357d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      # initialize the first freq state to be zero
358d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
359d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                                     self._num_units], dtype)
360d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      for fq in range(len(freq_inputs)):
361d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
362d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                                 [-1, self._num_units])
363d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
364d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                                 [-1, self._num_units])
365d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
3660e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower        cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
3670e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower                                       1)
368d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
369a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower        i, j, f, o = array_ops.split(
370a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower            value=lstm_matrix, num_or_size_splits=4, axis=1)
371d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
372d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        if self._use_peepholes:
373d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
374d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               sigmoid(i + w_i_diag * c_prev) * tanh(j))
375d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        else:
376d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
377d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
378d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        if self._cell_clip is not None:
379d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          # pylint: disable=invalid-unary-operand-type
380d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
381d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          # pylint: enable=invalid-unary-operand-type
382d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
383d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        if self._use_peepholes:
384d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          m = sigmoid(o + w_o_diag * c) * tanh(c)
385d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        else:
386d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          m = sigmoid(o) * tanh(c)
387d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        m_prev_freq = m
388d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        if fq == 0:
3890e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower          state_out = array_ops.concat([c, m], 1)
390d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          m_out = m
391d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        else:
3920e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower          state_out = array_ops.concat([state_out, c, m], 1)
3930e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower          m_out = array_ops.concat([m_out, m], 1)
394d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return m_out, state_out
395d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
396d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def _make_tf_features(self, input_feat):
397d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Make the frequency features.
398d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
399d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
400d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      input_feat: input Tensor, 2D, batch x num_units.
401d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
402d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
403d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A list of frequency features, with each element containing:
404d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      - A 2D, batch x output_dim, Tensor representing the time-frequency feature
405d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        for that frequency index. Here output_dim is feature_size.
406d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
407d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if input_size cannot be inferred from static shape inference.
408d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
409d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    input_size = input_feat.get_shape().with_rank(2)[-1].value
410d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if input_size is None:
411d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      raise ValueError("Cannot infer input_size from static shape inference.")
412d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    num_feats = int((input_size - self._feature_size) / (
413d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        self._frequency_skip)) + 1
414d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    freq_inputs = []
415d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    for f in range(num_feats):
416d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip],
417d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                                  [-1, self._feature_size])
418d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      freq_inputs.append(cur_input)
419d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return freq_inputs
420d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
421d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
42237fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass GridLSTMCell(core_rnn_cell.RNNCell):
423d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Grid Long short-term memory unit (LSTM) recurrent network cell.
424d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
425d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  The default is based on:
426d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Nal Kalchbrenner, Ivo Danihelka and Alex Graves
427d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    "Grid Long Short-Term Memory," Proc. ICLR 2016.
428d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    http://arxiv.org/abs/1507.01526
429d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
430d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  When peephole connections are used, the implementation is based on:
431d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Tara N. Sainath and Bo Li
432d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
433d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    for LVCSR Tasks." submitted to INTERSPEECH, 2016.
434d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
435d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  The code uses optional peephole connections, shared_weights and cell clipping.
436d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """
437d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
438d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def __init__(self, num_units, use_peepholes=False,
439d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               share_time_frequency_weights=False,
440d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               cell_clip=None, initializer=None,
441d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               num_unit_shards=1, forget_bias=1.0,
4428fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower               feature_size=None, frequency_skip=None,
4439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               num_frequency_blocks=None,
4449e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               start_freqindex_list=None,
4459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               end_freqindex_list=None,
4468fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower               couple_input_forget_gates=False,
4478fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower               state_is_tuple=False):
448d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
449d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
450d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
451d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
4521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      use_peepholes: (optional) bool, default False. Set True to enable
4531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        diagonal/peephole connections.
4541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      share_time_frequency_weights: (optional) bool, default False. Set True to
4551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        enable shared cell weights between time and frequency LSTMs.
4561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      cell_clip: (optional) A float value, default None, if provided the cell
4571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state is clipped by this value prior to the cell output activation.
458d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
4591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        projection matrices, default None.
4601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      num_unit_shards: (optional) int, defualt 1, How to split the weight
4611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        matrix. If > 1,the weight matrix is stored across num_unit_shards.
4621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      forget_bias: (optional) float, default 1.0, The initial bias of the
4631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        forget gates, used to reduce the scale of forgetting at the beginning
464d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        of the training.
4651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      feature_size: (optional) int, default None, The size of the input feature
4661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        the LSTM spans over.
4671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      frequency_skip: (optional) int, default None, The amount the LSTM filter
4681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        is shifted by in frequency.
4699e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      num_frequency_blocks: [required] A list of frequency blocks needed to
4709e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cover the whole input feature splitting defined by start_freqindex_list
4719e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        and end_freqindex_list.
4729e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      start_freqindex_list: [optional], list of ints, default None,  The
4739e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        starting frequency index for each frequency block.
4749e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      end_freqindex_list: [optional], list of ints, default None. The ending
4759e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        frequency index for each frequency block.
4761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      couple_input_forget_gates: (optional) bool, default False, Whether to
4771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
4781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        model parameters and computation cost.
4798fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      state_is_tuple: If True, accepted and returned states are 2-tuples of
4808fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        the `c_state` and `m_state`.  By default (False), they are concatenated
4818fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        along the column axis.  This default behavior will soon be deprecated.
4829e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    Raises:
4839e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      ValueError: if the num_frequency_blocks list is not specified
484d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
4858fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    if not state_is_tuple:
4868fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      logging.warn("%s: Using a concatenated state is slower and will soon be "
4878fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower                   "deprecated.  Use state_is_tuple=True.", self)
488d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_units = num_units
489d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._use_peepholes = use_peepholes
490d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._share_time_frequency_weights = share_time_frequency_weights
4918fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    self._couple_input_forget_gates = couple_input_forget_gates
4928fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    self._state_is_tuple = state_is_tuple
493d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._cell_clip = cell_clip
494d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._initializer = initializer
495d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_unit_shards = num_unit_shards
496d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._forget_bias = forget_bias
497d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._feature_size = feature_size
498d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._frequency_skip = frequency_skip
4999e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._start_freqindex_list = start_freqindex_list
5009e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._end_freqindex_list = end_freqindex_list
5019e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._num_frequency_blocks = num_frequency_blocks
5029e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._total_blocks = 0
5039e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    if self._num_frequency_blocks is None:
5049e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      raise ValueError("Must specify num_frequency_blocks")
5059e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower
5069e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    for block_index in range(len(self._num_frequency_blocks)):
5079e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      self._total_blocks += int(self._num_frequency_blocks[block_index])
5088fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    if state_is_tuple:
5098fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      state_names = ""
5109e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for block_index in range(len(self._num_frequency_blocks)):
5119e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        for freq_index in range(self._num_frequency_blocks[block_index]):
5129e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
5139e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
5148fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      self._state_tuple_type = collections.namedtuple(
5151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          "GridLSTMStateTuple", state_names.strip(","))
5168fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      self._state_size = self._state_tuple_type(
5179e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              *([num_units, num_units] * self._total_blocks))
5188fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    else:
5198fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      self._state_tuple_type = None
5209e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      self._state_size = num_units * self._total_blocks * 2
5219e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._output_size = num_units * self._total_blocks * 2
522d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
523d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
524d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def output_size(self):
525d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._output_size
526d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
527d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
528d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def state_size(self):
529d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._state_size
530d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
5318fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower  @property
5328fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower  def state_tuple_type(self):
5338fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    return self._state_tuple_type
5348fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower
535d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def __call__(self, inputs, state, scope=None):
536d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Run one step of LSTM.
537d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
538d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
5391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs: input Tensor, 2D, [batch, feature_size].
5401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
5411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        flag self._state_is_tuple.
5421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      scope: (optional) VariableScope for the created subgraph; if None, it
5431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        defaults to "GridLSTMCell".
544d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
545d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
546d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A tuple containing:
5471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
548d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
549d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        Here output_dim is num_units.
5501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
551d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
552d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
553d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if an input_size was specified and the provided inputs have
554d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        a different dimension.
555d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
5561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    batch_size = int(inputs.get_shape()[0])
5571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    freq_inputs = self._make_tf_features(inputs)
55892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope(scope or "grid_lstm_cell",
5591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           initializer=self._initializer):  # "GridLSTMCell"
5609e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      m_out_lst = []
5619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      state_out_lst = []
5629e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for block in range(len(freq_inputs)):
5639e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        m_out_lst_current, state_out_lst_current = self._compute(
5649e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            freq_inputs[block], block, state, batch_size,
5659e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            state_is_tuple=self._state_is_tuple)
5669e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        m_out_lst.extend(m_out_lst_current)
5679e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        state_out_lst.extend(state_out_lst_current)
5681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._state_is_tuple:
5691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state_out = self._state_tuple_type(*state_out_lst)
5701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
5710e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower        state_out = array_ops.concat(state_out_lst, 1)
5720e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower      m_out = array_ops.concat(m_out_lst, 1)
5731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    return m_out, state_out
5741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
5759e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower  def _compute(self, freq_inputs, block, state, batch_size,
5769e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               state_prefix="state",
5771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               state_is_tuple=True):
5781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """Run the actual computation of one step LSTM.
5791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
5801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Args:
5811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      freq_inputs: list of Tensors, 2D, [batch, feature_size].
5829e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      block: int, current frequency block index to process.
5831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on
5841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        the flag state_is_tuple.
5851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      batch_size: int32, batch size.
5861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state_prefix: (optional) string, name prefix for states, defaults to
5871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        "state".
5881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state_is_tuple: boolean, indicates whether the state is a tuple or Tensor.
5891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
5901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Returns:
5911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      A tuple, containing:
5921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A list of [batch, output_dim] Tensors, representing the output of the
5931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        LSTM given the inputs and state.
5941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A list of [batch, state_size] Tensors, representing the LSTM state
5951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        values given the inputs and previous state.
5961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """
597e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    sigmoid = math_ops.sigmoid
598e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    tanh = math_ops.tanh
5998fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    num_gates = 3 if self._couple_input_forget_gates else 4
6001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    dtype = freq_inputs[0].dtype
601d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    actual_input_size = freq_inputs[0].get_shape().as_list()[1]
6021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    concat_w_f = _get_concat_variable(
6049e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        "W_f_%d" % block, [actual_input_size + 2 * self._num_units,
6059e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                           num_gates * self._num_units],
6061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        dtype, self._num_unit_shards)
6071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    b_f = vs.get_variable(
6084ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        "B_f_%d" % block,
6094ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        shape=[num_gates * self._num_units],
6104ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        initializer=init_ops.zeros_initializer(),
6114ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        dtype=dtype)
6121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if not self._share_time_frequency_weights:
6131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      concat_w_t = _get_concat_variable(
6149e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_t_%d" % block, [actual_input_size + 2 * self._num_units,
6159e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                             num_gates * self._num_units],
616d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          dtype, self._num_unit_shards)
6171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      b_t = vs.get_variable(
6184ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          "B_t_%d" % block,
6194ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          shape=[num_gates * self._num_units],
6204ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          initializer=init_ops.zeros_initializer(),
6214ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          dtype=dtype)
622d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
6231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if self._use_peepholes:
6241855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Diagonal connections
6251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if not self._couple_input_forget_gates:
6261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_f_diag_freqf = vs.get_variable(
6279e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
6281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_f_diag_freqt = vs.get_variable(
6299e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_F_diag_freqt_%d"% block, shape=[self._num_units], dtype=dtype)
6301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_i_diag_freqf = vs.get_variable(
6319e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
6321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_i_diag_freqt = vs.get_variable(
6339e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
6341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_o_diag_freqf = vs.get_variable(
6359e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
6361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_o_diag_freqt = vs.get_variable(
6379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
6381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if not self._share_time_frequency_weights:
6398fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        if not self._couple_input_forget_gates:
6401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          w_f_diag_timef = vs.get_variable(
6419e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
6421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          w_f_diag_timet = vs.get_variable(
6439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
6441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_i_diag_timef = vs.get_variable(
6459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
6461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_i_diag_timet = vs.get_variable(
6479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
6481855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_o_diag_timef = vs.get_variable(
6499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
6501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_o_diag_timet = vs.get_variable(
6519e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
6521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    # initialize the first freq state to be zero
6541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
6551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
6561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    for freq_index in range(len(freq_inputs)):
6571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if state_is_tuple:
6589e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block)
6591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_prev_time = getattr(state, name_prefix + "_c")
6601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_prev_time = getattr(state, name_prefix + "_m")
6611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
6621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_prev_time = array_ops.slice(
6631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            state, [0, 2 * freq_index * self._num_units],
6641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            [-1, self._num_units])
6651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_prev_time = array_ops.slice(
6661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            state, [0, (2 * freq_index + 1) * self._num_units],
6671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            [-1, self._num_units])
6681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
6700e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower      cell_inputs = array_ops.concat(
671d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower          [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)
6721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # F-LSTM
6741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      lstm_matrix_freq = nn_ops.bias_add(math_ops.matmul(cell_inputs,
6751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                                         concat_w_f), b_f)
6761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._couple_input_forget_gates:
677a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower        i_freq, j_freq, o_freq = array_ops.split(
678a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
6791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_freq = None
6801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
681a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower        i_freq, j_freq, f_freq, o_freq = array_ops.split(
682a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
6831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # T-LSTM
6841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._share_time_frequency_weights:
6851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_time = i_freq
6861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        j_time = j_freq
6871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_time = f_freq
6881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        o_time = o_freq
6891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
6901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        lstm_matrix_time = nn_ops.bias_add(math_ops.matmul(cell_inputs,
6911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                                           concat_w_t), b_t)
6928fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        if self._couple_input_forget_gates:
693a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower          i_time, j_time, o_time = array_ops.split(
694a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
6951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_time = None
6968fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        else:
697a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower          i_time, j_time, f_time, o_time = array_ops.split(
698a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
699d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
7001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # F-LSTM c_freq
7011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # input gate activations
7021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
7031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_freq_g = sigmoid(i_freq +
7041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_i_diag_freqf * c_prev_freq +
7051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_i_diag_freqt * c_prev_time)
7061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
7071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_freq_g = sigmoid(i_freq)
7081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # forget gate activations
7091855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._couple_input_forget_gates:
7101855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_freq_g = 1.0 - i_freq_g
7111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
712d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        if self._use_peepholes:
7131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_freq_g = sigmoid(f_freq + self._forget_bias +
7141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                             w_f_diag_freqf * c_prev_freq +
7151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                             w_f_diag_freqt * c_prev_time)
7161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        else:
7171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_freq_g = sigmoid(f_freq + self._forget_bias)
7181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # cell state
7191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq)
7201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._cell_clip is not None:
7211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: disable=invalid-unary-operand-type
7221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip,
7231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                        self._cell_clip)
7241855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: enable=invalid-unary-operand-type
7251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
7261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # T-LSTM c_freq
7271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # input gate activations
7281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
7291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        if self._share_time_frequency_weights:
7301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          i_time_g = sigmoid(i_time +
7318fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower                             w_i_diag_freqf * c_prev_freq +
7328fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower                             w_i_diag_freqt * c_prev_time)
7338fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        else:
7341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          i_time_g = sigmoid(i_time +
7351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                             w_i_diag_timef * c_prev_freq +
7361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                             w_i_diag_timet * c_prev_time)
7371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
7381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_time_g = sigmoid(i_time)
7391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # forget gate activations
7401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._couple_input_forget_gates:
7411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_time_g = 1.0 - i_time_g
7421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
743d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        if self._use_peepholes:
744d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          if self._share_time_frequency_weights:
7451855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            f_time_g = sigmoid(f_time + self._forget_bias +
7461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                               w_f_diag_freqf * c_prev_freq +
7471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                               w_f_diag_freqt * c_prev_time)
748d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          else:
7491855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            f_time_g = sigmoid(f_time + self._forget_bias +
7501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                               w_f_diag_timef * c_prev_freq +
7511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                               w_f_diag_timet * c_prev_time)
7528fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        else:
7531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_time_g = sigmoid(f_time + self._forget_bias)
7541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # cell state
7551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time)
7561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._cell_clip is not None:
7571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: disable=invalid-unary-operand-type
7581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_time = clip_ops.clip_by_value(c_time, -self._cell_clip,
7591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                        self._cell_clip)
7601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: enable=invalid-unary-operand-type
7611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
7621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # F-LSTM m_freq
7631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
7641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_freq = sigmoid(o_freq +
7651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                         w_o_diag_freqf * c_freq +
7661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                         w_o_diag_freqt * c_time) * tanh(c_freq)
7671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
7681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_freq = sigmoid(o_freq) * tanh(c_freq)
769d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
7701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # T-LSTM m_time
7711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
7721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        if self._share_time_frequency_weights:
7731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          m_time = sigmoid(o_time +
7748fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower                           w_o_diag_freqf * c_freq +
7751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_o_diag_freqt * c_time) * tanh(c_time)
776d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        else:
7771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          m_time = sigmoid(o_time +
7781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_o_diag_timef * c_freq +
7791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_o_diag_timet * c_time) * tanh(c_time)
7808fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      else:
7811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_time = sigmoid(o_time) * tanh(c_time)
7821855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
7831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      m_prev_freq = m_freq
7841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      c_prev_freq = c_freq
7851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Concatenate the outputs for T-LSTM and F-LSTM for each shift
7861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if freq_index == 0:
7871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state_out_lst = [c_time, m_time]
7881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_out_lst = [m_time, m_freq]
7891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
7901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state_out_lst.extend([c_time, m_time])
7911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_out_lst.extend([m_time, m_freq])
792d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
7931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    return m_out_lst, state_out_lst
7941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
7951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  def _make_tf_features(self, input_feat, slice_offset=0):
796d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Make the frequency features.
797d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
798d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
7991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      input_feat: input Tensor, 2D, [batch, num_units].
8001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      slice_offset: (optional) Python int, default 0, the slicing offset is only
8011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        used for the backward processing in the BidirectionalGridLSTMCell. It
8021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        specifies a different starting point instead of always 0 to enable the
8031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        forward and backward processing look at different frequency blocks.
804d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
805d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
806d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A list of frequency features, with each element containing:
8071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, output_dim], Tensor representing the time-frequency
8081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        feature for that frequency index. Here output_dim is feature_size.
809d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
810d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if input_size cannot be inferred from static shape inference.
811d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
812d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    input_size = input_feat.get_shape().with_rank(2)[-1].value
813d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if input_size is None:
814d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      raise ValueError("Cannot infer input_size from static shape inference.")
8151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if slice_offset > 0:
8161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Padding to the end
8171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs = array_ops.pad(
8181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          input_feat, array_ops.constant([0, 0, 0, slice_offset], shape=[2, 2],
8191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                         dtype=dtypes.int32),
8201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          "CONSTANT")
8211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    elif slice_offset < 0:
8221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Padding to the front
8231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs = array_ops.pad(
8241855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          input_feat, array_ops.constant([0, 0, -slice_offset, 0], shape=[2, 2],
8251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                         dtype=dtypes.int32),
8261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          "CONSTANT")
8271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      slice_offset = 0
8281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    else:
8291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs = input_feat
830d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    freq_inputs = []
8319e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    if not self._start_freqindex_list:
8329e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if len(self._num_frequency_blocks) != 1:
8339e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError("Length of num_frequency_blocks"
8349e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         " is not 1, but instead is %d",
8359e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._num_frequency_blocks))
8369e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      num_feats = int((input_size - self._feature_size) / (
8379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          self._frequency_skip)) + 1
8389e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if num_feats != self._num_frequency_blocks[0]:
8399e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError(
8409e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "Invalid num_frequency_blocks, requires %d but gets %d, please"
8419e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            " check the input size and filter config are correct." % (
8429e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                self._num_frequency_blocks[0], num_feats))
8439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      block_inputs = []
8449e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for f in range(num_feats):
8459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cur_input = array_ops.slice(
8469e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            inputs, [0, slice_offset + f * self._frequency_skip],
8479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            [-1, self._feature_size])
8489e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        block_inputs.append(cur_input)
8499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      freq_inputs.append(block_inputs)
8509e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    else:
8519e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if len(self._start_freqindex_list) != len(self._end_freqindex_list):
8529e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError("Length of start and end freqindex_list"
8539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         " does not match %d %d",
8549e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._start_freqindex_list),
8559e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._end_freqindex_list))
8569e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if len(self._num_frequency_blocks) != len(self._start_freqindex_list):
8579e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError("Length of num_frequency_blocks"
8589e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         " is not equal to start_freqindex_list %d %d",
8599e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._num_frequency_blocks),
8609e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._start_freqindex_list))
8619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for b in range(len(self._start_freqindex_list)):
8629e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        start_index = self._start_freqindex_list[b]
8639e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        end_index = self._end_freqindex_list[b]
8649e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cur_size = end_index - start_index
8659e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        block_feats = int((cur_size - self._feature_size) / (
8669e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            self._frequency_skip)) + 1
8679e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        if block_feats != self._num_frequency_blocks[b]:
8689e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          raise ValueError(
8699e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              "Invalid num_frequency_blocks, requires %d but gets %d, please"
8709e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              " check the input size and filter config are correct." % (
8719e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                  self._num_frequency_blocks[b], block_feats))
8729e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        block_inputs = []
8739e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        for f in range(block_feats):
8749e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          cur_input = array_ops.slice(
8759e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              inputs, [0, start_index + slice_offset + f *
8769e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                       self._frequency_skip],
8779e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              [-1, self._feature_size])
8789e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          block_inputs.append(cur_input)
8799e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        freq_inputs.append(block_inputs)
880d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return freq_inputs
8815cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
8825cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
8831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerclass BidirectionalGridLSTMCell(GridLSTMCell):
8841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  """Bidirectional GridLstm cell.
8851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
8861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  The bidirection connection is only used in the frequency direction, which
8871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  hence doesn't affect the time direction's real-time processing that is
8881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  required for online recognition systems.
8891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  The current implementation uses different weights for the two directions.
8901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  """
8911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
8921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  def __init__(self, num_units, use_peepholes=False,
8931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               share_time_frequency_weights=False,
8941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               cell_clip=None, initializer=None,
8951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               num_unit_shards=1, forget_bias=1.0,
8961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               feature_size=None, frequency_skip=None,
8979e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               num_frequency_blocks=None,
8989e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               start_freqindex_list=None,
8999e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               end_freqindex_list=None,
9001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               couple_input_forget_gates=False,
9011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               backward_slice_offset=0):
9021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
9031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Args:
9051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
9061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      use_peepholes: (optional) bool, default False. Set True to enable
9071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        diagonal/peephole connections.
9081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      share_time_frequency_weights: (optional) bool, default False. Set True to
9091855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        enable shared cell weights between time and frequency LSTMs.
9101855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      cell_clip: (optional) A float value, default None, if provided the cell
9111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state is clipped by this value prior to the cell output activation.
9121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
9131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        projection matrices, default None.
9141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      num_unit_shards: (optional) int, defualt 1, How to split the weight
9151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        matrix. If > 1,the weight matrix is stored across num_unit_shards.
9161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      forget_bias: (optional) float, default 1.0, The initial bias of the
9171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        forget gates, used to reduce the scale of forgetting at the beginning
9181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        of the training.
9191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      feature_size: (optional) int, default None, The size of the input feature
9201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        the LSTM spans over.
9211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      frequency_skip: (optional) int, default None, The amount the LSTM filter
9221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        is shifted by in frequency.
9239e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      num_frequency_blocks: [required] A list of frequency blocks needed to
9249e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cover the whole input feature splitting defined by start_freqindex_list
9259e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        and end_freqindex_list.
9269e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      start_freqindex_list: [optional], list of ints, default None,  The
9279e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        starting frequency index for each frequency block.
9289e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      end_freqindex_list: [optional], list of ints, default None. The ending
9299e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        frequency index for each frequency block.
9301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      couple_input_forget_gates: (optional) bool, default False, Whether to
9311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
9321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        model parameters and computation cost.
9331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      backward_slice_offset: (optional) int32, default 0, the starting offset to
9341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        slice the feature for backward processing.
9351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """
9361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    super(BidirectionalGridLSTMCell, self).__init__(
9371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        num_units, use_peepholes, share_time_frequency_weights, cell_clip,
9381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        initializer, num_unit_shards, forget_bias, feature_size, frequency_skip,
9399e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        num_frequency_blocks, start_freqindex_list, end_freqindex_list,
9409e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        couple_input_forget_gates=False,
9411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state_is_tuple=True)
9421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    self._backward_slice_offset = int(backward_slice_offset)
9431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    state_names = ""
9441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    for direction in ["fwd", "bwd"]:
9459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for block_index in range(len(self._num_frequency_blocks)):
9469e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        for freq_index in range(self._num_frequency_blocks[block_index]):
9479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index,
9489e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                                                  block_index)
9499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
9501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    self._state_tuple_type = collections.namedtuple(
9511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        "BidirectionalGridLSTMStateTuple", state_names.strip(","))
9521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    self._state_size = self._state_tuple_type(
9539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        *([num_units, num_units] * self._total_blocks * 2))
9549e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._output_size = 2 * num_units * self._total_blocks * 2
9551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  def __call__(self, inputs, state, scope=None):
9571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """Run one step of LSTM.
9581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Args:
9601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs: input Tensor, 2D, [batch, num_units].
9611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state: tuple of Tensors, 2D, [batch, state_size].
9621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      scope: (optional) VariableScope for the created subgraph; if None, it
9631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        defaults to "BidirectionalGridLSTMCell".
9641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Returns:
9661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      A tuple containing:
9671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
9681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        after reading "inputs" when previous state was "state".
9691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        Here output_dim is num_units.
9701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
9711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        after reading "inputs" when previous state was "state".
9721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Raises:
9731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      ValueError: if an input_size was specified and the provided inputs have
9741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        a different dimension.
9751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """
9761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    batch_size = int(inputs.get_shape()[0])
9771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    fwd_inputs = self._make_tf_features(inputs)
9781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if self._backward_slice_offset:
9791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
9801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    else:
9811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      bwd_inputs = fwd_inputs
9821855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    # Forward processing
98492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope(scope or "bidirectional_grid_lstm_cell",
9851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           initializer=self._initializer):
98692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      with vs.variable_scope("fwd"):
98792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo        fwd_m_out_lst = []
98892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo        fwd_state_out_lst = []
98992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo        for block in range(len(fwd_inputs)):
99092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo          fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
99192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo              fwd_inputs[block], block, state, batch_size,
99292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo              state_prefix="fwd_state", state_is_tuple=True)
99392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo          fwd_m_out_lst.extend(fwd_m_out_lst_current)
99492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo          fwd_state_out_lst.extend(fwd_state_out_lst_current)
99592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      # Backward processing
99692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      bwd_m_out_lst = []
99792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      bwd_state_out_lst = []
99892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      with vs.variable_scope("bwd"):
99992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo        for block in range(len(bwd_inputs)):
100092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo          # Reverse the blocks
100192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo          bwd_inputs_reverse = bwd_inputs[block][::-1]
100292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo          bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
100392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo              bwd_inputs_reverse, block, state, batch_size,
100492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo              state_prefix="bwd_state", state_is_tuple=True)
100592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo          bwd_m_out_lst.extend(bwd_m_out_lst_current)
100692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo          bwd_state_out_lst.extend(bwd_state_out_lst_current)
10071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
10081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    # Outputs are always concated as it is never used separately.
10090e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower    m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
10101855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    return m_out, state_out
10111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
10121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
10135cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: disable=protected-access
101437fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie_linear = core_rnn_cell_impl._linear
10155cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: enable=protected-access
10165cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10175cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
101837fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass AttentionCellWrapper(core_rnn_cell.RNNCell):
10195cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  """Basic attention cell wrapper.
10205cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
1021ec3f4d62979ef1e70e8e12e2568b13dad45fd39eA. Unique TensorFlower  Implementation based on https://arxiv.org/abs/1409.0473.
10225cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  """
10235cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10245cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
10255cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower               input_size=None, state_is_tuple=False):
10265cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    """Create a cell with attention.
10275cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10285cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    Args:
10295cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      cell: an RNNCell, an attention is added to it.
10305cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_length: integer, the size of an attention window.
10315cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_size: integer, the size of an attention vector. Equal to
10325cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          cell.output_size by default.
10335cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_vec_size: integer, the number of convolutional features calculated
10345cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          on attention state and a size of the hidden layer built from
10355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          base cell state. Equal attn_size to by default.
10365cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      input_size: integer, the size of a hidden linear layer,
10375cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          built from inputs and attention. Derived from the input tensor
10385cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          by default.
10395cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      state_is_tuple: If True, accepted and returned states are n-tuples, where
10405cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        `n = len(cells)`.  By default (False), the states are all
10415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        concatenated along the column axis.
10425cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10435cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    Raises:
10445cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      TypeError: if cell is not an RNNCell.
10455cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      ValueError: if cell returns a state tuple but the flag
10465cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          `state_is_tuple` is `False` or if attn_length is zero or less.
10475cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    """
104837fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie    if not isinstance(cell, core_rnn_cell.RNNCell):
10495cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      raise TypeError("The parameter cell is not RNNCell.")
10504c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenech    if nest.is_sequence(cell.state_size) and not state_is_tuple:
10515cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      raise ValueError("Cell returns tuple of states, but the flag "
10525cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower                       "state_is_tuple is not set. State size is: %s"
10535cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower                       % str(cell.state_size))
10545cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if attn_length <= 0:
10555cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      raise ValueError("attn_length should be greater than zero, got %s"
10565cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower                       % str(attn_length))
10575cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if not state_is_tuple:
10585cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      logging.warn(
10595cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          "%s: Using a concatenated state is slower and will soon be "
1060d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower          "deprecated.  Use state_is_tuple=True.", self)
10615cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if attn_size is None:
10625cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_size = cell.output_size
10635cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if attn_vec_size is None:
10645cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_vec_size = attn_size
10655cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._state_is_tuple = state_is_tuple
10665cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._cell = cell
10675cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._attn_vec_size = attn_vec_size
10685cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._input_size = input_size
10695cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._attn_size = attn_size
10705cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._attn_length = attn_length
10715cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10725cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  @property
10735cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def state_size(self):
10745cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    size = (self._cell.state_size, self._attn_size,
10755cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower            self._attn_size * self._attn_length)
10765cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if self._state_is_tuple:
10775cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      return size
10785cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    else:
10795cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      return sum(list(size))
10805cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10815cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  @property
10825cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def output_size(self):
10835cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    return self._attn_size
10845cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10855cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def __call__(self, inputs, state, scope=None):
10865cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    """Long short-term memory cell with attention (LSTMA)."""
108792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope(scope or "attention_cell_wrapper"):
10885cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      if self._state_is_tuple:
10895cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        state, attns, attn_states = state
10905cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      else:
10915cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        states = state
10925cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
10935cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        attns = array_ops.slice(
10945cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower            states, [0, self._cell.state_size], [-1, self._attn_size])
10955cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        attn_states = array_ops.slice(
10965cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower            states, [0, self._cell.state_size + self._attn_size],
10975cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower            [-1, self._attn_size * self._attn_length])
10985cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_states = array_ops.reshape(attn_states,
10995cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower                                      [-1, self._attn_length, self._attn_size])
11005cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      input_size = self._input_size
11015cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      if input_size is None:
11025cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        input_size = inputs.get_shape().as_list()[1]
11035cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      inputs = _linear([inputs, attns], input_size, True)
11045cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      lstm_output, new_state = self._cell(inputs, state)
11055cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      if self._state_is_tuple:
11060e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower        new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
11075cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      else:
11085cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        new_state_cat = new_state
11095cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
111092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      with vs.variable_scope("attn_output_projection"):
11115cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        output = _linear([lstm_output, new_attns], self._attn_size, True)
11120e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower      new_attn_states = array_ops.concat(
1113d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower          [new_attn_states, array_ops.expand_dims(output, 1)], 1)
11145cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      new_attn_states = array_ops.reshape(
11155cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          new_attn_states, [-1, self._attn_length * self._attn_size])
11165cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      new_state = (new_state, new_attns, new_attn_states)
11175cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      if not self._state_is_tuple:
11180e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower        new_state = array_ops.concat(list(new_state), 1)
11195cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      return output, new_state
11205cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
11215cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def _attention(self, query, attn_states):
1122e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    conv2d = nn_ops.conv2d
1123e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    reduce_sum = math_ops.reduce_sum
1124e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    softmax = nn_ops.softmax
1125e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    tanh = math_ops.tanh
1126e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo
112792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope("attention"):
112892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      k = vs.get_variable(
112992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo          "attn_w", [1, 1, self._attn_size, self._attn_vec_size])
113092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      v = vs.get_variable("attn_v", [self._attn_vec_size])
11315cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      hidden = array_ops.reshape(attn_states,
11325cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower                                 [-1, self._attn_length, 1, self._attn_size])
11335cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
11345cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      y = _linear(query, self._attn_vec_size, True)
11355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
11365cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
11375cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      a = softmax(s)
11385cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      d = reduce_sum(
11395cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
11405cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      new_attns = array_ops.reshape(d, [-1, self._attn_size])
11415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
11425cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      return new_attns, new_attn_states
114334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
114434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
114537fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
114634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  """LSTM unit with layer normalization and recurrent dropout.
114734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
114834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  This class adds layer normalization and recurrent dropout to a
114934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  basic LSTM unit. Layer normalization implementation is based on:
115034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
115134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    https://arxiv.org/abs/1607.06450.
115234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
115334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  "Layer Normalization"
115434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
115534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
115634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  and is applied before the internal nonlinearities.
115734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  Recurrent dropout is base on:
115834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
115934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    https://arxiv.org/abs/1603.05118
116034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
116134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  "Recurrent Dropout without Memory Loss"
116234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
116334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  """
116434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
116534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def __init__(self, num_units, forget_bias=1.0,
116634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower               input_size=None, activation=math_ops.tanh,
116734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower               layer_norm=True, norm_gain=1.0, norm_shift=0.0,
116834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower               dropout_keep_prob=1.0, dropout_prob_seed=None):
116934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    """Initializes the basic LSTM cell.
117034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
117134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    Args:
117234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      num_units: int, The number of units in the LSTM cell.
117334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      forget_bias: float, The bias added to forget gates (see above).
117434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      input_size: Deprecated and unused.
117534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      activation: Activation function of the inner states.
117634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      layer_norm: If `True`, layer normalization will be applied.
117734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      norm_gain: float, The layer normalization gain initial value. If
117834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        `layer_norm` has been set to `False`, this argument will be ignored.
117934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      norm_shift: float, The layer normalization shift initial value. If
118034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        `layer_norm` has been set to `False`, this argument will be ignored.
118134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      dropout_keep_prob: unit Tensor or float between 0 and 1 representing the
118234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        recurrent dropout probability value. If float and 1.0, no dropout will
118334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        be applied.
118434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      dropout_prob_seed: (optional) integer, the randomness seed.
118534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    """
118634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
118734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    if input_size is not None:
118834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      logging.warn("%s: The input_size parameter is deprecated.", self)
118934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
119034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._num_units = num_units
119134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._activation = activation
119234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._forget_bias = forget_bias
119334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._keep_prob = dropout_keep_prob
119434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._seed = dropout_prob_seed
119534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._layer_norm = layer_norm
119634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._g = norm_gain
119734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._b = norm_shift
119834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
119934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  @property
120034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def state_size(self):
120137fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie    return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
120234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
120334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  @property
120434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def output_size(self):
120534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    return self._num_units
120634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
120734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def _norm(self, inp, scope):
120892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    shape = inp.get_shape()[-1:]
120992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    gamma_init = init_ops.constant_initializer(self._g)
121092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    beta_init = init_ops.constant_initializer(self._b)
121192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope(scope):
121292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      # Initialize beta and gamma for use by layer_norm.
121392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      vs.get_variable("gamma", shape=shape, initializer=gamma_init)
121492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      vs.get_variable("beta", shape=shape, initializer=beta_init)
121592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    normalized = layers.layer_norm(inp, reuse=True, scope=scope)
121692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    return normalized
121792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo
121892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo  def _linear(self, args):
121934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    out_size = 4 * self._num_units
122034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    proj_size = args.get_shape()[-1]
122192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    weights = vs.get_variable("weights", [proj_size, out_size])
122292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    out = math_ops.matmul(args, weights)
122392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    if not self._layer_norm:
122492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      bias = vs.get_variable("biases", [out_size])
122592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      out = nn_ops.bias_add(out, bias)
122692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    return out
122734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
122834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def __call__(self, inputs, state, scope=None):
122934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    """LSTM cell with layer normalization and recurrent dropout."""
123034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
123192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope(scope or "layer_norm_basic_lstm_cell"):
123234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      c, h = state
12330e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower      args = array_ops.concat([inputs, h], 1)
123434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      concat = self._linear(args)
123534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1236a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
123734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      if self._layer_norm:
123834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        i = self._norm(i, "input")
123934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        j = self._norm(j, "transform")
124034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        f = self._norm(f, "forget")
124134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        o = self._norm(o, "output")
124234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
124334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      g = self._activation(j)
124434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
124534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
124634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
124734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      new_c = (c * math_ops.sigmoid(f + self._forget_bias)
124834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower               + math_ops.sigmoid(i) * g)
124934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      if self._layer_norm:
125034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        new_c = self._norm(new_c, "state")
125134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      new_h = self._activation(new_c) * math_ops.sigmoid(o)
125234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
125337fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie      new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
125434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      return new_h, new_state
1255bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1256bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1257bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo_REGISTERED_OPS = None
1258bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1259bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1260bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdoclass CompiledWrapper(core_rnn_cell.RNNCell):
1261bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  """Wraps step execution in an XLA JIT scope."""
1262bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1263bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def __init__(self, cell, compile_stateful=False):
1264bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    """Create CompiledWrapper cell.
1265bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1266bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    Args:
1267bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      cell: Instance of `RNNCell`.
1268bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      compile_stateful: Whether to compile stateful ops like initializers
1269bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        and random number generators (default: False).
1270bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    """
1271bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    self._cell = cell
1272bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    self._compile_stateful = compile_stateful
1273bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1274bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  @property
1275bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def state_size(self):
1276bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    return self._cell.state_size
1277bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1278bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  @property
1279bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def output_size(self):
1280bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    return self._cell.output_size
1281bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1282bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def __call__(self, inputs, state, scope=None):
1283bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    if self._compile_stateful:
1284bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      compile_ops = True
1285bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    else:
1286bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      def compile_ops(node_def):
1287bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        global _REGISTERED_OPS
1288bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        if _REGISTERED_OPS is None:
1289bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo          _REGISTERED_OPS = op_def_registry.get_registered_ops()
1290bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        return not _REGISTERED_OPS[node_def.op].is_stateful
1291bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1292bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    with jit.experimental_jit_scope(compile_ops=compile_ops):
1293bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      return self._cell(inputs, state, scope=scope)
1294