1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for functional style sequence-to-sequence models."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import math
23import random
24
25import numpy as np
26
27from tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_lib
28from tensorflow.contrib.rnn.python.ops import core_rnn_cell
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import random_seed
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import clip_ops
35from tensorflow.python.ops import gradients_impl
36from tensorflow.python.ops import init_ops
37from tensorflow.python.ops import nn_impl
38from tensorflow.python.ops import rnn
39from tensorflow.python.ops import rnn_cell
40from tensorflow.python.ops import state_ops
41from tensorflow.python.ops import variable_scope
42from tensorflow.python.ops import variables
43from tensorflow.python.platform import test
44from tensorflow.python.training import adam
45
46
47class Seq2SeqTest(test.TestCase):
48
49  def testRNNDecoder(self):
50    with self.test_session() as sess:
51      with variable_scope.variable_scope(
52          "root", initializer=init_ops.constant_initializer(0.5)):
53        inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
54        _, enc_state = rnn.static_rnn(
55            rnn_cell.GRUCell(2), inp, dtype=dtypes.float32)
56        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
57        cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
58        dec, mem = seq2seq_lib.rnn_decoder(dec_inp, enc_state, cell)
59        sess.run([variables.global_variables_initializer()])
60        res = sess.run(dec)
61        self.assertEqual(3, len(res))
62        self.assertEqual((2, 4), res[0].shape)
63
64        res = sess.run([mem])
65        self.assertEqual((2, 2), res[0].shape)
66
67  def testBasicRNNSeq2Seq(self):
68    with self.test_session() as sess:
69      with variable_scope.variable_scope(
70          "root", initializer=init_ops.constant_initializer(0.5)):
71        inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
72        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
73        cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
74        dec, mem = seq2seq_lib.basic_rnn_seq2seq(inp, dec_inp, cell)
75        sess.run([variables.global_variables_initializer()])
76        res = sess.run(dec)
77        self.assertEqual(3, len(res))
78        self.assertEqual((2, 4), res[0].shape)
79
80        res = sess.run([mem])
81        self.assertEqual((2, 2), res[0].shape)
82
83  def testTiedRNNSeq2Seq(self):
84    with self.test_session() as sess:
85      with variable_scope.variable_scope(
86          "root", initializer=init_ops.constant_initializer(0.5)):
87        inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
88        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
89        cell = core_rnn_cell.OutputProjectionWrapper(rnn_cell.GRUCell(2), 4)
90        dec, mem = seq2seq_lib.tied_rnn_seq2seq(inp, dec_inp, cell)
91        sess.run([variables.global_variables_initializer()])
92        res = sess.run(dec)
93        self.assertEqual(3, len(res))
94        self.assertEqual((2, 4), res[0].shape)
95
96        res = sess.run([mem])
97        self.assertEqual(1, len(res))
98        self.assertEqual((2, 2), res[0].shape)
99
100  def testEmbeddingRNNDecoder(self):
101    with self.test_session() as sess:
102      with variable_scope.variable_scope(
103          "root", initializer=init_ops.constant_initializer(0.5)):
104        inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
105        cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
106        cell = cell_fn()
107        _, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
108        dec_inp = [
109            constant_op.constant(
110                i, dtypes.int32, shape=[2]) for i in range(3)
111        ]
112        # Use a new cell instance since the attention decoder uses a
113        # different variable scope.
114        dec, mem = seq2seq_lib.embedding_rnn_decoder(
115            dec_inp, enc_state, cell_fn(), num_symbols=4, embedding_size=2)
116        sess.run([variables.global_variables_initializer()])
117        res = sess.run(dec)
118        self.assertEqual(3, len(res))
119        self.assertEqual((2, 2), res[0].shape)
120
121        res = sess.run([mem])
122        self.assertEqual(1, len(res))
123        self.assertEqual((2, 2), res[0].c.shape)
124        self.assertEqual((2, 2), res[0].h.shape)
125
126  def testEmbeddingRNNSeq2Seq(self):
127    with self.test_session() as sess:
128      with variable_scope.variable_scope(
129          "root", initializer=init_ops.constant_initializer(0.5)):
130        enc_inp = [
131            constant_op.constant(
132                1, dtypes.int32, shape=[2]) for i in range(2)
133        ]
134        dec_inp = [
135            constant_op.constant(
136                i, dtypes.int32, shape=[2]) for i in range(3)
137        ]
138        cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
139        cell = cell_fn()
140        dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
141            enc_inp,
142            dec_inp,
143            cell,
144            num_encoder_symbols=2,
145            num_decoder_symbols=5,
146            embedding_size=2)
147        sess.run([variables.global_variables_initializer()])
148        res = sess.run(dec)
149        self.assertEqual(3, len(res))
150        self.assertEqual((2, 5), res[0].shape)
151
152        res = sess.run([mem])
153        self.assertEqual((2, 2), res[0].c.shape)
154        self.assertEqual((2, 2), res[0].h.shape)
155
156        # Test with state_is_tuple=False.
157        with variable_scope.variable_scope("no_tuple"):
158          cell_nt = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
159          dec, mem = seq2seq_lib.embedding_rnn_seq2seq(
160              enc_inp,
161              dec_inp,
162              cell_nt,
163              num_encoder_symbols=2,
164              num_decoder_symbols=5,
165              embedding_size=2)
166          sess.run([variables.global_variables_initializer()])
167          res = sess.run(dec)
168          self.assertEqual(3, len(res))
169          self.assertEqual((2, 5), res[0].shape)
170
171          res = sess.run([mem])
172          self.assertEqual((2, 4), res[0].shape)
173
174        # Test externally provided output projection.
175        w = variable_scope.get_variable("proj_w", [2, 5])
176        b = variable_scope.get_variable("proj_b", [5])
177        with variable_scope.variable_scope("proj_seq2seq"):
178          dec, _ = seq2seq_lib.embedding_rnn_seq2seq(
179              enc_inp,
180              dec_inp,
181              cell_fn(),
182              num_encoder_symbols=2,
183              num_decoder_symbols=5,
184              embedding_size=2,
185              output_projection=(w, b))
186        sess.run([variables.global_variables_initializer()])
187        res = sess.run(dec)
188        self.assertEqual(3, len(res))
189        self.assertEqual((2, 2), res[0].shape)
190
191        # Test that previous-feeding model ignores inputs after the first.
192        dec_inp2 = [
193            constant_op.constant(
194                0, dtypes.int32, shape=[2]) for _ in range(3)
195        ]
196        with variable_scope.variable_scope("other"):
197          d3, _ = seq2seq_lib.embedding_rnn_seq2seq(
198              enc_inp,
199              dec_inp2,
200              cell_fn(),
201              num_encoder_symbols=2,
202              num_decoder_symbols=5,
203              embedding_size=2,
204              feed_previous=constant_op.constant(True))
205        with variable_scope.variable_scope("other_2"):
206          d1, _ = seq2seq_lib.embedding_rnn_seq2seq(
207              enc_inp,
208              dec_inp,
209              cell_fn(),
210              num_encoder_symbols=2,
211              num_decoder_symbols=5,
212              embedding_size=2,
213              feed_previous=True)
214        with variable_scope.variable_scope("other_3"):
215          d2, _ = seq2seq_lib.embedding_rnn_seq2seq(
216              enc_inp,
217              dec_inp2,
218              cell_fn(),
219              num_encoder_symbols=2,
220              num_decoder_symbols=5,
221              embedding_size=2,
222              feed_previous=True)
223        sess.run([variables.global_variables_initializer()])
224        res1 = sess.run(d1)
225        res2 = sess.run(d2)
226        res3 = sess.run(d3)
227        self.assertAllClose(res1, res2)
228        self.assertAllClose(res1, res3)
229
230  def testEmbeddingTiedRNNSeq2Seq(self):
231    with self.test_session() as sess:
232      with variable_scope.variable_scope(
233          "root", initializer=init_ops.constant_initializer(0.5)):
234        enc_inp = [
235            constant_op.constant(
236                1, dtypes.int32, shape=[2]) for i in range(2)
237        ]
238        dec_inp = [
239            constant_op.constant(
240                i, dtypes.int32, shape=[2]) for i in range(3)
241        ]
242        cell = functools.partial(rnn_cell.BasicLSTMCell, 2, state_is_tuple=True)
243        dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
244            enc_inp, dec_inp, cell(), num_symbols=5, embedding_size=2)
245        sess.run([variables.global_variables_initializer()])
246        res = sess.run(dec)
247        self.assertEqual(3, len(res))
248        self.assertEqual((2, 5), res[0].shape)
249
250        res = sess.run([mem])
251        self.assertEqual((2, 2), res[0].c.shape)
252        self.assertEqual((2, 2), res[0].h.shape)
253
254        # Test when num_decoder_symbols is provided, the size of decoder output
255        # is num_decoder_symbols.
256        with variable_scope.variable_scope("decoder_symbols_seq2seq"):
257          dec, mem = seq2seq_lib.embedding_tied_rnn_seq2seq(
258              enc_inp,
259              dec_inp,
260              cell(),
261              num_symbols=5,
262              num_decoder_symbols=3,
263              embedding_size=2)
264        sess.run([variables.global_variables_initializer()])
265        res = sess.run(dec)
266        self.assertEqual(3, len(res))
267        self.assertEqual((2, 3), res[0].shape)
268
269        # Test externally provided output projection.
270        w = variable_scope.get_variable("proj_w", [2, 5])
271        b = variable_scope.get_variable("proj_b", [5])
272        with variable_scope.variable_scope("proj_seq2seq"):
273          dec, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
274              enc_inp,
275              dec_inp,
276              cell(),
277              num_symbols=5,
278              embedding_size=2,
279              output_projection=(w, b))
280        sess.run([variables.global_variables_initializer()])
281        res = sess.run(dec)
282        self.assertEqual(3, len(res))
283        self.assertEqual((2, 2), res[0].shape)
284
285        # Test that previous-feeding model ignores inputs after the first.
286        dec_inp2 = [constant_op.constant(0, dtypes.int32, shape=[2])] * 3
287        with variable_scope.variable_scope("other"):
288          d3, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
289              enc_inp,
290              dec_inp2,
291              cell(),
292              num_symbols=5,
293              embedding_size=2,
294              feed_previous=constant_op.constant(True))
295        with variable_scope.variable_scope("other_2"):
296          d1, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
297              enc_inp,
298              dec_inp,
299              cell(),
300              num_symbols=5,
301              embedding_size=2,
302              feed_previous=True)
303        with variable_scope.variable_scope("other_3"):
304          d2, _ = seq2seq_lib.embedding_tied_rnn_seq2seq(
305              enc_inp,
306              dec_inp2,
307              cell(),
308              num_symbols=5,
309              embedding_size=2,
310              feed_previous=True)
311        sess.run([variables.global_variables_initializer()])
312        res1 = sess.run(d1)
313        res2 = sess.run(d2)
314        res3 = sess.run(d3)
315        self.assertAllClose(res1, res2)
316        self.assertAllClose(res1, res3)
317
318  def testAttentionDecoder1(self):
319    with self.test_session() as sess:
320      with variable_scope.variable_scope(
321          "root", initializer=init_ops.constant_initializer(0.5)):
322        cell_fn = lambda: rnn_cell.GRUCell(2)
323        cell = cell_fn()
324        inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
325        enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
326        attn_states = array_ops.concat([
327            array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
328        ], 1)
329        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
330
331        # Create a new cell instance for the decoder, since it uses a
332        # different variable scope
333        dec, mem = seq2seq_lib.attention_decoder(
334            dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
335        sess.run([variables.global_variables_initializer()])
336        res = sess.run(dec)
337        self.assertEqual(3, len(res))
338        self.assertEqual((2, 4), res[0].shape)
339
340        res = sess.run([mem])
341        self.assertEqual((2, 2), res[0].shape)
342
343  def testAttentionDecoder2(self):
344    with self.test_session() as sess:
345      with variable_scope.variable_scope(
346          "root", initializer=init_ops.constant_initializer(0.5)):
347        cell_fn = lambda: rnn_cell.GRUCell(2)
348        cell = cell_fn()
349        inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
350        enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
351        attn_states = array_ops.concat([
352            array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
353        ], 1)
354        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
355
356        # Use a new cell instance since the attention decoder uses a
357        # different variable scope.
358        dec, mem = seq2seq_lib.attention_decoder(
359            dec_inp, enc_state, attn_states, cell_fn(),
360            output_size=4, num_heads=2)
361        sess.run([variables.global_variables_initializer()])
362        res = sess.run(dec)
363        self.assertEqual(3, len(res))
364        self.assertEqual((2, 4), res[0].shape)
365
366        res = sess.run([mem])
367        self.assertEqual((2, 2), res[0].shape)
368
369  def testDynamicAttentionDecoder1(self):
370    with self.test_session() as sess:
371      with variable_scope.variable_scope(
372          "root", initializer=init_ops.constant_initializer(0.5)):
373        cell_fn = lambda: rnn_cell.GRUCell(2)
374        cell = cell_fn()
375        inp = constant_op.constant(0.5, shape=[2, 2, 2])
376        enc_outputs, enc_state = rnn.dynamic_rnn(
377            cell, inp, dtype=dtypes.float32)
378        attn_states = enc_outputs
379        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
380
381        # Use a new cell instance since the attention decoder uses a
382        # different variable scope.
383        dec, mem = seq2seq_lib.attention_decoder(
384            dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
385        sess.run([variables.global_variables_initializer()])
386        res = sess.run(dec)
387        self.assertEqual(3, len(res))
388        self.assertEqual((2, 4), res[0].shape)
389
390        res = sess.run([mem])
391        self.assertEqual((2, 2), res[0].shape)
392
393  def testDynamicAttentionDecoder2(self):
394    with self.test_session() as sess:
395      with variable_scope.variable_scope(
396          "root", initializer=init_ops.constant_initializer(0.5)):
397        cell_fn = lambda: rnn_cell.GRUCell(2)
398        cell = cell_fn()
399        inp = constant_op.constant(0.5, shape=[2, 2, 2])
400        enc_outputs, enc_state = rnn.dynamic_rnn(
401            cell, inp, dtype=dtypes.float32)
402        attn_states = enc_outputs
403        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
404
405        # Use a new cell instance since the attention decoder uses a
406        # different variable scope.
407        dec, mem = seq2seq_lib.attention_decoder(
408            dec_inp, enc_state, attn_states, cell_fn(),
409            output_size=4, num_heads=2)
410        sess.run([variables.global_variables_initializer()])
411        res = sess.run(dec)
412        self.assertEqual(3, len(res))
413        self.assertEqual((2, 4), res[0].shape)
414
415        res = sess.run([mem])
416        self.assertEqual((2, 2), res[0].shape)
417
418  def testAttentionDecoderStateIsTuple(self):
419    with self.test_session() as sess:
420      with variable_scope.variable_scope(
421          "root", initializer=init_ops.constant_initializer(0.5)):
422        single_cell = lambda: rnn_cell.BasicLSTMCell(  # pylint: disable=g-long-lambda
423            2, state_is_tuple=True)
424        cell_fn = lambda: rnn_cell.MultiRNNCell(  # pylint: disable=g-long-lambda
425            cells=[single_cell() for _ in range(2)], state_is_tuple=True)
426        cell = cell_fn()
427        inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
428        enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
429        attn_states = array_ops.concat([
430            array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
431        ], 1)
432        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
433
434        # Use a new cell instance since the attention decoder uses a
435        # different variable scope.
436        dec, mem = seq2seq_lib.attention_decoder(
437            dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
438        sess.run([variables.global_variables_initializer()])
439        res = sess.run(dec)
440        self.assertEqual(3, len(res))
441        self.assertEqual((2, 4), res[0].shape)
442
443        res = sess.run([mem])
444        self.assertEqual(2, len(res[0]))
445        self.assertEqual((2, 2), res[0][0].c.shape)
446        self.assertEqual((2, 2), res[0][0].h.shape)
447        self.assertEqual((2, 2), res[0][1].c.shape)
448        self.assertEqual((2, 2), res[0][1].h.shape)
449
450  def testDynamicAttentionDecoderStateIsTuple(self):
451    with self.test_session() as sess:
452      with variable_scope.variable_scope(
453          "root", initializer=init_ops.constant_initializer(0.5)):
454        cell_fn = lambda: rnn_cell.MultiRNNCell(  # pylint: disable=g-long-lambda
455            cells=[rnn_cell.BasicLSTMCell(2) for _ in range(2)])
456        cell = cell_fn()
457        inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
458        enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
459        attn_states = array_ops.concat([
460            array_ops.reshape(e, [-1, 1, cell.output_size])
461            for e in enc_outputs
462        ], 1)
463        dec_inp = [constant_op.constant(0.4, shape=[2, 2])] * 3
464
465        # Use a new cell instance since the attention decoder uses a
466        # different variable scope.
467        dec, mem = seq2seq_lib.attention_decoder(
468            dec_inp, enc_state, attn_states, cell_fn(), output_size=4)
469        sess.run([variables.global_variables_initializer()])
470        res = sess.run(dec)
471        self.assertEqual(3, len(res))
472        self.assertEqual((2, 4), res[0].shape)
473
474        res = sess.run([mem])
475        self.assertEqual(2, len(res[0]))
476        self.assertEqual((2, 2), res[0][0].c.shape)
477        self.assertEqual((2, 2), res[0][0].h.shape)
478        self.assertEqual((2, 2), res[0][1].c.shape)
479        self.assertEqual((2, 2), res[0][1].h.shape)
480
481  def testEmbeddingAttentionDecoder(self):
482    with self.test_session() as sess:
483      with variable_scope.variable_scope(
484          "root", initializer=init_ops.constant_initializer(0.5)):
485        inp = [constant_op.constant(0.5, shape=[2, 2])] * 2
486        cell_fn = lambda: rnn_cell.GRUCell(2)
487        cell = cell_fn()
488        enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=dtypes.float32)
489        attn_states = array_ops.concat([
490            array_ops.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs
491        ], 1)
492        dec_inp = [
493            constant_op.constant(
494                i, dtypes.int32, shape=[2]) for i in range(3)
495        ]
496
497        # Use a new cell instance since the attention decoder uses a
498        # different variable scope.
499        dec, mem = seq2seq_lib.embedding_attention_decoder(
500            dec_inp,
501            enc_state,
502            attn_states,
503            cell_fn(),
504            num_symbols=4,
505            embedding_size=2,
506            output_size=3)
507        sess.run([variables.global_variables_initializer()])
508        res = sess.run(dec)
509        self.assertEqual(3, len(res))
510        self.assertEqual((2, 3), res[0].shape)
511
512        res = sess.run([mem])
513        self.assertEqual((2, 2), res[0].shape)
514
515  def testEmbeddingAttentionSeq2Seq(self):
516    with self.test_session() as sess:
517      with variable_scope.variable_scope(
518          "root", initializer=init_ops.constant_initializer(0.5)):
519        enc_inp = [
520            constant_op.constant(
521                1, dtypes.int32, shape=[2]) for i in range(2)
522        ]
523        dec_inp = [
524            constant_op.constant(
525                i, dtypes.int32, shape=[2]) for i in range(3)
526        ]
527        cell_fn = lambda: rnn_cell.BasicLSTMCell(2)
528        cell = cell_fn()
529        dec, mem = seq2seq_lib.embedding_attention_seq2seq(
530            enc_inp,
531            dec_inp,
532            cell,
533            num_encoder_symbols=2,
534            num_decoder_symbols=5,
535            embedding_size=2)
536        sess.run([variables.global_variables_initializer()])
537        res = sess.run(dec)
538        self.assertEqual(3, len(res))
539        self.assertEqual((2, 5), res[0].shape)
540
541        res = sess.run([mem])
542        self.assertEqual((2, 2), res[0].c.shape)
543        self.assertEqual((2, 2), res[0].h.shape)
544
545        # Test with state_is_tuple=False.
546        with variable_scope.variable_scope("no_tuple"):
547          cell_fn = functools.partial(
548              rnn_cell.BasicLSTMCell, 2, state_is_tuple=False)
549          cell_nt = cell_fn()
550          dec, mem = seq2seq_lib.embedding_attention_seq2seq(
551              enc_inp,
552              dec_inp,
553              cell_nt,
554              num_encoder_symbols=2,
555              num_decoder_symbols=5,
556              embedding_size=2)
557          sess.run([variables.global_variables_initializer()])
558          res = sess.run(dec)
559          self.assertEqual(3, len(res))
560          self.assertEqual((2, 5), res[0].shape)
561
562          res = sess.run([mem])
563          self.assertEqual((2, 4), res[0].shape)
564
565        # Test externally provided output projection.
566        w = variable_scope.get_variable("proj_w", [2, 5])
567        b = variable_scope.get_variable("proj_b", [5])
568        with variable_scope.variable_scope("proj_seq2seq"):
569          dec, _ = seq2seq_lib.embedding_attention_seq2seq(
570              enc_inp,
571              dec_inp,
572              cell_fn(),
573              num_encoder_symbols=2,
574              num_decoder_symbols=5,
575              embedding_size=2,
576              output_projection=(w, b))
577        sess.run([variables.global_variables_initializer()])
578        res = sess.run(dec)
579        self.assertEqual(3, len(res))
580        self.assertEqual((2, 2), res[0].shape)
581
582        # TODO(ebrevdo, lukaszkaiser): Re-enable once RNNCells allow reuse
583        # within a variable scope that already has a weights tensor.
584        #
585        # # Test that previous-feeding model ignores inputs after the first.
586        # dec_inp2 = [
587        #     constant_op.constant(
588        #         0, dtypes.int32, shape=[2]) for _ in range(3)
589        # ]
590        # with variable_scope.variable_scope("other"):
591        #   d3, _ = seq2seq_lib.embedding_attention_seq2seq(
592        #       enc_inp,
593        #       dec_inp2,
594        #       cell_fn(),
595        #       num_encoder_symbols=2,
596        #       num_decoder_symbols=5,
597        #       embedding_size=2,
598        #       feed_previous=constant_op.constant(True))
599        # sess.run([variables.global_variables_initializer()])
600        # variable_scope.get_variable_scope().reuse_variables()
601        # cell = cell_fn()
602        # d1, _ = seq2seq_lib.embedding_attention_seq2seq(
603        #     enc_inp,
604        #     dec_inp,
605        #     cell,
606        #     num_encoder_symbols=2,
607        #     num_decoder_symbols=5,
608        #     embedding_size=2,
609        #     feed_previous=True)
610        # d2, _ = seq2seq_lib.embedding_attention_seq2seq(
611        #     enc_inp,
612        #     dec_inp2,
613        #     cell,
614        #     num_encoder_symbols=2,
615        #     num_decoder_symbols=5,
616        #     embedding_size=2,
617        #     feed_previous=True)
618        # res1 = sess.run(d1)
619        # res2 = sess.run(d2)
620        # res3 = sess.run(d3)
621        # self.assertAllClose(res1, res2)
622        # self.assertAllClose(res1, res3)
623
624  def testOne2ManyRNNSeq2Seq(self):
625    with self.test_session() as sess:
626      with variable_scope.variable_scope(
627          "root", initializer=init_ops.constant_initializer(0.5)):
628        enc_inp = [
629            constant_op.constant(
630                1, dtypes.int32, shape=[2]) for i in range(2)
631        ]
632        dec_inp_dict = {}
633        dec_inp_dict["0"] = [
634            constant_op.constant(
635                i, dtypes.int32, shape=[2]) for i in range(3)
636        ]
637        dec_inp_dict["1"] = [
638            constant_op.constant(
639                i, dtypes.int32, shape=[2]) for i in range(4)
640        ]
641        dec_symbols_dict = {"0": 5, "1": 6}
642        def EncCellFn():
643          return rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
644        def DecCellsFn():
645          return dict((k, rnn_cell.BasicLSTMCell(2, state_is_tuple=True))
646                      for k in dec_symbols_dict)
647        outputs_dict, state_dict = (seq2seq_lib.one2many_rnn_seq2seq(
648            enc_inp, dec_inp_dict, EncCellFn(), DecCellsFn(),
649            2, dec_symbols_dict, embedding_size=2))
650
651        sess.run([variables.global_variables_initializer()])
652        res = sess.run(outputs_dict["0"])
653        self.assertEqual(3, len(res))
654        self.assertEqual((2, 5), res[0].shape)
655        res = sess.run(outputs_dict["1"])
656        self.assertEqual(4, len(res))
657        self.assertEqual((2, 6), res[0].shape)
658        res = sess.run([state_dict["0"]])
659        self.assertEqual((2, 2), res[0].c.shape)
660        self.assertEqual((2, 2), res[0].h.shape)
661        res = sess.run([state_dict["1"]])
662        self.assertEqual((2, 2), res[0].c.shape)
663        self.assertEqual((2, 2), res[0].h.shape)
664
665        # Test that previous-feeding model ignores inputs after the first, i.e.
666        # dec_inp_dict2 has different inputs from dec_inp_dict after the first
667        # time-step.
668        dec_inp_dict2 = {}
669        dec_inp_dict2["0"] = [
670            constant_op.constant(
671                0, dtypes.int32, shape=[2]) for _ in range(3)
672        ]
673        dec_inp_dict2["1"] = [
674            constant_op.constant(
675                0, dtypes.int32, shape=[2]) for _ in range(4)
676        ]
677        with variable_scope.variable_scope("other"):
678          outputs_dict3, _ = seq2seq_lib.one2many_rnn_seq2seq(
679              enc_inp,
680              dec_inp_dict2,
681              EncCellFn(),
682              DecCellsFn(),
683              2,
684              dec_symbols_dict,
685              embedding_size=2,
686              feed_previous=constant_op.constant(True))
687        with variable_scope.variable_scope("other_2"):
688          outputs_dict1, _ = seq2seq_lib.one2many_rnn_seq2seq(
689              enc_inp,
690              dec_inp_dict,
691              EncCellFn(),
692              DecCellsFn(),
693              2,
694              dec_symbols_dict,
695              embedding_size=2,
696              feed_previous=True)
697        with variable_scope.variable_scope("other_3"):
698          outputs_dict2, _ = seq2seq_lib.one2many_rnn_seq2seq(
699              enc_inp,
700              dec_inp_dict2,
701              EncCellFn(),
702              DecCellsFn(),
703              2,
704              dec_symbols_dict,
705              embedding_size=2,
706              feed_previous=True)
707        sess.run([variables.global_variables_initializer()])
708        res1 = sess.run(outputs_dict1["0"])
709        res2 = sess.run(outputs_dict2["0"])
710        res3 = sess.run(outputs_dict3["0"])
711        self.assertAllClose(res1, res2)
712        self.assertAllClose(res1, res3)
713
714  def testSequenceLoss(self):
715    with self.test_session() as sess:
716      logits = [constant_op.constant(i + 0.5, shape=[2, 5]) for i in range(3)]
717      targets = [
718          constant_op.constant(
719              i, dtypes.int32, shape=[2]) for i in range(3)
720      ]
721      weights = [constant_op.constant(1.0, shape=[2]) for i in range(3)]
722
723      average_loss_per_example = seq2seq_lib.sequence_loss(
724          logits,
725          targets,
726          weights,
727          average_across_timesteps=True,
728          average_across_batch=True)
729      res = sess.run(average_loss_per_example)
730      self.assertAllClose(1.60944, res)
731
732      average_loss_per_sequence = seq2seq_lib.sequence_loss(
733          logits,
734          targets,
735          weights,
736          average_across_timesteps=False,
737          average_across_batch=True)
738      res = sess.run(average_loss_per_sequence)
739      self.assertAllClose(4.828314, res)
740
741      total_loss = seq2seq_lib.sequence_loss(
742          logits,
743          targets,
744          weights,
745          average_across_timesteps=False,
746          average_across_batch=False)
747      res = sess.run(total_loss)
748      self.assertAllClose(9.656628, res)
749
750  def testSequenceLossByExample(self):
751    with self.test_session() as sess:
752      output_classes = 5
753      logits = [
754          constant_op.constant(
755              i + 0.5, shape=[2, output_classes]) for i in range(3)
756      ]
757      targets = [
758          constant_op.constant(
759              i, dtypes.int32, shape=[2]) for i in range(3)
760      ]
761      weights = [constant_op.constant(1.0, shape=[2]) for i in range(3)]
762
763      average_loss_per_example = (seq2seq_lib.sequence_loss_by_example(
764          logits, targets, weights, average_across_timesteps=True))
765      res = sess.run(average_loss_per_example)
766      self.assertAllClose(np.asarray([1.609438, 1.609438]), res)
767
768      loss_per_sequence = seq2seq_lib.sequence_loss_by_example(
769          logits, targets, weights, average_across_timesteps=False)
770      res = sess.run(loss_per_sequence)
771      self.assertAllClose(np.asarray([4.828314, 4.828314]), res)
772
773  # TODO(ebrevdo, lukaszkaiser): Re-enable once RNNCells allow reuse
774  # within a variable scope that already has a weights tensor.
775  #
776  # def testModelWithBucketsScopeAndLoss(self):
777  #   """Test variable scope reuse is not reset after model_with_buckets."""
778  #   classes = 10
779  #   buckets = [(4, 4), (8, 8)]
780
781  #   with self.test_session():
782  #     # Here comes a sample Seq2Seq model using GRU cells.
783  #     def SampleGRUSeq2Seq(enc_inp, dec_inp, weights, per_example_loss):
784  #       """Example sequence-to-sequence model that uses GRU cells."""
785
786  #       def GRUSeq2Seq(enc_inp, dec_inp):
787  #         cell = rnn_cell.MultiRNNCell(
788  #             [rnn_cell.GRUCell(24) for _ in range(2)])
789  #         return seq2seq_lib.embedding_attention_seq2seq(
790  #             enc_inp,
791  #             dec_inp,
792  #             cell,
793  #             num_encoder_symbols=classes,
794  #             num_decoder_symbols=classes,
795  #             embedding_size=24)
796
797  #       targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0]
798  #       return seq2seq_lib.model_with_buckets(
799  #           enc_inp,
800  #           dec_inp,
801  #           targets,
802  #           weights,
803  #           buckets,
804  #           GRUSeq2Seq,
805  #           per_example_loss=per_example_loss)
806
807  #     # Now we construct the copy model.
808  #     inp = [
809  #         array_ops.placeholder(
810  #             dtypes.int32, shape=[None]) for _ in range(8)
811  #     ]
812  #     out = [
813  #         array_ops.placeholder(
814  #             dtypes.int32, shape=[None]) for _ in range(8)
815  #     ]
816  #     weights = [
817  #         array_ops.ones_like(
818  #             inp[0], dtype=dtypes.float32) for _ in range(8)
819  #     ]
820  #     with variable_scope.variable_scope("root"):
821  #       _, losses1 = SampleGRUSeq2Seq(
822  #           inp, out, weights, per_example_loss=False)
823  #       # Now check that we did not accidentally set reuse.
824  #       self.assertEqual(False, variable_scope.get_variable_scope().reuse)
825  #     with variable_scope.variable_scope("new"):
826  #       _, losses2 = SampleGRUSeq2Seq
827  #           inp, out, weights, per_example_loss=True)
828  #       # First loss is scalar, the second one is a 1-dimensional tensor.
829  #       self.assertEqual([], losses1[0].get_shape().as_list())
830  #       self.assertEqual([None], losses2[0].get_shape().as_list())
831
832  def testModelWithBuckets(self):
833    """Larger tests that does full sequence-to-sequence model training."""
834    # We learn to copy 10 symbols in 2 buckets: length 4 and length 8.
835    classes = 10
836    buckets = [(4, 4), (8, 8)]
837    perplexities = [[], []]  # Results for each bucket.
838    random_seed.set_random_seed(111)
839    random.seed(111)
840    np.random.seed(111)
841
842    with self.test_session() as sess:
843      # We use sampled softmax so we keep output projection separate.
844      w = variable_scope.get_variable("proj_w", [24, classes])
845      w_t = array_ops.transpose(w)
846      b = variable_scope.get_variable("proj_b", [classes])
847
848      # Here comes a sample Seq2Seq model using GRU cells.
849      def SampleGRUSeq2Seq(enc_inp, dec_inp, weights):
850        """Example sequence-to-sequence model that uses GRU cells."""
851
852        def GRUSeq2Seq(enc_inp, dec_inp):
853          cell = rnn_cell.MultiRNNCell(
854              [rnn_cell.GRUCell(24) for _ in range(2)], state_is_tuple=True)
855          return seq2seq_lib.embedding_attention_seq2seq(
856              enc_inp,
857              dec_inp,
858              cell,
859              num_encoder_symbols=classes,
860              num_decoder_symbols=classes,
861              embedding_size=24,
862              output_projection=(w, b))
863
864        targets = [dec_inp[i + 1] for i in range(len(dec_inp) - 1)] + [0]
865
866        def SampledLoss(labels, logits):
867          labels = array_ops.reshape(labels, [-1, 1])
868          return nn_impl.sampled_softmax_loss(
869              weights=w_t,
870              biases=b,
871              labels=labels,
872              inputs=logits,
873              num_sampled=8,
874              num_classes=classes)
875
876        return seq2seq_lib.model_with_buckets(
877            enc_inp,
878            dec_inp,
879            targets,
880            weights,
881            buckets,
882            GRUSeq2Seq,
883            softmax_loss_function=SampledLoss)
884
885      # Now we construct the copy model.
886      batch_size = 8
887      inp = [
888          array_ops.placeholder(
889              dtypes.int32, shape=[None]) for _ in range(8)
890      ]
891      out = [
892          array_ops.placeholder(
893              dtypes.int32, shape=[None]) for _ in range(8)
894      ]
895      weights = [
896          array_ops.ones_like(
897              inp[0], dtype=dtypes.float32) for _ in range(8)
898      ]
899      with variable_scope.variable_scope("root"):
900        _, losses = SampleGRUSeq2Seq(inp, out, weights)
901        updates = []
902        params = variables.global_variables()
903        optimizer = adam.AdamOptimizer(0.03, epsilon=1e-5)
904        for i in range(len(buckets)):
905          full_grads = gradients_impl.gradients(losses[i], params)
906          grads, _ = clip_ops.clip_by_global_norm(full_grads, 30.0)
907          update = optimizer.apply_gradients(zip(grads, params))
908          updates.append(update)
909        sess.run([variables.global_variables_initializer()])
910      steps = 6
911      for _ in range(steps):
912        bucket = random.choice(np.arange(len(buckets)))
913        length = buckets[bucket][0]
914        i = [
915            np.array(
916                [np.random.randint(9) + 1 for _ in range(batch_size)],
917                dtype=np.int32) for _ in range(length)
918        ]
919        # 0 is our "GO" symbol here.
920        o = [np.array([0] * batch_size, dtype=np.int32)] + i
921        feed = {}
922        for i1, i2, o1, o2 in zip(inp[:length], i[:length], out[:length],
923                                  o[:length]):
924          feed[i1.name] = i2
925          feed[o1.name] = o2
926        if length < 8:  # For the 4-bucket, we need the 5th as target.
927          feed[out[length].name] = o[length]
928        res = sess.run([updates[bucket], losses[bucket]], feed)
929        perplexities[bucket].append(math.exp(float(res[1])))
930      for bucket in range(len(buckets)):
931        if len(perplexities[bucket]) > 1:  # Assert that perplexity went down.
932          self.assertLess(perplexities[bucket][-1],  # 20% margin of error.
933                          1.2 * perplexities[bucket][0])
934
935  def testModelWithBooleanFeedPrevious(self):
936    """Test the model behavior when feed_previous is True.
937
938    For example, the following two cases have the same effect:
939      - Train `embedding_rnn_seq2seq` with `feed_previous=True`, which contains
940        a `embedding_rnn_decoder` with `feed_previous=True` and
941        `update_embedding_for_previous=True`. The decoder is fed with "<Go>"
942        and outputs "A, B, C".
943      - Train `embedding_rnn_seq2seq` with `feed_previous=False`. The decoder
944        is fed with "<Go>, A, B".
945    """
946    num_encoder_symbols = 3
947    num_decoder_symbols = 5
948    batch_size = 2
949    num_enc_timesteps = 2
950    num_dec_timesteps = 3
951
952    def TestModel(seq2seq):
953      with self.test_session(graph=ops.Graph()) as sess:
954        random_seed.set_random_seed(111)
955        random.seed(111)
956        np.random.seed(111)
957
958        enc_inp = [
959            constant_op.constant(
960                i + 1, dtypes.int32, shape=[batch_size])
961            for i in range(num_enc_timesteps)
962        ]
963        dec_inp_fp_true = [
964            constant_op.constant(
965                i, dtypes.int32, shape=[batch_size])
966            for i in range(num_dec_timesteps)
967        ]
968        dec_inp_holder_fp_false = [
969            array_ops.placeholder(
970                dtypes.int32, shape=[batch_size])
971            for _ in range(num_dec_timesteps)
972        ]
973        targets = [
974            constant_op.constant(
975                i + 1, dtypes.int32, shape=[batch_size])
976            for i in range(num_dec_timesteps)
977        ]
978        weights = [
979            constant_op.constant(
980                1.0, shape=[batch_size]) for i in range(num_dec_timesteps)
981        ]
982
983        def ForwardBackward(enc_inp, dec_inp, feed_previous):
984          scope_name = "fp_{}".format(feed_previous)
985          with variable_scope.variable_scope(scope_name):
986            dec_op, _ = seq2seq(enc_inp, dec_inp, feed_previous=feed_previous)
987            net_variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
988                                               scope_name)
989          optimizer = adam.AdamOptimizer(0.03, epsilon=1e-5)
990          update_op = optimizer.minimize(
991              seq2seq_lib.sequence_loss(dec_op, targets, weights),
992              var_list=net_variables)
993          return dec_op, update_op, net_variables
994
995        dec_op_fp_true, update_fp_true, variables_fp_true = ForwardBackward(
996            enc_inp, dec_inp_fp_true, feed_previous=True)
997        _, update_fp_false, variables_fp_false = ForwardBackward(
998            enc_inp, dec_inp_holder_fp_false, feed_previous=False)
999
1000        sess.run(variables.global_variables_initializer())
1001
1002        # We only check consistencies between the variables existing in both
1003        # the models with True and False feed_previous. Variables created by
1004        # the loop_function in the model with True feed_previous are ignored.
1005        v_false_name_dict = {
1006            v.name.split("/", 1)[-1]: v
1007            for v in variables_fp_false
1008        }
1009        matched_variables = [(v, v_false_name_dict[v.name.split("/", 1)[-1]])
1010                             for v in variables_fp_true]
1011        for v_true, v_false in matched_variables:
1012          sess.run(state_ops.assign(v_false, v_true))
1013
1014        # Take the symbols generated by the decoder with feed_previous=True as
1015        # the true input symbols for the decoder with feed_previous=False.
1016        dec_fp_true = sess.run(dec_op_fp_true)
1017        output_symbols_fp_true = np.argmax(dec_fp_true, axis=2)
1018        dec_inp_fp_false = np.vstack((dec_inp_fp_true[0].eval(),
1019                                      output_symbols_fp_true[:-1]))
1020        sess.run(update_fp_true)
1021        sess.run(update_fp_false, {
1022            holder: inp
1023            for holder, inp in zip(dec_inp_holder_fp_false, dec_inp_fp_false)
1024        })
1025
1026        for v_true, v_false in matched_variables:
1027          self.assertAllClose(v_true.eval(), v_false.eval())
1028
1029    def EmbeddingRNNSeq2SeqF(enc_inp, dec_inp, feed_previous):
1030      cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
1031      return seq2seq_lib.embedding_rnn_seq2seq(
1032          enc_inp,
1033          dec_inp,
1034          cell,
1035          num_encoder_symbols,
1036          num_decoder_symbols,
1037          embedding_size=2,
1038          feed_previous=feed_previous)
1039
1040    def EmbeddingRNNSeq2SeqNoTupleF(enc_inp, dec_inp, feed_previous):
1041      cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
1042      return seq2seq_lib.embedding_rnn_seq2seq(
1043          enc_inp,
1044          dec_inp,
1045          cell,
1046          num_encoder_symbols,
1047          num_decoder_symbols,
1048          embedding_size=2,
1049          feed_previous=feed_previous)
1050
1051    def EmbeddingTiedRNNSeq2Seq(enc_inp, dec_inp, feed_previous):
1052      cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
1053      return seq2seq_lib.embedding_tied_rnn_seq2seq(
1054          enc_inp,
1055          dec_inp,
1056          cell,
1057          num_decoder_symbols,
1058          embedding_size=2,
1059          feed_previous=feed_previous)
1060
1061    def EmbeddingTiedRNNSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
1062      cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
1063      return seq2seq_lib.embedding_tied_rnn_seq2seq(
1064          enc_inp,
1065          dec_inp,
1066          cell,
1067          num_decoder_symbols,
1068          embedding_size=2,
1069          feed_previous=feed_previous)
1070
1071    def EmbeddingAttentionSeq2Seq(enc_inp, dec_inp, feed_previous):
1072      cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=True)
1073      return seq2seq_lib.embedding_attention_seq2seq(
1074          enc_inp,
1075          dec_inp,
1076          cell,
1077          num_encoder_symbols,
1078          num_decoder_symbols,
1079          embedding_size=2,
1080          feed_previous=feed_previous)
1081
1082    def EmbeddingAttentionSeq2SeqNoTuple(enc_inp, dec_inp, feed_previous):
1083      cell = rnn_cell.BasicLSTMCell(2, state_is_tuple=False)
1084      return seq2seq_lib.embedding_attention_seq2seq(
1085          enc_inp,
1086          dec_inp,
1087          cell,
1088          num_encoder_symbols,
1089          num_decoder_symbols,
1090          embedding_size=2,
1091          feed_previous=feed_previous)
1092
1093    for model in (EmbeddingRNNSeq2SeqF, EmbeddingRNNSeq2SeqNoTupleF,
1094                  EmbeddingTiedRNNSeq2Seq, EmbeddingTiedRNNSeq2SeqNoTuple,
1095                  EmbeddingAttentionSeq2Seq, EmbeddingAttentionSeq2SeqNoTuple):
1096      TestModel(model)
1097
1098
1099if __name__ == "__main__":
1100  test.main()
1101