rnn_cell.py revision e8482ab23bd0fce5c2941f6a190158bca2610a35
1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#
3d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
4d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# you may not use this file except in compliance with the License.
5d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# You may obtain a copy of the License at
6d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#
7d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
8d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#
9d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
10d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
11d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# See the License for the specific language governing permissions and
13d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# limitations under the License.
14d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# ==============================================================================
15d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
16d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower"""Module for constructing RNN Cells."""
17d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import absolute_import
18d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import division
19d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import print_function
20d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
218fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlowerimport collections
22d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerimport math
23d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
24bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.contrib.compiler import jit
2534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.contrib.layers.python.layers import layers
2637fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xiefrom tensorflow.contrib.rnn.python.ops import core_rnn_cell
2737fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xiefrom tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerfrom tensorflow.python.framework import dtypes
29bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.python.framework import op_def_registry
30d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.framework import ops
31d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import array_ops
32d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import clip_ops
3334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.python.ops import init_ops
34d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import math_ops
35d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import nn_ops
36bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlowerfrom tensorflow.python.ops import random_ops
37d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import variable_scope as vs
385cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlowerfrom tensorflow.python.platform import tf_logging as logging
394c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenechfrom tensorflow.python.util import nest
40d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
4137fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie
4254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo_checked_scope = core_rnn_cell_impl._checked_scope  # pylint: disable=protected-access
4354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo
4454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo
45d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_concat_variable(name, shape, dtype, num_shards):
46d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Get a sharded variable concatenated into one tensor."""
47d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
48d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  if len(sharded_variable) == 1:
49d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return sharded_variable[0]
50d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
51d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  concat_name = name + "/concat"
52d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
53d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
54d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if value.name == concat_full_name:
55d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      return value
56d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
570e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
58d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES,
59d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                        concat_variable)
60d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  return concat_variable
61d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
62d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
63d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_sharded_variable(name, shape, dtype, num_shards):
64d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Get a list of sharded variables with the given dtype."""
65d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  if num_shards > shape[0]:
66d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    raise ValueError("Too many shards: shape=%s, num_shards=%d" %
67d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                     (shape, num_shards))
68d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  unit_shard_size = int(math.floor(shape[0] / num_shards))
69d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  remaining_rows = shape[0] - unit_shard_size * num_shards
70d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
71d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  shards = []
72d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  for i in range(num_shards):
73d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    current_size = unit_shard_size
74d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if i < remaining_rows:
75d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      current_size += 1
76d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:],
77d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                                  dtype=dtype))
78d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  return shards
79d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
80d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
8137fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
821032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  """Long short-term memory unit (LSTM) recurrent network cell.
831032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
841032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The default non-peephole implementation is based on:
851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
861032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
871032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
881032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  S. Hochreiter and J. Schmidhuber.
891032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
901032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
911032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The peephole implementation is based on:
921032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
931032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    https://research.google.com/pubs/archive/43905.pdf
941032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  Hasim Sak, Andrew Senior, and Francoise Beaufays.
961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  "Long short-term memory recurrent neural network architectures for
971032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower   large scale acoustic modeling." INTERSPEECH, 2014.
981032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
991032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The coupling of input and forget gate is based on:
1001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1011032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    http://arxiv.org/pdf/1503.04069.pdf
1021032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1031032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  Greff et al. "LSTM: A Search Space Odyssey"
1041032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The class uses optional peep-hole connections, and an optional projection
1061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  layer.
1071032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  """
1081032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1091032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  def __init__(self, num_units, use_peepholes=False,
1101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower               initializer=None, num_proj=None, proj_clip=None,
1111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower               num_unit_shards=1, num_proj_shards=1,
112c3d99052ec49bb219f5e29567846d3af391d7b28A. Unique TensorFlower               forget_bias=1.0, state_is_tuple=True,
11354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               activation=math_ops.tanh, reuse=None):
1141032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
1151032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1161032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Args:
1171032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
1181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      use_peepholes: bool, set True to enable diagonal/peephole connections.
1191032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
1201032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        projection matrices.
1211032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_proj: (optional) int, The output dimensionality for the projection
1221032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        matrices.  If None, no projection is performed.
1231032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
1241032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      provided, then the projected values are clipped elementwise to within
1251032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      `[-proj_clip, proj_clip]`.
1261032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_unit_shards: How to split the weight matrix.  If >1, the weight
1271032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        matrix is stored across num_unit_shards.
1281032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_proj_shards: How to split the projection matrix.  If >1, the
1291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        projection matrix is stored across num_proj_shards.
1301032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      forget_bias: Biases of the forget gate are initialized by default to 1
1311032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        in order to reduce the scale of forgetting at the beginning of
1321032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        the training.
1331032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      state_is_tuple: If True, accepted and returned states are 2-tuples of
1341032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        the `c_state` and `m_state`.  By default (False), they are concatenated
1351032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        along the column axis.  This default behavior will soon be deprecated.
1361032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      activation: Activation function of the inner states.
13754d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
13854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
13954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
1401032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """
141e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
1421032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if not state_is_tuple:
1431032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      logging.warn(
1441032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          "%s: Using a concatenated state is slower and will soon be "
145d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower          "deprecated.  Use state_is_tuple=True.", self)
1461032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_units = num_units
1471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._use_peepholes = use_peepholes
1481032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._initializer = initializer
1491032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_proj = num_proj
1501032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._proj_clip = proj_clip
1511032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_unit_shards = num_unit_shards
1521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_proj_shards = num_proj_shards
1531032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._forget_bias = forget_bias
1541032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._state_is_tuple = state_is_tuple
1551032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._activation = activation
15654d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
1571032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1581032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if num_proj:
1591032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      self._state_size = (
16037fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie          core_rnn_cell.LSTMStateTuple(num_units, num_proj)
1611032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          if state_is_tuple else num_units + num_proj)
1621032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      self._output_size = num_proj
1631032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    else:
1641032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      self._state_size = (
16537fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie          core_rnn_cell.LSTMStateTuple(num_units, num_units)
1661032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower          if state_is_tuple else 2 * num_units)
1671032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      self._output_size = num_units
1681032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1691032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  @property
1701032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  def state_size(self):
1711032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    return self._state_size
1721032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1731032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  @property
1741032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  def output_size(self):
1751032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    return self._output_size
1761032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
177e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
1781032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """Run one step of LSTM.
1791032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1801032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Args:
1811032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      inputs: input Tensor, 2D, batch x num_units.
1821032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      state: if `state_is_tuple` is False, this must be a state Tensor,
1831032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
1841032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
1851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        `m_state`.
1861032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1871032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Returns:
1881032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      A tuple containing:
1891032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
1901032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        LSTM after reading `inputs` when previous state was `state`.
1911032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        Here output_dim is:
1921032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower           num_proj if num_proj was set,
1931032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower           num_units otherwise.
1941032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      - Tensor(s) representing the new state of LSTM after reading `inputs` when
1951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        the previous state was `state`.  Same type and shape(s) as `state`.
1961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1971032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Raises:
1981032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      ValueError: If input size cannot be inferred from inputs via
1991032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        static shape inference.
2001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """
201e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    sigmoid = math_ops.sigmoid
202e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo
2031032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    num_proj = self._num_units if self._num_proj is None else self._num_proj
2041032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if self._state_is_tuple:
2061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      (c_prev, m_prev) = state
2071032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    else:
2081032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
2091032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
2101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    dtype = inputs.dtype
2121032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    input_size = inputs.get_shape().with_rank(2)[1]
2131032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if input_size.value is None:
2141032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
215e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    concat_w = _get_concat_variable(
216e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "W", [input_size.value + num_proj, 3 * self._num_units],
217e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        dtype, self._num_unit_shards)
2181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
219e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    b = vs.get_variable(
220e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "B",
221e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        shape=[3 * self._num_units],
222e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        initializer=init_ops.zeros_initializer(),
223e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        dtype=dtype)
2241032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
225e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # j = new_input, f = forget_gate, o = output_gate
226e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    cell_inputs = array_ops.concat([inputs, m_prev], 1)
227e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
228e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
2291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
230e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Diagonal connections
231e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
232e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      w_f_diag = vs.get_variable(
233e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_F_diag", shape=[self._num_units], dtype=dtype)
234e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      w_o_diag = vs.get_variable(
235e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_O_diag", shape=[self._num_units], dtype=dtype)
2361032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
237e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
238e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
239e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    else:
240e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      f_act = sigmoid(f + self._forget_bias)
241e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    c = (f_act * c_prev + (1 - f_act) * self._activation(j))
2421032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
243e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
244e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m = sigmoid(o + w_o_diag * c) * self._activation(c)
245e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    else:
246e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m = sigmoid(o) * self._activation(c)
2471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
248e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._num_proj is not None:
249e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      concat_w_proj = _get_concat_variable(
250e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_P", [self._num_units, self._num_proj],
251e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          dtype, self._num_proj_shards)
2521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
253e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m = math_ops.matmul(m, concat_w_proj)
254e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      if self._proj_clip is not None:
255e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        # pylint: disable=invalid-unary-operand-type
256e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
257e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        # pylint: enable=invalid-unary-operand-type
2581032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2590e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower    new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else
2600e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower                 array_ops.concat([c, m], 1))
2611032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    return m, new_state
2621032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2631032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
26437fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass TimeFreqLSTMCell(core_rnn_cell.RNNCell):
265d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
266d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
267d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  This implementation is based on:
268d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
269d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Tara N. Sainath and Bo Li
270d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
271d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    for LVCSR Tasks." submitted to INTERSPEECH, 2016.
272d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
273d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  It uses peep-hole connections and optional cell clipping.
274d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """
275d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
276d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def __init__(self, num_units, use_peepholes=False,
277d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               cell_clip=None, initializer=None,
278d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               num_unit_shards=1, forget_bias=1.0,
27954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               feature_size=None, frequency_skip=None,
28054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               reuse=None):
281d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
282d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
283d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
284d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
285d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      use_peepholes: bool, set True to enable diagonal/peephole connections.
286d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      cell_clip: (optional) A float value, if provided the cell state is clipped
287d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        by this value prior to the cell output activation.
288d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
289d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        projection matrices.
290d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      num_unit_shards: int, How to split the weight matrix.  If >1, the weight
291d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        matrix is stored across num_unit_shards.
292d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      forget_bias: float, Biases of the forget gate are initialized by default
293d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        to 1 in order to reduce the scale of forgetting at the beginning
294d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        of the training.
295d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      feature_size: int, The size of the input feature the LSTM spans over.
296d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      frequency_skip: int, The amount the LSTM filter is shifted by in
297d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        frequency.
29854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
29954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
30054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
301d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
302e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(TimeFreqLSTMCell, self).__init__(_reuse=reuse)
303d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_units = num_units
304d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._use_peepholes = use_peepholes
305d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._cell_clip = cell_clip
306d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._initializer = initializer
307d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_unit_shards = num_unit_shards
308d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._forget_bias = forget_bias
309d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._feature_size = feature_size
310d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._frequency_skip = frequency_skip
311d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._state_size = 2 * num_units
312d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._output_size = num_units
31354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
314d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
315d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
316d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def output_size(self):
317d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._output_size
318d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
319d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
320d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def state_size(self):
321d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._state_size
322d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
323e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
324d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Run one step of LSTM.
325d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
326d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
327d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      inputs: input Tensor, 2D, batch x num_units.
328d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      state: state Tensor, 2D, batch x state_size.
329d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
330d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
331d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A tuple containing:
332d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      - A 2D, batch x output_dim, Tensor representing the output of the LSTM
333d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
334d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        Here output_dim is num_units.
335d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      - A 2D, batch x state_size, Tensor representing the new state of LSTM
336d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
337d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
338d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if an input_size was specified and the provided inputs have
339d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        a different dimension.
340d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
341e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    sigmoid = math_ops.sigmoid
342e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    tanh = math_ops.tanh
343d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
344d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    freq_inputs = self._make_tf_features(inputs)
345d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    dtype = inputs.dtype
346d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    actual_input_size = freq_inputs[0].get_shape().as_list()[1]
347d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
348e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    concat_w = _get_concat_variable(
349e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
350e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        dtype, self._num_unit_shards)
351d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
352e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    b = vs.get_variable(
353e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "B",
354e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        shape=[4 * self._num_units],
355e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        initializer=init_ops.zeros_initializer(),
356e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        dtype=dtype)
357d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
358e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Diagonal connections
359e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
360e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      w_f_diag = vs.get_variable(
361e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_F_diag", shape=[self._num_units], dtype=dtype)
362e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      w_i_diag = vs.get_variable(
363e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_I_diag", shape=[self._num_units], dtype=dtype)
364e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      w_o_diag = vs.get_variable(
365e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_O_diag", shape=[self._num_units], dtype=dtype)
366d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
367e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # initialize the first freq state to be zero
368e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
369e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                                   self._num_units], dtype)
370e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    for fq in range(len(freq_inputs)):
371e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
372e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                               [-1, self._num_units])
373e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
374e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                               [-1, self._num_units])
375e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
376e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
377e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                                     1)
378e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
379e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      i, j, f, o = array_ops.split(
380e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          value=lstm_matrix, num_or_size_splits=4, axis=1)
381e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
382e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      if self._use_peepholes:
383e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
384e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo             sigmoid(i + w_i_diag * c_prev) * tanh(j))
385e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      else:
386e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
387e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
388e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      if self._cell_clip is not None:
389e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        # pylint: disable=invalid-unary-operand-type
390e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
391e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        # pylint: enable=invalid-unary-operand-type
392e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
393e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      if self._use_peepholes:
394e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        m = sigmoid(o + w_o_diag * c) * tanh(c)
395e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      else:
396e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        m = sigmoid(o) * tanh(c)
397e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m_prev_freq = m
398e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      if fq == 0:
399e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        state_out = array_ops.concat([c, m], 1)
400e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        m_out = m
401e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      else:
402e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        state_out = array_ops.concat([state_out, c, m], 1)
403e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        m_out = array_ops.concat([m_out, m], 1)
404d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return m_out, state_out
405d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
406d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def _make_tf_features(self, input_feat):
407d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Make the frequency features.
408d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
409d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
410d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      input_feat: input Tensor, 2D, batch x num_units.
411d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
412d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
413d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A list of frequency features, with each element containing:
414d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      - A 2D, batch x output_dim, Tensor representing the time-frequency feature
415d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        for that frequency index. Here output_dim is feature_size.
416d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
417d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if input_size cannot be inferred from static shape inference.
418d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
419d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    input_size = input_feat.get_shape().with_rank(2)[-1].value
420d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if input_size is None:
421d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      raise ValueError("Cannot infer input_size from static shape inference.")
422d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    num_feats = int((input_size - self._feature_size) / (
423d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        self._frequency_skip)) + 1
424d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    freq_inputs = []
425d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    for f in range(num_feats):
426d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip],
427d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                                  [-1, self._feature_size])
428d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      freq_inputs.append(cur_input)
429d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return freq_inputs
430d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
431d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
43237fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass GridLSTMCell(core_rnn_cell.RNNCell):
433d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Grid Long short-term memory unit (LSTM) recurrent network cell.
434d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
435d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  The default is based on:
436d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Nal Kalchbrenner, Ivo Danihelka and Alex Graves
437d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    "Grid Long Short-Term Memory," Proc. ICLR 2016.
438d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    http://arxiv.org/abs/1507.01526
439d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
440d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  When peephole connections are used, the implementation is based on:
441d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Tara N. Sainath and Bo Li
442d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
443d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    for LVCSR Tasks." submitted to INTERSPEECH, 2016.
444d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
445d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  The code uses optional peephole connections, shared_weights and cell clipping.
446d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """
447d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
448d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def __init__(self, num_units, use_peepholes=False,
449d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               share_time_frequency_weights=False,
450d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               cell_clip=None, initializer=None,
451d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               num_unit_shards=1, forget_bias=1.0,
4528fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower               feature_size=None, frequency_skip=None,
4539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               num_frequency_blocks=None,
4549e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               start_freqindex_list=None,
4559e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               end_freqindex_list=None,
4568fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower               couple_input_forget_gates=False,
457c3d99052ec49bb219f5e29567846d3af391d7b28A. Unique TensorFlower               state_is_tuple=True,
45854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               reuse=None):
459d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
460d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
461d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
462d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
4631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      use_peepholes: (optional) bool, default False. Set True to enable
4641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        diagonal/peephole connections.
4651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      share_time_frequency_weights: (optional) bool, default False. Set True to
4661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        enable shared cell weights between time and frequency LSTMs.
4671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      cell_clip: (optional) A float value, default None, if provided the cell
4681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state is clipped by this value prior to the cell output activation.
469d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
4701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        projection matrices, default None.
4711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      num_unit_shards: (optional) int, defualt 1, How to split the weight
4721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        matrix. If > 1,the weight matrix is stored across num_unit_shards.
4731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      forget_bias: (optional) float, default 1.0, The initial bias of the
4741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        forget gates, used to reduce the scale of forgetting at the beginning
475d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        of the training.
4761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      feature_size: (optional) int, default None, The size of the input feature
4771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        the LSTM spans over.
4781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      frequency_skip: (optional) int, default None, The amount the LSTM filter
4791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        is shifted by in frequency.
4809e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      num_frequency_blocks: [required] A list of frequency blocks needed to
4819e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cover the whole input feature splitting defined by start_freqindex_list
4829e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        and end_freqindex_list.
4839e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      start_freqindex_list: [optional], list of ints, default None,  The
4849e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        starting frequency index for each frequency block.
4859e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      end_freqindex_list: [optional], list of ints, default None. The ending
4869e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        frequency index for each frequency block.
4871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      couple_input_forget_gates: (optional) bool, default False, Whether to
4881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
4891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        model parameters and computation cost.
4908fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      state_is_tuple: If True, accepted and returned states are 2-tuples of
4918fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        the `c_state` and `m_state`.  By default (False), they are concatenated
4928fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        along the column axis.  This default behavior will soon be deprecated.
49354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
49454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
49554d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
4969e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    Raises:
4979e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      ValueError: if the num_frequency_blocks list is not specified
498d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
499e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(GridLSTMCell, self).__init__(_reuse=reuse)
5008fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    if not state_is_tuple:
5018fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      logging.warn("%s: Using a concatenated state is slower and will soon be "
5028fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower                   "deprecated.  Use state_is_tuple=True.", self)
503d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_units = num_units
504d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._use_peepholes = use_peepholes
505d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._share_time_frequency_weights = share_time_frequency_weights
5068fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    self._couple_input_forget_gates = couple_input_forget_gates
5078fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    self._state_is_tuple = state_is_tuple
508d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._cell_clip = cell_clip
509d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._initializer = initializer
510d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_unit_shards = num_unit_shards
511d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._forget_bias = forget_bias
512d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._feature_size = feature_size
513d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._frequency_skip = frequency_skip
5149e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._start_freqindex_list = start_freqindex_list
5159e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._end_freqindex_list = end_freqindex_list
5169e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._num_frequency_blocks = num_frequency_blocks
5179e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._total_blocks = 0
51854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
5199e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    if self._num_frequency_blocks is None:
5209e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      raise ValueError("Must specify num_frequency_blocks")
5219e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower
5229e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    for block_index in range(len(self._num_frequency_blocks)):
5239e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      self._total_blocks += int(self._num_frequency_blocks[block_index])
5248fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    if state_is_tuple:
5258fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      state_names = ""
5269e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for block_index in range(len(self._num_frequency_blocks)):
5279e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        for freq_index in range(self._num_frequency_blocks[block_index]):
5289e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
5299e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
5308fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      self._state_tuple_type = collections.namedtuple(
5311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          "GridLSTMStateTuple", state_names.strip(","))
5328fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      self._state_size = self._state_tuple_type(
5339e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              *([num_units, num_units] * self._total_blocks))
5348fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    else:
5358fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      self._state_tuple_type = None
5369e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      self._state_size = num_units * self._total_blocks * 2
5379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._output_size = num_units * self._total_blocks * 2
538d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
539d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
540d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def output_size(self):
541d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._output_size
542d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
543d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
544d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def state_size(self):
545d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._state_size
546d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
5478fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower  @property
5488fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower  def state_tuple_type(self):
5498fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    return self._state_tuple_type
5508fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower
551e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
552d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Run one step of LSTM.
553d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
554d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
5551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs: input Tensor, 2D, [batch, feature_size].
5561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
5571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        flag self._state_is_tuple.
558d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
559d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
560d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A tuple containing:
5611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
562d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
563d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        Here output_dim is num_units.
5641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
565d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
566d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
567d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if an input_size was specified and the provided inputs have
568d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        a different dimension.
569d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
570499166454f0c7dd6c724c364f6dc5d99357514b1A. Unique TensorFlower    batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
5711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    freq_inputs = self._make_tf_features(inputs)
572e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    m_out_lst = []
573e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    state_out_lst = []
574e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    for block in range(len(freq_inputs)):
575e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m_out_lst_current, state_out_lst_current = self._compute(
576e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          freq_inputs[block], block, state, batch_size,
577e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          state_is_tuple=self._state_is_tuple)
578e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m_out_lst.extend(m_out_lst_current)
579e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      state_out_lst.extend(state_out_lst_current)
580e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._state_is_tuple:
581e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      state_out = self._state_tuple_type(*state_out_lst)
582e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    else:
583e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      state_out = array_ops.concat(state_out_lst, 1)
584e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    m_out = array_ops.concat(m_out_lst, 1)
5851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    return m_out, state_out
5861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
5879e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower  def _compute(self, freq_inputs, block, state, batch_size,
5889e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               state_prefix="state",
5891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               state_is_tuple=True):
5901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """Run the actual computation of one step LSTM.
5911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
5921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Args:
5931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      freq_inputs: list of Tensors, 2D, [batch, feature_size].
5949e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      block: int, current frequency block index to process.
5951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on
5961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        the flag state_is_tuple.
5971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      batch_size: int32, batch size.
5981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state_prefix: (optional) string, name prefix for states, defaults to
5991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        "state".
6001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state_is_tuple: boolean, indicates whether the state is a tuple or Tensor.
6011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Returns:
6031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      A tuple, containing:
6041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A list of [batch, output_dim] Tensors, representing the output of the
6051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        LSTM given the inputs and state.
6061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A list of [batch, state_size] Tensors, representing the LSTM state
6071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        values given the inputs and previous state.
6081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """
609e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    sigmoid = math_ops.sigmoid
610e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    tanh = math_ops.tanh
6118fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    num_gates = 3 if self._couple_input_forget_gates else 4
6121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    dtype = freq_inputs[0].dtype
613d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    actual_input_size = freq_inputs[0].get_shape().as_list()[1]
6141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    concat_w_f = _get_concat_variable(
6169e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        "W_f_%d" % block, [actual_input_size + 2 * self._num_units,
6179e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                           num_gates * self._num_units],
6181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        dtype, self._num_unit_shards)
6191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    b_f = vs.get_variable(
6204ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        "B_f_%d" % block,
6214ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        shape=[num_gates * self._num_units],
6224ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        initializer=init_ops.zeros_initializer(),
6234ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        dtype=dtype)
6241855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if not self._share_time_frequency_weights:
6251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      concat_w_t = _get_concat_variable(
6269e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_t_%d" % block, [actual_input_size + 2 * self._num_units,
6279e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                             num_gates * self._num_units],
628d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          dtype, self._num_unit_shards)
6291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      b_t = vs.get_variable(
6304ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          "B_t_%d" % block,
6314ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          shape=[num_gates * self._num_units],
6324ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          initializer=init_ops.zeros_initializer(),
6334ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          dtype=dtype)
634d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
6351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if self._use_peepholes:
6361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Diagonal connections
6371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if not self._couple_input_forget_gates:
6381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_f_diag_freqf = vs.get_variable(
6399e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
6401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_f_diag_freqt = vs.get_variable(
6419e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_F_diag_freqt_%d"% block, shape=[self._num_units], dtype=dtype)
6421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_i_diag_freqf = vs.get_variable(
6439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
6441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_i_diag_freqt = vs.get_variable(
6459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
6461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_o_diag_freqf = vs.get_variable(
6479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
6481855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_o_diag_freqt = vs.get_variable(
6499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
6501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if not self._share_time_frequency_weights:
6518fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        if not self._couple_input_forget_gates:
6521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          w_f_diag_timef = vs.get_variable(
6539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
6541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          w_f_diag_timet = vs.get_variable(
6559e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
6561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_i_diag_timef = vs.get_variable(
6579e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
6581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_i_diag_timet = vs.get_variable(
6599e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
6601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_o_diag_timef = vs.get_variable(
6619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
6621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_o_diag_timet = vs.get_variable(
6639e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
6641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    # initialize the first freq state to be zero
6661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
6671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
6681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    for freq_index in range(len(freq_inputs)):
6691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if state_is_tuple:
6709e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block)
6711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_prev_time = getattr(state, name_prefix + "_c")
6721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_prev_time = getattr(state, name_prefix + "_m")
6731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
6741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_prev_time = array_ops.slice(
6751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            state, [0, 2 * freq_index * self._num_units],
6761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            [-1, self._num_units])
6771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_prev_time = array_ops.slice(
6781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            state, [0, (2 * freq_index + 1) * self._num_units],
6791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            [-1, self._num_units])
6801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
6820e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower      cell_inputs = array_ops.concat(
683d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower          [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)
6841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # F-LSTM
6861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      lstm_matrix_freq = nn_ops.bias_add(math_ops.matmul(cell_inputs,
6871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                                         concat_w_f), b_f)
6881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._couple_input_forget_gates:
689a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower        i_freq, j_freq, o_freq = array_ops.split(
690a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
6911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_freq = None
6921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
693a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower        i_freq, j_freq, f_freq, o_freq = array_ops.split(
694a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
6951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # T-LSTM
6961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._share_time_frequency_weights:
6971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_time = i_freq
6981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        j_time = j_freq
6991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_time = f_freq
7001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        o_time = o_freq
7011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
7021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        lstm_matrix_time = nn_ops.bias_add(math_ops.matmul(cell_inputs,
7031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                                           concat_w_t), b_t)
7048fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        if self._couple_input_forget_gates:
705a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower          i_time, j_time, o_time = array_ops.split(
706a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
7071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_time = None
7088fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        else:
709a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower          i_time, j_time, f_time, o_time = array_ops.split(
710a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
711d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
7121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # F-LSTM c_freq
7131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # input gate activations
7141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
7151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_freq_g = sigmoid(i_freq +
7161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_i_diag_freqf * c_prev_freq +
7171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_i_diag_freqt * c_prev_time)
7181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
7191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_freq_g = sigmoid(i_freq)
7201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # forget gate activations
7211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._couple_input_forget_gates:
7221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_freq_g = 1.0 - i_freq_g
7231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
724d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        if self._use_peepholes:
7251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_freq_g = sigmoid(f_freq + self._forget_bias +
7261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                             w_f_diag_freqf * c_prev_freq +
7271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                             w_f_diag_freqt * c_prev_time)
7281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        else:
7291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_freq_g = sigmoid(f_freq + self._forget_bias)
7301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # cell state
7311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq)
7321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._cell_clip is not None:
7331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: disable=invalid-unary-operand-type
7341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip,
7351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                        self._cell_clip)
7361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: enable=invalid-unary-operand-type
7371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
7381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # T-LSTM c_freq
7391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # input gate activations
7401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
7411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        if self._share_time_frequency_weights:
7421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          i_time_g = sigmoid(i_time +
7438fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower                             w_i_diag_freqf * c_prev_freq +
7448fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower                             w_i_diag_freqt * c_prev_time)
7458fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        else:
7461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          i_time_g = sigmoid(i_time +
7471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                             w_i_diag_timef * c_prev_freq +
7481855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                             w_i_diag_timet * c_prev_time)
7491855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
7501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_time_g = sigmoid(i_time)
7511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # forget gate activations
7521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._couple_input_forget_gates:
7531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_time_g = 1.0 - i_time_g
7541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
755d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        if self._use_peepholes:
756d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          if self._share_time_frequency_weights:
7571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            f_time_g = sigmoid(f_time + self._forget_bias +
7581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                               w_f_diag_freqf * c_prev_freq +
7591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                               w_f_diag_freqt * c_prev_time)
760d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          else:
7611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            f_time_g = sigmoid(f_time + self._forget_bias +
7621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                               w_f_diag_timef * c_prev_freq +
7631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                               w_f_diag_timet * c_prev_time)
7648fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        else:
7651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_time_g = sigmoid(f_time + self._forget_bias)
7661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # cell state
7671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time)
7681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._cell_clip is not None:
7691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: disable=invalid-unary-operand-type
7701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_time = clip_ops.clip_by_value(c_time, -self._cell_clip,
7711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                        self._cell_clip)
7721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: enable=invalid-unary-operand-type
7731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
7741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # F-LSTM m_freq
7751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
7761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_freq = sigmoid(o_freq +
7771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                         w_o_diag_freqf * c_freq +
7781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                         w_o_diag_freqt * c_time) * tanh(c_freq)
7791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
7801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_freq = sigmoid(o_freq) * tanh(c_freq)
781d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
7821855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # T-LSTM m_time
7831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
7841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        if self._share_time_frequency_weights:
7851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          m_time = sigmoid(o_time +
7868fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower                           w_o_diag_freqf * c_freq +
7871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_o_diag_freqt * c_time) * tanh(c_time)
788d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        else:
7891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          m_time = sigmoid(o_time +
7901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_o_diag_timef * c_freq +
7911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_o_diag_timet * c_time) * tanh(c_time)
7928fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      else:
7931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_time = sigmoid(o_time) * tanh(c_time)
7941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
7951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      m_prev_freq = m_freq
7961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      c_prev_freq = c_freq
7971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Concatenate the outputs for T-LSTM and F-LSTM for each shift
7981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if freq_index == 0:
7991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state_out_lst = [c_time, m_time]
8001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_out_lst = [m_time, m_freq]
8011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
8021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state_out_lst.extend([c_time, m_time])
8031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_out_lst.extend([m_time, m_freq])
804d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
8051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    return m_out_lst, state_out_lst
8061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
8071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  def _make_tf_features(self, input_feat, slice_offset=0):
808d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Make the frequency features.
809d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
810d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
8111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      input_feat: input Tensor, 2D, [batch, num_units].
8121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      slice_offset: (optional) Python int, default 0, the slicing offset is only
8131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        used for the backward processing in the BidirectionalGridLSTMCell. It
8141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        specifies a different starting point instead of always 0 to enable the
8151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        forward and backward processing look at different frequency blocks.
816d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
817d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
818d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A list of frequency features, with each element containing:
8191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, output_dim], Tensor representing the time-frequency
8201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        feature for that frequency index. Here output_dim is feature_size.
821d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
822d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if input_size cannot be inferred from static shape inference.
823d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
824d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    input_size = input_feat.get_shape().with_rank(2)[-1].value
825d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if input_size is None:
826d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      raise ValueError("Cannot infer input_size from static shape inference.")
8271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if slice_offset > 0:
8281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Padding to the end
8291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs = array_ops.pad(
8301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          input_feat, array_ops.constant([0, 0, 0, slice_offset], shape=[2, 2],
8311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                         dtype=dtypes.int32),
8321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          "CONSTANT")
8331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    elif slice_offset < 0:
8341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Padding to the front
8351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs = array_ops.pad(
8361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          input_feat, array_ops.constant([0, 0, -slice_offset, 0], shape=[2, 2],
8371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                         dtype=dtypes.int32),
8381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          "CONSTANT")
8391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      slice_offset = 0
8401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    else:
8411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs = input_feat
842d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    freq_inputs = []
8439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    if not self._start_freqindex_list:
8449e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if len(self._num_frequency_blocks) != 1:
8459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError("Length of num_frequency_blocks"
8469e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         " is not 1, but instead is %d",
8479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._num_frequency_blocks))
8489e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      num_feats = int((input_size - self._feature_size) / (
8499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          self._frequency_skip)) + 1
8509e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if num_feats != self._num_frequency_blocks[0]:
8519e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError(
8529e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "Invalid num_frequency_blocks, requires %d but gets %d, please"
8539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            " check the input size and filter config are correct." % (
8549e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                self._num_frequency_blocks[0], num_feats))
8559e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      block_inputs = []
8569e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for f in range(num_feats):
8579e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cur_input = array_ops.slice(
8589e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            inputs, [0, slice_offset + f * self._frequency_skip],
8599e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            [-1, self._feature_size])
8609e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        block_inputs.append(cur_input)
8619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      freq_inputs.append(block_inputs)
8629e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    else:
8639e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if len(self._start_freqindex_list) != len(self._end_freqindex_list):
8649e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError("Length of start and end freqindex_list"
8659e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         " does not match %d %d",
8669e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._start_freqindex_list),
8679e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._end_freqindex_list))
8689e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if len(self._num_frequency_blocks) != len(self._start_freqindex_list):
8699e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError("Length of num_frequency_blocks"
8709e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         " is not equal to start_freqindex_list %d %d",
8719e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._num_frequency_blocks),
8729e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._start_freqindex_list))
8739e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for b in range(len(self._start_freqindex_list)):
8749e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        start_index = self._start_freqindex_list[b]
8759e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        end_index = self._end_freqindex_list[b]
8769e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cur_size = end_index - start_index
8779e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        block_feats = int((cur_size - self._feature_size) / (
8789e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            self._frequency_skip)) + 1
8799e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        if block_feats != self._num_frequency_blocks[b]:
8809e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          raise ValueError(
8819e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              "Invalid num_frequency_blocks, requires %d but gets %d, please"
8829e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              " check the input size and filter config are correct." % (
8839e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                  self._num_frequency_blocks[b], block_feats))
8849e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        block_inputs = []
8859e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        for f in range(block_feats):
8869e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          cur_input = array_ops.slice(
8879e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              inputs, [0, start_index + slice_offset + f *
8889e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                       self._frequency_skip],
8899e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              [-1, self._feature_size])
8909e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          block_inputs.append(cur_input)
8919e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        freq_inputs.append(block_inputs)
892d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return freq_inputs
8935cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
8945cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
8951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerclass BidirectionalGridLSTMCell(GridLSTMCell):
8961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  """Bidirectional GridLstm cell.
8971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
8981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  The bidirection connection is only used in the frequency direction, which
8991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  hence doesn't affect the time direction's real-time processing that is
9001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  required for online recognition systems.
9011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  The current implementation uses different weights for the two directions.
9021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  """
9031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  def __init__(self, num_units, use_peepholes=False,
9051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               share_time_frequency_weights=False,
9061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               cell_clip=None, initializer=None,
9071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               num_unit_shards=1, forget_bias=1.0,
9081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               feature_size=None, frequency_skip=None,
9099e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               num_frequency_blocks=None,
9109e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               start_freqindex_list=None,
9119e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               end_freqindex_list=None,
9121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               couple_input_forget_gates=False,
91354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               backward_slice_offset=0,
91454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               reuse=None):
9151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
9161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Args:
9181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
9191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      use_peepholes: (optional) bool, default False. Set True to enable
9201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        diagonal/peephole connections.
9211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      share_time_frequency_weights: (optional) bool, default False. Set True to
9221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        enable shared cell weights between time and frequency LSTMs.
9231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      cell_clip: (optional) A float value, default None, if provided the cell
9241855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state is clipped by this value prior to the cell output activation.
9251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
9261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        projection matrices, default None.
9271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      num_unit_shards: (optional) int, defualt 1, How to split the weight
9281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        matrix. If > 1,the weight matrix is stored across num_unit_shards.
9291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      forget_bias: (optional) float, default 1.0, The initial bias of the
9301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        forget gates, used to reduce the scale of forgetting at the beginning
9311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        of the training.
9321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      feature_size: (optional) int, default None, The size of the input feature
9331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        the LSTM spans over.
9341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      frequency_skip: (optional) int, default None, The amount the LSTM filter
9351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        is shifted by in frequency.
9369e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      num_frequency_blocks: [required] A list of frequency blocks needed to
9379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cover the whole input feature splitting defined by start_freqindex_list
9389e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        and end_freqindex_list.
9399e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      start_freqindex_list: [optional], list of ints, default None,  The
9409e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        starting frequency index for each frequency block.
9419e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      end_freqindex_list: [optional], list of ints, default None. The ending
9429e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        frequency index for each frequency block.
9431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      couple_input_forget_gates: (optional) bool, default False, Whether to
9441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
9451855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        model parameters and computation cost.
9461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      backward_slice_offset: (optional) int32, default 0, the starting offset to
9471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        slice the feature for backward processing.
94854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
94954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
95054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
9511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """
9521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    super(BidirectionalGridLSTMCell, self).__init__(
9531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        num_units, use_peepholes, share_time_frequency_weights, cell_clip,
9541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        initializer, num_unit_shards, forget_bias, feature_size, frequency_skip,
9559e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        num_frequency_blocks, start_freqindex_list, end_freqindex_list,
956a5493db6082a6fe29ed11e71f2aab5bfdf5f98c7A. Unique TensorFlower        couple_input_forget_gates, True, reuse)
9571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    self._backward_slice_offset = int(backward_slice_offset)
9581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    state_names = ""
9591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    for direction in ["fwd", "bwd"]:
9609e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for block_index in range(len(self._num_frequency_blocks)):
9619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        for freq_index in range(self._num_frequency_blocks[block_index]):
9629e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index,
9639e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                                                  block_index)
9649e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
9651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    self._state_tuple_type = collections.namedtuple(
9661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        "BidirectionalGridLSTMStateTuple", state_names.strip(","))
9671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    self._state_size = self._state_tuple_type(
9689e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        *([num_units, num_units] * self._total_blocks * 2))
9699e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._output_size = 2 * num_units * self._total_blocks * 2
9701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
971e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
9721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """Run one step of LSTM.
9731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Args:
9751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs: input Tensor, 2D, [batch, num_units].
9761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state: tuple of Tensors, 2D, [batch, state_size].
9771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Returns:
9791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      A tuple containing:
9801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
9811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        after reading "inputs" when previous state was "state".
9821855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        Here output_dim is num_units.
9831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
9841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        after reading "inputs" when previous state was "state".
9851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Raises:
9861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      ValueError: if an input_size was specified and the provided inputs have
9871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        a different dimension.
9881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """
989499166454f0c7dd6c724c364f6dc5d99357514b1A. Unique TensorFlower    batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
9901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    fwd_inputs = self._make_tf_features(inputs)
9911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if self._backward_slice_offset:
9921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
9931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    else:
9941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      bwd_inputs = fwd_inputs
9951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    # Forward processing
997e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("fwd"):
998e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      fwd_m_out_lst = []
999e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      fwd_state_out_lst = []
1000e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      for block in range(len(fwd_inputs)):
1001e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
1002e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo            fwd_inputs[block], block, state, batch_size,
1003e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo            state_prefix="fwd_state", state_is_tuple=True)
1004e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        fwd_m_out_lst.extend(fwd_m_out_lst_current)
1005e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        fwd_state_out_lst.extend(fwd_state_out_lst_current)
1006e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Backward processing
1007e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    bwd_m_out_lst = []
1008e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    bwd_state_out_lst = []
1009e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("bwd"):
1010e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      for block in range(len(bwd_inputs)):
1011e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        # Reverse the blocks
1012e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        bwd_inputs_reverse = bwd_inputs[block][::-1]
1013e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
1014e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo            bwd_inputs_reverse, block, state, batch_size,
1015e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo            state_prefix="bwd_state", state_is_tuple=True)
1016e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        bwd_m_out_lst.extend(bwd_m_out_lst_current)
1017e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        bwd_state_out_lst.extend(bwd_state_out_lst_current)
10181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
10191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    # Outputs are always concated as it is never used separately.
10200e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower    m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
10211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    return m_out, state_out
10221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
10231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
10245cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: disable=protected-access
102537fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie_linear = core_rnn_cell_impl._linear
10265cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: enable=protected-access
10275cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10285cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
102937fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass AttentionCellWrapper(core_rnn_cell.RNNCell):
10305cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  """Basic attention cell wrapper.
10315cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
1032ec3f4d62979ef1e70e8e12e2568b13dad45fd39eA. Unique TensorFlower  Implementation based on https://arxiv.org/abs/1409.0473.
10335cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  """
10345cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None,
1036c3d99052ec49bb219f5e29567846d3af391d7b28A. Unique TensorFlower               input_size=None, state_is_tuple=True, reuse=None):
10375cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    """Create a cell with attention.
10385cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10395cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    Args:
10405cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      cell: an RNNCell, an attention is added to it.
10415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_length: integer, the size of an attention window.
10425cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_size: integer, the size of an attention vector. Equal to
10435cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          cell.output_size by default.
10445cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_vec_size: integer, the number of convolutional features calculated
10455cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          on attention state and a size of the hidden layer built from
10465cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          base cell state. Equal attn_size to by default.
10475cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      input_size: integer, the size of a hidden linear layer,
10485cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          built from inputs and attention. Derived from the input tensor
10495cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          by default.
10505cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      state_is_tuple: If True, accepted and returned states are n-tuples, where
10515cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        `n = len(cells)`.  By default (False), the states are all
10525cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        concatenated along the column axis.
105354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
105454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
105554d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
10565cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10575cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    Raises:
10585cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      TypeError: if cell is not an RNNCell.
10595cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      ValueError: if cell returns a state tuple but the flag
10605cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          `state_is_tuple` is `False` or if attn_length is zero or less.
10615cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    """
1062e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(AttentionCellWrapper, self).__init__(_reuse=reuse)
106337fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie    if not isinstance(cell, core_rnn_cell.RNNCell):
10645cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      raise TypeError("The parameter cell is not RNNCell.")
10654c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenech    if nest.is_sequence(cell.state_size) and not state_is_tuple:
10665cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      raise ValueError("Cell returns tuple of states, but the flag "
10675cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower                       "state_is_tuple is not set. State size is: %s"
10685cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower                       % str(cell.state_size))
10695cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if attn_length <= 0:
10705cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      raise ValueError("attn_length should be greater than zero, got %s"
10715cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower                       % str(attn_length))
10725cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if not state_is_tuple:
10735cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      logging.warn(
10745cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          "%s: Using a concatenated state is slower and will soon be "
1075d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower          "deprecated.  Use state_is_tuple=True.", self)
10765cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if attn_size is None:
10775cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_size = cell.output_size
10785cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if attn_vec_size is None:
10795cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_vec_size = attn_size
10805cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._state_is_tuple = state_is_tuple
10815cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._cell = cell
10825cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._attn_vec_size = attn_vec_size
10835cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._input_size = input_size
10845cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._attn_size = attn_size
10855cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._attn_length = attn_length
108654d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
10875cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10885cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  @property
10895cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def state_size(self):
10905cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    size = (self._cell.state_size, self._attn_size,
10915cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower            self._attn_size * self._attn_length)
10925cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if self._state_is_tuple:
10935cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      return size
10945cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    else:
10955cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      return sum(list(size))
10965cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
10975cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  @property
10985cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def output_size(self):
10995cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    return self._attn_size
11005cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
1101e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
11025cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    """Long short-term memory cell with attention (LSTMA)."""
1103e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._state_is_tuple:
1104e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      state, attns, attn_states = state
1105e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    else:
1106e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      states = state
1107e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
1108e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      attns = array_ops.slice(
1109e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          states, [0, self._cell.state_size], [-1, self._attn_size])
1110e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      attn_states = array_ops.slice(
1111e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          states, [0, self._cell.state_size + self._attn_size],
1112e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          [-1, self._attn_size * self._attn_length])
1113e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    attn_states = array_ops.reshape(attn_states,
1114e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                                    [-1, self._attn_length, self._attn_size])
1115e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    input_size = self._input_size
1116e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if input_size is None:
1117e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      input_size = inputs.get_shape().as_list()[1]
1118e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    inputs = _linear([inputs, attns], input_size, True)
1119e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    lstm_output, new_state = self._cell(inputs, state)
1120e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._state_is_tuple:
1121e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
1122e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    else:
1123e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      new_state_cat = new_state
1124e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
1125e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("attn_output_projection"):
1126e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      output = _linear([lstm_output, new_attns], self._attn_size, True)
1127e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_attn_states = array_ops.concat(
1128e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        [new_attn_states, array_ops.expand_dims(output, 1)], 1)
1129e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_attn_states = array_ops.reshape(
1130e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        new_attn_states, [-1, self._attn_length * self._attn_size])
1131e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_state = (new_state, new_attns, new_attn_states)
1132e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if not self._state_is_tuple:
1133e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      new_state = array_ops.concat(list(new_state), 1)
1134e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    return output, new_state
11355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
11365cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def _attention(self, query, attn_states):
1137e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    conv2d = nn_ops.conv2d
1138e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    reduce_sum = math_ops.reduce_sum
1139e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    softmax = nn_ops.softmax
1140e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    tanh = math_ops.tanh
1141e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo
114292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope("attention"):
114392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      k = vs.get_variable(
114492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo          "attn_w", [1, 1, self._attn_size, self._attn_vec_size])
114592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      v = vs.get_variable("attn_v", [self._attn_vec_size])
11465cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      hidden = array_ops.reshape(attn_states,
11475cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower                                 [-1, self._attn_length, 1, self._attn_size])
11485cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
11495cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      y = _linear(query, self._attn_vec_size, True)
11505cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
11515cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
11525cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      a = softmax(s)
11535cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      d = reduce_sum(
11545cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
11555cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      new_attns = array_ops.reshape(d, [-1, self._attn_size])
11565cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
11575cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      return new_attns, new_attn_states
115834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
115934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
116037fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
116134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  """LSTM unit with layer normalization and recurrent dropout.
116234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
116334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  This class adds layer normalization and recurrent dropout to a
116434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  basic LSTM unit. Layer normalization implementation is based on:
116534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
116634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    https://arxiv.org/abs/1607.06450.
116734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
116834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  "Layer Normalization"
116934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
117034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
117134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  and is applied before the internal nonlinearities.
117234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  Recurrent dropout is base on:
117334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
117434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    https://arxiv.org/abs/1603.05118
117534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
117634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  "Recurrent Dropout without Memory Loss"
117734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
117834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  """
117934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
118034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def __init__(self, num_units, forget_bias=1.0,
118134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower               input_size=None, activation=math_ops.tanh,
118234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower               layer_norm=True, norm_gain=1.0, norm_shift=0.0,
118354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               dropout_keep_prob=1.0, dropout_prob_seed=None,
118454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               reuse=None):
118534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    """Initializes the basic LSTM cell.
118634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
118734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    Args:
118834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      num_units: int, The number of units in the LSTM cell.
118934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      forget_bias: float, The bias added to forget gates (see above).
119034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      input_size: Deprecated and unused.
119134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      activation: Activation function of the inner states.
119234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      layer_norm: If `True`, layer normalization will be applied.
119334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      norm_gain: float, The layer normalization gain initial value. If
119434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        `layer_norm` has been set to `False`, this argument will be ignored.
119534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      norm_shift: float, The layer normalization shift initial value. If
119634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        `layer_norm` has been set to `False`, this argument will be ignored.
119734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      dropout_keep_prob: unit Tensor or float between 0 and 1 representing the
119834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        recurrent dropout probability value. If float and 1.0, no dropout will
119934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        be applied.
120034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      dropout_prob_seed: (optional) integer, the randomness seed.
120154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
120254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
120354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
120434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    """
1205e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse)
120634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
120734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    if input_size is not None:
120834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      logging.warn("%s: The input_size parameter is deprecated.", self)
120934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
121034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._num_units = num_units
121134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._activation = activation
121234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._forget_bias = forget_bias
121334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._keep_prob = dropout_keep_prob
121434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._seed = dropout_prob_seed
121534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._layer_norm = layer_norm
121634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._g = norm_gain
121734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._b = norm_shift
121854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
121934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
122034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  @property
122134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def state_size(self):
122237fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie    return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
122334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
122434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  @property
122534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def output_size(self):
122634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    return self._num_units
122734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
122834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def _norm(self, inp, scope):
122992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    shape = inp.get_shape()[-1:]
123092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    gamma_init = init_ops.constant_initializer(self._g)
123192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    beta_init = init_ops.constant_initializer(self._b)
123292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope(scope):
123392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      # Initialize beta and gamma for use by layer_norm.
123492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      vs.get_variable("gamma", shape=shape, initializer=gamma_init)
123592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      vs.get_variable("beta", shape=shape, initializer=beta_init)
123692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    normalized = layers.layer_norm(inp, reuse=True, scope=scope)
123792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    return normalized
123892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo
123992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo  def _linear(self, args):
124034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    out_size = 4 * self._num_units
124134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    proj_size = args.get_shape()[-1]
124292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    weights = vs.get_variable("weights", [proj_size, out_size])
124392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    out = math_ops.matmul(args, weights)
124492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    if not self._layer_norm:
124592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      bias = vs.get_variable("biases", [out_size])
124692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      out = nn_ops.bias_add(out, bias)
124792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    return out
124834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1249e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
125034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    """LSTM cell with layer normalization and recurrent dropout."""
1251e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    c, h = state
1252e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    args = array_ops.concat([inputs, h], 1)
1253e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    concat = self._linear(args)
125434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1255e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
1256e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._layer_norm:
1257e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      i = self._norm(i, "input")
1258e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      j = self._norm(j, "transform")
1259e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      f = self._norm(f, "forget")
1260e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      o = self._norm(o, "output")
126134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1262e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    g = self._activation(j)
1263e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
1264e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
126534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1266e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_c = (c * math_ops.sigmoid(f + self._forget_bias)
1267e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo             + math_ops.sigmoid(i) * g)
1268e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._layer_norm:
1269e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      new_c = self._norm(new_c, "state")
1270e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_h = self._activation(new_c) * math_ops.sigmoid(o)
127134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1272e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
1273e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    return new_h, new_state
1274bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1275bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
12761e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlowerclass NASCell(core_rnn_cell.RNNCell):
12771e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  """Neural Architecture Search (NAS) recurrent network cell.
12781e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
12791e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  This implements the recurrent cell from the paper:
12801e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
12811e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    https://arxiv.org/abs/1611.01578
12821e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
12831e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  Barret Zoph and Quoc V. Le.
12841e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.
12851e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
12861e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  The class uses an optional projection layer.
12871e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  """
12881e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
12891e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  def __init__(self, num_units, num_proj=None,
129054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               use_biases=False, reuse=None):
12911e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    """Initialize the parameters for a NAS cell.
12921e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
12931e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    Args:
12941e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      num_units: int, The number of units in the NAS cell
12951e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      num_proj: (optional) int, The output dimensionality for the projection
12961e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        matrices.  If None, no projection is performed.
12971e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      use_biases: (optional) bool, If True then use biases within the cell. This
12981e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        is False by default.
129954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
130054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
130154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
13021e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    """
1303e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(NASCell, self).__init__(_reuse=reuse)
13041e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    self._num_units = num_units
13051e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    self._num_proj = num_proj
13061e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    self._use_biases = use_biases
130754d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
13081e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
13091e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    if num_proj is not None:
13101e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj)
13111e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      self._output_size = num_proj
13121e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    else:
13131e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units)
13141e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      self._output_size = num_units
13151e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
13161e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  @property
13171e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  def state_size(self):
13181e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    return self._state_size
13191e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
13201e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  @property
13211e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  def output_size(self):
13221e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    return self._output_size
13231e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
1324e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
13251e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    """Run one step of NAS Cell.
13261e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
13271e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    Args:
13281e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      inputs: input Tensor, 2D, batch x num_units.
13291e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      state: This must be a tuple of state Tensors, both `2-D`, with column
13301e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        sizes `c_state` and `m_state`.
13311e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
13321e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    Returns:
13331e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      A tuple containing:
13341e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
13351e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        NAS Cell after reading `inputs` when previous state was `state`.
13361e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        Here output_dim is:
13371e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower           num_proj if num_proj was set,
13381e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower           num_units otherwise.
13391e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      - Tensor(s) representing the new state of NAS Cell after reading `inputs`
13401e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        when the previous state was `state`.  Same type and shape(s) as `state`.
13411e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
13421e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    Raises:
13431e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      ValueError: If input size cannot be inferred from inputs via
13441e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        static shape inference.
13451e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    """
13461e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    sigmoid = math_ops.sigmoid
13471e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    tanh = math_ops.tanh
13481e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    relu = nn_ops.relu
13491e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
13501e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    num_proj = self._num_units if self._num_proj is None else self._num_proj
13511e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
13521e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    (c_prev, m_prev) = state
13531e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
13541e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    dtype = inputs.dtype
13551e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    input_size = inputs.get_shape().with_rank(2)[1]
13561e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    if input_size.value is None:
13571e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1358e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Variables for the NAS cell. W_m is all matrices multiplying the
1359e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # hiddenstate and W_inputs is all matrices multiplying the inputs.
1360e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    concat_w_m = vs.get_variable(
1361e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "recurrent_weights", [num_proj, 8 * self._num_units],
1362e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        dtype)
1363e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    concat_w_inputs = vs.get_variable(
1364e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "weights", [input_size.value, 8 * self._num_units],
1365e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        dtype)
1366e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1367e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    m_matrix = math_ops.matmul(m_prev, concat_w_m)
1368e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
1369e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1370e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_biases:
1371e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      b = vs.get_variable(
1372e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "bias",
1373e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          shape=[8 * self._num_units],
1374e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          initializer=init_ops.zeros_initializer(),
1375e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          dtype=dtype)
1376e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m_matrix = nn_ops.bias_add(m_matrix, b)
1377e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1378e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # The NAS cell branches into 8 different splits for both the hiddenstate
1379e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # and the input
1380e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
1381e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                                      value=m_matrix)
1382e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
1383e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                                           value=inputs_matrix)
1384e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1385e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # First layer
1386e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
1387e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
1388e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
1389e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
1390e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
1391e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
1392e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
1393e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
1394e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1395e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Second layer
1396e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l2_0 = tanh(layer1_0 * layer1_1)
1397e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l2_1 = tanh(layer1_2 + layer1_3)
1398e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l2_2 = tanh(layer1_4 * layer1_5)
1399e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l2_3 = sigmoid(layer1_6 + layer1_7)
1400e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1401e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Inject the cell
1402e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l2_0 = tanh(l2_0 + c_prev)
1403e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1404e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Third layer
1405e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l3_0_pre = l2_0 * l2_1
1406e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_c = l3_0_pre  # create new cell
1407e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l3_0 = l3_0_pre
1408e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l3_1 = tanh(l2_2 + l2_3)
1409e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1410e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Final layer
1411e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_m = tanh(l3_0 * l3_1)
1412e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1413e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Projection layer if specified
1414e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._num_proj is not None:
1415e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      concat_w_proj = vs.get_variable(
1416e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "projection_weights", [self._num_units, self._num_proj],
14171e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower          dtype)
1418e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      new_m = math_ops.matmul(new_m, concat_w_proj)
14191e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
1420e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
1421e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    return new_m, new_state
14221e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
14231e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
1424d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlowerclass UGRNNCell(core_rnn_cell.RNNCell):
1425d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  """Update Gate Recurrent Neural Network (UGRNN) cell.
1426d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1427d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  Compromise between a LSTM/GRU and a vanilla RNN.  There is only one
1428d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  gate, and that is to determine whether the unit should be
1429d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  integrating or computing instantaneously.  This is the recurrent
1430d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  idea of the feedforward Highway Network.
1431d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1432d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  This implements the recurrent cell from the paper:
1433d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1434d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    https://arxiv.org/abs/1611.09913
1435d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1436d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
1437d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
1438d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  """
1439d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1440d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  def __init__(self, num_units, initializer=None, forget_bias=1.0,
1441d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower               activation=math_ops.tanh, reuse=None):
1442d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """Initialize the parameters for an UGRNN cell.
1443d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1444d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Args:
1445d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      num_units: int, The number of units in the UGRNN cell
1446d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      initializer: (optional) The initializer to use for the weight matrices.
1447d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      forget_bias: (optional) float, default 1.0, The initial bias of the
1448d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        forget gate, used to reduce the scale of forgetting at the beginning
1449d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        of the training.
1450d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      activation: (optional) Activation function of the inner states.
1451d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        Default is `tf.tanh`.
1452d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      reuse: (optional) Python boolean describing whether to reuse variables
1453d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        in an existing scope.  If not `True`, and the existing scope already has
1454d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        the given variables, an error is raised.
1455d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """
1456e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(UGRNNCell, self).__init__(_reuse=reuse)
1457d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._num_units = num_units
1458d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._initializer = initializer
1459d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._forget_bias = forget_bias
1460d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._activation = activation
1461d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._reuse = reuse
1462d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1463d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  @property
1464d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  def state_size(self):
1465d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return self._num_units
1466d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1467d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  @property
1468d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  def output_size(self):
1469d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return self._num_units
1470d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1471e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
1472d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """Run one step of UGRNN.
1473d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1474d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Args:
1475d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      inputs: input Tensor, 2D, batch x input size.
1476d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      state: state Tensor, 2D, batch x num units.
1477d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1478d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Returns:
1479d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_output: batch x num units, Tensor representing the output of the UGRNN
1480d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        after reading `inputs` when previous state was `state`. Identical to
1481d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        `new_state`.
1482d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_state: batch x num units, Tensor representing the state of the UGRNN
1483d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        after reading `inputs` when previous state was `state`.
1484d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1485d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Raises:
1486d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      ValueError: If input size cannot be inferred from inputs via
1487d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        static shape inference.
1488d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """
1489d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    sigmoid = math_ops.sigmoid
1490d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1491d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    input_size = inputs.get_shape().with_rank(2)[1]
1492d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    if input_size.value is None:
1493d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1494d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1495e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope(vs.get_variable_scope(),
1496e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                           initializer=self._initializer):
1497d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      cell_inputs = array_ops.concat([inputs, state], 1)
1498d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      rnn_matrix = _linear(cell_inputs, 2 * self._num_units, True)
1499d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1500d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      [g_act, c_act] = array_ops.split(
1501d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower          axis=1, num_or_size_splits=2, value=rnn_matrix)
1502d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1503d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      c = self._activation(c_act)
1504d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      g = sigmoid(g_act + self._forget_bias)
1505d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_state = g * state + (1.0 - g) * c
1506d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_output = new_state
1507d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1508d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return new_output, new_state
1509d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1510d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1511d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlowerclass IntersectionRNNCell(core_rnn_cell.RNNCell):
1512d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  """Intersection Recurrent Neural Network (+RNN) cell.
1513d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1514d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  Architecture with coupled recurrent gate as well as coupled depth
1515d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  gate, designed to improve information flow through stacked RNNs. As the
1516d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  architecture uses depth gating, the dimensionality of the depth
1517d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  output (y) also should not change through depth (input size == output size).
1518d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  To achieve this, the first layer of a stacked Intersection RNN projects
1519d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  the inputs to N (num units) dimensions. Therefore when initializing an
1520d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  IntersectionRNNCell, one should set `num_in_proj = N` for the first layer
1521d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  and use default settings for subsequent layers.
1522d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1523d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  This implements the recurrent cell from the paper:
1524d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1525d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    https://arxiv.org/abs/1611.09913
1526d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1527d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
1528d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
1529d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1530d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  The Intersection RNN is built for use in deeply stacked
1531d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  RNNs so it may not achieve best performance with depth 1.
1532d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  """
1533d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1534d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  def __init__(self, num_units, num_in_proj=None,
1535d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower               initializer=None, forget_bias=1.0,
1536d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower               y_activation=nn_ops.relu, reuse=None):
1537d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """Initialize the parameters for an +RNN cell.
1538d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1539d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Args:
1540d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      num_units: int, The number of units in the +RNN cell
1541d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      num_in_proj: (optional) int, The input dimensionality for the RNN.
1542d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        If creating the first layer of an +RNN, this should be set to
1543d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        `num_units`. Otherwise, this should be set to `None` (default).
1544d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        If `None`, dimensionality of `inputs` should be equal to `num_units`,
1545d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        otherwise ValueError is thrown.
1546d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      initializer: (optional) The initializer to use for the weight matrices.
1547d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      forget_bias: (optional) float, default 1.0, The initial bias of the
1548d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        forget gates, used to reduce the scale of forgetting at the beginning
1549d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        of the training.
1550d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      y_activation: (optional) Activation function of the states passed
1551d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        through depth. Default is 'tf.nn.relu`.
1552d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      reuse: (optional) Python boolean describing whether to reuse variables
1553d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        in an existing scope.  If not `True`, and the existing scope already has
1554d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        the given variables, an error is raised.
1555d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """
1556e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(IntersectionRNNCell, self).__init__(_reuse=reuse)
1557d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._num_units = num_units
1558d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._initializer = initializer
1559d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._forget_bias = forget_bias
1560d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._num_input_proj = num_in_proj
1561d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._y_activation = y_activation
1562d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._reuse = reuse
1563d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1564d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  @property
1565d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  def state_size(self):
1566d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return self._num_units
1567d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1568d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  @property
1569d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  def output_size(self):
1570d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return self._num_units
1571d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1572e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
1573d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """Run one step of the Intersection RNN.
1574d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1575d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Args:
1576d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      inputs: input Tensor, 2D, batch x input size.
1577d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      state: state Tensor, 2D, batch x num units.
1578d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1579d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Returns:
1580d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_y: batch x num units, Tensor representing the output of the +RNN
1581d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        after reading `inputs` when previous state was `state`.
1582d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_state: batch x num units, Tensor representing the state of the +RNN
1583d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        after reading `inputs` when previous state was `state`.
1584d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1585d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Raises:
1586d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      ValueError: If input size cannot be inferred from `inputs` via
1587d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        static shape inference.
1588d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      ValueError: If input size != output size (these must be equal when
1589d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        using the Intersection RNN).
1590d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """
1591d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    sigmoid = math_ops.sigmoid
1592d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    tanh = math_ops.tanh
1593d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1594d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    input_size = inputs.get_shape().with_rank(2)[1]
1595d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    if input_size.value is None:
1596d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1597d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1598e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope(vs.get_variable_scope(),
1599e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                           initializer=self._initializer):
1600d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      # read-in projections (should be used for first layer in deep +RNN
1601d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      # to transform size of inputs from I --> N)
1602d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      if input_size.value != self._num_units:
1603d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        if self._num_input_proj:
1604d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower          with vs.variable_scope("in_projection"):
1605d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower            inputs = _linear(inputs, self._num_units, True)
1606d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        else:
1607d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower          raise ValueError("Must have input size == output size for "
1608d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower                           "Intersection RNN. To fix, num_in_proj should "
1609d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower                           "be set to num_units at cell init.")
1610d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1611d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      n_dim = i_dim = self._num_units
1612d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      cell_inputs = array_ops.concat([inputs, state], 1)
1613d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      rnn_matrix = _linear(cell_inputs, 2*n_dim + 2*i_dim, True)
1614d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1615d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      gh_act = rnn_matrix[:, :n_dim]                           # b x n
1616d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      h_act = rnn_matrix[:, n_dim:2*n_dim]                     # b x n
1617d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      gy_act = rnn_matrix[:, 2*n_dim:2*n_dim+i_dim]            # b x i
1618d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      y_act = rnn_matrix[:, 2*n_dim+i_dim:2*n_dim+2*i_dim]     # b x i
1619d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1620d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      h = tanh(h_act)
1621d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      y = self._y_activation(y_act)
1622d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      gh = sigmoid(gh_act + self._forget_bias)
1623d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      gy = sigmoid(gy_act + self._forget_bias)
1624d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1625d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_state = gh * state + (1.0 - gh) * h  # passed thru time
1626d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_y = gy * inputs + (1.0 - gy) * y  # passed thru depth
1627d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1628d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return new_y, new_state
1629d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1630d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1631bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo_REGISTERED_OPS = None
1632bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1633bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1634bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdoclass CompiledWrapper(core_rnn_cell.RNNCell):
1635bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  """Wraps step execution in an XLA JIT scope."""
1636bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1637bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def __init__(self, cell, compile_stateful=False):
1638bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    """Create CompiledWrapper cell.
1639bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1640bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    Args:
1641bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      cell: Instance of `RNNCell`.
1642bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      compile_stateful: Whether to compile stateful ops like initializers
1643bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        and random number generators (default: False).
1644bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    """
1645bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    self._cell = cell
1646bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    self._compile_stateful = compile_stateful
1647bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1648bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  @property
1649bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def state_size(self):
1650bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    return self._cell.state_size
1651bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1652bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  @property
1653bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def output_size(self):
1654bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    return self._cell.output_size
1655bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
165603abac7f23e6e6864949b959435282729192692eEugene Brevdo  def zero_state(self, batch_size, dtype):
165703abac7f23e6e6864949b959435282729192692eEugene Brevdo    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
165803abac7f23e6e6864949b959435282729192692eEugene Brevdo      return self._cell.zero_state(batch_size, dtype)
165903abac7f23e6e6864949b959435282729192692eEugene Brevdo
1660bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def __call__(self, inputs, state, scope=None):
1661bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    if self._compile_stateful:
1662bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      compile_ops = True
1663bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    else:
1664bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      def compile_ops(node_def):
1665bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        global _REGISTERED_OPS
1666bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        if _REGISTERED_OPS is None:
1667bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo          _REGISTERED_OPS = op_def_registry.get_registered_ops()
1668bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        return not _REGISTERED_OPS[node_def.op].is_stateful
1669bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1670bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    with jit.experimental_jit_scope(compile_ops=compile_ops):
1671e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      return self._cell(inputs, state, scope)
1672bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1673bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1674bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlowerdef _random_exp_initializer(minval,
1675bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower                            maxval,
1676bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower                            seed=None,
1677bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower                            dtype=dtypes.float32):
1678bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  """Returns an exponential distribution initializer.
1679bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1680bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  Args:
1681bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    minval: float or a scalar float Tensor. With value > 0. Lower bound of the
1682bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        range of random values to generate.
1683bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    maxval: float or a scalar float Tensor. With value > minval. Upper bound of
1684bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        the range of random values to generate.
1685bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    seed: An integer. Used to create random seeds.
1686bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    dtype: The data type.
1687bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1688bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  Returns:
1689bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    An initializer that generates tensors with an exponential distribution.
1690bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  """
1691bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1692bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  def _initializer(shape, dtype=dtype, partition_info=None):
1693bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    del partition_info  # Unused.
1694bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    return math_ops.exp(
1695bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        random_ops.random_uniform(
1696bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower            shape,
1697bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower            math_ops.log(minval),
1698bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower            math_ops.log(maxval),
1699bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower            dtype,
1700bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower            seed=seed))
1701bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1702bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  return _initializer
1703bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1704bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1705bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlowerclass PhasedLSTMCell(core_rnn_cell.RNNCell):
1706bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  """Phased LSTM recurrent network cell.
1707bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1708bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  https://arxiv.org/pdf/1610.09513v1.pdf
1709bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  """
1710bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1711bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  def __init__(self,
1712bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               num_units,
1713bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               use_peepholes=False,
1714bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               leak=0.001,
17150a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower               ratio_on=0.1,
17160a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower               trainable_ratio_on=True,
1717bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               period_init_min=1.0,
1718bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               period_init_max=1000.0,
1719bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               reuse=None):
1720bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    """Initialize the Phased LSTM cell.
1721bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1722bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    Args:
1723bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      num_units: int, The number of units in the Phased LSTM cell.
1724bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      use_peepholes: bool, set True to enable peephole connections.
1725bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      leak: float or scalar float Tensor with value in [0, 1]. Leak applied
1726bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          during training.
17270a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower      ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
1728bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          period during which the gates are open.
17290a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower      trainable_ratio_on: bool, weather ratio_on is trainable.
1730bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      period_init_min: float or scalar float Tensor. With value > 0.
1731bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          Minimum value of the initalized period.
1732bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          The period values are initialized by drawing from the distribution:
1733bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          e^U(log(period_init_min), log(period_init_max))
1734bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          Where U(.,.) is the uniform distribution.
1735bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      period_init_max: float or scalar float Tensor.
1736bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          With value > period_init_min. Maximum value of the initalized period.
1737bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      reuse: (optional) Python boolean describing whether to reuse variables
1738bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        in an existing scope. If not `True`, and the existing scope already has
1739bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        the given variables, an error is raised.
1740bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    """
1741e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(PhasedLSTMCell, self).__init__(_reuse=reuse)
1742bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._num_units = num_units
1743bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._use_peepholes = use_peepholes
1744bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._leak = leak
17450a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    self._ratio_on = ratio_on
17460a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    self._trainable_ratio_on = trainable_ratio_on
1747bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._period_init_min = period_init_min
1748bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._period_init_max = period_init_max
1749bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._reuse = reuse
1750bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1751bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  @property
1752bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  def state_size(self):
1753bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
1754bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1755bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  @property
1756bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  def output_size(self):
1757bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    return self._num_units
1758bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1759bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  def _mod(self, x, y):
1760bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    """Modulo function that propagates x gradients."""
1761bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x
1762bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
17630a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower  def _get_cycle_ratio(self, time, phase, period):
17640a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    """Compute the cycle ratio in the dtype of the time."""
17650a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    phase_casted = math_ops.cast(phase, dtype=time.dtype)
17660a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    period_casted = math_ops.cast(period, dtype=time.dtype)
17670a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    shifted_time = time - phase_casted
17680a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
17690a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
17700a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower
1771e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
1772bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    """Phased LSTM Cell.
1773bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1774bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    Args:
17750a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower      inputs: A tuple of 2 Tensor.
17760a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower         The first Tensor has shape [batch, 1], and type float32 or float64.
17770a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower         It stores the time.
17780a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower         The second Tensor has shape [batch, features_size], and type float32.
17790a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower         It stores the features.
1780bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      state: core_rnn_cell.LSTMStateTuple, state from previous timestep.
1781bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1782bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    Returns:
1783bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      A tuple containing:
1784bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      - A Tensor of float32, and shape [batch_size, num_units], representing the
1785bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        output of the cell.
1786bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      - A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape
1787bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        [batch_size, num_units], representing the new state and the output.
1788bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    """
1789e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    (c_prev, h_prev) = state
1790e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    (time, x) = inputs
1791bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1792e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    in_mask_gates = [x, h_prev]
1793e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
1794e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      in_mask_gates.append(c_prev)
1795bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1796e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("mask_gates"):
1797e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      mask_gates = math_ops.sigmoid(
1798e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          _linear(in_mask_gates, 2 * self._num_units, True))
1799e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      [input_gate, forget_gate] = array_ops.split(
1800e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          axis=1, num_or_size_splits=2, value=mask_gates)
1801bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1802e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("new_input"):
1803e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      new_input = math_ops.tanh(
1804e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          _linear([x, h_prev], self._num_units, True))
1805bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1806e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_c = (c_prev * forget_gate + input_gate * new_input)
1807bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1808e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    in_out_gate = [x, h_prev]
1809e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
1810e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      in_out_gate.append(new_c)
1811bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1812e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("output_gate"):
1813e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      output_gate = math_ops.sigmoid(
1814e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          _linear(in_out_gate, self._num_units, True))
1815bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1816e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_h = math_ops.tanh(new_c) * output_gate
1817bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1818e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    period = vs.get_variable(
1819e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "period", [self._num_units],
1820e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        initializer=_random_exp_initializer(
1821e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo            self._period_init_min, self._period_init_max))
1822e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    phase = vs.get_variable(
1823e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "phase", [self._num_units],
1824e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        initializer=init_ops.random_uniform_initializer(
1825e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo            0., period.initial_value))
1826e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    ratio_on = vs.get_variable(
1827e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "ratio_on", [self._num_units],
1828e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        initializer=init_ops.constant_initializer(self._ratio_on),
1829e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        trainable=self._trainable_ratio_on)
1830bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1831e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    cycle_ratio = self._get_cycle_ratio(time, phase, period)
1832bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1833e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    k_up = 2 * cycle_ratio / ratio_on
1834e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    k_down = 2 - k_up
1835e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    k_closed = self._leak * cycle_ratio
1836bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1837e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
1838e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
1839bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1840e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_c = k * new_c + (1 - k) * c_prev
1841e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_h = k * new_h + (1 - k) * h_prev
1842bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1843e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
1844bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1845e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    return new_h, new_state
1846