Lines Matching refs:ctx

160     XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype,
170 *ctx->GetOrCreateAdd(dtype),
177 explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims,
179 : XlaOpKernel(ctx),
182 OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
183 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
184 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
187 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
188 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
194 void Compile(XlaOpKernelContext* ctx) override {
195 OP_REQUIRES(ctx, strides_.size() == num_dims(),
202 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
206 OP_REQUIRES(ctx, dilations_.size() == num_dims(),
211 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
216 OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
222 const TensorShape input_shape = ctx->InputShape(0);
225 const TensorShape filter_shape = ctx->InputShape(1);
229 ctx, input_shape.dims() == num_dims(),
233 ctx, filter_shape.dims() == num_dims(),
242 OP_REQUIRES(ctx, in_depth == input_shape.dim_size(feature_dim),
247 xla::ComputationBuilder* b = ctx->builder();
249 xla::ComputationDataHandle filter = ctx->Input(1);
253 filter_shape, ctx->input_type(0), filter, b);
281 ctx, GetWindowedOutputSizeVerboseV2(
288 b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding,
290 ctx->SetOutput(0, conv);
307 explicit Conv2DOp(OpKernelConstruction* ctx)
308 : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
314 explicit Conv3DOp(OpKernelConstruction* ctx)
315 : ConvOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
321 explicit DepthwiseConv2DOp(OpKernelConstruction* ctx)
322 : ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
329 explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims,
331 : XlaOpKernel(ctx),
334 OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
335 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
336 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
338 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
339 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
345 void Compile(XlaOpKernelContext* ctx) override {
346 OP_REQUIRES(ctx, strides_.size() == num_dims(),
353 ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1,
357 OP_REQUIRES(ctx, dilations_.size() == num_dims(),
362 ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
367 OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
374 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
376 const TensorShape filter_shape = ctx->InputShape(1);
377 const TensorShape out_backprop_shape = ctx->InputShape(2);
384 OP_REQUIRES_OK(ctx,
390 xla::ComputationBuilder* b = ctx->builder();
391 auto filter = ctx->Input(1);
392 auto out_backprop = ctx->Input(2);
430 filter_shape, ctx->input_type(1), filter, b);
443 ctx->SetOutput(0, in_backprop);
460 explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx)
461 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
469 explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx)
470 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
478 explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx)
479 : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
487 explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims,
489 : XlaOpKernel(ctx),
492 OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
493 OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
494 OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
496 OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
497 OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
503 void Compile(XlaOpKernelContext* ctx) override {
508 ctx, (strides_[n_dim] == 1 && strides_[c_dim] == 1),
512 OP_REQUIRES(ctx, dilations_.size() == num_dims(),
517 ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1,
522 OP_REQUIRES(ctx, dilations_[input_dim] >= 1,
528 const TensorShape activations_shape = ctx->InputShape(0);
530 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape));
531 const TensorShape out_backprop_shape = ctx->InputShape(2);
539 OP_REQUIRES_OK(ctx,
545 xla::ComputationBuilder* b = ctx->builder();
546 xla::ComputationDataHandle activations = ctx->Input(0);
547 xla::ComputationDataHandle gradients = ctx->Input(2);
643 ctx, filter_shape, ctx->input_type(0), filter_backprop, b);
645 ctx->SetOutput(0, filter_backprop);
662 explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx)
663 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {
672 explicit Conv3DBackpropFilterOp(OpKernelConstruction* ctx)
673 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {
682 explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx)
683 : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}