1122cdce33e3e0a01a7f82645617317530aa571fbA. Unique TensorFlower# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# 3d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Licensed under the Apache License, Version 2.0 (the "License"); 4d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# you may not use this file except in compliance with the License. 5d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# You may obtain a copy of the License at 6d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# 7d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# http://www.apache.org/licenses/LICENSE-2.0 8d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# 9d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# Unless required by applicable law or agreed to in writing, software 10d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# distributed under the License is distributed on an "AS IS" BASIS, 11d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# See the License for the specific language governing permissions and 13d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# limitations under the License. 14d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower# ============================================================================== 15d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower"""Module for constructing RNN Cells.""" 16d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import absolute_import 17d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import division 18d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom __future__ import print_function 19d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 208fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlowerimport collections 21d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerimport math 22d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 23bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.contrib.compiler import jit 2434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.contrib.layers.python.layers import layers 2516fa134cfb576bfa690d7006864e555dc42c6b62Eugene Brevdofrom tensorflow.contrib.rnn.python.ops import core_rnn_cell 261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerfrom tensorflow.python.framework import dtypes 27bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdofrom tensorflow.python.framework import op_def_registry 28d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.framework import ops 2928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlowerfrom tensorflow.python.framework import tensor_shape 3020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyenfrom tensorflow.python.layers import base as base_layer 31d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import array_ops 32d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import clip_ops 3334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlowerfrom tensorflow.python.ops import init_ops 34d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import math_ops 35b9018073ec1afc7dfc302ab171db8bf5b177c2ddYifei Fengfrom tensorflow.python.ops import nn_impl # pylint: disable=unused-import 36d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import nn_ops 37b9018073ec1afc7dfc302ab171db8bf5b177c2ddYifei Fengfrom tensorflow.python.ops import partitioned_variables # pylint: disable=unused-import 38bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlowerfrom tensorflow.python.ops import random_ops 393038bc913713bc31d0b67150e7cf7c056baba7e4Adria Puigdomenechfrom tensorflow.python.ops import rnn_cell_impl 40d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerfrom tensorflow.python.ops import variable_scope as vs 415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlowerfrom tensorflow.python.platform import tf_logging as logging 424c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenechfrom tensorflow.python.util import nest 43d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 4437fbebdd6c3c8f274896cc36e6feb5b7e2097a59Jianwei Xie 45d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_concat_variable(name, shape, dtype, num_shards): 46d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Get a sharded variable concatenated into one tensor.""" 47d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) 48d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if len(sharded_variable) == 1: 49d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return sharded_variable[0] 50d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 51d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower concat_name = name + "/concat" 52d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" 53d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): 54d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if value.name == concat_full_name: 55d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return value 56d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 570e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) 58ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable) 59d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return concat_variable 60d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 61d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 62d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlowerdef _get_sharded_variable(name, shape, dtype, num_shards): 63d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Get a list of sharded variables with the given dtype.""" 64d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if num_shards > shape[0]: 65ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape, 66ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_shards)) 67d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower unit_shard_size = int(math.floor(shape[0] / num_shards)) 68d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower remaining_rows = shape[0] - unit_shard_size * num_shards 69d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 70d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower shards = [] 71d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for i in range(num_shards): 72d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower current_size = unit_shard_size 73d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if i < remaining_rows: 74d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower current_size += 1 75ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie shards.append( 76ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie vs.get_variable( 77ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie name + "_%d" % i, [current_size] + shape[1:], dtype=dtype)) 78d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return shards 79d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 80d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 81b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengdef _norm(g, b, inp, scope): 82b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng shape = inp.get_shape()[-1:] 83b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng gamma_init = init_ops.constant_initializer(g) 84b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng beta_init = init_ops.constant_initializer(b) 85b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng with vs.variable_scope(scope): 86b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # Initialize beta and gamma for use by layer_norm. 87b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng vs.get_variable("gamma", shape=shape, initializer=gamma_init) 88b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng vs.get_variable("beta", shape=shape, initializer=beta_init) 89b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng normalized = layers.layer_norm(inp, reuse=True, scope=scope) 90b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng return normalized 91b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 92b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 93827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): 941032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """Long short-term memory unit (LSTM) recurrent network cell. 951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The default non-peephole implementation is based on: 971032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 9850b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower http://www.bioinf.jku.at/publications/older/2604.pdf 991032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower S. Hochreiter and J. Schmidhuber. 1011032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 1021032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1031032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The peephole implementation is based on: 1041032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower https://research.google.com/pubs/archive/43905.pdf 1061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1071032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Hasim Sak, Andrew Senior, and Francoise Beaufays. 1081032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower "Long short-term memory recurrent neural network architectures for 1091032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower large scale acoustic modeling." INTERSPEECH, 2014. 1101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The coupling of input and forget gate is based on: 1121032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1131032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower http://arxiv.org/pdf/1503.04069.pdf 1141032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1151032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Greff et al. "LSTM: A Search Space Odyssey" 1161032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1171032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower The class uses optional peep-hole connections, and an optional projection 1181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower layer. 119b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Layer normalization implementation is based on: 120b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 121b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng https://arxiv.org/abs/1607.06450. 122b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 123b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng "Layer Normalization" 124b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 125b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 126b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng and is applied before the internal nonlinearities. 127b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 1281032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """ 1291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 130b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng def __init__(self, 131b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng num_units, 132b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng use_peepholes=False, 133b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng initializer=None, 134b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng num_proj=None, 135b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng proj_clip=None, 136b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng num_unit_shards=1, 137b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng num_proj_shards=1, 138b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng forget_bias=1.0, 139b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng state_is_tuple=True, 140b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng activation=math_ops.tanh, 141b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng reuse=None, 142b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng layer_norm=False, 143b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng norm_gain=1.0, 144b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng norm_shift=0.0): 1451032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """Initialize the parameters for an LSTM cell. 1461032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Args: 1481032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_units: int, The number of units in the LSTM cell 1491032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower use_peepholes: bool, set True to enable diagonal/peephole connections. 1501032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 1511032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower projection matrices. 1521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj: (optional) int, The output dimensionality for the projection 1531032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower matrices. If None, no projection is performed. 1541032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 1551032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower provided, then the projected values are clipped elementwise to within 1561032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower `[-proj_clip, proj_clip]`. 1571032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_unit_shards: How to split the weight matrix. If >1, the weight 1581032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower matrix is stored across num_unit_shards. 1591032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj_shards: How to split the projection matrix. If >1, the 1601032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower projection matrix is stored across num_proj_shards. 1611032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower forget_bias: Biases of the forget gate are initialized by default to 1 1621032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower in order to reduce the scale of forgetting at the beginning of 1631032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower the training. 1641032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower state_is_tuple: If True, accepted and returned states are 2-tuples of 1651032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower the `c_state` and `m_state`. By default (False), they are concatenated 1661032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower along the column axis. This default behavior will soon be deprecated. 1671032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower activation: Activation function of the inner states. 16854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 16954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 17054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 171b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng layer_norm: If `True`, layer normalization will be applied. 172b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng norm_gain: float, The layer normalization gain initial value. If 173b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng `layer_norm` has been set to `False`, this argument will be ignored. 174b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng norm_shift: float, The layer normalization shift initial value. If 175b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng `layer_norm` has been set to `False`, this argument will be ignored. 1761032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """ 177e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) 1781032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if not state_is_tuple: 179ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie logging.warn("%s: Using a concatenated state is slower and will soon be " 180ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "deprecated. Use state_is_tuple=True.", self) 1811032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_units = num_units 1821032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._use_peepholes = use_peepholes 1831032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._initializer = initializer 1841032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_proj = num_proj 1851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._proj_clip = proj_clip 1861032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_unit_shards = num_unit_shards 1871032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._num_proj_shards = num_proj_shards 1881032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._forget_bias = forget_bias 1891032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._state_is_tuple = state_is_tuple 1901032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._activation = activation 19154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 192b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._layer_norm = layer_norm 193b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._norm_gain = norm_gain 194b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._norm_shift = norm_shift 1951032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 1961032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if num_proj: 197ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._state_size = ( 198ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 199ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie if state_is_tuple else num_units + num_proj) 2001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._output_size = num_proj 2011032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower else: 202ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._state_size = ( 203ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie rnn_cell_impl.LSTMStateTuple(num_units, num_units) 204ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie if state_is_tuple else 2 * num_units) 2051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower self._output_size = num_units 2061032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2071032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower @property 2081032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower def state_size(self): 2091032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower return self._state_size 2101032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower @property 2121032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower def output_size(self): 2131032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower return self._output_size 2141032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 215e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 2161032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """Run one step of LSTM. 2171032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2181032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Args: 2191032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower inputs: input Tensor, 2D, batch x num_units. 2201032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower state: if `state_is_tuple` is False, this must be a state Tensor, 2211032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a 2221032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower tuple of state Tensors, both `2-D`, with column sizes `c_state` and 2231032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower `m_state`. 2241032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2251032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Returns: 2261032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower A tuple containing: 2271032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower - A `2-D, [batch x output_dim]`, Tensor representing the output of the 2281032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower LSTM after reading `inputs` when previous state was `state`. 2291032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Here output_dim is: 2301032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj if num_proj was set, 2311032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_units otherwise. 2321032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower - Tensor(s) representing the new state of LSTM after reading `inputs` when 2331032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower the previous state was `state`. Same type and shape(s) as `state`. 2341032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2351032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower Raises: 2361032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower ValueError: If input size cannot be inferred from inputs via 2371032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower static shape inference. 2381032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower """ 239e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo sigmoid = math_ops.sigmoid 240e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo 2411032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower num_proj = self._num_units if self._num_proj is None else self._num_proj 2421032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2431032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if self._state_is_tuple: 2441032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower (c_prev, m_prev) = state 2451032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower else: 2461032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) 2471032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) 2481032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 2491032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower dtype = inputs.dtype 2501032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower input_size = inputs.get_shape().with_rank(2)[1] 2511032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower if input_size.value is None: 2521032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 253e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat_w = _get_concat_variable( 254ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "W", [input_size.value + num_proj, 3 * self._num_units], dtype, 255ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._num_unit_shards) 2561032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 257e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo b = vs.get_variable( 258e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "B", 259e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo shape=[3 * self._num_units], 260e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=init_ops.zeros_initializer(), 261e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype=dtype) 2621032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 263e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # j = new_input, f = forget_gate, o = output_gate 264e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo cell_inputs = array_ops.concat([inputs, m_prev], 1) 265b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng lstm_matrix = math_ops.matmul(cell_inputs, concat_w) 266b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 267b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # If layer nomalization is applied, do not add bias 268b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if not self._layer_norm: 269b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng lstm_matrix = nn_ops.bias_add(lstm_matrix, b) 270b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 271e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1) 2721032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 273b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # Apply layer normalization 274b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if self._layer_norm: 275b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng j = _norm(self._norm_gain, self._norm_shift, j, "transform") 276b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng f = _norm(self._norm_gain, self._norm_shift, f, "forget") 277b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng o = _norm(self._norm_gain, self._norm_shift, o, "output") 278b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 279e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Diagonal connections 280e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 281e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo w_f_diag = vs.get_variable( 282e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_F_diag", shape=[self._num_units], dtype=dtype) 283e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo w_o_diag = vs.get_variable( 284e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_O_diag", shape=[self._num_units], dtype=dtype) 2851032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 286e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 287e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev) 288e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 289e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo f_act = sigmoid(f + self._forget_bias) 290e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo c = (f_act * c_prev + (1 - f_act) * self._activation(j)) 2911032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 292b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # Apply layer normalization 293b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if self._layer_norm: 294b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng c = _norm(self._norm_gain, self._norm_shift, c, "state") 295b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 296e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 297e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = sigmoid(o + w_o_diag * c) * self._activation(c) 298e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 299e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = sigmoid(o) * self._activation(c) 3001032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 301e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._num_proj is not None: 302ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie concat_w_proj = _get_concat_variable("W_P", 303ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [self._num_units, self._num_proj], 304ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie dtype, self._num_proj_shards) 3051032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 306e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = math_ops.matmul(m, concat_w_proj) 307e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._proj_clip is not None: 308e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # pylint: disable=invalid-unary-operand-type 309e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 310e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # pylint: enable=invalid-unary-operand-type 3111032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 312ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie new_state = ( 313ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie rnn_cell_impl.LSTMStateTuple(c, m) 314ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie if self._state_is_tuple else array_ops.concat([c, m], 1)) 3151032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower return m, new_state 3161032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 3171032ea21a8284b84fd9529e3823cac670860da12A. Unique TensorFlower 318827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass TimeFreqLSTMCell(rnn_cell_impl.RNNCell): 319d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell. 320d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 321d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower This implementation is based on: 322d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 323d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Tara N. Sainath and Bo Li 324d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures 325d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for LVCSR Tasks." submitted to INTERSPEECH, 2016. 326d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 327d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower It uses peep-hole connections and optional cell clipping. 328d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 329d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 330ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, 331ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_units, 332ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie use_peepholes=False, 333ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie cell_clip=None, 334ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=None, 335ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_unit_shards=1, 336ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie forget_bias=1.0, 337ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie feature_size=None, 338ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie frequency_skip=1, 33954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse=None): 340d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Initialize the parameters for an LSTM cell. 341d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 342d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 343d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_units: int, The number of units in the LSTM cell 344d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower use_peepholes: bool, set True to enable diagonal/peephole connections. 345d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower cell_clip: (optional) A float value, if provided the cell state is clipped 346d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower by this value prior to the cell output activation. 347d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 348d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower projection matrices. 349d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_unit_shards: int, How to split the weight matrix. If >1, the weight 350d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower matrix is stored across num_unit_shards. 351d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower forget_bias: float, Biases of the forget gate are initialized by default 352d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower to 1 in order to reduce the scale of forgetting at the beginning 353d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower of the training. 354d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower feature_size: int, The size of the input feature the LSTM spans over. 355d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower frequency_skip: int, The amount the LSTM filter is shifted by in 356d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower frequency. 35754d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 35854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 35954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 360d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 361e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(TimeFreqLSTMCell, self).__init__(_reuse=reuse) 362d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_units = num_units 363d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._use_peepholes = use_peepholes 364d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._cell_clip = cell_clip 365d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._initializer = initializer 366d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_unit_shards = num_unit_shards 367d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._forget_bias = forget_bias 368d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._feature_size = feature_size 369d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._frequency_skip = frequency_skip 370d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._state_size = 2 * num_units 371d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._output_size = num_units 37254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 373d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 374d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 375d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def output_size(self): 376d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._output_size 377d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 378d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 379d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def state_size(self): 380d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._state_size 381d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 382e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 383d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Run one step of LSTM. 384d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 385d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 386d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower inputs: input Tensor, 2D, batch x num_units. 387d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower state: state Tensor, 2D, batch x state_size. 388d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 389d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 390d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A tuple containing: 391d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower - A 2D, batch x output_dim, Tensor representing the output of the LSTM 392d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 393d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Here output_dim is num_units. 394d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower - A 2D, batch x state_size, Tensor representing the new state of LSTM 395d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 396d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 397d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if an input_size was specified and the provided inputs have 398d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower a different dimension. 399d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 400e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo sigmoid = math_ops.sigmoid 401e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo tanh = math_ops.tanh 402d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 403d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs = self._make_tf_features(inputs) 404d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower dtype = inputs.dtype 405d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower actual_input_size = freq_inputs[0].get_shape().as_list()[1] 406d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 407e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat_w = _get_concat_variable( 408ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units], 409e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype, self._num_unit_shards) 410d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 411e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo b = vs.get_variable( 412e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "B", 413e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo shape=[4 * self._num_units], 414e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=init_ops.zeros_initializer(), 415e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype=dtype) 416d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 417e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Diagonal connections 418e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 419e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo w_f_diag = vs.get_variable( 420e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_F_diag", shape=[self._num_units], dtype=dtype) 421e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo w_i_diag = vs.get_variable( 422e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_I_diag", shape=[self._num_units], dtype=dtype) 423e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo w_o_diag = vs.get_variable( 424e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "W_O_diag", shape=[self._num_units], dtype=dtype) 425d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 426e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # initialize the first freq state to be zero 427f70db141135dd463c294e76284cab7331c55933eRasmus Munk Larsen m_prev_freq = array_ops.zeros( 428f70db141135dd463c294e76284cab7331c55933eRasmus Munk Larsen [inputs.shape[0].value or inputs.get_shape()[0], self._num_units], 429f70db141135dd463c294e76284cab7331c55933eRasmus Munk Larsen dtype) 430e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo for fq in range(len(freq_inputs)): 431ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units], 432e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [-1, self._num_units]) 433ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units], 434e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [-1, self._num_units]) 435e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # i = input_gate, j = new_input, f = forget_gate, o = output_gate 436ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1) 437e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) 438e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo i, j, f, o = array_ops.split( 439e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo value=lstm_matrix, num_or_size_splits=4, axis=1) 440e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 441e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 442ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie c = ( 443ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + 444ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie sigmoid(i + w_i_diag * c_prev) * tanh(j)) 445e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 446e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) 447e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 448e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._cell_clip is not None: 449e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # pylint: disable=invalid-unary-operand-type 450e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 451e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # pylint: enable=invalid-unary-operand-type 452e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 453e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 454e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = sigmoid(o + w_o_diag * c) * tanh(c) 455e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 456e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m = sigmoid(o) * tanh(c) 457e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_prev_freq = m 458e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if fq == 0: 459e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out = array_ops.concat([c, m], 1) 460e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out = m 461e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 462e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out = array_ops.concat([state_out, c, m], 1) 463e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out = array_ops.concat([m_out, m], 1) 464d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return m_out, state_out 465d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 466d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def _make_tf_features(self, input_feat): 467d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Make the frequency features. 468d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 469d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 470d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower input_feat: input Tensor, 2D, batch x num_units. 471d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 472d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 473d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A list of frequency features, with each element containing: 474d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower - A 2D, batch x output_dim, Tensor representing the time-frequency feature 475d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for that frequency index. Here output_dim is feature_size. 476d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 477d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if input_size cannot be inferred from static shape inference. 478d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 479d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower input_size = input_feat.get_shape().with_rank(2)[-1].value 480d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if input_size is None: 481d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower raise ValueError("Cannot infer input_size from static shape inference.") 482ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_feats = int( 483ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie (input_size - self._feature_size) / (self._frequency_skip)) + 1 484d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs = [] 485d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for f in range(num_feats): 486ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip], 487d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower [-1, self._feature_size]) 488d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs.append(cur_input) 489d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return freq_inputs 490d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 491d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 492827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass GridLSTMCell(rnn_cell_impl.RNNCell): 493d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Grid Long short-term memory unit (LSTM) recurrent network cell. 494d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 495d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower The default is based on: 496d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Nal Kalchbrenner, Ivo Danihelka and Alex Graves 497d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "Grid Long Short-Term Memory," Proc. ICLR 2016. 498d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower http://arxiv.org/abs/1507.01526 499d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 500d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower When peephole connections are used, the implementation is based on: 501d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Tara N. Sainath and Bo Li 502d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures 503d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower for LVCSR Tasks." submitted to INTERSPEECH, 2016. 504d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 505d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower The code uses optional peephole connections, shared_weights and cell clipping. 506d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 507d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 508ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, 509ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_units, 510ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie use_peepholes=False, 511d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower share_time_frequency_weights=False, 512ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie cell_clip=None, 513ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=None, 514ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_unit_shards=1, 515ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie forget_bias=1.0, 516ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie feature_size=None, 517ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie frequency_skip=None, 5189e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks=None, 5199e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list=None, 5209e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list=None, 5218fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower couple_input_forget_gates=False, 522c3d99052ec49bb219f5e29567846d3af391d7b28A. Unique TensorFlower state_is_tuple=True, 52354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse=None): 524d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Initialize the parameters for an LSTM cell. 525d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 526d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 527d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower num_units: int, The number of units in the LSTM cell 5281855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower use_peepholes: (optional) bool, default False. Set True to enable 5291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower diagonal/peephole connections. 5301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower share_time_frequency_weights: (optional) bool, default False. Set True to 5311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower enable shared cell weights between time and frequency LSTMs. 5321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower cell_clip: (optional) A float value, default None, if provided the cell 5331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state is clipped by this value prior to the cell output activation. 534d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 5351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower projection matrices, default None. 5361b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu num_unit_shards: (optional) int, default 1, How to split the weight 5371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower matrix. If > 1,the weight matrix is stored across num_unit_shards. 5381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget_bias: (optional) float, default 1.0, The initial bias of the 5391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget gates, used to reduce the scale of forgetting at the beginning 540d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower of the training. 5411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature_size: (optional) int, default None, The size of the input feature 5421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower the LSTM spans over. 5431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower frequency_skip: (optional) int, default None, The amount the LSTM filter 5441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower is shifted by in frequency. 5459e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks: [required] A list of frequency blocks needed to 5469e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cover the whole input feature splitting defined by start_freqindex_list 5479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower and end_freqindex_list. 5489e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list: [optional], list of ints, default None, The 5499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower starting frequency index for each frequency block. 5509e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list: [optional], list of ints, default None. The ending 5519e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower frequency index for each frequency block. 5521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple_input_forget_gates: (optional) bool, default False, Whether to 5531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce 5541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower model parameters and computation cost. 5558fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower state_is_tuple: If True, accepted and returned states are 2-tuples of 5568fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower the `c_state` and `m_state`. By default (False), they are concatenated 5578fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower along the column axis. This default behavior will soon be deprecated. 55854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 55954d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 56054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 5619e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower Raises: 5629e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower ValueError: if the num_frequency_blocks list is not specified 563d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 564e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(GridLSTMCell, self).__init__(_reuse=reuse) 5658fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if not state_is_tuple: 5668fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower logging.warn("%s: Using a concatenated state is slower and will soon be " 5678fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower "deprecated. Use state_is_tuple=True.", self) 568d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_units = num_units 569d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._use_peepholes = use_peepholes 570d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._share_time_frequency_weights = share_time_frequency_weights 5718fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._couple_input_forget_gates = couple_input_forget_gates 5728fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._state_is_tuple = state_is_tuple 573d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._cell_clip = cell_clip 574d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._initializer = initializer 575d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._num_unit_shards = num_unit_shards 576d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._forget_bias = forget_bias 577d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._feature_size = feature_size 578d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower self._frequency_skip = frequency_skip 5799e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._start_freqindex_list = start_freqindex_list 5809e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._end_freqindex_list = end_freqindex_list 5819e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._num_frequency_blocks = num_frequency_blocks 5829e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._total_blocks = 0 58354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 5849e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if self._num_frequency_blocks is None: 5859e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Must specify num_frequency_blocks") 5869e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower 5879e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for block_index in range(len(self._num_frequency_blocks)): 5889e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._total_blocks += int(self._num_frequency_blocks[block_index]) 5898fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if state_is_tuple: 5908fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower state_names = "" 5919e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for block_index in range(len(self._num_frequency_blocks)): 5929e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for freq_index in range(self._num_frequency_blocks[block_index]): 5939e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower name_prefix = "state_f%02d_b%02d" % (freq_index, block_index) 5949e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) 595ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple", 596ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state_names.strip(",")) 597ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._state_size = self._state_tuple_type(*( 598ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [num_units, num_units] * self._total_blocks)) 5998fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 6008fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower self._state_tuple_type = None 6019e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._state_size = num_units * self._total_blocks * 2 6029e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._output_size = num_units * self._total_blocks * 2 603d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 604d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 605d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def output_size(self): 606d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._output_size 607d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 608d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower @property 609d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower def state_size(self): 610d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return self._state_size 611d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 6128fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower @property 6138fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower def state_tuple_type(self): 6148fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower return self._state_tuple_type 6158fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower 616e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 617d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Run one step of LSTM. 618d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 619d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 6201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs: input Tensor, 2D, [batch, feature_size]. 6211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the 6221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower flag self._state_is_tuple. 623d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 624d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 625d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A tuple containing: 6261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, output_dim], Tensor representing the output of the LSTM 627d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 628d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Here output_dim is num_units. 6291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, state_size], Tensor representing the new state of LSTM 630d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower after reading "inputs" when previous state was "state". 631d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 632d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if an input_size was specified and the provided inputs have 633d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower a different dimension. 634d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 635499166454f0c7dd6c724c364f6dc5d99357514b1A. Unique TensorFlower batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] 6361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower freq_inputs = self._make_tf_features(inputs) 637e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out_lst = [] 638e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out_lst = [] 639e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo for block in range(len(freq_inputs)): 640e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out_lst_current, state_out_lst_current = self._compute( 641ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie freq_inputs[block], 642ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie block, 643ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state, 644ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie batch_size, 645e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_is_tuple=self._state_is_tuple) 646e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out_lst.extend(m_out_lst_current) 647e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out_lst.extend(state_out_lst_current) 648e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._state_is_tuple: 649e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out = self._state_tuple_type(*state_out_lst) 650e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 651e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state_out = array_ops.concat(state_out_lst, 1) 652e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_out = array_ops.concat(m_out_lst, 1) 6531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower return m_out, state_out 6541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 655ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def _compute(self, 656ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie freq_inputs, 657ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie block, 658ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state, 659ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie batch_size, 6609e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_prefix="state", 6611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_is_tuple=True): 6621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Run the actual computation of one step LSTM. 6631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Args: 6651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower freq_inputs: list of Tensors, 2D, [batch, feature_size]. 6669e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block: int, current frequency block index to process. 6671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on 6681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower the flag state_is_tuple. 6691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower batch_size: int32, batch size. 6701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_prefix: (optional) string, name prefix for states, defaults to 6711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "state". 6721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_is_tuple: boolean, indicates whether the state is a tuple or Tensor. 6731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Returns: 6751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower A tuple, containing: 6761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A list of [batch, output_dim] Tensors, representing the output of the 6771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower LSTM given the inputs and state. 6781855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A list of [batch, state_size] Tensors, representing the LSTM state 6791855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower values given the inputs and previous state. 6801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 681e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo sigmoid = math_ops.sigmoid 682e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo tanh = math_ops.tanh 6838fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower num_gates = 3 if self._couple_input_forget_gates else 4 6841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower dtype = freq_inputs[0].dtype 685d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower actual_input_size = freq_inputs[0].get_shape().as_list()[1] 6861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 6871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower concat_w_f = _get_concat_variable( 688ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "W_f_%d" % block, 689ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [actual_input_size + 2 * self._num_units, num_gates * self._num_units], 6901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower dtype, self._num_unit_shards) 6911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower b_f = vs.get_variable( 6924ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist "B_f_%d" % block, 6934ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist shape=[num_gates * self._num_units], 6944ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist initializer=init_ops.zeros_initializer(), 6954ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist dtype=dtype) 6961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if not self._share_time_frequency_weights: 697ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie concat_w_t = _get_concat_variable("W_t_%d" % block, [ 698ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie actual_input_size + 2 * self._num_units, num_gates * self._num_units 699ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie ], dtype, self._num_unit_shards) 7001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower b_t = vs.get_variable( 7014ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist "B_t_%d" % block, 7024ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist shape=[num_gates * self._num_units], 7034ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist initializer=init_ops.zeros_initializer(), 7044ae96e5f3249190abbb4cc766ae04eede53f0199Olivia Nordquist dtype=dtype) 705d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 7061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 7071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Diagonal connections 7081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if not self._couple_input_forget_gates: 7091855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqf = vs.get_variable( 7109e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 7111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_freqt = vs.get_variable( 712ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 7131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqf = vs.get_variable( 7149e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 7151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqt = vs.get_variable( 7169e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 7171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqf = vs.get_variable( 7189e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) 7191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqt = vs.get_variable( 7209e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) 7211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if not self._share_time_frequency_weights: 7228fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if not self._couple_input_forget_gates: 7231855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_timef = vs.get_variable( 7249e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 7251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_f_diag_timet = vs.get_variable( 7269e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 7271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timef = vs.get_variable( 7289e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 7291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timet = vs.get_variable( 7309e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 7311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timef = vs.get_variable( 7329e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) 7331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timet = vs.get_variable( 7349e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) 7351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 7361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # initialize the first freq state to be zero 7371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) 7381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) 7391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower for freq_index in range(len(freq_inputs)): 7401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if state_is_tuple: 7419e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block) 7421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_time = getattr(state, name_prefix + "_c") 7431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_time = getattr(state, name_prefix + "_m") 7441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7451855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_time = array_ops.slice( 746ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state, [0, 2 * freq_index * self._num_units], [-1, self._num_units]) 7471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_time = array_ops.slice( 7481855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state, [0, (2 * freq_index + 1) * self._num_units], 7491855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower [-1, self._num_units]) 7501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 7511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # i = input_gate, j = new_input, f = forget_gate, o = output_gate 7520e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower cell_inputs = array_ops.concat( 753d4eb834824d79c6a64a3c4a1c4a88b434b73e63eA. Unique TensorFlower [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1) 7541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 7551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # F-LSTM 756ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie lstm_matrix_freq = nn_ops.bias_add( 757ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie math_ops.matmul(cell_inputs, concat_w_f), b_f) 7581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._couple_input_forget_gates: 759a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_freq, j_freq, o_freq = array_ops.split( 760a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) 7611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq = None 7621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 763a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_freq, j_freq, f_freq, o_freq = array_ops.split( 764a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) 7651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # T-LSTM 7661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._share_time_frequency_weights: 7671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_time = i_freq 7681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower j_time = j_freq 7691855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time = f_freq 7701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower o_time = o_freq 7711855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 772ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie lstm_matrix_time = nn_ops.bias_add( 773ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie math_ops.matmul(cell_inputs, concat_w_t), b_t) 7748fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower if self._couple_input_forget_gates: 775a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_time, j_time, o_time = array_ops.split( 776a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) 7771855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time = None 7788fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 779a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower i_time, j_time, f_time, o_time = array_ops.split( 780a46b6d211eac423c72d3a57a177daf2f64db8642A. Unique TensorFlower value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) 781d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 7821855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # F-LSTM c_freq 7831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # input gate activations 7841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 785ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq + 7861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_freqt * c_prev_time) 7871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_freq_g = sigmoid(i_freq) 7891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # forget gate activations 7901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._couple_input_forget_gates: 7911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq_g = 1.0 - i_freq_g 7921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 793d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._use_peepholes: 794ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf * 795ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie c_prev_freq + w_f_diag_freqt * c_prev_time) 7961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 7971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_freq_g = sigmoid(f_freq + self._forget_bias) 7981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # cell state 7991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq) 8001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._cell_clip is not None: 8011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: disable=invalid-unary-operand-type 8021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip, 8031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._cell_clip) 8041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: enable=invalid-unary-operand-type 8051855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 8061855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # T-LSTM c_freq 8071855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # input gate activations 8081855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 8091855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._share_time_frequency_weights: 810ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq + 8118fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower w_i_diag_freqt * c_prev_time) 8128fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 813ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq + 8141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_i_diag_timet * c_prev_time) 8151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 8161855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower i_time_g = sigmoid(i_time) 8171855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # forget gate activations 8181855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._couple_input_forget_gates: 8191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time_g = 1.0 - i_time_g 8201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 821d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._use_peepholes: 822d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if self._share_time_frequency_weights: 823ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf * 824ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie c_prev_freq + w_f_diag_freqt * c_prev_time) 825d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower else: 826ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef * 827ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie c_prev_freq + w_f_diag_timet * c_prev_time) 8288fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 8291855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower f_time_g = sigmoid(f_time + self._forget_bias) 8301855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # cell state 8311855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time) 8321855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._cell_clip is not None: 8331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: disable=invalid-unary-operand-type 8341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_time = clip_ops.clip_by_value(c_time, -self._cell_clip, 8351855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._cell_clip) 8361855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # pylint: enable=invalid-unary-operand-type 8371855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 8381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # F-LSTM m_freq 8391855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 840ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq + 8411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqt * c_time) * tanh(c_freq) 8421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 8431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_freq = sigmoid(o_freq) * tanh(c_freq) 844d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 8451855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # T-LSTM m_time 8461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._use_peepholes: 8471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._share_time_frequency_weights: 848ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie m_time = sigmoid(o_time + w_o_diag_freqf * c_freq + 8491855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_freqt * c_time) * tanh(c_time) 850d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower else: 851ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie m_time = sigmoid(o_time + w_o_diag_timef * c_freq + 8521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower w_o_diag_timet * c_time) * tanh(c_time) 8538fdea19f4151970915b69ba729cbcf725ea05860A. Unique TensorFlower else: 8541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_time = sigmoid(o_time) * tanh(c_time) 8551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 8561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_prev_freq = m_freq 8571855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower c_prev_freq = c_freq 8581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Concatenate the outputs for T-LSTM and F-LSTM for each shift 8591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if freq_index == 0: 8601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_out_lst = [c_time, m_time] 8611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_out_lst = [m_time, m_freq] 8621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 8631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_out_lst.extend([c_time, m_time]) 8641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower m_out_lst.extend([m_time, m_freq]) 865d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 8661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower return m_out_lst, state_out_lst 8671855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 8681855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower def _make_tf_features(self, input_feat, slice_offset=0): 869d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """Make the frequency features. 870d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 871d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Args: 8721855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower input_feat: input Tensor, 2D, [batch, num_units]. 8731855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower slice_offset: (optional) Python int, default 0, the slicing offset is only 8741855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower used for the backward processing in the BidirectionalGridLSTMCell. It 8751855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower specifies a different starting point instead of always 0 to enable the 8761855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forward and backward processing look at different frequency blocks. 877d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower 878d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Returns: 879d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower A list of frequency features, with each element containing: 8801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, output_dim], Tensor representing the time-frequency 8811855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature for that frequency index. Here output_dim is feature_size. 882d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower Raises: 883d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower ValueError: if input_size cannot be inferred from static shape inference. 884d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower """ 885d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower input_size = input_feat.get_shape().with_rank(2)[-1].value 886d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower if input_size is None: 887d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower raise ValueError("Cannot infer input_size from static shape inference.") 8881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if slice_offset > 0: 8891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Padding to the end 890ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie inputs = array_ops.pad(input_feat, 891ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie array_ops.constant( 892ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [0, 0, 0, slice_offset], 893ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie shape=[2, 2], 894ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie dtype=dtypes.int32), "CONSTANT") 8951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower elif slice_offset < 0: 8961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Padding to the front 897ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie inputs = array_ops.pad(input_feat, 898ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie array_ops.constant( 899ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [0, 0, -slice_offset, 0], 900ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie shape=[2, 2], 901ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie dtype=dtypes.int32), "CONSTANT") 9021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower slice_offset = 0 9031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 9041855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs = input_feat 905d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower freq_inputs = [] 9069e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if not self._start_freqindex_list: 9079e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if len(self._num_frequency_blocks) != 1: 9089e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Length of num_frequency_blocks" 9099e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " is not 1, but instead is %d", 9109e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._num_frequency_blocks)) 911ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_feats = int( 912ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie (input_size - self._feature_size) / (self._frequency_skip)) + 1 9139e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if num_feats != self._num_frequency_blocks[0]: 9149e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError( 9159e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "Invalid num_frequency_blocks, requires %d but gets %d, please" 916ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie " check the input size and filter config are correct." % 917ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie (self._num_frequency_blocks[0], num_feats)) 9189e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs = [] 9199e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for f in range(num_feats): 9209e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cur_input = array_ops.slice( 9219e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower inputs, [0, slice_offset + f * self._frequency_skip], 9229e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower [-1, self._feature_size]) 9239e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs.append(cur_input) 9249e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower freq_inputs.append(block_inputs) 9259e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower else: 9269e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if len(self._start_freqindex_list) != len(self._end_freqindex_list): 9279e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Length of start and end freqindex_list" 9289e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " does not match %d %d", 9299e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._start_freqindex_list), 9309e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._end_freqindex_list)) 9319e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if len(self._num_frequency_blocks) != len(self._start_freqindex_list): 9329e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError("Length of num_frequency_blocks" 9339e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower " is not equal to start_freqindex_list %d %d", 9349e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._num_frequency_blocks), 9359e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower len(self._start_freqindex_list)) 9369e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for b in range(len(self._start_freqindex_list)): 9379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_index = self._start_freqindex_list[b] 9389e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_index = self._end_freqindex_list[b] 9399e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cur_size = end_index - start_index 940ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie block_feats = int( 941ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie (cur_size - self._feature_size) / (self._frequency_skip)) + 1 9429e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower if block_feats != self._num_frequency_blocks[b]: 9439e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower raise ValueError( 9449e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower "Invalid num_frequency_blocks, requires %d but gets %d, please" 945ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie " check the input size and filter config are correct." % 946ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie (self._num_frequency_blocks[b], block_feats)) 9479e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs = [] 9489e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for f in range(block_feats): 9499e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cur_input = array_ops.slice( 950ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie inputs, 951ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [0, start_index + slice_offset + f * self._frequency_skip], 9529e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower [-1, self._feature_size]) 9539e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_inputs.append(cur_input) 9549e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower freq_inputs.append(block_inputs) 955d1a31025aaaecbc9137998b587b90221c5a36cb3A. Unique TensorFlower return freq_inputs 9565cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 9575cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 9581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlowerclass BidirectionalGridLSTMCell(GridLSTMCell): 9591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Bidirectional GridLstm cell. 9601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower The bidirection connection is only used in the frequency direction, which 9621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower hence doesn't affect the time direction's real-time processing that is 9631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower required for online recognition systems. 9641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower The current implementation uses different weights for the two directions. 9651855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 9661855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 967ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, 968ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_units, 969ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie use_peepholes=False, 9701855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower share_time_frequency_weights=False, 971ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie cell_clip=None, 972ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=None, 973ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_unit_shards=1, 974ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie forget_bias=1.0, 975ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie feature_size=None, 976ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie frequency_skip=None, 9779e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks=None, 9789e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list=None, 9799e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list=None, 9801855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple_input_forget_gates=False, 98154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo backward_slice_offset=0, 98254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse=None): 9831855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Initialize the parameters for an LSTM cell. 9841855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 9851855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Args: 9861855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_units: int, The number of units in the LSTM cell 9871855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower use_peepholes: (optional) bool, default False. Set True to enable 9881855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower diagonal/peephole connections. 9891855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower share_time_frequency_weights: (optional) bool, default False. Set True to 9901855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower enable shared cell weights between time and frequency LSTMs. 9911855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower cell_clip: (optional) A float value, default None, if provided the cell 9921855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state is clipped by this value prior to the cell output activation. 9931855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower initializer: (optional) The initializer to use for the weight and 9941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower projection matrices, default None. 9951b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu num_unit_shards: (optional) int, default 1, How to split the weight 9961855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower matrix. If > 1,the weight matrix is stored across num_unit_shards. 9971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget_bias: (optional) float, default 1.0, The initial bias of the 9981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower forget gates, used to reduce the scale of forgetting at the beginning 9991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower of the training. 10001855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower feature_size: (optional) int, default None, The size of the input feature 10011855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower the LSTM spans over. 10021855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower frequency_skip: (optional) int, default None, The amount the LSTM filter 10031855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower is shifted by in frequency. 10049e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks: [required] A list of frequency blocks needed to 10059e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower cover the whole input feature splitting defined by start_freqindex_list 10069e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower and end_freqindex_list. 10079e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower start_freqindex_list: [optional], list of ints, default None, The 10089e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower starting frequency index for each frequency block. 10099e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower end_freqindex_list: [optional], list of ints, default None. The ending 10109e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower frequency index for each frequency block. 10111855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple_input_forget_gates: (optional) bool, default False, Whether to 10121855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce 10131855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower model parameters and computation cost. 10141855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower backward_slice_offset: (optional) int32, default 0, the starting offset to 10151855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower slice the feature for backward processing. 101654d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 101754d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 101854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 10191855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 10201855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower super(BidirectionalGridLSTMCell, self).__init__( 10211855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower num_units, use_peepholes, share_time_frequency_weights, cell_clip, 10221855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower initializer, num_unit_shards, forget_bias, feature_size, frequency_skip, 10239e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower num_frequency_blocks, start_freqindex_list, end_freqindex_list, 1024a5493db6082a6fe29ed11e71f2aab5bfdf5f98c7A. Unique TensorFlower couple_input_forget_gates, True, reuse) 10251855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._backward_slice_offset = int(backward_slice_offset) 10261855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_names = "" 10271855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower for direction in ["fwd", "bwd"]: 10289e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for block_index in range(len(self._num_frequency_blocks)): 10299e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower for freq_index in range(self._num_frequency_blocks[block_index]): 10309e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index, 10319e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower block_index) 10329e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) 10331855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower self._state_tuple_type = collections.namedtuple( 10341855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower "BidirectionalGridLSTMStateTuple", state_names.strip(",")) 1035ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._state_size = self._state_tuple_type(*( 1036ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [num_units, num_units] * self._total_blocks * 2)) 10379e601230a91f9da95faa99f4bceb7b7e97fa8eb0A. Unique TensorFlower self._output_size = 2 * num_units * self._total_blocks * 2 10381855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 1039e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 10401855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """Run one step of LSTM. 10411855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 10421855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Args: 10431855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower inputs: input Tensor, 2D, [batch, num_units]. 10441855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state: tuple of Tensors, 2D, [batch, state_size]. 10451855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 10461855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Returns: 10471855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower A tuple containing: 10481855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, output_dim], Tensor representing the output of the LSTM 10491855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower after reading "inputs" when previous state was "state". 10501855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Here output_dim is num_units. 10511855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower - A 2D, [batch, state_size], Tensor representing the new state of LSTM 10521855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower after reading "inputs" when previous state was "state". 10531855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower Raises: 10541855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower ValueError: if an input_size was specified and the provided inputs have 10551855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower a different dimension. 10561855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower """ 1057499166454f0c7dd6c724c364f6dc5d99357514b1A. Unique TensorFlower batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] 10581855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower fwd_inputs = self._make_tf_features(inputs) 10591855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower if self._backward_slice_offset: 10601855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset) 10611855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower else: 10621855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower bwd_inputs = fwd_inputs 10631855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 10641855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Forward processing 1065e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("fwd"): 1066e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_m_out_lst = [] 1067e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_state_out_lst = [] 1068e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo for block in range(len(fwd_inputs)): 1069e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( 1070ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie fwd_inputs[block], 1071ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie block, 1072ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state, 1073ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie batch_size, 1074ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state_prefix="fwd_state", 1075ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state_is_tuple=True) 1076e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_m_out_lst.extend(fwd_m_out_lst_current) 1077e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo fwd_state_out_lst.extend(fwd_state_out_lst_current) 1078e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Backward processing 1079e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_m_out_lst = [] 1080e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_state_out_lst = [] 1081e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("bwd"): 1082e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo for block in range(len(bwd_inputs)): 1083e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Reverse the blocks 1084e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_inputs_reverse = bwd_inputs[block][::-1] 1085e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( 1086ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie bwd_inputs_reverse, 1087ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie block, 1088ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state, 1089ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie batch_size, 1090ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state_prefix="bwd_state", 1091ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state_is_tuple=True) 1092e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_m_out_lst.extend(bwd_m_out_lst_current) 1093e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo bwd_state_out_lst.extend(bwd_state_out_lst_current) 10941855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst)) 10951855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower # Outputs are always concated as it is never used separately. 10960e226af7eed5e2764aa8acb825af4cd3e06d2452A. Unique TensorFlower m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1) 10971855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower return m_out, state_out 10981855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 10991855c80558fecebc1e8b90efbe2cfe8573bff38aA. Unique TensorFlower 11005cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: disable=protected-access 110116fa134cfb576bfa690d7006864e555dc42c6b62Eugene Brevdo_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name 1102ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 11035cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower# pylint: enable=protected-access 11045cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 11055cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 1106827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass AttentionCellWrapper(rnn_cell_impl.RNNCell): 11075cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """Basic attention cell wrapper. 11085cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 1109ec3f4d62979ef1e70e8e12e2568b13dad45fd39eA. Unique TensorFlower Implementation based on https://arxiv.org/abs/1409.0473. 11105cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """ 11115cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 1112ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, 1113ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie cell, 1114ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie attn_length, 1115ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie attn_size=None, 1116ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie attn_vec_size=None, 1117ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie input_size=None, 1118ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie state_is_tuple=True, 1119ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie reuse=None): 11205cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """Create a cell with attention. 11215cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 11225cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower Args: 11235cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower cell: an RNNCell, an attention is added to it. 11245cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_length: integer, the size of an attention window. 11255cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_size: integer, the size of an attention vector. Equal to 11265cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower cell.output_size by default. 11275cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_vec_size: integer, the number of convolutional features calculated 11285cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower on attention state and a size of the hidden layer built from 11295cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower base cell state. Equal attn_size to by default. 11305cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower input_size: integer, the size of a hidden linear layer, 11315cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower built from inputs and attention. Derived from the input tensor 11325cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower by default. 11335cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower state_is_tuple: If True, accepted and returned states are n-tuples, where 11345cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower `n = len(cells)`. By default (False), the states are all 11355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower concatenated along the column axis. 113654d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 113754d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 113854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 11395cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 11405cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower Raises: 11415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower TypeError: if cell is not an RNNCell. 11425cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower ValueError: if cell returns a state tuple but the flag 11435cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower `state_is_tuple` is `False` or if attn_length is zero or less. 11445cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """ 1145e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(AttentionCellWrapper, self).__init__(_reuse=reuse) 11463038bc913713bc31d0b67150e7cf7c056baba7e4Adria Puigdomenech if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access 11475cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower raise TypeError("The parameter cell is not RNNCell.") 11484c7fde3025c70bfd19291511f0360eaab48f8c0dAdria Puigdomenech if nest.is_sequence(cell.state_size) and not state_is_tuple: 1149ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie raise ValueError( 1150ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "Cell returns tuple of states, but the flag " 1151ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "state_is_tuple is not set. State size is: %s" % str(cell.state_size)) 11525cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if attn_length <= 0: 1153ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie raise ValueError( 1154ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "attn_length should be greater than zero, got %s" % str(attn_length)) 11555cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if not state_is_tuple: 1156ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie logging.warn("%s: Using a concatenated state is slower and will soon be " 1157ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "deprecated. Use state_is_tuple=True.", self) 11585cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if attn_size is None: 11595cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_size = cell.output_size 11605cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if attn_vec_size is None: 11615cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower attn_vec_size = attn_size 11625cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._state_is_tuple = state_is_tuple 11635cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._cell = cell 11645cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_vec_size = attn_vec_size 11655cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._input_size = input_size 11665cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_size = attn_size 11675cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_length = attn_length 116854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 11693f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear1 = None 11703f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear2 = None 11713f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear3 = None 11725cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 11735cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower @property 11745cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def state_size(self): 11755cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower size = (self._cell.state_size, self._attn_size, 11765cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower self._attn_size * self._attn_length) 11775cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower if self._state_is_tuple: 11785cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return size 11795cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower else: 11805cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return sum(list(size)) 11815cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 11825cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower @property 11835cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def output_size(self): 11845cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return self._attn_size 11855cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 1186e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 11875cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower """Long short-term memory cell with attention (LSTMA).""" 1188e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._state_is_tuple: 1189e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state, attns, attn_states = state 1190e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 1191e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo states = state 1192e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size]) 1193ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie attns = array_ops.slice(states, [0, self._cell.state_size], 1194ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [-1, self._attn_size]) 1195e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo attn_states = array_ops.slice( 1196e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo states, [0, self._cell.state_size + self._attn_size], 1197e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [-1, self._attn_size * self._attn_length]) 1198e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo attn_states = array_ops.reshape(attn_states, 1199e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [-1, self._attn_length, self._attn_size]) 1200e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo input_size = self._input_size 1201e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if input_size is None: 1202e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo input_size = inputs.get_shape().as_list()[1] 12033f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower if self._linear1 is None: 12043f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear1 = _Linear([inputs, attns], input_size, True) 12053f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower inputs = self._linear1([inputs, attns]) 120650b999a8336d19400ab75aea66fe46eca2f5fe0bA. Unique TensorFlower cell_output, new_state = self._cell(inputs, state) 1207e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._state_is_tuple: 1208e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state_cat = array_ops.concat(nest.flatten(new_state), 1) 1209e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo else: 1210e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state_cat = new_state 1211e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_attns, new_attn_states = self._attention(new_state_cat, attn_states) 1212e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("attn_output_projection"): 12133f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower if self._linear2 is None: 12143f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True) 12153f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower output = self._linear2([cell_output, new_attns]) 1216e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_attn_states = array_ops.concat( 1217e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [new_attn_states, array_ops.expand_dims(output, 1)], 1) 1218e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_attn_states = array_ops.reshape( 1219e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_attn_states, [-1, self._attn_length * self._attn_size]) 1220e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state = (new_state, new_attns, new_attn_states) 1221e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if not self._state_is_tuple: 1222e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_state = array_ops.concat(list(new_state), 1) 1223e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo return output, new_state 12245cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower 12255cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower def _attention(self, query, attn_states): 1226e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo conv2d = nn_ops.conv2d 1227e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo reduce_sum = math_ops.reduce_sum 1228e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo softmax = nn_ops.softmax 1229e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo tanh = math_ops.tanh 1230e5a1c6a933eeae54ca69bc9eadf54c51f1614519Eugene Brevdo 123192da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope("attention"): 1232ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie k = vs.get_variable("attn_w", 1233ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [1, 1, self._attn_size, self._attn_vec_size]) 123492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo v = vs.get_variable("attn_v", [self._attn_vec_size]) 12355cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower hidden = array_ops.reshape(attn_states, 12365cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower [-1, self._attn_length, 1, self._attn_size]) 12375cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") 12383f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower if self._linear3 is None: 12393f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear3 = _Linear(query, self._attn_vec_size, True) 12403f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower y = self._linear3(query) 12415cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) 12425cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) 12435cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower a = softmax(s) 12445cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower d = reduce_sum( 12455cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2]) 12465cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_attns = array_ops.reshape(d, [-1, self._attn_size]) 12475cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1]) 12485cb6030958dc6fb7c4a7a1bbcbf06223c5da7390A. Unique TensorFlower return new_attns, new_attn_states 124934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 125034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1251827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass HighwayWrapper(rnn_cell_impl.RNNCell): 125203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower """RNNCell wrapper that adds highway connection on cell input and output. 125303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 125403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower Based on: 125503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks", 125603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower arXiv preprint arXiv:1505.00387, 2015. 125703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower https://arxiv.org/abs/1505.00387 125803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower """ 125903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 1260ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, 1261ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie cell, 126203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower couple_carry_transform_gates=True, 126303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower carry_bias_init=1.0): 126403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower """Constructs a `HighwayWrapper` for `cell`. 126503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 126603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower Args: 126703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower cell: An instance of `RNNCell`. 126803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower couple_carry_transform_gates: boolean, should the Carry and Transform gate 126903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower be coupled. 127003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower carry_bias_init: float, carry gates bias initialization. 127103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower """ 127203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower self._cell = cell 127303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower self._couple_carry_transform_gates = couple_carry_transform_gates 127403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower self._carry_bias_init = carry_bias_init 127503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 127603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower @property 127703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower def state_size(self): 127803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower return self._cell.state_size 127903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 128003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower @property 128103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower def output_size(self): 128203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower return self._cell.output_size 128303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 128403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower def zero_state(self, batch_size, dtype): 128503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 128603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower return self._cell.zero_state(batch_size, dtype) 128703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 128803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower def _highway(self, inp, out): 128903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower input_size = inp.get_shape().with_rank(2)[1].value 129003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower carry_weight = vs.get_variable("carry_w", [input_size, input_size]) 129103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower carry_bias = vs.get_variable( 129203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower "carry_b", [input_size], 1293ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=init_ops.constant_initializer(self._carry_bias_init)) 129403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias)) 129503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower if self._couple_carry_transform_gates: 129603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower transform = 1 - carry 129703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower else: 129803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower transform_weight = vs.get_variable("transform_w", 129903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower [input_size, input_size]) 130003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower transform_bias = vs.get_variable( 130103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower "transform_b", [input_size], 1302ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=init_ops.constant_initializer(-self._carry_bias_init)) 1303ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie transform = math_ops.sigmoid( 1304ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie nn_ops.xw_plus_b(inp, transform_weight, transform_bias)) 130503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower return inp * carry + out * transform 130603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 130703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower def __call__(self, inputs, state, scope=None): 130803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower """Run the cell and add its inputs to its outputs. 130903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 131003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower Args: 131103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower inputs: cell inputs. 131203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower state: cell state. 131303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower scope: optional cell scope. 131403327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 131503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower Returns: 131603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower Tuple of cell outputs and new state. 131703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 131803327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower Raises: 131903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower TypeError: If cell inputs and outputs have different structure (type). 132003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower ValueError: If cell inputs and outputs have different structure (value). 132103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower """ 132203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower outputs, new_state = self._cell(inputs, state, scope=scope) 132303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower nest.assert_same_structure(inputs, outputs) 1324ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 132503327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower # Ensure shapes match 132603327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower def assert_shape_match(inp, out): 132703327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower inp.get_shape().assert_is_compatible_with(out.get_shape()) 1328ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 132903327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower nest.map_structure(assert_shape_match, inputs, outputs) 133003327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower res_outputs = nest.map_structure(self._highway, inputs, outputs) 133103327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower return (res_outputs, new_state) 133203327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 133303327190420dd5b1c34a5ffdd0000aff40980ed5A. Unique TensorFlower 1334827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): 133534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """LSTM unit with layer normalization and recurrent dropout. 133634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 133734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower This class adds layer normalization and recurrent dropout to a 133834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower basic LSTM unit. Layer normalization implementation is based on: 133934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 134034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower https://arxiv.org/abs/1607.06450. 134134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 134234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower "Layer Normalization" 134334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 134434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 134534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower and is applied before the internal nonlinearities. 134634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Recurrent dropout is base on: 134734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 134834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower https://arxiv.org/abs/1603.05118 134934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 135034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower "Recurrent Dropout without Memory Loss" 135134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth. 135234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """ 135334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1354ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, 1355ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_units, 1356ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie forget_bias=1.0, 1357ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie input_size=None, 1358ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie activation=math_ops.tanh, 1359ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie layer_norm=True, 1360ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie norm_gain=1.0, 1361ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie norm_shift=0.0, 1362ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie dropout_keep_prob=1.0, 1363ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie dropout_prob_seed=None, 136454d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse=None): 136534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """Initializes the basic LSTM cell. 136634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 136734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower Args: 136834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower num_units: int, The number of units in the LSTM cell. 136934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower forget_bias: float, The bias added to forget gates (see above). 137034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower input_size: Deprecated and unused. 137134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower activation: Activation function of the inner states. 137234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower layer_norm: If `True`, layer normalization will be applied. 137334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower norm_gain: float, The layer normalization gain initial value. If 137434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower `layer_norm` has been set to `False`, this argument will be ignored. 137534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower norm_shift: float, The layer normalization shift initial value. If 137634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower `layer_norm` has been set to `False`, this argument will be ignored. 137734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower dropout_keep_prob: unit Tensor or float between 0 and 1 representing the 137834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower recurrent dropout probability value. If float and 1.0, no dropout will 137934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower be applied. 138034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower dropout_prob_seed: (optional) integer, the randomness seed. 138154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 138254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 138354d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 138434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """ 1385e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse) 138634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 138734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower if input_size is not None: 138834aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower logging.warn("%s: The input_size parameter is deprecated.", self) 138934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 139034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._num_units = num_units 139134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._activation = activation 139234aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._forget_bias = forget_bias 139334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._keep_prob = dropout_keep_prob 139434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._seed = dropout_prob_seed 139534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower self._layer_norm = layer_norm 1396b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._norm_gain = norm_gain 1397b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._norm_shift = norm_shift 139854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 139934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 140034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower @property 140134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def state_size(self): 1402827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) 140334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 140434aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower @property 140534aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower def output_size(self): 140634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower return self._num_units 140734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1408b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng def _norm(self, inp, scope, dtype=dtypes.float32): 140992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo shape = inp.get_shape()[-1:] 1410b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng gamma_init = init_ops.constant_initializer(self._norm_gain) 1411b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng beta_init = init_ops.constant_initializer(self._norm_shift) 141292da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo with vs.variable_scope(scope): 141392da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo # Initialize beta and gamma for use by layer_norm. 1414b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype) 1415b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype) 141692da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo normalized = layers.layer_norm(inp, reuse=True, scope=scope) 141792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo return normalized 141892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo 141992da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo def _linear(self, args): 142034aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower out_size = 4 * self._num_units 142134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower proj_size = args.get_shape()[-1] 1422b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng dtype = args.dtype 1423b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype) 142492da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo out = math_ops.matmul(args, weights) 142592da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo if not self._layer_norm: 1426b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng bias = vs.get_variable("bias", [out_size], dtype=dtype) 142792da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo out = nn_ops.bias_add(out, bias) 142892da8abfd35b93488ed7a55308b8f589ee23b622Eugene Brevdo return out 142934aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1430e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 143134aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower """LSTM cell with layer normalization and recurrent dropout.""" 1432e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo c, h = state 1433e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo args = array_ops.concat([inputs, h], 1) 1434e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat = self._linear(args) 1435b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng dtype = args.dtype 143634aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1437e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) 1438e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._layer_norm: 1439b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng i = self._norm(i, "input", dtype=dtype) 1440b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng j = self._norm(j, "transform", dtype=dtype) 1441b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng f = self._norm(f, "forget", dtype=dtype) 1442b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng o = self._norm(o, "output", dtype=dtype) 144334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1444e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo g = self._activation(j) 1445e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: 1446e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) 144734aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1448ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie new_c = ( 1449ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) 1450e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._layer_norm: 1451b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng new_c = self._norm(new_c, "state", dtype=dtype) 1452e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_h = self._activation(new_c) * math_ops.sigmoid(o) 145334aede1549663343fca0083ca05193e57e70411dA. Unique TensorFlower 1454827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) 1455e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo return new_h, new_state 1456bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1457bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1458827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass NASCell(rnn_cell_impl.RNNCell): 14591e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """Neural Architecture Search (NAS) recurrent network cell. 14601e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 14611e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower This implements the recurrent cell from the paper: 14621e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 14631e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower https://arxiv.org/abs/1611.01578 14641e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 14651e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Barret Zoph and Quoc V. Le. 14661e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017. 14671e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 14681e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower The class uses an optional projection layer. 14691e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """ 14701e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 1471ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None): 14721e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """Initialize the parameters for a NAS cell. 14731e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 14741e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Args: 14751e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower num_units: int, The number of units in the NAS cell 14761e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower num_proj: (optional) int, The output dimensionality for the projection 14771e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower matrices. If None, no projection is performed. 14781e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower use_biases: (optional) bool, If True then use biases within the cell. This 14791e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower is False by default. 148054d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo reuse: (optional) Python boolean describing whether to reuse variables 148154d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo in an existing scope. If not `True`, and the existing scope already has 148254d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo the given variables, an error is raised. 14831e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """ 1484e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(NASCell, self).__init__(_reuse=reuse) 14851e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._num_units = num_units 14861e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._num_proj = num_proj 14871e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._use_biases = use_biases 148854d50ffec8df4f748694632dbe5ebde9971e2c9eEugene Brevdo self._reuse = reuse 14891e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 14901e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower if num_proj is not None: 1491827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 14921e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._output_size = num_proj 14931e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower else: 1494827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) 14951e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower self._output_size = num_units 14961e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 14971e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower @property 14981e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower def state_size(self): 14991e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower return self._state_size 15001e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 15011e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower @property 15021e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower def output_size(self): 15031e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower return self._output_size 15041e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 1505e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 15061e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """Run one step of NAS Cell. 15071e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 15081e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Args: 15091e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower inputs: input Tensor, 2D, batch x num_units. 15101e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower state: This must be a tuple of state Tensors, both `2-D`, with column 15111e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower sizes `c_state` and `m_state`. 15121e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 15131e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Returns: 15141e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower A tuple containing: 15151e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower - A `2-D, [batch x output_dim]`, Tensor representing the output of the 15161e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower NAS Cell after reading `inputs` when previous state was `state`. 15171e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Here output_dim is: 15181e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower num_proj if num_proj was set, 15191e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower num_units otherwise. 15201e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower - Tensor(s) representing the new state of NAS Cell after reading `inputs` 15211e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower when the previous state was `state`. Same type and shape(s) as `state`. 15221e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 15231e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower Raises: 15241e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower ValueError: If input size cannot be inferred from inputs via 15251e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower static shape inference. 15261e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower """ 15271e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower sigmoid = math_ops.sigmoid 15281e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower tanh = math_ops.tanh 15291e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower relu = nn_ops.relu 15301e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 15311e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower num_proj = self._num_units if self._num_proj is None else self._num_proj 15321e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 15331e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower (c_prev, m_prev) = state 15341e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 15351e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower dtype = inputs.dtype 15361e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower input_size = inputs.get_shape().with_rank(2)[1] 15371e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower if input_size.value is None: 15381e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1539e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Variables for the NAS cell. W_m is all matrices multiplying the 1540e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # hiddenstate and W_inputs is all matrices multiplying the inputs. 1541ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie concat_w_m = vs.get_variable("recurrent_kernel", 1542ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [num_proj, 8 * self._num_units], dtype) 1543e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo concat_w_inputs = vs.get_variable( 1544ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "kernel", [input_size.value, 8 * self._num_units], dtype) 1545e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1546e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_matrix = math_ops.matmul(m_prev, concat_w_m) 1547e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) 1548e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1549e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_biases: 1550e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo b = vs.get_variable( 1551e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "bias", 1552e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo shape=[8 * self._num_units], 1553e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=init_ops.zeros_initializer(), 1554e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo dtype=dtype) 1555e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo m_matrix = nn_ops.bias_add(m_matrix, b) 1556e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1557e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # The NAS cell branches into 8 different splits for both the hiddenstate 1558e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # and the input 1559ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie m_matrix_splits = array_ops.split( 1560ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie axis=1, num_or_size_splits=8, value=m_matrix) 1561ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie inputs_matrix_splits = array_ops.split( 1562ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie axis=1, num_or_size_splits=8, value=inputs_matrix) 1563e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1564e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # First layer 1565e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) 1566e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1]) 1567e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2]) 1568e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3]) 1569e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4]) 1570e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5]) 1571e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6]) 1572e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7]) 1573e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1574e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Second layer 1575e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l2_0 = tanh(layer1_0 * layer1_1) 1576e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l2_1 = tanh(layer1_2 + layer1_3) 1577e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l2_2 = tanh(layer1_4 * layer1_5) 1578e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l2_3 = sigmoid(layer1_6 + layer1_7) 1579e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1580e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Inject the cell 1581e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l2_0 = tanh(l2_0 + c_prev) 1582e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1583e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Third layer 1584e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l3_0_pre = l2_0 * l2_1 1585e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_c = l3_0_pre # create new cell 1586e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l3_0 = l3_0_pre 1587e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo l3_1 = tanh(l2_2 + l2_3) 1588e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1589e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Final layer 1590e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_m = tanh(l3_0 * l3_1) 1591e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo 1592e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo # Projection layer if specified 1593e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._num_proj is not None: 1594ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie concat_w_proj = vs.get_variable("projection_weights", 1595ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [self._num_units, self._num_proj], dtype) 1596e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_m = math_ops.matmul(new_m, concat_w_proj) 15971e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 1598827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m) 1599e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo return new_m, new_state 16001e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 16011e982bb8330c0c5a571559e8653c5e8b948a621eA. Unique TensorFlower 1602827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass UGRNNCell(rnn_cell_impl.RNNCell): 1603d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Update Gate Recurrent Neural Network (UGRNN) cell. 1604d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1605d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Compromise between a LSTM/GRU and a vanilla RNN. There is only one 1606d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower gate, and that is to determine whether the unit should be 1607d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower integrating or computing instantaneously. This is the recurrent 1608d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower idea of the feedforward Highway Network. 1609d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1610d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower This implements the recurrent cell from the paper: 1611d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1612d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower https://arxiv.org/abs/1611.09913 1613d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1614d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo. 1615d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. 1616d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1617d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1618ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, 1619ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_units, 1620ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=None, 1621ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie forget_bias=1.0, 1622ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie activation=math_ops.tanh, 1623ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie reuse=None): 1624d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Initialize the parameters for an UGRNN cell. 1625d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1626d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Args: 1627d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower num_units: int, The number of units in the UGRNN cell 1628d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower initializer: (optional) The initializer to use for the weight matrices. 1629d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower forget_bias: (optional) float, default 1.0, The initial bias of the 1630d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower forget gate, used to reduce the scale of forgetting at the beginning 1631d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower of the training. 1632d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower activation: (optional) Activation function of the inner states. 1633d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Default is `tf.tanh`. 1634d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower reuse: (optional) Python boolean describing whether to reuse variables 1635d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower in an existing scope. If not `True`, and the existing scope already has 1636d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower the given variables, an error is raised. 1637d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1638e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(UGRNNCell, self).__init__(_reuse=reuse) 1639d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._num_units = num_units 1640d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._initializer = initializer 1641d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._forget_bias = forget_bias 1642d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._activation = activation 1643d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._reuse = reuse 16443f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear = None 1645d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1646d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower @property 1647d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower def state_size(self): 1648d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return self._num_units 1649d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1650d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower @property 1651d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower def output_size(self): 1652d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return self._num_units 1653d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1654e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 1655d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Run one step of UGRNN. 1656d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1657d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Args: 1658d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower inputs: input Tensor, 2D, batch x input size. 1659d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower state: state Tensor, 2D, batch x num units. 1660d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1661d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Returns: 1662d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_output: batch x num units, Tensor representing the output of the UGRNN 1663d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower after reading `inputs` when previous state was `state`. Identical to 1664d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower `new_state`. 1665d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_state: batch x num units, Tensor representing the state of the UGRNN 1666d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower after reading `inputs` when previous state was `state`. 1667d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1668d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Raises: 1669d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower ValueError: If input size cannot be inferred from inputs via 1670d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower static shape inference. 1671d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1672d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower sigmoid = math_ops.sigmoid 1673d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1674d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower input_size = inputs.get_shape().with_rank(2)[1] 1675d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower if input_size.value is None: 1676d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1677d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1678ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie with vs.variable_scope( 1679ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie vs.get_variable_scope(), initializer=self._initializer): 1680d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower cell_inputs = array_ops.concat([inputs, state], 1) 16813f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower if self._linear is None: 16823f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear = _Linear(cell_inputs, 2 * self._num_units, True) 16833f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower rnn_matrix = self._linear(cell_inputs) 1684d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1685d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower [g_act, c_act] = array_ops.split( 1686d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower axis=1, num_or_size_splits=2, value=rnn_matrix) 1687d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1688d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower c = self._activation(c_act) 1689d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower g = sigmoid(g_act + self._forget_bias) 1690d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_state = g * state + (1.0 - g) * c 1691d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_output = new_state 1692d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1693d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return new_output, new_state 1694d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1695d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1696827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass IntersectionRNNCell(rnn_cell_impl.RNNCell): 1697d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Intersection Recurrent Neural Network (+RNN) cell. 1698d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1699d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Architecture with coupled recurrent gate as well as coupled depth 1700d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower gate, designed to improve information flow through stacked RNNs. As the 1701d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower architecture uses depth gating, the dimensionality of the depth 1702d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower output (y) also should not change through depth (input size == output size). 1703d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower To achieve this, the first layer of a stacked Intersection RNN projects 1704d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower the inputs to N (num units) dimensions. Therefore when initializing an 1705d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower IntersectionRNNCell, one should set `num_in_proj = N` for the first layer 1706d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower and use default settings for subsequent layers. 1707d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1708d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower This implements the recurrent cell from the paper: 1709d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1710d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower https://arxiv.org/abs/1611.09913 1711d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1712d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo. 1713d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. 1714d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1715d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower The Intersection RNN is built for use in deeply stacked 1716d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower RNNs so it may not achieve best performance with depth 1. 1717d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1718d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1719ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, 1720ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_units, 1721ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_in_proj=None, 1722ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=None, 1723ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie forget_bias=1.0, 1724ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie y_activation=nn_ops.relu, 1725ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie reuse=None): 1726d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Initialize the parameters for an +RNN cell. 1727d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1728d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Args: 1729d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower num_units: int, The number of units in the +RNN cell 1730d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower num_in_proj: (optional) int, The input dimensionality for the RNN. 1731d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower If creating the first layer of an +RNN, this should be set to 1732d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower `num_units`. Otherwise, this should be set to `None` (default). 1733d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower If `None`, dimensionality of `inputs` should be equal to `num_units`, 1734d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower otherwise ValueError is thrown. 1735d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower initializer: (optional) The initializer to use for the weight matrices. 1736d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower forget_bias: (optional) float, default 1.0, The initial bias of the 1737d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower forget gates, used to reduce the scale of forgetting at the beginning 1738d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower of the training. 1739d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower y_activation: (optional) Activation function of the states passed 1740d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower through depth. Default is 'tf.nn.relu`. 1741d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower reuse: (optional) Python boolean describing whether to reuse variables 1742d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower in an existing scope. If not `True`, and the existing scope already has 1743d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower the given variables, an error is raised. 1744d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1745e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(IntersectionRNNCell, self).__init__(_reuse=reuse) 1746d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._num_units = num_units 1747d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._initializer = initializer 1748d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._forget_bias = forget_bias 1749d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._num_input_proj = num_in_proj 1750d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._y_activation = y_activation 1751d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower self._reuse = reuse 17523f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear1 = None 17533f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear2 = None 1754d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1755d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower @property 1756d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower def state_size(self): 1757d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return self._num_units 1758d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1759d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower @property 1760d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower def output_size(self): 1761d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return self._num_units 1762d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1763e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 1764d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """Run one step of the Intersection RNN. 1765d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1766d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Args: 1767d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower inputs: input Tensor, 2D, batch x input size. 1768d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower state: state Tensor, 2D, batch x num units. 1769d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1770d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Returns: 1771d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_y: batch x num units, Tensor representing the output of the +RNN 1772d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower after reading `inputs` when previous state was `state`. 1773d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_state: batch x num units, Tensor representing the state of the +RNN 1774d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower after reading `inputs` when previous state was `state`. 1775d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1776d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower Raises: 1777d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower ValueError: If input size cannot be inferred from `inputs` via 1778d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower static shape inference. 1779d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower ValueError: If input size != output size (these must be equal when 1780d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower using the Intersection RNN). 1781d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower """ 1782d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower sigmoid = math_ops.sigmoid 1783d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower tanh = math_ops.tanh 1784d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1785d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower input_size = inputs.get_shape().with_rank(2)[1] 1786d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower if input_size.value is None: 1787d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 1788d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1789ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie with vs.variable_scope( 1790ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie vs.get_variable_scope(), initializer=self._initializer): 1791d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower # read-in projections (should be used for first layer in deep +RNN 1792d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower # to transform size of inputs from I --> N) 1793d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower if input_size.value != self._num_units: 1794d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower if self._num_input_proj: 1795d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower with vs.variable_scope("in_projection"): 17963f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower if self._linear1 is None: 17973f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear1 = _Linear(inputs, self._num_units, True) 17983f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower inputs = self._linear1(inputs) 1799d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower else: 1800d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower raise ValueError("Must have input size == output size for " 1801d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower "Intersection RNN. To fix, num_in_proj should " 1802d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower "be set to num_units at cell init.") 1803d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1804d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower n_dim = i_dim = self._num_units 1805d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower cell_inputs = array_ops.concat([inputs, state], 1) 18063f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower if self._linear2 is None: 1807ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True) 18083f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower rnn_matrix = self._linear2(cell_inputs) 1809d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1810ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie gh_act = rnn_matrix[:, :n_dim] # b x n 1811ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie h_act = rnn_matrix[:, n_dim:2 * n_dim] # b x n 1812ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim] # b x i 1813ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim] # b x i 1814d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1815d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower h = tanh(h_act) 1816d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower y = self._y_activation(y_act) 1817d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower gh = sigmoid(gh_act + self._forget_bias) 1818d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower gy = sigmoid(gy_act + self._forget_bias) 1819d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1820d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_state = gh * state + (1.0 - gh) * h # passed thru time 1821d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower new_y = gy * inputs + (1.0 - gy) * y # passed thru depth 1822d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1823d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower return new_y, new_state 1824d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1825d58fe97fdefcd968c28f4cff916ba6a26e234d4fA. Unique TensorFlower 1826bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo_REGISTERED_OPS = None 1827bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1828bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1829827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass CompiledWrapper(rnn_cell_impl.RNNCell): 1830bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo """Wraps step execution in an XLA JIT scope.""" 1831bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1832bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def __init__(self, cell, compile_stateful=False): 1833bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo """Create CompiledWrapper cell. 1834bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1835bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo Args: 1836bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo cell: Instance of `RNNCell`. 1837bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo compile_stateful: Whether to compile stateful ops like initializers 1838bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo and random number generators (default: False). 1839bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo """ 1840bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo self._cell = cell 1841bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo self._compile_stateful = compile_stateful 1842bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1843bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo @property 1844bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def state_size(self): 1845bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo return self._cell.state_size 1846bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1847bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo @property 1848bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def output_size(self): 1849bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo return self._cell.output_size 1850bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 185103abac7f23e6e6864949b959435282729192692eEugene Brevdo def zero_state(self, batch_size, dtype): 185203abac7f23e6e6864949b959435282729192692eEugene Brevdo with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 185303abac7f23e6e6864949b959435282729192692eEugene Brevdo return self._cell.zero_state(batch_size, dtype) 185403abac7f23e6e6864949b959435282729192692eEugene Brevdo 1855bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def __call__(self, inputs, state, scope=None): 1856bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo if self._compile_stateful: 1857bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo compile_ops = True 1858bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo else: 1859ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 1860bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo def compile_ops(node_def): 1861bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo global _REGISTERED_OPS 1862bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo if _REGISTERED_OPS is None: 1863bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo _REGISTERED_OPS = op_def_registry.get_registered_ops() 1864bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo return not _REGISTERED_OPS[node_def.op].is_stateful 1865bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo 1866bab22b9f25741e172bb70ff1f82dc803ced0f579Eugene Brevdo with jit.experimental_jit_scope(compile_ops=compile_ops): 1867dd6f9d5f43870dc39dbed91c6897dc4bb22ca495Eugene Brevdo return self._cell(inputs, state, scope=scope) 1868bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1869bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1870ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xiedef _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32): 1871bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """Returns an exponential distribution initializer. 1872bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1873bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Args: 1874bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower minval: float or a scalar float Tensor. With value > 0. Lower bound of the 1875bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower range of random values to generate. 1876bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower maxval: float or a scalar float Tensor. With value > minval. Upper bound of 1877bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower the range of random values to generate. 1878bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower seed: An integer. Used to create random seeds. 1879bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower dtype: The data type. 1880bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1881bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Returns: 1882bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower An initializer that generates tensors with an exponential distribution. 1883bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """ 1884bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1885bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower def _initializer(shape, dtype=dtype, partition_info=None): 1886bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower del partition_info # Unused. 1887bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower return math_ops.exp( 1888bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower random_ops.random_uniform( 1889ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie shape, math_ops.log(minval), math_ops.log(maxval), dtype, 1890bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower seed=seed)) 1891bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1892bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower return _initializer 1893bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1894bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1895827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass PhasedLSTMCell(rnn_cell_impl.RNNCell): 1896bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """Phased LSTM recurrent network cell. 1897bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1898bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower https://arxiv.org/pdf/1610.09513v1.pdf 1899bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """ 1900bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1901bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower def __init__(self, 1902bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower num_units, 1903bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower use_peepholes=False, 1904bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower leak=0.001, 19050a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower ratio_on=0.1, 19060a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower trainable_ratio_on=True, 1907bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower period_init_min=1.0, 1908bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower period_init_max=1000.0, 1909bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower reuse=None): 1910bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """Initialize the Phased LSTM cell. 1911bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1912bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Args: 1913bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower num_units: int, The number of units in the Phased LSTM cell. 1914bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower use_peepholes: bool, set True to enable peephole connections. 1915bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower leak: float or scalar float Tensor with value in [0, 1]. Leak applied 1916bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower during training. 19170a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the 1918bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower period during which the gates are open. 19190a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower trainable_ratio_on: bool, weather ratio_on is trainable. 1920bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower period_init_min: float or scalar float Tensor. With value > 0. 19211b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu Minimum value of the initialized period. 1922bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower The period values are initialized by drawing from the distribution: 1923bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower e^U(log(period_init_min), log(period_init_max)) 1924bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Where U(.,.) is the uniform distribution. 1925bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower period_init_max: float or scalar float Tensor. 19261b5235fd897f7ea5cffc715300f67b4dc852fa27Jonathan Hseu With value > period_init_min. Maximum value of the initialized period. 1927bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower reuse: (optional) Python boolean describing whether to reuse variables 1928bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower in an existing scope. If not `True`, and the existing scope already has 1929bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower the given variables, an error is raised. 1930bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """ 1931e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo super(PhasedLSTMCell, self).__init__(_reuse=reuse) 1932bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._num_units = num_units 1933bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._use_peepholes = use_peepholes 1934bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._leak = leak 19350a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower self._ratio_on = ratio_on 19360a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower self._trainable_ratio_on = trainable_ratio_on 1937bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._period_init_min = period_init_min 1938bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._period_init_max = period_init_max 1939bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower self._reuse = reuse 19403f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear1 = None 19413f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear2 = None 19423f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear3 = None 1943bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1944bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower @property 1945bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower def state_size(self): 1946827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) 1947bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1948bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower @property 1949bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower def output_size(self): 1950bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower return self._num_units 1951bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1952bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower def _mod(self, x, y): 1953bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """Modulo function that propagates x gradients.""" 1954bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x 1955bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 19560a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower def _get_cycle_ratio(self, time, phase, period): 19570a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower """Compute the cycle ratio in the dtype of the time.""" 19580a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower phase_casted = math_ops.cast(phase, dtype=time.dtype) 19590a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower period_casted = math_ops.cast(period, dtype=time.dtype) 19600a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower shifted_time = time - phase_casted 19610a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower cycle_ratio = self._mod(shifted_time, period_casted) / period_casted 19620a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower return math_ops.cast(cycle_ratio, dtype=dtypes.float32) 19630a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower 1964e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo def call(self, inputs, state): 1965bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """Phased LSTM Cell. 1966bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1967bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Args: 19680a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower inputs: A tuple of 2 Tensor. 19690a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower The first Tensor has shape [batch, 1], and type float32 or float64. 19700a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower It stores the time. 19710a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower The second Tensor has shape [batch, features_size], and type float32. 19720a5652254eee640c1f400fc76dcae394bd9206a0A. Unique TensorFlower It stores the features. 1973827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo state: rnn_cell_impl.LSTMStateTuple, state from previous timestep. 1974bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1975bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower Returns: 1976bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower A tuple containing: 1977bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower - A Tensor of float32, and shape [batch_size, num_units], representing the 1978bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower output of the cell. 1979827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo - A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape 1980bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower [batch_size, num_units], representing the new state and the output. 1981bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower """ 1982e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo (c_prev, h_prev) = state 1983e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo (time, x) = inputs 1984bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1985e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo in_mask_gates = [x, h_prev] 1986e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 1987e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo in_mask_gates.append(c_prev) 1988bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1989e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("mask_gates"): 19903f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower if self._linear1 is None: 19913f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True) 19923f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower 1993ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates)) 1994e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo [input_gate, forget_gate] = array_ops.split( 1995e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo axis=1, num_or_size_splits=2, value=mask_gates) 1996bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 1997e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("new_input"): 19983f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower if self._linear2 is None: 19993f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear2 = _Linear([x, h_prev], self._num_units, True) 20003f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower new_input = math_ops.tanh(self._linear2([x, h_prev])) 2001bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2002e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_c = (c_prev * forget_gate + input_gate * new_input) 2003bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2004e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo in_out_gate = [x, h_prev] 2005e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo if self._use_peepholes: 2006e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo in_out_gate.append(new_c) 2007bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2008e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo with vs.variable_scope("output_gate"): 20093f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower if self._linear3 is None: 20103f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear3 = _Linear(in_out_gate, self._num_units, True) 20113f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower output_gate = math_ops.sigmoid(self._linear3(in_out_gate)) 2012bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2013e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_h = math_ops.tanh(new_c) * output_gate 2014bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2015e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo period = vs.get_variable( 2016e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "period", [self._num_units], 2017ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=_random_exp_initializer(self._period_init_min, 2018ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._period_init_max)) 2019e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo phase = vs.get_variable( 2020e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "phase", [self._num_units], 2021ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=init_ops.random_uniform_initializer(0., 2022ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie period.initial_value)) 2023e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo ratio_on = vs.get_variable( 2024e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo "ratio_on", [self._num_units], 2025e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo initializer=init_ops.constant_initializer(self._ratio_on), 2026e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo trainable=self._trainable_ratio_on) 2027bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2028e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo cycle_ratio = self._get_cycle_ratio(time, phase, period) 2029bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2030e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo k_up = 2 * cycle_ratio / ratio_on 2031e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo k_down = 2 - k_up 2032e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo k_closed = self._leak * cycle_ratio 2033bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2034e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed) 2035e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k) 2036bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2037e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_c = k * new_c + (1 - k) * c_prev 2038e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo new_h = k * new_h + (1 - k) * h_prev 2039bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2040827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) 2041bd755618e43931b53618a5dd0b77e155cd6151eeA. Unique TensorFlower 2042e8482ab23bd0fce5c2941f6a190158bca2610a35Eugene Brevdo return new_h, new_state 2043ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2044ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 204528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlowerclass ConvLSTMCell(rnn_cell_impl.RNNCell): 204628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """Convolutional LSTM recurrent network cell. 204728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 204828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower https://arxiv.org/pdf/1506.04214v1.pdf 204928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """ 205028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 205128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower def __init__(self, 205228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower conv_ndims, 205328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower input_shape, 205428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower output_channels, 205528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower kernel_shape, 205628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower use_bias=True, 205728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower skip_connection=False, 205828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower forget_bias=1.0, 205928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower initializers=None, 206028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower name="conv_lstm_cell"): 206128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """Construct ConvLSTMCell. 206228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower Args: 206328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower conv_ndims: Convolution dimensionality (1, 2 or 3). 206428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower input_shape: Shape of the input as int tuple, excluding the batch size. 206528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower output_channels: int, number of output channels of the conv LSTM. 206628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3). 206728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower use_bias: Use bias in convolutions. 206828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower skip_connection: If set to `True`, concatenate the input to the 206928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower output of the conv LSTM. Default: `False`. 207028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower forget_bias: Forget bias. 207128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower name: Name of the module. 207228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower Raises: 207328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower ValueError: If `skip_connection` is `True` and stride is different from 1 207428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower or if `input_shape` is incompatible with `conv_ndims`. 207528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """ 207628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower super(ConvLSTMCell, self).__init__(name=name) 207728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 2078ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie if conv_ndims != len(input_shape) - 1: 207928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower raise ValueError("Invalid input_shape {} for conv_ndims={}.".format( 208028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower input_shape, conv_ndims)) 208128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 208228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower self._conv_ndims = conv_ndims 208328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower self._input_shape = input_shape 208428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower self._output_channels = output_channels 208528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower self._kernel_shape = kernel_shape 208628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower self._use_bias = use_bias 208728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower self._forget_bias = forget_bias 208828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower self._skip_connection = skip_connection 208928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 209028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower self._total_output_channels = output_channels 209128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower if self._skip_connection: 209228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower self._total_output_channels += self._input_shape[-1] 209328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 2094b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng state_size = tensor_shape.TensorShape( 2095b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._input_shape[:-1] + [self._output_channels]) 209628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size) 2097ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._output_size = tensor_shape.TensorShape( 2098ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._input_shape[:-1] + [self._total_output_channels]) 209928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 210028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower @property 210128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower def output_size(self): 210228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower return self._output_size 210328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 210428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower @property 210528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower def state_size(self): 210628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower return self._state_size 210728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 210828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower def call(self, inputs, state, scope=None): 210928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower cell, hidden = state 2110ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie new_hidden = _conv([inputs, hidden], self._kernel_shape, 2111ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 4 * self._output_channels, self._use_bias) 2112ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie gates = array_ops.split( 2113ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1) 211428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 211528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower input_gate, new_input, forget_gate, output_gate = gates 211628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell 211728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input) 211828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate) 211928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 212028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower if self._skip_connection: 212128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower output = array_ops.concat([output, inputs], axis=-1) 212228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) 212328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower return output, new_state 212428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 2125ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 212628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlowerclass Conv1DLSTMCell(ConvLSTMCell): 212728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """1D Convolutional LSTM recurrent network cell. 212828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 212928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower https://arxiv.org/pdf/1506.04214v1.pdf 213028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """ 2131ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 213228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower def __init__(self, name="conv_1d_lstm_cell", **kwargs): 213328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """Construct Conv1DLSTM. See `ConvLSTMCell` for more details.""" 213428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs) 213528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 2136ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 213728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlowerclass Conv2DLSTMCell(ConvLSTMCell): 213828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """2D Convolutional LSTM recurrent network cell. 213928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 214028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower https://arxiv.org/pdf/1506.04214v1.pdf 214128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """ 2142ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 214328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower def __init__(self, name="conv_2d_lstm_cell", **kwargs): 214428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """Construct Conv2DLSTM. See `ConvLSTMCell` for more details.""" 214528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs) 214628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 2147ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 214828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlowerclass Conv3DLSTMCell(ConvLSTMCell): 214928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """3D Convolutional LSTM recurrent network cell. 215028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 215128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower https://arxiv.org/pdf/1506.04214v1.pdf 215228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """ 2153ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 215428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower def __init__(self, name="conv_3d_lstm_cell", **kwargs): 215528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """Construct Conv3DLSTM. See `ConvLSTMCell` for more details.""" 215628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs) 215728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 2158b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2159b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengdef _conv(args, filter_size, num_features, bias, bias_start=0.0): 216028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """convolution: 216128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower Args: 216226f43e6a8e1c234060096f21f1fd57d3cf57cfbcA. Unique TensorFlower args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D, 216328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower batch x n, Tensors. 216428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower filter_size: int tuple of filter height and width. 216528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower num_features: int, number of features. 216628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower bias_start: starting value to initialize the bias; 0 by default. 216728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower Returns: 216828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower A 3D, 4D, or 5D Tensor with shape [batch ... num_features] 216928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower Raises: 217028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower ValueError: if some of the arguments has unspecified or wrong shape. 217128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower """ 217228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 217328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower # Calculate the total size of arguments on dimension 1. 217428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower total_arg_size_depth = 0 217528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower shapes = [a.get_shape().as_list() for a in args] 217628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower shape_length = len(shapes[0]) 217728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower for shape in shapes: 2178ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie if len(shape) not in [3, 4, 5]: 21795eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen raise ValueError("Conv Linear expects 3D, 4D " 21805eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen "or 5D arguments: %s" % str(shapes)) 218128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower if len(shape) != len(shapes[0]): 21825eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen raise ValueError("Conv Linear expects all args " 21835eaefbabce16bffeeb4b19cee9890b1aeccabb09Frank Chen "to be of same Dimension: %s" % str(shapes)) 218428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower else: 218528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower total_arg_size_depth += shape[-1] 218628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower dtype = [a.dtype for a in args][0] 218728ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 218828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower # determine correct conv operation 2189ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie if shape_length == 3: 219028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower conv_op = nn_ops.conv1d 219128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower strides = 1 219228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower elif shape_length == 4: 219328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower conv_op = nn_ops.conv2d 2194ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie strides = shape_length * [1] 219528ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower elif shape_length == 5: 219628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower conv_op = nn_ops.conv3d 2197ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie strides = shape_length * [1] 219828ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower 219928ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower # Now the computation. 220028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower kernel = vs.get_variable( 2201ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype) 220228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower if len(args) == 1: 2203ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie res = conv_op(args[0], kernel, strides, padding="SAME") 220428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower else: 2205ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie res = conv_op( 2206ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie array_ops.concat(axis=shape_length - 1, values=args), 2207ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie kernel, 2208ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie strides, 2209ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie padding="SAME") 221028ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower if not bias: 221128ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower return res 221228ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower bias_term = vs.get_variable( 221328ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower "biases", [num_features], 221428ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower dtype=dtype, 2215ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) 221628ce1d163eeffe618a6972c5245be0e660d94e85A. Unique TensorFlower return res + bias_term 2217ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2218ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 2219827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdoclass GLSTMCell(rnn_cell_impl.RNNCell): 2220ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner """Group LSTM cell (G-LSTM). 2221ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2222ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner The implementation is based on: 2223ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2224ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner https://arxiv.org/abs/1703.10722 2225ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2226ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner O. Kuchaiev and B. Ginsburg 2227ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner "Factorization Tricks for LSTM Networks", ICLR 2017 workshop. 2228ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner """ 2229ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2230ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, 2231ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_units, 2232ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=None, 2233ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_proj=None, 2234ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie number_of_groups=1, 2235ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie forget_bias=1.0, 2236ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie activation=math_ops.tanh, 2237ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner reuse=None): 2238ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner """Initialize the parameters of G-LSTM cell. 2239ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2240ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner Args: 2241ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner num_units: int, The number of units in the G-LSTM cell 2242ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner initializer: (optional) The initializer to use for the weight and 2243ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner projection matrices. 2244ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner num_proj: (optional) int, The output dimensionality for the projection 2245ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner matrices. If None, no projection is performed. 2246ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner number_of_groups: (optional) int, number of groups to use. 2247ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner If `number_of_groups` is 1, then it should be equivalent to LSTM cell 2248ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner forget_bias: Biases of the forget gate are initialized by default to 1 2249ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner in order to reduce the scale of forgetting at the beginning of 2250ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner the training. 2251ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner activation: Activation function of the inner states. 2252ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner reuse: (optional) Python boolean describing whether to reuse variables 2253ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner in an existing scope. If not `True`, and the existing scope already 2254ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner has the given variables, an error is raised. 2255ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2256ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner Raises: 225726f43e6a8e1c234060096f21f1fd57d3cf57cfbcA. Unique TensorFlower ValueError: If `num_units` or `num_proj` is not divisible by 2258ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner `number_of_groups`. 2259ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner """ 2260ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner super(GLSTMCell, self).__init__(_reuse=reuse) 2261ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner self._num_units = num_units 2262ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner self._initializer = initializer 2263ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner self._num_proj = num_proj 2264ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner self._forget_bias = forget_bias 2265ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner self._activation = activation 2266ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner self._number_of_groups = number_of_groups 2267ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2268ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner if self._num_units % self._number_of_groups != 0: 2269ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner raise ValueError("num_units must be divisible by number_of_groups") 2270ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner if self._num_proj: 2271ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner if self._num_proj % self._number_of_groups != 0: 2272ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner raise ValueError("num_proj must be divisible by number_of_groups") 2273ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._group_shape = [ 2274ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie int(self._num_proj / self._number_of_groups), 2275ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie int(self._num_units / self._number_of_groups) 2276ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie ] 2277ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner else: 2278ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._group_shape = [ 2279ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie int(self._num_units / self._number_of_groups), 2280ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie int(self._num_units / self._number_of_groups) 2281ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie ] 2282ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2283ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner if num_proj: 2284827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 2285ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner self._output_size = num_proj 2286ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner else: 2287827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) 2288ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner self._output_size = num_units 2289d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar self._linear1 = [None] * number_of_groups 22903f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear2 = None 2291ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2292ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner @property 2293ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner def state_size(self): 2294ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner return self._state_size 2295ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2296ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner @property 2297ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner def output_size(self): 2298ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner return self._output_size 2299ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2300ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner def _get_input_for_group(self, inputs, group_id, group_size): 2301ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner """Slices inputs into groups to prepare for processing by cell's groups 2302ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2303ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner Args: 2304ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner inputs: cell input or it's previous state, 2305ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner a Tensor, 2D, [batch x num_units] 2306ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner group_id: group id, a Scalar, for which to prepare input 2307ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner group_size: size of the group 2308ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2309ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner Returns: 2310ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner subset of inputs corresponding to group "group_id", 2311ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner a Tensor, 2D, [batch x num_units/number_of_groups] 2312ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner """ 2313ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie return array_ops.slice( 2314ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie input_=inputs, 2315ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie begin=[0, group_id * group_size], 2316ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie size=[self._batch_size, group_size], 2317ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie name=("GLSTM_group%d_input_generation" % group_id)) 2318ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2319ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner def call(self, inputs, state): 2320ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner """Run one step of G-LSTM. 2321ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2322ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner Args: 2323ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner inputs: input Tensor, 2D, [batch x num_units]. 2324ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner state: this must be a tuple of state Tensors, both `2-D`, 2325ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner with column sizes `c_state` and `m_state`. 2326ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2327ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner Returns: 2328ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner A tuple containing: 2329ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2330ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner - A `2-D, [batch x output_dim]`, Tensor representing the output of the 2331ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner G-LSTM after reading `inputs` when previous state was `state`. 2332ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner Here output_dim is: 2333ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner num_proj if num_proj was set, 2334ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner num_units otherwise. 23350815de21239955e346b562e899640649c8d2b9cbBenoit Steiner - LSTMStateTuple representing the new state of G-LSTM cell 2336ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner after reading `inputs` when the previous state was `state`. 2337ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2338ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner Raises: 2339ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner ValueError: If input size cannot be inferred from inputs via 2340ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner static shape inference. 2341ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner """ 2342ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner (c_prev, m_prev) = state 2343ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2344ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] 2345ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner dtype = inputs.dtype 2346ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner scope = vs.get_variable_scope() 2347ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner with vs.variable_scope(scope, initializer=self._initializer): 2348ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner i_parts = [] 2349ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner j_parts = [] 2350ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner f_parts = [] 2351ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner o_parts = [] 2352ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2353ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner for group_id in range(self._number_of_groups): 2354ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner with vs.variable_scope("group%d" % group_id): 2355ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner x_g_id = array_ops.concat( 2356ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [ 2357ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._get_input_for_group(inputs, group_id, 2358ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._group_shape[0]), 2359ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._get_input_for_group(m_prev, group_id, 2360ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._group_shape[0]) 2361ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie ], 2362ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie axis=1) 2363d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar linear = self._linear1[group_id] 2364d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar if linear is None: 2365d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar linear = _Linear(x_g_id, 4 * self._group_shape[1], False) 2366d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar self._linear1[group_id] = linear 2367d0904cbe01c88332acb4faa8bede21adb5fa1de7Asim Shankar R_k = linear(x_g_id) # pylint: disable=invalid-name 2368ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) 2369ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2370ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner i_parts.append(i_k) 2371ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner j_parts.append(j_k) 2372ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner f_parts.append(f_k) 2373ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner o_parts.append(o_k) 2374ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2375ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie bi = vs.get_variable( 2376ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie name="bias_i", 2377ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie shape=[self._num_units], 2378ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie dtype=dtype, 2379ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2380ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie bj = vs.get_variable( 2381ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie name="bias_j", 2382ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie shape=[self._num_units], 2383ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie dtype=dtype, 2384ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2385ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie bf = vs.get_variable( 2386ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie name="bias_f", 2387ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie shape=[self._num_units], 2388ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie dtype=dtype, 2389ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2390ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie bo = vs.get_variable( 2391ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie name="bias_o", 2392ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie shape=[self._num_units], 2393ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie dtype=dtype, 2394ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=init_ops.constant_initializer(0.0, dtype=dtype)) 2395ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2396ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi) 2397ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj) 2398ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf) 2399ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo) 2400ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2401ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie c = ( 2402ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie math_ops.sigmoid(f + self._forget_bias) * c_prev + 2403ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie math_ops.sigmoid(i) * math_ops.tanh(j)) 2404ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner m = math_ops.sigmoid(o) * self._activation(c) 2405ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2406ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner if self._num_proj is not None: 2407ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner with vs.variable_scope("projection"): 24083f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower if self._linear2 is None: 24093f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower self._linear2 = _Linear(m, self._num_proj, False) 24103f579020bab8f00e4621e9c7c740cbf13136a809A. Unique TensorFlower m = self._linear2(m) 2411ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner 2412827d2e4b9180db67853f60c125e548d83986b96cEugene Brevdo new_state = rnn_cell_impl.LSTMStateTuple(c, m) 2413ee112cff56081fb9d0b74c987a8935acc360b05cBenoit Steiner return m, new_state 2414b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2415b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2416b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Fengclass LayerNormLSTMCell(rnn_cell_impl.RNNCell): 2417b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng """Long short-term memory unit (LSTM) recurrent network cell. 2418b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2419b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng The default non-peephole implementation is based on: 2420b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2421b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng http://www.bioinf.jku.at/publications/older/2604.pdf 2422b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2423b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng S. Hochreiter and J. Schmidhuber. 2424b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 2425b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2426b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng The peephole implementation is based on: 2427b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2428b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng https://research.google.com/pubs/archive/43905.pdf 2429b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2430b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Hasim Sak, Andrew Senior, and Francoise Beaufays. 2431b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng "Long short-term memory recurrent neural network architectures for 2432b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng large scale acoustic modeling." INTERSPEECH, 2014. 2433b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2434b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng The class uses optional peep-hole connections, optional cell clipping, and 2435b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng an optional projection layer. 2436b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2437b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Layer normalization implementation is based on: 2438b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2439b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng https://arxiv.org/abs/1607.06450. 2440b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2441b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng "Layer Normalization" 2442b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton 2443b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2444b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng and is applied before the internal nonlinearities. 2445b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2446b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng """ 2447b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2448b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng def __init__(self, 2449b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng num_units, 2450b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng use_peepholes=False, 2451b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng cell_clip=None, 2452b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng initializer=None, 2453b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng num_proj=None, 2454b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng proj_clip=None, 2455b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng forget_bias=1.0, 2456b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng activation=None, 2457b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng layer_norm=False, 2458b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng norm_gain=1.0, 2459b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng norm_shift=0.0, 2460b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng reuse=None): 2461b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng """Initialize the parameters for an LSTM cell. 2462b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2463b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Args: 2464b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng num_units: int, The number of units in the LSTM cell 2465b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng use_peepholes: bool, set True to enable diagonal/peephole connections. 2466b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng cell_clip: (optional) A float value, if provided the cell state is clipped 2467b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng by this value prior to the cell output activation. 2468b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng initializer: (optional) The initializer to use for the weight and 2469b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng projection matrices. 2470b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng num_proj: (optional) int, The output dimensionality for the projection 2471b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng matrices. If None, no projection is performed. 2472b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 2473b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng provided, then the projected values are clipped elementwise to within 2474b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng `[-proj_clip, proj_clip]`. 2475b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng forget_bias: Biases of the forget gate are initialized by default to 1 2476b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng in order to reduce the scale of forgetting at the beginning of 2477b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng the training. Must set it manually to `0.0` when restoring from 2478b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng CudnnLSTM trained checkpoints. 2479b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng activation: Activation function of the inner states. Default: `tanh`. 2480b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng layer_norm: If `True`, layer normalization will be applied. 2481b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng norm_gain: float, The layer normalization gain initial value. If 2482b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng `layer_norm` has been set to `False`, this argument will be ignored. 2483b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng norm_shift: float, The layer normalization shift initial value. If 2484b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng `layer_norm` has been set to `False`, this argument will be ignored. 2485b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng reuse: (optional) Python boolean describing whether to reuse variables 2486b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng in an existing scope. If not `True`, and the existing scope already has 2487b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng the given variables, an error is raised. 2488b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2489b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng When restoring from CudnnLSTM-trained checkpoints, must use 2490b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng CudnnCompatibleLSTMCell instead. 2491b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng """ 2492b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng super(LayerNormLSTMCell, self).__init__(_reuse=reuse) 2493b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2494b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._num_units = num_units 2495b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._use_peepholes = use_peepholes 2496b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._cell_clip = cell_clip 2497b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._initializer = initializer 2498b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._num_proj = num_proj 2499b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._proj_clip = proj_clip 2500b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._forget_bias = forget_bias 2501b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._activation = activation or math_ops.tanh 2502b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._layer_norm = layer_norm 2503b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._norm_gain = norm_gain 2504b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._norm_shift = norm_shift 2505b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2506b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if num_proj: 2507b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)) 2508b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._output_size = num_proj 2509b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng else: 2510b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)) 2511b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng self._output_size = num_units 2512b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2513b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng @property 2514b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng def state_size(self): 2515b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng return self._state_size 2516b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2517b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng @property 2518b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng def output_size(self): 2519b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng return self._output_size 2520b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2521b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng def _linear(self, 2522b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng args, 2523b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng output_size, 2524b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng bias, 2525b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng bias_initializer=None, 2526b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng kernel_initializer=None, 2527b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng layer_norm=False): 2528b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable. 2529b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2530b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Args: 2531b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng args: a 2D Tensor or a list of 2D, batch x n, Tensors. 2532b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng output_size: int, second dimension of W[i]. 2533b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng bias: boolean, whether to add a bias term or not. 2534b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng bias_initializer: starting value to initialize the bias 2535b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng (default is all zeros). 2536b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng kernel_initializer: starting value to initialize the weight. 2537b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng layer_norm: boolean, whether to apply layer normalization. 2538b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2539b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2540b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Returns: 2541b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng A 2D Tensor with shape [batch x output_size] taking value 2542b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng sum_i(args[i] * W[i]), where each W[i] is a newly created Variable. 2543b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2544b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Raises: 2545b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng ValueError: if some of the arguments has unspecified or wrong shape. 2546b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng """ 2547b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if args is None or (nest.is_sequence(args) and not args): 2548b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng raise ValueError("`args` must be specified") 2549b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if not nest.is_sequence(args): 2550b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng args = [args] 2551b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2552b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # Calculate the total size of arguments on dimension 1. 2553b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng total_arg_size = 0 2554b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng shapes = [a.get_shape() for a in args] 2555b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng for shape in shapes: 2556b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if shape.ndims != 2: 2557b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng raise ValueError("linear is expecting 2D arguments: %s" % shapes) 2558b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if shape[1].value is None: 2559b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng raise ValueError("linear expects shape[1] to be provided for shape %s, " 2560b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng "but saw %s" % (shape, shape[1])) 2561b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng else: 2562b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng total_arg_size += shape[1].value 2563b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2564b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng dtype = [a.dtype for a in args][0] 2565b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2566b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # Now the computation. 2567b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng scope = vs.get_variable_scope() 2568b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng with vs.variable_scope(scope) as outer_scope: 2569b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng weights = vs.get_variable( 2570b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng "kernel", [total_arg_size, output_size], 2571b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng dtype=dtype, 2572b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng initializer=kernel_initializer) 2573b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if len(args) == 1: 2574b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng res = math_ops.matmul(args[0], weights) 2575b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng else: 2576b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng res = math_ops.matmul(array_ops.concat(args, 1), weights) 2577b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if not bias: 2578b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng return res 2579b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng with vs.variable_scope(outer_scope) as inner_scope: 2580b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng inner_scope.set_partitioner(None) 2581b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if bias_initializer is None: 2582b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) 2583b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng biases = vs.get_variable( 2584b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng "bias", [output_size], dtype=dtype, initializer=bias_initializer) 2585b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2586b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if not layer_norm: 2587b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng res = nn_ops.bias_add(res, biases) 2588b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2589b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng return res 2590b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2591b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng def call(self, inputs, state): 2592b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng """Run one step of LSTM. 2593b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2594b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Args: 2595b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng inputs: input Tensor, 2D, batch x num_units. 2596b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng state: this must be a tuple of state Tensors, 2597b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng both `2-D`, with column sizes `c_state` and 2598b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng `m_state`. 2599b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2600b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Returns: 2601b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng A tuple containing: 2602b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2603b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng - A `2-D, [batch x output_dim]`, Tensor representing the output of the 2604b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng LSTM after reading `inputs` when previous state was `state`. 2605b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Here output_dim is: 2606b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng num_proj if num_proj was set, 2607b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng num_units otherwise. 2608b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng - Tensor(s) representing the new state of LSTM after reading `inputs` when 2609b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng the previous state was `state`. Same type and shape(s) as `state`. 2610b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2611b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng Raises: 2612b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng ValueError: If input size cannot be inferred from inputs via 2613b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng static shape inference. 2614b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng """ 2615b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng sigmoid = math_ops.sigmoid 2616b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2617b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng (c_prev, m_prev) = state 2618b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2619b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng dtype = inputs.dtype 2620b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng input_size = inputs.get_shape().with_rank(2)[1] 2621b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if input_size.value is None: 2622b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 2623b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng scope = vs.get_variable_scope() 2624b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng with vs.variable_scope(scope, initializer=self._initializer) as unit_scope: 2625b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2626b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # i = input_gate, j = new_input, f = forget_gate, o = output_gate 2627b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng lstm_matrix = self._linear( 2628b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng [inputs, m_prev], 2629b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 4 * self._num_units, 2630b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng bias=True, 2631b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng bias_initializer=None, 2632b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng layer_norm=self._layer_norm) 2633b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng i, j, f, o = array_ops.split( 2634b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng value=lstm_matrix, num_or_size_splits=4, axis=1) 2635b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2636b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if self._layer_norm: 2637b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng i = _norm(self._norm_gain, self._norm_shift, i, "input") 2638b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng j = _norm(self._norm_gain, self._norm_shift, j, "transform") 2639b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng f = _norm(self._norm_gain, self._norm_shift, f, "forget") 2640b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng o = _norm(self._norm_gain, self._norm_shift, o, "output") 2641b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2642b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # Diagonal connections 2643b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if self._use_peepholes: 2644b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng with vs.variable_scope(unit_scope): 2645b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng w_f_diag = vs.get_variable( 2646b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng "w_f_diag", shape=[self._num_units], dtype=dtype) 2647b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng w_i_diag = vs.get_variable( 2648b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng "w_i_diag", shape=[self._num_units], dtype=dtype) 2649b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng w_o_diag = vs.get_variable( 2650b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng "w_o_diag", shape=[self._num_units], dtype=dtype) 2651b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2652b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if self._use_peepholes: 2653b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng c = ( 2654b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + 2655b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng sigmoid(i + w_i_diag * c_prev) * self._activation(j)) 2656b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng else: 2657b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng c = ( 2658b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng sigmoid(f + self._forget_bias) * c_prev + 2659b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng sigmoid(i) * self._activation(j)) 2660b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2661b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if self._layer_norm: 2662b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng c = _norm(self._norm_gain, self._norm_shift, c, "state") 2663b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2664b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if self._cell_clip is not None: 2665b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # pylint: disable=invalid-unary-operand-type 2666b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) 2667b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # pylint: enable=invalid-unary-operand-type 2668b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if self._use_peepholes: 2669b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng m = sigmoid(o + w_o_diag * c) * self._activation(c) 2670b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng else: 2671b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng m = sigmoid(o) * self._activation(c) 2672b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2673b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if self._num_proj is not None: 2674b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng with vs.variable_scope("projection"): 2675b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng m = self._linear(m, self._num_proj, bias=False) 2676b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2677b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng if self._proj_clip is not None: 2678b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # pylint: disable=invalid-unary-operand-type 2679b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) 2680b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng # pylint: enable=invalid-unary-operand-type 2681b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng 2682b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng new_state = (rnn_cell_impl.LSTMStateTuple(c, m)) 2683b1d8c59e9b014b527fb2fbef9ce9afc14dbc4938Yifei Feng return m, new_state 268420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 268520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 2686f98264a1b9916e46a88089b605e962265ecde1a6Eugene Brevdoclass SRUCell(rnn_cell_impl.LayerRNNCell): 268720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen """SRU, Simple Recurrent Unit 2688ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 268920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen Implementation based on 269020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755). 269120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 2692ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie This variation of RNN cell is characterized by the simplified data 2693ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie dependence 269420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen between hidden states of two consecutive time steps. Traditionally, hidden 269520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen states from a cell at time step t-1 needs to be multiplied with a matrix 269620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen W_hh before being fed into the ensuing cell at time step t. 269720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen This flavor of RNN replaces the matrix multiplication between h_{t-1} 269820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen and W_hh with a pointwise multiplication, resulting in performance 269920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen gain. 270020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 270120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen Args: 270220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen num_units: int, The number of units in the SRU cell. 270320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen activation: Nonlinearity to use. Default: `tanh`. 270420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen reuse: (optional) Python boolean describing whether to reuse variables 270520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen in an existing scope. If not `True`, and the existing scope already has 270620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen the given variables, an error is raised. 270720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen name: (optional) String, the name of the layer. Layers with the same name 270820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen will share weights, but to avoid mistakes we require reuse=True in such 270920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen cases. 271020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen """ 2711ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie 2712ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, num_units, activation=None, reuse=None, name=None): 271320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen super(SRUCell, self).__init__(_reuse=reuse, name=name) 271420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen self._num_units = num_units 271520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen self._activation = activation or math_ops.tanh 271620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 271720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen # Restrict inputs to be 2-dimensional matrices 271820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen self.input_spec = base_layer.InputSpec(ndim=2) 271920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 272020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen @property 272120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen def state_size(self): 272220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen return self._num_units 272320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 272420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen @property 272520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen def output_size(self): 272620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen return self._num_units 272720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 272820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen def build(self, inputs_shape): 272920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen if inputs_shape[1].value is None: 2730ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie raise ValueError( 2731ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) 273220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 273320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen input_depth = inputs_shape[1].value 273420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 273520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen self._kernel = self.add_variable( 273620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen rnn_cell_impl._WEIGHTS_VARIABLE_NAME, 2737d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case shape=[input_depth, 4 * self._num_units]) 273820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 273920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen self._bias = self.add_variable( 274020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen rnn_cell_impl._BIAS_VARIABLE_NAME, 274120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen shape=[2 * self._num_units], 274220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) 274320765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 274420765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen self._built = True 274520765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 274620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen def call(self, inputs, state): 274720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen """Simple recurrent unit (SRU) with num_units cells.""" 274820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 274920765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen U = math_ops.matmul(inputs, self._kernel) 2750d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split( 2751d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case value=U, num_or_size_splits=4, axis=1) 275220765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 2753ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie f_r = math_ops.sigmoid( 2754ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie nn_ops.bias_add( 2755ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie array_ops.concat([f_intermediate, r_intermediate], 1), self._bias)) 275620765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1) 275720765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 275820765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen c = f * state + (1.0 - f) * x_bar 2759d90054e7c0f41f4bab81df0548577a73b939a87aMichael Case h = r * self._activation(c) + (1.0 - r) * x_tx 276020765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen 276120765b3e1ae3b718699592c98aa9805cb874b6d1Patrick Nguyen return h, c 2762d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2763d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2764d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xieclass WeightNormLSTMCell(rnn_cell_impl.RNNCell): 2765d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`. 2766d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2767d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie The weight-norm implementation is based on: 2768d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie https://arxiv.org/abs/1602.07868 2769d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Tim Salimans, Diederik P. Kingma. 2770d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Weight Normalization: A Simple Reparameterization to Accelerate 2771d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Training of Deep Neural Networks 2772d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2773d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie The default LSTM implementation based on: 2774d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie http://www.bioinf.jku.at/publications/older/2604.pdf 2775d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie S. Hochreiter and J. Schmidhuber. 2776d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. 2777d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2778d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie The class uses optional peephole connections, optional cell clipping 2779d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie and an optional projection layer. 2780d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2781d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie The optional peephole implementation is based on: 2782d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie https://research.google.com/pubs/archive/43905.pdf 2783d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Hasim Sak, Andrew Senior, and Francoise Beaufays. 2784d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie "Long short-term memory recurrent neural network architectures for 2785d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie large scale acoustic modeling." INTERSPEECH, 2014. 2786d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie """ 2787d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2788ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def __init__(self, 2789ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_units, 2790ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie norm=True, 2791ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie use_peepholes=False, 2792ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie cell_clip=None, 2793ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie initializer=None, 2794ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie num_proj=None, 2795ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie proj_clip=None, 2796ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie forget_bias=1, 2797ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie activation=None, 2798d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie reuse=None): 2799d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie """Initialize the parameters of a weight-normalized LSTM cell. 2800d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2801d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Args: 2802d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie num_units: int, The number of units in the LSTM cell 2803d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie norm: If `True`, apply normalization to the weight matrices. If False, 2804d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie the result is identical to that obtained from `rnn_cell_impl.LSTMCell` 2805d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie use_peepholes: bool, set `True` to enable diagonal/peephole connections. 2806d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie cell_clip: (optional) A float value, if provided the cell state is clipped 2807d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie by this value prior to the cell output activation. 2808d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie initializer: (optional) The initializer to use for the weight matrices. 2809d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie num_proj: (optional) int, The output dimensionality for the projection 2810d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie matrices. If None, no projection is performed. 2811d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is 2812d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie provided, then the projected values are clipped elementwise to within 2813d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie `[-proj_clip, proj_clip]`. 2814d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie forget_bias: Biases of the forget gate are initialized by default to 1 2815d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie in order to reduce the scale of forgetting at the beginning of 2816d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie the training. 2817d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie activation: Activation function of the inner states. Default: `tanh`. 2818d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie reuse: (optional) Python boolean describing whether to reuse variables 2819d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie in an existing scope. If not `True`, and the existing scope already has 2820d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie the given variables, an error is raised. 2821d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie """ 2822d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie super(WeightNormLSTMCell, self).__init__(_reuse=reuse) 2823d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2824ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._scope = "wn_lstm_cell" 2825d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._num_units = num_units 2826d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._norm = norm 2827d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._initializer = initializer 2828d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._use_peepholes = use_peepholes 2829d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._cell_clip = cell_clip 2830d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._num_proj = num_proj 2831d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._proj_clip = proj_clip 2832d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._activation = activation or math_ops.tanh 2833d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._forget_bias = forget_bias 2834d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2835d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._weights_variable_name = "kernel" 2836d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._bias_variable_name = "bias" 2837d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2838d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if num_proj: 2839d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) 2840d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._output_size = num_proj 2841d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie else: 2842d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) 2843d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._output_size = num_units 2844d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2845d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie @property 2846d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie def state_size(self): 2847d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie return self._state_size 2848d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2849d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie @property 2850d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie def output_size(self): 2851d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie return self._output_size 2852d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2853d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie def _normalize(self, weight, name): 2854d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie """Apply weight normalization. 2855d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2856d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Args: 2857d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie weight: a 2D tensor with known number of columns. 2858d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie name: string, variable name for the normalizer. 2859d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Returns: 2860d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie A tensor with the same shape as `weight`. 2861d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie """ 2862d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2863d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie output_size = weight.get_shape().as_list()[1] 2864d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie g = vs.get_variable(name, [output_size], dtype=weight.dtype) 2865d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie return nn_impl.l2_normalize(weight, dim=0) * g 2866d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2867ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie def _linear(self, 2868ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie args, 2869d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie output_size, 2870d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie norm, 2871d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie bias, 2872d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie bias_initializer=None, 2873d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie kernel_initializer=None): 2874d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 2875d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2876d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Args: 2877d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie args: a 2D Tensor or a list of 2D, batch x n, Tensors. 2878d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie output_size: int, second dimension of W[i]. 2879d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie bias: boolean, whether to add a bias term or not. 2880d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie bias_initializer: starting value to initialize the bias 2881d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie (default is all zeros). 2882d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie kernel_initializer: starting value to initialize the weight. 2883d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2884d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Returns: 2885d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie A 2D Tensor with shape [batch x output_size] equal to 2886d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie sum_i(args[i] * W[i]), where W[i]s are newly created matrices. 2887d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2888d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Raises: 2889d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie ValueError: if some of the arguments has unspecified or wrong shape. 2890d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie """ 2891d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if args is None or (nest.is_sequence(args) and not args): 2892d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie raise ValueError("`args` must be specified") 2893d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if not nest.is_sequence(args): 2894d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie args = [args] 2895d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2896d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie # Calculate the total size of arguments on dimension 1. 2897d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie total_arg_size = 0 2898d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie shapes = [a.get_shape() for a in args] 2899d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie for shape in shapes: 2900d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if shape.ndims != 2: 2901d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie raise ValueError("linear is expecting 2D arguments: %s" % shapes) 2902d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if shape[1].value is None: 2903d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie raise ValueError("linear expects shape[1] to be provided for shape %s, " 2904d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie "but saw %s" % (shape, shape[1])) 2905d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie else: 2906d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie total_arg_size += shape[1].value 2907d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2908d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie dtype = [a.dtype for a in args][0] 2909d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2910d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie # Now the computation. 2911d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie scope = vs.get_variable_scope() 2912d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie with vs.variable_scope(scope) as outer_scope: 2913d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie weights = vs.get_variable( 2914d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._weights_variable_name, [total_arg_size, output_size], 2915d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie dtype=dtype, 2916d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie initializer=kernel_initializer) 2917d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if norm: 2918d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie wn = [] 2919d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie st = 0 2920d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie with ops.control_dependencies(None): 2921d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie for i in range(len(args)): 2922d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie en = st + shapes[i][1].value 2923ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie wn.append( 2924ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie self._normalize(weights[st:en, :], name="norm_{}".format(i))) 2925d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie st = en 2926d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2927d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie weights = array_ops.concat(wn, axis=0) 2928d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2929d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if len(args) == 1: 2930d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie res = math_ops.matmul(args[0], weights) 2931d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie else: 2932d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie res = math_ops.matmul(array_ops.concat(args, 1), weights) 2933d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if not bias: 2934d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie return res 2935d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2936d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie with vs.variable_scope(outer_scope) as inner_scope: 2937d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie inner_scope.set_partitioner(None) 2938d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if bias_initializer is None: 2939d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) 2940d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2941d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie biases = vs.get_variable( 2942d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._bias_variable_name, [output_size], 2943d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie dtype=dtype, 2944d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie initializer=bias_initializer) 2945d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2946d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie return nn_ops.bias_add(res, biases) 2947d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2948d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie def call(self, inputs, state): 2949d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie """Run one step of LSTM. 2950d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2951d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Args: 2952d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie inputs: input Tensor, 2D, batch x num_units. 2953d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie state: A tuple of state Tensors, both `2-D`, with column sizes 2954d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie `c_state` and `m_state`. 2955d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2956d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Returns: 2957d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie A tuple containing: 2958d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2959d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie - A `2-D, [batch x output_dim]`, Tensor representing the output of the 2960d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie LSTM after reading `inputs` when previous state was `state`. 2961d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Here output_dim is: 2962d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie num_proj if num_proj was set, 2963d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie num_units otherwise. 2964d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie - Tensor(s) representing the new state of LSTM after reading `inputs` when 2965d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie the previous state was `state`. Same type and shape(s) as `state`. 2966d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2967d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie Raises: 2968d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie ValueError: If input size cannot be inferred from inputs via 2969d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie static shape inference. 2970d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie """ 2971d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie dtype = inputs.dtype 2972d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie num_units = self._num_units 2973d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie sigmoid = math_ops.sigmoid 2974d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie c, h = state 2975d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2976d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie input_size = inputs.get_shape().with_rank(2)[1] 2977d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if input_size.value is None: 2978d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie raise ValueError("Could not infer input size from inputs.get_shape()[-1]") 2979d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2980d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie with vs.variable_scope(self._scope, initializer=self._initializer): 2981d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2982ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie concat = self._linear( 2983ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie [inputs, h], 4 * num_units, norm=self._norm, bias=True) 2984d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2985d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie # i = input_gate, j = new_input, f = forget_gate, o = output_gate 2986d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) 2987d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2988d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if self._use_peepholes: 2989d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype) 2990d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype) 2991d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype) 2992d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 2993ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie new_c = ( 2994ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie c * sigmoid(f + self._forget_bias + w_f_diag * c) + 2995ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie sigmoid(i + w_i_diag * c) * self._activation(j)) 2996d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie else: 2997ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie new_c = ( 2998ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie c * sigmoid(f + self._forget_bias) + 2999ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie sigmoid(i) * self._activation(j)) 3000d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 3001d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if self._cell_clip is not None: 3002d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie # pylint: disable=invalid-unary-operand-type 3003d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip) 3004d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie # pylint: enable=invalid-unary-operand-type 3005d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if self._use_peepholes: 3006d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c) 3007d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie else: 3008d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie new_h = sigmoid(o) * self._activation(new_c) 3009d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 3010d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if self._num_proj is not None: 3011d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie with vs.variable_scope("projection"): 3012ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie new_h = self._linear( 3013ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie new_h, self._num_proj, norm=self._norm, bias=False) 3014d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 3015d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie if self._proj_clip is not None: 3016d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie # pylint: disable=invalid-unary-operand-type 3017ad07a86d75ab06bbcfd6f8f6a24debd9036a52d0Jianwei Xie new_h = clip_ops.clip_by_value(new_h, -self._proj_clip, 3018d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie self._proj_clip) 3019d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie # pylint: enable=invalid-unary-operand-type 3020d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie 3021d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) 3022d9f93c42a50b1f1401d9c186eac0ae8dc9093c3bJianwei Xie return new_h, new_state 3023