1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module for constructing GridRNN cells"""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from collections import namedtuple
22import functools
23
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import nn
27from tensorflow.python.ops import variable_scope as vs
28
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.contrib import layers
31from tensorflow.contrib import rnn
32
33
34class GridRNNCell(rnn.RNNCell):
35  """Grid recurrent cell.
36
37  This implementation is based on:
38
39    http://arxiv.org/pdf/1507.01526v3.pdf
40
41    This is the generic implementation of GridRNN. Users can specify arbitrary
42    number of dimensions,
43    set some of them to be priority (section 3.2), non-recurrent (section 3.3)
44    and input/output dimensions (section 3.4).
45    Weight sharing can also be specified using the `tied` parameter.
46    Type of recurrent units can be specified via `cell_fn`.
47  """
48
49  def __init__(self,
50               num_units,
51               num_dims=1,
52               input_dims=None,
53               output_dims=None,
54               priority_dims=None,
55               non_recurrent_dims=None,
56               tied=False,
57               cell_fn=None,
58               non_recurrent_fn=None,
59               state_is_tuple=True,
60               output_is_tuple=True):
61    """Initialize the parameters of a Grid RNN cell
62
63    Args:
64      num_units: int, The number of units in all dimensions of this GridRNN cell
65      num_dims: int, Number of dimensions of this grid.
66      input_dims: int or list, List of dimensions which will receive input data.
67      output_dims: int or list, List of dimensions from which the output will be
68        recorded.
69      priority_dims: int or list, List of dimensions to be considered as
70        priority dimensions.
71              If None, no dimension is prioritized.
72      non_recurrent_dims: int or list, List of dimensions that are not
73        recurrent.
74              The transfer function for non-recurrent dimensions is specified
75                via `non_recurrent_fn`, which is
76                default to be `tensorflow.nn.relu`.
77      tied: bool, Whether to share the weights among the dimensions of this
78        GridRNN cell.
79              If there are non-recurrent dimensions in the grid, weights are
80                shared between each group of recurrent and non-recurrent
81                dimensions.
82      cell_fn: function, a function which returns the recurrent cell object.
83        Has to be in the following signature:
84              ```
85              def cell_func(num_units):
86                # ...
87              ```
88              and returns an object of type `RNNCell`. If None, LSTMCell with
89                default parameters will be used.
90        Note that if you use a custom RNNCell (with `cell_fn`), it is your
91        responsibility to make sure the inner cell use `state_is_tuple=True`.
92
93      non_recurrent_fn: a tensorflow Op that will be the transfer function of
94        the non-recurrent dimensions
95      state_is_tuple: If True, accepted and returned states are tuples of the
96        states of the recurrent dimensions. If False, they are concatenated
97        along the column axis. The latter behavior will soon be deprecated.
98
99        Note that if you use a custom RNNCell (with `cell_fn`), it is your
100        responsibility to make sure the inner cell use `state_is_tuple=True`.
101
102      output_is_tuple: If True, the output is a tuple of the outputs of the
103        recurrent dimensions. If False, they are concatenated along the
104        column axis. The later behavior will soon be deprecated.
105
106    Raises:
107      TypeError: if cell_fn does not return an RNNCell instance.
108    """
109    if not state_is_tuple:
110      logging.warning('%s: Using a concatenated state is slower and will '
111                      'soon be deprecated.  Use state_is_tuple=True.', self)
112    if not output_is_tuple:
113      logging.warning('%s: Using a concatenated output is slower and will'
114                      'soon be deprecated.  Use output_is_tuple=True.', self)
115
116    if num_dims < 1:
117      raise ValueError('dims must be >= 1: {}'.format(num_dims))
118
119    self._config = _parse_rnn_config(num_dims, input_dims, output_dims,
120                                     priority_dims, non_recurrent_dims,
121                                     non_recurrent_fn or nn.relu, tied,
122                                     num_units)
123
124    self._state_is_tuple = state_is_tuple
125    self._output_is_tuple = output_is_tuple
126
127    if cell_fn is None:
128      my_cell_fn = functools.partial(
129          rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple)
130    else:
131      my_cell_fn = lambda: cell_fn(num_units)
132    if tied:
133      self._cells = [my_cell_fn()] * num_dims
134    else:
135      self._cells = [my_cell_fn() for _ in range(num_dims)]
136    if not isinstance(self._cells[0], rnn.RNNCell):
137      raise TypeError('cell_fn must return an RNNCell instance, saw: %s' %
138                      type(self._cells[0]))
139
140    if self._output_is_tuple:
141      self._output_size = tuple(self._cells[0].output_size
142                                for _ in self._config.outputs)
143    else:
144      self._output_size = self._cells[0].output_size * len(self._config.outputs)
145
146    if self._state_is_tuple:
147      self._state_size = tuple(self._cells[0].state_size
148                               for _ in self._config.recurrents)
149    else:
150      self._state_size = self._cell_state_size() * len(self._config.recurrents)
151
152  @property
153  def output_size(self):
154    return self._output_size
155
156  @property
157  def state_size(self):
158    return self._state_size
159
160  def __call__(self, inputs, state, scope=None):
161    """Run one step of GridRNN.
162
163    Args:
164      inputs: input Tensor, 2D, batch x input_size. Or None
165      state: state Tensor, 2D, batch x state_size. Note that state_size =
166        cell_state_size * recurrent_dims
167      scope: VariableScope for the created subgraph; defaults to "GridRNNCell".
168
169    Returns:
170      A tuple containing:
171
172      - A 2D, batch x output_size, Tensor representing the output of the cell
173        after reading "inputs" when previous state was "state".
174      - A 2D, batch x state_size, Tensor representing the new state of the cell
175        after reading "inputs" when previous state was "state".
176    """
177    conf = self._config
178    dtype = inputs.dtype
179
180    c_prev, m_prev, cell_output_size = self._extract_states(state)
181
182    new_output = [None] * conf.num_dims
183    new_state = [None] * conf.num_dims
184
185    with vs.variable_scope(scope or type(self).__name__):  # GridRNNCell
186      # project input, populate c_prev and m_prev
187      self._project_input(inputs, c_prev, m_prev, cell_output_size > 0)
188
189      # propagate along dimensions, first for non-priority dimensions
190      # then priority dimensions
191      _propagate(conf.non_priority, conf, self._cells, c_prev, m_prev,
192                 new_output, new_state, True)
193      _propagate(conf.priority, conf, self._cells,
194                 c_prev, m_prev, new_output, new_state, False)
195
196      # collect outputs and states
197      output_tensors = [new_output[i] for i in self._config.outputs]
198      if self._output_is_tuple:
199        output = tuple(output_tensors)
200      else:
201        if output_tensors:
202          output = array_ops.concat(output_tensors, 1)
203        else:
204          output = array_ops.zeros([0, 0], dtype)
205
206      if self._state_is_tuple:
207        states = tuple(new_state[i] for i in self._config.recurrents)
208      else:
209        # concat each state first, then flatten the whole thing
210        state_tensors = [
211            x for i in self._config.recurrents for x in new_state[i]
212        ]
213        if state_tensors:
214          states = array_ops.concat(state_tensors, 1)
215        else:
216          states = array_ops.zeros([0, 0], dtype)
217
218    return output, states
219
220  def _extract_states(self, state):
221    """Extract the cell and previous output tensors from the given state.
222
223    Args:
224      state: The RNN state.
225
226    Returns:
227      Tuple of the cell value, previous output, and cell_output_size.
228
229    Raises:
230      ValueError: If len(self._config.recurrents) != len(state).
231    """
232    conf = self._config
233
234    # c_prev is `m` (cell value), and
235    # m_prev is `h` (previous output) in the paper.
236    # Keeping c and m here for consistency with the codebase
237    c_prev = [None] * conf.num_dims
238    m_prev = [None] * conf.num_dims
239
240    # for LSTM   : state = memory cell + output, hence cell_output_size > 0
241    # for GRU/RNN: state = output (whose size is equal to _num_units),
242    #              hence cell_output_size = 0
243    total_cell_state_size = self._cell_state_size()
244    cell_output_size = total_cell_state_size - conf.num_units
245
246    if self._state_is_tuple:
247      if len(conf.recurrents) != len(state):
248        raise ValueError('Expected state as a tuple of {} '
249                         'element'.format(len(conf.recurrents)))
250
251      for recurrent_dim, recurrent_state in zip(conf.recurrents, state):
252        if cell_output_size > 0:
253          c_prev[recurrent_dim], m_prev[recurrent_dim] = recurrent_state
254        else:
255          m_prev[recurrent_dim] = recurrent_state
256    else:
257      for recurrent_dim, start_idx in zip(conf.recurrents,
258                                          range(0, self.state_size,
259                                                total_cell_state_size)):
260        if cell_output_size > 0:
261          c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
262                                                  [-1, conf.num_units])
263          m_prev[recurrent_dim] = array_ops.slice(
264              state, [0, start_idx + conf.num_units], [-1, cell_output_size])
265        else:
266          m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
267                                                  [-1, conf.num_units])
268    return c_prev, m_prev, cell_output_size
269
270  def _project_input(self, inputs, c_prev, m_prev, with_c):
271    """Fills in c_prev and m_prev with projected input, for input dimensions.
272
273    Args:
274      inputs: inputs tensor
275      c_prev: cell value
276      m_prev: previous output
277      with_c: boolean; whether to include project_c.
278
279    Raises:
280      ValueError: if len(self._config.input) != len(inputs)
281    """
282    conf = self._config
283
284    if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 and
285        conf.inputs):
286      if isinstance(inputs, tuple):
287        if len(conf.inputs) != len(inputs):
288          raise ValueError('Expect inputs as a tuple of {} '
289                           'tensors'.format(len(conf.inputs)))
290        input_splits = inputs
291      else:
292        input_splits = array_ops.split(
293            value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
294      input_sz = input_splits[0].get_shape().with_rank(2)[1].value
295
296      for i, j in enumerate(conf.inputs):
297        input_project_m = vs.get_variable(
298            'project_m_{}'.format(j), [input_sz, conf.num_units],
299            dtype=inputs.dtype)
300        m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
301
302        if with_c:
303          input_project_c = vs.get_variable(
304              'project_c_{}'.format(j), [input_sz, conf.num_units],
305              dtype=inputs.dtype)
306          c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
307
308  def _cell_state_size(self):
309    """Total size of the state of the inner cell used in this grid.
310
311    Returns:
312      Total size of the state of the inner cell.
313    """
314    state_sizes = self._cells[0].state_size
315    if isinstance(state_sizes, tuple):
316      return sum(state_sizes)
317    return state_sizes
318
319
320"""Specialized cells, for convenience
321"""
322
323
324class Grid1BasicRNNCell(GridRNNCell):
325  """1D BasicRNN cell"""
326
327  def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True):
328    super(Grid1BasicRNNCell, self).__init__(
329        num_units=num_units,
330        num_dims=1,
331        input_dims=0,
332        output_dims=0,
333        priority_dims=0,
334        tied=False,
335        cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
336        state_is_tuple=state_is_tuple,
337        output_is_tuple=output_is_tuple)
338
339
340class Grid2BasicRNNCell(GridRNNCell):
341  """2D BasicRNN cell
342
343  This creates a 2D cell which receives input and gives output in the first
344  dimension.
345
346  The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
347  specified.
348  """
349
350  def __init__(self,
351               num_units,
352               tied=False,
353               non_recurrent_fn=None,
354               state_is_tuple=True,
355               output_is_tuple=True):
356    super(Grid2BasicRNNCell, self).__init__(
357        num_units=num_units,
358        num_dims=2,
359        input_dims=0,
360        output_dims=0,
361        priority_dims=0,
362        tied=tied,
363        non_recurrent_dims=None if non_recurrent_fn is None else 0,
364        cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
365        non_recurrent_fn=non_recurrent_fn,
366        state_is_tuple=state_is_tuple,
367        output_is_tuple=output_is_tuple)
368
369
370class Grid1BasicLSTMCell(GridRNNCell):
371  """1D BasicLSTM cell."""
372
373  def __init__(self,
374               num_units,
375               forget_bias=1,
376               state_is_tuple=True,
377               output_is_tuple=True):
378    def cell_fn(n):
379      return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
380    super(Grid1BasicLSTMCell, self).__init__(
381        num_units=num_units,
382        num_dims=1,
383        input_dims=0,
384        output_dims=0,
385        priority_dims=0,
386        tied=False,
387        cell_fn=cell_fn,
388        state_is_tuple=state_is_tuple,
389        output_is_tuple=output_is_tuple)
390
391
392class Grid2BasicLSTMCell(GridRNNCell):
393  """2D BasicLSTM cell.
394
395  This creates a 2D cell which receives input and gives output in the first
396  dimension.
397
398  The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
399  specified.
400  """
401
402  def __init__(self,
403               num_units,
404               tied=False,
405               non_recurrent_fn=None,
406               forget_bias=1,
407               state_is_tuple=True,
408               output_is_tuple=True):
409    def cell_fn(n):
410      return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
411    super(Grid2BasicLSTMCell, self).__init__(
412        num_units=num_units,
413        num_dims=2,
414        input_dims=0,
415        output_dims=0,
416        priority_dims=0,
417        tied=tied,
418        non_recurrent_dims=None if non_recurrent_fn is None else 0,
419        cell_fn=cell_fn,
420        non_recurrent_fn=non_recurrent_fn,
421        state_is_tuple=state_is_tuple,
422        output_is_tuple=output_is_tuple)
423
424
425class Grid1LSTMCell(GridRNNCell):
426  """1D LSTM cell.
427
428  This is different from Grid1BasicLSTMCell because it gives options to
429  specify the forget bias and enabling peepholes.
430  """
431
432  def __init__(self,
433               num_units,
434               use_peepholes=False,
435               forget_bias=1.0,
436               state_is_tuple=True,
437               output_is_tuple=True):
438
439    def cell_fn(n):
440      return rnn.LSTMCell(
441          num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
442
443    super(Grid1LSTMCell, self).__init__(
444        num_units=num_units,
445        num_dims=1,
446        input_dims=0,
447        output_dims=0,
448        priority_dims=0,
449        cell_fn=cell_fn,
450        state_is_tuple=state_is_tuple,
451        output_is_tuple=output_is_tuple)
452
453
454class Grid2LSTMCell(GridRNNCell):
455  """2D LSTM cell.
456
457    This creates a 2D cell which receives input and gives output in the first
458    dimension.
459    The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
460    specified.
461  """
462
463  def __init__(self,
464               num_units,
465               tied=False,
466               non_recurrent_fn=None,
467               use_peepholes=False,
468               forget_bias=1.0,
469               state_is_tuple=True,
470               output_is_tuple=True):
471
472    def cell_fn(n):
473      return rnn.LSTMCell(
474          num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
475
476    super(Grid2LSTMCell, self).__init__(
477        num_units=num_units,
478        num_dims=2,
479        input_dims=0,
480        output_dims=0,
481        priority_dims=0,
482        tied=tied,
483        non_recurrent_dims=None if non_recurrent_fn is None else 0,
484        cell_fn=cell_fn,
485        non_recurrent_fn=non_recurrent_fn,
486        state_is_tuple=state_is_tuple,
487        output_is_tuple=output_is_tuple)
488
489
490class Grid3LSTMCell(GridRNNCell):
491  """3D BasicLSTM cell.
492
493    This creates a 2D cell which receives input and gives output in the first
494    dimension.
495    The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
496    specified.
497    The second and third dimensions are LSTM.
498  """
499
500  def __init__(self,
501               num_units,
502               tied=False,
503               non_recurrent_fn=None,
504               use_peepholes=False,
505               forget_bias=1.0,
506               state_is_tuple=True,
507               output_is_tuple=True):
508
509    def cell_fn(n):
510      return rnn.LSTMCell(
511          num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
512
513    super(Grid3LSTMCell, self).__init__(
514        num_units=num_units,
515        num_dims=3,
516        input_dims=0,
517        output_dims=0,
518        priority_dims=0,
519        tied=tied,
520        non_recurrent_dims=None if non_recurrent_fn is None else 0,
521        cell_fn=cell_fn,
522        non_recurrent_fn=non_recurrent_fn,
523        state_is_tuple=state_is_tuple,
524        output_is_tuple=output_is_tuple)
525
526
527class Grid2GRUCell(GridRNNCell):
528  """2D LSTM cell.
529
530    This creates a 2D cell which receives input and gives output in the first
531    dimension.
532    The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
533    specified.
534  """
535
536  def __init__(self,
537               num_units,
538               tied=False,
539               non_recurrent_fn=None,
540               state_is_tuple=True,
541               output_is_tuple=True):
542    super(Grid2GRUCell, self).__init__(
543        num_units=num_units,
544        num_dims=2,
545        input_dims=0,
546        output_dims=0,
547        priority_dims=0,
548        tied=tied,
549        non_recurrent_dims=None if non_recurrent_fn is None else 0,
550        cell_fn=lambda n: rnn.GRUCell(num_units=n),
551        non_recurrent_fn=non_recurrent_fn,
552        state_is_tuple=state_is_tuple,
553        output_is_tuple=output_is_tuple)
554
555
556# Helpers
557
558_GridRNNDimension = namedtuple('_GridRNNDimension', [
559    'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'
560])
561
562_GridRNNConfig = namedtuple('_GridRNNConfig',
563                            ['num_dims', 'dims', 'inputs', 'outputs',
564                             'recurrents', 'priority', 'non_priority', 'tied',
565                             'num_units'])
566
567
568def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
569                      ls_non_recurrent_dims, non_recurrent_fn, tied, num_units):
570  def check_dim_list(ls):
571    if ls is None:
572      ls = []
573    if not isinstance(ls, (list, tuple)):
574      ls = [ls]
575    ls = sorted(set(ls))
576    if any(_ < 0 or _ >= num_dims for _ in ls):
577      raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls,
578                                                                     num_dims))
579    return ls
580
581  input_dims = check_dim_list(ls_input_dims)
582  output_dims = check_dim_list(ls_output_dims)
583  priority_dims = check_dim_list(ls_priority_dims)
584  non_recurrent_dims = check_dim_list(ls_non_recurrent_dims)
585
586  rnn_dims = []
587  for i in range(num_dims):
588    rnn_dims.append(
589        _GridRNNDimension(
590            idx=i,
591            is_input=(i in input_dims),
592            is_output=(i in output_dims),
593            is_priority=(i in priority_dims),
594            non_recurrent_fn=non_recurrent_fn
595            if i in non_recurrent_dims else None))
596  return _GridRNNConfig(
597      num_dims=num_dims,
598      dims=rnn_dims,
599      inputs=input_dims,
600      outputs=output_dims,
601      recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
602      priority=priority_dims,
603      non_priority=[x for x in range(num_dims) if x not in priority_dims],
604      tied=tied,
605      num_units=num_units)
606
607
608def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
609               first_call):
610  """Propagates through all the cells in dim_indices dimensions.
611  """
612  if len(dim_indices) == 0:
613    return
614
615  # Because of the way RNNCells are implemented, we take the last dimension
616  # (H_{N-1}) out and feed it as the state of the RNN cell
617  # (in `last_dim_output`).
618  # The input of the cell (H_0 to H_{N-2}) are concatenated into `cell_inputs`
619  if conf.num_dims > 1:
620    ls_cell_inputs = [None] * (conf.num_dims - 1)
621    for d in conf.dims[:-1]:
622      if new_output[d.idx] is None:
623        ls_cell_inputs[d.idx] = m_prev[d.idx]
624      else:
625        ls_cell_inputs[d.idx] = new_output[d.idx]
626    cell_inputs = array_ops.concat(ls_cell_inputs, 1)
627  else:
628    cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
629                                  m_prev[0].dtype)
630
631  last_dim_output = (new_output[-1]
632                     if new_output[-1] is not None else m_prev[-1])
633
634  for i in dim_indices:
635    d = conf.dims[i]
636    if d.non_recurrent_fn:
637      if conf.num_dims > 1:
638        linear_args = array_ops.concat([cell_inputs, last_dim_output], 1)
639      else:
640        linear_args = last_dim_output
641      with vs.variable_scope('non_recurrent' if conf.tied else
642                             'non_recurrent/cell_{}'.format(i)):
643        if conf.tied and not (first_call and i == dim_indices[0]):
644          vs.get_variable_scope().reuse_variables()
645
646        new_output[d.idx] = layers.fully_connected(
647            linear_args,
648            num_outputs=conf.num_units,
649            activation_fn=d.non_recurrent_fn,
650            weights_initializer=(vs.get_variable_scope().initializer or
651                                 layers.initializers.xavier_initializer),
652            weights_regularizer=vs.get_variable_scope().regularizer)
653    else:
654      if c_prev[i] is not None:
655        cell_state = (c_prev[i], last_dim_output)
656      else:
657        # for GRU/RNN, the state is just the previous output
658        cell_state = last_dim_output
659
660      with vs.variable_scope('recurrent' if conf.tied else
661                             'recurrent/cell_{}'.format(i)):
662        if conf.tied and not (first_call and i == dim_indices[0]):
663          vs.get_variable_scope().reuse_variables()
664        cell = cells[i]
665        new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state)
666