/external/tensorflow/tensorflow/core/kernels/ |
H A D | reverse_sequence_op.cc | 44 void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { argument 56 OP_REQUIRES(context, batch_dim != seq_dim, 57 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim)); 61 OP_REQUIRES(context, batch_dim < input.dims(), 62 errors::InvalidArgument("batch_dim must be < input.dims()", "( ", 63 batch_dim, " vs. ", input.dims(), ")")); 64 OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), 65 errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, 67 " vs. ", input.dim_size(batch_dim))); 78 void CheckErrorsGPU(OpKernelContext* context, int batch_dim, in argument 98 CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) argument 104 CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) argument [all...] |
H A D | reverse_sequence_op.h | 32 ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 batch_dim, argument 35 batch_dim_(batch_dim), 65 int32 batch_dim, int32 seq_dim, 68 generator::ReverseGenerator<T, Tlen, Dims> generator(input, batch_dim, 63 Compute( const Device& d, typename TTypes<T, Dims>::ConstTensor input, int32 batch_dim, int32 seq_dim, typename TTypes<Tlen>::ConstVec seq_lengths, typename TTypes<T, Dims>::Tensor output) argument
|
H A D | conv_grad_ops.cc | 116 int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); local 117 dims->batch_size = input_shape.dim_size(batch_dim); 118 if (dims->batch_size != out_backprop_shape.dim_size(batch_dim)) { 122 "outbackprop batch: ", out_backprop_shape.dim_size(batch_dim), 123 " batch_dim: ", batch_dim);
|
H A D | scatter_nd_op.cc | 292 // Check whether updates.shape = indices.shape[:batch_dim] + 298 const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1; local 302 "Must have updates.shape = indices.shape[:batch_dim] + ", 307 ", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim); 310 if (updates.dims() < batch_dim) return shape_err(); 311 if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) { 314 if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) { 317 for (int d = 0; d < batch_dim; ++d) { 320 for (int d = 0; d < updates.dims() - batch_dim; [all...] |
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
H A D | scatter_nd_op.cc | 29 // Check whether updates.shape = indices.shape[:batch_dim] + 41 const int64 batch_dim = indices_shape.dims() - 1; local 45 "Must have updates.shape = indices.shape[:batch_dim] + ", 50 ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim); 53 if (updates_shape.dims() < batch_dim) return shape_err(); 55 num_index_dims + (updates_shape.dims() - batch_dim)) { 59 batch_dim + buffer_shape.dims() - num_index_dims) { 62 for (int d = 0; d < batch_dim; ++d) { 67 for (int d = 0; d < updates_shape.dims() - batch_dim; [all...] |
H A D | extract_image_patches_op.cc | 53 int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); variable 56 ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1, 61 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, 65 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, 127 dims.set_input_batch_dimension(batch_dim); 128 dims.set_output_batch_dimension(batch_dim);
|
H A D | conv_ops.cc | 199 int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); variable 202 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, 211 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, 264 dims.set_input_batch_dimension(batch_dim); 265 dims.set_output_batch_dimension(batch_dim); 350 int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); variable 353 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, 362 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, 399 dnums.set_input_batch_dimension(batch_dim); 400 dnums.set_output_batch_dimension(batch_dim); [all...] |
/external/tensorflow/tensorflow/python/ops/ |
H A D | rnn.py | 424 batch_dim = 0 427 batch_dim = 1 429 def _reverse(input_, seq_lengths, seq_dim, batch_dim): 433 seq_dim=seq_dim, batch_dim=batch_dim) 440 seq_dim=time_dim, batch_dim=batch_dim) 449 seq_dim=time_dim, batch_dim=batch_dim)
|
H A D | data_flow_ops.py | 384 batch_dim = vals[0].get_shape().with_rank_at_least(1)[0] 386 batch_dim = batch_dim.merge_with( 489 batch_dim = tensor_shape.Dimension( 493 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 1089 batch_dim = None 1091 batch_dim = tensor_shape.Dimension( 1093 op.outputs[0].set_shape(tensor_shape.vector(batch_dim)) # indices 1094 op.outputs[1].set_shape(tensor_shape.vector(batch_dim)) # keys 1097 tensor_shape.TensorShape([batch_dim]) [all...] |
H A D | array_ops.py | 2641 batch_dim=None): 2645 "batch_dim", batch_dim) 2650 batch_dim=batch_axis, 2658 gen_array_ops.reverse_sequence.__doc__, "batch_dim", "batch_axis"),
|
/external/tensorflow/tensorflow/contrib/gan/python/features/python/ |
H A D | virtual_batchnorm_impl.py | 98 def _validate_call_input(tensor_list, batch_dim): 99 """Verifies that tensor shapes are compatible, except for `batch_dim`.""" 102 del shape[batch_dim] 212 # than in the `nn.batch_normalization` case, due to `batch_dim`.
|
/external/tensorflow/tensorflow/core/ops/ |
H A D | image_ops.cc | 28 // Sets output[0] to shape [batch_dim,height,width,channel_dim], where 30 Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, argument 58 c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim})); 532 DimensionHandle batch_dim; 534 c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim)); 538 return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
|
H A D | array_ops_test.cc | 1060 auto rebuild_node_def = [&op](const int32 seq_dim, const int32 batch_dim) { 1065 .Attr("batch_dim", batch_dim) 1076 // Validate seq_dim and batch_dim 1078 INFER_ERROR("batch_dim must be < input rank", op, "[1,2,3];[3]");
|
H A D | array_ops.cc | 1243 .Attr("batch_dim: int = 0") 1253 int64 batch_dim; 1254 TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim)); 1260 // Validate batch_dim and seq_dim against input. 1262 if (batch_dim >= input_rank) { 1264 "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank); 1271 DimensionHandle batch_dim_dim = c->Dim(input, batch_dim); 1275 // Replace batch_dim o [all...] |
/external/tensorflow/tensorflow/cc/gradients/ |
H A D | array_grad.cc | 228 int batch_dim; local 229 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim)); 234 ReverseSequence::BatchDim(batch_dim)));
|
/external/tensorflow/tensorflow/python/kernel_tests/ |
H A D | sparse_xent_op_test.py | 51 batch_dim = 0 53 batch_size = features.shape[batch_dim]
|
/external/tensorflow/tensorflow/core/framework/ |
H A D | common_shape_fns.cc | 316 DimensionHandle* batch_dim, 322 *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format)); 340 Status ShapeFromDimensions(DimensionHandle batch_dim, argument 348 out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim; 315 DimensionsFromShape(ShapeHandle shape, TensorFormat format, DimensionHandle* batch_dim, gtl::MutableArraySlice<DimensionHandle> spatial_dims, DimensionHandle* filter_dim, InferenceContext* context) argument
|
/external/tensorflow/tensorflow/python/data/ops/ |
H A D | dataset_ops.py | 973 batch_dim = flat_tensors[0].get_shape()[0] 975 batch_dim.assert_is_compatible_with(t.get_shape()[0])
|