1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module for constructing GridRNN cells""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from collections import namedtuple 22import functools 23 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import nn 27from tensorflow.python.ops import variable_scope as vs 28 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.contrib import layers 31from tensorflow.contrib import rnn 32 33 34class GridRNNCell(rnn.RNNCell): 35 """Grid recurrent cell. 36 37 This implementation is based on: 38 39 http://arxiv.org/pdf/1507.01526v3.pdf 40 41 This is the generic implementation of GridRNN. Users can specify arbitrary 42 number of dimensions, 43 set some of them to be priority (section 3.2), non-recurrent (section 3.3) 44 and input/output dimensions (section 3.4). 45 Weight sharing can also be specified using the `tied` parameter. 46 Type of recurrent units can be specified via `cell_fn`. 47 """ 48 49 def __init__(self, 50 num_units, 51 num_dims=1, 52 input_dims=None, 53 output_dims=None, 54 priority_dims=None, 55 non_recurrent_dims=None, 56 tied=False, 57 cell_fn=None, 58 non_recurrent_fn=None, 59 state_is_tuple=True, 60 output_is_tuple=True): 61 """Initialize the parameters of a Grid RNN cell 62 63 Args: 64 num_units: int, The number of units in all dimensions of this GridRNN cell 65 num_dims: int, Number of dimensions of this grid. 66 input_dims: int or list, List of dimensions which will receive input data. 67 output_dims: int or list, List of dimensions from which the output will be 68 recorded. 69 priority_dims: int or list, List of dimensions to be considered as 70 priority dimensions. 71 If None, no dimension is prioritized. 72 non_recurrent_dims: int or list, List of dimensions that are not 73 recurrent. 74 The transfer function for non-recurrent dimensions is specified 75 via `non_recurrent_fn`, which is 76 default to be `tensorflow.nn.relu`. 77 tied: bool, Whether to share the weights among the dimensions of this 78 GridRNN cell. 79 If there are non-recurrent dimensions in the grid, weights are 80 shared between each group of recurrent and non-recurrent 81 dimensions. 82 cell_fn: function, a function which returns the recurrent cell object. 83 Has to be in the following signature: 84 ``` 85 def cell_func(num_units): 86 # ... 87 ``` 88 and returns an object of type `RNNCell`. If None, LSTMCell with 89 default parameters will be used. 90 Note that if you use a custom RNNCell (with `cell_fn`), it is your 91 responsibility to make sure the inner cell use `state_is_tuple=True`. 92 93 non_recurrent_fn: a tensorflow Op that will be the transfer function of 94 the non-recurrent dimensions 95 state_is_tuple: If True, accepted and returned states are tuples of the 96 states of the recurrent dimensions. If False, they are concatenated 97 along the column axis. The latter behavior will soon be deprecated. 98 99 Note that if you use a custom RNNCell (with `cell_fn`), it is your 100 responsibility to make sure the inner cell use `state_is_tuple=True`. 101 102 output_is_tuple: If True, the output is a tuple of the outputs of the 103 recurrent dimensions. If False, they are concatenated along the 104 column axis. The later behavior will soon be deprecated. 105 106 Raises: 107 TypeError: if cell_fn does not return an RNNCell instance. 108 """ 109 if not state_is_tuple: 110 logging.warning('%s: Using a concatenated state is slower and will ' 111 'soon be deprecated. Use state_is_tuple=True.', self) 112 if not output_is_tuple: 113 logging.warning('%s: Using a concatenated output is slower and will' 114 'soon be deprecated. Use output_is_tuple=True.', self) 115 116 if num_dims < 1: 117 raise ValueError('dims must be >= 1: {}'.format(num_dims)) 118 119 self._config = _parse_rnn_config(num_dims, input_dims, output_dims, 120 priority_dims, non_recurrent_dims, 121 non_recurrent_fn or nn.relu, tied, 122 num_units) 123 124 self._state_is_tuple = state_is_tuple 125 self._output_is_tuple = output_is_tuple 126 127 if cell_fn is None: 128 my_cell_fn = functools.partial( 129 rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple) 130 else: 131 my_cell_fn = lambda: cell_fn(num_units) 132 if tied: 133 self._cells = [my_cell_fn()] * num_dims 134 else: 135 self._cells = [my_cell_fn() for _ in range(num_dims)] 136 if not isinstance(self._cells[0], rnn.RNNCell): 137 raise TypeError('cell_fn must return an RNNCell instance, saw: %s' % 138 type(self._cells[0])) 139 140 if self._output_is_tuple: 141 self._output_size = tuple(self._cells[0].output_size 142 for _ in self._config.outputs) 143 else: 144 self._output_size = self._cells[0].output_size * len(self._config.outputs) 145 146 if self._state_is_tuple: 147 self._state_size = tuple(self._cells[0].state_size 148 for _ in self._config.recurrents) 149 else: 150 self._state_size = self._cell_state_size() * len(self._config.recurrents) 151 152 @property 153 def output_size(self): 154 return self._output_size 155 156 @property 157 def state_size(self): 158 return self._state_size 159 160 def __call__(self, inputs, state, scope=None): 161 """Run one step of GridRNN. 162 163 Args: 164 inputs: input Tensor, 2D, batch x input_size. Or None 165 state: state Tensor, 2D, batch x state_size. Note that state_size = 166 cell_state_size * recurrent_dims 167 scope: VariableScope for the created subgraph; defaults to "GridRNNCell". 168 169 Returns: 170 A tuple containing: 171 172 - A 2D, batch x output_size, Tensor representing the output of the cell 173 after reading "inputs" when previous state was "state". 174 - A 2D, batch x state_size, Tensor representing the new state of the cell 175 after reading "inputs" when previous state was "state". 176 """ 177 conf = self._config 178 dtype = inputs.dtype 179 180 c_prev, m_prev, cell_output_size = self._extract_states(state) 181 182 new_output = [None] * conf.num_dims 183 new_state = [None] * conf.num_dims 184 185 with vs.variable_scope(scope or type(self).__name__): # GridRNNCell 186 # project input, populate c_prev and m_prev 187 self._project_input(inputs, c_prev, m_prev, cell_output_size > 0) 188 189 # propagate along dimensions, first for non-priority dimensions 190 # then priority dimensions 191 _propagate(conf.non_priority, conf, self._cells, c_prev, m_prev, 192 new_output, new_state, True) 193 _propagate(conf.priority, conf, self._cells, 194 c_prev, m_prev, new_output, new_state, False) 195 196 # collect outputs and states 197 output_tensors = [new_output[i] for i in self._config.outputs] 198 if self._output_is_tuple: 199 output = tuple(output_tensors) 200 else: 201 if output_tensors: 202 output = array_ops.concat(output_tensors, 1) 203 else: 204 output = array_ops.zeros([0, 0], dtype) 205 206 if self._state_is_tuple: 207 states = tuple(new_state[i] for i in self._config.recurrents) 208 else: 209 # concat each state first, then flatten the whole thing 210 state_tensors = [ 211 x for i in self._config.recurrents for x in new_state[i] 212 ] 213 if state_tensors: 214 states = array_ops.concat(state_tensors, 1) 215 else: 216 states = array_ops.zeros([0, 0], dtype) 217 218 return output, states 219 220 def _extract_states(self, state): 221 """Extract the cell and previous output tensors from the given state. 222 223 Args: 224 state: The RNN state. 225 226 Returns: 227 Tuple of the cell value, previous output, and cell_output_size. 228 229 Raises: 230 ValueError: If len(self._config.recurrents) != len(state). 231 """ 232 conf = self._config 233 234 # c_prev is `m` (cell value), and 235 # m_prev is `h` (previous output) in the paper. 236 # Keeping c and m here for consistency with the codebase 237 c_prev = [None] * conf.num_dims 238 m_prev = [None] * conf.num_dims 239 240 # for LSTM : state = memory cell + output, hence cell_output_size > 0 241 # for GRU/RNN: state = output (whose size is equal to _num_units), 242 # hence cell_output_size = 0 243 total_cell_state_size = self._cell_state_size() 244 cell_output_size = total_cell_state_size - conf.num_units 245 246 if self._state_is_tuple: 247 if len(conf.recurrents) != len(state): 248 raise ValueError('Expected state as a tuple of {} ' 249 'element'.format(len(conf.recurrents))) 250 251 for recurrent_dim, recurrent_state in zip(conf.recurrents, state): 252 if cell_output_size > 0: 253 c_prev[recurrent_dim], m_prev[recurrent_dim] = recurrent_state 254 else: 255 m_prev[recurrent_dim] = recurrent_state 256 else: 257 for recurrent_dim, start_idx in zip(conf.recurrents, 258 range(0, self.state_size, 259 total_cell_state_size)): 260 if cell_output_size > 0: 261 c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], 262 [-1, conf.num_units]) 263 m_prev[recurrent_dim] = array_ops.slice( 264 state, [0, start_idx + conf.num_units], [-1, cell_output_size]) 265 else: 266 m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx], 267 [-1, conf.num_units]) 268 return c_prev, m_prev, cell_output_size 269 270 def _project_input(self, inputs, c_prev, m_prev, with_c): 271 """Fills in c_prev and m_prev with projected input, for input dimensions. 272 273 Args: 274 inputs: inputs tensor 275 c_prev: cell value 276 m_prev: previous output 277 with_c: boolean; whether to include project_c. 278 279 Raises: 280 ValueError: if len(self._config.input) != len(inputs) 281 """ 282 conf = self._config 283 284 if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 and 285 conf.inputs): 286 if isinstance(inputs, tuple): 287 if len(conf.inputs) != len(inputs): 288 raise ValueError('Expect inputs as a tuple of {} ' 289 'tensors'.format(len(conf.inputs))) 290 input_splits = inputs 291 else: 292 input_splits = array_ops.split( 293 value=inputs, num_or_size_splits=len(conf.inputs), axis=1) 294 input_sz = input_splits[0].get_shape().with_rank(2)[1].value 295 296 for i, j in enumerate(conf.inputs): 297 input_project_m = vs.get_variable( 298 'project_m_{}'.format(j), [input_sz, conf.num_units], 299 dtype=inputs.dtype) 300 m_prev[j] = math_ops.matmul(input_splits[i], input_project_m) 301 302 if with_c: 303 input_project_c = vs.get_variable( 304 'project_c_{}'.format(j), [input_sz, conf.num_units], 305 dtype=inputs.dtype) 306 c_prev[j] = math_ops.matmul(input_splits[i], input_project_c) 307 308 def _cell_state_size(self): 309 """Total size of the state of the inner cell used in this grid. 310 311 Returns: 312 Total size of the state of the inner cell. 313 """ 314 state_sizes = self._cells[0].state_size 315 if isinstance(state_sizes, tuple): 316 return sum(state_sizes) 317 return state_sizes 318 319 320"""Specialized cells, for convenience 321""" 322 323 324class Grid1BasicRNNCell(GridRNNCell): 325 """1D BasicRNN cell""" 326 327 def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True): 328 super(Grid1BasicRNNCell, self).__init__( 329 num_units=num_units, 330 num_dims=1, 331 input_dims=0, 332 output_dims=0, 333 priority_dims=0, 334 tied=False, 335 cell_fn=lambda n: rnn.BasicRNNCell(num_units=n), 336 state_is_tuple=state_is_tuple, 337 output_is_tuple=output_is_tuple) 338 339 340class Grid2BasicRNNCell(GridRNNCell): 341 """2D BasicRNN cell 342 343 This creates a 2D cell which receives input and gives output in the first 344 dimension. 345 346 The first dimension can optionally be non-recurrent if `non_recurrent_fn` is 347 specified. 348 """ 349 350 def __init__(self, 351 num_units, 352 tied=False, 353 non_recurrent_fn=None, 354 state_is_tuple=True, 355 output_is_tuple=True): 356 super(Grid2BasicRNNCell, self).__init__( 357 num_units=num_units, 358 num_dims=2, 359 input_dims=0, 360 output_dims=0, 361 priority_dims=0, 362 tied=tied, 363 non_recurrent_dims=None if non_recurrent_fn is None else 0, 364 cell_fn=lambda n: rnn.BasicRNNCell(num_units=n), 365 non_recurrent_fn=non_recurrent_fn, 366 state_is_tuple=state_is_tuple, 367 output_is_tuple=output_is_tuple) 368 369 370class Grid1BasicLSTMCell(GridRNNCell): 371 """1D BasicLSTM cell.""" 372 373 def __init__(self, 374 num_units, 375 forget_bias=1, 376 state_is_tuple=True, 377 output_is_tuple=True): 378 def cell_fn(n): 379 return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias) 380 super(Grid1BasicLSTMCell, self).__init__( 381 num_units=num_units, 382 num_dims=1, 383 input_dims=0, 384 output_dims=0, 385 priority_dims=0, 386 tied=False, 387 cell_fn=cell_fn, 388 state_is_tuple=state_is_tuple, 389 output_is_tuple=output_is_tuple) 390 391 392class Grid2BasicLSTMCell(GridRNNCell): 393 """2D BasicLSTM cell. 394 395 This creates a 2D cell which receives input and gives output in the first 396 dimension. 397 398 The first dimension can optionally be non-recurrent if `non_recurrent_fn` is 399 specified. 400 """ 401 402 def __init__(self, 403 num_units, 404 tied=False, 405 non_recurrent_fn=None, 406 forget_bias=1, 407 state_is_tuple=True, 408 output_is_tuple=True): 409 def cell_fn(n): 410 return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias) 411 super(Grid2BasicLSTMCell, self).__init__( 412 num_units=num_units, 413 num_dims=2, 414 input_dims=0, 415 output_dims=0, 416 priority_dims=0, 417 tied=tied, 418 non_recurrent_dims=None if non_recurrent_fn is None else 0, 419 cell_fn=cell_fn, 420 non_recurrent_fn=non_recurrent_fn, 421 state_is_tuple=state_is_tuple, 422 output_is_tuple=output_is_tuple) 423 424 425class Grid1LSTMCell(GridRNNCell): 426 """1D LSTM cell. 427 428 This is different from Grid1BasicLSTMCell because it gives options to 429 specify the forget bias and enabling peepholes. 430 """ 431 432 def __init__(self, 433 num_units, 434 use_peepholes=False, 435 forget_bias=1.0, 436 state_is_tuple=True, 437 output_is_tuple=True): 438 439 def cell_fn(n): 440 return rnn.LSTMCell( 441 num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) 442 443 super(Grid1LSTMCell, self).__init__( 444 num_units=num_units, 445 num_dims=1, 446 input_dims=0, 447 output_dims=0, 448 priority_dims=0, 449 cell_fn=cell_fn, 450 state_is_tuple=state_is_tuple, 451 output_is_tuple=output_is_tuple) 452 453 454class Grid2LSTMCell(GridRNNCell): 455 """2D LSTM cell. 456 457 This creates a 2D cell which receives input and gives output in the first 458 dimension. 459 The first dimension can optionally be non-recurrent if `non_recurrent_fn` is 460 specified. 461 """ 462 463 def __init__(self, 464 num_units, 465 tied=False, 466 non_recurrent_fn=None, 467 use_peepholes=False, 468 forget_bias=1.0, 469 state_is_tuple=True, 470 output_is_tuple=True): 471 472 def cell_fn(n): 473 return rnn.LSTMCell( 474 num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) 475 476 super(Grid2LSTMCell, self).__init__( 477 num_units=num_units, 478 num_dims=2, 479 input_dims=0, 480 output_dims=0, 481 priority_dims=0, 482 tied=tied, 483 non_recurrent_dims=None if non_recurrent_fn is None else 0, 484 cell_fn=cell_fn, 485 non_recurrent_fn=non_recurrent_fn, 486 state_is_tuple=state_is_tuple, 487 output_is_tuple=output_is_tuple) 488 489 490class Grid3LSTMCell(GridRNNCell): 491 """3D BasicLSTM cell. 492 493 This creates a 2D cell which receives input and gives output in the first 494 dimension. 495 The first dimension can optionally be non-recurrent if `non_recurrent_fn` is 496 specified. 497 The second and third dimensions are LSTM. 498 """ 499 500 def __init__(self, 501 num_units, 502 tied=False, 503 non_recurrent_fn=None, 504 use_peepholes=False, 505 forget_bias=1.0, 506 state_is_tuple=True, 507 output_is_tuple=True): 508 509 def cell_fn(n): 510 return rnn.LSTMCell( 511 num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes) 512 513 super(Grid3LSTMCell, self).__init__( 514 num_units=num_units, 515 num_dims=3, 516 input_dims=0, 517 output_dims=0, 518 priority_dims=0, 519 tied=tied, 520 non_recurrent_dims=None if non_recurrent_fn is None else 0, 521 cell_fn=cell_fn, 522 non_recurrent_fn=non_recurrent_fn, 523 state_is_tuple=state_is_tuple, 524 output_is_tuple=output_is_tuple) 525 526 527class Grid2GRUCell(GridRNNCell): 528 """2D LSTM cell. 529 530 This creates a 2D cell which receives input and gives output in the first 531 dimension. 532 The first dimension can optionally be non-recurrent if `non_recurrent_fn` is 533 specified. 534 """ 535 536 def __init__(self, 537 num_units, 538 tied=False, 539 non_recurrent_fn=None, 540 state_is_tuple=True, 541 output_is_tuple=True): 542 super(Grid2GRUCell, self).__init__( 543 num_units=num_units, 544 num_dims=2, 545 input_dims=0, 546 output_dims=0, 547 priority_dims=0, 548 tied=tied, 549 non_recurrent_dims=None if non_recurrent_fn is None else 0, 550 cell_fn=lambda n: rnn.GRUCell(num_units=n), 551 non_recurrent_fn=non_recurrent_fn, 552 state_is_tuple=state_is_tuple, 553 output_is_tuple=output_is_tuple) 554 555 556# Helpers 557 558_GridRNNDimension = namedtuple('_GridRNNDimension', [ 559 'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn' 560]) 561 562_GridRNNConfig = namedtuple('_GridRNNConfig', 563 ['num_dims', 'dims', 'inputs', 'outputs', 564 'recurrents', 'priority', 'non_priority', 'tied', 565 'num_units']) 566 567 568def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims, 569 ls_non_recurrent_dims, non_recurrent_fn, tied, num_units): 570 def check_dim_list(ls): 571 if ls is None: 572 ls = [] 573 if not isinstance(ls, (list, tuple)): 574 ls = [ls] 575 ls = sorted(set(ls)) 576 if any(_ < 0 or _ >= num_dims for _ in ls): 577 raise ValueError('Invalid dims: {}. Must be in [0, {})'.format(ls, 578 num_dims)) 579 return ls 580 581 input_dims = check_dim_list(ls_input_dims) 582 output_dims = check_dim_list(ls_output_dims) 583 priority_dims = check_dim_list(ls_priority_dims) 584 non_recurrent_dims = check_dim_list(ls_non_recurrent_dims) 585 586 rnn_dims = [] 587 for i in range(num_dims): 588 rnn_dims.append( 589 _GridRNNDimension( 590 idx=i, 591 is_input=(i in input_dims), 592 is_output=(i in output_dims), 593 is_priority=(i in priority_dims), 594 non_recurrent_fn=non_recurrent_fn 595 if i in non_recurrent_dims else None)) 596 return _GridRNNConfig( 597 num_dims=num_dims, 598 dims=rnn_dims, 599 inputs=input_dims, 600 outputs=output_dims, 601 recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims], 602 priority=priority_dims, 603 non_priority=[x for x in range(num_dims) if x not in priority_dims], 604 tied=tied, 605 num_units=num_units) 606 607 608def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state, 609 first_call): 610 """Propagates through all the cells in dim_indices dimensions. 611 """ 612 if len(dim_indices) == 0: 613 return 614 615 # Because of the way RNNCells are implemented, we take the last dimension 616 # (H_{N-1}) out and feed it as the state of the RNN cell 617 # (in `last_dim_output`). 618 # The input of the cell (H_0 to H_{N-2}) are concatenated into `cell_inputs` 619 if conf.num_dims > 1: 620 ls_cell_inputs = [None] * (conf.num_dims - 1) 621 for d in conf.dims[:-1]: 622 if new_output[d.idx] is None: 623 ls_cell_inputs[d.idx] = m_prev[d.idx] 624 else: 625 ls_cell_inputs[d.idx] = new_output[d.idx] 626 cell_inputs = array_ops.concat(ls_cell_inputs, 1) 627 else: 628 cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0], 629 m_prev[0].dtype) 630 631 last_dim_output = (new_output[-1] 632 if new_output[-1] is not None else m_prev[-1]) 633 634 for i in dim_indices: 635 d = conf.dims[i] 636 if d.non_recurrent_fn: 637 if conf.num_dims > 1: 638 linear_args = array_ops.concat([cell_inputs, last_dim_output], 1) 639 else: 640 linear_args = last_dim_output 641 with vs.variable_scope('non_recurrent' if conf.tied else 642 'non_recurrent/cell_{}'.format(i)): 643 if conf.tied and not (first_call and i == dim_indices[0]): 644 vs.get_variable_scope().reuse_variables() 645 646 new_output[d.idx] = layers.fully_connected( 647 linear_args, 648 num_outputs=conf.num_units, 649 activation_fn=d.non_recurrent_fn, 650 weights_initializer=(vs.get_variable_scope().initializer or 651 layers.initializers.xavier_initializer), 652 weights_regularizer=vs.get_variable_scope().regularizer) 653 else: 654 if c_prev[i] is not None: 655 cell_state = (c_prev[i], last_dim_output) 656 else: 657 # for GRU/RNN, the state is just the previous output 658 cell_state = last_dim_output 659 660 with vs.variable_scope('recurrent' if conf.tied else 661 'recurrent/cell_{}'.format(i)): 662 if conf.tied and not (first_call and i == dim_indices[0]): 663 vs.get_variable_scope().reuse_variables() 664 cell = cells[i] 665 new_output[d.idx], new_state[d.idx] = cell(cell_inputs, cell_state) 666