Searched defs:batch_dim (Results 1 - 10 of 10) sorted by relevance

/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
H A Dscatter_nd_op.cc29 // Check whether updates.shape = indices.shape[:batch_dim] +
41 const int64 batch_dim = indices_shape.dims() - 1; local
45 "Must have updates.shape = indices.shape[:batch_dim] + ",
50 ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim);
53 if (updates_shape.dims() < batch_dim) return shape_err();
55 num_index_dims + (updates_shape.dims() - batch_dim)) {
59 batch_dim + buffer_shape.dims() - num_index_dims) {
62 for (int d = 0; d < batch_dim; ++d) {
67 for (int d = 0; d < updates_shape.dims() - batch_dim;
[all...]
H A Dextract_image_patches_op.cc53 int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); variable
56 ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1,
61 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
65 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
127 dims.set_input_batch_dimension(batch_dim);
128 dims.set_output_batch_dimension(batch_dim);
H A Dconv_ops.cc199 int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); variable
202 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
211 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
264 dims.set_input_batch_dimension(batch_dim);
265 dims.set_output_batch_dimension(batch_dim);
350 int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); variable
353 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
362 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
399 dnums.set_input_batch_dimension(batch_dim);
400 dnums.set_output_batch_dimension(batch_dim);
[all...]
/external/tensorflow/tensorflow/core/kernels/
H A Dreverse_sequence_op.cc44 void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { argument
56 OP_REQUIRES(context, batch_dim != seq_dim,
57 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
61 OP_REQUIRES(context, batch_dim < input.dims(),
62 errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
63 batch_dim, " vs. ", input.dims(), ")"));
64 OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
65 errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
67 " vs. ", input.dim_size(batch_dim)));
78 void CheckErrorsGPU(OpKernelContext* context, int batch_dim, in argument
98 CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) argument
104 CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) argument
[all...]
H A Dreverse_sequence_op.h32 ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 batch_dim, argument
35 batch_dim_(batch_dim),
65 int32 batch_dim, int32 seq_dim,
68 generator::ReverseGenerator<T, Tlen, Dims> generator(input, batch_dim,
63 Compute( const Device& d, typename TTypes<T, Dims>::ConstTensor input, int32 batch_dim, int32 seq_dim, typename TTypes<Tlen>::ConstVec seq_lengths, typename TTypes<T, Dims>::Tensor output) argument
H A Dconv_grad_ops.cc116 int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); local
117 dims->batch_size = input_shape.dim_size(batch_dim);
118 if (dims->batch_size != out_backprop_shape.dim_size(batch_dim)) {
122 "outbackprop batch: ", out_backprop_shape.dim_size(batch_dim),
123 " batch_dim: ", batch_dim);
H A Dscatter_nd_op.cc292 // Check whether updates.shape = indices.shape[:batch_dim] +
298 const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1; local
302 "Must have updates.shape = indices.shape[:batch_dim] + ",
307 ", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim);
310 if (updates.dims() < batch_dim) return shape_err();
311 if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) {
314 if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) {
317 for (int d = 0; d < batch_dim; ++d) {
320 for (int d = 0; d < updates.dims() - batch_dim;
[all...]
/external/tensorflow/tensorflow/core/ops/
H A Dimage_ops.cc28 // Sets output[0] to shape [batch_dim,height,width,channel_dim], where
30 Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, argument
58 c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
532 DimensionHandle batch_dim;
534 c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
538 return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
/external/tensorflow/tensorflow/cc/gradients/
H A Darray_grad.cc228 int batch_dim; local
229 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
234 ReverseSequence::BatchDim(batch_dim)));
/external/tensorflow/tensorflow/core/framework/
H A Dcommon_shape_fns.cc316 DimensionHandle* batch_dim,
322 *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
340 Status ShapeFromDimensions(DimensionHandle batch_dim, argument
348 out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
315 DimensionsFromShape(ShapeHandle shape, TensorFormat format, DimensionHandle* batch_dim, gtl::MutableArraySlice<DimensionHandle> spatial_dims, DimensionHandle* filter_dim, InferenceContext* context) argument

Completed in 312 milliseconds