rnn_cell.py revision e8482ab23bd0fce5c2941f6a190158bca2610a35
1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# 3d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License"); 4d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# you may not use this file except in compliance with the License. 5d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# You may obtain a copy of the License at 6d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# 7d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# http://www.apache.org/licenses/LICENSE-2.0 8d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# 9d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software 10d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS, 11d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# See the License for the specific language governing permissions and 13d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# limitations under the License. 14d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# ============================================================================== 15d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 16d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower"""Module for constructing RNN Cells.""" 17d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import absolute_import 18d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import division 19d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import print_function 20d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 218fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlowerimport collections 22d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerimport math 23d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 24bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.contrib.compiler import jit 2534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.contrib.layers.python.layers import layers 2637fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xiefrom tensorflow.contrib.rnn.python.ops import core_rnn_cell 2737fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xiefrom tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl 281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerfrom tensorflow.python.framework import dtypes 29bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.python.framework import op_def_registry 30d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.framework import ops 31d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import array_ops 32d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import clip_ops 3334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.python.ops import init_ops 34d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import math_ops 35d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import nn_ops 36bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlowerfrom tensorflow.python.ops import random_ops 37d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import variable_scope as vs 385cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlowerfrom tensorflow.python.platform import tf_logging as logging 394c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenechfrom tensorflow.python.util import nest 40d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 4137fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie 4254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo_checked_scope = core_rnn_cell_impl._checked_scope # pylint: disable=protected-access 4354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo 4454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo 45d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_concat_variable(name, shape, dtype, num_shards): 46d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Get a sharded variable concatenated into one tensor.""" 47d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) 48d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if len(sharded_variable) == 1: 49d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return sharded_variable[0] 50d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 51d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower concat_name = name + "/concat" 52d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" 53d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): 54d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if value.name == concat_full_name: 55d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return value 56d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 570e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) 58d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, 59d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower concat_variable) 60d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return concat_variable 61d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 62d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 63d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_sharded_variable(name, shape, dtype, num_shards): 64d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Get a list of sharded variables with the given dtype.""" 65d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if num_shards > shape[0]: 66d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower raise ValueError("Too many shards: shape=%s, num_shards=%d" % 67d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower (shape, num_shards)) 68d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower unit_shard_size = int(math.floor(shape[0] / num_shards)) 69d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower remaining_rows = shape[0] - unit_shard_size * num_shards 70d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 71d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower shards = [] 72d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for i in range(num_shards): 73d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower current_size = unit_shard_size 74d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if i < remaining_rows: 75d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower current_size += 1 76d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower shards.append(vs.get_variable(name + "_%d" % i, [current_size] + shape[1:], 77d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower dtype=dtype)) 78d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return shards 79d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 80d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 8137fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell): 821032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """Long short-term memory unit (LSTM) recurrent network cell. 831032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 841032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The default non-peephole implementation is based on: 851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 861032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf 871032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 881032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower S. Hochreiter and J. Schmidhuber. 891032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 901032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 911032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The peephole implementation is based on: 921032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 931032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower https://research.google.com/pubs/archive/43905.pdf 941032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Hasim Sak, Andrew Senior, and Francoise Beaufays. 961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "Long short-term memory recurrent neural network architectures for 971032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower large scale acoustic modeling." INTERSPEECH, 2014. 981032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 991032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The coupling of input and forget gate is based on: 1001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1011032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower http://arxiv.org/pdf/1503.04069.pdf 1021032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1031032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Greff et al. "LSTM: A Search Space Odyssey" 1041032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The class uses optional peep-hole connections, and an optional projection 1061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower layer. 1071032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """ 1081032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1091032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower def __init__(self, num_units, use_peepholes=False, 1101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower initializer=None, num_proj=None, proj_clip=None, 1111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_unit_shards=1, num_proj_shards=1, 112c3d99052ec49bb219f5e29567846d3af391d7b28A. Unique TensorFlower forget_bias=1.0, state_is_tuple=True, 11354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo activation=math_ops.tanh, reuse=None): 1141032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """Initialize the parameters for an LSTM cell. 1151032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1161032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Args: 1171032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_units: int, The number of units in the LSTM cell 1181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower use_peepholes: bool, set True to enable diagonal/peephole connections. 1191032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 1201032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower projection matrices. 1211032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj: (optional) int, The output dimensionality for the projection 1221032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower matrices. If None, no projection is performed. 1231032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 1241032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower provided, then the projected values are clipped elementwise to within 1251032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower `[-proj_clip, proj_clip]`. 1261032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_unit_shards: How to split the weight matrix. If >1, the weight 1271032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower matrix is stored across num_unit_shards. 1281032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj_shards: How to split the projection matrix. If >1, the 1291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower projection matrix is stored across num_proj_shards. 1301032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower forget_bias: Biases of the forget gate are initialized by default to 1 1311032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower in order to reduce the scale of forgetting at the beginning of 1321032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower the training. 1331032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower state_is_tuple: If True, accepted and returned states are 2-tuples of 1341032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower the `c_state` and `m_state`. By default (False), they are concatenated 1351032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower along the column axis. This default behavior will soon be deprecated. 1361032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower activation: Activation function of the inner states. 13754d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 13854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 13954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 1401032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """ 141e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) 1421032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if not state_is_tuple: 1431032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower logging.warn( 1441032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "%s: Using a concatenated state is slower and will soon be " 145d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower "deprecated. Use state_is_tuple=True.", self) 1461032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_units = num_units 1471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._use_peepholes = use_peepholes 1481032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._initializer = initializer 1491032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_proj = num_proj 1501032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._proj_clip = proj_clip 1511032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_unit_shards = num_unit_shards 1521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_proj_shards = num_proj_shards 1531032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._forget_bias = forget_bias 1541032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._state_is_tuple = state_is_tuple 1551032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._activation = activation 15654d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 1571032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1581032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if num_proj: 1591032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._state_size = ( 16037fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie core_rnn_cell.LSTMStateTuple(num_units, num_proj) 1611032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if state_is_tuple else num_units + num_proj) 1621032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._output_size = num_proj 1631032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower else: 1641032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._state_size = ( 16537fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie core_rnn_cell.LSTMStateTuple(num_units, num_units) 1661032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if state_is_tuple else 2 * num_units) 1671032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._output_size = num_units 1681032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1691032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower @property 1701032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower def state_size(self): 1711032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower return self._state_size 1721032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1731032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower @property 1741032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower def output_size(self): 1751032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower return self._output_size 1761032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 177e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 1781032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """Run one step of LSTM. 1791032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1801032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Args: 1811032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower inputs: input Tensor, 2D, batch x num_units. 1821032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower state: if `state_is_tuple` is False, this must be a state Tensor, 1831032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a 1841032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower tuple of state Tensors, both `2-D`, with column sizes `c_state` and 1851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower `m_state`. 1861032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1871032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Returns: 1881032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower A tuple containing: 1891032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower - A `2-D, [batch x output_dim]`, Tensor representing the output of the 1901032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower LSTM after reading `inputs` when previous state was `state`. 1911032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Here output_dim is: 1921032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj if num_proj was set, 1931032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_units otherwise. 1941032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower - Tensor(s) representing the new state of LSTM after reading `inputs` when 1951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower the previous state was `state`. Same type and shape(s) as `state`. 1961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1971032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Raises: 1981032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower ValueError: If input size cannot be inferred from inputs via 1991032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower static shape inference. 2001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """ 201e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo sigmoid = math_ops.sigmoid 202e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo 2031032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj = self._num_units if self._num_proj is None else self._num_proj 2041032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if self._state_is_tuple: 2061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower (c_prev, m_prev) = state 2071032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower else: 2081032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 2091032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 2101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower dtype = inputs.dtype 2121032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower input_size = inputs.get_shape().with_rank(2)[1] 2131032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if input_size.value is None: 2141032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 215e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat_w = _get_concat_variable( 216e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W", [input_size.value + num_proj, 3 * self._num_units], 217e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype, self._num_unit_shards) 2181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 219e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo b = vs.get_variable( 220e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "B", 221e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo shape=[3 * self._num_units], 222e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=init_ops.zeros_initializer(), 223e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype=dtype) 2241032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 225e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # j = new_input, f = forget_gate, o = output_gate 226e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo cell_inputs = array_ops.concat([inputs, m_prev], 1) 227e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) 228e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1) 2291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 230e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Diagonal connections 231e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 232e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo w_f_diag = vs.get_variable( 233e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_F_diag", shape=[self._num_units], dtype=dtype) 234e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo w_o_diag = vs.get_variable( 235e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_O_diag", shape=[self._num_units], dtype=dtype) 2361032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 237e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 238e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev) 239e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 240e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo f_act = sigmoid(f + self._forget_bias) 241e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo c = (f_act * c_prev + (1 - f_act) * self._activation(j)) 2421032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 243e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 244e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = sigmoid(o + w_o_diag * c) * self._activation(c) 245e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 246e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = sigmoid(o) * self._activation(c) 2471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 248e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._num_proj is not None: 249e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat_w_proj = _get_concat_variable( 250e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_P", [self._num_units, self._num_proj], 251e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype, self._num_proj_shards) 2521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 253e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = math_ops.matmul(m, concat_w_proj) 254e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._proj_clip is not None: 255e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # pylint: disable=invalid-unary-operand-type 256e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 257e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # pylint: enable=invalid-unary-operand-type 2581032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2590e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else 2600e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower array_ops.concat([c, m], 1)) 2611032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower return m, new_state 2621032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2631032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 26437fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass TimeFreqLSTMCell(core_rnn_cell.RNNCell): 265d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell. 266d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 267d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower This implementation is based on: 268d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 269d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Tara N. Sainath and Bo Li 270d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures 271d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for LVCSR Tasks." submitted to INTERSPEECH, 2016. 272d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 273d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower It uses peep-hole connections and optional cell clipping. 274d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 275d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 276d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def __init__(self, num_units, use_peepholes=False, 277d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower cell_clip=None, initializer=None, 278d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_unit_shards=1, forget_bias=1.0, 27954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo feature_size=None, frequency_skip=None, 28054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse=None): 281d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Initialize the parameters for an LSTM cell. 282d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 283d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 284d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_units: int, The number of units in the LSTM cell 285d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower use_peepholes: bool, set True to enable diagonal/peephole connections. 286d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower cell_clip: (optional) A float value, if provided the cell state is clipped 287d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower by this value prior to the cell output activation. 288d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 289d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower projection matrices. 290d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_unit_shards: int, How to split the weight matrix. If >1, the weight 291d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower matrix is stored across num_unit_shards. 292d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower forget_bias: float, Biases of the forget gate are initialized by default 293d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower to 1 in order to reduce the scale of forgetting at the beginning 294d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower of the training. 295d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower feature_size: int, The size of the input feature the LSTM spans over. 296d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower frequency_skip: int, The amount the LSTM filter is shifted by in 297d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower frequency. 29854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 29954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 30054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 301d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 302e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(TimeFreqLSTMCell, self).__init__(_reuse=reuse) 303d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_units = num_units 304d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._use_peepholes = use_peepholes 305d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._cell_clip = cell_clip 306d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._initializer = initializer 307d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_unit_shards = num_unit_shards 308d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._forget_bias = forget_bias 309d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._feature_size = feature_size 310d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._frequency_skip = frequency_skip 311d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._state_size = 2 * num_units 312d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._output_size = num_units 31354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 314d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 315d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 316d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def output_size(self): 317d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._output_size 318d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 319d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 320d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def state_size(self): 321d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._state_size 322d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 323e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 324d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Run one step of LSTM. 325d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 326d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 327d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower inputs: input Tensor, 2D, batch x num_units. 328d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower state: state Tensor, 2D, batch x state_size. 329d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 330d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 331d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A tuple containing: 332d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower - A 2D, batch x output_dim, Tensor representing the output of the LSTM 333d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 334d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Here output_dim is num_units. 335d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower - A 2D, batch x state_size, Tensor representing the new state of LSTM 336d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 337d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 338d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if an input_size was specified and the provided inputs have 339d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower a different dimension. 340d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 341e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo sigmoid = math_ops.sigmoid 342e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo tanh = math_ops.tanh 343d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 344d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs = self._make_tf_features(inputs) 345d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower dtype = inputs.dtype 346d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower actual_input_size = freq_inputs[0].get_shape().as_list()[1] 347d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 348e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat_w = _get_concat_variable( 349e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W", [actual_input_size + 2*self._num_units, 4 * self._num_units], 350e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype, self._num_unit_shards) 351d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 352e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo b = vs.get_variable( 353e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "B", 354e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo shape=[4 * self._num_units], 355e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=init_ops.zeros_initializer(), 356e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype=dtype) 357d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 358e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Diagonal connections 359e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 360e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo w_f_diag = vs.get_variable( 361e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_F_diag", shape=[self._num_units], dtype=dtype) 362e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo w_i_diag = vs.get_variable( 363e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_I_diag", shape=[self._num_units], dtype=dtype) 364e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo w_o_diag = vs.get_variable( 365e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_O_diag", shape=[self._num_units], dtype=dtype) 366d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 367e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # initialize the first freq state to be zero 368e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), 369e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo self._num_units], dtype) 370e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo for fq in range(len(freq_inputs)): 371e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo c_prev = array_ops.slice(state, [0, 2*fq*self._num_units], 372e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [-1, self._num_units]) 373e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units], 374e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [-1, self._num_units]) 375e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # i = input_gate, j = new_input, f = forget_gate, o = output_gate 376e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 377e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1) 378e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) 379e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo i, j, f, o = array_ops.split( 380e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo value=lstm_matrix, num_or_size_splits=4, axis=1) 381e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 382e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 383e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + 384e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo sigmoid(i + w_i_diag * c_prev) * tanh(j)) 385e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 386e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) 387e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 388e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._cell_clip is not None: 389e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # pylint: disable=invalid-unary-operand-type 390e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 391e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # pylint: enable=invalid-unary-operand-type 392e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 393e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 394e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = sigmoid(o + w_o_diag * c) * tanh(c) 395e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 396e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = sigmoid(o) * tanh(c) 397e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_prev_freq = m 398e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if fq == 0: 399e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out = array_ops.concat([c, m], 1) 400e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out = m 401e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 402e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out = array_ops.concat([state_out, c, m], 1) 403e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out = array_ops.concat([m_out, m], 1) 404d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return m_out, state_out 405d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 406d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def _make_tf_features(self, input_feat): 407d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Make the frequency features. 408d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 409d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 410d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower input_feat: input Tensor, 2D, batch x num_units. 411d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 412d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 413d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A list of frequency features, with each element containing: 414d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower - A 2D, batch x output_dim, Tensor representing the time-frequency feature 415d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for that frequency index. Here output_dim is feature_size. 416d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 417d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if input_size cannot be inferred from static shape inference. 418d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 419d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower input_size = input_feat.get_shape().with_rank(2)[-1].value 420d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if input_size is None: 421d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower raise ValueError("Cannot infer input_size from static shape inference.") 422d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_feats = int((input_size - self._feature_size) / ( 423d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._frequency_skip)) + 1 424d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs = [] 425d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for f in range(num_feats): 426d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower cur_input = array_ops.slice(input_feat, [0, f*self._frequency_skip], 427d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower [-1, self._feature_size]) 428d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs.append(cur_input) 429d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return freq_inputs 430d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 431d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 43237fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass GridLSTMCell(core_rnn_cell.RNNCell): 433d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Grid Long short-term memory unit (LSTM) recurrent network cell. 434d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 435d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower The default is based on: 436d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Nal Kalchbrenner, Ivo Danihelka and Alex Graves 437d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "Grid Long Short-Term Memory," Proc. ICLR 2016. 438d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower http://arxiv.org/abs/1507.01526 439d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 440d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower When peephole connections are used, the implementation is based on: 441d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Tara N. Sainath and Bo Li 442d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures 443d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for LVCSR Tasks." submitted to INTERSPEECH, 2016. 444d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 445d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower The code uses optional peephole connections, shared_weights and cell clipping. 446d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 447d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 448d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def __init__(self, num_units, use_peepholes=False, 449d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower share_time_frequency_weights=False, 450d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower cell_clip=None, initializer=None, 451d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_unit_shards=1, forget_bias=1.0, 4528fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower feature_size=None, frequency_skip=None, 4539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks=None, 4549e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list=None, 4559e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list=None, 4568fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower couple_input_forget_gates=False, 457c3d99052ec49bb219f5e29567846d3af391d7b28A. Unique TensorFlower state_is_tuple=True, 45854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse=None): 459d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Initialize the parameters for an LSTM cell. 460d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 461d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 462d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_units: int, The number of units in the LSTM cell 4631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower use_peepholes: (optional) bool, default False. Set True to enable 4641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower diagonal/peephole connections. 4651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower share_time_frequency_weights: (optional) bool, default False. Set True to 4661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower enable shared cell weights between time and frequency LSTMs. 4671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower cell_clip: (optional) A float value, default None, if provided the cell 4681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state is clipped by this value prior to the cell output activation. 469d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 4701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower projection matrices, default None. 4711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_unit_shards: (optional) int, defualt 1, How to split the weight 4721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower matrix. If > 1,the weight matrix is stored across num_unit_shards. 4731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget_bias: (optional) float, default 1.0, The initial bias of the 4741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget gates, used to reduce the scale of forgetting at the beginning 475d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower of the training. 4761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature_size: (optional) int, default None, The size of the input feature 4771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower the LSTM spans over. 4781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower frequency_skip: (optional) int, default None, The amount the LSTM filter 4791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower is shifted by in frequency. 4809e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks: [required] A list of frequency blocks needed to 4819e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cover the whole input feature splitting defined by start_freqindex_list 4829e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower and end_freqindex_list. 4839e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list: [optional], list of ints, default None, The 4849e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower starting frequency index for each frequency block. 4859e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list: [optional], list of ints, default None. The ending 4869e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower frequency index for each frequency block. 4871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple_input_forget_gates: (optional) bool, default False, Whether to 4881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce 4891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower model parameters and computation cost. 4908fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower state_is_tuple: If True, accepted and returned states are 2-tuples of 4918fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower the `c_state` and `m_state`. By default (False), they are concatenated 4928fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower along the column axis. This default behavior will soon be deprecated. 49354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 49454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 49554d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 4969e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower Raises: 4979e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower ValueError: if the num_frequency_blocks list is not specified 498d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 499e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(GridLSTMCell, self).__init__(_reuse=reuse) 5008fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if not state_is_tuple: 5018fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower logging.warn("%s: Using a concatenated state is slower and will soon be " 5028fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower "deprecated. Use state_is_tuple=True.", self) 503d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_units = num_units 504d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._use_peepholes = use_peepholes 505d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._share_time_frequency_weights = share_time_frequency_weights 5068fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._couple_input_forget_gates = couple_input_forget_gates 5078fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._state_is_tuple = state_is_tuple 508d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._cell_clip = cell_clip 509d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._initializer = initializer 510d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_unit_shards = num_unit_shards 511d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._forget_bias = forget_bias 512d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._feature_size = feature_size 513d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._frequency_skip = frequency_skip 5149e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._start_freqindex_list = start_freqindex_list 5159e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._end_freqindex_list = end_freqindex_list 5169e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._num_frequency_blocks = num_frequency_blocks 5179e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._total_blocks = 0 51854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 5199e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if self._num_frequency_blocks is None: 5209e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Must specify num_frequency_blocks") 5219e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower 5229e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for block_index in range(len(self._num_frequency_blocks)): 5239e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._total_blocks += int(self._num_frequency_blocks[block_index]) 5248fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if state_is_tuple: 5258fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower state_names = "" 5269e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for block_index in range(len(self._num_frequency_blocks)): 5279e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for freq_index in range(self._num_frequency_blocks[block_index]): 5289e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower name_prefix = "state_f%02d_b%02d" % (freq_index, block_index) 5299e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) 5308fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._state_tuple_type = collections.namedtuple( 5311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "GridLSTMStateTuple", state_names.strip(",")) 5328fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._state_size = self._state_tuple_type( 5339e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower *([num_units, num_units] * self._total_blocks)) 5348fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 5358fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._state_tuple_type = None 5369e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._state_size = num_units * self._total_blocks * 2 5379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._output_size = num_units * self._total_blocks * 2 538d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 539d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 540d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def output_size(self): 541d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._output_size 542d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 543d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 544d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def state_size(self): 545d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._state_size 546d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 5478fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower @property 5488fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower def state_tuple_type(self): 5498fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower return self._state_tuple_type 5508fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower 551e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 552d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Run one step of LSTM. 553d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 554d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 5551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs: input Tensor, 2D, [batch, feature_size]. 5561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the 5571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower flag self._state_is_tuple. 558d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 559d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 560d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A tuple containing: 5611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, output_dim], Tensor representing the output of the LSTM 562d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 563d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Here output_dim is num_units. 5641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, state_size], Tensor representing the new state of LSTM 565d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 566d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 567d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if an input_size was specified and the provided inputs have 568d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower a different dimension. 569d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 570499166454f0c7dd6c724c364f6dc5d99357514b1A. Unique TensorFlower batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] 5711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower freq_inputs = self._make_tf_features(inputs) 572e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out_lst = [] 573e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out_lst = [] 574e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo for block in range(len(freq_inputs)): 575e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out_lst_current, state_out_lst_current = self._compute( 576e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo freq_inputs[block], block, state, batch_size, 577e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_is_tuple=self._state_is_tuple) 578e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out_lst.extend(m_out_lst_current) 579e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out_lst.extend(state_out_lst_current) 580e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._state_is_tuple: 581e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out = self._state_tuple_type(*state_out_lst) 582e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 583e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out = array_ops.concat(state_out_lst, 1) 584e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out = array_ops.concat(m_out_lst, 1) 5851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower return m_out, state_out 5861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 5879e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower def _compute(self, freq_inputs, block, state, batch_size, 5889e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_prefix="state", 5891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_is_tuple=True): 5901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Run the actual computation of one step LSTM. 5911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 5921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Args: 5931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower freq_inputs: list of Tensors, 2D, [batch, feature_size]. 5949e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block: int, current frequency block index to process. 5951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on 5961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower the flag state_is_tuple. 5971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower batch_size: int32, batch size. 5981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_prefix: (optional) string, name prefix for states, defaults to 5991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "state". 6001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_is_tuple: boolean, indicates whether the state is a tuple or Tensor. 6011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Returns: 6031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower A tuple, containing: 6041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A list of [batch, output_dim] Tensors, representing the output of the 6051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower LSTM given the inputs and state. 6061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A list of [batch, state_size] Tensors, representing the LSTM state 6071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower values given the inputs and previous state. 6081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 609e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo sigmoid = math_ops.sigmoid 610e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo tanh = math_ops.tanh 6118fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower num_gates = 3 if self._couple_input_forget_gates else 4 6121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower dtype = freq_inputs[0].dtype 613d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower actual_input_size = freq_inputs[0].get_shape().as_list()[1] 6141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower concat_w_f = _get_concat_variable( 6169e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_f_%d" % block, [actual_input_size + 2 * self._num_units, 6179e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_gates * self._num_units], 6181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower dtype, self._num_unit_shards) 6191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower b_f = vs.get_variable( 6204ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist "B_f_%d" % block, 6214ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist shape=[num_gates * self._num_units], 6224ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist initializer=init_ops.zeros_initializer(), 6234ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist dtype=dtype) 6241855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if not self._share_time_frequency_weights: 6251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower concat_w_t = _get_concat_variable( 6269e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_t_%d" % block, [actual_input_size + 2 * self._num_units, 6279e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_gates * self._num_units], 628d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower dtype, self._num_unit_shards) 6291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower b_t = vs.get_variable( 6304ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist "B_t_%d" % block, 6314ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist shape=[num_gates * self._num_units], 6324ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist initializer=init_ops.zeros_initializer(), 6334ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist dtype=dtype) 634d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 6351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 6361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Diagonal connections 6371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if not self._couple_input_forget_gates: 6381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqf = vs.get_variable( 6399e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 6401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqt = vs.get_variable( 6419e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_freqt_%d"% block, shape=[self._num_units], dtype=dtype) 6421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqf = vs.get_variable( 6439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 6441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqt = vs.get_variable( 6459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 6461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqf = vs.get_variable( 6479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 6481855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqt = vs.get_variable( 6499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 6501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if not self._share_time_frequency_weights: 6518fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if not self._couple_input_forget_gates: 6521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_timef = vs.get_variable( 6539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 6541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_timet = vs.get_variable( 6559e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 6561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timef = vs.get_variable( 6579e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 6581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timet = vs.get_variable( 6599e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 6601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timef = vs.get_variable( 6619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 6621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timet = vs.get_variable( 6639e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 6641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # initialize the first freq state to be zero 6661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) 6671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) 6681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower for freq_index in range(len(freq_inputs)): 6691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if state_is_tuple: 6709e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block) 6711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_time = getattr(state, name_prefix + "_c") 6721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_time = getattr(state, name_prefix + "_m") 6731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 6741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_time = array_ops.slice( 6751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state, [0, 2 * freq_index * self._num_units], 6761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower [-1, self._num_units]) 6771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_time = array_ops.slice( 6781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state, [0, (2 * freq_index + 1) * self._num_units], 6791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower [-1, self._num_units]) 6801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # i = input_gate, j = new_input, f = forget_gate, o = output_gate 6820e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower cell_inputs = array_ops.concat( 683d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1) 6841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # F-LSTM 6861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower lstm_matrix_freq = nn_ops.bias_add(math_ops.matmul(cell_inputs, 6871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower concat_w_f), b_f) 6881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._couple_input_forget_gates: 689a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_freq, j_freq, o_freq = array_ops.split( 690a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) 6911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq = None 6921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 693a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_freq, j_freq, f_freq, o_freq = array_ops.split( 694a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) 6951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # T-LSTM 6961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._share_time_frequency_weights: 6971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_time = i_freq 6981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower j_time = j_freq 6991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time = f_freq 7001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower o_time = o_freq 7011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower lstm_matrix_time = nn_ops.bias_add(math_ops.matmul(cell_inputs, 7031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower concat_w_t), b_t) 7048fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if self._couple_input_forget_gates: 705a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_time, j_time, o_time = array_ops.split( 706a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) 7071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time = None 7088fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 709a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_time, j_time, f_time, o_time = array_ops.split( 710a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) 711d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 7121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # F-LSTM c_freq 7131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # input gate activations 7141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 7151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_freq_g = sigmoid(i_freq + 7161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqf * c_prev_freq + 7171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqt * c_prev_time) 7181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_freq_g = sigmoid(i_freq) 7201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # forget gate activations 7211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._couple_input_forget_gates: 7221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq_g = 1.0 - i_freq_g 7231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 724d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._use_peepholes: 7251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq_g = sigmoid(f_freq + self._forget_bias + 7261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqf * c_prev_freq + 7271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqt * c_prev_time) 7281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq_g = sigmoid(f_freq + self._forget_bias) 7301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # cell state 7311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq) 7321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._cell_clip is not None: 7331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: disable=invalid-unary-operand-type 7341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip, 7351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._cell_clip) 7361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: enable=invalid-unary-operand-type 7371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 7381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # T-LSTM c_freq 7391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # input gate activations 7401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 7411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._share_time_frequency_weights: 7421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_time_g = sigmoid(i_time + 7438fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower w_i_diag_freqf * c_prev_freq + 7448fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower w_i_diag_freqt * c_prev_time) 7458fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 7461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_time_g = sigmoid(i_time + 7471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timef * c_prev_freq + 7481855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timet * c_prev_time) 7491855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_time_g = sigmoid(i_time) 7511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # forget gate activations 7521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._couple_input_forget_gates: 7531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time_g = 1.0 - i_time_g 7541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 755d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._use_peepholes: 756d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._share_time_frequency_weights: 7571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time_g = sigmoid(f_time + self._forget_bias + 7581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqf * c_prev_freq + 7591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqt * c_prev_time) 760d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower else: 7611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time_g = sigmoid(f_time + self._forget_bias + 7621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_timef * c_prev_freq + 7631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_timet * c_prev_time) 7648fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 7651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time_g = sigmoid(f_time + self._forget_bias) 7661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # cell state 7671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time) 7681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._cell_clip is not None: 7691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: disable=invalid-unary-operand-type 7701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_time = clip_ops.clip_by_value(c_time, -self._cell_clip, 7711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._cell_clip) 7721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: enable=invalid-unary-operand-type 7731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 7741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # F-LSTM m_freq 7751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 7761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_freq = sigmoid(o_freq + 7771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqf * c_freq + 7781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqt * c_time) * tanh(c_freq) 7791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_freq = sigmoid(o_freq) * tanh(c_freq) 781d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 7821855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # T-LSTM m_time 7831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 7841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._share_time_frequency_weights: 7851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_time = sigmoid(o_time + 7868fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower w_o_diag_freqf * c_freq + 7871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqt * c_time) * tanh(c_time) 788d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower else: 7891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_time = sigmoid(o_time + 7901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timef * c_freq + 7911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timet * c_time) * tanh(c_time) 7928fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 7931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_time = sigmoid(o_time) * tanh(c_time) 7941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 7951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_freq = m_freq 7961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_freq = c_freq 7971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Concatenate the outputs for T-LSTM and F-LSTM for each shift 7981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if freq_index == 0: 7991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_out_lst = [c_time, m_time] 8001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_out_lst = [m_time, m_freq] 8011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 8021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_out_lst.extend([c_time, m_time]) 8031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_out_lst.extend([m_time, m_freq]) 804d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 8051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower return m_out_lst, state_out_lst 8061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 8071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower def _make_tf_features(self, input_feat, slice_offset=0): 808d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Make the frequency features. 809d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 810d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 8111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower input_feat: input Tensor, 2D, [batch, num_units]. 8121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower slice_offset: (optional) Python int, default 0, the slicing offset is only 8131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower used for the backward processing in the BidirectionalGridLSTMCell. It 8141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower specifies a different starting point instead of always 0 to enable the 8151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forward and backward processing look at different frequency blocks. 816d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 817d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 818d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A list of frequency features, with each element containing: 8191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, output_dim], Tensor representing the time-frequency 8201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature for that frequency index. Here output_dim is feature_size. 821d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 822d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if input_size cannot be inferred from static shape inference. 823d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 824d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower input_size = input_feat.get_shape().with_rank(2)[-1].value 825d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if input_size is None: 826d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower raise ValueError("Cannot infer input_size from static shape inference.") 8271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if slice_offset > 0: 8281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Padding to the end 8291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs = array_ops.pad( 8301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower input_feat, array_ops.constant([0, 0, 0, slice_offset], shape=[2, 2], 8311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower dtype=dtypes.int32), 8321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "CONSTANT") 8331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower elif slice_offset < 0: 8341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Padding to the front 8351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs = array_ops.pad( 8361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower input_feat, array_ops.constant([0, 0, -slice_offset, 0], shape=[2, 2], 8371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower dtype=dtypes.int32), 8381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "CONSTANT") 8391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower slice_offset = 0 8401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 8411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs = input_feat 842d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs = [] 8439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if not self._start_freqindex_list: 8449e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if len(self._num_frequency_blocks) != 1: 8459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Length of num_frequency_blocks" 8469e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " is not 1, but instead is %d", 8479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._num_frequency_blocks)) 8489e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_feats = int((input_size - self._feature_size) / ( 8499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._frequency_skip)) + 1 8509e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if num_feats != self._num_frequency_blocks[0]: 8519e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError( 8529e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "Invalid num_frequency_blocks, requires %d but gets %d, please" 8539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " check the input size and filter config are correct." % ( 8549e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._num_frequency_blocks[0], num_feats)) 8559e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs = [] 8569e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for f in range(num_feats): 8579e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cur_input = array_ops.slice( 8589e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower inputs, [0, slice_offset + f * self._frequency_skip], 8599e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower [-1, self._feature_size]) 8609e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs.append(cur_input) 8619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower freq_inputs.append(block_inputs) 8629e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower else: 8639e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if len(self._start_freqindex_list) != len(self._end_freqindex_list): 8649e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Length of start and end freqindex_list" 8659e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " does not match %d %d", 8669e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._start_freqindex_list), 8679e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._end_freqindex_list)) 8689e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if len(self._num_frequency_blocks) != len(self._start_freqindex_list): 8699e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Length of num_frequency_blocks" 8709e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " is not equal to start_freqindex_list %d %d", 8719e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._num_frequency_blocks), 8729e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._start_freqindex_list)) 8739e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for b in range(len(self._start_freqindex_list)): 8749e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_index = self._start_freqindex_list[b] 8759e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_index = self._end_freqindex_list[b] 8769e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cur_size = end_index - start_index 8779e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_feats = int((cur_size - self._feature_size) / ( 8789e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._frequency_skip)) + 1 8799e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if block_feats != self._num_frequency_blocks[b]: 8809e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError( 8819e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "Invalid num_frequency_blocks, requires %d but gets %d, please" 8829e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " check the input size and filter config are correct." % ( 8839e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._num_frequency_blocks[b], block_feats)) 8849e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs = [] 8859e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for f in range(block_feats): 8869e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cur_input = array_ops.slice( 8879e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower inputs, [0, start_index + slice_offset + f * 8889e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._frequency_skip], 8899e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower [-1, self._feature_size]) 8909e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs.append(cur_input) 8919e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower freq_inputs.append(block_inputs) 892d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return freq_inputs 8935cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 8945cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 8951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerclass BidirectionalGridLSTMCell(GridLSTMCell): 8961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Bidirectional GridLstm cell. 8971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 8981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower The bidirection connection is only used in the frequency direction, which 8991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower hence doesn't affect the time direction's real-time processing that is 9001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower required for online recognition systems. 9011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower The current implementation uses different weights for the two directions. 9021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 9031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower def __init__(self, num_units, use_peepholes=False, 9051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower share_time_frequency_weights=False, 9061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower cell_clip=None, initializer=None, 9071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_unit_shards=1, forget_bias=1.0, 9081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature_size=None, frequency_skip=None, 9099e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks=None, 9109e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list=None, 9119e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list=None, 9121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple_input_forget_gates=False, 91354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo backward_slice_offset=0, 91454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse=None): 9151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Initialize the parameters for an LSTM cell. 9161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Args: 9181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_units: int, The number of units in the LSTM cell 9191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower use_peepholes: (optional) bool, default False. Set True to enable 9201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower diagonal/peephole connections. 9211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower share_time_frequency_weights: (optional) bool, default False. Set True to 9221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower enable shared cell weights between time and frequency LSTMs. 9231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower cell_clip: (optional) A float value, default None, if provided the cell 9241855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state is clipped by this value prior to the cell output activation. 9251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 9261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower projection matrices, default None. 9271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_unit_shards: (optional) int, defualt 1, How to split the weight 9281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower matrix. If > 1,the weight matrix is stored across num_unit_shards. 9291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget_bias: (optional) float, default 1.0, The initial bias of the 9301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget gates, used to reduce the scale of forgetting at the beginning 9311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower of the training. 9321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature_size: (optional) int, default None, The size of the input feature 9331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower the LSTM spans over. 9341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower frequency_skip: (optional) int, default None, The amount the LSTM filter 9351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower is shifted by in frequency. 9369e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks: [required] A list of frequency blocks needed to 9379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cover the whole input feature splitting defined by start_freqindex_list 9389e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower and end_freqindex_list. 9399e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list: [optional], list of ints, default None, The 9409e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower starting frequency index for each frequency block. 9419e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list: [optional], list of ints, default None. The ending 9429e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower frequency index for each frequency block. 9431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple_input_forget_gates: (optional) bool, default False, Whether to 9441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce 9451855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower model parameters and computation cost. 9461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower backward_slice_offset: (optional) int32, default 0, the starting offset to 9471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower slice the feature for backward processing. 94854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 94954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 95054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 9511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 9521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower super(BidirectionalGridLSTMCell, self).__init__( 9531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_units, use_peepholes, share_time_frequency_weights, cell_clip, 9541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower initializer, num_unit_shards, forget_bias, feature_size, frequency_skip, 9559e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks, start_freqindex_list, end_freqindex_list, 956a5493db6082a6fe29ed11e71f2aab5bfdf5f98c7A. Unique TensorFlower couple_input_forget_gates, True, reuse) 9571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._backward_slice_offset = int(backward_slice_offset) 9581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_names = "" 9591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower for direction in ["fwd", "bwd"]: 9609e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for block_index in range(len(self._num_frequency_blocks)): 9619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for freq_index in range(self._num_frequency_blocks[block_index]): 9629e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index, 9639e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_index) 9649e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) 9651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._state_tuple_type = collections.namedtuple( 9661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "BidirectionalGridLSTMStateTuple", state_names.strip(",")) 9671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._state_size = self._state_tuple_type( 9689e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower *([num_units, num_units] * self._total_blocks * 2)) 9699e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._output_size = 2 * num_units * self._total_blocks * 2 9701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 971e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 9721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Run one step of LSTM. 9731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Args: 9751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs: input Tensor, 2D, [batch, num_units]. 9761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state: tuple of Tensors, 2D, [batch, state_size]. 9771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Returns: 9791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower A tuple containing: 9801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, output_dim], Tensor representing the output of the LSTM 9811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower after reading "inputs" when previous state was "state". 9821855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Here output_dim is num_units. 9831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, state_size], Tensor representing the new state of LSTM 9841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower after reading "inputs" when previous state was "state". 9851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Raises: 9861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower ValueError: if an input_size was specified and the provided inputs have 9871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower a different dimension. 9881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 989499166454f0c7dd6c724c364f6dc5d99357514b1A. Unique TensorFlower batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] 9901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower fwd_inputs = self._make_tf_features(inputs) 9911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._backward_slice_offset: 9921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset) 9931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 9941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower bwd_inputs = fwd_inputs 9951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Forward processing 997e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("fwd"): 998e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_m_out_lst = [] 999e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_state_out_lst = [] 1000e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo for block in range(len(fwd_inputs)): 1001e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( 1002e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_inputs[block], block, state, batch_size, 1003e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_prefix="fwd_state", state_is_tuple=True) 1004e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_m_out_lst.extend(fwd_m_out_lst_current) 1005e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_state_out_lst.extend(fwd_state_out_lst_current) 1006e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Backward processing 1007e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_m_out_lst = [] 1008e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_state_out_lst = [] 1009e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("bwd"): 1010e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo for block in range(len(bwd_inputs)): 1011e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Reverse the blocks 1012e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_inputs_reverse = bwd_inputs[block][::-1] 1013e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( 1014e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_inputs_reverse, block, state, batch_size, 1015e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_prefix="bwd_state", state_is_tuple=True) 1016e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_m_out_lst.extend(bwd_m_out_lst_current) 1017e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_state_out_lst.extend(bwd_state_out_lst_current) 10181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst)) 10191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Outputs are always concated as it is never used separately. 10200e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1) 10211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower return m_out, state_out 10221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 10231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 10245cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: disable=protected-access 102537fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie_linear = core_rnn_cell_impl._linear 10265cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: enable=protected-access 10275cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10285cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 102937fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass AttentionCellWrapper(core_rnn_cell.RNNCell): 10305cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """Basic attention cell wrapper. 10315cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 1032ec3f4d62979ef1e70e8e12e2568b13dad45fd39eA. Unique TensorFlower Implementation based on https://arxiv.org/abs/1409.0473. 10335cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """ 10345cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None, 1036c3d99052ec49bb219f5e29567846d3af391d7b28A. Unique TensorFlower input_size=None, state_is_tuple=True, reuse=None): 10375cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """Create a cell with attention. 10385cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10395cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower Args: 10405cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower cell: an RNNCell, an attention is added to it. 10415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_length: integer, the size of an attention window. 10425cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_size: integer, the size of an attention vector. Equal to 10435cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower cell.output_size by default. 10445cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_vec_size: integer, the number of convolutional features calculated 10455cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower on attention state and a size of the hidden layer built from 10465cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower base cell state. Equal attn_size to by default. 10475cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower input_size: integer, the size of a hidden linear layer, 10485cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower built from inputs and attention. Derived from the input tensor 10495cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower by default. 10505cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower state_is_tuple: If True, accepted and returned states are n-tuples, where 10515cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower `n = len(cells)`. By default (False), the states are all 10525cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower concatenated along the column axis. 105354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 105454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 105554d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 10565cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10575cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower Raises: 10585cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower TypeError: if cell is not an RNNCell. 10595cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower ValueError: if cell returns a state tuple but the flag 10605cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower `state_is_tuple` is `False` or if attn_length is zero or less. 10615cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """ 1062e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(AttentionCellWrapper, self).__init__(_reuse=reuse) 106337fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie if not isinstance(cell, core_rnn_cell.RNNCell): 10645cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower raise TypeError("The parameter cell is not RNNCell.") 10654c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenech if nest.is_sequence(cell.state_size) and not state_is_tuple: 10665cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower raise ValueError("Cell returns tuple of states, but the flag " 10675cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower "state_is_tuple is not set. State size is: %s" 10685cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower % str(cell.state_size)) 10695cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if attn_length <= 0: 10705cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower raise ValueError("attn_length should be greater than zero, got %s" 10715cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower % str(attn_length)) 10725cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if not state_is_tuple: 10735cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower logging.warn( 10745cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower "%s: Using a concatenated state is slower and will soon be " 1075d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower "deprecated. Use state_is_tuple=True.", self) 10765cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if attn_size is None: 10775cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_size = cell.output_size 10785cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if attn_vec_size is None: 10795cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_vec_size = attn_size 10805cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._state_is_tuple = state_is_tuple 10815cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._cell = cell 10825cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_vec_size = attn_vec_size 10835cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._input_size = input_size 10845cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_size = attn_size 10855cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_length = attn_length 108654d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 10875cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10885cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower @property 10895cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def state_size(self): 10905cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower size = (self._cell.state_size, self._attn_size, 10915cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_size * self._attn_length) 10925cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if self._state_is_tuple: 10935cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return size 10945cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower else: 10955cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return sum(list(size)) 10965cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 10975cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower @property 10985cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def output_size(self): 10995cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return self._attn_size 11005cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 1101e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 11025cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """Long short-term memory cell with attention (LSTMA).""" 1103e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._state_is_tuple: 1104e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state, attns, attn_states = state 1105e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 1106e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo states = state 1107e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size]) 1108e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo attns = array_ops.slice( 1109e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo states, [0, self._cell.state_size], [-1, self._attn_size]) 1110e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo attn_states = array_ops.slice( 1111e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo states, [0, self._cell.state_size + self._attn_size], 1112e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [-1, self._attn_size * self._attn_length]) 1113e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo attn_states = array_ops.reshape(attn_states, 1114e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [-1, self._attn_length, self._attn_size]) 1115e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo input_size = self._input_size 1116e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if input_size is None: 1117e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo input_size = inputs.get_shape().as_list()[1] 1118e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo inputs = _linear([inputs, attns], input_size, True) 1119e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo lstm_output, new_state = self._cell(inputs, state) 1120e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._state_is_tuple: 1121e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state_cat = array_ops.concat(nest.flatten(new_state), 1) 1122e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 1123e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state_cat = new_state 1124e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_attns, new_attn_states = self._attention(new_state_cat, attn_states) 1125e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("attn_output_projection"): 1126e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo output = _linear([lstm_output, new_attns], self._attn_size, True) 1127e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_attn_states = array_ops.concat( 1128e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [new_attn_states, array_ops.expand_dims(output, 1)], 1) 1129e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_attn_states = array_ops.reshape( 1130e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_attn_states, [-1, self._attn_length * self._attn_size]) 1131e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state = (new_state, new_attns, new_attn_states) 1132e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if not self._state_is_tuple: 1133e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state = array_ops.concat(list(new_state), 1) 1134e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo return output, new_state 11355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 11365cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def _attention(self, query, attn_states): 1137e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo conv2d = nn_ops.conv2d 1138e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo reduce_sum = math_ops.reduce_sum 1139e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo softmax = nn_ops.softmax 1140e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo tanh = math_ops.tanh 1141e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo 114292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope("attention"): 114392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo k = vs.get_variable( 114492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo "attn_w", [1, 1, self._attn_size, self._attn_vec_size]) 114592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo v = vs.get_variable("attn_v", [self._attn_vec_size]) 11465cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower hidden = array_ops.reshape(attn_states, 11475cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower [-1, self._attn_length, 1, self._attn_size]) 11485cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") 11495cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower y = _linear(query, self._attn_vec_size, True) 11505cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) 11515cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) 11525cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower a = softmax(s) 11535cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower d = reduce_sum( 11545cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2]) 11555cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_attns = array_ops.reshape(d, [-1, self._attn_size]) 11565cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1]) 11575cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return new_attns, new_attn_states 115834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 115934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 116037fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xieclass LayerNormBasicLSTMCell(core_rnn_cell.RNNCell): 116134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """LSTM unit with layer normalization and recurrent dropout. 116234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 116334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower This class adds layer normalization and recurrent dropout to a 116434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower basic LSTM unit. Layer normalization implementation is based on: 116534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 116634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower https://arxiv.org/abs/1607.06450. 116734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 116834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower "Layer Normalization" 116934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 117034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 117134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower and is applied before the internal nonlinearities. 117234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Recurrent dropout is base on: 117334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 117434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower https://arxiv.org/abs/1603.05118 117534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 117634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower "Recurrent Dropout without Memory Loss" 117734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth. 117834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """ 117934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 118034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def __init__(self, num_units, forget_bias=1.0, 118134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower input_size=None, activation=math_ops.tanh, 118234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower layer_norm=True, norm_gain=1.0, norm_shift=0.0, 118354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo dropout_keep_prob=1.0, dropout_prob_seed=None, 118454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse=None): 118534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """Initializes the basic LSTM cell. 118634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 118734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Args: 118834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower num_units: int, The number of units in the LSTM cell. 118934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower forget_bias: float, The bias added to forget gates (see above). 119034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower input_size: Deprecated and unused. 119134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower activation: Activation function of the inner states. 119234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower layer_norm: If `True`, layer normalization will be applied. 119334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower norm_gain: float, The layer normalization gain initial value. If 119434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower `layer_norm` has been set to `False`, this argument will be ignored. 119534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower norm_shift: float, The layer normalization shift initial value. If 119634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower `layer_norm` has been set to `False`, this argument will be ignored. 119734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower dropout_keep_prob: unit Tensor or float between 0 and 1 representing the 119834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower recurrent dropout probability value. If float and 1.0, no dropout will 119934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower be applied. 120034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower dropout_prob_seed: (optional) integer, the randomness seed. 120154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 120254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 120354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 120434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """ 1205e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse) 120634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 120734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower if input_size is not None: 120834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower logging.warn("%s: The input_size parameter is deprecated.", self) 120934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 121034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._num_units = num_units 121134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._activation = activation 121234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._forget_bias = forget_bias 121334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._keep_prob = dropout_keep_prob 121434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._seed = dropout_prob_seed 121534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._layer_norm = layer_norm 121634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._g = norm_gain 121734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._b = norm_shift 121854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 121934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 122034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower @property 122134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def state_size(self): 122237fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units) 122334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 122434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower @property 122534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def output_size(self): 122634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower return self._num_units 122734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 122834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def _norm(self, inp, scope): 122992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo shape = inp.get_shape()[-1:] 123092da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo gamma_init = init_ops.constant_initializer(self._g) 123192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo beta_init = init_ops.constant_initializer(self._b) 123292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope(scope): 123392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo # Initialize beta and gamma for use by layer_norm. 123492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo vs.get_variable("gamma", shape=shape, initializer=gamma_init) 123592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo vs.get_variable("beta", shape=shape, initializer=beta_init) 123692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo normalized = layers.layer_norm(inp, reuse=True, scope=scope) 123792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo return normalized 123892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo 123992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo def _linear(self, args): 124034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower out_size = 4 * self._num_units 124134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower proj_size = args.get_shape()[-1] 124292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo weights = vs.get_variable("weights", [proj_size, out_size]) 124392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo out = math_ops.matmul(args, weights) 124492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo if not self._layer_norm: 124592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo bias = vs.get_variable("biases", [out_size]) 124692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo out = nn_ops.bias_add(out, bias) 124792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo return out 124834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1249e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 125034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """LSTM cell with layer normalization and recurrent dropout.""" 1251e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo c, h = state 1252e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo args = array_ops.concat([inputs, h], 1) 1253e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat = self._linear(args) 125434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1255e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) 1256e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._layer_norm: 1257e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo i = self._norm(i, "input") 1258e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo j = self._norm(j, "transform") 1259e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo f = self._norm(f, "forget") 1260e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo o = self._norm(o, "output") 126134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1262e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo g = self._activation(j) 1263e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: 1264e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) 126534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1266e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_c = (c * math_ops.sigmoid(f + self._forget_bias) 1267e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo + math_ops.sigmoid(i) * g) 1268e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._layer_norm: 1269e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_c = self._norm(new_c, "state") 1270e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_h = self._activation(new_c) * math_ops.sigmoid(o) 127134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1272e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) 1273e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo return new_h, new_state 1274bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1275bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 12761e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlowerclass NASCell(core_rnn_cell.RNNCell): 12771e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """Neural Architecture Search (NAS) recurrent network cell. 12781e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 12791e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower This implements the recurrent cell from the paper: 12801e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 12811e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower https://arxiv.org/abs/1611.01578 12821e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 12831e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Barret Zoph and Quoc V. Le. 12841e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017. 12851e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 12861e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower The class uses an optional projection layer. 12871e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """ 12881e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 12891e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower def __init__(self, num_units, num_proj=None, 129054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo use_biases=False, reuse=None): 12911e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """Initialize the parameters for a NAS cell. 12921e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 12931e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Args: 12941e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower num_units: int, The number of units in the NAS cell 12951e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower num_proj: (optional) int, The output dimensionality for the projection 12961e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower matrices. If None, no projection is performed. 12971e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower use_biases: (optional) bool, If True then use biases within the cell. This 12981e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower is False by default. 129954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 130054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 130154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 13021e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """ 1303e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(NASCell, self).__init__(_reuse=reuse) 13041e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._num_units = num_units 13051e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._num_proj = num_proj 13061e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._use_biases = use_biases 130754d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 13081e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 13091e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower if num_proj is not None: 13101e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_proj) 13111e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._output_size = num_proj 13121e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower else: 13131e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._state_size = core_rnn_cell.LSTMStateTuple(num_units, num_units) 13141e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._output_size = num_units 13151e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 13161e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower @property 13171e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower def state_size(self): 13181e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower return self._state_size 13191e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 13201e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower @property 13211e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower def output_size(self): 13221e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower return self._output_size 13231e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 1324e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 13251e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """Run one step of NAS Cell. 13261e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 13271e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Args: 13281e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower inputs: input Tensor, 2D, batch x num_units. 13291e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower state: This must be a tuple of state Tensors, both `2-D`, with column 13301e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower sizes `c_state` and `m_state`. 13311e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 13321e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Returns: 13331e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower A tuple containing: 13341e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower - A `2-D, [batch x output_dim]`, Tensor representing the output of the 13351e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower NAS Cell after reading `inputs` when previous state was `state`. 13361e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Here output_dim is: 13371e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower num_proj if num_proj was set, 13381e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower num_units otherwise. 13391e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower - Tensor(s) representing the new state of NAS Cell after reading `inputs` 13401e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower when the previous state was `state`. Same type and shape(s) as `state`. 13411e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 13421e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Raises: 13431e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower ValueError: If input size cannot be inferred from inputs via 13441e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower static shape inference. 13451e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """ 13461e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower sigmoid = math_ops.sigmoid 13471e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower tanh = math_ops.tanh 13481e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower relu = nn_ops.relu 13491e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 13501e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower num_proj = self._num_units if self._num_proj is None else self._num_proj 13511e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 13521e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower (c_prev, m_prev) = state 13531e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 13541e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower dtype = inputs.dtype 13551e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower input_size = inputs.get_shape().with_rank(2)[1] 13561e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower if input_size.value is None: 13571e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1358e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Variables for the NAS cell. W_m is all matrices multiplying the 1359e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # hiddenstate and W_inputs is all matrices multiplying the inputs. 1360e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat_w_m = vs.get_variable( 1361e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "recurrent_weights", [num_proj, 8 * self._num_units], 1362e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype) 1363e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat_w_inputs = vs.get_variable( 1364e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "weights", [input_size.value, 8 * self._num_units], 1365e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype) 1366e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1367e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_matrix = math_ops.matmul(m_prev, concat_w_m) 1368e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) 1369e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1370e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_biases: 1371e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo b = vs.get_variable( 1372e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "bias", 1373e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo shape=[8 * self._num_units], 1374e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=init_ops.zeros_initializer(), 1375e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype=dtype) 1376e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_matrix = nn_ops.bias_add(m_matrix, b) 1377e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1378e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # The NAS cell branches into 8 different splits for both the hiddenstate 1379e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # and the input 1380e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, 1381e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo value=m_matrix) 1382e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8, 1383e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo value=inputs_matrix) 1384e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1385e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # First layer 1386e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) 1387e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1]) 1388e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2]) 1389e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3]) 1390e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4]) 1391e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5]) 1392e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6]) 1393e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7]) 1394e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1395e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Second layer 1396e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l2_0 = tanh(layer1_0 * layer1_1) 1397e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l2_1 = tanh(layer1_2 + layer1_3) 1398e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l2_2 = tanh(layer1_4 * layer1_5) 1399e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l2_3 = sigmoid(layer1_6 + layer1_7) 1400e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1401e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Inject the cell 1402e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l2_0 = tanh(l2_0 + c_prev) 1403e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1404e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Third layer 1405e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l3_0_pre = l2_0 * l2_1 1406e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_c = l3_0_pre # create new cell 1407e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l3_0 = l3_0_pre 1408e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l3_1 = tanh(l2_2 + l2_3) 1409e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1410e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Final layer 1411e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_m = tanh(l3_0 * l3_1) 1412e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1413e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Projection layer if specified 1414e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._num_proj is not None: 1415e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat_w_proj = vs.get_variable( 1416e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "projection_weights", [self._num_units, self._num_proj], 14171e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower dtype) 1418e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_m = math_ops.matmul(new_m, concat_w_proj) 14191e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 1420e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m) 1421e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo return new_m, new_state 14221e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 14231e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 1424d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlowerclass UGRNNCell(core_rnn_cell.RNNCell): 1425d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Update Gate Recurrent Neural Network (UGRNN) cell. 1426d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1427d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Compromise between a LSTM/GRU and a vanilla RNN. There is only one 1428d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower gate, and that is to determine whether the unit should be 1429d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower integrating or computing instantaneously. This is the recurrent 1430d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower idea of the feedforward Highway Network. 1431d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1432d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower This implements the recurrent cell from the paper: 1433d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1434d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower https://arxiv.org/abs/1611.09913 1435d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1436d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo. 1437d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. 1438d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1439d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1440d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower def __init__(self, num_units, initializer=None, forget_bias=1.0, 1441d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower activation=math_ops.tanh, reuse=None): 1442d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Initialize the parameters for an UGRNN cell. 1443d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1444d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Args: 1445d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower num_units: int, The number of units in the UGRNN cell 1446d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower initializer: (optional) The initializer to use for the weight matrices. 1447d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower forget_bias: (optional) float, default 1.0, The initial bias of the 1448d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower forget gate, used to reduce the scale of forgetting at the beginning 1449d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower of the training. 1450d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower activation: (optional) Activation function of the inner states. 1451d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Default is `tf.tanh`. 1452d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower reuse: (optional) Python boolean describing whether to reuse variables 1453d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower in an existing scope. If not `True`, and the existing scope already has 1454d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower the given variables, an error is raised. 1455d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1456e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(UGRNNCell, self).__init__(_reuse=reuse) 1457d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._num_units = num_units 1458d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._initializer = initializer 1459d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._forget_bias = forget_bias 1460d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._activation = activation 1461d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._reuse = reuse 1462d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1463d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower @property 1464d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower def state_size(self): 1465d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return self._num_units 1466d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1467d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower @property 1468d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower def output_size(self): 1469d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return self._num_units 1470d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1471e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 1472d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Run one step of UGRNN. 1473d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1474d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Args: 1475d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower inputs: input Tensor, 2D, batch x input size. 1476d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower state: state Tensor, 2D, batch x num units. 1477d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1478d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Returns: 1479d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_output: batch x num units, Tensor representing the output of the UGRNN 1480d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower after reading `inputs` when previous state was `state`. Identical to 1481d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower `new_state`. 1482d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_state: batch x num units, Tensor representing the state of the UGRNN 1483d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower after reading `inputs` when previous state was `state`. 1484d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1485d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Raises: 1486d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower ValueError: If input size cannot be inferred from inputs via 1487d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower static shape inference. 1488d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1489d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower sigmoid = math_ops.sigmoid 1490d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1491d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower input_size = inputs.get_shape().with_rank(2)[1] 1492d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower if input_size.value is None: 1493d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1494d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1495e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope(vs.get_variable_scope(), 1496e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=self._initializer): 1497d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower cell_inputs = array_ops.concat([inputs, state], 1) 1498d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower rnn_matrix = _linear(cell_inputs, 2 * self._num_units, True) 1499d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1500d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower [g_act, c_act] = array_ops.split( 1501d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower axis=1, num_or_size_splits=2, value=rnn_matrix) 1502d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1503d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower c = self._activation(c_act) 1504d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower g = sigmoid(g_act + self._forget_bias) 1505d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_state = g * state + (1.0 - g) * c 1506d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_output = new_state 1507d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1508d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return new_output, new_state 1509d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1510d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1511d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlowerclass IntersectionRNNCell(core_rnn_cell.RNNCell): 1512d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Intersection Recurrent Neural Network (+RNN) cell. 1513d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1514d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Architecture with coupled recurrent gate as well as coupled depth 1515d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower gate, designed to improve information flow through stacked RNNs. As the 1516d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower architecture uses depth gating, the dimensionality of the depth 1517d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower output (y) also should not change through depth (input size == output size). 1518d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower To achieve this, the first layer of a stacked Intersection RNN projects 1519d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower the inputs to N (num units) dimensions. Therefore when initializing an 1520d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower IntersectionRNNCell, one should set `num_in_proj = N` for the first layer 1521d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower and use default settings for subsequent layers. 1522d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1523d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower This implements the recurrent cell from the paper: 1524d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1525d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower https://arxiv.org/abs/1611.09913 1526d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1527d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo. 1528d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. 1529d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1530d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower The Intersection RNN is built for use in deeply stacked 1531d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower RNNs so it may not achieve best performance with depth 1. 1532d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1533d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1534d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower def __init__(self, num_units, num_in_proj=None, 1535d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower initializer=None, forget_bias=1.0, 1536d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower y_activation=nn_ops.relu, reuse=None): 1537d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Initialize the parameters for an +RNN cell. 1538d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1539d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Args: 1540d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower num_units: int, The number of units in the +RNN cell 1541d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower num_in_proj: (optional) int, The input dimensionality for the RNN. 1542d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower If creating the first layer of an +RNN, this should be set to 1543d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower `num_units`. Otherwise, this should be set to `None` (default). 1544d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower If `None`, dimensionality of `inputs` should be equal to `num_units`, 1545d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower otherwise ValueError is thrown. 1546d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower initializer: (optional) The initializer to use for the weight matrices. 1547d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower forget_bias: (optional) float, default 1.0, The initial bias of the 1548d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower forget gates, used to reduce the scale of forgetting at the beginning 1549d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower of the training. 1550d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower y_activation: (optional) Activation function of the states passed 1551d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower through depth. Default is 'tf.nn.relu`. 1552d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower reuse: (optional) Python boolean describing whether to reuse variables 1553d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower in an existing scope. If not `True`, and the existing scope already has 1554d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower the given variables, an error is raised. 1555d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1556e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(IntersectionRNNCell, self).__init__(_reuse=reuse) 1557d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._num_units = num_units 1558d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._initializer = initializer 1559d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._forget_bias = forget_bias 1560d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._num_input_proj = num_in_proj 1561d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._y_activation = y_activation 1562d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._reuse = reuse 1563d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1564d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower @property 1565d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower def state_size(self): 1566d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return self._num_units 1567d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1568d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower @property 1569d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower def output_size(self): 1570d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return self._num_units 1571d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1572e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 1573d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Run one step of the Intersection RNN. 1574d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1575d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Args: 1576d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower inputs: input Tensor, 2D, batch x input size. 1577d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower state: state Tensor, 2D, batch x num units. 1578d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1579d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Returns: 1580d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_y: batch x num units, Tensor representing the output of the +RNN 1581d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower after reading `inputs` when previous state was `state`. 1582d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_state: batch x num units, Tensor representing the state of the +RNN 1583d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower after reading `inputs` when previous state was `state`. 1584d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1585d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Raises: 1586d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower ValueError: If input size cannot be inferred from `inputs` via 1587d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower static shape inference. 1588d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower ValueError: If input size != output size (these must be equal when 1589d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower using the Intersection RNN). 1590d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1591d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower sigmoid = math_ops.sigmoid 1592d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower tanh = math_ops.tanh 1593d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1594d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower input_size = inputs.get_shape().with_rank(2)[1] 1595d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower if input_size.value is None: 1596d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1597d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1598e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope(vs.get_variable_scope(), 1599e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=self._initializer): 1600d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower # read-in projections (should be used for first layer in deep +RNN 1601d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower # to transform size of inputs from I --> N) 1602d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower if input_size.value != self._num_units: 1603d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower if self._num_input_proj: 1604d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower with vs.variable_scope("in_projection"): 1605d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower inputs = _linear(inputs, self._num_units, True) 1606d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower else: 1607d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower raise ValueError("Must have input size == output size for " 1608d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower "Intersection RNN. To fix, num_in_proj should " 1609d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower "be set to num_units at cell init.") 1610d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1611d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower n_dim = i_dim = self._num_units 1612d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower cell_inputs = array_ops.concat([inputs, state], 1) 1613d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower rnn_matrix = _linear(cell_inputs, 2*n_dim + 2*i_dim, True) 1614d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1615d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower gh_act = rnn_matrix[:, :n_dim] # b x n 1616d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower h_act = rnn_matrix[:, n_dim:2*n_dim] # b x n 1617d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower gy_act = rnn_matrix[:, 2*n_dim:2*n_dim+i_dim] # b x i 1618d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower y_act = rnn_matrix[:, 2*n_dim+i_dim:2*n_dim+2*i_dim] # b x i 1619d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1620d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower h = tanh(h_act) 1621d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower y = self._y_activation(y_act) 1622d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower gh = sigmoid(gh_act + self._forget_bias) 1623d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower gy = sigmoid(gy_act + self._forget_bias) 1624d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1625d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_state = gh * state + (1.0 - gh) * h # passed thru time 1626d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_y = gy * inputs + (1.0 - gy) * y # passed thru depth 1627d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1628d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return new_y, new_state 1629d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1630d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1631bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo_REGISTERED_OPS = None 1632bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1633bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1634bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdoclass CompiledWrapper(core_rnn_cell.RNNCell): 1635bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo """Wraps step execution in an XLA JIT scope.""" 1636bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1637bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def __init__(self, cell, compile_stateful=False): 1638bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo """Create CompiledWrapper cell. 1639bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1640bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo Args: 1641bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo cell: Instance of `RNNCell`. 1642bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo compile_stateful: Whether to compile stateful ops like initializers 1643bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo and random number generators (default: False). 1644bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo """ 1645bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo self._cell = cell 1646bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo self._compile_stateful = compile_stateful 1647bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1648bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo @property 1649bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def state_size(self): 1650bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo return self._cell.state_size 1651bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1652bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo @property 1653bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def output_size(self): 1654bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo return self._cell.output_size 1655bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 165603abac7f23e6e6864949b959435282729192692eEugene Brevdo def zero_state(self, batch_size, dtype): 165703abac7f23e6e6864949b959435282729192692eEugene Brevdo with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 165803abac7f23e6e6864949b959435282729192692eEugene Brevdo return self._cell.zero_state(batch_size, dtype) 165903abac7f23e6e6864949b959435282729192692eEugene Brevdo 1660bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def __call__(self, inputs, state, scope=None): 1661bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo if self._compile_stateful: 1662bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo compile_ops = True 1663bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo else: 1664bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def compile_ops(node_def): 1665bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo global _REGISTERED_OPS 1666bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo if _REGISTERED_OPS is None: 1667bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo _REGISTERED_OPS = op_def_registry.get_registered_ops() 1668bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo return not _REGISTERED_OPS[node_def.op].is_stateful 1669bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1670bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo with jit.experimental_jit_scope(compile_ops=compile_ops): 1671e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo return self._cell(inputs, state, scope) 1672bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1673bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1674bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlowerdef _random_exp_initializer(minval, 1675bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower maxval, 1676bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower seed=None, 1677bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower dtype=dtypes.float32): 1678bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """Returns an exponential distribution initializer. 1679bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1680bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Args: 1681bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower minval: float or a scalar float Tensor. With value > 0. Lower bound of the 1682bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower range of random values to generate. 1683bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower maxval: float or a scalar float Tensor. With value > minval. Upper bound of 1684bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower the range of random values to generate. 1685bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower seed: An integer. Used to create random seeds. 1686bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower dtype: The data type. 1687bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1688bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Returns: 1689bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower An initializer that generates tensors with an exponential distribution. 1690bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """ 1691bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1692bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower def _initializer(shape, dtype=dtype, partition_info=None): 1693bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower del partition_info # Unused. 1694bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower return math_ops.exp( 1695bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower random_ops.random_uniform( 1696bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower shape, 1697bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower math_ops.log(minval), 1698bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower math_ops.log(maxval), 1699bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower dtype, 1700bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower seed=seed)) 1701bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1702bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower return _initializer 1703bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1704bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1705bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlowerclass PhasedLSTMCell(core_rnn_cell.RNNCell): 1706bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """Phased LSTM recurrent network cell. 1707bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1708bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower https://arxiv.org/pdf/1610.09513v1.pdf 1709bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """ 1710bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1711bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower def __init__(self, 1712bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower num_units, 1713bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower use_peepholes=False, 1714bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower leak=0.001, 17150a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower ratio_on=0.1, 17160a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower trainable_ratio_on=True, 1717bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower period_init_min=1.0, 1718bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower period_init_max=1000.0, 1719bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower reuse=None): 1720bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """Initialize the Phased LSTM cell. 1721bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1722bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Args: 1723bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower num_units: int, The number of units in the Phased LSTM cell. 1724bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower use_peepholes: bool, set True to enable peephole connections. 1725bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower leak: float or scalar float Tensor with value in [0, 1]. Leak applied 1726bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower during training. 17270a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the 1728bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower period during which the gates are open. 17290a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower trainable_ratio_on: bool, weather ratio_on is trainable. 1730bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower period_init_min: float or scalar float Tensor. With value > 0. 1731bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Minimum value of the initalized period. 1732bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower The period values are initialized by drawing from the distribution: 1733bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower e^U(log(period_init_min), log(period_init_max)) 1734bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Where U(.,.) is the uniform distribution. 1735bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower period_init_max: float or scalar float Tensor. 1736bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower With value > period_init_min. Maximum value of the initalized period. 1737bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower reuse: (optional) Python boolean describing whether to reuse variables 1738bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower in an existing scope. If not `True`, and the existing scope already has 1739bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower the given variables, an error is raised. 1740bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """ 1741e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(PhasedLSTMCell, self).__init__(_reuse=reuse) 1742bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._num_units = num_units 1743bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._use_peepholes = use_peepholes 1744bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._leak = leak 17450a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower self._ratio_on = ratio_on 17460a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower self._trainable_ratio_on = trainable_ratio_on 1747bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._period_init_min = period_init_min 1748bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._period_init_max = period_init_max 1749bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._reuse = reuse 1750bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1751bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower @property 1752bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower def state_size(self): 1753bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units) 1754bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1755bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower @property 1756bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower def output_size(self): 1757bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower return self._num_units 1758bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1759bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower def _mod(self, x, y): 1760bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """Modulo function that propagates x gradients.""" 1761bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x 1762bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 17630a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower def _get_cycle_ratio(self, time, phase, period): 17640a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower """Compute the cycle ratio in the dtype of the time.""" 17650a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower phase_casted = math_ops.cast(phase, dtype=time.dtype) 17660a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower period_casted = math_ops.cast(period, dtype=time.dtype) 17670a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower shifted_time = time - phase_casted 17680a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower cycle_ratio = self._mod(shifted_time, period_casted) / period_casted 17690a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower return math_ops.cast(cycle_ratio, dtype=dtypes.float32) 17700a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower 1771e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 1772bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """Phased LSTM Cell. 1773bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1774bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Args: 17750a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower inputs: A tuple of 2 Tensor. 17760a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower The first Tensor has shape [batch, 1], and type float32 or float64. 17770a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower It stores the time. 17780a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower The second Tensor has shape [batch, features_size], and type float32. 17790a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower It stores the features. 1780bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower state: core_rnn_cell.LSTMStateTuple, state from previous timestep. 1781bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1782bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Returns: 1783bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower A tuple containing: 1784bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower - A Tensor of float32, and shape [batch_size, num_units], representing the 1785bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower output of the cell. 1786bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower - A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape 1787bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower [batch_size, num_units], representing the new state and the output. 1788bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """ 1789e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo (c_prev, h_prev) = state 1790e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo (time, x) = inputs 1791bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1792e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo in_mask_gates = [x, h_prev] 1793e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 1794e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo in_mask_gates.append(c_prev) 1795bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1796e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("mask_gates"): 1797e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo mask_gates = math_ops.sigmoid( 1798e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo _linear(in_mask_gates, 2 * self._num_units, True)) 1799e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [input_gate, forget_gate] = array_ops.split( 1800e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo axis=1, num_or_size_splits=2, value=mask_gates) 1801bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1802e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("new_input"): 1803e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_input = math_ops.tanh( 1804e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo _linear([x, h_prev], self._num_units, True)) 1805bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1806e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_c = (c_prev * forget_gate + input_gate * new_input) 1807bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1808e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo in_out_gate = [x, h_prev] 1809e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 1810e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo in_out_gate.append(new_c) 1811bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1812e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("output_gate"): 1813e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo output_gate = math_ops.sigmoid( 1814e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo _linear(in_out_gate, self._num_units, True)) 1815bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1816e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_h = math_ops.tanh(new_c) * output_gate 1817bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1818e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo period = vs.get_variable( 1819e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "period", [self._num_units], 1820e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=_random_exp_initializer( 1821e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo self._period_init_min, self._period_init_max)) 1822e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo phase = vs.get_variable( 1823e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "phase", [self._num_units], 1824e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=init_ops.random_uniform_initializer( 1825e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 0., period.initial_value)) 1826e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo ratio_on = vs.get_variable( 1827e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "ratio_on", [self._num_units], 1828e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=init_ops.constant_initializer(self._ratio_on), 1829e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo trainable=self._trainable_ratio_on) 1830bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1831e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo cycle_ratio = self._get_cycle_ratio(time, phase, period) 1832bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1833e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo k_up = 2 * cycle_ratio / ratio_on 1834e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo k_down = 2 - k_up 1835e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo k_closed = self._leak * cycle_ratio 1836bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1837e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed) 1838e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k) 1839bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1840e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_c = k * new_c + (1 - k) * c_prev 1841e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_h = k * new_h + (1 - k) * h_prev 1842bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1843e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h) 1844bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1845e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo return new_h, new_state 1846