1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A powerful dynamic attention wrapper object.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import functools 23import math 24 25import numpy as np 26 27from tensorflow.contrib.framework.python.framework import tensor_util 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.layers import base as layers_base 32from tensorflow.python.layers import core as layers_core 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import check_ops 35from tensorflow.python.ops import clip_ops 36from tensorflow.python.ops import functional_ops 37from tensorflow.python.ops import init_ops 38from tensorflow.python.ops import math_ops 39from tensorflow.python.ops import nn_ops 40from tensorflow.python.ops import random_ops 41from tensorflow.python.ops import rnn_cell_impl 42from tensorflow.python.ops import tensor_array_ops 43from tensorflow.python.ops import variable_scope 44from tensorflow.python.util import nest 45 46 47__all__ = [ 48 "AttentionMechanism", 49 "AttentionWrapper", 50 "AttentionWrapperState", 51 "LuongAttention", 52 "BahdanauAttention", 53 "hardmax", 54 "safe_cumprod", 55 "monotonic_attention", 56 "BahdanauMonotonicAttention", 57 "LuongMonotonicAttention", 58] 59 60 61_zero_state_tensors = rnn_cell_impl._zero_state_tensors # pylint: disable=protected-access 62 63 64class AttentionMechanism(object): 65 66 @property 67 def alignments_size(self): 68 raise NotImplementedError 69 70 @property 71 def state_size(self): 72 raise NotImplementedError 73 74 75def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined): 76 """Convert to tensor and possibly mask `memory`. 77 78 Args: 79 memory: `Tensor`, shaped `[batch_size, max_time, ...]`. 80 memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`. 81 check_inner_dims_defined: Python boolean. If `True`, the `memory` 82 argument's shape is checked to ensure all but the two outermost 83 dimensions are fully defined. 84 85 Returns: 86 A (possibly masked), checked, new `memory`. 87 88 Raises: 89 ValueError: If `check_inner_dims_defined` is `True` and not 90 `memory.shape[2:].is_fully_defined()`. 91 """ 92 memory = nest.map_structure( 93 lambda m: ops.convert_to_tensor(m, name="memory"), memory) 94 if memory_sequence_length is not None: 95 memory_sequence_length = ops.convert_to_tensor( 96 memory_sequence_length, name="memory_sequence_length") 97 if check_inner_dims_defined: 98 def _check_dims(m): 99 if not m.get_shape()[2:].is_fully_defined(): 100 raise ValueError("Expected memory %s to have fully defined inner dims, " 101 "but saw shape: %s" % (m.name, m.get_shape())) 102 nest.map_structure(_check_dims, memory) 103 if memory_sequence_length is None: 104 seq_len_mask = None 105 else: 106 seq_len_mask = array_ops.sequence_mask( 107 memory_sequence_length, 108 maxlen=array_ops.shape(nest.flatten(memory)[0])[1], 109 dtype=nest.flatten(memory)[0].dtype) 110 seq_len_batch_size = ( 111 memory_sequence_length.shape[0].value 112 or array_ops.shape(memory_sequence_length)[0]) 113 def _maybe_mask(m, seq_len_mask): 114 rank = m.get_shape().ndims 115 rank = rank if rank is not None else array_ops.rank(m) 116 extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32) 117 m_batch_size = m.shape[0].value or array_ops.shape(m)[0] 118 if memory_sequence_length is not None: 119 message = ("memory_sequence_length and memory tensor batch sizes do not " 120 "match.") 121 with ops.control_dependencies([ 122 check_ops.assert_equal( 123 seq_len_batch_size, m_batch_size, message=message)]): 124 seq_len_mask = array_ops.reshape( 125 seq_len_mask, 126 array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0)) 127 return m * seq_len_mask 128 else: 129 return m 130 return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory) 131 132 133def _maybe_mask_score(score, memory_sequence_length, score_mask_value): 134 if memory_sequence_length is None: 135 return score 136 message = ("All values in memory_sequence_length must greater than zero.") 137 with ops.control_dependencies( 138 [check_ops.assert_positive(memory_sequence_length, message=message)]): 139 score_mask = array_ops.sequence_mask( 140 memory_sequence_length, maxlen=array_ops.shape(score)[1]) 141 score_mask_values = score_mask_value * array_ops.ones_like(score) 142 return array_ops.where(score_mask, score, score_mask_values) 143 144 145class _BaseAttentionMechanism(AttentionMechanism): 146 """A base AttentionMechanism class providing common functionality. 147 148 Common functionality includes: 149 1. Storing the query and memory layers. 150 2. Preprocessing and storing the memory. 151 """ 152 153 def __init__(self, 154 query_layer, 155 memory, 156 probability_fn, 157 memory_sequence_length=None, 158 memory_layer=None, 159 check_inner_dims_defined=True, 160 score_mask_value=None, 161 name=None): 162 """Construct base AttentionMechanism class. 163 164 Args: 165 query_layer: Callable. Instance of `tf.layers.Layer`. The layer's depth 166 must match the depth of `memory_layer`. If `query_layer` is not 167 provided, the shape of `query` must match that of `memory_layer`. 168 memory: The memory to query; usually the output of an RNN encoder. This 169 tensor should be shaped `[batch_size, max_time, ...]`. 170 probability_fn: A `callable`. Converts the score and previous alignments 171 to probabilities. Its signature should be: 172 `probabilities = probability_fn(score, state)`. 173 memory_sequence_length (optional): Sequence lengths for the batch entries 174 in memory. If provided, the memory tensor rows are masked with zeros 175 for values past the respective sequence lengths. 176 memory_layer: Instance of `tf.layers.Layer` (may be None). The layer's 177 depth must match the depth of `query_layer`. 178 If `memory_layer` is not provided, the shape of `memory` must match 179 that of `query_layer`. 180 check_inner_dims_defined: Python boolean. If `True`, the `memory` 181 argument's shape is checked to ensure all but the two outermost 182 dimensions are fully defined. 183 score_mask_value: (optional): The mask value for score before passing into 184 `probability_fn`. The default is -inf. Only used if 185 `memory_sequence_length` is not None. 186 name: Name to use when creating ops. 187 """ 188 if (query_layer is not None 189 and not isinstance(query_layer, layers_base.Layer)): 190 raise TypeError( 191 "query_layer is not a Layer: %s" % type(query_layer).__name__) 192 if (memory_layer is not None 193 and not isinstance(memory_layer, layers_base.Layer)): 194 raise TypeError( 195 "memory_layer is not a Layer: %s" % type(memory_layer).__name__) 196 self._query_layer = query_layer 197 self._memory_layer = memory_layer 198 self.dtype = memory_layer.dtype 199 if not callable(probability_fn): 200 raise TypeError("probability_fn must be callable, saw type: %s" % 201 type(probability_fn).__name__) 202 if score_mask_value is None: 203 score_mask_value = dtypes.as_dtype( 204 self._memory_layer.dtype).as_numpy_dtype(-np.inf) 205 self._probability_fn = lambda score, prev: ( # pylint:disable=g-long-lambda 206 probability_fn( 207 _maybe_mask_score(score, memory_sequence_length, score_mask_value), 208 prev)) 209 with ops.name_scope( 210 name, "BaseAttentionMechanismInit", nest.flatten(memory)): 211 self._values = _prepare_memory( 212 memory, memory_sequence_length, 213 check_inner_dims_defined=check_inner_dims_defined) 214 self._keys = ( 215 self.memory_layer(self._values) if self.memory_layer # pylint: disable=not-callable 216 else self._values) 217 self._batch_size = ( 218 self._keys.shape[0].value or array_ops.shape(self._keys)[0]) 219 self._alignments_size = (self._keys.shape[1].value or 220 array_ops.shape(self._keys)[1]) 221 222 @property 223 def memory_layer(self): 224 return self._memory_layer 225 226 @property 227 def query_layer(self): 228 return self._query_layer 229 230 @property 231 def values(self): 232 return self._values 233 234 @property 235 def keys(self): 236 return self._keys 237 238 @property 239 def batch_size(self): 240 return self._batch_size 241 242 @property 243 def alignments_size(self): 244 return self._alignments_size 245 246 @property 247 def state_size(self): 248 return self._alignments_size 249 250 def initial_alignments(self, batch_size, dtype): 251 """Creates the initial alignment values for the `AttentionWrapper` class. 252 253 This is important for AttentionMechanisms that use the previous alignment 254 to calculate the alignment at the next time step (e.g. monotonic attention). 255 256 The default behavior is to return a tensor of all zeros. 257 258 Args: 259 batch_size: `int32` scalar, the batch_size. 260 dtype: The `dtype`. 261 262 Returns: 263 A `dtype` tensor shaped `[batch_size, alignments_size]` 264 (`alignments_size` is the values' `max_time`). 265 """ 266 max_time = self._alignments_size 267 return _zero_state_tensors(max_time, batch_size, dtype) 268 269 def initial_state(self, batch_size, dtype): 270 """Creates the initial state values for the `AttentionWrapper` class. 271 272 This is important for AttentionMechanisms that use the previous alignment 273 to calculate the alignment at the next time step (e.g. monotonic attention). 274 275 The default behavior is to return the same output as initial_alignments. 276 277 Args: 278 batch_size: `int32` scalar, the batch_size. 279 dtype: The `dtype`. 280 281 Returns: 282 A structure of all-zero tensors with shapes as described by `state_size`. 283 """ 284 return self.initial_alignments(batch_size, dtype) 285 286 287def _luong_score(query, keys, scale): 288 """Implements Luong-style (multiplicative) scoring function. 289 290 This attention has two forms. The first is standard Luong attention, 291 as described in: 292 293 Minh-Thang Luong, Hieu Pham, Christopher D. Manning. 294 "Effective Approaches to Attention-based Neural Machine Translation." 295 EMNLP 2015. https://arxiv.org/abs/1508.04025 296 297 The second is the scaled form inspired partly by the normalized form of 298 Bahdanau attention. 299 300 To enable the second form, call this function with `scale=True`. 301 302 Args: 303 query: Tensor, shape `[batch_size, num_units]` to compare to keys. 304 keys: Processed memory, shape `[batch_size, max_time, num_units]`. 305 scale: Whether to apply a scale to the score function. 306 307 Returns: 308 A `[batch_size, max_time]` tensor of unnormalized score values. 309 310 Raises: 311 ValueError: If `key` and `query` depths do not match. 312 """ 313 depth = query.get_shape()[-1] 314 key_units = keys.get_shape()[-1] 315 if depth != key_units: 316 raise ValueError( 317 "Incompatible or unknown inner dimensions between query and keys. " 318 "Query (%s) has units: %s. Keys (%s) have units: %s. " 319 "Perhaps you need to set num_units to the keys' dimension (%s)?" 320 % (query, depth, keys, key_units, key_units)) 321 dtype = query.dtype 322 323 # Reshape from [batch_size, depth] to [batch_size, 1, depth] 324 # for matmul. 325 query = array_ops.expand_dims(query, 1) 326 327 # Inner product along the query units dimension. 328 # matmul shapes: query is [batch_size, 1, depth] and 329 # keys is [batch_size, max_time, depth]. 330 # the inner product is asked to **transpose keys' inner shape** to get a 331 # batched matmul on: 332 # [batch_size, 1, depth] . [batch_size, depth, max_time] 333 # resulting in an output shape of: 334 # [batch_size, 1, max_time]. 335 # we then squeeze out the center singleton dimension. 336 score = math_ops.matmul(query, keys, transpose_b=True) 337 score = array_ops.squeeze(score, [1]) 338 339 if scale: 340 # Scalar used in weight scaling 341 g = variable_scope.get_variable( 342 "attention_g", dtype=dtype, initializer=1.) 343 score = g * score 344 return score 345 346 347class LuongAttention(_BaseAttentionMechanism): 348 """Implements Luong-style (multiplicative) attention scoring. 349 350 This attention has two forms. The first is standard Luong attention, 351 as described in: 352 353 Minh-Thang Luong, Hieu Pham, Christopher D. Manning. 354 "Effective Approaches to Attention-based Neural Machine Translation." 355 EMNLP 2015. https://arxiv.org/abs/1508.04025 356 357 The second is the scaled form inspired partly by the normalized form of 358 Bahdanau attention. 359 360 To enable the second form, construct the object with parameter 361 `scale=True`. 362 """ 363 364 def __init__(self, 365 num_units, 366 memory, 367 memory_sequence_length=None, 368 scale=False, 369 probability_fn=None, 370 score_mask_value=None, 371 dtype=None, 372 name="LuongAttention"): 373 """Construct the AttentionMechanism mechanism. 374 375 Args: 376 num_units: The depth of the attention mechanism. 377 memory: The memory to query; usually the output of an RNN encoder. This 378 tensor should be shaped `[batch_size, max_time, ...]`. 379 memory_sequence_length: (optional) Sequence lengths for the batch entries 380 in memory. If provided, the memory tensor rows are masked with zeros 381 for values past the respective sequence lengths. 382 scale: Python boolean. Whether to scale the energy term. 383 probability_fn: (optional) A `callable`. Converts the score to 384 probabilities. The default is @{tf.nn.softmax}. Other options include 385 @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. 386 Its signature should be: `probabilities = probability_fn(score)`. 387 score_mask_value: (optional) The mask value for score before passing into 388 `probability_fn`. The default is -inf. Only used if 389 `memory_sequence_length` is not None. 390 dtype: The data type for the memory layer of the attention mechanism. 391 name: Name to use when creating ops. 392 """ 393 # For LuongAttention, we only transform the memory layer; thus 394 # num_units **must** match expected the query depth. 395 if probability_fn is None: 396 probability_fn = nn_ops.softmax 397 if dtype is None: 398 dtype = dtypes.float32 399 wrapped_probability_fn = lambda score, _: probability_fn(score) 400 super(LuongAttention, self).__init__( 401 query_layer=None, 402 memory_layer=layers_core.Dense( 403 num_units, name="memory_layer", use_bias=False, dtype=dtype), 404 memory=memory, 405 probability_fn=wrapped_probability_fn, 406 memory_sequence_length=memory_sequence_length, 407 score_mask_value=score_mask_value, 408 name=name) 409 self._num_units = num_units 410 self._scale = scale 411 self._name = name 412 413 def __call__(self, query, state): 414 """Score the query based on the keys and values. 415 416 Args: 417 query: Tensor of dtype matching `self.values` and shape 418 `[batch_size, query_depth]`. 419 state: Tensor of dtype matching `self.values` and shape 420 `[batch_size, alignments_size]` 421 (`alignments_size` is memory's `max_time`). 422 423 Returns: 424 alignments: Tensor of dtype matching `self.values` and shape 425 `[batch_size, alignments_size]` (`alignments_size` is memory's 426 `max_time`). 427 """ 428 with variable_scope.variable_scope(None, "luong_attention", [query]): 429 score = _luong_score(query, self._keys, self._scale) 430 alignments = self._probability_fn(score, state) 431 next_state = alignments 432 return alignments, next_state 433 434 435def _bahdanau_score(processed_query, keys, normalize): 436 """Implements Bahdanau-style (additive) scoring function. 437 438 This attention has two forms. The first is Bhandanau attention, 439 as described in: 440 441 Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. 442 "Neural Machine Translation by Jointly Learning to Align and Translate." 443 ICLR 2015. https://arxiv.org/abs/1409.0473 444 445 The second is the normalized form. This form is inspired by the 446 weight normalization article: 447 448 Tim Salimans, Diederik P. Kingma. 449 "Weight Normalization: A Simple Reparameterization to Accelerate 450 Training of Deep Neural Networks." 451 https://arxiv.org/abs/1602.07868 452 453 To enable the second form, set `normalize=True`. 454 455 Args: 456 processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys. 457 keys: Processed memory, shape `[batch_size, max_time, num_units]`. 458 normalize: Whether to normalize the score function. 459 460 Returns: 461 A `[batch_size, max_time]` tensor of unnormalized score values. 462 """ 463 dtype = processed_query.dtype 464 # Get the number of hidden units from the trailing dimension of keys 465 num_units = keys.shape[2].value or array_ops.shape(keys)[2] 466 # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting. 467 processed_query = array_ops.expand_dims(processed_query, 1) 468 v = variable_scope.get_variable( 469 "attention_v", [num_units], dtype=dtype) 470 if normalize: 471 # Scalar used in weight normalization 472 g = variable_scope.get_variable( 473 "attention_g", dtype=dtype, 474 initializer=math.sqrt((1. / num_units))) 475 # Bias added prior to the nonlinearity 476 b = variable_scope.get_variable( 477 "attention_b", [num_units], dtype=dtype, 478 initializer=init_ops.zeros_initializer()) 479 # normed_v = g * v / ||v|| 480 normed_v = g * v * math_ops.rsqrt( 481 math_ops.reduce_sum(math_ops.square(v))) 482 return math_ops.reduce_sum( 483 normed_v * math_ops.tanh(keys + processed_query + b), [2]) 484 else: 485 return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2]) 486 487 488class BahdanauAttention(_BaseAttentionMechanism): 489 """Implements Bahdanau-style (additive) attention. 490 491 This attention has two forms. The first is Bahdanau attention, 492 as described in: 493 494 Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio. 495 "Neural Machine Translation by Jointly Learning to Align and Translate." 496 ICLR 2015. https://arxiv.org/abs/1409.0473 497 498 The second is the normalized form. This form is inspired by the 499 weight normalization article: 500 501 Tim Salimans, Diederik P. Kingma. 502 "Weight Normalization: A Simple Reparameterization to Accelerate 503 Training of Deep Neural Networks." 504 https://arxiv.org/abs/1602.07868 505 506 To enable the second form, construct the object with parameter 507 `normalize=True`. 508 """ 509 510 def __init__(self, 511 num_units, 512 memory, 513 memory_sequence_length=None, 514 normalize=False, 515 probability_fn=None, 516 score_mask_value=None, 517 dtype=None, 518 name="BahdanauAttention"): 519 """Construct the Attention mechanism. 520 521 Args: 522 num_units: The depth of the query mechanism. 523 memory: The memory to query; usually the output of an RNN encoder. This 524 tensor should be shaped `[batch_size, max_time, ...]`. 525 memory_sequence_length (optional): Sequence lengths for the batch entries 526 in memory. If provided, the memory tensor rows are masked with zeros 527 for values past the respective sequence lengths. 528 normalize: Python boolean. Whether to normalize the energy term. 529 probability_fn: (optional) A `callable`. Converts the score to 530 probabilities. The default is @{tf.nn.softmax}. Other options include 531 @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. 532 Its signature should be: `probabilities = probability_fn(score)`. 533 score_mask_value: (optional): The mask value for score before passing into 534 `probability_fn`. The default is -inf. Only used if 535 `memory_sequence_length` is not None. 536 dtype: The data type for the query and memory layers of the attention 537 mechanism. 538 name: Name to use when creating ops. 539 """ 540 if probability_fn is None: 541 probability_fn = nn_ops.softmax 542 if dtype is None: 543 dtype = dtypes.float32 544 wrapped_probability_fn = lambda score, _: probability_fn(score) 545 super(BahdanauAttention, self).__init__( 546 query_layer=layers_core.Dense( 547 num_units, name="query_layer", use_bias=False, dtype=dtype), 548 memory_layer=layers_core.Dense( 549 num_units, name="memory_layer", use_bias=False, dtype=dtype), 550 memory=memory, 551 probability_fn=wrapped_probability_fn, 552 memory_sequence_length=memory_sequence_length, 553 score_mask_value=score_mask_value, 554 name=name) 555 self._num_units = num_units 556 self._normalize = normalize 557 self._name = name 558 559 def __call__(self, query, state): 560 """Score the query based on the keys and values. 561 562 Args: 563 query: Tensor of dtype matching `self.values` and shape 564 `[batch_size, query_depth]`. 565 state: Tensor of dtype matching `self.values` and shape 566 `[batch_size, alignments_size]` 567 (`alignments_size` is memory's `max_time`). 568 569 Returns: 570 alignments: Tensor of dtype matching `self.values` and shape 571 `[batch_size, alignments_size]` (`alignments_size` is memory's 572 `max_time`). 573 """ 574 with variable_scope.variable_scope(None, "bahdanau_attention", [query]): 575 processed_query = self.query_layer(query) if self.query_layer else query 576 score = _bahdanau_score(processed_query, self._keys, self._normalize) 577 alignments = self._probability_fn(score, state) 578 next_state = alignments 579 return alignments, next_state 580 581 582def safe_cumprod(x, *args, **kwargs): 583 """Computes cumprod of x in logspace using cumsum to avoid underflow. 584 585 The cumprod function and its gradient can result in numerical instabilities 586 when its argument has very small and/or zero values. As long as the argument 587 is all positive, we can instead compute the cumulative product as 588 exp(cumsum(log(x))). This function can be called identically to tf.cumprod. 589 590 Args: 591 x: Tensor to take the cumulative product of. 592 *args: Passed on to cumsum; these are identical to those in cumprod. 593 **kwargs: Passed on to cumsum; these are identical to those in cumprod. 594 Returns: 595 Cumulative product of x. 596 """ 597 with ops.name_scope(None, "SafeCumprod", [x]): 598 x = ops.convert_to_tensor(x, name="x") 599 tiny = np.finfo(x.dtype.as_numpy_dtype).tiny 600 return math_ops.exp(math_ops.cumsum( 601 math_ops.log(clip_ops.clip_by_value(x, tiny, 1)), *args, **kwargs)) 602 603 604def monotonic_attention(p_choose_i, previous_attention, mode): 605 """Compute monotonic attention distribution from choosing probabilities. 606 607 Monotonic attention implies that the input sequence is processed in an 608 explicitly left-to-right manner when generating the output sequence. In 609 addition, once an input sequence element is attended to at a given output 610 timestep, elements occurring before it cannot be attended to at subsequent 611 output timesteps. This function generates attention distributions according 612 to these assumptions. For more information, see ``Online and Linear-Time 613 Attention by Enforcing Monotonic Alignments''. 614 615 Args: 616 p_choose_i: Probability of choosing input sequence/memory element i. Should 617 be of shape (batch_size, input_sequence_length), and should all be in the 618 range [0, 1]. 619 previous_attention: The attention distribution from the previous output 620 timestep. Should be of shape (batch_size, input_sequence_length). For 621 the first output timestep, preevious_attention[n] should be [1, 0, 0, ..., 622 0] for all n in [0, ... batch_size - 1]. 623 mode: How to compute the attention distribution. Must be one of 624 'recursive', 'parallel', or 'hard'. 625 * 'recursive' uses tf.scan to recursively compute the distribution. 626 This is slowest but is exact, general, and does not suffer from 627 numerical instabilities. 628 * 'parallel' uses parallelized cumulative-sum and cumulative-product 629 operations to compute a closed-form solution to the recurrence 630 relation defining the attention distribution. This makes it more 631 efficient than 'recursive', but it requires numerical checks which 632 make the distribution non-exact. This can be a problem in particular 633 when input_sequence_length is long and/or p_choose_i has entries very 634 close to 0 or 1. 635 * 'hard' requires that the probabilities in p_choose_i are all either 0 636 or 1, and subsequently uses a more efficient and exact solution. 637 638 Returns: 639 A tensor of shape (batch_size, input_sequence_length) representing the 640 attention distributions for each sequence in the batch. 641 642 Raises: 643 ValueError: mode is not one of 'recursive', 'parallel', 'hard'. 644 """ 645 # Force things to be tensors 646 p_choose_i = ops.convert_to_tensor(p_choose_i, name="p_choose_i") 647 previous_attention = ops.convert_to_tensor( 648 previous_attention, name="previous_attention") 649 if mode == "recursive": 650 # Use .shape[0].value when it's not None, or fall back on symbolic shape 651 batch_size = p_choose_i.shape[0].value or array_ops.shape(p_choose_i)[0] 652 # Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_i[-2]] 653 shifted_1mp_choose_i = array_ops.concat( 654 [array_ops.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1) 655 # Compute attention distribution recursively as 656 # q[i] = (1 - p_choose_i[i])*q[i - 1] + previous_attention[i] 657 # attention[i] = p_choose_i[i]*q[i] 658 attention = p_choose_i*array_ops.transpose(functional_ops.scan( 659 # Need to use reshape to remind TF of the shape between loop iterations 660 lambda x, yz: array_ops.reshape(yz[0]*x + yz[1], (batch_size,)), 661 # Loop variables yz[0] and yz[1] 662 [array_ops.transpose(shifted_1mp_choose_i), 663 array_ops.transpose(previous_attention)], 664 # Initial value of x is just zeros 665 array_ops.zeros((batch_size,)))) 666 elif mode == "parallel": 667 # safe_cumprod computes cumprod in logspace with numeric checks 668 cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, axis=1, exclusive=True) 669 # Compute recurrence relation solution 670 attention = p_choose_i*cumprod_1mp_choose_i*math_ops.cumsum( 671 previous_attention / 672 # Clip cumprod_1mp to avoid divide-by-zero 673 clip_ops.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.), axis=1) 674 elif mode == "hard": 675 # Remove any probabilities before the index chosen last time step 676 p_choose_i *= math_ops.cumsum(previous_attention, axis=1) 677 # Now, use exclusive cumprod to remove probabilities after the first 678 # chosen index, like so: 679 # p_choose_i = [0, 0, 0, 1, 1, 0, 1, 1] 680 # cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0] 681 # Product of above: [0, 0, 0, 1, 0, 0, 0, 0] 682 attention = p_choose_i*math_ops.cumprod( 683 1 - p_choose_i, axis=1, exclusive=True) 684 else: 685 raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.") 686 return attention 687 688 689def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode, 690 seed=None): 691 """Attention probability function for monotonic attention. 692 693 Takes in unnormalized attention scores, adds pre-sigmoid noise to encourage 694 the model to make discrete attention decisions, passes them through a sigmoid 695 to obtain "choosing" probabilities, and then calls monotonic_attention to 696 obtain the attention distribution. For more information, see 697 698 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 699 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 700 ICML 2017. https://arxiv.org/abs/1704.00784 701 702 Args: 703 score: Unnormalized attention scores, shape `[batch_size, alignments_size]` 704 previous_alignments: Previous attention distribution, shape 705 `[batch_size, alignments_size]` 706 sigmoid_noise: Standard deviation of pre-sigmoid noise. Setting this larger 707 than 0 will encourage the model to produce large attention scores, 708 effectively making the choosing probabilities discrete and the resulting 709 attention distribution one-hot. It should be set to 0 at test-time, and 710 when hard attention is not desired. 711 mode: How to compute the attention distribution. Must be one of 712 'recursive', 'parallel', or 'hard'. See the docstring for 713 `tf.contrib.seq2seq.monotonic_attention` for more information. 714 seed: (optional) Random seed for pre-sigmoid noise. 715 716 Returns: 717 A `[batch_size, alignments_size]`-shape tensor corresponding to the 718 resulting attention distribution. 719 """ 720 # Optionally add pre-sigmoid noise to the scores 721 if sigmoid_noise > 0: 722 noise = random_ops.random_normal(array_ops.shape(score), dtype=score.dtype, 723 seed=seed) 724 score += sigmoid_noise*noise 725 # Compute "choosing" probabilities from the attention scores 726 if mode == "hard": 727 # When mode is hard, use a hard sigmoid 728 p_choose_i = math_ops.cast(score > 0, score.dtype) 729 else: 730 p_choose_i = math_ops.sigmoid(score) 731 # Convert from choosing probabilities to attention distribution 732 return monotonic_attention(p_choose_i, previous_alignments, mode) 733 734 735class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism): 736 """Base attention mechanism for monotonic attention. 737 738 Simply overrides the initial_alignments function to provide a dirac 739 distribution,which is needed in order for the monotonic attention 740 distributions to have the correct behavior. 741 """ 742 743 def initial_alignments(self, batch_size, dtype): 744 """Creates the initial alignment values for the monotonic attentions. 745 746 Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0] 747 for all entries in the batch. 748 749 Args: 750 batch_size: `int32` scalar, the batch_size. 751 dtype: The `dtype`. 752 753 Returns: 754 A `dtype` tensor shaped `[batch_size, alignments_size]` 755 (`alignments_size` is the values' `max_time`). 756 """ 757 max_time = self._alignments_size 758 return array_ops.one_hot( 759 array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time, 760 dtype=dtype) 761 762 763class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): 764 """Monotonic attention mechanism with Bahadanau-style energy function. 765 766 This type of attention encorces a monotonic constraint on the attention 767 distributions; that is once the model attends to a given point in the memory 768 it can't attend to any prior points at subsequence output timesteps. It 769 achieves this by using the _monotonic_probability_fn instead of softmax to 770 construct its attention distributions. Since the attention scores are passed 771 through a sigmoid, a learnable scalar bias parameter is applied after the 772 score function and before the sigmoid. Otherwise, it is equivalent to 773 BahdanauAttention. This approach is proposed in 774 775 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 776 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 777 ICML 2017. https://arxiv.org/abs/1704.00784 778 """ 779 780 def __init__(self, 781 num_units, 782 memory, 783 memory_sequence_length=None, 784 normalize=False, 785 score_mask_value=None, 786 sigmoid_noise=0., 787 sigmoid_noise_seed=None, 788 score_bias_init=0., 789 mode="parallel", 790 dtype=None, 791 name="BahdanauMonotonicAttention"): 792 """Construct the Attention mechanism. 793 794 Args: 795 num_units: The depth of the query mechanism. 796 memory: The memory to query; usually the output of an RNN encoder. This 797 tensor should be shaped `[batch_size, max_time, ...]`. 798 memory_sequence_length (optional): Sequence lengths for the batch entries 799 in memory. If provided, the memory tensor rows are masked with zeros 800 for values past the respective sequence lengths. 801 normalize: Python boolean. Whether to normalize the energy term. 802 score_mask_value: (optional): The mask value for score before passing into 803 `probability_fn`. The default is -inf. Only used if 804 `memory_sequence_length` is not None. 805 sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring 806 for `_monotonic_probability_fn` for more information. 807 sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. 808 score_bias_init: Initial value for score bias scalar. It's recommended to 809 initialize this to a negative value when the length of the memory is 810 large. 811 mode: How to compute the attention distribution. Must be one of 812 'recursive', 'parallel', or 'hard'. See the docstring for 813 `tf.contrib.seq2seq.monotonic_attention` for more information. 814 dtype: The data type for the query and memory layers of the attention 815 mechanism. 816 name: Name to use when creating ops. 817 """ 818 # Set up the monotonic probability fn with supplied parameters 819 if dtype is None: 820 dtype = dtypes.float32 821 wrapped_probability_fn = functools.partial( 822 _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, 823 seed=sigmoid_noise_seed) 824 super(BahdanauMonotonicAttention, self).__init__( 825 query_layer=layers_core.Dense( 826 num_units, name="query_layer", use_bias=False, dtype=dtype), 827 memory_layer=layers_core.Dense( 828 num_units, name="memory_layer", use_bias=False, dtype=dtype), 829 memory=memory, 830 probability_fn=wrapped_probability_fn, 831 memory_sequence_length=memory_sequence_length, 832 score_mask_value=score_mask_value, 833 name=name) 834 self._num_units = num_units 835 self._normalize = normalize 836 self._name = name 837 self._score_bias_init = score_bias_init 838 839 def __call__(self, query, state): 840 """Score the query based on the keys and values. 841 842 Args: 843 query: Tensor of dtype matching `self.values` and shape 844 `[batch_size, query_depth]`. 845 state: Tensor of dtype matching `self.values` and shape 846 `[batch_size, alignments_size]` 847 (`alignments_size` is memory's `max_time`). 848 849 Returns: 850 alignments: Tensor of dtype matching `self.values` and shape 851 `[batch_size, alignments_size]` (`alignments_size` is memory's 852 `max_time`). 853 """ 854 with variable_scope.variable_scope( 855 None, "bahdanau_monotonic_attention", [query]): 856 processed_query = self.query_layer(query) if self.query_layer else query 857 score = _bahdanau_score(processed_query, self._keys, self._normalize) 858 score_bias = variable_scope.get_variable( 859 "attention_score_bias", dtype=processed_query.dtype, 860 initializer=self._score_bias_init) 861 score += score_bias 862 alignments = self._probability_fn(score, state) 863 next_state = alignments 864 return alignments, next_state 865 866 867class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): 868 """Monotonic attention mechanism with Luong-style energy function. 869 870 This type of attention encorces a monotonic constraint on the attention 871 distributions; that is once the model attends to a given point in the memory 872 it can't attend to any prior points at subsequence output timesteps. It 873 achieves this by using the _monotonic_probability_fn instead of softmax to 874 construct its attention distributions. Otherwise, it is equivalent to 875 LuongAttention. This approach is proposed in 876 877 Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck, 878 "Online and Linear-Time Attention by Enforcing Monotonic Alignments." 879 ICML 2017. https://arxiv.org/abs/1704.00784 880 """ 881 882 def __init__(self, 883 num_units, 884 memory, 885 memory_sequence_length=None, 886 scale=False, 887 score_mask_value=None, 888 sigmoid_noise=0., 889 sigmoid_noise_seed=None, 890 score_bias_init=0., 891 mode="parallel", 892 dtype=None, 893 name="LuongMonotonicAttention"): 894 """Construct the Attention mechanism. 895 896 Args: 897 num_units: The depth of the query mechanism. 898 memory: The memory to query; usually the output of an RNN encoder. This 899 tensor should be shaped `[batch_size, max_time, ...]`. 900 memory_sequence_length (optional): Sequence lengths for the batch entries 901 in memory. If provided, the memory tensor rows are masked with zeros 902 for values past the respective sequence lengths. 903 scale: Python boolean. Whether to scale the energy term. 904 score_mask_value: (optional): The mask value for score before passing into 905 `probability_fn`. The default is -inf. Only used if 906 `memory_sequence_length` is not None. 907 sigmoid_noise: Standard deviation of pre-sigmoid noise. See the docstring 908 for `_monotonic_probability_fn` for more information. 909 sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise. 910 score_bias_init: Initial value for score bias scalar. It's recommended to 911 initialize this to a negative value when the length of the memory is 912 large. 913 mode: How to compute the attention distribution. Must be one of 914 'recursive', 'parallel', or 'hard'. See the docstring for 915 `tf.contrib.seq2seq.monotonic_attention` for more information. 916 dtype: The data type for the query and memory layers of the attention 917 mechanism. 918 name: Name to use when creating ops. 919 """ 920 # Set up the monotonic probability fn with supplied parameters 921 if dtype is None: 922 dtype = dtypes.float32 923 wrapped_probability_fn = functools.partial( 924 _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode, 925 seed=sigmoid_noise_seed) 926 super(LuongMonotonicAttention, self).__init__( 927 query_layer=None, 928 memory_layer=layers_core.Dense( 929 num_units, name="memory_layer", use_bias=False, dtype=dtype), 930 memory=memory, 931 probability_fn=wrapped_probability_fn, 932 memory_sequence_length=memory_sequence_length, 933 score_mask_value=score_mask_value, 934 name=name) 935 self._num_units = num_units 936 self._scale = scale 937 self._score_bias_init = score_bias_init 938 self._name = name 939 940 def __call__(self, query, state): 941 """Score the query based on the keys and values. 942 943 Args: 944 query: Tensor of dtype matching `self.values` and shape 945 `[batch_size, query_depth]`. 946 state: Tensor of dtype matching `self.values` and shape 947 `[batch_size, alignments_size]` 948 (`alignments_size` is memory's `max_time`). 949 950 Returns: 951 alignments: Tensor of dtype matching `self.values` and shape 952 `[batch_size, alignments_size]` (`alignments_size` is memory's 953 `max_time`). 954 """ 955 with variable_scope.variable_scope(None, "luong_monotonic_attention", 956 [query]): 957 score = _luong_score(query, self._keys, self._scale) 958 score_bias = variable_scope.get_variable( 959 "attention_score_bias", dtype=query.dtype, 960 initializer=self._score_bias_init) 961 score += score_bias 962 alignments = self._probability_fn(score, state) 963 next_state = alignments 964 return alignments, next_state 965 966 967class AttentionWrapperState( 968 collections.namedtuple("AttentionWrapperState", 969 ("cell_state", "attention", "time", "alignments", 970 "alignment_history", "attention_state"))): 971 """`namedtuple` storing the state of a `AttentionWrapper`. 972 973 Contains: 974 975 - `cell_state`: The state of the wrapped `RNNCell` at the previous time 976 step. 977 - `attention`: The attention emitted at the previous time step. 978 - `time`: int32 scalar containing the current time step. 979 - `alignments`: A single or tuple of `Tensor`(s) containing the alignments 980 emitted at the previous time step for each attention mechanism. 981 - `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s) 982 containing alignment matrices from all time steps for each attention 983 mechanism. Call `stack()` on each to convert to a `Tensor`. 984 - `attention_state`: A single or tuple of nested objects 985 containing attention mechanism state for each attention mechanism. 986 The objects may contain Tensors or TensorArrays. 987 """ 988 989 def clone(self, **kwargs): 990 """Clone this object, overriding components provided by kwargs. 991 992 The new state fields' shape must match original state fields' shape. This 993 will be validated, and original fields' shape will be propagated to new 994 fields. 995 996 Example: 997 998 ```python 999 initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...) 1000 initial_state = initial_state.clone(cell_state=encoder_state) 1001 ``` 1002 1003 Args: 1004 **kwargs: Any properties of the state object to replace in the returned 1005 `AttentionWrapperState`. 1006 1007 Returns: 1008 A new `AttentionWrapperState` whose properties are the same as 1009 this one, except any overridden properties as provided in `kwargs`. 1010 """ 1011 def with_same_shape(old, new): 1012 """Check and set new tensor's shape.""" 1013 if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor): 1014 return tensor_util.with_same_shape(old, new) 1015 return new 1016 1017 return nest.map_structure( 1018 with_same_shape, 1019 self, 1020 super(AttentionWrapperState, self)._replace(**kwargs)) 1021 1022 1023def hardmax(logits, name=None): 1024 """Returns batched one-hot vectors. 1025 1026 The depth index containing the `1` is that of the maximum logit value. 1027 1028 Args: 1029 logits: A batch tensor of logit values. 1030 name: Name to use when creating ops. 1031 Returns: 1032 A batched one-hot tensor. 1033 """ 1034 with ops.name_scope(name, "Hardmax", [logits]): 1035 logits = ops.convert_to_tensor(logits, name="logits") 1036 if logits.get_shape()[-1].value is not None: 1037 depth = logits.get_shape()[-1].value 1038 else: 1039 depth = array_ops.shape(logits)[-1] 1040 return array_ops.one_hot( 1041 math_ops.argmax(logits, -1), depth, dtype=logits.dtype) 1042 1043 1044def _compute_attention(attention_mechanism, cell_output, attention_state, 1045 attention_layer): 1046 """Computes the attention and alignments for a given attention_mechanism.""" 1047 alignments, next_attention_state = attention_mechanism( 1048 cell_output, state=attention_state) 1049 1050 # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time] 1051 expanded_alignments = array_ops.expand_dims(alignments, 1) 1052 # Context is the inner product of alignments and values along the 1053 # memory time dimension. 1054 # alignments shape is 1055 # [batch_size, 1, memory_time] 1056 # attention_mechanism.values shape is 1057 # [batch_size, memory_time, memory_size] 1058 # the batched matmul is over memory_time, so the output shape is 1059 # [batch_size, 1, memory_size]. 1060 # we then squeeze out the singleton dim. 1061 context = math_ops.matmul(expanded_alignments, attention_mechanism.values) 1062 context = array_ops.squeeze(context, [1]) 1063 1064 if attention_layer is not None: 1065 attention = attention_layer(array_ops.concat([cell_output, context], 1)) 1066 else: 1067 attention = context 1068 1069 return attention, alignments, next_attention_state 1070 1071 1072class AttentionWrapper(rnn_cell_impl.RNNCell): 1073 """Wraps another `RNNCell` with attention. 1074 """ 1075 1076 def __init__(self, 1077 cell, 1078 attention_mechanism, 1079 attention_layer_size=None, 1080 alignment_history=False, 1081 cell_input_fn=None, 1082 output_attention=True, 1083 initial_cell_state=None, 1084 name=None): 1085 """Construct the `AttentionWrapper`. 1086 1087 **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in 1088 `AttentionWrapper`, then you must ensure that: 1089 1090 - The encoder output has been tiled to `beam_width` via 1091 @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`). 1092 - The `batch_size` argument passed to the `zero_state` method of this 1093 wrapper is equal to `true_batch_size * beam_width`. 1094 - The initial state created with `zero_state` above contains a 1095 `cell_state` value containing properly tiled final state from the 1096 encoder. 1097 1098 An example: 1099 1100 ``` 1101 tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( 1102 encoder_outputs, multiplier=beam_width) 1103 tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( 1104 encoder_final_state, multiplier=beam_width) 1105 tiled_sequence_length = tf.contrib.seq2seq.tile_batch( 1106 sequence_length, multiplier=beam_width) 1107 attention_mechanism = MyFavoriteAttentionMechanism( 1108 num_units=attention_depth, 1109 memory=tiled_inputs, 1110 memory_sequence_length=tiled_sequence_length) 1111 attention_cell = AttentionWrapper(cell, attention_mechanism, ...) 1112 decoder_initial_state = attention_cell.zero_state( 1113 dtype, batch_size=true_batch_size * beam_width) 1114 decoder_initial_state = decoder_initial_state.clone( 1115 cell_state=tiled_encoder_final_state) 1116 ``` 1117 1118 Args: 1119 cell: An instance of `RNNCell`. 1120 attention_mechanism: A list of `AttentionMechanism` instances or a single 1121 instance. 1122 attention_layer_size: A list of Python integers or a single Python 1123 integer, the depth of the attention (output) layer(s). If None 1124 (default), use the context as attention at each time step. Otherwise, 1125 feed the context and cell output into the attention layer to generate 1126 attention at each time step. If attention_mechanism is a list, 1127 attention_layer_size must be a list of the same length. 1128 alignment_history: Python boolean, whether to store alignment history 1129 from all time steps in the final output state (currently stored as a 1130 time major `TensorArray` on which you must call `stack()`). 1131 cell_input_fn: (optional) A `callable`. The default is: 1132 `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`. 1133 output_attention: Python bool. If `True` (default), the output at each 1134 time step is the attention value. This is the behavior of Luong-style 1135 attention mechanisms. If `False`, the output at each time step is 1136 the output of `cell`. This is the beahvior of Bhadanau-style 1137 attention mechanisms. In both cases, the `attention` tensor is 1138 propagated to the next time step via the state and is used there. 1139 This flag only controls whether the attention mechanism is propagated 1140 up to the next cell in an RNN stack or to the top RNN output. 1141 initial_cell_state: The initial state value to use for the cell when 1142 the user calls `zero_state()`. Note that if this value is provided 1143 now, and the user uses a `batch_size` argument of `zero_state` which 1144 does not match the batch size of `initial_cell_state`, proper 1145 behavior is not guaranteed. 1146 name: Name to use when creating ops. 1147 1148 Raises: 1149 TypeError: `attention_layer_size` is not None and (`attention_mechanism` 1150 is a list but `attention_layer_size` is not; or vice versa). 1151 ValueError: if `attention_layer_size` is not None, `attention_mechanism` 1152 is a list, and its length does not match that of `attention_layer_size`. 1153 """ 1154 super(AttentionWrapper, self).__init__(name=name) 1155 if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access 1156 raise TypeError( 1157 "cell must be an RNNCell, saw type: %s" % type(cell).__name__) 1158 if isinstance(attention_mechanism, (list, tuple)): 1159 self._is_multi = True 1160 attention_mechanisms = attention_mechanism 1161 for attention_mechanism in attention_mechanisms: 1162 if not isinstance(attention_mechanism, AttentionMechanism): 1163 raise TypeError( 1164 "attention_mechanism must contain only instances of " 1165 "AttentionMechanism, saw type: %s" 1166 % type(attention_mechanism).__name__) 1167 else: 1168 self._is_multi = False 1169 if not isinstance(attention_mechanism, AttentionMechanism): 1170 raise TypeError( 1171 "attention_mechanism must be an AttentionMechanism or list of " 1172 "multiple AttentionMechanism instances, saw type: %s" 1173 % type(attention_mechanism).__name__) 1174 attention_mechanisms = (attention_mechanism,) 1175 1176 if cell_input_fn is None: 1177 cell_input_fn = ( 1178 lambda inputs, attention: array_ops.concat([inputs, attention], -1)) 1179 else: 1180 if not callable(cell_input_fn): 1181 raise TypeError( 1182 "cell_input_fn must be callable, saw type: %s" 1183 % type(cell_input_fn).__name__) 1184 1185 if attention_layer_size is not None: 1186 attention_layer_sizes = tuple( 1187 attention_layer_size 1188 if isinstance(attention_layer_size, (list, tuple)) 1189 else (attention_layer_size,)) 1190 if len(attention_layer_sizes) != len(attention_mechanisms): 1191 raise ValueError( 1192 "If provided, attention_layer_size must contain exactly one " 1193 "integer per attention_mechanism, saw: %d vs %d" 1194 % (len(attention_layer_sizes), len(attention_mechanisms))) 1195 self._attention_layers = tuple( 1196 layers_core.Dense( 1197 attention_layer_size, 1198 name="attention_layer", 1199 use_bias=False, 1200 dtype=attention_mechanisms[i].dtype) 1201 for i, attention_layer_size in enumerate(attention_layer_sizes)) 1202 self._attention_layer_size = sum(attention_layer_sizes) 1203 else: 1204 self._attention_layers = None 1205 self._attention_layer_size = sum( 1206 attention_mechanism.values.get_shape()[-1].value 1207 for attention_mechanism in attention_mechanisms) 1208 1209 self._cell = cell 1210 self._attention_mechanisms = attention_mechanisms 1211 self._cell_input_fn = cell_input_fn 1212 self._output_attention = output_attention 1213 self._alignment_history = alignment_history 1214 with ops.name_scope(name, "AttentionWrapperInit"): 1215 if initial_cell_state is None: 1216 self._initial_cell_state = None 1217 else: 1218 final_state_tensor = nest.flatten(initial_cell_state)[-1] 1219 state_batch_size = ( 1220 final_state_tensor.shape[0].value 1221 or array_ops.shape(final_state_tensor)[0]) 1222 error_message = ( 1223 "When constructing AttentionWrapper %s: " % self._base_name + 1224 "Non-matching batch sizes between the memory " 1225 "(encoder output) and initial_cell_state. Are you using " 1226 "the BeamSearchDecoder? You may need to tile your initial state " 1227 "via the tf.contrib.seq2seq.tile_batch function with argument " 1228 "multiple=beam_width.") 1229 with ops.control_dependencies( 1230 self._batch_size_checks(state_batch_size, error_message)): 1231 self._initial_cell_state = nest.map_structure( 1232 lambda s: array_ops.identity(s, name="check_initial_cell_state"), 1233 initial_cell_state) 1234 1235 def _batch_size_checks(self, batch_size, error_message): 1236 return [check_ops.assert_equal(batch_size, 1237 attention_mechanism.batch_size, 1238 message=error_message) 1239 for attention_mechanism in self._attention_mechanisms] 1240 1241 def _item_or_tuple(self, seq): 1242 """Returns `seq` as tuple or the singular element. 1243 1244 Which is returned is determined by how the AttentionMechanism(s) were passed 1245 to the constructor. 1246 1247 Args: 1248 seq: A non-empty sequence of items or generator. 1249 1250 Returns: 1251 Either the values in the sequence as a tuple if AttentionMechanism(s) 1252 were passed to the constructor as a sequence or the singular element. 1253 """ 1254 t = tuple(seq) 1255 if self._is_multi: 1256 return t 1257 else: 1258 return t[0] 1259 1260 @property 1261 def output_size(self): 1262 if self._output_attention: 1263 return self._attention_layer_size 1264 else: 1265 return self._cell.output_size 1266 1267 @property 1268 def state_size(self): 1269 """The `state_size` property of `AttentionWrapper`. 1270 1271 Returns: 1272 An `AttentionWrapperState` tuple containing shapes used by this object. 1273 """ 1274 return AttentionWrapperState( 1275 cell_state=self._cell.state_size, 1276 time=tensor_shape.TensorShape([]), 1277 attention=self._attention_layer_size, 1278 alignments=self._item_or_tuple( 1279 a.alignments_size for a in self._attention_mechanisms), 1280 attention_state=self._item_or_tuple( 1281 a.state_size for a in self._attention_mechanisms), 1282 alignment_history=self._item_or_tuple( 1283 () for _ in self._attention_mechanisms)) # sometimes a TensorArray 1284 1285 def zero_state(self, batch_size, dtype): 1286 """Return an initial (zero) state tuple for this `AttentionWrapper`. 1287 1288 **NOTE** Please see the initializer documentation for details of how 1289 to call `zero_state` if using an `AttentionWrapper` with a 1290 `BeamSearchDecoder`. 1291 1292 Args: 1293 batch_size: `0D` integer tensor: the batch size. 1294 dtype: The internal state data type. 1295 1296 Returns: 1297 An `AttentionWrapperState` tuple containing zeroed out tensors and, 1298 possibly, empty `TensorArray` objects. 1299 1300 Raises: 1301 ValueError: (or, possibly at runtime, InvalidArgument), if 1302 `batch_size` does not match the output size of the encoder passed 1303 to the wrapper object at initialization time. 1304 """ 1305 with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): 1306 if self._initial_cell_state is not None: 1307 cell_state = self._initial_cell_state 1308 else: 1309 cell_state = self._cell.zero_state(batch_size, dtype) 1310 error_message = ( 1311 "When calling zero_state of AttentionWrapper %s: " % self._base_name + 1312 "Non-matching batch sizes between the memory " 1313 "(encoder output) and the requested batch size. Are you using " 1314 "the BeamSearchDecoder? If so, make sure your encoder output has " 1315 "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and " 1316 "the batch_size= argument passed to zero_state is " 1317 "batch_size * beam_width.") 1318 with ops.control_dependencies( 1319 self._batch_size_checks(batch_size, error_message)): 1320 cell_state = nest.map_structure( 1321 lambda s: array_ops.identity(s, name="checked_cell_state"), 1322 cell_state) 1323 return AttentionWrapperState( 1324 cell_state=cell_state, 1325 time=array_ops.zeros([], dtype=dtypes.int32), 1326 attention=_zero_state_tensors(self._attention_layer_size, batch_size, 1327 dtype), 1328 alignments=self._item_or_tuple( 1329 attention_mechanism.initial_alignments(batch_size, dtype) 1330 for attention_mechanism in self._attention_mechanisms), 1331 attention_state=self._item_or_tuple( 1332 attention_mechanism.initial_state(batch_size, dtype) 1333 for attention_mechanism in self._attention_mechanisms), 1334 alignment_history=self._item_or_tuple( 1335 tensor_array_ops.TensorArray(dtype=dtype, size=0, 1336 dynamic_size=True) 1337 if self._alignment_history else () 1338 for _ in self._attention_mechanisms)) 1339 1340 def call(self, inputs, state): 1341 """Perform a step of attention-wrapped RNN. 1342 1343 - Step 1: Mix the `inputs` and previous step's `attention` output via 1344 `cell_input_fn`. 1345 - Step 2: Call the wrapped `cell` with this input and its previous state. 1346 - Step 3: Score the cell's output with `attention_mechanism`. 1347 - Step 4: Calculate the alignments by passing the score through the 1348 `normalizer`. 1349 - Step 5: Calculate the context vector as the inner product between the 1350 alignments and the attention_mechanism's values (memory). 1351 - Step 6: Calculate the attention output by concatenating the cell output 1352 and context through the attention layer (a linear layer with 1353 `attention_layer_size` outputs). 1354 1355 Args: 1356 inputs: (Possibly nested tuple of) Tensor, the input at this time step. 1357 state: An instance of `AttentionWrapperState` containing 1358 tensors from the previous time step. 1359 1360 Returns: 1361 A tuple `(attention_or_cell_output, next_state)`, where: 1362 1363 - `attention_or_cell_output` depending on `output_attention`. 1364 - `next_state` is an instance of `AttentionWrapperState` 1365 containing the state calculated at this time step. 1366 1367 Raises: 1368 TypeError: If `state` is not an instance of `AttentionWrapperState`. 1369 """ 1370 if not isinstance(state, AttentionWrapperState): 1371 raise TypeError("Expected state to be instance of AttentionWrapperState. " 1372 "Received type %s instead." % type(state)) 1373 1374 # Step 1: Calculate the true inputs to the cell based on the 1375 # previous attention value. 1376 cell_inputs = self._cell_input_fn(inputs, state.attention) 1377 cell_state = state.cell_state 1378 cell_output, next_cell_state = self._cell(cell_inputs, cell_state) 1379 1380 cell_batch_size = ( 1381 cell_output.shape[0].value or array_ops.shape(cell_output)[0]) 1382 error_message = ( 1383 "When applying AttentionWrapper %s: " % self.name + 1384 "Non-matching batch sizes between the memory " 1385 "(encoder output) and the query (decoder output). Are you using " 1386 "the BeamSearchDecoder? You may need to tile your memory input via " 1387 "the tf.contrib.seq2seq.tile_batch function with argument " 1388 "multiple=beam_width.") 1389 with ops.control_dependencies( 1390 self._batch_size_checks(cell_batch_size, error_message)): 1391 cell_output = array_ops.identity( 1392 cell_output, name="checked_cell_output") 1393 1394 if self._is_multi: 1395 previous_attention_state = state.attention_state 1396 previous_alignment_history = state.alignment_history 1397 else: 1398 previous_attention_state = [state.attention_state] 1399 previous_alignment_history = [state.alignment_history] 1400 1401 all_alignments = [] 1402 all_attentions = [] 1403 all_attention_states = [] 1404 maybe_all_histories = [] 1405 for i, attention_mechanism in enumerate(self._attention_mechanisms): 1406 attention, alignments, next_attention_state = _compute_attention( 1407 attention_mechanism, cell_output, previous_attention_state[i], 1408 self._attention_layers[i] if self._attention_layers else None) 1409 alignment_history = previous_alignment_history[i].write( 1410 state.time, alignments) if self._alignment_history else () 1411 1412 all_attention_states.append(next_attention_state) 1413 all_alignments.append(alignments) 1414 all_attentions.append(attention) 1415 maybe_all_histories.append(alignment_history) 1416 1417 attention = array_ops.concat(all_attentions, 1) 1418 next_state = AttentionWrapperState( 1419 time=state.time + 1, 1420 cell_state=next_cell_state, 1421 attention=attention, 1422 attention_state=self._item_or_tuple(all_attention_states), 1423 alignments=self._item_or_tuple(all_alignments), 1424 alignment_history=self._item_or_tuple(maybe_all_histories)) 1425 1426 if self._output_attention: 1427 return attention, next_state 1428 else: 1429 return cell_output, next_state 1430