1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#
3d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License");
4d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# you may not use this file except in compliance with the License.
5d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# You may obtain a copy of the License at
6d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#
7d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#     http://www.apache.org/licenses/LICENSE-2.0
8d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower#
9d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software
10d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS,
11d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# See the License for the specific language governing permissions and
13d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# limitations under the License.
14d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# ==============================================================================
15d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower"""Module for constructing RNN Cells."""
16d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import absolute_import
17d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import division
18d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import print_function
19d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
208fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlowerimport collections
21d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerimport math
22d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
23bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.contrib.compiler import jit
2434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.contrib.layers.python.layers import layers
2516fa134cfb576bfa690d7006864e555dc42c6b62Eugene Brevdofrom tensorflow.contrib.rnn.python.ops import core_rnn_cell
261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerfrom tensorflow.python.framework import dtypes
27bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.python.framework import op_def_registry
28d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.framework import ops
2928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlowerfrom tensorflow.python.framework import tensor_shape
3020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyenfrom tensorflow.python.layers import base as base_layer
31d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import array_ops
32d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import clip_ops
3334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.python.ops import init_ops
34d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import math_ops
35b9018073ec1afc7dfc302ab171db8bf5b177c2ddYifei Fengfrom tensorflow.python.ops import nn_impl  # pylint: disable=unused-import
36d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import nn_ops
37b9018073ec1afc7dfc302ab171db8bf5b177c2ddYifei Fengfrom tensorflow.python.ops import partitioned_variables  # pylint: disable=unused-import
38bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlowerfrom tensorflow.python.ops import random_ops
393038bc913713bc31d0b67150e7cf7c056baba7e4Adria Puigdomenechfrom tensorflow.python.ops import rnn_cell_impl
40d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import variable_scope as vs
415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlowerfrom tensorflow.python.platform import tf_logging as logging
424c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenechfrom tensorflow.python.util import nest
43d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
4437fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie
45d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_concat_variable(name, shape, dtype, num_shards):
46d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Get a sharded variable concatenated into one tensor."""
47d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)
48d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  if len(sharded_variable) == 1:
49d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return sharded_variable[0]
50d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
51d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  concat_name = name + "/concat"
52d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0"
53d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES):
54d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if value.name == concat_full_name:
55d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      return value
56d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
570e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower  concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name)
58ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable)
59d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  return concat_variable
60d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
61d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
62d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_sharded_variable(name, shape, dtype, num_shards):
63d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Get a list of sharded variables with the given dtype."""
64d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  if num_shards > shape[0]:
65ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape,
66ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                                                   num_shards))
67d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  unit_shard_size = int(math.floor(shape[0] / num_shards))
68d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  remaining_rows = shape[0] - unit_shard_size * num_shards
69d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
70d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  shards = []
71d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  for i in range(num_shards):
72d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    current_size = unit_shard_size
73d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if i < remaining_rows:
74d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      current_size += 1
75ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    shards.append(
76ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        vs.get_variable(
77ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            name + "_%d" % i, [current_size] + shape[1:], dtype=dtype))
78d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  return shards
79d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
80d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
81b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengdef _norm(g, b, inp, scope):
82b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  shape = inp.get_shape()[-1:]
83b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  gamma_init = init_ops.constant_initializer(g)
84b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  beta_init = init_ops.constant_initializer(b)
85b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  with vs.variable_scope(scope):
86b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    # Initialize beta and gamma for use by layer_norm.
87b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    vs.get_variable("gamma", shape=shape, initializer=gamma_init)
88b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    vs.get_variable("beta", shape=shape, initializer=beta_init)
89b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  normalized = layers.layer_norm(inp, reuse=True, scope=scope)
90b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  return normalized
91b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
92b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
93827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
941032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  """Long short-term memory unit (LSTM) recurrent network cell.
951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The default non-peephole implementation is based on:
971032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
9850b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower    http://www.bioinf.jku.at/publications/older/2604.pdf
991032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  S. Hochreiter and J. Schmidhuber.
1011032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
1021032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1031032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The peephole implementation is based on:
1041032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    https://research.google.com/pubs/archive/43905.pdf
1061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1071032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  Hasim Sak, Andrew Senior, and Francoise Beaufays.
1081032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  "Long short-term memory recurrent neural network architectures for
1091032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower   large scale acoustic modeling." INTERSPEECH, 2014.
1101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The coupling of input and forget gate is based on:
1121032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1131032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    http://arxiv.org/pdf/1503.04069.pdf
1141032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1151032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  Greff et al. "LSTM: A Search Space Odyssey"
1161032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1171032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  The class uses optional peep-hole connections, and an optional projection
1181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  layer.
119b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  Layer normalization implementation is based on:
120b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
121b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    https://arxiv.org/abs/1607.06450.
122b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
123b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  "Layer Normalization"
124b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
125b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
126b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  and is applied before the internal nonlinearities.
127b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
1281032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  """
1291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
130b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  def __init__(self,
131b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               num_units,
132b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               use_peepholes=False,
133b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               initializer=None,
134b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               num_proj=None,
135b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               proj_clip=None,
136b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               num_unit_shards=1,
137b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               num_proj_shards=1,
138b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               forget_bias=1.0,
139b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               state_is_tuple=True,
140b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               activation=math_ops.tanh,
141b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               reuse=None,
142b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               layer_norm=False,
143b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               norm_gain=1.0,
144b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               norm_shift=0.0):
1451032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
1461032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Args:
1481032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
1491032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      use_peepholes: bool, set True to enable diagonal/peephole connections.
1501032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
1511032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        projection matrices.
1521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_proj: (optional) int, The output dimensionality for the projection
1531032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        matrices.  If None, no projection is performed.
1541032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
1551032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      provided, then the projected values are clipped elementwise to within
1561032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      `[-proj_clip, proj_clip]`.
1571032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_unit_shards: How to split the weight matrix.  If >1, the weight
1581032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        matrix is stored across num_unit_shards.
1591032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      num_proj_shards: How to split the projection matrix.  If >1, the
1601032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        projection matrix is stored across num_proj_shards.
1611032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      forget_bias: Biases of the forget gate are initialized by default to 1
1621032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        in order to reduce the scale of forgetting at the beginning of
1631032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        the training.
1641032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      state_is_tuple: If True, accepted and returned states are 2-tuples of
1651032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        the `c_state` and `m_state`.  By default (False), they are concatenated
1661032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        along the column axis.  This default behavior will soon be deprecated.
1671032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      activation: Activation function of the inner states.
16854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
16954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
17054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
171b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      layer_norm: If `True`, layer normalization will be applied.
172b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      norm_gain: float, The layer normalization gain initial value. If
173b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        `layer_norm` has been set to `False`, this argument will be ignored.
174b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      norm_shift: float, The layer normalization shift initial value. If
175b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        `layer_norm` has been set to `False`, this argument will be ignored.
1761032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """
177e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
1781032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if not state_is_tuple:
179ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      logging.warn("%s: Using a concatenated state is slower and will soon be "
180ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                   "deprecated.  Use state_is_tuple=True.", self)
1811032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_units = num_units
1821032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._use_peepholes = use_peepholes
1831032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._initializer = initializer
1841032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_proj = num_proj
1851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._proj_clip = proj_clip
1861032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_unit_shards = num_unit_shards
1871032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._num_proj_shards = num_proj_shards
1881032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._forget_bias = forget_bias
1891032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._state_is_tuple = state_is_tuple
1901032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    self._activation = activation
19154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
192b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._layer_norm = layer_norm
193b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._norm_gain = norm_gain
194b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._norm_shift = norm_shift
1951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
1961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if num_proj:
197ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      self._state_size = (
198ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
199ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          if state_is_tuple else num_units + num_proj)
2001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      self._output_size = num_proj
2011032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    else:
202ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      self._state_size = (
203ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          rnn_cell_impl.LSTMStateTuple(num_units, num_units)
204ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          if state_is_tuple else 2 * num_units)
2051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      self._output_size = num_units
2061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2071032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  @property
2081032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  def state_size(self):
2091032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    return self._state_size
2101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  @property
2121032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower  def output_size(self):
2131032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    return self._output_size
2141032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
215e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
2161032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """Run one step of LSTM.
2171032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Args:
2191032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      inputs: input Tensor, 2D, batch x num_units.
2201032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      state: if `state_is_tuple` is False, this must be a state Tensor,
2211032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
2221032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
2231032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        `m_state`.
2241032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2251032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Returns:
2261032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      A tuple containing:
2271032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2281032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        LSTM after reading `inputs` when previous state was `state`.
2291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        Here output_dim is:
2301032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower           num_proj if num_proj was set,
2311032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower           num_units otherwise.
2321032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      - Tensor(s) representing the new state of LSTM after reading `inputs` when
2331032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        the previous state was `state`.  Same type and shape(s) as `state`.
2341032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2351032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    Raises:
2361032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      ValueError: If input size cannot be inferred from inputs via
2371032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower        static shape inference.
2381032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    """
239e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    sigmoid = math_ops.sigmoid
240e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo
2411032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    num_proj = self._num_units if self._num_proj is None else self._num_proj
2421032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2431032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if self._state_is_tuple:
2441032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      (c_prev, m_prev) = state
2451032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    else:
2461032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
2471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
2481032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
2491032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    dtype = inputs.dtype
2501032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    input_size = inputs.get_shape().with_rank(2)[1]
2511032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    if input_size.value is None:
2521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
253e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    concat_w = _get_concat_variable(
254ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        "W", [input_size.value + num_proj, 3 * self._num_units], dtype,
255ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        self._num_unit_shards)
2561032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
257e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    b = vs.get_variable(
258e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "B",
259e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        shape=[3 * self._num_units],
260e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        initializer=init_ops.zeros_initializer(),
261e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        dtype=dtype)
2621032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
263e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # j = new_input, f = forget_gate, o = output_gate
264e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    cell_inputs = array_ops.concat([inputs, m_prev], 1)
265b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    lstm_matrix = math_ops.matmul(cell_inputs, concat_w)
266b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
267b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    # If layer nomalization is applied, do not add bias
268b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    if not self._layer_norm:
269b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      lstm_matrix = nn_ops.bias_add(lstm_matrix, b)
270b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
271e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
2721032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
273b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    # Apply layer normalization
274b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    if self._layer_norm:
275b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      j = _norm(self._norm_gain, self._norm_shift, j, "transform")
276b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      f = _norm(self._norm_gain, self._norm_shift, f, "forget")
277b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      o = _norm(self._norm_gain, self._norm_shift, o, "output")
278b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
279e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Diagonal connections
280e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
281e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      w_f_diag = vs.get_variable(
282e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_F_diag", shape=[self._num_units], dtype=dtype)
283e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      w_o_diag = vs.get_variable(
284e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_O_diag", shape=[self._num_units], dtype=dtype)
2851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
286e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
287e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
288e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    else:
289e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      f_act = sigmoid(f + self._forget_bias)
290e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    c = (f_act * c_prev + (1 - f_act) * self._activation(j))
2911032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
292b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    # Apply layer normalization
293b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    if self._layer_norm:
294b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      c = _norm(self._norm_gain, self._norm_shift, c, "state")
295b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
296e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
297e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m = sigmoid(o + w_o_diag * c) * self._activation(c)
298e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    else:
299e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m = sigmoid(o) * self._activation(c)
3001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
301e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._num_proj is not None:
302ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      concat_w_proj = _get_concat_variable("W_P",
303ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                           [self._num_units, self._num_proj],
304ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                           dtype, self._num_proj_shards)
3051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
306e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m = math_ops.matmul(m, concat_w_proj)
307e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      if self._proj_clip is not None:
308e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        # pylint: disable=invalid-unary-operand-type
309e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
310e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        # pylint: enable=invalid-unary-operand-type
3111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
312ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    new_state = (
313ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        rnn_cell_impl.LSTMStateTuple(c, m)
314ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        if self._state_is_tuple else array_ops.concat([c, m], 1))
3151032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower    return m, new_state
3161032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
3171032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower
318827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass TimeFreqLSTMCell(rnn_cell_impl.RNNCell):
319d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell.
320d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
321d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  This implementation is based on:
322d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
323d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Tara N. Sainath and Bo Li
324d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
325d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    for LVCSR Tasks." submitted to INTERSPEECH, 2016.
326d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
327d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  It uses peep-hole connections and optional cell clipping.
328d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """
329d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
330ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self,
331ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_units,
332ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               use_peepholes=False,
333ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               cell_clip=None,
334ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               initializer=None,
335ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_unit_shards=1,
336ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               forget_bias=1.0,
337ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               feature_size=None,
338ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               frequency_skip=1,
33954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               reuse=None):
340d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
341d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
342d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
343d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
344d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      use_peepholes: bool, set True to enable diagonal/peephole connections.
345d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      cell_clip: (optional) A float value, if provided the cell state is clipped
346d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        by this value prior to the cell output activation.
347d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
348d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        projection matrices.
349d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      num_unit_shards: int, How to split the weight matrix.  If >1, the weight
350d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        matrix is stored across num_unit_shards.
351d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      forget_bias: float, Biases of the forget gate are initialized by default
352d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        to 1 in order to reduce the scale of forgetting at the beginning
353d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        of the training.
354d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      feature_size: int, The size of the input feature the LSTM spans over.
355d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      frequency_skip: int, The amount the LSTM filter is shifted by in
356d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        frequency.
35754d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
35854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
35954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
360d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
361e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(TimeFreqLSTMCell, self).__init__(_reuse=reuse)
362d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_units = num_units
363d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._use_peepholes = use_peepholes
364d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._cell_clip = cell_clip
365d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._initializer = initializer
366d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_unit_shards = num_unit_shards
367d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._forget_bias = forget_bias
368d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._feature_size = feature_size
369d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._frequency_skip = frequency_skip
370d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._state_size = 2 * num_units
371d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._output_size = num_units
37254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
373d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
374d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
375d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def output_size(self):
376d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._output_size
377d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
378d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
379d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def state_size(self):
380d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._state_size
381d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
382e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
383d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Run one step of LSTM.
384d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
385d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
386d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      inputs: input Tensor, 2D, batch x num_units.
387d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      state: state Tensor, 2D, batch x state_size.
388d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
389d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
390d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A tuple containing:
391d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      - A 2D, batch x output_dim, Tensor representing the output of the LSTM
392d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
393d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        Here output_dim is num_units.
394d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      - A 2D, batch x state_size, Tensor representing the new state of LSTM
395d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
396d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
397d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if an input_size was specified and the provided inputs have
398d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        a different dimension.
399d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
400e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    sigmoid = math_ops.sigmoid
401e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    tanh = math_ops.tanh
402d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
403d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    freq_inputs = self._make_tf_features(inputs)
404d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    dtype = inputs.dtype
405d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    actual_input_size = freq_inputs[0].get_shape().as_list()[1]
406d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
407e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    concat_w = _get_concat_variable(
408ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units],
409e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        dtype, self._num_unit_shards)
410d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
411e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    b = vs.get_variable(
412e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "B",
413e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        shape=[4 * self._num_units],
414e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        initializer=init_ops.zeros_initializer(),
415e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        dtype=dtype)
416d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
417e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Diagonal connections
418e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
419e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      w_f_diag = vs.get_variable(
420e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_F_diag", shape=[self._num_units], dtype=dtype)
421e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      w_i_diag = vs.get_variable(
422e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_I_diag", shape=[self._num_units], dtype=dtype)
423e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      w_o_diag = vs.get_variable(
424e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "W_O_diag", shape=[self._num_units], dtype=dtype)
425d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
426e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # initialize the first freq state to be zero
427f70db141135dd463c294e76284cab7331c55933eRasmus Munk Larsen    m_prev_freq = array_ops.zeros(
428f70db141135dd463c294e76284cab7331c55933eRasmus Munk Larsen        [inputs.shape[0].value or inputs.get_shape()[0], self._num_units],
429f70db141135dd463c294e76284cab7331c55933eRasmus Munk Larsen        dtype)
430e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    for fq in range(len(freq_inputs)):
431ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units],
432e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                               [-1, self._num_units])
433ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units],
434e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                               [-1, self._num_units])
435e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
436ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1)
437e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
438e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      i, j, f, o = array_ops.split(
439e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          value=lstm_matrix, num_or_size_splits=4, axis=1)
440e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
441e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      if self._use_peepholes:
442ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        c = (
443ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
444ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            sigmoid(i + w_i_diag * c_prev) * tanh(j))
445e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      else:
446e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
447e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
448e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      if self._cell_clip is not None:
449e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        # pylint: disable=invalid-unary-operand-type
450e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
451e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        # pylint: enable=invalid-unary-operand-type
452e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
453e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      if self._use_peepholes:
454e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        m = sigmoid(o + w_o_diag * c) * tanh(c)
455e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      else:
456e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        m = sigmoid(o) * tanh(c)
457e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m_prev_freq = m
458e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      if fq == 0:
459e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        state_out = array_ops.concat([c, m], 1)
460e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        m_out = m
461e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      else:
462e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        state_out = array_ops.concat([state_out, c, m], 1)
463e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        m_out = array_ops.concat([m_out, m], 1)
464d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return m_out, state_out
465d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
466d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def _make_tf_features(self, input_feat):
467d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Make the frequency features.
468d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
469d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
470d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      input_feat: input Tensor, 2D, batch x num_units.
471d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
472d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
473d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A list of frequency features, with each element containing:
474d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      - A 2D, batch x output_dim, Tensor representing the time-frequency feature
475d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        for that frequency index. Here output_dim is feature_size.
476d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
477d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if input_size cannot be inferred from static shape inference.
478d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
479d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    input_size = input_feat.get_shape().with_rank(2)[-1].value
480d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if input_size is None:
481d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      raise ValueError("Cannot infer input_size from static shape inference.")
482ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    num_feats = int(
483ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        (input_size - self._feature_size) / (self._frequency_skip)) + 1
484d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    freq_inputs = []
485d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    for f in range(num_feats):
486ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip],
487d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower                                  [-1, self._feature_size])
488d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      freq_inputs.append(cur_input)
489d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return freq_inputs
490d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
491d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
492827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass GridLSTMCell(rnn_cell_impl.RNNCell):
493d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """Grid Long short-term memory unit (LSTM) recurrent network cell.
494d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
495d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  The default is based on:
496d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Nal Kalchbrenner, Ivo Danihelka and Alex Graves
497d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    "Grid Long Short-Term Memory," Proc. ICLR 2016.
498d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    http://arxiv.org/abs/1507.01526
499d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
500d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  When peephole connections are used, the implementation is based on:
501d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Tara N. Sainath and Bo Li
502d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures
503d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    for LVCSR Tasks." submitted to INTERSPEECH, 2016.
504d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
505d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  The code uses optional peephole connections, shared_weights and cell clipping.
506d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  """
507d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
508ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self,
509ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_units,
510ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               use_peepholes=False,
511d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower               share_time_frequency_weights=False,
512ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               cell_clip=None,
513ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               initializer=None,
514ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_unit_shards=1,
515ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               forget_bias=1.0,
516ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               feature_size=None,
517ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               frequency_skip=None,
5189e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               num_frequency_blocks=None,
5199e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               start_freqindex_list=None,
5209e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               end_freqindex_list=None,
5218fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower               couple_input_forget_gates=False,
522c3d99052ec49bb219f5e29567846d3af391d7b28A. Unique TensorFlower               state_is_tuple=True,
52354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               reuse=None):
524d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
525d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
526d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
527d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
5281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      use_peepholes: (optional) bool, default False. Set True to enable
5291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        diagonal/peephole connections.
5301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      share_time_frequency_weights: (optional) bool, default False. Set True to
5311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        enable shared cell weights between time and frequency LSTMs.
5321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      cell_clip: (optional) A float value, default None, if provided the cell
5331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state is clipped by this value prior to the cell output activation.
534d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
5351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        projection matrices, default None.
5361b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu      num_unit_shards: (optional) int, default 1, How to split the weight
5371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        matrix. If > 1,the weight matrix is stored across num_unit_shards.
5381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      forget_bias: (optional) float, default 1.0, The initial bias of the
5391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        forget gates, used to reduce the scale of forgetting at the beginning
540d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        of the training.
5411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      feature_size: (optional) int, default None, The size of the input feature
5421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        the LSTM spans over.
5431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      frequency_skip: (optional) int, default None, The amount the LSTM filter
5441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        is shifted by in frequency.
5459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      num_frequency_blocks: [required] A list of frequency blocks needed to
5469e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cover the whole input feature splitting defined by start_freqindex_list
5479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        and end_freqindex_list.
5489e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      start_freqindex_list: [optional], list of ints, default None,  The
5499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        starting frequency index for each frequency block.
5509e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      end_freqindex_list: [optional], list of ints, default None. The ending
5519e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        frequency index for each frequency block.
5521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      couple_input_forget_gates: (optional) bool, default False, Whether to
5531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
5541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        model parameters and computation cost.
5558fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      state_is_tuple: If True, accepted and returned states are 2-tuples of
5568fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        the `c_state` and `m_state`.  By default (False), they are concatenated
5578fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        along the column axis.  This default behavior will soon be deprecated.
55854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
55954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
56054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
5619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    Raises:
5629e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      ValueError: if the num_frequency_blocks list is not specified
563d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
564e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(GridLSTMCell, self).__init__(_reuse=reuse)
5658fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    if not state_is_tuple:
5668fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      logging.warn("%s: Using a concatenated state is slower and will soon be "
5678fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower                   "deprecated.  Use state_is_tuple=True.", self)
568d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_units = num_units
569d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._use_peepholes = use_peepholes
570d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._share_time_frequency_weights = share_time_frequency_weights
5718fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    self._couple_input_forget_gates = couple_input_forget_gates
5728fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    self._state_is_tuple = state_is_tuple
573d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._cell_clip = cell_clip
574d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._initializer = initializer
575d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._num_unit_shards = num_unit_shards
576d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._forget_bias = forget_bias
577d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._feature_size = feature_size
578d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    self._frequency_skip = frequency_skip
5799e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._start_freqindex_list = start_freqindex_list
5809e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._end_freqindex_list = end_freqindex_list
5819e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._num_frequency_blocks = num_frequency_blocks
5829e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._total_blocks = 0
58354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
5849e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    if self._num_frequency_blocks is None:
5859e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      raise ValueError("Must specify num_frequency_blocks")
5869e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower
5879e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    for block_index in range(len(self._num_frequency_blocks)):
5889e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      self._total_blocks += int(self._num_frequency_blocks[block_index])
5898fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    if state_is_tuple:
5908fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      state_names = ""
5919e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for block_index in range(len(self._num_frequency_blocks)):
5929e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        for freq_index in range(self._num_frequency_blocks[block_index]):
5939e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          name_prefix = "state_f%02d_b%02d" % (freq_index, block_index)
5949e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
595ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple",
596ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                                      state_names.strip(","))
597ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      self._state_size = self._state_tuple_type(*(
598ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          [num_units, num_units] * self._total_blocks))
5998fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    else:
6008fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      self._state_tuple_type = None
6019e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      self._state_size = num_units * self._total_blocks * 2
6029e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._output_size = num_units * self._total_blocks * 2
603d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
604d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
605d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def output_size(self):
606d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._output_size
607d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
608d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  @property
609d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower  def state_size(self):
610d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return self._state_size
611d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
6128fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower  @property
6138fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower  def state_tuple_type(self):
6148fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    return self._state_tuple_type
6158fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower
616e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
617d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Run one step of LSTM.
618d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
619d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
6201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs: input Tensor, 2D, [batch, feature_size].
6211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
6221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        flag self._state_is_tuple.
623d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
624d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
625d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A tuple containing:
6261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
627d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
628d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        Here output_dim is num_units.
6291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
630d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        after reading "inputs" when previous state was "state".
631d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
632d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if an input_size was specified and the provided inputs have
633d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        a different dimension.
634d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
635499166454f0c7dd6c724c364f6dc5d99357514b1A. Unique TensorFlower    batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
6361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    freq_inputs = self._make_tf_features(inputs)
637e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    m_out_lst = []
638e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    state_out_lst = []
639e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    for block in range(len(freq_inputs)):
640e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m_out_lst_current, state_out_lst_current = self._compute(
641ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          freq_inputs[block],
642ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          block,
643ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          state,
644ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          batch_size,
645e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          state_is_tuple=self._state_is_tuple)
646e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m_out_lst.extend(m_out_lst_current)
647e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      state_out_lst.extend(state_out_lst_current)
648e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._state_is_tuple:
649e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      state_out = self._state_tuple_type(*state_out_lst)
650e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    else:
651e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      state_out = array_ops.concat(state_out_lst, 1)
652e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    m_out = array_ops.concat(m_out_lst, 1)
6531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    return m_out, state_out
6541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
655ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def _compute(self,
656ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               freq_inputs,
657ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               block,
658ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               state,
659ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               batch_size,
6609e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               state_prefix="state",
6611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               state_is_tuple=True):
6621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """Run the actual computation of one step LSTM.
6631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Args:
6651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      freq_inputs: list of Tensors, 2D, [batch, feature_size].
6669e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      block: int, current frequency block index to process.
6671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on
6681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        the flag state_is_tuple.
6691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      batch_size: int32, batch size.
6701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state_prefix: (optional) string, name prefix for states, defaults to
6711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        "state".
6721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state_is_tuple: boolean, indicates whether the state is a tuple or Tensor.
6731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Returns:
6751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      A tuple, containing:
6761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A list of [batch, output_dim] Tensors, representing the output of the
6771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        LSTM given the inputs and state.
6781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A list of [batch, state_size] Tensors, representing the LSTM state
6791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        values given the inputs and previous state.
6801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """
681e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    sigmoid = math_ops.sigmoid
682e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    tanh = math_ops.tanh
6838fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower    num_gates = 3 if self._couple_input_forget_gates else 4
6841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    dtype = freq_inputs[0].dtype
685d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    actual_input_size = freq_inputs[0].get_shape().as_list()[1]
6861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
6871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    concat_w_f = _get_concat_variable(
688ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        "W_f_%d" % block,
689ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        [actual_input_size + 2 * self._num_units, num_gates * self._num_units],
6901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        dtype, self._num_unit_shards)
6911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    b_f = vs.get_variable(
6924ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        "B_f_%d" % block,
6934ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        shape=[num_gates * self._num_units],
6944ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        initializer=init_ops.zeros_initializer(),
6954ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist        dtype=dtype)
6961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if not self._share_time_frequency_weights:
697ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      concat_w_t = _get_concat_variable("W_t_%d" % block, [
698ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          actual_input_size + 2 * self._num_units, num_gates * self._num_units
699ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      ], dtype, self._num_unit_shards)
7001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      b_t = vs.get_variable(
7014ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          "B_t_%d" % block,
7024ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          shape=[num_gates * self._num_units],
7034ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          initializer=init_ops.zeros_initializer(),
7044ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist          dtype=dtype)
705d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
7061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if self._use_peepholes:
7071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Diagonal connections
7081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if not self._couple_input_forget_gates:
7091855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_f_diag_freqf = vs.get_variable(
7109e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
7111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_f_diag_freqt = vs.get_variable(
712ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
7131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_i_diag_freqf = vs.get_variable(
7149e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
7151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_i_diag_freqt = vs.get_variable(
7169e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
7171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_o_diag_freqf = vs.get_variable(
7189e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype)
7191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      w_o_diag_freqt = vs.get_variable(
7209e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype)
7211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if not self._share_time_frequency_weights:
7228fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        if not self._couple_input_forget_gates:
7231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          w_f_diag_timef = vs.get_variable(
7249e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
7251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          w_f_diag_timet = vs.get_variable(
7269e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
7271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_i_diag_timef = vs.get_variable(
7289e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
7291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_i_diag_timet = vs.get_variable(
7309e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
7311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_o_diag_timef = vs.get_variable(
7329e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype)
7331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        w_o_diag_timet = vs.get_variable(
7349e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype)
7351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
7361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    # initialize the first freq state to be zero
7371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
7381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype)
7391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    for freq_index in range(len(freq_inputs)):
7401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if state_is_tuple:
7419e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block)
7421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_prev_time = getattr(state, name_prefix + "_c")
7431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_prev_time = getattr(state, name_prefix + "_m")
7441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
7451855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_prev_time = array_ops.slice(
746ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            state, [0, 2 * freq_index * self._num_units], [-1, self._num_units])
7471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_prev_time = array_ops.slice(
7481855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            state, [0, (2 * freq_index + 1) * self._num_units],
7491855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower            [-1, self._num_units])
7501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
7511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
7520e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower      cell_inputs = array_ops.concat(
753d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower          [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1)
7541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
7551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # F-LSTM
756ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      lstm_matrix_freq = nn_ops.bias_add(
757ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          math_ops.matmul(cell_inputs, concat_w_f), b_f)
7581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._couple_input_forget_gates:
759a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower        i_freq, j_freq, o_freq = array_ops.split(
760a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
7611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_freq = None
7621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
763a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower        i_freq, j_freq, f_freq, o_freq = array_ops.split(
764a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower            value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
7651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # T-LSTM
7661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._share_time_frequency_weights:
7671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_time = i_freq
7681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        j_time = j_freq
7691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_time = f_freq
7701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        o_time = o_freq
7711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
772ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        lstm_matrix_time = nn_ops.bias_add(
773ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            math_ops.matmul(cell_inputs, concat_w_t), b_t)
7748fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        if self._couple_input_forget_gates:
775a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower          i_time, j_time, o_time = array_ops.split(
776a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
7771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_time = None
7788fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        else:
779a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower          i_time, j_time, f_time, o_time = array_ops.split(
780a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower              value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
781d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
7821855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # F-LSTM c_freq
7831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # input gate activations
7841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
785ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq +
7861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_i_diag_freqt * c_prev_time)
7871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
7881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_freq_g = sigmoid(i_freq)
7891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # forget gate activations
7901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._couple_input_forget_gates:
7911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_freq_g = 1.0 - i_freq_g
7921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
793d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        if self._use_peepholes:
794ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf *
795ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                             c_prev_freq + w_f_diag_freqt * c_prev_time)
7961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        else:
7971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_freq_g = sigmoid(f_freq + self._forget_bias)
7981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # cell state
7991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq)
8001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._cell_clip is not None:
8011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: disable=invalid-unary-operand-type
8021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip,
8031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                        self._cell_clip)
8041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: enable=invalid-unary-operand-type
8051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
8061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # T-LSTM c_freq
8071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # input gate activations
8081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
8091855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        if self._share_time_frequency_weights:
810ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq +
8118fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower                             w_i_diag_freqt * c_prev_time)
8128fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        else:
813ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq +
8141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                             w_i_diag_timet * c_prev_time)
8151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
8161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        i_time_g = sigmoid(i_time)
8171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # forget gate activations
8181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._couple_input_forget_gates:
8191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        f_time_g = 1.0 - i_time_g
8201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
821d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        if self._use_peepholes:
822d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          if self._share_time_frequency_weights:
823ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf *
824ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                               c_prev_freq + w_f_diag_freqt * c_prev_time)
825d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower          else:
826ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef *
827ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                               c_prev_freq + w_f_diag_timet * c_prev_time)
8288fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower        else:
8291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower          f_time_g = sigmoid(f_time + self._forget_bias)
8301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # cell state
8311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time)
8321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._cell_clip is not None:
8331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: disable=invalid-unary-operand-type
8341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        c_time = clip_ops.clip_by_value(c_time, -self._cell_clip,
8351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                                        self._cell_clip)
8361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        # pylint: enable=invalid-unary-operand-type
8371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
8381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # F-LSTM m_freq
8391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
840ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq +
8411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                         w_o_diag_freqt * c_time) * tanh(c_freq)
8421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
8431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_freq = sigmoid(o_freq) * tanh(c_freq)
844d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
8451855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # T-LSTM m_time
8461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if self._use_peepholes:
8471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        if self._share_time_frequency_weights:
848ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          m_time = sigmoid(o_time + w_o_diag_freqf * c_freq +
8491855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_o_diag_freqt * c_time) * tanh(c_time)
850d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower        else:
851ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          m_time = sigmoid(o_time + w_o_diag_timef * c_freq +
8521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower                           w_o_diag_timet * c_time) * tanh(c_time)
8538fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower      else:
8541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_time = sigmoid(o_time) * tanh(c_time)
8551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
8561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      m_prev_freq = m_freq
8571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      c_prev_freq = c_freq
8581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Concatenate the outputs for T-LSTM and F-LSTM for each shift
8591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      if freq_index == 0:
8601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state_out_lst = [c_time, m_time]
8611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_out_lst = [m_time, m_freq]
8621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      else:
8631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state_out_lst.extend([c_time, m_time])
8641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        m_out_lst.extend([m_time, m_freq])
865d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
8661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    return m_out_lst, state_out_lst
8671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
8681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  def _make_tf_features(self, input_feat, slice_offset=0):
869d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """Make the frequency features.
870d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
871d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Args:
8721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      input_feat: input Tensor, 2D, [batch, num_units].
8731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      slice_offset: (optional) Python int, default 0, the slicing offset is only
8741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        used for the backward processing in the BidirectionalGridLSTMCell. It
8751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        specifies a different starting point instead of always 0 to enable the
8761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        forward and backward processing look at different frequency blocks.
877d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower
878d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Returns:
879d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      A list of frequency features, with each element containing:
8801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, output_dim], Tensor representing the time-frequency
8811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        feature for that frequency index. Here output_dim is feature_size.
882d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    Raises:
883d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      ValueError: if input_size cannot be inferred from static shape inference.
884d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    """
885d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    input_size = input_feat.get_shape().with_rank(2)[-1].value
886d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    if input_size is None:
887d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower      raise ValueError("Cannot infer input_size from static shape inference.")
8881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if slice_offset > 0:
8891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Padding to the end
890ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      inputs = array_ops.pad(input_feat,
891ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                             array_ops.constant(
892ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                 [0, 0, 0, slice_offset],
893ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                 shape=[2, 2],
894ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                 dtype=dtypes.int32), "CONSTANT")
8951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    elif slice_offset < 0:
8961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      # Padding to the front
897ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      inputs = array_ops.pad(input_feat,
898ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                             array_ops.constant(
899ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                 [0, 0, -slice_offset, 0],
900ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                 shape=[2, 2],
901ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                 dtype=dtypes.int32), "CONSTANT")
9021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      slice_offset = 0
9031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    else:
9041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs = input_feat
905d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    freq_inputs = []
9069e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    if not self._start_freqindex_list:
9079e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if len(self._num_frequency_blocks) != 1:
9089e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError("Length of num_frequency_blocks"
9099e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         " is not 1, but instead is %d",
9109e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._num_frequency_blocks))
911ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      num_feats = int(
912ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          (input_size - self._feature_size) / (self._frequency_skip)) + 1
9139e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if num_feats != self._num_frequency_blocks[0]:
9149e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError(
9159e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            "Invalid num_frequency_blocks, requires %d but gets %d, please"
916ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            " check the input size and filter config are correct." %
917ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            (self._num_frequency_blocks[0], num_feats))
9189e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      block_inputs = []
9199e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for f in range(num_feats):
9209e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cur_input = array_ops.slice(
9219e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            inputs, [0, slice_offset + f * self._frequency_skip],
9229e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower            [-1, self._feature_size])
9239e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        block_inputs.append(cur_input)
9249e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      freq_inputs.append(block_inputs)
9259e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    else:
9269e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if len(self._start_freqindex_list) != len(self._end_freqindex_list):
9279e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError("Length of start and end freqindex_list"
9289e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         " does not match %d %d",
9299e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._start_freqindex_list),
9309e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._end_freqindex_list))
9319e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      if len(self._num_frequency_blocks) != len(self._start_freqindex_list):
9329e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        raise ValueError("Length of num_frequency_blocks"
9339e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         " is not equal to start_freqindex_list %d %d",
9349e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._num_frequency_blocks),
9359e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                         len(self._start_freqindex_list))
9369e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for b in range(len(self._start_freqindex_list)):
9379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        start_index = self._start_freqindex_list[b]
9389e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        end_index = self._end_freqindex_list[b]
9399e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cur_size = end_index - start_index
940ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        block_feats = int(
941ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            (cur_size - self._feature_size) / (self._frequency_skip)) + 1
9429e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        if block_feats != self._num_frequency_blocks[b]:
9439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          raise ValueError(
9449e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              "Invalid num_frequency_blocks, requires %d but gets %d, please"
945ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie              " check the input size and filter config are correct." %
946ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie              (self._num_frequency_blocks[b], block_feats))
9479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        block_inputs = []
9489e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        for f in range(block_feats):
9499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          cur_input = array_ops.slice(
950ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie              inputs,
951ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie              [0, start_index + slice_offset + f * self._frequency_skip],
9529e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower              [-1, self._feature_size])
9539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          block_inputs.append(cur_input)
9549e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        freq_inputs.append(block_inputs)
955d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower    return freq_inputs
9565cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
9575cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
9581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerclass BidirectionalGridLSTMCell(GridLSTMCell):
9591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  """Bidirectional GridLstm cell.
9601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  The bidirection connection is only used in the frequency direction, which
9621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  hence doesn't affect the time direction's real-time processing that is
9631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  required for online recognition systems.
9641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  The current implementation uses different weights for the two directions.
9651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower  """
9661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
967ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self,
968ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_units,
969ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               use_peepholes=False,
9701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               share_time_frequency_weights=False,
971ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               cell_clip=None,
972ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               initializer=None,
973ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_unit_shards=1,
974ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               forget_bias=1.0,
975ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               feature_size=None,
976ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               frequency_skip=None,
9779e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               num_frequency_blocks=None,
9789e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               start_freqindex_list=None,
9799e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower               end_freqindex_list=None,
9801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower               couple_input_forget_gates=False,
98154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               backward_slice_offset=0,
98254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               reuse=None):
9831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """Initialize the parameters for an LSTM cell.
9841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
9851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Args:
9861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      num_units: int, The number of units in the LSTM cell
9871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      use_peepholes: (optional) bool, default False. Set True to enable
9881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        diagonal/peephole connections.
9891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      share_time_frequency_weights: (optional) bool, default False. Set True to
9901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        enable shared cell weights between time and frequency LSTMs.
9911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      cell_clip: (optional) A float value, default None, if provided the cell
9921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        state is clipped by this value prior to the cell output activation.
9931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      initializer: (optional) The initializer to use for the weight and
9941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        projection matrices, default None.
9951b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu      num_unit_shards: (optional) int, default 1, How to split the weight
9961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        matrix. If > 1,the weight matrix is stored across num_unit_shards.
9971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      forget_bias: (optional) float, default 1.0, The initial bias of the
9981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        forget gates, used to reduce the scale of forgetting at the beginning
9991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        of the training.
10001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      feature_size: (optional) int, default None, The size of the input feature
10011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        the LSTM spans over.
10021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      frequency_skip: (optional) int, default None, The amount the LSTM filter
10031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        is shifted by in frequency.
10049e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      num_frequency_blocks: [required] A list of frequency blocks needed to
10059e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        cover the whole input feature splitting defined by start_freqindex_list
10069e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        and end_freqindex_list.
10079e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      start_freqindex_list: [optional], list of ints, default None,  The
10089e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        starting frequency index for each frequency block.
10099e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      end_freqindex_list: [optional], list of ints, default None. The ending
10109e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        frequency index for each frequency block.
10111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      couple_input_forget_gates: (optional) bool, default False, Whether to
10121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce
10131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        model parameters and computation cost.
10141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      backward_slice_offset: (optional) int32, default 0, the starting offset to
10151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        slice the feature for backward processing.
101654d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
101754d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
101854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
10191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """
10201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    super(BidirectionalGridLSTMCell, self).__init__(
10211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        num_units, use_peepholes, share_time_frequency_weights, cell_clip,
10221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        initializer, num_unit_shards, forget_bias, feature_size, frequency_skip,
10239e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        num_frequency_blocks, start_freqindex_list, end_freqindex_list,
1024a5493db6082a6fe29ed11e71f2aab5bfdf5f98c7A. Unique TensorFlower        couple_input_forget_gates, True, reuse)
10251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    self._backward_slice_offset = int(backward_slice_offset)
10261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    state_names = ""
10271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    for direction in ["fwd", "bwd"]:
10289e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower      for block_index in range(len(self._num_frequency_blocks)):
10299e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower        for freq_index in range(self._num_frequency_blocks[block_index]):
10309e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index,
10319e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower                                                  block_index)
10329e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower          state_names += ("%s_c, %s_m," % (name_prefix, name_prefix))
10331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    self._state_tuple_type = collections.namedtuple(
10341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        "BidirectionalGridLSTMStateTuple", state_names.strip(","))
1035ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    self._state_size = self._state_tuple_type(*(
1036ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        [num_units, num_units] * self._total_blocks * 2))
10379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower    self._output_size = 2 * num_units * self._total_blocks * 2
10381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
1039e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
10401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """Run one step of LSTM.
10411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
10421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Args:
10431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      inputs: input Tensor, 2D, [batch, num_units].
10441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      state: tuple of Tensors, 2D, [batch, state_size].
10451855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
10461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Returns:
10471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      A tuple containing:
10481855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, output_dim], Tensor representing the output of the LSTM
10491855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        after reading "inputs" when previous state was "state".
10501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        Here output_dim is num_units.
10511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      - A 2D, [batch, state_size], Tensor representing the new state of LSTM
10521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        after reading "inputs" when previous state was "state".
10531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    Raises:
10541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      ValueError: if an input_size was specified and the provided inputs have
10551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower        a different dimension.
10561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    """
1057499166454f0c7dd6c724c364f6dc5d99357514b1A. Unique TensorFlower    batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
10581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    fwd_inputs = self._make_tf_features(inputs)
10591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    if self._backward_slice_offset:
10601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset)
10611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    else:
10621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower      bwd_inputs = fwd_inputs
10631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
10641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    # Forward processing
1065e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("fwd"):
1066e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      fwd_m_out_lst = []
1067e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      fwd_state_out_lst = []
1068e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      for block in range(len(fwd_inputs)):
1069e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
1070ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            fwd_inputs[block],
1071ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            block,
1072ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            state,
1073ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            batch_size,
1074ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            state_prefix="fwd_state",
1075ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            state_is_tuple=True)
1076e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        fwd_m_out_lst.extend(fwd_m_out_lst_current)
1077e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        fwd_state_out_lst.extend(fwd_state_out_lst_current)
1078e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Backward processing
1079e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    bwd_m_out_lst = []
1080e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    bwd_state_out_lst = []
1081e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("bwd"):
1082e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      for block in range(len(bwd_inputs)):
1083e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        # Reverse the blocks
1084e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        bwd_inputs_reverse = bwd_inputs[block][::-1]
1085e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
1086ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            bwd_inputs_reverse,
1087ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            block,
1088ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            state,
1089ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            batch_size,
1090ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            state_prefix="bwd_state",
1091ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            state_is_tuple=True)
1092e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        bwd_m_out_lst.extend(bwd_m_out_lst_current)
1093e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        bwd_state_out_lst.extend(bwd_state_out_lst_current)
10941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
10951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    # Outputs are always concated as it is never used separately.
10960e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower    m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
10971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower    return m_out, state_out
10981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
10991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower
11005cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: disable=protected-access
110116fa134cfb576bfa690d7006864e555dc42c6b62Eugene Brevdo_Linear = core_rnn_cell._Linear  # pylint: disable=invalid-name
1102ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
11035cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: enable=protected-access
11045cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
11055cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
1106827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass AttentionCellWrapper(rnn_cell_impl.RNNCell):
11075cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  """Basic attention cell wrapper.
11085cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
1109ec3f4d62979ef1e70e8e12e2568b13dad45fd39eA. Unique TensorFlower  Implementation based on https://arxiv.org/abs/1409.0473.
11105cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  """
11115cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
1112ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self,
1113ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               cell,
1114ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               attn_length,
1115ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               attn_size=None,
1116ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               attn_vec_size=None,
1117ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               input_size=None,
1118ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               state_is_tuple=True,
1119ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               reuse=None):
11205cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    """Create a cell with attention.
11215cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
11225cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    Args:
11235cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      cell: an RNNCell, an attention is added to it.
11245cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_length: integer, the size of an attention window.
11255cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_size: integer, the size of an attention vector. Equal to
11265cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          cell.output_size by default.
11275cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_vec_size: integer, the number of convolutional features calculated
11285cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          on attention state and a size of the hidden layer built from
11295cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          base cell state. Equal attn_size to by default.
11305cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      input_size: integer, the size of a hidden linear layer,
11315cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          built from inputs and attention. Derived from the input tensor
11325cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          by default.
11335cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      state_is_tuple: If True, accepted and returned states are n-tuples, where
11345cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        `n = len(cells)`.  By default (False), the states are all
11355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower        concatenated along the column axis.
113654d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
113754d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
113854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
11395cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
11405cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    Raises:
11415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      TypeError: if cell is not an RNNCell.
11425cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      ValueError: if cell returns a state tuple but the flag
11435cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          `state_is_tuple` is `False` or if attn_length is zero or less.
11445cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    """
1145e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(AttentionCellWrapper, self).__init__(_reuse=reuse)
11463038bc913713bc31d0b67150e7cf7c056baba7e4Adria Puigdomenech    if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
11475cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      raise TypeError("The parameter cell is not RNNCell.")
11484c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenech    if nest.is_sequence(cell.state_size) and not state_is_tuple:
1149ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      raise ValueError(
1150ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          "Cell returns tuple of states, but the flag "
1151ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          "state_is_tuple is not set. State size is: %s" % str(cell.state_size))
11525cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if attn_length <= 0:
1153ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      raise ValueError(
1154ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          "attn_length should be greater than zero, got %s" % str(attn_length))
11555cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if not state_is_tuple:
1156ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      logging.warn("%s: Using a concatenated state is slower and will soon be "
1157ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                   "deprecated.  Use state_is_tuple=True.", self)
11585cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if attn_size is None:
11595cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_size = cell.output_size
11605cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if attn_vec_size is None:
11615cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      attn_vec_size = attn_size
11625cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._state_is_tuple = state_is_tuple
11635cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._cell = cell
11645cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._attn_vec_size = attn_vec_size
11655cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._input_size = input_size
11665cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._attn_size = attn_size
11675cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    self._attn_length = attn_length
116854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
11693f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    self._linear1 = None
11703f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    self._linear2 = None
11713f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    self._linear3 = None
11725cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
11735cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  @property
11745cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def state_size(self):
11755cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    size = (self._cell.state_size, self._attn_size,
11765cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower            self._attn_size * self._attn_length)
11775cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    if self._state_is_tuple:
11785cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      return size
11795cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    else:
11805cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      return sum(list(size))
11815cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
11825cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  @property
11835cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def output_size(self):
11845cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    return self._attn_size
11855cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
1186e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
11875cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower    """Long short-term memory cell with attention (LSTMA)."""
1188e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._state_is_tuple:
1189e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      state, attns, attn_states = state
1190e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    else:
1191e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      states = state
1192e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
1193ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      attns = array_ops.slice(states, [0, self._cell.state_size],
1194ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                              [-1, self._attn_size])
1195e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      attn_states = array_ops.slice(
1196e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          states, [0, self._cell.state_size + self._attn_size],
1197e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          [-1, self._attn_size * self._attn_length])
1198e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    attn_states = array_ops.reshape(attn_states,
1199e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo                                    [-1, self._attn_length, self._attn_size])
1200e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    input_size = self._input_size
1201e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if input_size is None:
1202e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      input_size = inputs.get_shape().as_list()[1]
12033f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    if self._linear1 is None:
12043f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      self._linear1 = _Linear([inputs, attns], input_size, True)
12053f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    inputs = self._linear1([inputs, attns])
120650b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower    cell_output, new_state = self._cell(inputs, state)
1207e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._state_is_tuple:
1208e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
1209e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    else:
1210e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      new_state_cat = new_state
1211e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
1212e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("attn_output_projection"):
12133f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      if self._linear2 is None:
12143f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower        self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True)
12153f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      output = self._linear2([cell_output, new_attns])
1216e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_attn_states = array_ops.concat(
1217e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        [new_attn_states, array_ops.expand_dims(output, 1)], 1)
1218e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_attn_states = array_ops.reshape(
1219e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        new_attn_states, [-1, self._attn_length * self._attn_size])
1220e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_state = (new_state, new_attns, new_attn_states)
1221e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if not self._state_is_tuple:
1222e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      new_state = array_ops.concat(list(new_state), 1)
1223e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    return output, new_state
12245cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower
12255cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower  def _attention(self, query, attn_states):
1226e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    conv2d = nn_ops.conv2d
1227e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    reduce_sum = math_ops.reduce_sum
1228e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    softmax = nn_ops.softmax
1229e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo    tanh = math_ops.tanh
1230e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo
123192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope("attention"):
1232ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      k = vs.get_variable("attn_w",
1233ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                          [1, 1, self._attn_size, self._attn_vec_size])
123492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      v = vs.get_variable("attn_v", [self._attn_vec_size])
12355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      hidden = array_ops.reshape(attn_states,
12365cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower                                 [-1, self._attn_length, 1, self._attn_size])
12375cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME")
12383f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      if self._linear3 is None:
12393f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower        self._linear3 = _Linear(query, self._attn_vec_size, True)
12403f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      y = self._linear3(query)
12415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size])
12425cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      s = reduce_sum(v * tanh(hidden_features + y), [2, 3])
12435cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      a = softmax(s)
12445cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      d = reduce_sum(
12455cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower          array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2])
12465cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      new_attns = array_ops.reshape(d, [-1, self._attn_size])
12475cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1])
12485cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower      return new_attns, new_attn_states
124934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
125034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1251827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass HighwayWrapper(rnn_cell_impl.RNNCell):
125203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower  """RNNCell wrapper that adds highway connection on cell input and output.
125303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
125403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower  Based on:
125503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks",
125603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    arXiv preprint arXiv:1505.00387, 2015.
125703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    https://arxiv.org/abs/1505.00387
125803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower  """
125903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
1260ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self,
1261ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               cell,
126203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower               couple_carry_transform_gates=True,
126303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower               carry_bias_init=1.0):
126403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    """Constructs a `HighwayWrapper` for `cell`.
126503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
126603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    Args:
126703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      cell: An instance of `RNNCell`.
126803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      couple_carry_transform_gates: boolean, should the Carry and Transform gate
126903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower        be coupled.
127003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      carry_bias_init: float, carry gates bias initialization.
127103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    """
127203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    self._cell = cell
127303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    self._couple_carry_transform_gates = couple_carry_transform_gates
127403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    self._carry_bias_init = carry_bias_init
127503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
127603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower  @property
127703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower  def state_size(self):
127803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    return self._cell.state_size
127903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
128003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower  @property
128103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower  def output_size(self):
128203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    return self._cell.output_size
128303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
128403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower  def zero_state(self, batch_size, dtype):
128503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
128603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      return self._cell.zero_state(batch_size, dtype)
128703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
128803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower  def _highway(self, inp, out):
128903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    input_size = inp.get_shape().with_rank(2)[1].value
129003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    carry_weight = vs.get_variable("carry_w", [input_size, input_size])
129103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    carry_bias = vs.get_variable(
129203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower        "carry_b", [input_size],
1293ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        initializer=init_ops.constant_initializer(self._carry_bias_init))
129403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
129503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    if self._couple_carry_transform_gates:
129603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      transform = 1 - carry
129703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    else:
129803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      transform_weight = vs.get_variable("transform_w",
129903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower                                         [input_size, input_size])
130003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      transform_bias = vs.get_variable(
130103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower          "transform_b", [input_size],
1302ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          initializer=init_ops.constant_initializer(-self._carry_bias_init))
1303ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      transform = math_ops.sigmoid(
1304ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          nn_ops.xw_plus_b(inp, transform_weight, transform_bias))
130503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    return inp * carry + out * transform
130603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
130703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower  def __call__(self, inputs, state, scope=None):
130803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    """Run the cell and add its inputs to its outputs.
130903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
131003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    Args:
131103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      inputs: cell inputs.
131203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      state: cell state.
131303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      scope: optional cell scope.
131403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
131503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    Returns:
131603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      Tuple of cell outputs and new state.
131703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
131803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    Raises:
131903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      TypeError: If cell inputs and outputs have different structure (type).
132003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      ValueError: If cell inputs and outputs have different structure (value).
132103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    """
132203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    outputs, new_state = self._cell(inputs, state, scope=scope)
132303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    nest.assert_same_structure(inputs, outputs)
1324ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
132503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    # Ensure shapes match
132603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    def assert_shape_match(inp, out):
132703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower      inp.get_shape().assert_is_compatible_with(out.get_shape())
1328ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
132903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    nest.map_structure(assert_shape_match, inputs, outputs)
133003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    res_outputs = nest.map_structure(self._highway, inputs, outputs)
133103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower    return (res_outputs, new_state)
133203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
133303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower
1334827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
133534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  """LSTM unit with layer normalization and recurrent dropout.
133634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
133734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  This class adds layer normalization and recurrent dropout to a
133834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  basic LSTM unit. Layer normalization implementation is based on:
133934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
134034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    https://arxiv.org/abs/1607.06450.
134134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
134234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  "Layer Normalization"
134334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
134434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
134534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  and is applied before the internal nonlinearities.
134634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  Recurrent dropout is base on:
134734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
134834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    https://arxiv.org/abs/1603.05118
134934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
135034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  "Recurrent Dropout without Memory Loss"
135134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth.
135234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  """
135334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1354ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self,
1355ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_units,
1356ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               forget_bias=1.0,
1357ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               input_size=None,
1358ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               activation=math_ops.tanh,
1359ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               layer_norm=True,
1360ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               norm_gain=1.0,
1361ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               norm_shift=0.0,
1362ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               dropout_keep_prob=1.0,
1363ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               dropout_prob_seed=None,
136454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo               reuse=None):
136534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    """Initializes the basic LSTM cell.
136634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
136734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    Args:
136834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      num_units: int, The number of units in the LSTM cell.
136934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      forget_bias: float, The bias added to forget gates (see above).
137034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      input_size: Deprecated and unused.
137134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      activation: Activation function of the inner states.
137234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      layer_norm: If `True`, layer normalization will be applied.
137334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      norm_gain: float, The layer normalization gain initial value. If
137434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        `layer_norm` has been set to `False`, this argument will be ignored.
137534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      norm_shift: float, The layer normalization shift initial value. If
137634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        `layer_norm` has been set to `False`, this argument will be ignored.
137734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      dropout_keep_prob: unit Tensor or float between 0 and 1 representing the
137834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        recurrent dropout probability value. If float and 1.0, no dropout will
137934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower        be applied.
138034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      dropout_prob_seed: (optional) integer, the randomness seed.
138154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
138254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
138354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
138434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    """
1385e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse)
138634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
138734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    if input_size is not None:
138834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower      logging.warn("%s: The input_size parameter is deprecated.", self)
138934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
139034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._num_units = num_units
139134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._activation = activation
139234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._forget_bias = forget_bias
139334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._keep_prob = dropout_keep_prob
139434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._seed = dropout_prob_seed
139534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    self._layer_norm = layer_norm
1396b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._norm_gain = norm_gain
1397b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._norm_shift = norm_shift
139854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
139934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
140034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  @property
140134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def state_size(self):
1402827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
140334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
140434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  @property
140534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower  def output_size(self):
140634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    return self._num_units
140734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1408b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  def _norm(self, inp, scope, dtype=dtypes.float32):
140992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    shape = inp.get_shape()[-1:]
1410b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    gamma_init = init_ops.constant_initializer(self._norm_gain)
1411b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    beta_init = init_ops.constant_initializer(self._norm_shift)
141292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    with vs.variable_scope(scope):
141392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      # Initialize beta and gamma for use by layer_norm.
1414b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype)
1415b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype)
141692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    normalized = layers.layer_norm(inp, reuse=True, scope=scope)
141792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    return normalized
141892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo
141992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo  def _linear(self, args):
142034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    out_size = 4 * self._num_units
142134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    proj_size = args.get_shape()[-1]
1422b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    dtype = args.dtype
1423b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype)
142492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    out = math_ops.matmul(args, weights)
142592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    if not self._layer_norm:
1426b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      bias = vs.get_variable("bias", [out_size], dtype=dtype)
142792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo      out = nn_ops.bias_add(out, bias)
142892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo    return out
142934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1430e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
143134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower    """LSTM cell with layer normalization and recurrent dropout."""
1432e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    c, h = state
1433e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    args = array_ops.concat([inputs, h], 1)
1434e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    concat = self._linear(args)
1435b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    dtype = args.dtype
143634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1437e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
1438e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._layer_norm:
1439b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      i = self._norm(i, "input", dtype=dtype)
1440b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      j = self._norm(j, "transform", dtype=dtype)
1441b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      f = self._norm(f, "forget", dtype=dtype)
1442b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      o = self._norm(o, "output", dtype=dtype)
144334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1444e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    g = self._activation(j)
1445e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
1446e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
144734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1448ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    new_c = (
1449ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g)
1450e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._layer_norm:
1451b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      new_c = self._norm(new_c, "state", dtype=dtype)
1452e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_h = self._activation(new_c) * math_ops.sigmoid(o)
145334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower
1454827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
1455e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    return new_h, new_state
1456bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1457bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1458827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass NASCell(rnn_cell_impl.RNNCell):
14591e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  """Neural Architecture Search (NAS) recurrent network cell.
14601e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
14611e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  This implements the recurrent cell from the paper:
14621e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
14631e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    https://arxiv.org/abs/1611.01578
14641e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
14651e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  Barret Zoph and Quoc V. Le.
14661e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017.
14671e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
14681e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  The class uses an optional projection layer.
14691e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  """
14701e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
1471ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None):
14721e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    """Initialize the parameters for a NAS cell.
14731e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
14741e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    Args:
14751e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      num_units: int, The number of units in the NAS cell
14761e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      num_proj: (optional) int, The output dimensionality for the projection
14771e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        matrices.  If None, no projection is performed.
14781e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      use_biases: (optional) bool, If True then use biases within the cell. This
14791e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        is False by default.
148054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo      reuse: (optional) Python boolean describing whether to reuse variables
148154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        in an existing scope.  If not `True`, and the existing scope already has
148254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo        the given variables, an error is raised.
14831e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    """
1484e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(NASCell, self).__init__(_reuse=reuse)
14851e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    self._num_units = num_units
14861e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    self._num_proj = num_proj
14871e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    self._use_biases = use_biases
148854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo    self._reuse = reuse
14891e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
14901e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    if num_proj is not None:
1491827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
14921e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      self._output_size = num_proj
14931e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    else:
1494827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
14951e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      self._output_size = num_units
14961e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
14971e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  @property
14981e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  def state_size(self):
14991e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    return self._state_size
15001e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
15011e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  @property
15021e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower  def output_size(self):
15031e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    return self._output_size
15041e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
1505e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
15061e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    """Run one step of NAS Cell.
15071e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
15081e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    Args:
15091e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      inputs: input Tensor, 2D, batch x num_units.
15101e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      state: This must be a tuple of state Tensors, both `2-D`, with column
15111e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        sizes `c_state` and `m_state`.
15121e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
15131e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    Returns:
15141e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      A tuple containing:
15151e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
15161e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        NAS Cell after reading `inputs` when previous state was `state`.
15171e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        Here output_dim is:
15181e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower           num_proj if num_proj was set,
15191e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower           num_units otherwise.
15201e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      - Tensor(s) representing the new state of NAS Cell after reading `inputs`
15211e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        when the previous state was `state`.  Same type and shape(s) as `state`.
15221e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
15231e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    Raises:
15241e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      ValueError: If input size cannot be inferred from inputs via
15251e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower        static shape inference.
15261e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    """
15271e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    sigmoid = math_ops.sigmoid
15281e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    tanh = math_ops.tanh
15291e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    relu = nn_ops.relu
15301e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
15311e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    num_proj = self._num_units if self._num_proj is None else self._num_proj
15321e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
15331e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    (c_prev, m_prev) = state
15341e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
15351e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    dtype = inputs.dtype
15361e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    input_size = inputs.get_shape().with_rank(2)[1]
15371e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower    if input_size.value is None:
15381e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1539e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Variables for the NAS cell. W_m is all matrices multiplying the
1540e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # hiddenstate and W_inputs is all matrices multiplying the inputs.
1541ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    concat_w_m = vs.get_variable("recurrent_kernel",
1542ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                 [num_proj, 8 * self._num_units], dtype)
1543e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    concat_w_inputs = vs.get_variable(
1544ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        "kernel", [input_size.value, 8 * self._num_units], dtype)
1545e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1546e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    m_matrix = math_ops.matmul(m_prev, concat_w_m)
1547e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
1548e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1549e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_biases:
1550e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      b = vs.get_variable(
1551e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          "bias",
1552e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          shape=[8 * self._num_units],
1553e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          initializer=init_ops.zeros_initializer(),
1554e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          dtype=dtype)
1555e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      m_matrix = nn_ops.bias_add(m_matrix, b)
1556e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1557e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # The NAS cell branches into 8 different splits for both the hiddenstate
1558e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # and the input
1559ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    m_matrix_splits = array_ops.split(
1560ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        axis=1, num_or_size_splits=8, value=m_matrix)
1561ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    inputs_matrix_splits = array_ops.split(
1562ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        axis=1, num_or_size_splits=8, value=inputs_matrix)
1563e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1564e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # First layer
1565e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
1566e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
1567e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
1568e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
1569e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
1570e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
1571e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
1572e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
1573e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1574e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Second layer
1575e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l2_0 = tanh(layer1_0 * layer1_1)
1576e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l2_1 = tanh(layer1_2 + layer1_3)
1577e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l2_2 = tanh(layer1_4 * layer1_5)
1578e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l2_3 = sigmoid(layer1_6 + layer1_7)
1579e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1580e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Inject the cell
1581e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l2_0 = tanh(l2_0 + c_prev)
1582e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1583e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Third layer
1584e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l3_0_pre = l2_0 * l2_1
1585e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_c = l3_0_pre  # create new cell
1586e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l3_0 = l3_0_pre
1587e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    l3_1 = tanh(l2_2 + l2_3)
1588e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1589e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Final layer
1590e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_m = tanh(l3_0 * l3_1)
1591e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo
1592e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    # Projection layer if specified
1593e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._num_proj is not None:
1594ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      concat_w_proj = vs.get_variable("projection_weights",
1595ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                      [self._num_units, self._num_proj], dtype)
1596e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      new_m = math_ops.matmul(new_m, concat_w_proj)
15971e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
1598827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m)
1599e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    return new_m, new_state
16001e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
16011e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower
1602827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass UGRNNCell(rnn_cell_impl.RNNCell):
1603d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  """Update Gate Recurrent Neural Network (UGRNN) cell.
1604d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1605d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  Compromise between a LSTM/GRU and a vanilla RNN.  There is only one
1606d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  gate, and that is to determine whether the unit should be
1607d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  integrating or computing instantaneously.  This is the recurrent
1608d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  idea of the feedforward Highway Network.
1609d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1610d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  This implements the recurrent cell from the paper:
1611d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1612d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    https://arxiv.org/abs/1611.09913
1613d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1614d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
1615d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
1616d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  """
1617d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1618ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self,
1619ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_units,
1620ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               initializer=None,
1621ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               forget_bias=1.0,
1622ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               activation=math_ops.tanh,
1623ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               reuse=None):
1624d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """Initialize the parameters for an UGRNN cell.
1625d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1626d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Args:
1627d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      num_units: int, The number of units in the UGRNN cell
1628d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      initializer: (optional) The initializer to use for the weight matrices.
1629d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      forget_bias: (optional) float, default 1.0, The initial bias of the
1630d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        forget gate, used to reduce the scale of forgetting at the beginning
1631d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        of the training.
1632d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      activation: (optional) Activation function of the inner states.
1633d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        Default is `tf.tanh`.
1634d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      reuse: (optional) Python boolean describing whether to reuse variables
1635d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        in an existing scope.  If not `True`, and the existing scope already has
1636d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        the given variables, an error is raised.
1637d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """
1638e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(UGRNNCell, self).__init__(_reuse=reuse)
1639d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._num_units = num_units
1640d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._initializer = initializer
1641d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._forget_bias = forget_bias
1642d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._activation = activation
1643d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._reuse = reuse
16443f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    self._linear = None
1645d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1646d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  @property
1647d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  def state_size(self):
1648d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return self._num_units
1649d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1650d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  @property
1651d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  def output_size(self):
1652d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return self._num_units
1653d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1654e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
1655d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """Run one step of UGRNN.
1656d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1657d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Args:
1658d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      inputs: input Tensor, 2D, batch x input size.
1659d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      state: state Tensor, 2D, batch x num units.
1660d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1661d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Returns:
1662d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_output: batch x num units, Tensor representing the output of the UGRNN
1663d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        after reading `inputs` when previous state was `state`. Identical to
1664d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        `new_state`.
1665d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_state: batch x num units, Tensor representing the state of the UGRNN
1666d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        after reading `inputs` when previous state was `state`.
1667d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1668d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Raises:
1669d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      ValueError: If input size cannot be inferred from inputs via
1670d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        static shape inference.
1671d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """
1672d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    sigmoid = math_ops.sigmoid
1673d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1674d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    input_size = inputs.get_shape().with_rank(2)[1]
1675d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    if input_size.value is None:
1676d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1677d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1678ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    with vs.variable_scope(
1679ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        vs.get_variable_scope(), initializer=self._initializer):
1680d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      cell_inputs = array_ops.concat([inputs, state], 1)
16813f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      if self._linear is None:
16823f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower        self._linear = _Linear(cell_inputs, 2 * self._num_units, True)
16833f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      rnn_matrix = self._linear(cell_inputs)
1684d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1685d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      [g_act, c_act] = array_ops.split(
1686d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower          axis=1, num_or_size_splits=2, value=rnn_matrix)
1687d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1688d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      c = self._activation(c_act)
1689d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      g = sigmoid(g_act + self._forget_bias)
1690d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_state = g * state + (1.0 - g) * c
1691d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_output = new_state
1692d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1693d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return new_output, new_state
1694d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1695d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1696827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass IntersectionRNNCell(rnn_cell_impl.RNNCell):
1697d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  """Intersection Recurrent Neural Network (+RNN) cell.
1698d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1699d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  Architecture with coupled recurrent gate as well as coupled depth
1700d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  gate, designed to improve information flow through stacked RNNs. As the
1701d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  architecture uses depth gating, the dimensionality of the depth
1702d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  output (y) also should not change through depth (input size == output size).
1703d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  To achieve this, the first layer of a stacked Intersection RNN projects
1704d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  the inputs to N (num units) dimensions. Therefore when initializing an
1705d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  IntersectionRNNCell, one should set `num_in_proj = N` for the first layer
1706d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  and use default settings for subsequent layers.
1707d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1708d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  This implements the recurrent cell from the paper:
1709d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1710d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    https://arxiv.org/abs/1611.09913
1711d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1712d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
1713d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
1714d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1715d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  The Intersection RNN is built for use in deeply stacked
1716d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  RNNs so it may not achieve best performance with depth 1.
1717d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  """
1718d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1719ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self,
1720ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_units,
1721ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_in_proj=None,
1722ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               initializer=None,
1723ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               forget_bias=1.0,
1724ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               y_activation=nn_ops.relu,
1725ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               reuse=None):
1726d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """Initialize the parameters for an +RNN cell.
1727d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1728d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Args:
1729d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      num_units: int, The number of units in the +RNN cell
1730d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      num_in_proj: (optional) int, The input dimensionality for the RNN.
1731d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        If creating the first layer of an +RNN, this should be set to
1732d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        `num_units`. Otherwise, this should be set to `None` (default).
1733d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        If `None`, dimensionality of `inputs` should be equal to `num_units`,
1734d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        otherwise ValueError is thrown.
1735d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      initializer: (optional) The initializer to use for the weight matrices.
1736d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      forget_bias: (optional) float, default 1.0, The initial bias of the
1737d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        forget gates, used to reduce the scale of forgetting at the beginning
1738d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        of the training.
1739d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      y_activation: (optional) Activation function of the states passed
1740d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        through depth. Default is 'tf.nn.relu`.
1741d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      reuse: (optional) Python boolean describing whether to reuse variables
1742d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        in an existing scope.  If not `True`, and the existing scope already has
1743d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        the given variables, an error is raised.
1744d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """
1745e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(IntersectionRNNCell, self).__init__(_reuse=reuse)
1746d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._num_units = num_units
1747d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._initializer = initializer
1748d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._forget_bias = forget_bias
1749d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._num_input_proj = num_in_proj
1750d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._y_activation = y_activation
1751d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    self._reuse = reuse
17523f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    self._linear1 = None
17533f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    self._linear2 = None
1754d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1755d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  @property
1756d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  def state_size(self):
1757d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return self._num_units
1758d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1759d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  @property
1760d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower  def output_size(self):
1761d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return self._num_units
1762d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1763e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
1764d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """Run one step of the Intersection RNN.
1765d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1766d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Args:
1767d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      inputs: input Tensor, 2D, batch x input size.
1768d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      state: state Tensor, 2D, batch x num units.
1769d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1770d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Returns:
1771d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_y: batch x num units, Tensor representing the output of the +RNN
1772d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        after reading `inputs` when previous state was `state`.
1773d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_state: batch x num units, Tensor representing the state of the +RNN
1774d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        after reading `inputs` when previous state was `state`.
1775d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1776d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    Raises:
1777d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      ValueError: If input size cannot be inferred from `inputs` via
1778d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        static shape inference.
1779d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      ValueError: If input size != output size (these must be equal when
1780d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        using the Intersection RNN).
1781d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    """
1782d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    sigmoid = math_ops.sigmoid
1783d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    tanh = math_ops.tanh
1784d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1785d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    input_size = inputs.get_shape().with_rank(2)[1]
1786d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    if input_size.value is None:
1787d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
1788d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1789ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    with vs.variable_scope(
1790ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        vs.get_variable_scope(), initializer=self._initializer):
1791d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      # read-in projections (should be used for first layer in deep +RNN
1792d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      # to transform size of inputs from I --> N)
1793d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      if input_size.value != self._num_units:
1794d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        if self._num_input_proj:
1795d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower          with vs.variable_scope("in_projection"):
17963f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower            if self._linear1 is None:
17973f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower              self._linear1 = _Linear(inputs, self._num_units, True)
17983f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower            inputs = self._linear1(inputs)
1799d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower        else:
1800d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower          raise ValueError("Must have input size == output size for "
1801d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower                           "Intersection RNN. To fix, num_in_proj should "
1802d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower                           "be set to num_units at cell init.")
1803d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1804d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      n_dim = i_dim = self._num_units
1805d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      cell_inputs = array_ops.concat([inputs, state], 1)
18063f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      if self._linear2 is None:
1807ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True)
18083f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      rnn_matrix = self._linear2(cell_inputs)
1809d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1810ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      gh_act = rnn_matrix[:, :n_dim]  # b x n
1811ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      h_act = rnn_matrix[:, n_dim:2 * n_dim]  # b x n
1812ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim]  # b x i
1813ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim]  # b x i
1814d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1815d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      h = tanh(h_act)
1816d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      y = self._y_activation(y_act)
1817d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      gh = sigmoid(gh_act + self._forget_bias)
1818d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      gy = sigmoid(gy_act + self._forget_bias)
1819d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1820d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_state = gh * state + (1.0 - gh) * h  # passed thru time
1821d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower      new_y = gy * inputs + (1.0 - gy) * y  # passed thru depth
1822d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1823d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower    return new_y, new_state
1824d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1825d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower
1826bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo_REGISTERED_OPS = None
1827bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1828bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1829827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass CompiledWrapper(rnn_cell_impl.RNNCell):
1830bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  """Wraps step execution in an XLA JIT scope."""
1831bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1832bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def __init__(self, cell, compile_stateful=False):
1833bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    """Create CompiledWrapper cell.
1834bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1835bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    Args:
1836bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      cell: Instance of `RNNCell`.
1837bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      compile_stateful: Whether to compile stateful ops like initializers
1838bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        and random number generators (default: False).
1839bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    """
1840bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    self._cell = cell
1841bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    self._compile_stateful = compile_stateful
1842bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1843bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  @property
1844bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def state_size(self):
1845bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    return self._cell.state_size
1846bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1847bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  @property
1848bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def output_size(self):
1849bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    return self._cell.output_size
1850bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
185103abac7f23e6e6864949b959435282729192692eEugene Brevdo  def zero_state(self, batch_size, dtype):
185203abac7f23e6e6864949b959435282729192692eEugene Brevdo    with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
185303abac7f23e6e6864949b959435282729192692eEugene Brevdo      return self._cell.zero_state(batch_size, dtype)
185403abac7f23e6e6864949b959435282729192692eEugene Brevdo
1855bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo  def __call__(self, inputs, state, scope=None):
1856bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    if self._compile_stateful:
1857bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      compile_ops = True
1858bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    else:
1859ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
1860bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo      def compile_ops(node_def):
1861bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        global _REGISTERED_OPS
1862bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        if _REGISTERED_OPS is None:
1863bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo          _REGISTERED_OPS = op_def_registry.get_registered_ops()
1864bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo        return not _REGISTERED_OPS[node_def.op].is_stateful
1865bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo
1866bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo    with jit.experimental_jit_scope(compile_ops=compile_ops):
1867dd6f9d5f43870dc39dbed91c6897dc4bb22ca495Eugene Brevdo      return self._cell(inputs, state, scope=scope)
1868bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1869bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1870ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xiedef _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32):
1871bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  """Returns an exponential distribution initializer.
1872bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1873bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  Args:
1874bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    minval: float or a scalar float Tensor. With value > 0. Lower bound of the
1875bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        range of random values to generate.
1876bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    maxval: float or a scalar float Tensor. With value > minval. Upper bound of
1877bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        the range of random values to generate.
1878bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    seed: An integer. Used to create random seeds.
1879bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    dtype: The data type.
1880bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1881bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  Returns:
1882bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    An initializer that generates tensors with an exponential distribution.
1883bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  """
1884bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1885bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  def _initializer(shape, dtype=dtype, partition_info=None):
1886bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    del partition_info  # Unused.
1887bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    return math_ops.exp(
1888bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        random_ops.random_uniform(
1889ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            shape, math_ops.log(minval), math_ops.log(maxval), dtype,
1890bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower            seed=seed))
1891bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1892bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  return _initializer
1893bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1894bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1895827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass PhasedLSTMCell(rnn_cell_impl.RNNCell):
1896bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  """Phased LSTM recurrent network cell.
1897bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1898bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  https://arxiv.org/pdf/1610.09513v1.pdf
1899bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  """
1900bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1901bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  def __init__(self,
1902bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               num_units,
1903bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               use_peepholes=False,
1904bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               leak=0.001,
19050a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower               ratio_on=0.1,
19060a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower               trainable_ratio_on=True,
1907bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               period_init_min=1.0,
1908bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               period_init_max=1000.0,
1909bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower               reuse=None):
1910bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    """Initialize the Phased LSTM cell.
1911bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1912bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    Args:
1913bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      num_units: int, The number of units in the Phased LSTM cell.
1914bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      use_peepholes: bool, set True to enable peephole connections.
1915bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      leak: float or scalar float Tensor with value in [0, 1]. Leak applied
1916bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          during training.
19170a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower      ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
1918bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          period during which the gates are open.
19190a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower      trainable_ratio_on: bool, weather ratio_on is trainable.
1920bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      period_init_min: float or scalar float Tensor. With value > 0.
19211b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu          Minimum value of the initialized period.
1922bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          The period values are initialized by drawing from the distribution:
1923bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          e^U(log(period_init_min), log(period_init_max))
1924bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower          Where U(.,.) is the uniform distribution.
1925bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      period_init_max: float or scalar float Tensor.
19261b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu          With value > period_init_min. Maximum value of the initialized period.
1927bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      reuse: (optional) Python boolean describing whether to reuse variables
1928bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        in an existing scope. If not `True`, and the existing scope already has
1929bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        the given variables, an error is raised.
1930bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    """
1931e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    super(PhasedLSTMCell, self).__init__(_reuse=reuse)
1932bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._num_units = num_units
1933bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._use_peepholes = use_peepholes
1934bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._leak = leak
19350a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    self._ratio_on = ratio_on
19360a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    self._trainable_ratio_on = trainable_ratio_on
1937bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._period_init_min = period_init_min
1938bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._period_init_max = period_init_max
1939bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    self._reuse = reuse
19403f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    self._linear1 = None
19413f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    self._linear2 = None
19423f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    self._linear3 = None
1943bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1944bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  @property
1945bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  def state_size(self):
1946827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo    return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units)
1947bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1948bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  @property
1949bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  def output_size(self):
1950bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    return self._num_units
1951bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1952bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower  def _mod(self, x, y):
1953bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    """Modulo function that propagates x gradients."""
1954bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x
1955bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
19560a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower  def _get_cycle_ratio(self, time, phase, period):
19570a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    """Compute the cycle ratio in the dtype of the time."""
19580a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    phase_casted = math_ops.cast(phase, dtype=time.dtype)
19590a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    period_casted = math_ops.cast(period, dtype=time.dtype)
19600a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    shifted_time = time - phase_casted
19610a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
19620a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower    return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
19630a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower
1964e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo  def call(self, inputs, state):
1965bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    """Phased LSTM Cell.
1966bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1967bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    Args:
19680a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower      inputs: A tuple of 2 Tensor.
19690a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower         The first Tensor has shape [batch, 1], and type float32 or float64.
19700a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower         It stores the time.
19710a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower         The second Tensor has shape [batch, features_size], and type float32.
19720a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower         It stores the features.
1973827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo      state: rnn_cell_impl.LSTMStateTuple, state from previous timestep.
1974bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1975bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    Returns:
1976bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      A tuple containing:
1977bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower      - A Tensor of float32, and shape [batch_size, num_units], representing the
1978bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        output of the cell.
1979827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo      - A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape
1980bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower        [batch_size, num_units], representing the new state and the output.
1981bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower    """
1982e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    (c_prev, h_prev) = state
1983e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    (time, x) = inputs
1984bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1985e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    in_mask_gates = [x, h_prev]
1986e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
1987e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      in_mask_gates.append(c_prev)
1988bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1989e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("mask_gates"):
19903f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      if self._linear1 is None:
19913f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower        self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True)
19923f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower
1993ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates))
1994e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      [input_gate, forget_gate] = array_ops.split(
1995e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo          axis=1, num_or_size_splits=2, value=mask_gates)
1996bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
1997e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("new_input"):
19983f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      if self._linear2 is None:
19993f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower        self._linear2 = _Linear([x, h_prev], self._num_units, True)
20003f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      new_input = math_ops.tanh(self._linear2([x, h_prev]))
2001bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2002e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_c = (c_prev * forget_gate + input_gate * new_input)
2003bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2004e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    in_out_gate = [x, h_prev]
2005e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    if self._use_peepholes:
2006e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo      in_out_gate.append(new_c)
2007bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2008e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    with vs.variable_scope("output_gate"):
20093f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      if self._linear3 is None:
20103f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower        self._linear3 = _Linear(in_out_gate, self._num_units, True)
20113f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower      output_gate = math_ops.sigmoid(self._linear3(in_out_gate))
2012bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2013e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_h = math_ops.tanh(new_c) * output_gate
2014bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2015e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    period = vs.get_variable(
2016e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "period", [self._num_units],
2017ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        initializer=_random_exp_initializer(self._period_init_min,
2018ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                            self._period_init_max))
2019e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    phase = vs.get_variable(
2020e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "phase", [self._num_units],
2021ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        initializer=init_ops.random_uniform_initializer(0.,
2022ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                                        period.initial_value))
2023e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    ratio_on = vs.get_variable(
2024e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        "ratio_on", [self._num_units],
2025e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        initializer=init_ops.constant_initializer(self._ratio_on),
2026e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo        trainable=self._trainable_ratio_on)
2027bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2028e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    cycle_ratio = self._get_cycle_ratio(time, phase, period)
2029bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2030e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    k_up = 2 * cycle_ratio / ratio_on
2031e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    k_down = 2 - k_up
2032e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    k_closed = self._leak * cycle_ratio
2033bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2034e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
2035e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
2036bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2037e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_c = k * new_c + (1 - k) * c_prev
2038e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    new_h = k * new_h + (1 - k) * h_prev
2039bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2040827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo    new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
2041bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower
2042e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo    return new_h, new_state
2043ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2044ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
204528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlowerclass ConvLSTMCell(rnn_cell_impl.RNNCell):
204628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  """Convolutional LSTM recurrent network cell.
204728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
204828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  https://arxiv.org/pdf/1506.04214v1.pdf
204928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  """
205028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
205128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  def __init__(self,
205228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower               conv_ndims,
205328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower               input_shape,
205428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower               output_channels,
205528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower               kernel_shape,
205628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower               use_bias=True,
205728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower               skip_connection=False,
205828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower               forget_bias=1.0,
205928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower               initializers=None,
206028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower               name="conv_lstm_cell"):
206128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    """Construct ConvLSTMCell.
206228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    Args:
206328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      conv_ndims: Convolution dimensionality (1, 2 or 3).
206428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      input_shape: Shape of the input as int tuple, excluding the batch size.
206528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      output_channels: int, number of output channels of the conv LSTM.
206628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3).
206728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      use_bias: Use bias in convolutions.
206828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      skip_connection: If set to `True`, concatenate the input to the
206928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      output of the conv LSTM. Default: `False`.
207028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      forget_bias: Forget bias.
207128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      name: Name of the module.
207228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    Raises:
207328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      ValueError: If `skip_connection` is `True` and stride is different from 1
207428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower        or if `input_shape` is incompatible with `conv_ndims`.
207528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    """
207628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    super(ConvLSTMCell, self).__init__(name=name)
207728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
2078ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    if conv_ndims != len(input_shape) - 1:
207928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      raise ValueError("Invalid input_shape {} for conv_ndims={}.".format(
208028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower          input_shape, conv_ndims))
208128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
208228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    self._conv_ndims = conv_ndims
208328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    self._input_shape = input_shape
208428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    self._output_channels = output_channels
208528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    self._kernel_shape = kernel_shape
208628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    self._use_bias = use_bias
208728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    self._forget_bias = forget_bias
208828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    self._skip_connection = skip_connection
208928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
209028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    self._total_output_channels = output_channels
209128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    if self._skip_connection:
209228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      self._total_output_channels += self._input_shape[-1]
209328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
2094b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    state_size = tensor_shape.TensorShape(
2095b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        self._input_shape[:-1] + [self._output_channels])
209628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size)
2097ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    self._output_size = tensor_shape.TensorShape(
2098ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        self._input_shape[:-1] + [self._total_output_channels])
209928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
210028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  @property
210128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  def output_size(self):
210228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    return self._output_size
210328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
210428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  @property
210528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  def state_size(self):
210628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    return self._state_size
210728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
210828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  def call(self, inputs, state, scope=None):
210928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    cell, hidden = state
2110ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    new_hidden = _conv([inputs, hidden], self._kernel_shape,
2111ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                       4 * self._output_channels, self._use_bias)
2112ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    gates = array_ops.split(
2113ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1)
211428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
211528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    input_gate, new_input, forget_gate, output_gate = gates
211628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell
211728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input)
211828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate)
211928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
212028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    if self._skip_connection:
212128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      output = array_ops.concat([output, inputs], axis=-1)
212228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output)
212328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    return output, new_state
212428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
2125ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
212628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlowerclass Conv1DLSTMCell(ConvLSTMCell):
212728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  """1D Convolutional LSTM recurrent network cell.
212828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
212928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  https://arxiv.org/pdf/1506.04214v1.pdf
213028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  """
2131ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
213228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  def __init__(self, name="conv_1d_lstm_cell", **kwargs):
213328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    """Construct Conv1DLSTM. See `ConvLSTMCell` for more details."""
213428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs)
213528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
2136ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
213728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlowerclass Conv2DLSTMCell(ConvLSTMCell):
213828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  """2D Convolutional LSTM recurrent network cell.
213928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
214028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  https://arxiv.org/pdf/1506.04214v1.pdf
214128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  """
2142ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
214328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  def __init__(self, name="conv_2d_lstm_cell", **kwargs):
214428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    """Construct Conv2DLSTM. See `ConvLSTMCell` for more details."""
214528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs)
214628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
2147ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
214828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlowerclass Conv3DLSTMCell(ConvLSTMCell):
214928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  """3D Convolutional LSTM recurrent network cell.
215028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
215128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  https://arxiv.org/pdf/1506.04214v1.pdf
215228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  """
2153ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
215428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  def __init__(self, name="conv_3d_lstm_cell", **kwargs):
215528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    """Construct Conv3DLSTM. See `ConvLSTMCell` for more details."""
215628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs)
215728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
2158b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2159b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengdef _conv(args, filter_size, num_features, bias, bias_start=0.0):
216028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  """convolution:
216128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  Args:
216226f43e6a8e1c234060096f21f1fd57d3cf57cfbcA. Unique TensorFlower    args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D,
216328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    batch x n, Tensors.
216428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    filter_size: int tuple of filter height and width.
216528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    num_features: int, number of features.
216628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    bias_start: starting value to initialize the bias; 0 by default.
216728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  Returns:
216828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    A 3D, 4D, or 5D Tensor with shape [batch ... num_features]
216928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  Raises:
217028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    ValueError: if some of the arguments has unspecified or wrong shape.
217128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  """
217228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
217328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  # Calculate the total size of arguments on dimension 1.
217428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  total_arg_size_depth = 0
217528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  shapes = [a.get_shape().as_list() for a in args]
217628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  shape_length = len(shapes[0])
217728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  for shape in shapes:
2178ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    if len(shape) not in [3, 4, 5]:
21795eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      raise ValueError("Conv Linear expects 3D, 4D "
21805eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                       "or 5D arguments: %s" % str(shapes))
218128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    if len(shape) != len(shapes[0]):
21825eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen      raise ValueError("Conv Linear expects all args "
21835eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen                       "to be of same Dimension: %s" % str(shapes))
218428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    else:
218528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      total_arg_size_depth += shape[-1]
218628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  dtype = [a.dtype for a in args][0]
218728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
218828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  # determine correct conv operation
2189ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  if shape_length == 3:
219028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    conv_op = nn_ops.conv1d
219128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    strides = 1
219228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  elif shape_length == 4:
219328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    conv_op = nn_ops.conv2d
2194ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    strides = shape_length * [1]
219528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  elif shape_length == 5:
219628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    conv_op = nn_ops.conv3d
2197ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    strides = shape_length * [1]
219828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower
219928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  # Now the computation.
220028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  kernel = vs.get_variable(
2201ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype)
220228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  if len(args) == 1:
2203ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    res = conv_op(args[0], kernel, strides, padding="SAME")
220428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  else:
2205ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    res = conv_op(
2206ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        array_ops.concat(axis=shape_length - 1, values=args),
2207ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        kernel,
2208ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        strides,
2209ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        padding="SAME")
221028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  if not bias:
221128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower    return res
221228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  bias_term = vs.get_variable(
221328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      "biases", [num_features],
221428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower      dtype=dtype,
2215ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
221628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower  return res + bias_term
2217ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2218ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
2219827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass GLSTMCell(rnn_cell_impl.RNNCell):
2220ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  """Group LSTM cell (G-LSTM).
2221ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2222ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  The implementation is based on:
2223ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2224ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    https://arxiv.org/abs/1703.10722
2225ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2226ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  O. Kuchaiev and B. Ginsburg
2227ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  "Factorization Tricks for LSTM Networks", ICLR 2017 workshop.
2228ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  """
2229ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2230ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self,
2231ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_units,
2232ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               initializer=None,
2233ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_proj=None,
2234ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               number_of_groups=1,
2235ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               forget_bias=1.0,
2236ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               activation=math_ops.tanh,
2237ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner               reuse=None):
2238ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    """Initialize the parameters of G-LSTM cell.
2239ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2240ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    Args:
2241ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      num_units: int, The number of units in the G-LSTM cell
2242ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      initializer: (optional) The initializer to use for the weight and
2243ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        projection matrices.
2244ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      num_proj: (optional) int, The output dimensionality for the projection
2245ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        matrices.  If None, no projection is performed.
2246ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      number_of_groups: (optional) int, number of groups to use.
2247ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        If `number_of_groups` is 1, then it should be equivalent to LSTM cell
2248ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      forget_bias: Biases of the forget gate are initialized by default to 1
2249ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        in order to reduce the scale of forgetting at the beginning of
2250ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        the training.
2251ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      activation: Activation function of the inner states.
2252ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      reuse: (optional) Python boolean describing whether to reuse variables
2253ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        in an existing scope.  If not `True`, and the existing scope already
2254ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        has the given variables, an error is raised.
2255ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2256ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    Raises:
225726f43e6a8e1c234060096f21f1fd57d3cf57cfbcA. Unique TensorFlower      ValueError: If `num_units` or `num_proj` is not divisible by
2258ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        `number_of_groups`.
2259ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    """
2260ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    super(GLSTMCell, self).__init__(_reuse=reuse)
2261ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    self._num_units = num_units
2262ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    self._initializer = initializer
2263ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    self._num_proj = num_proj
2264ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    self._forget_bias = forget_bias
2265ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    self._activation = activation
2266ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    self._number_of_groups = number_of_groups
2267ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2268ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    if self._num_units % self._number_of_groups != 0:
2269ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      raise ValueError("num_units must be divisible by number_of_groups")
2270ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    if self._num_proj:
2271ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      if self._num_proj % self._number_of_groups != 0:
2272ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        raise ValueError("num_proj must be divisible by number_of_groups")
2273ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      self._group_shape = [
2274ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          int(self._num_proj / self._number_of_groups),
2275ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          int(self._num_units / self._number_of_groups)
2276ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      ]
2277ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    else:
2278ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      self._group_shape = [
2279ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          int(self._num_units / self._number_of_groups),
2280ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          int(self._num_units / self._number_of_groups)
2281ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      ]
2282ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2283ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    if num_proj:
2284827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
2285ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      self._output_size = num_proj
2286ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    else:
2287827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
2288ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      self._output_size = num_units
2289d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar    self._linear1 = [None] * number_of_groups
22903f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower    self._linear2 = None
2291ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2292ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  @property
2293ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  def state_size(self):
2294ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    return self._state_size
2295ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2296ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  @property
2297ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  def output_size(self):
2298ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    return self._output_size
2299ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2300ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  def _get_input_for_group(self, inputs, group_id, group_size):
2301ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    """Slices inputs into groups to prepare for processing by cell's groups
2302ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2303ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    Args:
2304ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      inputs: cell input or it's previous state,
2305ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner              a Tensor, 2D, [batch x num_units]
2306ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      group_id: group id, a Scalar, for which to prepare input
2307ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      group_size: size of the group
2308ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2309ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    Returns:
2310ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      subset of inputs corresponding to group "group_id",
2311ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      a Tensor, 2D, [batch x num_units/number_of_groups]
2312ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    """
2313ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    return array_ops.slice(
2314ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        input_=inputs,
2315ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        begin=[0, group_id * group_size],
2316ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        size=[self._batch_size, group_size],
2317ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        name=("GLSTM_group%d_input_generation" % group_id))
2318ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2319ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner  def call(self, inputs, state):
2320ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    """Run one step of G-LSTM.
2321ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2322ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    Args:
2323ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      inputs: input Tensor, 2D, [batch x num_units].
2324ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      state: this must be a tuple of state Tensors, both `2-D`,
2325ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      with column sizes `c_state` and `m_state`.
2326ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2327ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    Returns:
2328ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      A tuple containing:
2329ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2330ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2331ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        G-LSTM after reading `inputs` when previous state was `state`.
2332ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        Here output_dim is:
2333ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner           num_proj if num_proj was set,
2334ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner           num_units otherwise.
23350815de21239955e346b562e899640649c8d2b9cbBenoit Steiner      - LSTMStateTuple representing the new state of G-LSTM cell
2336ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        after reading `inputs` when the previous state was `state`.
2337ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2338ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    Raises:
2339ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      ValueError: If input size cannot be inferred from inputs via
2340ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        static shape inference.
2341ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    """
2342ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    (c_prev, m_prev) = state
2343ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2344ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
2345ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    dtype = inputs.dtype
2346ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    scope = vs.get_variable_scope()
2347ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    with vs.variable_scope(scope, initializer=self._initializer):
2348ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      i_parts = []
2349ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      j_parts = []
2350ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      f_parts = []
2351ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      o_parts = []
2352ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2353ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      for group_id in range(self._number_of_groups):
2354ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        with vs.variable_scope("group%d" % group_id):
2355ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner          x_g_id = array_ops.concat(
2356ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie              [
2357ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                  self._get_input_for_group(inputs, group_id,
2358ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                            self._group_shape[0]),
2359ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                  self._get_input_for_group(m_prev, group_id,
2360ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                                            self._group_shape[0])
2361ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie              ],
2362ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie              axis=1)
2363d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar          linear = self._linear1[group_id]
2364d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar          if linear is None:
2365d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar            linear = _Linear(x_g_id, 4 * self._group_shape[1], False)
2366d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar            self._linear1[group_id] = linear
2367d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar          R_k = linear(x_g_id)  # pylint: disable=invalid-name
2368ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner          i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1)
2369ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2370ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        i_parts.append(i_k)
2371ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        j_parts.append(j_k)
2372ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        f_parts.append(f_k)
2373ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner        o_parts.append(o_k)
2374ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2375ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      bi = vs.get_variable(
2376ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          name="bias_i",
2377ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          shape=[self._num_units],
2378ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          dtype=dtype,
2379ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2380ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      bj = vs.get_variable(
2381ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          name="bias_j",
2382ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          shape=[self._num_units],
2383ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          dtype=dtype,
2384ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2385ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      bf = vs.get_variable(
2386ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          name="bias_f",
2387ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          shape=[self._num_units],
2388ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          dtype=dtype,
2389ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2390ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      bo = vs.get_variable(
2391ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          name="bias_o",
2392ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          shape=[self._num_units],
2393ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          dtype=dtype,
2394ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          initializer=init_ops.constant_initializer(0.0, dtype=dtype))
2395ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2396ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi)
2397ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj)
2398ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf)
2399ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo)
2400ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2401ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    c = (
2402ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        math_ops.sigmoid(f + self._forget_bias) * c_prev +
2403ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        math_ops.sigmoid(i) * math_ops.tanh(j))
2404ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    m = math_ops.sigmoid(o) * self._activation(c)
2405ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2406ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    if self._num_proj is not None:
2407ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner      with vs.variable_scope("projection"):
24083f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower        if self._linear2 is None:
24093f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower          self._linear2 = _Linear(m, self._num_proj, False)
24103f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower        m = self._linear2(m)
2411ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner
2412827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo    new_state = rnn_cell_impl.LSTMStateTuple(c, m)
2413ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner    return m, new_state
2414b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2415b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2416b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengclass LayerNormLSTMCell(rnn_cell_impl.RNNCell):
2417b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  """Long short-term memory unit (LSTM) recurrent network cell.
2418b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2419b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  The default non-peephole implementation is based on:
2420b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2421b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    http://www.bioinf.jku.at/publications/older/2604.pdf
2422b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2423b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  S. Hochreiter and J. Schmidhuber.
2424b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
2425b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2426b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  The peephole implementation is based on:
2427b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2428b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    https://research.google.com/pubs/archive/43905.pdf
2429b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2430b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  Hasim Sak, Andrew Senior, and Francoise Beaufays.
2431b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  "Long short-term memory recurrent neural network architectures for
2432b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng   large scale acoustic modeling." INTERSPEECH, 2014.
2433b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2434b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  The class uses optional peep-hole connections, optional cell clipping, and
2435b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  an optional projection layer.
2436b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2437b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  Layer normalization implementation is based on:
2438b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2439b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    https://arxiv.org/abs/1607.06450.
2440b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2441b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  "Layer Normalization"
2442b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
2443b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2444b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  and is applied before the internal nonlinearities.
2445b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2446b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  """
2447b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2448b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  def __init__(self,
2449b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               num_units,
2450b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               use_peepholes=False,
2451b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               cell_clip=None,
2452b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               initializer=None,
2453b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               num_proj=None,
2454b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               proj_clip=None,
2455b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               forget_bias=1.0,
2456b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               activation=None,
2457b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               layer_norm=False,
2458b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               norm_gain=1.0,
2459b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               norm_shift=0.0,
2460b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng               reuse=None):
2461b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    """Initialize the parameters for an LSTM cell.
2462b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2463b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    Args:
2464b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      num_units: int, The number of units in the LSTM cell
2465b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      use_peepholes: bool, set True to enable diagonal/peephole connections.
2466b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      cell_clip: (optional) A float value, if provided the cell state is clipped
2467b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        by this value prior to the cell output activation.
2468b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      initializer: (optional) The initializer to use for the weight and
2469b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        projection matrices.
2470b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      num_proj: (optional) int, The output dimensionality for the projection
2471b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        matrices.  If None, no projection is performed.
2472b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
2473b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        provided, then the projected values are clipped elementwise to within
2474b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        `[-proj_clip, proj_clip]`.
2475b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      forget_bias: Biases of the forget gate are initialized by default to 1
2476b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        in order to reduce the scale of forgetting at the beginning of
2477b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        the training. Must set it manually to `0.0` when restoring from
2478b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        CudnnLSTM trained checkpoints.
2479b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      activation: Activation function of the inner states.  Default: `tanh`.
2480b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      layer_norm: If `True`, layer normalization will be applied.
2481b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      norm_gain: float, The layer normalization gain initial value. If
2482b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        `layer_norm` has been set to `False`, this argument will be ignored.
2483b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      norm_shift: float, The layer normalization shift initial value. If
2484b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        `layer_norm` has been set to `False`, this argument will be ignored.
2485b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      reuse: (optional) Python boolean describing whether to reuse variables
2486b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        in an existing scope.  If not `True`, and the existing scope already has
2487b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        the given variables, an error is raised.
2488b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2489b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      When restoring from CudnnLSTM-trained checkpoints, must use
2490b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      CudnnCompatibleLSTMCell instead.
2491b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    """
2492b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    super(LayerNormLSTMCell, self).__init__(_reuse=reuse)
2493b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2494b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._num_units = num_units
2495b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._use_peepholes = use_peepholes
2496b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._cell_clip = cell_clip
2497b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._initializer = initializer
2498b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._num_proj = num_proj
2499b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._proj_clip = proj_clip
2500b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._forget_bias = forget_bias
2501b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._activation = activation or math_ops.tanh
2502b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._layer_norm = layer_norm
2503b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._norm_gain = norm_gain
2504b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    self._norm_shift = norm_shift
2505b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2506b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    if num_proj:
2507b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj))
2508b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      self._output_size = num_proj
2509b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    else:
2510b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units))
2511b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      self._output_size = num_units
2512b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2513b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  @property
2514b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  def state_size(self):
2515b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    return self._state_size
2516b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2517b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  @property
2518b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  def output_size(self):
2519b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    return self._output_size
2520b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2521b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  def _linear(self,
2522b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng              args,
2523b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng              output_size,
2524b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng              bias,
2525b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng              bias_initializer=None,
2526b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng              kernel_initializer=None,
2527b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng              layer_norm=False):
2528b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable.
2529b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2530b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    Args:
2531b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      args: a 2D Tensor or a list of 2D, batch x n, Tensors.
2532b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      output_size: int, second dimension of W[i].
2533b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      bias: boolean, whether to add a bias term or not.
2534b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      bias_initializer: starting value to initialize the bias
2535b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        (default is all zeros).
2536b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      kernel_initializer: starting value to initialize the weight.
2537b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      layer_norm: boolean, whether to apply layer normalization.
2538b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2539b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2540b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    Returns:
2541b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      A 2D Tensor with shape [batch x output_size] taking value
2542b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      sum_i(args[i] * W[i]), where each W[i] is a newly created Variable.
2543b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2544b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    Raises:
2545b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      ValueError: if some of the arguments has unspecified or wrong shape.
2546b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    """
2547b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    if args is None or (nest.is_sequence(args) and not args):
2548b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      raise ValueError("`args` must be specified")
2549b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    if not nest.is_sequence(args):
2550b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      args = [args]
2551b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2552b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    # Calculate the total size of arguments on dimension 1.
2553b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    total_arg_size = 0
2554b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    shapes = [a.get_shape() for a in args]
2555b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    for shape in shapes:
2556b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if shape.ndims != 2:
2557b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        raise ValueError("linear is expecting 2D arguments: %s" % shapes)
2558b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if shape[1].value is None:
2559b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        raise ValueError("linear expects shape[1] to be provided for shape %s, "
2560b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng                         "but saw %s" % (shape, shape[1]))
2561b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      else:
2562b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        total_arg_size += shape[1].value
2563b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2564b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    dtype = [a.dtype for a in args][0]
2565b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2566b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    # Now the computation.
2567b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    scope = vs.get_variable_scope()
2568b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    with vs.variable_scope(scope) as outer_scope:
2569b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      weights = vs.get_variable(
2570b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          "kernel", [total_arg_size, output_size],
2571b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          dtype=dtype,
2572b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          initializer=kernel_initializer)
2573b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if len(args) == 1:
2574b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        res = math_ops.matmul(args[0], weights)
2575b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      else:
2576b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        res = math_ops.matmul(array_ops.concat(args, 1), weights)
2577b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if not bias:
2578b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        return res
2579b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      with vs.variable_scope(outer_scope) as inner_scope:
2580b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        inner_scope.set_partitioner(None)
2581b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        if bias_initializer is None:
2582b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
2583b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        biases = vs.get_variable(
2584b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng            "bias", [output_size], dtype=dtype, initializer=bias_initializer)
2585b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2586b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    if not layer_norm:
2587b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      res = nn_ops.bias_add(res, biases)
2588b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2589b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    return res
2590b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2591b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng  def call(self, inputs, state):
2592b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    """Run one step of LSTM.
2593b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2594b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    Args:
2595b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      inputs: input Tensor, 2D, batch x num_units.
2596b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      state: this must be a tuple of state Tensors,
2597b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng       both `2-D`, with column sizes `c_state` and
2598b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        `m_state`.
2599b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2600b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    Returns:
2601b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      A tuple containing:
2602b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2603b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2604b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        LSTM after reading `inputs` when previous state was `state`.
2605b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        Here output_dim is:
2606b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng           num_proj if num_proj was set,
2607b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng           num_units otherwise.
2608b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      - Tensor(s) representing the new state of LSTM after reading `inputs` when
2609b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        the previous state was `state`.  Same type and shape(s) as `state`.
2610b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2611b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    Raises:
2612b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      ValueError: If input size cannot be inferred from inputs via
2613b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        static shape inference.
2614b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    """
2615b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    sigmoid = math_ops.sigmoid
2616b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2617b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    (c_prev, m_prev) = state
2618b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2619b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    dtype = inputs.dtype
2620b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    input_size = inputs.get_shape().with_rank(2)[1]
2621b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    if input_size.value is None:
2622b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
2623b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    scope = vs.get_variable_scope()
2624b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
2625b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2626b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
2627b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      lstm_matrix = self._linear(
2628b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          [inputs, m_prev],
2629b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          4 * self._num_units,
2630b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          bias=True,
2631b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          bias_initializer=None,
2632b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          layer_norm=self._layer_norm)
2633b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      i, j, f, o = array_ops.split(
2634b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          value=lstm_matrix, num_or_size_splits=4, axis=1)
2635b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2636b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if self._layer_norm:
2637b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        i = _norm(self._norm_gain, self._norm_shift, i, "input")
2638b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        j = _norm(self._norm_gain, self._norm_shift, j, "transform")
2639b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        f = _norm(self._norm_gain, self._norm_shift, f, "forget")
2640b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        o = _norm(self._norm_gain, self._norm_shift, o, "output")
2641b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2642b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      # Diagonal connections
2643b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if self._use_peepholes:
2644b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        with vs.variable_scope(unit_scope):
2645b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          w_f_diag = vs.get_variable(
2646b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng              "w_f_diag", shape=[self._num_units], dtype=dtype)
2647b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          w_i_diag = vs.get_variable(
2648b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng              "w_i_diag", shape=[self._num_units], dtype=dtype)
2649b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          w_o_diag = vs.get_variable(
2650b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng              "w_o_diag", shape=[self._num_units], dtype=dtype)
2651b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2652b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if self._use_peepholes:
2653b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        c = (
2654b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng            sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
2655b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng            sigmoid(i + w_i_diag * c_prev) * self._activation(j))
2656b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      else:
2657b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        c = (
2658b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng            sigmoid(f + self._forget_bias) * c_prev +
2659b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng            sigmoid(i) * self._activation(j))
2660b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2661b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if self._layer_norm:
2662b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        c = _norm(self._norm_gain, self._norm_shift, c, "state")
2663b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2664b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if self._cell_clip is not None:
2665b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        # pylint: disable=invalid-unary-operand-type
2666b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
2667b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        # pylint: enable=invalid-unary-operand-type
2668b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if self._use_peepholes:
2669b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        m = sigmoid(o + w_o_diag * c) * self._activation(c)
2670b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      else:
2671b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        m = sigmoid(o) * self._activation(c)
2672b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2673b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng      if self._num_proj is not None:
2674b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        with vs.variable_scope("projection"):
2675b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          m = self._linear(m, self._num_proj, bias=False)
2676b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2677b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng        if self._proj_clip is not None:
2678b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          # pylint: disable=invalid-unary-operand-type
2679b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
2680b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng          # pylint: enable=invalid-unary-operand-type
2681b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng
2682b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    new_state = (rnn_cell_impl.LSTMStateTuple(c, m))
2683b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng    return m, new_state
268420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
268520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
2686f98264a1b9916e46a88089b605e962265ecde1a6Eugene Brevdoclass SRUCell(rnn_cell_impl.LayerRNNCell):
268720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen  """SRU, Simple Recurrent Unit
2688ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
268920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen     Implementation based on
269020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen     Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755).
269120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
2692ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie     This variation of RNN cell is characterized by the simplified data
2693ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie     dependence
269420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen     between hidden states of two consecutive time steps. Traditionally, hidden
269520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen     states from a cell at time step t-1 needs to be multiplied with a matrix
269620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen     W_hh before being fed into the ensuing cell at time step t.
269720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen     This flavor of RNN replaces the matrix multiplication between h_{t-1}
269820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen     and W_hh with a pointwise multiplication, resulting in performance
269920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen     gain.
270020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
270120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen  Args:
270220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    num_units: int, The number of units in the SRU cell.
270320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    activation: Nonlinearity to use.  Default: `tanh`.
270420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    reuse: (optional) Python boolean describing whether to reuse variables
270520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen      in an existing scope.  If not `True`, and the existing scope already has
270620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen      the given variables, an error is raised.
270720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    name: (optional) String, the name of the layer. Layers with the same name
270820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen      will share weights, but to avoid mistakes we require reuse=True in such
270920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen      cases.
271020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen  """
2711ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie
2712ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self, num_units, activation=None, reuse=None, name=None):
271320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    super(SRUCell, self).__init__(_reuse=reuse, name=name)
271420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    self._num_units = num_units
271520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    self._activation = activation or math_ops.tanh
271620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
271720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    # Restrict inputs to be 2-dimensional matrices
271820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    self.input_spec = base_layer.InputSpec(ndim=2)
271920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
272020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen  @property
272120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen  def state_size(self):
272220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    return self._num_units
272320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
272420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen  @property
272520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen  def output_size(self):
272620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    return self._num_units
272720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
272820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen  def build(self, inputs_shape):
272920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    if inputs_shape[1].value is None:
2730ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      raise ValueError(
2731ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape)
273220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
273320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    input_depth = inputs_shape[1].value
273420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
273520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    self._kernel = self.add_variable(
273620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen        rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
2737d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        shape=[input_depth, 4 * self._num_units])
273820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
273920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    self._bias = self.add_variable(
274020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen        rnn_cell_impl._BIAS_VARIABLE_NAME,
274120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen        shape=[2 * self._num_units],
274220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen        initializer=init_ops.constant_initializer(0.0, dtype=self.dtype))
274320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
274420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    self._built = True
274520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
274620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen  def call(self, inputs, state):
274720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    """Simple recurrent unit (SRU) with num_units cells."""
274820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
274920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    U = math_ops.matmul(inputs, self._kernel)
2750d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case    x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split(
2751d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case        value=U, num_or_size_splits=4, axis=1)
275220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
2753ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    f_r = math_ops.sigmoid(
2754ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        nn_ops.bias_add(
2755ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            array_ops.concat([f_intermediate, r_intermediate], 1), self._bias))
275620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)
275720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
275820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    c = f * state + (1.0 - f) * x_bar
2759d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case    h = r * self._activation(c) + (1.0 - r) * x_tx
276020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen
276120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen    return h, c
2762d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2763d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2764d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xieclass WeightNormLSTMCell(rnn_cell_impl.RNNCell):
2765d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie  """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`.
2766d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2767d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    The weight-norm implementation is based on:
2768d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    https://arxiv.org/abs/1602.07868
2769d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Tim Salimans, Diederik P. Kingma.
2770d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Weight Normalization: A Simple Reparameterization to Accelerate
2771d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Training of Deep Neural Networks
2772d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2773d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    The default LSTM implementation based on:
2774d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    http://www.bioinf.jku.at/publications/older/2604.pdf
2775d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    S. Hochreiter and J. Schmidhuber.
2776d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
2777d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2778d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    The class uses optional peephole connections, optional cell clipping
2779d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    and an optional projection layer.
2780d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2781d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    The optional peephole implementation is based on:
2782d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    https://research.google.com/pubs/archive/43905.pdf
2783d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Hasim Sak, Andrew Senior, and Francoise Beaufays.
2784d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    "Long short-term memory recurrent neural network architectures for
2785d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    large scale acoustic modeling." INTERSPEECH, 2014.
2786d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie  """
2787d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2788ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def __init__(self,
2789ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_units,
2790ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               norm=True,
2791ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               use_peepholes=False,
2792ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               cell_clip=None,
2793ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               initializer=None,
2794ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               num_proj=None,
2795ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               proj_clip=None,
2796ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               forget_bias=1,
2797ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie               activation=None,
2798d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie               reuse=None):
2799d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    """Initialize the parameters of a weight-normalized LSTM cell.
2800d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2801d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Args:
2802d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      num_units: int, The number of units in the LSTM cell
2803d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      norm: If `True`, apply normalization to the weight matrices. If False,
2804d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        the result is identical to that obtained from `rnn_cell_impl.LSTMCell`
2805d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      use_peepholes: bool, set `True` to enable diagonal/peephole connections.
2806d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      cell_clip: (optional) A float value, if provided the cell state is clipped
2807d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        by this value prior to the cell output activation.
2808d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      initializer: (optional) The initializer to use for the weight matrices.
2809d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      num_proj: (optional) int, The output dimensionality for the projection
2810d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        matrices.  If None, no projection is performed.
2811d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      proj_clip: (optional) A float value.  If `num_proj > 0` and `proj_clip` is
2812d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        provided, then the projected values are clipped elementwise to within
2813d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        `[-proj_clip, proj_clip]`.
2814d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      forget_bias: Biases of the forget gate are initialized by default to 1
2815d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        in order to reduce the scale of forgetting at the beginning of
2816d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        the training.
2817d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      activation: Activation function of the inner states.  Default: `tanh`.
2818d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      reuse: (optional) Python boolean describing whether to reuse variables
2819d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        in an existing scope.  If not `True`, and the existing scope already has
2820d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        the given variables, an error is raised.
2821d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    """
2822d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    super(WeightNormLSTMCell, self).__init__(_reuse=reuse)
2823d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2824ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie    self._scope = "wn_lstm_cell"
2825d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._num_units = num_units
2826d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._norm = norm
2827d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._initializer = initializer
2828d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._use_peepholes = use_peepholes
2829d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._cell_clip = cell_clip
2830d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._num_proj = num_proj
2831d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._proj_clip = proj_clip
2832d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._activation = activation or math_ops.tanh
2833d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._forget_bias = forget_bias
2834d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2835d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._weights_variable_name = "kernel"
2836d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    self._bias_variable_name = "bias"
2837d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2838d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    if num_proj:
2839d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
2840d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      self._output_size = num_proj
2841d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    else:
2842d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units)
2843d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      self._output_size = num_units
2844d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2845d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie  @property
2846d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie  def state_size(self):
2847d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    return self._state_size
2848d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2849d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie  @property
2850d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie  def output_size(self):
2851d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    return self._output_size
2852d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2853d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie  def _normalize(self, weight, name):
2854d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    """Apply weight normalization.
2855d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2856d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Args:
2857d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      weight: a 2D tensor with known number of columns.
2858d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      name: string, variable name for the normalizer.
2859d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Returns:
2860d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      A tensor with the same shape as `weight`.
2861d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    """
2862d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2863d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    output_size = weight.get_shape().as_list()[1]
2864d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    g = vs.get_variable(name, [output_size], dtype=weight.dtype)
2865d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    return nn_impl.l2_normalize(weight, dim=0) * g
2866d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2867ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie  def _linear(self,
2868ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie              args,
2869d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie              output_size,
2870d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie              norm,
2871d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie              bias,
2872d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie              bias_initializer=None,
2873d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie              kernel_initializer=None):
2874d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
2875d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2876d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Args:
2877d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      args: a 2D Tensor or a list of 2D, batch x n, Tensors.
2878d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      output_size: int, second dimension of W[i].
2879d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      bias: boolean, whether to add a bias term or not.
2880d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      bias_initializer: starting value to initialize the bias
2881d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        (default is all zeros).
2882d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      kernel_initializer: starting value to initialize the weight.
2883d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2884d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Returns:
2885d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      A 2D Tensor with shape [batch x output_size] equal to
2886d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      sum_i(args[i] * W[i]), where W[i]s are newly created matrices.
2887d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2888d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Raises:
2889d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      ValueError: if some of the arguments has unspecified or wrong shape.
2890d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    """
2891d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    if args is None or (nest.is_sequence(args) and not args):
2892d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      raise ValueError("`args` must be specified")
2893d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    if not nest.is_sequence(args):
2894d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      args = [args]
2895d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2896d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    # Calculate the total size of arguments on dimension 1.
2897d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    total_arg_size = 0
2898d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    shapes = [a.get_shape() for a in args]
2899d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    for shape in shapes:
2900d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      if shape.ndims != 2:
2901d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        raise ValueError("linear is expecting 2D arguments: %s" % shapes)
2902d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      if shape[1].value is None:
2903d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        raise ValueError("linear expects shape[1] to be provided for shape %s, "
2904d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie                         "but saw %s" % (shape, shape[1]))
2905d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      else:
2906d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        total_arg_size += shape[1].value
2907d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2908d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    dtype = [a.dtype for a in args][0]
2909d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2910d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    # Now the computation.
2911d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    scope = vs.get_variable_scope()
2912d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    with vs.variable_scope(scope) as outer_scope:
2913d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      weights = vs.get_variable(
2914d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          self._weights_variable_name, [total_arg_size, output_size],
2915d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          dtype=dtype,
2916d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          initializer=kernel_initializer)
2917d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      if norm:
2918d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        wn = []
2919d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        st = 0
2920d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        with ops.control_dependencies(None):
2921d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          for i in range(len(args)):
2922d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie            en = st + shapes[i][1].value
2923ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            wn.append(
2924ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie                self._normalize(weights[st:en, :], name="norm_{}".format(i)))
2925d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie            st = en
2926d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2927d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          weights = array_ops.concat(wn, axis=0)
2928d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2929d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      if len(args) == 1:
2930d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        res = math_ops.matmul(args[0], weights)
2931d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      else:
2932d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        res = math_ops.matmul(array_ops.concat(args, 1), weights)
2933d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      if not bias:
2934d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        return res
2935d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2936d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      with vs.variable_scope(outer_scope) as inner_scope:
2937d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        inner_scope.set_partitioner(None)
2938d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        if bias_initializer is None:
2939d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
2940d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2941d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        biases = vs.get_variable(
2942d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie            self._bias_variable_name, [output_size],
2943d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie            dtype=dtype,
2944d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie            initializer=bias_initializer)
2945d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2946d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      return nn_ops.bias_add(res, biases)
2947d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2948d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie  def call(self, inputs, state):
2949d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    """Run one step of LSTM.
2950d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2951d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Args:
2952d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      inputs: input Tensor, 2D, batch x num_units.
2953d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      state: A tuple of state Tensors, both `2-D`, with column sizes
2954d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie       `c_state` and `m_state`.
2955d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2956d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Returns:
2957d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      A tuple containing:
2958d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2959d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
2960d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        LSTM after reading `inputs` when previous state was `state`.
2961d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        Here output_dim is:
2962d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie           num_proj if num_proj was set,
2963d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie           num_units otherwise.
2964d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      - Tensor(s) representing the new state of LSTM after reading `inputs` when
2965d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        the previous state was `state`.  Same type and shape(s) as `state`.
2966d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2967d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    Raises:
2968d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      ValueError: If input size cannot be inferred from inputs via
2969d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        static shape inference.
2970d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    """
2971d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    dtype = inputs.dtype
2972d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    num_units = self._num_units
2973d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    sigmoid = math_ops.sigmoid
2974d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    c, h = state
2975d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2976d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    input_size = inputs.get_shape().with_rank(2)[1]
2977d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    if input_size.value is None:
2978d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
2979d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2980d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie    with vs.variable_scope(self._scope, initializer=self._initializer):
2981d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2982ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie      concat = self._linear(
2983ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          [inputs, h], 4 * num_units, norm=self._norm, bias=True)
2984d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2985d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
2986d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
2987d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2988d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      if self._use_peepholes:
2989d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype)
2990d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype)
2991d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype)
2992d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
2993ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        new_c = (
2994ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            c * sigmoid(f + self._forget_bias + w_f_diag * c) +
2995ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            sigmoid(i + w_i_diag * c) * self._activation(j))
2996d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      else:
2997ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie        new_c = (
2998ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            c * sigmoid(f + self._forget_bias) +
2999ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie            sigmoid(i) * self._activation(j))
3000d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
3001d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      if self._cell_clip is not None:
3002d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        # pylint: disable=invalid-unary-operand-type
3003d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip)
3004d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        # pylint: enable=invalid-unary-operand-type
3005d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      if self._use_peepholes:
3006d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c)
3007d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      else:
3008d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        new_h = sigmoid(o) * self._activation(new_c)
3009d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
3010d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      if self._num_proj is not None:
3011d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        with vs.variable_scope("projection"):
3012ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          new_h = self._linear(
3013ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie              new_h, self._num_proj, norm=self._norm, bias=False)
3014d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
3015d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie        if self._proj_clip is not None:
3016d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          # pylint: disable=invalid-unary-operand-type
3017ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie          new_h = clip_ops.clip_by_value(new_h, -self._proj_clip,
3018d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie                                         self._proj_clip)
3019d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie          # pylint: enable=invalid-unary-operand-type
3020d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie
3021d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
3022d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie      return new_h, new_state
3023