1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// SWIG typemaps and declarations for building, compiling, and
17// executing XLA computations, wrapping most of what is declared in
18// local_computation_builder.h.
19//
20// The typemaps below implement/assert the following correspondences
21// (with elaborations below):
22//
23//    C++                                  Python
24// -------------------------------------+---------------------------------------
25//  ComputationDataHandle              <-> int
26//  ArraySlice<int64>                  <-  sequence of int
27//  ArraySlice<ComputationDataHandle>  <-  sequence of int
28//  Literal                            <-> (nested tuple of) numpy ndarray
29//  std::vector<Literal>               <-  sequence of (nested tuple of) ndarray
30//  Shape                               -> pair holding (dtype, dimensions)
31//                                     <-  object duck-typed as xla_client.Shape
32//  std::vector<Shape>                 <-  sequence of xla_client.Shape objects
33//  PrimitiveType                      <-  int
34//  ArraySlice<pair<int64, in64>>      <-  sequence of int pairs
35//  PaddingConfig proto                <-  corresponding Python proto
36//  ConvolutionDimensionNumbers proto  <-  corresponding Python proto
37//  DotDimensionNumbers proto          <-  corresponding Python proto
38//
39// Arrows indicate whether a conversion only ever occurs in one
40// direction, or whether it is maintained bidirectionally.
41//
42// The Python objects corresponding to C++ Literals have the type:
43//
44//   T = ndarray | (T, ...)
45//
46// where a terminal numpy ndarray translates to a Literal with a
47// non-tuple Shape, an XLA primitive element type corresponding to the
48// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates
49// to a tuple-shaped Literal whose tuple components are translated
50// recursively. For example, if x is a numpy ndarray in Python, with
51// shape (2, 3) and dtype of dtype('float32'), then x translates to a
52// Literal with rank 2, dimension 2 and 3, and XLA primitive type
53// F32. Meanwhile,
54//
55//   (x, (x, x), (x,)),
56//
57// translates to a tuple-shaped XLA Literal, whose component subshapes
58// are a 2x3 F32-shaped literal followed by two tuple-shaped literals.
59//
60// Shapes output by C++ become Python objects with the type:
61//
62//   T            = (dtype, S)
63//   S            = DIMENSIONS | TUPLE_SHAPES
64//   DIMENSIONS   = (int, ...)
65//   TUPLE_SHAPES = (T, ...)
66//
67// In the pair described by the T rule, the terminal dtype determines
68// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is
69// dtype('O'), numpy's object dtype, the structure represents a tuple
70// shape and the expansion of the non-terminal S is
71// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type
72// and S expands into DIMENSIONS giving dimension sizes. For example:
73//
74//   (dtype('float32'), (3, 5, 7))
75//
76// describes a 3x5x7 array of F32s, and
77//
78//   (dtype('O'), ((dtype('float32'), (2, 3)),
79//                 (dtype('float64'), (4, 5))))
80//
81// describes a tuple shape with two subshapes: the first a 2x3 F32,
82// and the other a 4x5 F64.
83//
84// The Python int corresponding to a PrimitiveType enum must be valid
85// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32).
86//
87// The SWIG object wrappers generated by this file are not intended
88// for end use, but rather for internal use in the Python XLA client,
89// xla_client.py.
90//
91// One central reason for the Python-side indirection is that the
92// Python-side objects produced by the typemaps in this file are
93// further packaged up by xla_client before being passed on. For
94// instance, xla_client wraps the long produced for a C++
95// ComputationDataHandle in a Python ComputationDataHandle proto,
96// rather than exposing a raw long outside of the client. Similarly,
97// the Python pair produced for a C++ Shape is further wrapped in a
98// Python class (xla_client.Shape) so as not to expose the raw pair
99// externally.
100//
101// Other SWIG object wrappers (e.g. of LocalComputation) are further
102// wrapped by xla_client in order to set up a custom destructor that
103// triggers memory deallocation on the C++ side.
104
105%module(threads="1") local_computation_builder
106
107// Keep the GIL except where explicitly specified.
108%nothread;
109
110%include "tensorflow/python/platform/base.i"
111
112%{
113// Must be included first
114#include "tensorflow/python/lib/core/numpy.h"
115
116#include "tensorflow/compiler/xla/literal_util.h"
117#include "tensorflow/compiler/xla/shape_util.h"
118#include "tensorflow/compiler/xla/xla_data.pb.h"
119#include "tensorflow/core/lib/gtl/array_slice.h"
120#include "tensorflow/compiler/xla/python/numpy_bridge.h"
121#include "tensorflow/compiler/xla/python/local_computation_builder.h"
122
123using namespace xla;
124using namespace xla::swig;
125
126namespace xla {
127namespace swig {
128
129bool GetIntAttr(PyObject* o, const char* field, int64* result) {
130  PyObject* fo = PyObject_GetAttrString(o, field);
131  if (!fo) {
132    return false;
133  }
134  const int64 value = numpy::PyIntOrPyLongToLong(fo);
135  if (value == -1 && PyErr_Occurred()) {
136    Py_DECREF(fo);
137    return false;
138  }
139  Py_DECREF(fo);
140  *result = value;
141  return true;
142}
143
144}
145}
146%}
147
148// Required to use PyArray_* functions.
149%init %{
150tensorflow::ImportNumpy();
151%}
152
153// ComputationDataHandle
154
155%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) {
156  const int64 handle = numpy::PyIntOrPyLongToLong($input);
157  if (handle == -1 && PyErr_Occurred()) {
158    return NULL;
159  }
160  temp.set_handle(handle);
161  $1 = &temp;
162}
163
164%typemap(out) ComputationDataHandle {
165  $result = numpy::LongToPyIntOrPyLong($1.handle());
166}
167
168%typemap(out) StatusOr<xla::swig::CompiledLocalComputation*> {
169  if ($1.ok()) {
170    auto* value = $1.ValueOrDie();
171    {
172      auto* $1 = value;
173      $typemap(out, xla::swig::CompiledLocalComputation*)
174    }
175  } else {
176    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
177    return NULL;
178  }
179}
180
181%typemap(out) StatusOr< std::unique_ptr<Literal> > {
182  if ($1.ok()) {
183    std::unique_ptr<Literal> value = $1.ConsumeValueOrDie();
184    $result = numpy::PyObjectFromXlaLiteral(*value);
185  } else {
186    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
187    return NULL;
188  }
189}
190
191%typemap(out) StatusOr<xla::swig::LocalComputation*> {
192  if ($1.ok()) {
193    auto* value = $1.ValueOrDie();
194    {
195      auto* $1 = value;
196      $typemap(out, xla::swig::LocalComputation*)
197    }
198  } else {
199    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
200    return NULL;
201  }
202}
203
204%typemap(out) StatusOr<Shape> {
205  if ($1.ok()) {
206    $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie());
207  } else {
208    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
209    return NULL;
210  }
211}
212
213%typemap(out) Status {
214  if (!$1.ok()) {
215    PyErr_SetString(
216        PyExc_RuntimeError, $1.ToString().c_str());
217    return NULL;
218  }
219  $result = Py_None;
220}
221
222// ArraySlice<int64>
223
224%typemap(in) tensorflow::gtl::ArraySlice<int64>
225    (std::vector<int64> temps) {
226  if (!PySequence_Check($input)) {
227    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
228    return NULL;
229  }
230  const int size = PySequence_Size($input);
231  temps.resize(size);
232  for (int i = 0; i < size; ++i) {
233    PyObject* o = PySequence_GetItem($input, i);
234    PyObject* py_int = numpy::PyNumberToPyInt(o);
235    if (!py_int) {
236      PyErr_SetString(
237          PyExc_TypeError,
238          "Argument sequence element cannot be converted to int");
239      Py_DECREF(o);
240      return NULL;
241    }
242    temps[i] = numpy::PyIntOrPyLongToLong(py_int);
243    if (temps[i] == -1 && PyErr_Occurred()) {
244      Py_DECREF(py_int);
245      Py_DECREF(o);
246      return NULL;
247    }
248    Py_DECREF(py_int);
249    Py_DECREF(o);
250  }
251  $1 = temps;
252}
253
254// ComputationDataHandle
255
256%typemap(in) tensorflow::gtl::ArraySlice<ComputationDataHandle>
257    (std::vector<ComputationDataHandle> temps) {
258  if (!PySequence_Check($input)) {
259    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
260    return NULL;
261  }
262  const int size = PySequence_Size($input);
263  temps.resize(size);
264  for (int i = 0; i < size; ++i) {
265    PyObject* o = PySequence_GetItem($input, i);
266    PyObject* py_int = numpy::PyNumberToPyInt(o);
267    if (!py_int) {
268      PyErr_SetString(
269          PyExc_TypeError,
270          "Argument sequence element cannot be converted to int");
271      return NULL;
272    }
273    const int64 handle = numpy::PyIntOrPyLongToLong(py_int);
274    if (handle == -1 && PyErr_Occurred()) {
275      Py_DECREF(py_int);
276      Py_DECREF(o);
277      return NULL;
278    }
279    temps[i].set_handle(handle);
280    Py_DECREF(py_int);
281    Py_DECREF(o);
282  }
283  $1 = temps;
284}
285
286// LocalShapedBuffer*
287
288%typemap(in) tensorflow::gtl::ArraySlice<xla::swig::LocalShapedBuffer*>
289    (std::vector<LocalShapedBuffer*> temps) {
290  if (!PySequence_Check($input)) {
291    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
292    return NULL;
293  }
294  const int size = PySequence_Size($input);
295  temps.reserve(size);
296  for (int i = 0; i < size; ++i) {
297    PyObject* o = PySequence_GetItem($input, i);
298    LocalShapedBuffer* lsbp;
299    if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*),
300                         SWIG_POINTER_EXCEPTION)) == -1) {
301      return NULL;
302    }
303    temps.push_back(lsbp);
304    Py_DECREF(o);
305  }
306  $1 = temps;
307}
308
309// Literal
310
311%typemap(in) const Literal& (StatusOr< std::unique_ptr<Literal> > literal_status) {
312  literal_status = numpy::XlaLiteralFromPyObject($input);
313  if (!literal_status.ok()) {
314    PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
315    return NULL;
316  }
317  $1 = literal_status.ValueOrDie().get();
318}
319
320%typemap(out) std::unique_ptr<Literal> {
321  $result = numpy::PyObjectFromXlaLiteral(*$1);
322}
323
324%typemap(out) StatusOr< std::unique_ptr<Literal> > {
325  if (!$1.ok()) {
326    PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
327    return NULL;
328  }
329  $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie());
330}
331
332%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
333  if (!PySequence_Check($input)) {
334    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
335    return NULL;
336  }
337  const int size = PySequence_Size($input);
338  for (int i = 0; i < size; ++i) {
339    PyObject* o = PySequence_GetItem($input, i);
340    StatusOr< std::unique_ptr<Literal> > literal_status = numpy::XlaLiteralFromPyObject(o);
341    if (!literal_status.ok()) {
342      PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str());
343      Py_DECREF(o);
344      return NULL;
345    }
346    temps.push_back(std::move(*literal_status.ConsumeValueOrDie()));
347    Py_DECREF(o);
348  }
349  $1 = &temps;
350}
351
352// OpMetadata
353
354%typemap(in) const OpMetadata& (OpMetadata temp) {
355  StatusOr<OpMetadata> statusor = numpy::OpMetadataFromPyObject($input);
356  if (!statusor.ok()) {
357    PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
358    return NULL;
359  }
360  temp = std::move(statusor).ValueOrDie();
361  $1 = &temp;
362}
363
364// Shape
365
366%typemap(in) const Shape& (Shape temp) {
367  StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
368  if (!statusor.ok()) {
369    PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
370    return NULL;
371  }
372  temp = std::move(statusor).ValueOrDie();
373  $1 = &temp;
374}
375
376%typemap(in) const tensorflow::gtl::optional<Shape>& (
377    tensorflow::gtl::optional<Shape> temp) {
378  if ($input == Py_None) {
379    temp = tensorflow::gtl::nullopt;
380    $1 = &temp;
381  } else {
382    StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape($input);
383    if (!statusor.ok()) {
384      PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
385      return NULL;
386    }
387    temp = std::move(statusor).ValueOrDie();
388    $1 = &temp;
389  }
390}
391
392%typemap(out) std::unique_ptr<Shape> {
393  $result = numpy::PyShapeInfoFromXlaShape(*$1);
394}
395
396%typemap(in) const std::vector<Shape>& (std::vector<Shape> temps) {
397  if (!PySequence_Check($input)) {
398    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
399    return NULL;
400  }
401  const int size = PySequence_Size($input);
402  for (int i = 0; i < size; ++i) {
403    PyObject* o = PySequence_GetItem($input, i);
404    StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
405    Py_DECREF(o);
406    if (!statusor.ok()) {
407      PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
408      return NULL;
409    }
410    temps.push_back(statusor.ConsumeValueOrDie());
411  }
412  $1 = &temps;
413}
414
415%typemap(in) const std::vector<tensorflow::gtl::optional<Shape> >& (
416    std::vector<tensorflow::gtl::optional<Shape> > temps) {
417  if (!PySequence_Check($input)) {
418    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
419    return NULL;
420  }
421  const int size = PySequence_Size($input);
422  for (int i = 0; i < size; ++i) {
423    PyObject* o = PySequence_GetItem($input, i);
424    if (o == Py_None) {
425      temps.push_back(tensorflow::gtl::nullopt);
426    } else {
427      StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
428      Py_DECREF(o);
429      if (!statusor.ok()) {
430        PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str());
431        return NULL;
432      }
433      temps.push_back(statusor.ConsumeValueOrDie());
434    }
435  }
436  $1 = &temps;
437}
438
439// PrimitiveType
440
441%typemap(in) PrimitiveType {
442  PyObject* py_int = numpy::PyNumberToPyInt($input);
443  if (!py_int) {
444    PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int");
445    return NULL;
446  }
447  const long value = numpy::PyIntOrPyLongToLong(py_int);
448  if (value == -1 && PyErr_Occurred()) {
449    Py_DECREF(py_int);
450    return NULL;
451  }
452  if (!PrimitiveType_IsValid(value)) {
453    PyErr_SetString(
454        PyExc_TypeError, "Argument not valid for PrimitiveType enum");
455    Py_DECREF(py_int);
456    return NULL;
457  }
458  $1 = static_cast<PrimitiveType>(value);
459}
460
461// ArraySlice<pair<int64, in64>>
462
463%typemap(in) tensorflow::gtl::ArraySlice<std::pair<int64, int64> >
464    (std::vector<std::pair<int64, int64> > temps) {
465  if (!PySequence_Check($input)) {
466    PyErr_SetString(PyExc_TypeError, "Argument is not a sequence");
467    return NULL;
468  }
469  const int size = PySequence_Size($input);
470  temps.reserve(size);
471  for (int i = 0; i < size; ++i) {
472    PyObject* o = PySequence_GetItem($input, i);
473    if (!o) {
474      return NULL;
475    }
476    PyObject* first = PyTuple_GetItem(o, 0);
477    if (!first) {
478      Py_DECREF(o);
479      return NULL;
480    }
481    PyObject* first_pyint = numpy::PyNumberToPyInt(first);
482    if (!first_pyint) {
483      PyErr_SetString(
484          PyExc_TypeError,
485          "First pair item cannot be converted to int");
486      Py_DECREF(o);
487      return NULL;
488    }
489    PyObject* second = PyTuple_GetItem(o, 1);
490    if (!second) {
491      Py_DECREF(o);
492      Py_DECREF(first_pyint);
493      return NULL;
494    }
495    PyObject* second_pyint = numpy::PyNumberToPyInt(second);
496    if (!second_pyint) {
497      PyErr_SetString(
498          PyExc_TypeError,
499          "Second pair item cannot be converted to int");
500      Py_DECREF(o);
501      Py_DECREF(first_pyint);
502      return NULL;
503    }
504    const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint);
505    if (first_value == -1 && PyErr_Occurred()) {
506      Py_DECREF(o);
507      Py_DECREF(first_pyint);
508      Py_DECREF(second_pyint);
509      return NULL;
510    }
511    const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint);
512    if (second_value == -1 && PyErr_Occurred()) {
513      Py_DECREF(o);
514      Py_DECREF(first_pyint);
515      Py_DECREF(second_pyint);
516      return NULL;
517    }
518    temps.push_back(std::make_pair(first_value, second_value));
519    Py_DECREF(o);
520  }
521  $1 = temps;
522}
523
524// DotDimensionNumbers
525
526%typemap(in) const DotDimensionNumbers&
527    (DotDimensionNumbers dimension_numbers) {
528  int length;
529
530  /* lhs_contracting_dimensions */
531  PyObject* lhs_contracting_dimensions = PyObject_GetAttrString(
532      $input, "lhs_contracting_dimensions");
533  if (!lhs_contracting_dimensions) {
534    return NULL;
535  }
536
537  length = PySequence_Size(lhs_contracting_dimensions);
538  if (length == -1) {
539    Py_DECREF(lhs_contracting_dimensions);
540    return NULL;
541  }
542
543  for (int i = 0; i < length; ++i) {
544    PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i);
545    if (!item) {
546      Py_DECREF(lhs_contracting_dimensions);
547      return NULL;
548    }
549    const int64 dimension = numpy::PyIntOrPyLongToLong(item);
550    if (dimension == -1 && PyErr_Occurred()) {
551      Py_DECREF(item);
552      Py_DECREF(lhs_contracting_dimensions);
553      return NULL;
554    }
555    dimension_numbers.add_lhs_contracting_dimensions(dimension);
556    Py_DECREF(item);
557  }
558  Py_DECREF(lhs_contracting_dimensions);
559
560  /* rhs_contracting_dimensions */
561  PyObject* rhs_contracting_dimensions = PyObject_GetAttrString(
562      $input, "rhs_contracting_dimensions");
563  if (!lhs_contracting_dimensions) {
564    return NULL;
565  }
566
567  length = PySequence_Size(rhs_contracting_dimensions);
568  if (length == -1) {
569    Py_DECREF(rhs_contracting_dimensions);
570    return NULL;
571  }
572
573  for (int i = 0; i < length; ++i) {
574    PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i);
575    if (!item) {
576      Py_DECREF(rhs_contracting_dimensions);
577      return NULL;
578    }
579    const int64 dimension = numpy::PyIntOrPyLongToLong(item);
580    if (dimension == -1 && PyErr_Occurred()) {
581      Py_DECREF(item);
582      Py_DECREF(rhs_contracting_dimensions);
583      return NULL;
584    }
585    dimension_numbers.add_rhs_contracting_dimensions(dimension);
586    Py_DECREF(item);
587  }
588  Py_DECREF(rhs_contracting_dimensions);
589
590  /* lhs_batch_dimensions */
591  PyObject* lhs_batch_dimensions = PyObject_GetAttrString(
592      $input, "lhs_batch_dimensions");
593  if (!lhs_batch_dimensions) {
594    return NULL;
595  }
596
597  length = PySequence_Size(lhs_batch_dimensions);
598  if (length == -1) {
599    Py_DECREF(lhs_batch_dimensions);
600    return NULL;
601  }
602
603  for (int i = 0; i < length; ++i) {
604    PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i);
605    if (!item) {
606      Py_DECREF(lhs_batch_dimensions);
607      return NULL;
608    }
609    const int64 dimension = numpy::PyIntOrPyLongToLong(item);
610    if (dimension == -1 && PyErr_Occurred()) {
611      Py_DECREF(item);
612      Py_DECREF(lhs_batch_dimensions);
613      return NULL;
614    }
615    dimension_numbers.add_lhs_batch_dimensions(dimension);
616    Py_DECREF(item);
617  }
618  Py_DECREF(lhs_batch_dimensions);
619
620  /* rhs_batch_dimensions */
621  PyObject* rhs_batch_dimensions = PyObject_GetAttrString(
622      $input, "rhs_batch_dimensions");
623  if (!rhs_batch_dimensions) {
624    return NULL;
625  }
626
627  length = PySequence_Size(rhs_batch_dimensions);
628  if (length == -1) {
629    Py_DECREF(rhs_batch_dimensions);
630    return NULL;
631  }
632
633  for (int i = 0; i < length; ++i) {
634    PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i);
635    if (!item) {
636      Py_DECREF(rhs_batch_dimensions);
637      return NULL;
638    }
639    const int64 dimension = numpy::PyIntOrPyLongToLong(item);
640    if (dimension == -1 && PyErr_Occurred()) {
641      Py_DECREF(item);
642      Py_DECREF(rhs_batch_dimensions);
643      return NULL;
644    }
645    dimension_numbers.add_rhs_batch_dimensions(dimension);
646    Py_DECREF(item);
647  }
648  Py_DECREF(rhs_batch_dimensions);
649
650  $1 = &dimension_numbers;
651}
652
653// PaddingConfig
654
655%typemap(in) const PaddingConfig&
656    (PaddingConfig padding_config) {
657  PyObject* dimensions = PyObject_GetAttrString($input, "dimensions");
658  if (!dimensions) {
659    return NULL;
660  }
661
662  int length = PySequence_Size(dimensions);
663  if (length == -1) {
664    Py_DECREF(dimensions);
665    return NULL;
666  }
667
668  for (int i = 0; i < length; ++i) {
669    PyObject* item = PySequence_GetItem(dimensions, i);
670    if (!item) {
671      Py_DECREF(dimensions);
672      return NULL;
673    }
674    int64 edge_padding_low, edge_padding_high, interior_padding;
675    if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low)
676        || !GetIntAttr(item, "edge_padding_high", &edge_padding_high)
677        || !GetIntAttr(item, "interior_padding", &interior_padding)) {
678      Py_DECREF(item);
679      Py_DECREF(dimensions);
680      return NULL;
681    }
682    Py_DECREF(item);
683
684    PaddingConfig::PaddingConfigDimension* dimension =
685        padding_config.add_dimensions();
686    dimension->set_edge_padding_low(edge_padding_low);
687    dimension->set_edge_padding_high(edge_padding_high);
688    dimension->set_interior_padding(interior_padding);
689  }
690  Py_DECREF(dimensions);
691
692  $1 = &padding_config;
693}
694
695// ConvolutionDimensionNumbers
696
697%typemap(in) const ConvolutionDimensionNumbers&
698    (ConvolutionDimensionNumbers dimension_numbers) {
699  int64 value;
700
701  if (!GetIntAttr($input, "input_batch_dimension", &value)) {
702    return NULL;
703  }
704  dimension_numbers.set_input_batch_dimension(value);
705
706  if (!GetIntAttr($input, "input_feature_dimension", &value)) {
707    return NULL;
708  }
709  dimension_numbers.set_input_feature_dimension(value);
710
711  if (!GetIntAttr($input, "output_batch_dimension", &value)) {
712    return NULL;
713  }
714  dimension_numbers.set_output_batch_dimension(value);
715
716  if (!GetIntAttr($input, "output_feature_dimension", &value)) {
717    return NULL;
718  }
719  dimension_numbers.set_output_feature_dimension(value);
720
721  if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) {
722    return NULL;
723  }
724  dimension_numbers.set_kernel_output_feature_dimension(value);
725
726  if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) {
727    return NULL;
728  }
729  dimension_numbers.set_kernel_input_feature_dimension(value);
730
731  PyObject* o;
732  int length;
733
734  o = PyObject_GetAttrString($input, "input_spatial_dimensions");
735  if (!o) {
736    return NULL;
737  }
738  length = PySequence_Size(o);
739  if (length == -1) {
740    Py_DECREF(o);
741    return NULL;
742  }
743  for (int i = 0; i < length; ++i) {
744    PyObject* item = PySequence_GetItem(o, i);
745    if (!item) {
746      Py_DECREF(o);
747      return NULL;
748    }
749    const int64 dimension = numpy::PyIntOrPyLongToLong(item);
750    if (dimension == -1 && PyErr_Occurred()) {
751      Py_DECREF(item);
752      Py_DECREF(o);
753      return NULL;
754    }
755    dimension_numbers.add_input_spatial_dimensions(dimension);
756    Py_DECREF(item);
757  }
758  Py_DECREF(o);
759
760  o = PyObject_GetAttrString($input, "kernel_spatial_dimensions");
761  if (!o) {
762    return NULL;
763  }
764  length = PySequence_Size(o);
765  if (length == -1) {
766    Py_DECREF(o);
767    return NULL;
768  }
769  for (int i = 0; i < length; ++i) {
770    PyObject* item = PySequence_GetItem(o, i);
771    if (!item) {
772      Py_DECREF(o);
773      return NULL;
774    }
775    const int64 dimension = numpy::PyIntOrPyLongToLong(item);
776    if (dimension == -1 && PyErr_Occurred()) {
777      Py_DECREF(item);
778      Py_DECREF(o);
779      return NULL;
780    }
781    dimension_numbers.add_kernel_spatial_dimensions(dimension);
782    Py_DECREF(item);
783  }
784  Py_DECREF(o);
785
786  o = PyObject_GetAttrString($input, "output_spatial_dimensions");
787  if (!o) {
788    return NULL;
789  }
790  length = PySequence_Size(o);
791  if (length == -1) {
792    Py_DECREF(o);
793    return NULL;
794  }
795  for (int i = 0; i < length; ++i) {
796    PyObject* item = PySequence_GetItem(o, i);
797    if (!item) {
798      Py_DECREF(o);
799      return NULL;
800    }
801    const int64 dimension = numpy::PyIntOrPyLongToLong(item);
802    if (dimension == -1 && PyErr_Occurred()) {
803      Py_DECREF(item);
804      Py_DECREF(o);
805      return NULL;
806    }
807    dimension_numbers.add_output_spatial_dimensions(dimension);
808    Py_DECREF(item);
809  }
810  Py_DECREF(o);
811
812  $1 = &dimension_numbers;
813}
814
815// ExecutableBuildOptions
816
817%typemap(in) const ExecutableBuildOptions*
818    (ExecutableBuildOptions build_options) {
819  if ($input == Py_None) {
820    $1 = NULL;
821  } else {
822    PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph");
823    if (!o) {
824      return NULL;
825    }
826    if (o != Py_None) {
827      if (!PyString_Check(o)) {
828        PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None.");
829        return NULL;
830      }
831      build_options.set_generate_hlo_graph(PyString_AsString(o));
832    }
833    Py_DECREF(o);
834
835    o = PyObject_GetAttrString($input, "result_shape");
836    if (o == nullptr) {
837      return nullptr;
838    }
839    if (o != Py_None) {
840      StatusOr<Shape> statusor = numpy::XlaShapeFromPyShape(o);
841      if (!statusor.ok()) {
842        PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str());
843        Py_DECREF(o);
844        return NULL;
845      }
846      build_options.set_result_layout(statusor.ValueOrDie());
847    }
848    Py_DECREF(o);
849
850    $1 = &build_options;
851  }
852}
853
854%ignoreall
855%unignore xla;
856%unignore xla::swig;
857%unignore xla::swig::InitializeReplicaCount;
858%unignore xla::swig::GetReplicaCount;
859%unignore xla::swig::TransferToInfeedLocal;
860%unignore xla::swig::TransferToInfeedLocalReplica;
861%unignore xla::swig::TransferFromOutfeedLocalReplica;
862%unignore xla::swig::LocalShapedBuffer;
863%unignore xla::swig::LocalShapedBuffer::FromLiteral;
864%unignore xla::swig::LocalShapedBuffer::ToLiteral;
865%unignore xla::swig::CompiledLocalComputation;
866%unignore xla::swig::CompiledLocalComputation::Execute;
867%unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers;
868%unignore xla::swig::LocalComputation;
869%unignore xla::swig::LocalComputation::Compile;
870%unignore xla::swig::LocalComputation::GetReturnValueShape;
871%unignore xla::swig::LocalComputationBuilder;
872%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder;
873%unignore xla::swig::LocalComputationBuilder::Build;
874%unignore xla::swig::LocalComputationBuilder::SetOpMetadata;
875%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata;
876%unignore xla::swig::LocalComputationBuilder::Parameter;
877%unignore xla::swig::LocalComputationBuilder::GetShape;
878%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape;
879%unignore xla::swig::LocalComputationBuilder::Infeed;
880%unignore xla::swig::LocalComputationBuilder::Outfeed;
881%unignore xla::swig::LocalComputationBuilder::ConstantLiteral;
882%unignore xla::swig::LocalComputationBuilder::ConstantR0;
883%unignore xla::swig::LocalComputationBuilder::Broadcast;
884%unignore xla::swig::LocalComputationBuilder::Pad;
885%unignore xla::swig::LocalComputationBuilder::Reshape;
886%unignore xla::swig::LocalComputationBuilder::Collapse;
887%unignore xla::swig::LocalComputationBuilder::CrossReplicaSum;
888%unignore xla::swig::LocalComputationBuilder::Slice;
889%unignore xla::swig::LocalComputationBuilder::DynamicSlice;
890%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice;
891%unignore xla::swig::LocalComputationBuilder::ConcatInDim;
892%unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding;
893%unignore xla::swig::LocalComputationBuilder::Select;
894%unignore xla::swig::LocalComputationBuilder::Tuple;
895%unignore xla::swig::LocalComputationBuilder::GetTupleElement;
896%unignore xla::swig::LocalComputationBuilder::ConvertElementType;
897%unignore xla::swig::LocalComputationBuilder::Call;
898%unignore xla::swig::LocalComputationBuilder::Transpose;
899%unignore xla::swig::LocalComputationBuilder::Rev;
900%unignore xla::swig::LocalComputationBuilder::Clamp;
901%unignore xla::swig::LocalComputationBuilder::Map;
902%unignore xla::swig::LocalComputationBuilder::Reduce;
903%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding;
904%unignore xla::swig::LocalComputationBuilder::RngNormal;
905%unignore xla::swig::LocalComputationBuilder::RngUniform;
906%unignore xla::swig::LocalComputationBuilder::RngBernoulli;
907%unignore xla::swig::LocalComputationBuilder::While;
908%unignore xla::swig::LocalComputationBuilder::Conditional;
909%unignore xla::swig::LocalComputationBuilder::Eq;
910%unignore xla::swig::LocalComputationBuilder::Ne;
911%unignore xla::swig::LocalComputationBuilder::Ge;
912%unignore xla::swig::LocalComputationBuilder::Gt;
913%unignore xla::swig::LocalComputationBuilder::Lt;
914%unignore xla::swig::LocalComputationBuilder::Le;
915%unignore xla::swig::LocalComputationBuilder::Dot;
916%unignore xla::swig::LocalComputationBuilder::DotGeneral;
917%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated;
918%unignore xla::swig::LocalComputationBuilder::Add;
919%unignore xla::swig::LocalComputationBuilder::Sub;
920%unignore xla::swig::LocalComputationBuilder::Mul;
921%unignore xla::swig::LocalComputationBuilder::Div;
922%unignore xla::swig::LocalComputationBuilder::Rem;
923%unignore xla::swig::LocalComputationBuilder::Max;
924%unignore xla::swig::LocalComputationBuilder::Min;
925%unignore xla::swig::LocalComputationBuilder::And;
926%unignore xla::swig::LocalComputationBuilder::Or;
927%unignore xla::swig::LocalComputationBuilder::Not;
928%unignore xla::swig::LocalComputationBuilder::Abs;
929%unignore xla::swig::LocalComputationBuilder::Exp;
930%unignore xla::swig::LocalComputationBuilder::Floor;
931%unignore xla::swig::LocalComputationBuilder::Ceil;
932%unignore xla::swig::LocalComputationBuilder::Round;
933%unignore xla::swig::LocalComputationBuilder::Log;
934%unignore xla::swig::LocalComputationBuilder::Sign;
935%unignore xla::swig::LocalComputationBuilder::Cos;
936%unignore xla::swig::LocalComputationBuilder::Sin;
937%unignore xla::swig::LocalComputationBuilder::Tanh;
938%unignore xla::swig::LocalComputationBuilder::SqrtF32;
939%unignore xla::swig::LocalComputationBuilder::SquareF32;
940%unignore xla::swig::LocalComputationBuilder::Pow;
941%unignore xla::swig::LocalComputationBuilder::IsFinite;
942%unignore xla::swig::LocalComputationBuilder::ReciprocalF32;
943%unignore xla::swig::LocalComputationBuilder::Neg;
944%unignore xla::swig::LocalComputationBuilder::Sort;
945%unignore xla::swig::DeleteLocalShapedBuffer;
946%unignore xla::swig::DeleteLocalComputation;
947%unignore xla::swig::DeleteCompiledLocalComputation;
948
949%thread;
950%include "tensorflow/compiler/xla/python/local_computation_builder.h"
951%nothread;
952
953%unignoreall
954