Searched refs:batch_dim (Results 1 - 18 of 18) sorted by relevance

/external/tensorflow/tensorflow/core/kernels/
H A Dreverse_sequence_op.cc44 void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { argument
56 OP_REQUIRES(context, batch_dim != seq_dim,
57 errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim));
61 OP_REQUIRES(context, batch_dim < input.dims(),
62 errors::InvalidArgument("batch_dim must be < input.dims()", "( ",
63 batch_dim, " vs. ", input.dims(), ")"));
64 OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
65 errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
67 " vs. ", input.dim_size(batch_dim)));
78 void CheckErrorsGPU(OpKernelContext* context, int batch_dim, in argument
98 CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) argument
104 CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) argument
[all...]
H A Dreverse_sequence_op.h32 ReverseGenerator(typename TTypes<T, Dims>::ConstTensor input, int32 batch_dim, argument
35 batch_dim_(batch_dim),
65 int32 batch_dim, int32 seq_dim,
68 generator::ReverseGenerator<T, Tlen, Dims> generator(input, batch_dim,
63 Compute( const Device& d, typename TTypes<T, Dims>::ConstTensor input, int32 batch_dim, int32 seq_dim, typename TTypes<Tlen>::ConstVec seq_lengths, typename TTypes<T, Dims>::Tensor output) argument
H A Dconv_grad_ops.cc116 int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); local
117 dims->batch_size = input_shape.dim_size(batch_dim);
118 if (dims->batch_size != out_backprop_shape.dim_size(batch_dim)) {
122 "outbackprop batch: ", out_backprop_shape.dim_size(batch_dim),
123 " batch_dim: ", batch_dim);
H A Dscatter_nd_op.cc292 // Check whether updates.shape = indices.shape[:batch_dim] +
298 const int64 batch_dim = (indices.dims() > 1) ? indices.dims() - 1 : 1; local
302 "Must have updates.shape = indices.shape[:batch_dim] + ",
307 ", slice_dim: ", slice_dim, ", and batch_dim: ", batch_dim);
310 if (updates.dims() < batch_dim) return shape_err();
311 if (params_shape.dims() < slice_dim + (updates.dims() - batch_dim)) {
314 if (updates.dims() != batch_dim + params_shape.dims() - slice_dim) {
317 for (int d = 0; d < batch_dim; ++d) {
320 for (int d = 0; d < updates.dims() - batch_dim;
[all...]
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
H A Dscatter_nd_op.cc29 // Check whether updates.shape = indices.shape[:batch_dim] +
41 const int64 batch_dim = indices_shape.dims() - 1; local
45 "Must have updates.shape = indices.shape[:batch_dim] + ",
50 ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim);
53 if (updates_shape.dims() < batch_dim) return shape_err();
55 num_index_dims + (updates_shape.dims() - batch_dim)) {
59 batch_dim + buffer_shape.dims() - num_index_dims) {
62 for (int d = 0; d < batch_dim; ++d) {
67 for (int d = 0; d < updates_shape.dims() - batch_dim;
[all...]
H A Dextract_image_patches_op.cc53 int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); variable
56 ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1,
61 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
65 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
127 dims.set_input_batch_dimension(batch_dim);
128 dims.set_output_batch_dimension(batch_dim);
H A Dconv_ops.cc199 int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); variable
202 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
211 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
264 dims.set_input_batch_dimension(batch_dim);
265 dims.set_output_batch_dimension(batch_dim);
350 int batch_dim = GetTensorBatchDimIndex(num_dims(), data_format_); variable
353 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
362 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
399 dnums.set_input_batch_dimension(batch_dim);
400 dnums.set_output_batch_dimension(batch_dim);
[all...]
/external/tensorflow/tensorflow/python/ops/
H A Drnn.py424 batch_dim = 0
427 batch_dim = 1
429 def _reverse(input_, seq_lengths, seq_dim, batch_dim):
433 seq_dim=seq_dim, batch_dim=batch_dim)
440 seq_dim=time_dim, batch_dim=batch_dim)
449 seq_dim=time_dim, batch_dim=batch_dim)
H A Ddata_flow_ops.py384 batch_dim = vals[0].get_shape().with_rank_at_least(1)[0]
386 batch_dim = batch_dim.merge_with(
489 batch_dim = tensor_shape.Dimension(
493 tensor_shape.TensorShape([batch_dim]).concatenate(shape))
1089 batch_dim = None
1091 batch_dim = tensor_shape.Dimension(
1093 op.outputs[0].set_shape(tensor_shape.vector(batch_dim)) # indices
1094 op.outputs[1].set_shape(tensor_shape.vector(batch_dim)) # keys
1097 tensor_shape.TensorShape([batch_dim])
[all...]
H A Darray_ops.py2641 batch_dim=None):
2645 "batch_dim", batch_dim)
2650 batch_dim=batch_axis,
2658 gen_array_ops.reverse_sequence.__doc__, "batch_dim", "batch_axis"),
/external/tensorflow/tensorflow/contrib/gan/python/features/python/
H A Dvirtual_batchnorm_impl.py98 def _validate_call_input(tensor_list, batch_dim):
99 """Verifies that tensor shapes are compatible, except for `batch_dim`."""
102 del shape[batch_dim]
212 # than in the `nn.batch_normalization` case, due to `batch_dim`.
/external/tensorflow/tensorflow/core/ops/
H A Dimage_ops.cc28 // Sets output[0] to shape [batch_dim,height,width,channel_dim], where
30 Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim, argument
58 c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
532 DimensionHandle batch_dim;
534 c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
538 return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
H A Darray_ops_test.cc1060 auto rebuild_node_def = [&op](const int32 seq_dim, const int32 batch_dim) {
1065 .Attr("batch_dim", batch_dim)
1076 // Validate seq_dim and batch_dim
1078 INFER_ERROR("batch_dim must be < input rank", op, "[1,2,3];[3]");
H A Darray_ops.cc1243 .Attr("batch_dim: int = 0")
1253 int64 batch_dim;
1254 TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim));
1260 // Validate batch_dim and seq_dim against input.
1262 if (batch_dim >= input_rank) {
1264 "batch_dim must be < input rank: ", batch_dim, " vs. ", input_rank);
1271 DimensionHandle batch_dim_dim = c->Dim(input, batch_dim);
1275 // Replace batch_dim o
[all...]
/external/tensorflow/tensorflow/cc/gradients/
H A Darray_grad.cc228 int batch_dim; local
229 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
234 ReverseSequence::BatchDim(batch_dim)));
/external/tensorflow/tensorflow/python/kernel_tests/
H A Dsparse_xent_op_test.py51 batch_dim = 0
53 batch_size = features.shape[batch_dim]
/external/tensorflow/tensorflow/core/framework/
H A Dcommon_shape_fns.cc316 DimensionHandle* batch_dim,
322 *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format));
340 Status ShapeFromDimensions(DimensionHandle batch_dim, argument
348 out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim;
315 DimensionsFromShape(ShapeHandle shape, TensorFormat format, DimensionHandle* batch_dim, gtl::MutableArraySlice<DimensionHandle> spatial_dims, DimensionHandle* filter_dim, InferenceContext* context) argument
/external/tensorflow/tensorflow/python/data/ops/
H A Ddataset_ops.py973 batch_dim = flat_tensors[0].get_shape()[0]
975 batch_dim.assert_is_compatible_with(t.get_shape()[0])

Completed in 630 milliseconds