1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for functional style sequence-to-sequence models.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import math 23import random 24 25import numpy as np 26 27from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib 28from tensorflow.contrib.rnn.python.ops import core_rnn_cell 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import random_seed 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import clip_ops 35from tensorflow.python.ops import gradients_impl 36from tensorflow.python.ops import init_ops 37from tensorflow.python.ops import nn_impl 38from tensorflow.python.ops import rnn 39from tensorflow.python.ops import rnn_cell 40from tensorflow.python.ops import state_ops 41from tensorflow.python.ops import variable_scope 42from tensorflow.python.ops import variables 43from tensorflow.python.platform import test 44from tensorflow.python.training import adam 45 46 47class Seq2SeqTest(test.TestCase): 48 49 def testRNNDecoder(self): 50 with self.test_session() as sess: 51 with variable_scope.variable_scope( 52 "root", initializer=init_ops.constant_initializer(0.5)): 53 inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 54 _, enc_state = rnn.static_rnn( 55 rnn_cell.GRUCell(2), inp, dtype=dtypes.float32) 56 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 57 cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) 58 dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell) 59 sess.run([variables.global_variables_initializer()]) 60 res = sess.run(dec) 61 self.assertEqual(3, len(res)) 62 self.assertEqual((2, 4), res[0].shape) 63 64 res = sess.run([mem]) 65 self.assertEqual((2, 2), res[0].shape) 66 67 def testBasicRNNSeq2Seq(self): 68 with self.test_session() as sess: 69 with variable_scope.variable_scope( 70 "root", initializer=init_ops.constant_initializer(0.5)): 71 inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 72 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 73 cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) 74 dec, mem = seq2seq_lib.basic_rnn_seq2seq(inp, dec_inp, cell) 75 sess.run([variables.global_variables_initializer()]) 76 res = sess.run(dec) 77 self.assertEqual(3, len(res)) 78 self.assertEqual((2, 4), res[0].shape) 79 80 res = sess.run([mem]) 81 self.assertEqual((2, 2), res[0].shape) 82 83 def testTiedRNNSeq2Seq(self): 84 with self.test_session() as sess: 85 with variable_scope.variable_scope( 86 "root", initializer=init_ops.constant_initializer(0.5)): 87 inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 88 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 89 cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4) 90 dec, mem = seq2seq_lib.tied_rnn_seq2seq(inp, dec_inp, cell) 91 sess.run([variables.global_variables_initializer()]) 92 res = sess.run(dec) 93 self.assertEqual(3, len(res)) 94 self.assertEqual((2, 4), res[0].shape) 95 96 res = sess.run([mem]) 97 self.assertEqual(1, len(res)) 98 self.assertEqual((2, 2), res[0].shape) 99 100 def testEmbeddingRNNDecoder(self): 101 with self.test_session() as sess: 102 with variable_scope.variable_scope( 103 "root", initializer=init_ops.constant_initializer(0.5)): 104 inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 105 cell_fn = lambda: rnn_cell.BasicLSTMCell(2) 106 cell = cell_fn() 107 _, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) 108 dec_inp = [ 109 constant_op.constant( 110 i, dtypes.int32, shape=[2]) for i in range(3) 111 ] 112 # Use a new cell instance since the attention decoder uses a 113 # different variable scope. 114 dec, mem = seq2seq_lib.embedding_rnn_decoder( 115 dec_inp, enc_state, cell_fn(), num_symbols=4, embedding_size=2) 116 sess.run([variables.global_variables_initializer()]) 117 res = sess.run(dec) 118 self.assertEqual(3, len(res)) 119 self.assertEqual((2, 2), res[0].shape) 120 121 res = sess.run([mem]) 122 self.assertEqual(1, len(res)) 123 self.assertEqual((2, 2), res[0].c.shape) 124 self.assertEqual((2, 2), res[0].h.shape) 125 126 def testEmbeddingRNNSeq2Seq(self): 127 with self.test_session() as sess: 128 with variable_scope.variable_scope( 129 "root", initializer=init_ops.constant_initializer(0.5)): 130 enc_inp = [ 131 constant_op.constant( 132 1, dtypes.int32, shape=[2]) for i in range(2) 133 ] 134 dec_inp = [ 135 constant_op.constant( 136 i, dtypes.int32, shape=[2]) for i in range(3) 137 ] 138 cell_fn = lambda: rnn_cell.BasicLSTMCell(2) 139 cell = cell_fn() 140 dec, mem = seq2seq_lib.embedding_rnn_seq2seq( 141 enc_inp, 142 dec_inp, 143 cell, 144 num_encoder_symbols=2, 145 num_decoder_symbols=5, 146 embedding_size=2) 147 sess.run([variables.global_variables_initializer()]) 148 res = sess.run(dec) 149 self.assertEqual(3, len(res)) 150 self.assertEqual((2, 5), res[0].shape) 151 152 res = sess.run([mem]) 153 self.assertEqual((2, 2), res[0].c.shape) 154 self.assertEqual((2, 2), res[0].h.shape) 155 156 # Test with state_is_tuple=False. 157 with variable_scope.variable_scope("no_tuple"): 158 cell_nt = rnn_cell.BasicLSTMCell(2, state_is_tuple=False) 159 dec, mem = seq2seq_lib.embedding_rnn_seq2seq( 160 enc_inp, 161 dec_inp, 162 cell_nt, 163 num_encoder_symbols=2, 164 num_decoder_symbols=5, 165 embedding_size=2) 166 sess.run([variables.global_variables_initializer()]) 167 res = sess.run(dec) 168 self.assertEqual(3, len(res)) 169 self.assertEqual((2, 5), res[0].shape) 170 171 res = sess.run([mem]) 172 self.assertEqual((2, 4), res[0].shape) 173 174 # Test externally provided output projection. 175 w = variable_scope.get_variable("proj_w", [2, 5]) 176 b = variable_scope.get_variable("proj_b", [5]) 177 with variable_scope.variable_scope("proj_seq2seq"): 178 dec, _ = seq2seq_lib.embedding_rnn_seq2seq( 179 enc_inp, 180 dec_inp, 181 cell_fn(), 182 num_encoder_symbols=2, 183 num_decoder_symbols=5, 184 embedding_size=2, 185 output_projection=(w, b)) 186 sess.run([variables.global_variables_initializer()]) 187 res = sess.run(dec) 188 self.assertEqual(3, len(res)) 189 self.assertEqual((2, 2), res[0].shape) 190 191 # Test that previous-feeding model ignores inputs after the first. 192 dec_inp2 = [ 193 constant_op.constant( 194 0, dtypes.int32, shape=[2]) for _ in range(3) 195 ] 196 with variable_scope.variable_scope("other"): 197 d3, _ = seq2seq_lib.embedding_rnn_seq2seq( 198 enc_inp, 199 dec_inp2, 200 cell_fn(), 201 num_encoder_symbols=2, 202 num_decoder_symbols=5, 203 embedding_size=2, 204 feed_previous=constant_op.constant(True)) 205 with variable_scope.variable_scope("other_2"): 206 d1, _ = seq2seq_lib.embedding_rnn_seq2seq( 207 enc_inp, 208 dec_inp, 209 cell_fn(), 210 num_encoder_symbols=2, 211 num_decoder_symbols=5, 212 embedding_size=2, 213 feed_previous=True) 214 with variable_scope.variable_scope("other_3"): 215 d2, _ = seq2seq_lib.embedding_rnn_seq2seq( 216 enc_inp, 217 dec_inp2, 218 cell_fn(), 219 num_encoder_symbols=2, 220 num_decoder_symbols=5, 221 embedding_size=2, 222 feed_previous=True) 223 sess.run([variables.global_variables_initializer()]) 224 res1 = sess.run(d1) 225 res2 = sess.run(d2) 226 res3 = sess.run(d3) 227 self.assertAllClose(res1, res2) 228 self.assertAllClose(res1, res3) 229 230 def testEmbeddingTiedRNNSeq2Seq(self): 231 with self.test_session() as sess: 232 with variable_scope.variable_scope( 233 "root", initializer=init_ops.constant_initializer(0.5)): 234 enc_inp = [ 235 constant_op.constant( 236 1, dtypes.int32, shape=[2]) for i in range(2) 237 ] 238 dec_inp = [ 239 constant_op.constant( 240 i, dtypes.int32, shape=[2]) for i in range(3) 241 ] 242 cell = functools.partial(rnn_cell.BasicLSTMCell, 2, state_is_tuple=True) 243 dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq( 244 enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2) 245 sess.run([variables.global_variables_initializer()]) 246 res = sess.run(dec) 247 self.assertEqual(3, len(res)) 248 self.assertEqual((2, 5), res[0].shape) 249 250 res = sess.run([mem]) 251 self.assertEqual((2, 2), res[0].c.shape) 252 self.assertEqual((2, 2), res[0].h.shape) 253 254 # Test when num_decoder_symbols is provided, the size of decoder output 255 # is num_decoder_symbols. 256 with variable_scope.variable_scope("decoder_symbols_seq2seq"): 257 dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq( 258 enc_inp, 259 dec_inp, 260 cell(), 261 num_symbols=5, 262 num_decoder_symbols=3, 263 embedding_size=2) 264 sess.run([variables.global_variables_initializer()]) 265 res = sess.run(dec) 266 self.assertEqual(3, len(res)) 267 self.assertEqual((2, 3), res[0].shape) 268 269 # Test externally provided output projection. 270 w = variable_scope.get_variable("proj_w", [2, 5]) 271 b = variable_scope.get_variable("proj_b", [5]) 272 with variable_scope.variable_scope("proj_seq2seq"): 273 dec, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( 274 enc_inp, 275 dec_inp, 276 cell(), 277 num_symbols=5, 278 embedding_size=2, 279 output_projection=(w, b)) 280 sess.run([variables.global_variables_initializer()]) 281 res = sess.run(dec) 282 self.assertEqual(3, len(res)) 283 self.assertEqual((2, 2), res[0].shape) 284 285 # Test that previous-feeding model ignores inputs after the first. 286 dec_inp2 = [constant_op.constant(0, dtypes.int32, shape=[2])] * 3 287 with variable_scope.variable_scope("other"): 288 d3, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( 289 enc_inp, 290 dec_inp2, 291 cell(), 292 num_symbols=5, 293 embedding_size=2, 294 feed_previous=constant_op.constant(True)) 295 with variable_scope.variable_scope("other_2"): 296 d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( 297 enc_inp, 298 dec_inp, 299 cell(), 300 num_symbols=5, 301 embedding_size=2, 302 feed_previous=True) 303 with variable_scope.variable_scope("other_3"): 304 d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq( 305 enc_inp, 306 dec_inp2, 307 cell(), 308 num_symbols=5, 309 embedding_size=2, 310 feed_previous=True) 311 sess.run([variables.global_variables_initializer()]) 312 res1 = sess.run(d1) 313 res2 = sess.run(d2) 314 res3 = sess.run(d3) 315 self.assertAllClose(res1, res2) 316 self.assertAllClose(res1, res3) 317 318 def testAttentionDecoder1(self): 319 with self.test_session() as sess: 320 with variable_scope.variable_scope( 321 "root", initializer=init_ops.constant_initializer(0.5)): 322 cell_fn = lambda: rnn_cell.GRUCell(2) 323 cell = cell_fn() 324 inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 325 enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) 326 attn_states = array_ops.concat([ 327 array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs 328 ], 1) 329 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 330 331 # Create a new cell instance for the decoder, since it uses a 332 # different variable scope 333 dec, mem = seq2seq_lib.attention_decoder( 334 dec_inp, enc_state, attn_states, cell_fn(), output_size=4) 335 sess.run([variables.global_variables_initializer()]) 336 res = sess.run(dec) 337 self.assertEqual(3, len(res)) 338 self.assertEqual((2, 4), res[0].shape) 339 340 res = sess.run([mem]) 341 self.assertEqual((2, 2), res[0].shape) 342 343 def testAttentionDecoder2(self): 344 with self.test_session() as sess: 345 with variable_scope.variable_scope( 346 "root", initializer=init_ops.constant_initializer(0.5)): 347 cell_fn = lambda: rnn_cell.GRUCell(2) 348 cell = cell_fn() 349 inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 350 enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) 351 attn_states = array_ops.concat([ 352 array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs 353 ], 1) 354 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 355 356 # Use a new cell instance since the attention decoder uses a 357 # different variable scope. 358 dec, mem = seq2seq_lib.attention_decoder( 359 dec_inp, enc_state, attn_states, cell_fn(), 360 output_size=4, num_heads=2) 361 sess.run([variables.global_variables_initializer()]) 362 res = sess.run(dec) 363 self.assertEqual(3, len(res)) 364 self.assertEqual((2, 4), res[0].shape) 365 366 res = sess.run([mem]) 367 self.assertEqual((2, 2), res[0].shape) 368 369 def testDynamicAttentionDecoder1(self): 370 with self.test_session() as sess: 371 with variable_scope.variable_scope( 372 "root", initializer=init_ops.constant_initializer(0.5)): 373 cell_fn = lambda: rnn_cell.GRUCell(2) 374 cell = cell_fn() 375 inp = constant_op.constant(0.5, shape=[2, 2, 2]) 376 enc_outputs, enc_state = rnn.dynamic_rnn( 377 cell, inp, dtype=dtypes.float32) 378 attn_states = enc_outputs 379 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 380 381 # Use a new cell instance since the attention decoder uses a 382 # different variable scope. 383 dec, mem = seq2seq_lib.attention_decoder( 384 dec_inp, enc_state, attn_states, cell_fn(), output_size=4) 385 sess.run([variables.global_variables_initializer()]) 386 res = sess.run(dec) 387 self.assertEqual(3, len(res)) 388 self.assertEqual((2, 4), res[0].shape) 389 390 res = sess.run([mem]) 391 self.assertEqual((2, 2), res[0].shape) 392 393 def testDynamicAttentionDecoder2(self): 394 with self.test_session() as sess: 395 with variable_scope.variable_scope( 396 "root", initializer=init_ops.constant_initializer(0.5)): 397 cell_fn = lambda: rnn_cell.GRUCell(2) 398 cell = cell_fn() 399 inp = constant_op.constant(0.5, shape=[2, 2, 2]) 400 enc_outputs, enc_state = rnn.dynamic_rnn( 401 cell, inp, dtype=dtypes.float32) 402 attn_states = enc_outputs 403 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 404 405 # Use a new cell instance since the attention decoder uses a 406 # different variable scope. 407 dec, mem = seq2seq_lib.attention_decoder( 408 dec_inp, enc_state, attn_states, cell_fn(), 409 output_size=4, num_heads=2) 410 sess.run([variables.global_variables_initializer()]) 411 res = sess.run(dec) 412 self.assertEqual(3, len(res)) 413 self.assertEqual((2, 4), res[0].shape) 414 415 res = sess.run([mem]) 416 self.assertEqual((2, 2), res[0].shape) 417 418 def testAttentionDecoderStateIsTuple(self): 419 with self.test_session() as sess: 420 with variable_scope.variable_scope( 421 "root", initializer=init_ops.constant_initializer(0.5)): 422 single_cell = lambda: rnn_cell.BasicLSTMCell( # pylint: disable=g-long-lambda 423 2, state_is_tuple=True) 424 cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda 425 cells=[single_cell() for _ in range(2)], state_is_tuple=True) 426 cell = cell_fn() 427 inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 428 enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) 429 attn_states = array_ops.concat([ 430 array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs 431 ], 1) 432 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 433 434 # Use a new cell instance since the attention decoder uses a 435 # different variable scope. 436 dec, mem = seq2seq_lib.attention_decoder( 437 dec_inp, enc_state, attn_states, cell_fn(), output_size=4) 438 sess.run([variables.global_variables_initializer()]) 439 res = sess.run(dec) 440 self.assertEqual(3, len(res)) 441 self.assertEqual((2, 4), res[0].shape) 442 443 res = sess.run([mem]) 444 self.assertEqual(2, len(res[0])) 445 self.assertEqual((2, 2), res[0][0].c.shape) 446 self.assertEqual((2, 2), res[0][0].h.shape) 447 self.assertEqual((2, 2), res[0][1].c.shape) 448 self.assertEqual((2, 2), res[0][1].h.shape) 449 450 def testDynamicAttentionDecoderStateIsTuple(self): 451 with self.test_session() as sess: 452 with variable_scope.variable_scope( 453 "root", initializer=init_ops.constant_initializer(0.5)): 454 cell_fn = lambda: rnn_cell.MultiRNNCell( # pylint: disable=g-long-lambda 455 cells=[rnn_cell.BasicLSTMCell(2) for _ in range(2)]) 456 cell = cell_fn() 457 inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 458 enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) 459 attn_states = array_ops.concat([ 460 array_ops.reshape(e, [-1, 1, cell.output_size]) 461 for e in enc_outputs 462 ], 1) 463 dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3 464 465 # Use a new cell instance since the attention decoder uses a 466 # different variable scope. 467 dec, mem = seq2seq_lib.attention_decoder( 468 dec_inp, enc_state, attn_states, cell_fn(), output_size=4) 469 sess.run([variables.global_variables_initializer()]) 470 res = sess.run(dec) 471 self.assertEqual(3, len(res)) 472 self.assertEqual((2, 4), res[0].shape) 473 474 res = sess.run([mem]) 475 self.assertEqual(2, len(res[0])) 476 self.assertEqual((2, 2), res[0][0].c.shape) 477 self.assertEqual((2, 2), res[0][0].h.shape) 478 self.assertEqual((2, 2), res[0][1].c.shape) 479 self.assertEqual((2, 2), res[0][1].h.shape) 480 481 def testEmbeddingAttentionDecoder(self): 482 with self.test_session() as sess: 483 with variable_scope.variable_scope( 484 "root", initializer=init_ops.constant_initializer(0.5)): 485 inp = [constant_op.constant(0.5, shape=[2, 2])] * 2 486 cell_fn = lambda: rnn_cell.GRUCell(2) 487 cell = cell_fn() 488 enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32) 489 attn_states = array_ops.concat([ 490 array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs 491 ], 1) 492 dec_inp = [ 493 constant_op.constant( 494 i, dtypes.int32, shape=[2]) for i in range(3) 495 ] 496 497 # Use a new cell instance since the attention decoder uses a 498 # different variable scope. 499 dec, mem = seq2seq_lib.embedding_attention_decoder( 500 dec_inp, 501 enc_state, 502 attn_states, 503 cell_fn(), 504 num_symbols=4, 505 embedding_size=2, 506 output_size=3) 507 sess.run([variables.global_variables_initializer()]) 508 res = sess.run(dec) 509 self.assertEqual(3, len(res)) 510 self.assertEqual((2, 3), res[0].shape) 511 512 res = sess.run([mem]) 513 self.assertEqual((2, 2), res[0].shape) 514 515 def testEmbeddingAttentionSeq2Seq(self): 516 with self.test_session() as sess: 517 with variable_scope.variable_scope( 518 "root", initializer=init_ops.constant_initializer(0.5)): 519 enc_inp = [ 520 constant_op.constant( 521 1, dtypes.int32, shape=[2]) for i in range(2) 522 ] 523 dec_inp = [ 524 constant_op.constant( 525 i, dtypes.int32, shape=[2]) for i in range(3) 526 ] 527 cell_fn = lambda: rnn_cell.BasicLSTMCell(2) 528 cell = cell_fn() 529 dec, mem = seq2seq_lib.embedding_attention_seq2seq( 530 enc_inp, 531 dec_inp, 532 cell, 533 num_encoder_symbols=2, 534 num_decoder_symbols=5, 535 embedding_size=2) 536 sess.run([variables.global_variables_initializer()]) 537 res = sess.run(dec) 538 self.assertEqual(3, len(res)) 539 self.assertEqual((2, 5), res[0].shape) 540 541 res = sess.run([mem]) 542 self.assertEqual((2, 2), res[0].c.shape) 543 self.assertEqual((2, 2), res[0].h.shape) 544 545 # Test with state_is_tuple=False. 546 with variable_scope.variable_scope("no_tuple"): 547 cell_fn = functools.partial( 548 rnn_cell.BasicLSTMCell, 2, state_is_tuple=False) 549 cell_nt = cell_fn() 550 dec, mem = seq2seq_lib.embedding_attention_seq2seq( 551 enc_inp, 552 dec_inp, 553 cell_nt, 554 num_encoder_symbols=2, 555 num_decoder_symbols=5, 556 embedding_size=2) 557 sess.run([variables.global_variables_initializer()]) 558 res = sess.run(dec) 559 self.assertEqual(3, len(res)) 560 self.assertEqual((2, 5), res[0].shape) 561 562 res = sess.run([mem]) 563 self.assertEqual((2, 4), res[0].shape) 564 565 # Test externally provided output projection. 566 w = variable_scope.get_variable("proj_w", [2, 5]) 567 b = variable_scope.get_variable("proj_b", [5]) 568 with variable_scope.variable_scope("proj_seq2seq"): 569 dec, _ = seq2seq_lib.embedding_attention_seq2seq( 570 enc_inp, 571 dec_inp, 572 cell_fn(), 573 num_encoder_symbols=2, 574 num_decoder_symbols=5, 575 embedding_size=2, 576 output_projection=(w, b)) 577 sess.run([variables.global_variables_initializer()]) 578 res = sess.run(dec) 579 self.assertEqual(3, len(res)) 580 self.assertEqual((2, 2), res[0].shape) 581 582 # TODO(ebrevdo, lukaszkaiser): Re-enable once RNNCells allow reuse 583 # within a variable scope that already has a weights tensor. 584 # 585 # # Test that previous-feeding model ignores inputs after the first. 586 # dec_inp2 = [ 587 # constant_op.constant( 588 # 0, dtypes.int32, shape=[2]) for _ in range(3) 589 # ] 590 # with variable_scope.variable_scope("other"): 591 # d3, _ = seq2seq_lib.embedding_attention_seq2seq( 592 # enc_inp, 593 # dec_inp2, 594 # cell_fn(), 595 # num_encoder_symbols=2, 596 # num_decoder_symbols=5, 597 # embedding_size=2, 598 # feed_previous=constant_op.constant(True)) 599 # sess.run([variables.global_variables_initializer()]) 600 # variable_scope.get_variable_scope().reuse_variables() 601 # cell = cell_fn() 602 # d1, _ = seq2seq_lib.embedding_attention_seq2seq( 603 # enc_inp, 604 # dec_inp, 605 # cell, 606 # num_encoder_symbols=2, 607 # num_decoder_symbols=5, 608 # embedding_size=2, 609 # feed_previous=True) 610 # d2, _ = seq2seq_lib.embedding_attention_seq2seq( 611 # enc_inp, 612 # dec_inp2, 613 # cell, 614 # num_encoder_symbols=2, 615 # num_decoder_symbols=5, 616 # embedding_size=2, 617 # feed_previous=True) 618 # res1 = sess.run(d1) 619 # res2 = sess.run(d2) 620 # res3 = sess.run(d3) 621 # self.assertAllClose(res1, res2) 622 # self.assertAllClose(res1, res3) 623 624 def testOne2ManyRNNSeq2Seq(self): 625 with self.test_session() as sess: 626 with variable_scope.variable_scope( 627 "root", initializer=init_ops.constant_initializer(0.5)): 628 enc_inp = [ 629 constant_op.constant( 630 1, dtypes.int32, shape=[2]) for i in range(2) 631 ] 632 dec_inp_dict = {} 633 dec_inp_dict["0"] = [ 634 constant_op.constant( 635 i, dtypes.int32, shape=[2]) for i in range(3) 636 ] 637 dec_inp_dict["1"] = [ 638 constant_op.constant( 639 i, dtypes.int32, shape=[2]) for i in range(4) 640 ] 641 dec_symbols_dict = {"0": 5, "1": 6} 642 def EncCellFn(): 643 return rnn_cell.BasicLSTMCell(2, state_is_tuple=True) 644 def DecCellsFn(): 645 return dict((k, rnn_cell.BasicLSTMCell(2, state_is_tuple=True)) 646 for k in dec_symbols_dict) 647 outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq( 648 enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(), 649 2, dec_symbols_dict, embedding_size=2)) 650 651 sess.run([variables.global_variables_initializer()]) 652 res = sess.run(outputs_dict["0"]) 653 self.assertEqual(3, len(res)) 654 self.assertEqual((2, 5), res[0].shape) 655 res = sess.run(outputs_dict["1"]) 656 self.assertEqual(4, len(res)) 657 self.assertEqual((2, 6), res[0].shape) 658 res = sess.run([state_dict["0"]]) 659 self.assertEqual((2, 2), res[0].c.shape) 660 self.assertEqual((2, 2), res[0].h.shape) 661 res = sess.run([state_dict["1"]]) 662 self.assertEqual((2, 2), res[0].c.shape) 663 self.assertEqual((2, 2), res[0].h.shape) 664 665 # Test that previous-feeding model ignores inputs after the first, i.e. 666 # dec_inp_dict2 has different inputs from dec_inp_dict after the first 667 # time-step. 668 dec_inp_dict2 = {} 669 dec_inp_dict2["0"] = [ 670 constant_op.constant( 671 0, dtypes.int32, shape=[2]) for _ in range(3) 672 ] 673 dec_inp_dict2["1"] = [ 674 constant_op.constant( 675 0, dtypes.int32, shape=[2]) for _ in range(4) 676 ] 677 with variable_scope.variable_scope("other"): 678 outputs_dict3, _ = seq2seq_lib.one2many_rnn_seq2seq( 679 enc_inp, 680 dec_inp_dict2, 681 EncCellFn(), 682 DecCellsFn(), 683 2, 684 dec_symbols_dict, 685 embedding_size=2, 686 feed_previous=constant_op.constant(True)) 687 with variable_scope.variable_scope("other_2"): 688 outputs_dict1, _ = seq2seq_lib.one2many_rnn_seq2seq( 689 enc_inp, 690 dec_inp_dict, 691 EncCellFn(), 692 DecCellsFn(), 693 2, 694 dec_symbols_dict, 695 embedding_size=2, 696 feed_previous=True) 697 with variable_scope.variable_scope("other_3"): 698 outputs_dict2, _ = seq2seq_lib.one2many_rnn_seq2seq( 699 enc_inp, 700 dec_inp_dict2, 701 EncCellFn(), 702 DecCellsFn(), 703 2, 704 dec_symbols_dict, 705 embedding_size=2, 706 feed_previous=True) 707 sess.run([variables.global_variables_initializer()]) 708 res1 = sess.run(outputs_dict1["0"]) 709 res2 = sess.run(outputs_dict2["0"]) 710 res3 = sess.run(outputs_dict3["0"]) 711 self.assertAllClose(res1, res2) 712 self.assertAllClose(res1, res3) 713 714 def testSequenceLoss(self): 715 with self.test_session() as sess: 716 logits = [constant_op.constant(i + 0.5, shape=[2, 5]) for i in range(3)] 717 targets = [ 718 constant_op.constant( 719 i, dtypes.int32, shape=[2]) for i in range(3) 720 ] 721 weights = [constant_op.constant(1.0, shape=[2]) for i in range(3)] 722 723 average_loss_per_example = seq2seq_lib.sequence_loss( 724 logits, 725 targets, 726 weights, 727 average_across_timesteps=True, 728 average_across_batch=True) 729 res = sess.run(average_loss_per_example) 730 self.assertAllClose(1.60944, res) 731 732 average_loss_per_sequence = seq2seq_lib.sequence_loss( 733 logits, 734 targets, 735 weights, 736 average_across_timesteps=False, 737 average_across_batch=True) 738 res = sess.run(average_loss_per_sequence) 739 self.assertAllClose(4.828314, res) 740 741 total_loss = seq2seq_lib.sequence_loss( 742 logits, 743 targets, 744 weights, 745 average_across_timesteps=False, 746 average_across_batch=False) 747 res = sess.run(total_loss) 748 self.assertAllClose(9.656628, res) 749 750 def testSequenceLossByExample(self): 751 with self.test_session() as sess: 752 output_classes = 5 753 logits = [ 754 constant_op.constant( 755 i + 0.5, shape=[2, output_classes]) for i in range(3) 756 ] 757 targets = [ 758 constant_op.constant( 759 i, dtypes.int32, shape=[2]) for i in range(3) 760 ] 761 weights = [constant_op.constant(1.0, shape=[2]) for i in range(3)] 762 763 average_loss_per_example = (seq2seq_lib.sequence_loss_by_example( 764 logits, targets, weights, average_across_timesteps=True)) 765 res = sess.run(average_loss_per_example) 766 self.assertAllClose(np.asarray([1.609438, 1.609438]), res) 767 768 loss_per_sequence = seq2seq_lib.sequence_loss_by_example( 769 logits, targets, weights, average_across_timesteps=False) 770 res = sess.run(loss_per_sequence) 771 self.assertAllClose(np.asarray([4.828314, 4.828314]), res) 772 773 # TODO(ebrevdo, lukaszkaiser): Re-enable once RNNCells allow reuse 774 # within a variable scope that already has a weights tensor. 775 # 776 # def testModelWithBucketsScopeAndLoss(self): 777 # """Test variable scope reuse is not reset after model_with_buckets.""" 778 # classes = 10 779 # buckets = [(4, 4), (8, 8)] 780 781 # with self.test_session(): 782 # # Here comes a sample Seq2Seq model using GRU cells. 783 # def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss): 784 # """Example sequence-to-sequence model that uses GRU cells.""" 785 786 # def GRUSeq2Seq(enc_inp, dec_inp): 787 # cell = rnn_cell.MultiRNNCell( 788 # [rnn_cell.GRUCell(24) for _ in range(2)]) 789 # return seq2seq_lib.embedding_attention_seq2seq( 790 # enc_inp, 791 # dec_inp, 792 # cell, 793 # num_encoder_symbols=classes, 794 # num_decoder_symbols=classes, 795 # embedding_size=24) 796 797 # targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0] 798 # return seq2seq_lib.model_with_buckets( 799 # enc_inp, 800 # dec_inp, 801 # targets, 802 # weights, 803 # buckets, 804 # GRUSeq2Seq, 805 # per_example_loss=per_example_loss) 806 807 # # Now we construct the copy model. 808 # inp = [ 809 # array_ops.placeholder( 810 # dtypes.int32, shape=[None]) for _ in range(8) 811 # ] 812 # out = [ 813 # array_ops.placeholder( 814 # dtypes.int32, shape=[None]) for _ in range(8) 815 # ] 816 # weights = [ 817 # array_ops.ones_like( 818 # inp[0], dtype=dtypes.float32) for _ in range(8) 819 # ] 820 # with variable_scope.variable_scope("root"): 821 # _, losses1 = SampleGRUSeq2Seq( 822 # inp, out, weights, per_example_loss=False) 823 # # Now check that we did not accidentally set reuse. 824 # self.assertEqual(False, variable_scope.get_variable_scope().reuse) 825 # with variable_scope.variable_scope("new"): 826 # _, losses2 = SampleGRUSeq2Seq 827 # inp, out, weights, per_example_loss=True) 828 # # First loss is scalar, the second one is a 1-dimensional tensor. 829 # self.assertEqual([], losses1[0].get_shape().as_list()) 830 # self.assertEqual([None], losses2[0].get_shape().as_list()) 831 832 def testModelWithBuckets(self): 833 """Larger tests that does full sequence-to-sequence model training.""" 834 # We learn to copy 10 symbols in 2 buckets: length 4 and length 8. 835 classes = 10 836 buckets = [(4, 4), (8, 8)] 837 perplexities = [[], []] # Results for each bucket. 838 random_seed.set_random_seed(111) 839 random.seed(111) 840 np.random.seed(111) 841 842 with self.test_session() as sess: 843 # We use sampled softmax so we keep output projection separate. 844 w = variable_scope.get_variable("proj_w", [24, classes]) 845 w_t = array_ops.transpose(w) 846 b = variable_scope.get_variable("proj_b", [classes]) 847 848 # Here comes a sample Seq2Seq model using GRU cells. 849 def SampleGRUSeq2Seq(enc_inp, dec_inp, weights): 850 """Example sequence-to-sequence model that uses GRU cells.""" 851 852 def GRUSeq2Seq(enc_inp, dec_inp): 853 cell = rnn_cell.MultiRNNCell( 854 [rnn_cell.GRUCell(24) for _ in range(2)], state_is_tuple=True) 855 return seq2seq_lib.embedding_attention_seq2seq( 856 enc_inp, 857 dec_inp, 858 cell, 859 num_encoder_symbols=classes, 860 num_decoder_symbols=classes, 861 embedding_size=24, 862 output_projection=(w, b)) 863 864 targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0] 865 866 def SampledLoss(labels, logits): 867 labels = array_ops.reshape(labels, [-1, 1]) 868 return nn_impl.sampled_softmax_loss( 869 weights=w_t, 870 biases=b, 871 labels=labels, 872 inputs=logits, 873 num_sampled=8, 874 num_classes=classes) 875 876 return seq2seq_lib.model_with_buckets( 877 enc_inp, 878 dec_inp, 879 targets, 880 weights, 881 buckets, 882 GRUSeq2Seq, 883 softmax_loss_function=SampledLoss) 884 885 # Now we construct the copy model. 886 batch_size = 8 887 inp = [ 888 array_ops.placeholder( 889 dtypes.int32, shape=[None]) for _ in range(8) 890 ] 891 out = [ 892 array_ops.placeholder( 893 dtypes.int32, shape=[None]) for _ in range(8) 894 ] 895 weights = [ 896 array_ops.ones_like( 897 inp[0], dtype=dtypes.float32) for _ in range(8) 898 ] 899 with variable_scope.variable_scope("root"): 900 _, losses = SampleGRUSeq2Seq(inp, out, weights) 901 updates = [] 902 params = variables.global_variables() 903 optimizer = adam.AdamOptimizer(0.03, epsilon=1e-5) 904 for i in range(len(buckets)): 905 full_grads = gradients_impl.gradients(losses[i], params) 906 grads, _ = clip_ops.clip_by_global_norm(full_grads, 30.0) 907 update = optimizer.apply_gradients(zip(grads, params)) 908 updates.append(update) 909 sess.run([variables.global_variables_initializer()]) 910 steps = 6 911 for _ in range(steps): 912 bucket = random.choice(np.arange(len(buckets))) 913 length = buckets[bucket][0] 914 i = [ 915 np.array( 916 [np.random.randint(9) + 1 for _ in range(batch_size)], 917 dtype=np.int32) for _ in range(length) 918 ] 919 # 0 is our "GO" symbol here. 920 o = [np.array([0] * batch_size, dtype=np.int32)] + i 921 feed = {} 922 for i1, i2, o1, o2 in zip(inp[:length], i[:length], out[:length], 923 o[:length]): 924 feed[i1.name] = i2 925 feed[o1.name] = o2 926 if length < 8: # For the 4-bucket, we need the 5th as target. 927 feed[out[length].name] = o[length] 928 res = sess.run([updates[bucket], losses[bucket]], feed) 929 perplexities[bucket].append(math.exp(float(res[1]))) 930 for bucket in range(len(buckets)): 931 if len(perplexities[bucket]) > 1: # Assert that perplexity went down. 932 self.assertLess(perplexities[bucket][-1], # 20% margin of error. 933 1.2 * perplexities[bucket][0]) 934 935 def testModelWithBooleanFeedPrevious(self): 936 """Test the model behavior when feed_previous is True. 937 938 For example, the following two cases have the same effect: 939 - Train `embedding_rnn_seq2seq` with `feed_previous=True`, which contains 940 a `embedding_rnn_decoder` with `feed_previous=True` and 941 `update_embedding_for_previous=True`. The decoder is fed with "<Go>" 942 and outputs "A, B, C". 943 - Train `embedding_rnn_seq2seq` with `feed_previous=False`. The decoder 944 is fed with "<Go>, A, B". 945 """ 946 num_encoder_symbols = 3 947 num_decoder_symbols = 5 948 batch_size = 2 949 num_enc_timesteps = 2 950 num_dec_timesteps = 3 951 952 def TestModel(seq2seq): 953 with self.test_session(graph=ops.Graph()) as sess: 954 random_seed.set_random_seed(111) 955 random.seed(111) 956 np.random.seed(111) 957 958 enc_inp = [ 959 constant_op.constant( 960 i + 1, dtypes.int32, shape=[batch_size]) 961 for i in range(num_enc_timesteps) 962 ] 963 dec_inp_fp_true = [ 964 constant_op.constant( 965 i, dtypes.int32, shape=[batch_size]) 966 for i in range(num_dec_timesteps) 967 ] 968 dec_inp_holder_fp_false = [ 969 array_ops.placeholder( 970 dtypes.int32, shape=[batch_size]) 971 for _ in range(num_dec_timesteps) 972 ] 973 targets = [ 974 constant_op.constant( 975 i + 1, dtypes.int32, shape=[batch_size]) 976 for i in range(num_dec_timesteps) 977 ] 978 weights = [ 979 constant_op.constant( 980 1.0, shape=[batch_size]) for i in range(num_dec_timesteps) 981 ] 982 983 def ForwardBackward(enc_inp, dec_inp, feed_previous): 984 scope_name = "fp_{}".format(feed_previous) 985 with variable_scope.variable_scope(scope_name): 986 dec_op, _ = seq2seq(enc_inp, dec_inp, feed_previous=feed_previous) 987 net_variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, 988 scope_name) 989 optimizer = adam.AdamOptimizer(0.03, epsilon=1e-5) 990 update_op = optimizer.minimize( 991 seq2seq_lib.sequence_loss(dec_op, targets, weights), 992 var_list=net_variables) 993 return dec_op, update_op, net_variables 994 995 dec_op_fp_true, update_fp_true, variables_fp_true = ForwardBackward( 996 enc_inp, dec_inp_fp_true, feed_previous=True) 997 _, update_fp_false, variables_fp_false = ForwardBackward( 998 enc_inp, dec_inp_holder_fp_false, feed_previous=False) 999 1000 sess.run(variables.global_variables_initializer()) 1001 1002 # We only check consistencies between the variables existing in both 1003 # the models with True and False feed_previous. Variables created by 1004 # the loop_function in the model with True feed_previous are ignored. 1005 v_false_name_dict = { 1006 v.name.split("/", 1)[-1]: v 1007 for v in variables_fp_false 1008 } 1009 matched_variables = [(v, v_false_name_dict[v.name.split("/", 1)[-1]]) 1010 for v in variables_fp_true] 1011 for v_true, v_false in matched_variables: 1012 sess.run(state_ops.assign(v_false, v_true)) 1013 1014 # Take the symbols generated by the decoder with feed_previous=True as 1015 # the true input symbols for the decoder with feed_previous=False. 1016 dec_fp_true = sess.run(dec_op_fp_true) 1017 output_symbols_fp_true = np.argmax(dec_fp_true, axis=2) 1018 dec_inp_fp_false = np.vstack((dec_inp_fp_true[0].eval(), 1019 output_symbols_fp_true[:-1])) 1020 sess.run(update_fp_true) 1021 sess.run(update_fp_false, { 1022 holder: inp 1023 for holder, inp in zip(dec_inp_holder_fp_false, dec_inp_fp_false) 1024 }) 1025 1026 for v_true, v_false in matched_variables: 1027 self.assertAllClose(v_true.eval(), v_false.eval()) 1028 1029 def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous): 1030 cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True) 1031 return seq2seq_lib.embedding_rnn_seq2seq( 1032 enc_inp, 1033 dec_inp, 1034 cell, 1035 num_encoder_symbols, 1036 num_decoder_symbols, 1037 embedding_size=2, 1038 feed_previous=feed_previous) 1039 1040 def EmbeddingRNNSeq2SeqNoTupleF(enc_inp, dec_inp, feed_previous): 1041 cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False) 1042 return seq2seq_lib.embedding_rnn_seq2seq( 1043 enc_inp, 1044 dec_inp, 1045 cell, 1046 num_encoder_symbols, 1047 num_decoder_symbols, 1048 embedding_size=2, 1049 feed_previous=feed_previous) 1050 1051 def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous): 1052 cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True) 1053 return seq2seq_lib.embedding_tied_rnn_seq2seq( 1054 enc_inp, 1055 dec_inp, 1056 cell, 1057 num_decoder_symbols, 1058 embedding_size=2, 1059 feed_previous=feed_previous) 1060 1061 def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous): 1062 cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False) 1063 return seq2seq_lib.embedding_tied_rnn_seq2seq( 1064 enc_inp, 1065 dec_inp, 1066 cell, 1067 num_decoder_symbols, 1068 embedding_size=2, 1069 feed_previous=feed_previous) 1070 1071 def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous): 1072 cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True) 1073 return seq2seq_lib.embedding_attention_seq2seq( 1074 enc_inp, 1075 dec_inp, 1076 cell, 1077 num_encoder_symbols, 1078 num_decoder_symbols, 1079 embedding_size=2, 1080 feed_previous=feed_previous) 1081 1082 def EmbeddingAttentionSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous): 1083 cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False) 1084 return seq2seq_lib.embedding_attention_seq2seq( 1085 enc_inp, 1086 dec_inp, 1087 cell, 1088 num_encoder_symbols, 1089 num_decoder_symbols, 1090 embedding_size=2, 1091 feed_previous=feed_previous) 1092 1093 for model in (EmbeddingRNNSeq2SeqF, EmbeddingRNNSeq2SeqNoTupleF, 1094 EmbeddingTiedRNNSeq2Seq, EmbeddingTiedRNNSeq2SeqNoTuple, 1095 EmbeddingAttentionSeq2Seq, EmbeddingAttentionSeq2SeqNoTuple): 1096 TestModel(model) 1097 1098 1099if __name__ == "__main__": 1100 test.main() 1101