Lines Matching refs:ctx

38   PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims)
39 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
40 if (ctx->num_inputs() == 1) {
43 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int));
44 OP_REQUIRES(ctx, ksize_int.size() == num_dims(),
48 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int));
49 OP_REQUIRES(ctx, stride_int.size() == num_dims(),
59 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding));
70 virtual const xla::Computation* Reduction(XlaOpKernelContext* ctx,
75 XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
78 void Compile(XlaOpKernelContext* ctx) override {
79 xla::ComputationDataHandle input = ctx->Input(0);
80 const TensorShape input_shape = ctx->InputShape(0);
84 if (ctx->num_inputs() != 1) {
85 const TensorShape ksize_shape = ctx->InputShape(1);
87 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
90 OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(),
95 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize));
97 const TensorShape stride_shape = ctx->InputShape(2);
99 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
102 OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(),
107 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride));
109 OP_REQUIRES(ctx, input_shape.dims() == num_dims(),
115 xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow(
116 input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize,
118 ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape));
131 MaxPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
132 : PoolingOp(ctx, /*num_spatial_dims=*/num_spatial_dims) {}
139 const xla::Computation* Reduction(XlaOpKernelContext* ctx,
141 return ctx->GetOrCreateMax(dtype);
145 XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
153 explicit MaxPool2DOp(OpKernelConstruction* ctx)
154 : MaxPoolOp(ctx, /*num_spatial_dims=*/2) {
156 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
157 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
169 explicit MaxPool3DOp(OpKernelConstruction* ctx)
170 : MaxPoolOp(ctx, /*num_spatial_dims=*/3) {}
178 XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
190 XlaHelpers::IntegerLiteral(ctx->builder(), dtype, window_size);
191 return ctx->builder()->Div(output, divisor);
212 auto ones = ctx->builder()->Broadcast(
213 XlaHelpers::One(ctx->builder(), dtype), input_dim_sizes);
217 auto counts = ctx->builder()->ReduceWindow(
218 ones, XlaHelpers::Zero(ctx->builder(), dtype),
219 *ctx->GetOrCreateAdd(dtype), window_ksize, window_stride,
222 return ctx->builder()->Div(output, counts, window_dims);
228 AvgPoolOp(OpKernelConstruction* ctx, int num_spatial_dims)
229 : PoolingOp(ctx, num_spatial_dims) {}
236 const xla::Computation* Reduction(XlaOpKernelContext* ctx,
238 return ctx->GetOrCreateAdd(dtype);
242 XlaOpKernelContext* ctx, const xla::ComputationDataHandle& output,
244 return AvgPoolDivideByCount(ctx, output, dtype, input_shape, padding_,
252 explicit AvgPool2DOp(OpKernelConstruction* ctx)
253 : AvgPoolOp(ctx, /*num_spatial_dims=*/2) {
255 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str));
256 OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_),
264 explicit AvgPool3DOp(OpKernelConstruction* ctx)
265 : AvgPoolOp(ctx, /*num_spatial_dims=*/3) {}
277 MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
278 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
279 if (ctx->num_inputs() == 3) {
280 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
281 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
283 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
288 void Compile(XlaOpKernelContext* ctx) override {
289 if (ctx->num_inputs() != 3) {
291 ctx, ctx->num_inputs() == 5,
293 const TensorShape ksize_shape = ctx->InputShape(3);
295 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape),
298 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_));
300 const TensorShape stride_shape = ctx->InputShape(4);
302 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape),
305 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_));
308 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
312 OP_REQUIRES(ctx, stride_.size() == num_dims(),
317 const TensorShape tensor_in_shape = ctx->InputShape(0);
318 const TensorShape tensor_out_shape = ctx->InputShape(1);
319 const TensorShape out_backprop_shape = ctx->InputShape(2);
322 OP_REQUIRES(ctx, tensor_in_shape.dims() == num_dims(),
325 OP_REQUIRES(ctx, tensor_out_shape.dims() == num_dims(),
329 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
335 auto input = ctx->Input(0);
336 auto out_backprop = ctx->Input(2);
342 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_type(2), &element_type));
344 XlaHelpers::Zero(ctx->builder(), input_type(2));
345 auto select = CreateScalarGeComputation(element_type, ctx->builder());
346 auto scatter = CreateScalarAddComputation(element_type, ctx->builder());
347 xla::ComputationDataHandle gradients = ctx->builder()->SelectAndScatter(
351 ctx->SetOutput(0, gradients);
364 explicit MaxPool2DGradOp(OpKernelConstruction* ctx)
365 : MaxPoolGradOp(ctx, /*num_spatial_dims=*/2) {
367 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
368 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
380 explicit MaxPool3DGradOp(OpKernelConstruction* ctx)
381 : MaxPoolGradOp(ctx, /*num_spatial_dims=*/3) {}
388 AvgPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims)
389 : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
390 OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_));
391 OP_REQUIRES(ctx, ksize_.size() == num_dims(),
395 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_));
396 OP_REQUIRES(ctx, stride_.size() == num_dims(),
400 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
401 OP_REQUIRES(ctx, ksize_[0] == 1 && stride_[0] == 1,
408 void Compile(XlaOpKernelContext* ctx) override {
410 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &gradients_shape));
412 const TensorShape out_backprop_shape = ctx->InputShape(1);
415 OP_REQUIRES(ctx, gradients_shape.dims() == num_dims(),
420 OP_REQUIRES(ctx, out_backprop_shape.dims() == num_dims(),
453 ctx, ConvBackpropComputeDimensions(
458 auto out_backprop = ctx->Input(1);
472 ctx, out_backprop, dtype, gradients_shape, xla_padding, ksize_,
486 auto zero = XlaHelpers::Zero(ctx->builder(), dtype);
488 ctx->builder()->Pad(out_backprop_div, zero, padding_config);
492 xla::ComputationDataHandle in_backprop = ctx->builder()->ReduceWindow(
493 padded_gradients, zero, *ctx->GetOrCreateAdd(dtype), ksize_,
496 ctx->SetOutput(0, in_backprop);
509 explicit AvgPool2DGradOp(OpKernelConstruction* ctx)
510 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/2) {
512 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
513 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
522 explicit AvgPool3DGradOp(OpKernelConstruction* ctx)
523 : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {}