rnn_cell.py revision bab22b9f25741e172bb70ff1f82dc803ced0f579
1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# 3d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License"); 4d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# you may not use this file except in compliance with the License. 5d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# You may obtain a copy of the License at 6d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# 7d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# http://www.apache.org/licenses/LICENSE-2.0 8d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# 9d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software 10d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS, 11d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# See the License for the specific language governing permissions and 13d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# limitations under the License. 14d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# ============================================================================== 15d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 16d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower"""Module for constructing RNN Cells.""" 17d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import absolute_import 18d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import division 19d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import print_function 20d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 218fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlowerimport collections 22d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerimport math 23d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 24bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.contrib.compiler import jit 2534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.contrib.layers.python.layers import layers 2637fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xiefrom tensorflow.contrib.rnn.python.ops import core_rnn_cell 2737fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xiefrom tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl 281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerfrom tensorflow.python.framework import dtypes 29bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.python.framework import op_def_registry 30d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.framework import ops 31d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import array_ops 32d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import clip_ops 3334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.python.ops import init_ops 34d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import math_ops 35d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import nn_ops 36d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import variable_scope as vs 375cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlowerfrom tensorflow.python.platform import tf_logging as logging 384c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenechfrom tensorflow.python.util import nest 39d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 4037fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie 41d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_concat_variable(name, shape, dtype, num_shards): 42d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Get a sharded variable concatenated into one tensor.""" 43d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) 44d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if len(sharded_variable) == 1: 45d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return sharded_variable[0] 46d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 47d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower concat_name = name + "/concat" 48d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" 49d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): 50d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if value.name == concat_full_name: 51d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return value 52d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 530e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) 54d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, 55d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower concat_variable) 56d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return concat_variable 57d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 58d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 59d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_sharded_variable(name, shape, dtype, num_shards): 60d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Get a list of sharded variables with the given dtype.""" 61d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if num_shards > shape[0]: 62d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower raise ValueError("Too many shards: shape=%s, num_shards=%d" % 63d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower (shape, num_shards)) 64d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower unit_shard_size = int(math.floor(shape[0] / num_shards)) 65d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower remaining_rows = shape[0] - unit_shard_size * num_shards 66d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 67d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower shards = [] 68d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for i in range(num_shards): 69d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower current_size = unit_shard_size 70d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if i < remaining_rows: 71d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower current_size += 1 72d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:], 73d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower dtype=dtype)) 74d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return shards 75d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 76d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 7737fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell): 781032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """Long short-term memory unit (LSTM) recurrent network cell. 791032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 801032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The default non-peephole implementation is based on: 811032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 821032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf 831032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 841032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower S. Hochreiter and J. Schmidhuber. 851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 861032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 871032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The peephole implementation is based on: 881032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 891032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower https://research.google.com/pubs/archive/43905.pdf 901032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 911032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Hasim Sak, Andrew Senior, and Francoise Beaufays. 921032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "Long short-term memory recurrent neural network architectures for 931032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower large scale acoustic modeling." INTERSPEECH, 2014. 941032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The coupling of input and forget gate is based on: 961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 971032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower http://arxiv.org/pdf/1503.04069.pdf 981032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 991032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Greff et al. "LSTM: A Search Space Odyssey" 1001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1011032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The class uses optional peep-hole connections, and an optional projection 1021032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower layer. 1031032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """ 1041032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower def __init__(self, num_units, use_peepholes=False, 1061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower initializer=None, num_proj=None, proj_clip=None, 1071032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_unit_shards=1, num_proj_shards=1, 1081032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower forget_bias=1.0, state_is_tuple=False, 109e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo activation=math_ops.tanh): 1101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """Initialize the parameters for an LSTM cell. 1111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1121032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Args: 1131032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_units: int, The number of units in the LSTM cell 1141032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower use_peepholes: bool, set True to enable diagonal/peephole connections. 1151032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 1161032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower projection matrices. 1171032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj: (optional) int, The output dimensionality for the projection 1181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower matrices. If None, no projection is performed. 1191032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 1201032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower provided, then the projected values are clipped elementwise to within 1211032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower `[-proj_clip, proj_clip]`. 1221032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_unit_shards: How to split the weight matrix. If >1, the weight 1231032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower matrix is stored across num_unit_shards. 1241032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj_shards: How to split the projection matrix. If >1, the 1251032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower projection matrix is stored across num_proj_shards. 1261032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower forget_bias: Biases of the forget gate are initialized by default to 1 1271032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower in order to reduce the scale of forgetting at the beginning of 1281032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower the training. 1291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower state_is_tuple: If True, accepted and returned states are 2-tuples of 1301032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower the `c_state` and `m_state`. By default (False), they are concatenated 1311032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower along the column axis. This default behavior will soon be deprecated. 1321032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower activation: Activation function of the inner states. 1331032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """ 1341032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if not state_is_tuple: 1351032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower logging.warn( 1361032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "%s: Using a concatenated state is slower and will soon be " 137d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower "deprecated. Use state_is_tuple=True.", self) 1381032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_units = num_units 1391032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._use_peepholes = use_peepholes 1401032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._initializer = initializer 1411032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_proj = num_proj 1421032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._proj_clip = proj_clip 1431032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_unit_shards = num_unit_shards 1441032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_proj_shards = num_proj_shards 1451032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._forget_bias = forget_bias 1461032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._state_is_tuple = state_is_tuple 1471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._activation = activation 1481032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1491032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if num_proj: 1501032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._state_size = ( 15137fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie core_rnn_cell.LSTMStateTuple(num_units, num_proj) 1521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if state_is_tuple else num_units + num_proj) 1531032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._output_size = num_proj 1541032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower else: 1551032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._state_size = ( 15637fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie core_rnn_cell.LSTMStateTuple(num_units, num_units) 1571032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if state_is_tuple else 2 * num_units) 1581032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._output_size = num_units 1591032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1601032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower @property 1611032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower def state_size(self): 1621032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower return self._state_size 1631032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1641032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower @property 1651032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower def output_size(self): 1661032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower return self._output_size 1671032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1681032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower def __call__(self, inputs, state, scope=None): 1691032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """Run one step of LSTM. 1701032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1711032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Args: 1721032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower inputs: input Tensor, 2D, batch x num_units. 1731032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower state: if `state_is_tuple` is False, this must be a state Tensor, 1741032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a 1751032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower tuple of state Tensors, both `2-D`, with column sizes `c_state` and 1761032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower `m_state`. 1771032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower scope: VariableScope for the created subgraph; defaults to "LSTMCell". 1781032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1791032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Returns: 1801032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower A tuple containing: 1811032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower - A `2-D, [batch x output_dim]`, Tensor representing the output of the 1821032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower LSTM after reading `inputs` when previous state was `state`. 1831032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Here output_dim is: 1841032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj if num_proj was set, 1851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_units otherwise. 1861032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower - Tensor(s) representing the new state of LSTM after reading `inputs` when 1871032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower the previous state was `state`. Same type and shape(s) as `state`. 1881032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1891032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Raises: 1901032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower ValueError: If input size cannot be inferred from inputs via 1911032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower static shape inference. 1921032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """ 193e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo sigmoid = math_ops.sigmoid 194e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo 1951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj = self._num_units if self._num_proj is None else self._num_proj 1961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1971032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if self._state_is_tuple: 1981032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower (c_prev, m_prev) = state 1991032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower else: 2001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 2011032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 2021032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2031032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower dtype = inputs.dtype 2041032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower input_size = inputs.get_shape().with_rank(2)[1] 2051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if input_size.value is None: 2061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 20792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope(scope or "coupled_input_forget_gate_lstm_cell", 20892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo initializer=self._initializer): 2091032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower concat_w = _get_concat_variable( 2101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "W", [input_size.value + num_proj, 3 * self._num_units], 2111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower dtype, self._num_unit_shards) 2121032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2131032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower b = vs.get_variable( 2144ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist "B", 2154ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist shape=[3 * self._num_units], 2164ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist initializer=init_ops.zeros_initializer(), 2174ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist dtype=dtype) 2181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2191032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower # j = new_input, f = forget_gate, o = output_gate 2200e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower cell_inputs = array_ops.concat([inputs, m_prev], 1) 2211032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) 222a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1) 2231032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2241032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower # Diagonal connections 2251032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if self._use_peepholes: 2261032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower w_f_diag = vs.get_variable( 2271032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "W_F_diag", shape=[self._num_units], dtype=dtype) 2281032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower w_o_diag = vs.get_variable( 2291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "W_O_diag", shape=[self._num_units], dtype=dtype) 2301032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2311032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if self._use_peepholes: 2321032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev) 2331032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower else: 2341032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower f_act = sigmoid(f + self._forget_bias) 2351032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower c = (f_act * c_prev + (1 - f_act) * self._activation(j)) 2361032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2371032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if self._use_peepholes: 2381032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower m = sigmoid(o + w_o_diag * c) * self._activation(c) 2391032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower else: 2401032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower m = sigmoid(o) * self._activation(c) 2411032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2421032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if self._num_proj is not None: 2431032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower concat_w_proj = _get_concat_variable( 2441032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "W_P", [self._num_units, self._num_proj], 2451032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower dtype, self._num_proj_shards) 2461032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower m = math_ops.matmul(m, concat_w_proj) 2481032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if self._proj_clip is not None: 2491032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower # pylint: disable=invalid-unary-operand-type 2501032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 2511032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower # pylint: enable=invalid-unary-operand-type 2521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2530e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else 2540e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower array_ops.concat([c, m], 1)) 2551032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower return m, new_state 2561032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2571032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 25837fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass TimeFreqLSTMCell(core_rnn_cell.RNNCell): 259d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell. 260d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 261d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower This implementation is based on: 262d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 263d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Tara N. Sainath and Bo Li 264d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures 265d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for LVCSR Tasks." submitted to INTERSPEECH, 2016. 266d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 267d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower It uses peep-hole connections and optional cell clipping. 268d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 269d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 270d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def __init__(self, num_units, use_peepholes=False, 271d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower cell_clip=None, initializer=None, 272d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_unit_shards=1, forget_bias=1.0, 273d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower feature_size=None, frequency_skip=None): 274d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Initialize the parameters for an LSTM cell. 275d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 276d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 277d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_units: int, The number of units in the LSTM cell 278d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower use_peepholes: bool, set True to enable diagonal/peephole connections. 279d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower cell_clip: (optional) A float value, if provided the cell state is clipped 280d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower by this value prior to the cell output activation. 281d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 282d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower projection matrices. 283d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_unit_shards: int, How to split the weight matrix. If >1, the weight 284d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower matrix is stored across num_unit_shards. 285d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower forget_bias: float, Biases of the forget gate are initialized by default 286d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower to 1 in order to reduce the scale of forgetting at the beginning 287d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower of the training. 288d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower feature_size: int, The size of the input feature the LSTM spans over. 289d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower frequency_skip: int, The amount the LSTM filter is shifted by in 290d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower frequency. 291d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 292d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_units = num_units 293d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._use_peepholes = use_peepholes 294d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._cell_clip = cell_clip 295d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._initializer = initializer 296d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_unit_shards = num_unit_shards 297d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._forget_bias = forget_bias 298d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._feature_size = feature_size 299d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._frequency_skip = frequency_skip 300d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._state_size = 2 * num_units 301d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._output_size = num_units 302d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 303d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 304d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def output_size(self): 305d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._output_size 306d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 307d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 308d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def state_size(self): 309d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._state_size 310d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 311d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def __call__(self, inputs, state, scope=None): 312d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Run one step of LSTM. 313d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 314d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 315d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower inputs: input Tensor, 2D, batch x num_units. 316d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower state: state Tensor, 2D, batch x state_size. 317d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower scope: VariableScope for the created subgraph; defaults to 318d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "TimeFreqLSTMCell". 319d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 320d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 321d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A tuple containing: 322d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower - A 2D, batch x output_dim, Tensor representing the output of the LSTM 323d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 324d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Here output_dim is num_units. 325d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower - A 2D, batch x state_size, Tensor representing the new state of LSTM 326d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 327d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 328d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if an input_size was specified and the provided inputs have 329d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower a different dimension. 330d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 331e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo sigmoid = math_ops.sigmoid 332e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo tanh = math_ops.tanh 333d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 334d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs = self._make_tf_features(inputs) 335d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower dtype = inputs.dtype 336d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower actual_input_size = freq_inputs[0].get_shape().as_list()[1] 33792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope(scope or "time_freq_lstm_cell", 338d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower initializer=self._initializer): # "TimeFreqLSTMCell" 339d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower concat_w = _get_concat_variable( 340d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "W", [actual_input_size + 2*self._num_units, 4 * self._num_units], 341d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower dtype, self._num_unit_shards) 342d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower b = vs.get_variable( 3434ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist "B", 3444ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist shape=[4 * self._num_units], 3454ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist initializer=init_ops.zeros_initializer(), 3464ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist dtype=dtype) 347d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 348763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower # Diagonal connections 349763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower if self._use_peepholes: 350763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower w_f_diag = vs.get_variable( 351763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower "W_F_diag", shape=[self._num_units], dtype=dtype) 352763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower w_i_diag = vs.get_variable( 353763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower "W_I_diag", shape=[self._num_units], dtype=dtype) 354763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower w_o_diag = vs.get_variable( 355763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower "W_O_diag", shape=[self._num_units], dtype=dtype) 356763aca87cdf3084f9efa71f4439127b969035367A. Unique TensorFlower 357d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower # initialize the first freq state to be zero 358d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), 359d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_units], dtype) 360d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for fq in range(len(freq_inputs)): 361d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower c_prev = array_ops.slice(state, [0, 2*fq*self._num_units], 362d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower [-1, self._num_units]) 363d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units], 364d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower [-1, self._num_units]) 365d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower # i = input_gate, j = new_input, f = forget_gate, o = output_gate 3660e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 3670e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower 1) 368d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) 369a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i, j, f, o = array_ops.split( 370a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix, num_or_size_splits=4, axis=1) 371d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 372d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._use_peepholes: 373d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + 374d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower sigmoid(i + w_i_diag * c_prev) * tanh(j)) 375d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower else: 376d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) 377d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 378d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._cell_clip is not None: 379d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower # pylint: disable=invalid-unary-operand-type 380d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 381d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower # pylint: enable=invalid-unary-operand-type 382d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 383d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._use_peepholes: 384d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower m = sigmoid(o + w_o_diag * c) * tanh(c) 385d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower else: 386d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower m = sigmoid(o) * tanh(c) 387d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower m_prev_freq = m 388d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if fq == 0: 3890e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower state_out = array_ops.concat([c, m], 1) 390d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower m_out = m 391d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower else: 3920e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower state_out = array_ops.concat([state_out, c, m], 1) 3930e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower m_out = array_ops.concat([m_out, m], 1) 394d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return m_out, state_out 395d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 396d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def _make_tf_features(self, input_feat): 397d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Make the frequency features. 398d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 399d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 400d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower input_feat: input Tensor, 2D, batch x num_units. 401d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 402d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 403d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A list of frequency features, with each element containing: 404d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower - A 2D, batch x output_dim, Tensor representing the time-frequency feature 405d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for that frequency index. Here output_dim is feature_size. 406d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 407d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if input_size cannot be inferred from static shape inference. 408d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 409d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower input_size = input_feat.get_shape().with_rank(2)[-1].value 410d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if input_size is None: 411d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower raise ValueError("Cannot infer input_size from static shape inference.") 412d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_feats = int((input_size - self._feature_size) / ( 413d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._frequency_skip)) + 1 414d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs = [] 415d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for f in range(num_feats): 416d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip], 417d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower [-1, self._feature_size]) 418d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs.append(cur_input) 419d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return freq_inputs 420d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 421d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 42237fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass GridLSTMCell(core_rnn_cell.RNNCell): 423d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Grid Long short-term memory unit (LSTM) recurrent network cell. 424d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 425d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower The default is based on: 426d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Nal Kalchbrenner, Ivo Danihelka and Alex Graves 427d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "Grid Long Short-Term Memory," Proc. ICLR 2016. 428d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower http://arxiv.org/abs/1507.01526 429d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 430d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower When peephole connections are used, the implementation is based on: 431d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Tara N. Sainath and Bo Li 432d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures 433d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for LVCSR Tasks." submitted to INTERSPEECH, 2016. 434d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 435d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower The code uses optional peephole connections, shared_weights and cell clipping. 436d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 437d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 438d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def __init__(self, num_units, use_peepholes=False, 439d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower share_time_frequency_weights=False, 440d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower cell_clip=None, initializer=None, 441d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_unit_shards=1, forget_bias=1.0, 4428fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower feature_size=None, frequency_skip=None, 4439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks=None, 4449e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list=None, 4459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list=None, 4468fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower couple_input_forget_gates=False, 4478fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower state_is_tuple=False): 448d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Initialize the parameters for an LSTM cell. 449d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 450d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 451d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_units: int, The number of units in the LSTM cell 4521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower use_peepholes: (optional) bool, default False. Set True to enable 4531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower diagonal/peephole connections. 4541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower share_time_frequency_weights: (optional) bool, default False. Set True to 4551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower enable shared cell weights between time and frequency LSTMs. 4561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower cell_clip: (optional) A float value, default None, if provided the cell 4571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state is clipped by this value prior to the cell output activation. 458d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 4591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower projection matrices, default None. 4601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_unit_shards: (optional) int, defualt 1, How to split the weight 4611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower matrix. If > 1,the weight matrix is stored across num_unit_shards. 4621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget_bias: (optional) float, default 1.0, The initial bias of the 4631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget gates, used to reduce the scale of forgetting at the beginning 464d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower of the training. 4651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature_size: (optional) int, default None, The size of the input feature 4661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower the LSTM spans over. 4671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower frequency_skip: (optional) int, default None, The amount the LSTM filter 4681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower is shifted by in frequency. 4699e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks: [required] A list of frequency blocks needed to 4709e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cover the whole input feature splitting defined by start_freqindex_list 4719e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower and end_freqindex_list. 4729e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list: [optional], list of ints, default None, The 4739e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower starting frequency index for each frequency block. 4749e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list: [optional], list of ints, default None. The ending 4759e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower frequency index for each frequency block. 4761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple_input_forget_gates: (optional) bool, default False, Whether to 4771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce 4781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower model parameters and computation cost. 4798fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower state_is_tuple: If True, accepted and returned states are 2-tuples of 4808fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower the `c_state` and `m_state`. By default (False), they are concatenated 4818fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower along the column axis. This default behavior will soon be deprecated. 4829e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower Raises: 4839e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower ValueError: if the num_frequency_blocks list is not specified 484d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 4858fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if not state_is_tuple: 4868fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower logging.warn("%s: Using a concatenated state is slower and will soon be " 4878fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower "deprecated. Use state_is_tuple=True.", self) 488d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_units = num_units 489d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._use_peepholes = use_peepholes 490d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._share_time_frequency_weights = share_time_frequency_weights 4918fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._couple_input_forget_gates = couple_input_forget_gates 4928fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._state_is_tuple = state_is_tuple 493d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._cell_clip = cell_clip 494d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._initializer = initializer 495d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_unit_shards = num_unit_shards 496d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._forget_bias = forget_bias 497d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._feature_size = feature_size 498d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._frequency_skip = frequency_skip 4999e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._start_freqindex_list = start_freqindex_list 5009e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._end_freqindex_list = end_freqindex_list 5019e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._num_frequency_blocks = num_frequency_blocks 5029e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._total_blocks = 0 5039e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if self._num_frequency_blocks is None: 5049e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Must specify num_frequency_blocks") 5059e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower 5069e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for block_index in range(len(self._num_frequency_blocks)): 5079e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._total_blocks += int(self._num_frequency_blocks[block_index]) 5088fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if state_is_tuple: 5098fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower state_names = "" 5109e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for block_index in range(len(self._num_frequency_blocks)): 5119e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for freq_index in range(self._num_frequency_blocks[block_index]): 5129e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower name_prefix = "state_f%02d_b%02d" % (freq_index, block_index) 5139e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) 5148fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._state_tuple_type = collections.namedtuple( 5151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "GridLSTMStateTuple", state_names.strip(",")) 5168fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._state_size = self._state_tuple_type( 5179e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower *([num_units, num_units] * self._total_blocks)) 5188fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 5198fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._state_tuple_type = None 5209e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._state_size = num_units * self._total_blocks * 2 5219e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._output_size = num_units * self._total_blocks * 2 522d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 523d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 524d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def output_size(self): 525d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._output_size 526d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 527d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 528d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def state_size(self): 529d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._state_size 530d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 5318fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower @property 5328fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower def state_tuple_type(self): 5338fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower return self._state_tuple_type 5348fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower 535d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def __call__(self, inputs, state, scope=None): 536d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Run one step of LSTM. 537d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 538d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 5391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs: input Tensor, 2D, [batch, feature_size]. 5401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the 5411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower flag self._state_is_tuple. 5421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower scope: (optional) VariableScope for the created subgraph; if None, it 5431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower defaults to "GridLSTMCell". 544d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 545d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 546d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A tuple containing: 5471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, output_dim], Tensor representing the output of the LSTM 548d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 549d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Here output_dim is num_units. 5501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, state_size], Tensor representing the new state of LSTM 551d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 552d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 553d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if an input_size was specified and the provided inputs have 554d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower a different dimension. 555d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 5561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower batch_size = int(inputs.get_shape()[0]) 5571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower freq_inputs = self._make_tf_features(inputs) 55892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope(scope or "grid_lstm_cell", 5591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower initializer=self._initializer): # "GridLSTMCell" 5609e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower m_out_lst = [] 5619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_out_lst = [] 5629e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for block in range(len(freq_inputs)): 5639e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower m_out_lst_current, state_out_lst_current = self._compute( 5649e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower freq_inputs[block], block, state, batch_size, 5659e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_is_tuple=self._state_is_tuple) 5669e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower m_out_lst.extend(m_out_lst_current) 5679e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_out_lst.extend(state_out_lst_current) 5681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._state_is_tuple: 5691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_out = self._state_tuple_type(*state_out_lst) 5701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 5710e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower state_out = array_ops.concat(state_out_lst, 1) 5720e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower m_out = array_ops.concat(m_out_lst, 1) 5731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower return m_out, state_out 5741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 5759e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower def _compute(self, freq_inputs, block, state, batch_size, 5769e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_prefix="state", 5771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_is_tuple=True): 5781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Run the actual computation of one step LSTM. 5791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 5801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Args: 5811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower freq_inputs: list of Tensors, 2D, [batch, feature_size]. 5829e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block: int, current frequency block index to process. 5831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on 5841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower the flag state_is_tuple. 5851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower batch_size: int32, batch size. 5861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_prefix: (optional) string, name prefix for states, defaults to 5871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "state". 5881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_is_tuple: boolean, indicates whether the state is a tuple or Tensor. 5891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 5901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Returns: 5911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower A tuple, containing: 5921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A list of [batch, output_dim] Tensors, representing the output of the 5931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower LSTM given the inputs and state. 5941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A list of [batch, state_size] Tensors, representing the LSTM state 5951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower values given the inputs and previous state. 5961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 597e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo sigmoid = math_ops.sigmoid 598e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo tanh = math_ops.tanh 5998fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower num_gates = 3 if self._couple_input_forget_gates else 4 6001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower dtype = freq_inputs[0].dtype 601d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower actual_input_size = freq_inputs[0].get_shape().as_list()[1] 6021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower concat_w_f = _get_concat_variable( 6049e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_f_%d" % block, [actual_input_size + 2 * self._num_units, 6059e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_gates * self._num_units], 6061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower dtype, self._num_unit_shards) 6071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower b_f = vs.get_variable( 6084ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist "B_f_%d" % block, 6094ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist shape=[num_gates * self._num_units], 6104ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist initializer=init_ops.zeros_initializer(), 6114ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist dtype=dtype) 6121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if not self._share_time_frequency_weights: 6131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower concat_w_t = _get_concat_variable( 6149e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_t_%d" % block, [actual_input_size + 2 * self._num_units, 6159e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_gates * self._num_units], 616d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower dtype, self._num_unit_shards) 6171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower b_t = vs.get_variable( 6184ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist "B_t_%d" % block, 6194ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist shape=[num_gates * self._num_units], 6204ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist initializer=init_ops.zeros_initializer(), 6214ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist dtype=dtype) 622d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 6231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 6241855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Diagonal connections 6251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if not self._couple_input_forget_gates: 6261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqf = vs.get_variable( 6279e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 6281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqt = vs.get_variable( 6299e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_freqt_%d"% block, shape=[self._num_units], dtype=dtype) 6301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqf = vs.get_variable( 6319e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 6321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqt = vs.get_variable( 6339e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 6341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqf = vs.get_variable( 6359e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 6361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqt = vs.get_variable( 6379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 6381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if not self._share_time_frequency_weights: 6398fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if not self._couple_input_forget_gates: 6401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_timef = vs.get_variable( 6419e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 6421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_timet = vs.get_variable( 6439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 6441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timef = vs.get_variable( 6459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 6461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timet = vs.get_variable( 6479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 6481855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timef = vs.get_variable( 6499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 6501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timet = vs.get_variable( 6519e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 6521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # initialize the first freq state to be zero 6541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) 6551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) 6561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower for freq_index in range(len(freq_inputs)): 6571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if state_is_tuple: 6589e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block) 6591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_time = getattr(state, name_prefix + "_c") 6601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_time = getattr(state, name_prefix + "_m") 6611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 6621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_time = array_ops.slice( 6631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state, [0, 2 * freq_index * self._num_units], 6641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower [-1, self._num_units]) 6651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_time = array_ops.slice( 6661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state, [0, (2 * freq_index + 1) * self._num_units], 6671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower [-1, self._num_units]) 6681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # i = input_gate, j = new_input, f = forget_gate, o = output_gate 6700e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower cell_inputs = array_ops.concat( 671d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1) 6721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # F-LSTM 6741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower lstm_matrix_freq = nn_ops.bias_add(math_ops.matmul(cell_inputs, 6751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower concat_w_f), b_f) 6761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._couple_input_forget_gates: 677a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_freq, j_freq, o_freq = array_ops.split( 678a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) 6791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq = None 6801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 681a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_freq, j_freq, f_freq, o_freq = array_ops.split( 682a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) 6831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # T-LSTM 6841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._share_time_frequency_weights: 6851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_time = i_freq 6861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower j_time = j_freq 6871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time = f_freq 6881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower o_time = o_freq 6891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 6901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower lstm_matrix_time = nn_ops.bias_add(math_ops.matmul(cell_inputs, 6911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower concat_w_t), b_t) 6928fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if self._couple_input_forget_gates: 693a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_time, j_time, o_time = array_ops.split( 694a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) 6951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time = None 6968fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 697a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_time, j_time, f_time, o_time = array_ops.split( 698a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) 699d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 7001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # F-LSTM c_freq 7011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # input gate activations 7021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 7031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_freq_g = sigmoid(i_freq + 7041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqf * c_prev_freq + 7051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqt * c_prev_time) 7061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_freq_g = sigmoid(i_freq) 7081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # forget gate activations 7091855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._couple_input_forget_gates: 7101855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq_g = 1.0 - i_freq_g 7111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 712d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._use_peepholes: 7131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq_g = sigmoid(f_freq + self._forget_bias + 7141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqf * c_prev_freq + 7151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqt * c_prev_time) 7161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq_g = sigmoid(f_freq + self._forget_bias) 7181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # cell state 7191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq) 7201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._cell_clip is not None: 7211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: disable=invalid-unary-operand-type 7221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip, 7231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._cell_clip) 7241855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: enable=invalid-unary-operand-type 7251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 7261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # T-LSTM c_freq 7271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # input gate activations 7281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 7291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._share_time_frequency_weights: 7301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_time_g = sigmoid(i_time + 7318fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower w_i_diag_freqf * c_prev_freq + 7328fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower w_i_diag_freqt * c_prev_time) 7338fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 7341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_time_g = sigmoid(i_time + 7351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timef * c_prev_freq + 7361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timet * c_prev_time) 7371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_time_g = sigmoid(i_time) 7391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # forget gate activations 7401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._couple_input_forget_gates: 7411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time_g = 1.0 - i_time_g 7421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 743d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._use_peepholes: 744d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._share_time_frequency_weights: 7451855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time_g = sigmoid(f_time + self._forget_bias + 7461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqf * c_prev_freq + 7471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqt * c_prev_time) 748d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower else: 7491855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time_g = sigmoid(f_time + self._forget_bias + 7501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_timef * c_prev_freq + 7511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_timet * c_prev_time) 7528fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 7531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time_g = sigmoid(f_time + self._forget_bias) 7541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # cell state 7551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time) 7561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._cell_clip is not None: 7571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: disable=invalid-unary-operand-type 7581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_time = clip_ops.clip_by_value(c_time, -self._cell_clip, 7591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._cell_clip) 7601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: enable=invalid-unary-operand-type 7611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 7621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # F-LSTM m_freq 7631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 7641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_freq = sigmoid(o_freq + 7651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqf * c_freq + 7661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqt * c_time) * tanh(c_freq) 7671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_freq = sigmoid(o_freq) * tanh(c_freq) 769d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 7701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # T-LSTM m_time 7711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 7721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._share_time_frequency_weights: 7731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_time = sigmoid(o_time + 7748fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower w_o_diag_freqf * c_freq + 7751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqt * c_time) * tanh(c_time) 776d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower else: 7771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_time = sigmoid(o_time + 7781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timef * c_freq + 7791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timet * c_time) * tanh(c_time) 7808fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 7811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_time = sigmoid(o_time) * tanh(c_time) 7821855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 7831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_freq = m_freq 7841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_freq = c_freq 7851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Concatenate the outputs for T-LSTM and F-LSTM for each shift 7861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if freq_index == 0: 7871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_out_lst = [c_time, m_time] 7881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_out_lst = [m_time, m_freq] 7891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_out_lst.extend([c_time, m_time]) 7911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_out_lst.extend([m_time, m_freq]) 792d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 7931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower return m_out_lst, state_out_lst 7941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 7951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower def _make_tf_features(self, input_feat, slice_offset=0): 796d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Make the frequency features. 797d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 798d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 7991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower input_feat: input Tensor, 2D, [batch, num_units]. 8001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower slice_offset: (optional) Python int, default 0, the slicing offset is only 8011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower used for the backward processing in the BidirectionalGridLSTMCell. It 8021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower specifies a different starting point instead of always 0 to enable the 8031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forward and backward processing look at different frequency blocks. 804d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 805d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 806d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A list of frequency features, with each element containing: 8071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, output_dim], Tensor representing the time-frequency 8081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature for that frequency index. Here output_dim is feature_size. 809d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 810d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if input_size cannot be inferred from static shape inference. 811d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 812d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower input_size = input_feat.get_shape().with_rank(2)[-1].value 813d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if input_size is None: 814d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower raise ValueError("Cannot infer input_size from static shape inference.") 8151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if slice_offset > 0: 8161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Padding to the end 8171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs = array_ops.pad( 8181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower input_feat, array_ops.constant([0, 0, 0, slice_offset], shape=[2, 2], 8191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower dtype=dtypes.int32), 8201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "CONSTANT") 8211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower elif slice_offset < 0: 8221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Padding to the front 8231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs = array_ops.pad( 8241855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower input_feat, array_ops.constant([0, 0, -slice_offset, 0], shape=[2, 2], 8251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower dtype=dtypes.int32), 8261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "CONSTANT") 8271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower slice_offset = 0 8281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 8291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs = input_feat 830d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs = [] 8319e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if not self._start_freqindex_list: 8329e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if len(self._num_frequency_blocks) != 1: 8339e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Length of num_frequency_blocks" 8349e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " is not 1, but instead is %d", 8359e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._num_frequency_blocks)) 8369e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_feats = int((input_size - self._feature_size) / ( 8379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._frequency_skip)) + 1 8389e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if num_feats != self._num_frequency_blocks[0]: 8399e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError( 8409e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "Invalid num_frequency_blocks, requires %d but gets %d, please" 8419e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " check the input size and filter config are correct." % ( 8429e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._num_frequency_blocks[0], num_feats)) 8439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs = [] 8449e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for f in range(num_feats): 8459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cur_input = array_ops.slice( 8469e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower inputs, [0, slice_offset + f * self._frequency_skip], 8479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower [-1, self._feature_size]) 8489e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs.append(cur_input) 8499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower freq_inputs.append(block_inputs) 8509e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower else: 8519e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if len(self._start_freqindex_list) != len(self._end_freqindex_list): 8529e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Length of start and end freqindex_list" 8539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " does not match %d %d", 8549e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._start_freqindex_list), 8559e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._end_freqindex_list)) 8569e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if len(self._num_frequency_blocks) != len(self._start_freqindex_list): 8579e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Length of num_frequency_blocks" 8589e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " is not equal to start_freqindex_list %d %d", 8599e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._num_frequency_blocks), 8609e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._start_freqindex_list)) 8619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for b in range(len(self._start_freqindex_list)): 8629e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_index = self._start_freqindex_list[b] 8639e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_index = self._end_freqindex_list[b] 8649e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cur_size = end_index - start_index 8659e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_feats = int((cur_size - self._feature_size) / ( 8669e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._frequency_skip)) + 1 8679e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if block_feats != self._num_frequency_blocks[b]: 8689e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError( 8699e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "Invalid num_frequency_blocks, requires %d but gets %d, please" 8709e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " check the input size and filter config are correct." % ( 8719e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._num_frequency_blocks[b], block_feats)) 8729e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs = [] 8739e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for f in range(block_feats): 8749e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cur_input = array_ops.slice( 8759e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower inputs, [0, start_index + slice_offset + f * 8769e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._frequency_skip], 8779e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower [-1, self._feature_size]) 8789e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs.append(cur_input) 8799e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower freq_inputs.append(block_inputs) 880d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return freq_inputs 8815cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 8825cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 8831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerclass BidirectionalGridLSTMCell(GridLSTMCell): 8841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Bidirectional GridLstm cell. 8851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 8861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower The bidirection connection is only used in the frequency direction, which 8871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower hence doesn't affect the time direction's real-time processing that is 8881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower required for online recognition systems. 8891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower The current implementation uses different weights for the two directions. 8901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 8911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 8921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower def __init__(self, num_units, use_peepholes=False, 8931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower share_time_frequency_weights=False, 8941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower cell_clip=None, initializer=None, 8951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_unit_shards=1, forget_bias=1.0, 8961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature_size=None, frequency_skip=None, 8979e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks=None, 8989e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list=None, 8999e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list=None, 9001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple_input_forget_gates=False, 9011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower backward_slice_offset=0): 9021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Initialize the parameters for an LSTM cell. 9031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Args: 9051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_units: int, The number of units in the LSTM cell 9061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower use_peepholes: (optional) bool, default False. Set True to enable 9071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower diagonal/peephole connections. 9081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower share_time_frequency_weights: (optional) bool, default False. Set True to 9091855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower enable shared cell weights between time and frequency LSTMs. 9101855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower cell_clip: (optional) A float value, default None, if provided the cell 9111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state is clipped by this value prior to the cell output activation. 9121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 9131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower projection matrices, default None. 9141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_unit_shards: (optional) int, defualt 1, How to split the weight 9151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower matrix. If > 1,the weight matrix is stored across num_unit_shards. 9161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget_bias: (optional) float, default 1.0, The initial bias of the 9171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget gates, used to reduce the scale of forgetting at the beginning 9181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower of the training. 9191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature_size: (optional) int, default None, The size of the input feature 9201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower the LSTM spans over. 9211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower frequency_skip: (optional) int, default None, The amount the LSTM filter 9221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower is shifted by in frequency. 9239e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks: [required] A list of frequency blocks needed to 9249e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cover the whole input feature splitting defined by start_freqindex_list 9259e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower and end_freqindex_list. 9269e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list: [optional], list of ints, default None, The 9279e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower starting frequency index for each frequency block. 9289e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list: [optional], list of ints, default None. The ending 9299e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower frequency index for each frequency block. 9301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple_input_forget_gates: (optional) bool, default False, Whether to 9311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce 9321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower model parameters and computation cost. 9331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower backward_slice_offset: (optional) int32, default 0, the starting offset to 9341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower slice the feature for backward processing. 9351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 9361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower super(BidirectionalGridLSTMCell, self).__init__( 9371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_units, use_peepholes, share_time_frequency_weights, cell_clip, 9381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower initializer, num_unit_shards, forget_bias, feature_size, frequency_skip, 9399e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks, start_freqindex_list, end_freqindex_list, 9409e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower couple_input_forget_gates=False, 9411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_is_tuple=True) 9421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._backward_slice_offset = int(backward_slice_offset) 9431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_names = "" 9441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower for direction in ["fwd", "bwd"]: 9459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for block_index in range(len(self._num_frequency_blocks)): 9469e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for freq_index in range(self._num_frequency_blocks[block_index]): 9479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index, 9489e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_index) 9499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) 9501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._state_tuple_type = collections.namedtuple( 9511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "BidirectionalGridLSTMStateTuple", state_names.strip(",")) 9521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._state_size = self._state_tuple_type( 9539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower *([num_units, num_units] * self._total_blocks * 2)) 9549e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._output_size = 2 * num_units * self._total_blocks * 2 9551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower def __call__(self, inputs, state, scope=None): 9571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Run one step of LSTM. 9581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Args: 9601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs: input Tensor, 2D, [batch, num_units]. 9611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state: tuple of Tensors, 2D, [batch, state_size]. 9621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower scope: (optional) VariableScope for the created subgraph; if None, it 9631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower defaults to "BidirectionalGridLSTMCell". 9641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Returns: 9661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower A tuple containing: 9671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, output_dim], Tensor representing the output of the LSTM 9681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower after reading "inputs" when previous state was "state". 9691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Here output_dim is num_units. 9701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, state_size], Tensor representing the new state of LSTM 9711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower after reading "inputs" when previous state was "state". 9721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Raises: 9731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower ValueError: if an input_size was specified and the provided inputs have 9741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower a different dimension. 9751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 9761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower batch_size = int(inputs.get_shape()[0]) 9771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower fwd_inputs = self._make_tf_features(inputs) 9781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._backward_slice_offset: 9791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset) 9801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 9811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower bwd_inputs = fwd_inputs 9821855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Forward processing 98492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope(scope or "bidirectional_grid_lstm_cell", 9851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower initializer=self._initializer): 98692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope("fwd"): 98792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo fwd_m_out_lst = [] 98892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo fwd_state_out_lst = [] 98992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo for block in range(len(fwd_inputs)): 99092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( 99192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo fwd_inputs[block], block, state, batch_size, 99292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo state_prefix="fwd_state", state_is_tuple=True) 99392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo fwd_m_out_lst.extend(fwd_m_out_lst_current) 99492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo fwd_state_out_lst.extend(fwd_state_out_lst_current) 99592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo # Backward processing 99692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo bwd_m_out_lst = [] 99792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo bwd_state_out_lst = [] 99892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope("bwd"): 99992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo for block in range(len(bwd_inputs)): 100092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo # Reverse the blocks 100192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo bwd_inputs_reverse = bwd_inputs[block][::-1] 100292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( 100392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo bwd_inputs_reverse, block, state, batch_size, 100492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo state_prefix="bwd_state", state_is_tuple=True) 100592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo bwd_m_out_lst.extend(bwd_m_out_lst_current) 100692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo bwd_state_out_lst.extend(bwd_state_out_lst_current) 10071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst)) 10081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Outputs are always concated as it is never used separately. 10090e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1) 10101855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower return m_out, state_out 10111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 10121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 10135cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: disable=protected-access 101437fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie_linear = core_rnn_cell_impl._linear 10155cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: enable=protected-access 10165cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10175cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 101837fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass AttentionCellWrapper(core_rnn_cell.RNNCell): 10195cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """Basic attention cell wrapper. 10205cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 1021ec3f4d62979ef1e70e8e12e2568b13dad45fd39eA. Unique TensorFlower Implementation based on https://arxiv.org/abs/1409.0473. 10225cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """ 10235cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10245cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None, 10255cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower input_size=None, state_is_tuple=False): 10265cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """Create a cell with attention. 10275cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10285cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower Args: 10295cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower cell: an RNNCell, an attention is added to it. 10305cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_length: integer, the size of an attention window. 10315cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_size: integer, the size of an attention vector. Equal to 10325cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower cell.output_size by default. 10335cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_vec_size: integer, the number of convolutional features calculated 10345cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower on attention state and a size of the hidden layer built from 10355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower base cell state. Equal attn_size to by default. 10365cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower input_size: integer, the size of a hidden linear layer, 10375cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower built from inputs and attention. Derived from the input tensor 10385cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower by default. 10395cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower state_is_tuple: If True, accepted and returned states are n-tuples, where 10405cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower `n = len(cells)`. By default (False), the states are all 10415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower concatenated along the column axis. 10425cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10435cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower Raises: 10445cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower TypeError: if cell is not an RNNCell. 10455cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower ValueError: if cell returns a state tuple but the flag 10465cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower `state_is_tuple` is `False` or if attn_length is zero or less. 10475cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """ 104837fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie if not isinstance(cell, core_rnn_cell.RNNCell): 10495cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower raise TypeError("The parameter cell is not RNNCell.") 10504c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenech if nest.is_sequence(cell.state_size) and not state_is_tuple: 10515cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower raise ValueError("Cell returns tuple of states, but the flag " 10525cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower "state_is_tuple is not set. State size is: %s" 10535cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower % str(cell.state_size)) 10545cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if attn_length <= 0: 10555cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower raise ValueError("attn_length should be greater than zero, got %s" 10565cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower % str(attn_length)) 10575cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if not state_is_tuple: 10585cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower logging.warn( 10595cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower "%s: Using a concatenated state is slower and will soon be " 1060d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower "deprecated. Use state_is_tuple=True.", self) 10615cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if attn_size is None: 10625cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_size = cell.output_size 10635cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if attn_vec_size is None: 10645cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_vec_size = attn_size 10655cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._state_is_tuple = state_is_tuple 10665cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._cell = cell 10675cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_vec_size = attn_vec_size 10685cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._input_size = input_size 10695cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_size = attn_size 10705cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_length = attn_length 10715cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10725cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower @property 10735cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def state_size(self): 10745cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower size = (self._cell.state_size, self._attn_size, 10755cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_size * self._attn_length) 10765cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if self._state_is_tuple: 10775cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return size 10785cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower else: 10795cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return sum(list(size)) 10805cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10815cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower @property 10825cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def output_size(self): 10835cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return self._attn_size 10845cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10855cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def __call__(self, inputs, state, scope=None): 10865cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """Long short-term memory cell with attention (LSTMA).""" 108792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope(scope or "attention_cell_wrapper"): 10885cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if self._state_is_tuple: 10895cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower state, attns, attn_states = state 10905cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower else: 10915cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower states = state 10925cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size]) 10935cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attns = array_ops.slice( 10945cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower states, [0, self._cell.state_size], [-1, self._attn_size]) 10955cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_states = array_ops.slice( 10965cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower states, [0, self._cell.state_size + self._attn_size], 10975cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower [-1, self._attn_size * self._attn_length]) 10985cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_states = array_ops.reshape(attn_states, 10995cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower [-1, self._attn_length, self._attn_size]) 11005cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower input_size = self._input_size 11015cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if input_size is None: 11025cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower input_size = inputs.get_shape().as_list()[1] 11035cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower inputs = _linear([inputs, attns], input_size, True) 11045cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower lstm_output, new_state = self._cell(inputs, state) 11055cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if self._state_is_tuple: 11060e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower new_state_cat = array_ops.concat(nest.flatten(new_state), 1) 11075cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower else: 11085cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_state_cat = new_state 11095cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_attns, new_attn_states = self._attention(new_state_cat, attn_states) 111092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope("attn_output_projection"): 11115cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower output = _linear([lstm_output, new_attns], self._attn_size, True) 11120e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower new_attn_states = array_ops.concat( 1113d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower [new_attn_states, array_ops.expand_dims(output, 1)], 1) 11145cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_attn_states = array_ops.reshape( 11155cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_attn_states, [-1, self._attn_length * self._attn_size]) 11165cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_state = (new_state, new_attns, new_attn_states) 11175cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if not self._state_is_tuple: 11180e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower new_state = array_ops.concat(list(new_state), 1) 11195cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return output, new_state 11205cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 11215cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def _attention(self, query, attn_states): 1122e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo conv2d = nn_ops.conv2d 1123e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo reduce_sum = math_ops.reduce_sum 1124e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo softmax = nn_ops.softmax 1125e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo tanh = math_ops.tanh 1126e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo 112792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope("attention"): 112892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo k = vs.get_variable( 112992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo "attn_w", [1, 1, self._attn_size, self._attn_vec_size]) 113092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo v = vs.get_variable("attn_v", [self._attn_vec_size]) 11315cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower hidden = array_ops.reshape(attn_states, 11325cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower [-1, self._attn_length, 1, self._attn_size]) 11335cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") 11345cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower y = _linear(query, self._attn_vec_size, True) 11355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) 11365cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) 11375cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower a = softmax(s) 11385cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower d = reduce_sum( 11395cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2]) 11405cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_attns = array_ops.reshape(d, [-1, self._attn_size]) 11415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1]) 11425cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return new_attns, new_attn_states 114334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 114434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 114537fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass LayerNormBasicLSTMCell(core_rnn_cell.RNNCell): 114634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """LSTM unit with layer normalization and recurrent dropout. 114734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 114834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower This class adds layer normalization and recurrent dropout to a 114934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower basic LSTM unit. Layer normalization implementation is based on: 115034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 115134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower https://arxiv.org/abs/1607.06450. 115234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 115334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower "Layer Normalization" 115434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 115534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 115634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower and is applied before the internal nonlinearities. 115734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Recurrent dropout is base on: 115834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 115934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower https://arxiv.org/abs/1603.05118 116034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 116134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower "Recurrent Dropout without Memory Loss" 116234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth. 116334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """ 116434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 116534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def __init__(self, num_units, forget_bias=1.0, 116634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower input_size=None, activation=math_ops.tanh, 116734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower layer_norm=True, norm_gain=1.0, norm_shift=0.0, 116834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower dropout_keep_prob=1.0, dropout_prob_seed=None): 116934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """Initializes the basic LSTM cell. 117034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 117134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Args: 117234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower num_units: int, The number of units in the LSTM cell. 117334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower forget_bias: float, The bias added to forget gates (see above). 117434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower input_size: Deprecated and unused. 117534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower activation: Activation function of the inner states. 117634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower layer_norm: If `True`, layer normalization will be applied. 117734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower norm_gain: float, The layer normalization gain initial value. If 117834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower `layer_norm` has been set to `False`, this argument will be ignored. 117934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower norm_shift: float, The layer normalization shift initial value. If 118034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower `layer_norm` has been set to `False`, this argument will be ignored. 118134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower dropout_keep_prob: unit Tensor or float between 0 and 1 representing the 118234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower recurrent dropout probability value. If float and 1.0, no dropout will 118334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower be applied. 118434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower dropout_prob_seed: (optional) integer, the randomness seed. 118534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """ 118634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 118734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower if input_size is not None: 118834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower logging.warn("%s: The input_size parameter is deprecated.", self) 118934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 119034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._num_units = num_units 119134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._activation = activation 119234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._forget_bias = forget_bias 119334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._keep_prob = dropout_keep_prob 119434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._seed = dropout_prob_seed 119534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._layer_norm = layer_norm 119634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._g = norm_gain 119734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._b = norm_shift 119834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 119934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower @property 120034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def state_size(self): 120137fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units) 120234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 120334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower @property 120434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def output_size(self): 120534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower return self._num_units 120634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 120734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def _norm(self, inp, scope): 120892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo shape = inp.get_shape()[-1:] 120992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo gamma_init = init_ops.constant_initializer(self._g) 121092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo beta_init = init_ops.constant_initializer(self._b) 121192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope(scope): 121292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo # Initialize beta and gamma for use by layer_norm. 121392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo vs.get_variable("gamma", shape=shape, initializer=gamma_init) 121492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo vs.get_variable("beta", shape=shape, initializer=beta_init) 121592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo normalized = layers.layer_norm(inp, reuse=True, scope=scope) 121692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo return normalized 121792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo 121892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo def _linear(self, args): 121934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower out_size = 4 * self._num_units 122034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower proj_size = args.get_shape()[-1] 122192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo weights = vs.get_variable("weights", [proj_size, out_size]) 122292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo out = math_ops.matmul(args, weights) 122392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo if not self._layer_norm: 122492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo bias = vs.get_variable("biases", [out_size]) 122592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo out = nn_ops.bias_add(out, bias) 122692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo return out 122734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 122834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def __call__(self, inputs, state, scope=None): 122934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """LSTM cell with layer normalization and recurrent dropout.""" 123034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 123192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope(scope or "layer_norm_basic_lstm_cell"): 123234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower c, h = state 12330e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower args = array_ops.concat([inputs, h], 1) 123434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower concat = self._linear(args) 123534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1236a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) 123734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower if self._layer_norm: 123834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower i = self._norm(i, "input") 123934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower j = self._norm(j, "transform") 124034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower f = self._norm(f, "forget") 124134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower o = self._norm(o, "output") 124234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 124334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower g = self._activation(j) 124434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: 124534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) 124634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 124734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower new_c = (c * math_ops.sigmoid(f + self._forget_bias) 124834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower + math_ops.sigmoid(i) * g) 124934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower if self._layer_norm: 125034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower new_c = self._norm(new_c, "state") 125134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower new_h = self._activation(new_c) * math_ops.sigmoid(o) 125234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 125337fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) 125434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower return new_h, new_state 1255bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1256bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1257bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo_REGISTERED_OPS = None 1258bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1259bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1260bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdoclass CompiledWrapper(core_rnn_cell.RNNCell): 1261bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo """Wraps step execution in an XLA JIT scope.""" 1262bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1263bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def __init__(self, cell, compile_stateful=False): 1264bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo """Create CompiledWrapper cell. 1265bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1266bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo Args: 1267bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo cell: Instance of `RNNCell`. 1268bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo compile_stateful: Whether to compile stateful ops like initializers 1269bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo and random number generators (default: False). 1270bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo """ 1271bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo self._cell = cell 1272bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo self._compile_stateful = compile_stateful 1273bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1274bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo @property 1275bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def state_size(self): 1276bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo return self._cell.state_size 1277bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1278bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo @property 1279bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def output_size(self): 1280bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo return self._cell.output_size 1281bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1282bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def __call__(self, inputs, state, scope=None): 1283bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo if self._compile_stateful: 1284bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo compile_ops = True 1285bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo else: 1286bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def compile_ops(node_def): 1287bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo global _REGISTERED_OPS 1288bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo if _REGISTERED_OPS is None: 1289bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo _REGISTERED_OPS = op_def_registry.get_registered_ops() 1290bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo return not _REGISTERED_OPS[node_def.op].is_stateful 1291bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1292bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo with jit.experimental_jit_scope(compile_ops=compile_ops): 1293bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo return self._cell(inputs, state, scope=scope) 1294