data_flow_grad.py revision f41959ccb2d9d4c722fe8fc3351401d53bcf4900
1"""Gradients for operators defined in data_flow_ops.py."""
2
3from tensorflow.python.framework import ops
4from tensorflow.python.framework import types
5from tensorflow.python.ops import array_ops
6from tensorflow.python.ops import constant_op
7from tensorflow.python.ops import data_flow_ops
8from tensorflow.python.ops import gen_data_flow_ops
9from tensorflow.python.ops import math_ops
10
11
12@ops.RegisterGradient("DynamicStitch")
13def _DynamicStitchGrads(op, grad):
14  """Gradients for DynamicStitch."""
15
16  num_values = len(op.inputs) / 2
17  indices_grad = [None] * num_values
18
19  def AsInt32(x):
20    return (x if op.inputs[0].dtype == types.int32 else
21            math_ops.cast(x, types.int32))
22  inputs = [AsInt32(op.inputs[i]) for i in range(num_values)]
23  if isinstance(grad, ops.IndexedSlices):
24    output_shape = array_ops.shape(op.outputs[0])
25    output_rows = output_shape[0]
26    grad = math_ops.unsorted_segment_sum(grad.values, grad.indices, output_rows)
27  values_grad = [array_ops.gather(grad, inp) for inp in inputs]
28  return indices_grad + values_grad
29
30
31ops.NoGradient("Queue")
32ops.NoGradient("QueueEnqueue")
33ops.NoGradient("QueueEnqueueMany")
34ops.NoGradient("QueueDequeue")
35ops.NoGradient("QueueDequeueMany")
36ops.NoGradient("QueueClose")
37ops.NoGradient("QueueSize")
38