data_flow_grad.py revision f41959ccb2d9d4c722fe8fc3351401d53bcf4900
1"""Gradients for operators defined in data_flow_ops.py.""" 2 3from tensorflow.python.framework import ops 4from tensorflow.python.framework import types 5from tensorflow.python.ops import array_ops 6from tensorflow.python.ops import constant_op 7from tensorflow.python.ops import data_flow_ops 8from tensorflow.python.ops import gen_data_flow_ops 9from tensorflow.python.ops import math_ops 10 11 12@ops.RegisterGradient("DynamicStitch") 13def _DynamicStitchGrads(op, grad): 14 """Gradients for DynamicStitch.""" 15 16 num_values = len(op.inputs) / 2 17 indices_grad = [None] * num_values 18 19 def AsInt32(x): 20 return (x if op.inputs[0].dtype == types.int32 else 21 math_ops.cast(x, types.int32)) 22 inputs = [AsInt32(op.inputs[i]) for i in range(num_values)] 23 if isinstance(grad, ops.IndexedSlices): 24 output_shape = array_ops.shape(op.outputs[0]) 25 output_rows = output_shape[0] 26 grad = math_ops.unsorted_segment_sum(grad.values, grad.indices, output_rows) 27 values_grad = [array_ops.gather(grad, inp) for inp in inputs] 28 return indices_grad + values_grad 29 30 31ops.NoGradient("Queue") 32ops.NoGradient("QueueEnqueue") 33ops.NoGradient("QueueEnqueueMany") 34ops.NoGradient("QueueDequeue") 35ops.NoGradient("QueueDequeueMany") 36ops.NoGradient("QueueClose") 37ops.NoGradient("QueueSize") 38