1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/client/computation_builder.h"
17
18#include <stddef.h>
19#include <array>
20#include <numeric>
21#include <set>
22#include <vector>
23
24#include "tensorflow/compiler/xla/ptr_util.h"
25#include "tensorflow/compiler/xla/shape_util.h"
26#include "tensorflow/compiler/xla/status_macros.h"
27#include "tensorflow/compiler/xla/types.h"
28#include "tensorflow/compiler/xla/util.h"
29#include "tensorflow/compiler/xla/xla.pb.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/lib/strings/strcat.h"
32#include "tensorflow/core/platform/logging.h"
33#include "tensorflow/core/platform/protobuf.h"
34
35namespace xla {
36
37ComputationBuilder::ComputationBuilder(Client* client,
38                                       const string& computation_name)
39    : name_(computation_name), client_(client) {}
40
41ComputationBuilder::~ComputationBuilder() {}
42
43void ComputationBuilder::NoteError(const Status& error) {
44  if (die_immediately_on_error_) {
45    LOG(FATAL) << "error building computation: " << error;
46  }
47
48  if (first_error_.ok()) {
49    first_error_ = error;
50    first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
51  }
52}
53
54std::unique_ptr<ComputationBuilder> ComputationBuilder::CreateSubBuilder(
55    const string& computation_name) {
56  auto sub_builder = MakeUnique<ComputationBuilder>(client_, computation_name);
57  sub_builder->parent_builder_ = this;
58  sub_builder->die_immediately_on_error_ = die_immediately_on_error_;
59  return sub_builder;
60}
61
62Status ComputationBuilder::PrepareComputation() {
63  TF_RETURN_IF_ERROR(first_error_);
64
65  if (!computation_.IsNull()) {
66    return Status::OK();
67  }
68
69  ComputationRequest request;
70  request.set_name(name_);
71  ComputationResponse response;
72
73  VLOG(2) << "making computation request";
74  Status s = client_->stub()->Computation(&request, &response);
75  VLOG(2) << "done with computation request";
76
77  if (!s.ok()) {
78    NoteError(s);
79    return first_error_;
80  }
81
82  computation_ = Computation(client_->stub(), response.computation());
83  return Status::OK();
84}
85
86Status ComputationBuilder::RunOp(OpRequest* op_request,
87                                 OpResponse* op_response) {
88  TF_RETURN_IF_ERROR(first_error_);
89  TF_RETURN_IF_ERROR(PrepareComputation());
90
91  // Fill in fields that are set on every OpRequest.
92  *op_request->mutable_computation() = computation_.handle();
93  *op_request->mutable_metadata() = metadata_;
94  if (sharding_) {
95    *op_request->mutable_sharding() = *sharding_;
96  }
97
98  const string& op_name =
99      OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name();
100  VLOG(2) << "running op request: " << op_name;
101  Status status = client_->stub()->Op(op_request, op_response);
102  VLOG(2) << "done with op request: " << op_name;
103  return status;
104}
105
106void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) {
107  OpResponse op_response;
108  Status status = RunOp(op_request, &op_response);
109  if (!status.ok()) {
110    NoteError(status);
111  }
112}
113
114ComputationDataHandle ComputationBuilder::RunOpAndParseResponse(
115    OpRequest* op_request) {
116  OpResponse op_response;
117  Status status = RunOp(op_request, &op_response);
118  if (!status.ok()) {
119    NoteError(status);
120    return ComputationDataHandle();
121  }
122  if (op_response.output().handle() == 0) {
123    NoteError(InternalError("No output handle"));
124    return ComputationDataHandle();
125  }
126  return op_response.output();
127}
128
129bool ComputationBuilder::MakeWindow(
130    tensorflow::gtl::ArraySlice<int64> window_dimensions,
131    tensorflow::gtl::ArraySlice<int64> window_strides,
132    tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
133    tensorflow::gtl::ArraySlice<int64> lhs_dilation,
134    tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) {
135  const auto verify_size = [&](const size_t x, const char* x_name) {
136    if (x == 0 || x == window_dimensions.size()) {
137      return true;
138    } else {
139      NoteError(InvalidArgument(
140          "%s", tensorflow::strings::StrCat(
141                    "Window has different number of window dimensions than of ",
142                    x_name, "\nNumber of window dimensions: ",
143                    window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
144                    "\n")
145                    .c_str()));  //
146      return false;
147    }
148  };
149  if (!verify_size(window_strides.size(), "window strides") ||
150      !verify_size(padding.size(), "padding entries") ||
151      !verify_size(lhs_dilation.size(), "lhs dilation factors") ||
152      !verify_size(rhs_dilation.size(), "rhs dilation factors")) {
153    return false;
154  }
155
156  window->Clear();
157  for (size_t i = 0; i < window_dimensions.size(); i++) {
158    auto dim = window->add_dimensions();
159    dim->set_size(window_dimensions[i]);
160    if (!window_strides.empty()) {
161      dim->set_stride(window_strides[i]);
162    } else {
163      dim->set_stride(1);
164    }
165    if (!padding.empty()) {
166      dim->set_padding_low(padding[i].first);
167      dim->set_padding_high(padding[i].second);
168    } else {
169      dim->set_padding_low(0);
170      dim->set_padding_high(0);
171    }
172    if (!lhs_dilation.empty()) {
173      dim->set_base_dilation(lhs_dilation[i]);
174    } else {
175      dim->set_base_dilation(1);
176    }
177    if (!rhs_dilation.empty()) {
178      dim->set_window_dilation(rhs_dilation[i]);
179    } else {
180      dim->set_window_dilation(1);
181    }
182    dim->set_window_reversal(false);
183  }
184  return true;
185}
186
187ComputationDataHandle ComputationBuilder::ConstantLiteral(
188    const Literal& literal) {
189  OpRequest op_request;
190  ConstantRequest* request = op_request.mutable_constant_request();
191  *request->mutable_literal() = literal.ToProto();
192  VLOG(3) << "created constant: " << request->literal().ShortDebugString();
193  return RunOpAndParseResponse(&op_request);
194}
195
196ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number,
197                                                    const Shape& shape,
198                                                    const string& name) {
199  OpRequest op_request;
200  ParameterRequest* request = op_request.mutable_parameter_request();
201  *request->mutable_shape() = shape;
202  request->set_parameter(parameter_number);
203  request->set_name(name);
204  return RunOpAndParseResponse(&op_request);
205}
206
207StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShapeWithoutNoteError(
208    const ComputationDataHandle& operand) {
209  GetLocalShapeRequest request;
210  *request.mutable_computation() = computation_.handle();
211  *request.mutable_operand() = operand;
212  GetLocalShapeResponse response;
213
214  VLOG(2) << "making get-shape request";
215  TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response));
216  VLOG(2) << "done with request";
217
218  TF_RET_CHECK(response.has_shape());
219  std::unique_ptr<Shape> shape = WrapUnique(response.release_shape());
220  TF_RET_CHECK(shape != nullptr);
221  return std::move(shape);
222}
223
224StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShape(
225    const ComputationDataHandle& operand) {
226  TF_RETURN_IF_ERROR(first_error_);
227
228  auto status_or_shape = GetShapeWithoutNoteError(operand);
229  if (!status_or_shape.ok()) {
230    NoteError(status_or_shape.status());
231    return first_error_;
232  }
233  return status_or_shape;
234}
235
236StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() {
237  TF_RETURN_IF_ERROR(first_error_);
238
239  GetComputationShapeRequest request;
240  *request.mutable_computation() = computation_.handle();
241  GetComputationShapeResponse response;
242
243  VLOG(2) << "making get-program-shape-request";
244  Status status = client_->stub()->GetComputationShape(&request, &response);
245  VLOG(2) << "done with get-program-shape-request";
246
247  if (!status.ok()) {
248    first_error_ = status;
249    return status;
250  }
251
252  TF_RET_CHECK(response.has_program_shape());
253  return std::move(*response.mutable_program_shape());
254}
255
256ComputationDataHandle ComputationBuilder::CheckShape(
257    const ComputationDataHandle& operand, const Shape& expected_shape) {
258  std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie();
259  CHECK(ShapeUtil::Equal(expected_shape, *actual_shape))
260      << "want " << ShapeUtil::HumanString(expected_shape) << " got "
261      << ShapeUtil::HumanString(*actual_shape);
262  return operand;
263}
264
265void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs,
266                                        const ComputationDataHandle& rhs) {
267  std::unique_ptr<Shape> lhs_shape = GetShape(lhs).ConsumeValueOrDie();
268  std::unique_ptr<Shape> rhs_shape = GetShape(rhs).ConsumeValueOrDie();
269  VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals "
270          << ShapeUtil::HumanString(*rhs_shape);
271  CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape))
272      << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs "
273      << ShapeUtil::HumanString(*rhs_shape);
274}
275
276ComputationDataHandle ComputationBuilder::Slice(
277    const ComputationDataHandle& operand,
278    tensorflow::gtl::ArraySlice<int64> start_indices,
279    tensorflow::gtl::ArraySlice<int64> limit_indices,
280    tensorflow::gtl::ArraySlice<int64> strides) {
281  OpRequest op_request;
282  SliceRequest* request = op_request.mutable_slice_request();
283  *request->mutable_operand() = operand;
284  for (int64 index : start_indices) {
285    request->add_start_indices(index);
286  }
287  for (int64 index : limit_indices) {
288    request->add_limit_indices(index);
289  }
290  for (int64 index : strides) {
291    request->add_strides(index);
292  }
293  return RunOpAndParseResponse(&op_request);
294}
295
296ComputationDataHandle ComputationBuilder::SliceInDim(
297    const ComputationDataHandle& operand, int64 start_index, int64 limit_index,
298    int64 stride, int64 dimno) {
299  StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
300  if (!shape_status.ok()) {
301    NoteError(shape_status.status());
302    return ComputationDataHandle{};
303  }
304  const Shape& shape = *shape_status.ValueOrDie();
305  std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
306  std::vector<int64> limits(shape.dimensions().begin(),
307                            shape.dimensions().end());
308  std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
309  starts[dimno] = start_index;
310  limits[dimno] = limit_index;
311  strides[dimno] = stride;
312  return Slice(operand, starts, limits, strides);
313}
314
315ComputationDataHandle ComputationBuilder::DynamicSlice(
316    const ComputationDataHandle& operand,
317    const ComputationDataHandle& start_indices,
318    tensorflow::gtl::ArraySlice<int64> slice_sizes) {
319  OpRequest op_request;
320  DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request();
321  *request->mutable_operand() = operand;
322  *request->mutable_start_indices() = start_indices;
323  for (int64 index : slice_sizes) {
324    request->add_slice_sizes(index);
325  }
326  return RunOpAndParseResponse(&op_request);
327}
328
329ComputationDataHandle ComputationBuilder::DynamicUpdateSlice(
330    const ComputationDataHandle& operand, const ComputationDataHandle& update,
331    const ComputationDataHandle& start_indices) {
332  OpRequest op_request;
333  DynamicUpdateSliceRequest* request =
334      op_request.mutable_dynamic_update_slice_request();
335  *request->mutable_operand() = operand;
336  *request->mutable_update() = update;
337  *request->mutable_start_indices() = start_indices;
338  return RunOpAndParseResponse(&op_request);
339}
340
341ComputationDataHandle ComputationBuilder::ConcatInDim(
342    tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
343    int64 dimension) {
344  OpRequest op_request;
345  ConcatenateRequest* request = op_request.mutable_concatenate_request();
346  for (const ComputationDataHandle& operand : operands) {
347    *request->add_operands() = operand;
348  }
349  request->set_dimension(dimension);
350  return RunOpAndParseResponse(&op_request);
351}
352
353ComputationDataHandle ComputationBuilder::Broadcast(
354    const ComputationDataHandle& operand,
355    tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
356  OpRequest op_request;
357  BroadcastRequest* request = op_request.mutable_broadcast_request();
358  *request->mutable_operand() = operand;
359  for (int64 size : broadcast_sizes) {
360    request->add_broadcast_sizes(size);
361  }
362  return RunOpAndParseResponse(&op_request);
363}
364
365ComputationDataHandle ComputationBuilder::Pad(
366    const ComputationDataHandle& operand,
367    const ComputationDataHandle& padding_value,
368    const PaddingConfig& padding_config) {
369  OpRequest op_request;
370  PadRequest* request = op_request.mutable_pad_request();
371  *request->mutable_operand() = operand;
372  *request->mutable_padding_value() = padding_value;
373  *request->mutable_padding_config() = padding_config;
374  return RunOpAndParseResponse(&op_request);
375}
376
377ComputationDataHandle ComputationBuilder::Reshape(
378    const ComputationDataHandle& operand,
379    tensorflow::gtl::ArraySlice<int64> dimensions,
380    tensorflow::gtl::ArraySlice<int64> new_sizes) {
381  OpRequest op_request;
382  ReshapeRequest* request = op_request.mutable_reshape_request();
383  *request->mutable_operand() = operand;
384  for (int64 dimension : dimensions) {
385    request->add_dimensions(dimension);
386  }
387  for (int64 new_size : new_sizes) {
388    request->add_new_sizes(new_size);
389  }
390  return RunOpAndParseResponse(&op_request);
391}
392
393ComputationDataHandle ComputationBuilder::Reshape(
394    const ComputationDataHandle& operand,
395    tensorflow::gtl::ArraySlice<int64> new_sizes) {
396  if (!first_error_.ok()) {
397    return ComputationDataHandle();
398  }
399
400  StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
401  if (!shape.ok()) {
402    return ComputationDataHandle();
403  }
404  std::vector<int64> dimensions(shape.ValueOrDie()->dimensions().size());
405  std::iota(dimensions.begin(), dimensions.end(), 0);
406  return Reshape(operand, dimensions, new_sizes);
407}
408
409ComputationDataHandle ComputationBuilder::Collapse(
410    const ComputationDataHandle& operand,
411    tensorflow::gtl::ArraySlice<int64> dims_to_collapse) {
412  if (!first_error_.ok()) {
413    return ComputationDataHandle();
414  }
415
416  // Don't support out-of-order collapse here.
417  // Checks that the collapsed dimensions are in order and consecutive.
418  for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
419       i < dims_to_collapse.size(); ++i) {
420    if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) {
421      NoteError(InvalidArgument(
422          "Collapsed dimensions are not in order and consecutive."));
423      return ComputationDataHandle();
424    }
425  }
426
427  // Create a new sizes vector from the old shape, replacing the collapsed
428  // dimensions by the product of their sizes.
429  StatusOr<std::unique_ptr<Shape>> shape_or_status = GetShape(operand);
430  if (!shape_or_status.ok()) {
431    return ComputationDataHandle();
432  }
433  std::unique_ptr<Shape> original_shape = shape_or_status.ConsumeValueOrDie();
434
435  VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape);
436  VLOG(3) << "dims to collapse: "
437          << tensorflow::str_util::Join(dims_to_collapse, ",");
438
439  if (dims_to_collapse.size() <= 1) {
440    // Not collapsing anything, trivially we can return the operand versus
441    // enqueueing a trivial reshape.
442    return operand;
443  }
444
445  std::vector<int64> new_sizes;
446  for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) {
447    if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) {
448      new_sizes.push_back(original_shape->dimensions(i));
449    } else {
450      new_sizes.back() *= original_shape->dimensions(i);
451    }
452  }
453
454  VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
455          << "]";
456
457  return Reshape(operand, new_sizes);
458}
459
460void ComputationBuilder::Trace(const string& tag,
461                               const ComputationDataHandle& operand) {
462  OpRequest op_request;
463  TraceRequest* request = op_request.mutable_trace_request();
464  request->set_tag(tag);
465  *request->mutable_operand() = operand;
466  RunOpAndNoteError(&op_request);
467}
468
469ComputationDataHandle ComputationBuilder::Select(
470    const ComputationDataHandle& pred, const ComputationDataHandle& on_true,
471    const ComputationDataHandle& on_false) {
472  return TernaryOp(TRIOP_SELECT, pred, on_true, on_false);
473}
474
475ComputationDataHandle ComputationBuilder::Tuple(
476    tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) {
477  OpRequest op_request;
478  VariadicOpRequest* request = op_request.mutable_variadic_op_request();
479  request->set_varop(VAROP_TUPLE);
480  for (const ComputationDataHandle& operand : elements) {
481    *request->add_operands() = operand;
482  }
483  return RunOpAndParseResponse(&op_request);
484}
485
486ComputationDataHandle ComputationBuilder::GetTupleElement(
487    const ComputationDataHandle& tuple_data, int64 index) {
488  OpRequest op_request;
489  GetTupleElementRequest* request =
490      op_request.mutable_get_tuple_element_request();
491  *request->mutable_operand() = tuple_data;
492  request->set_index(index);
493  return RunOpAndParseResponse(&op_request);
494}
495
496ComputationDataHandle ComputationBuilder::Eq(
497    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
498    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
499  return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions);
500}
501
502ComputationDataHandle ComputationBuilder::Ne(
503    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
504    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
505  return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions);
506}
507
508ComputationDataHandle ComputationBuilder::Ge(
509    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
510    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
511  return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions);
512}
513
514ComputationDataHandle ComputationBuilder::Gt(
515    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
516    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
517  return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions);
518}
519
520ComputationDataHandle ComputationBuilder::Le(
521    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
522    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
523  return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions);
524}
525
526ComputationDataHandle ComputationBuilder::Lt(
527    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
528    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
529  return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions);
530}
531
532ComputationDataHandle ComputationBuilder::Dot(
533    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
534  StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
535  if (!lhs_shape_or_status.ok()) {
536    return ComputationDataHandle();
537  }
538  std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
539
540  DotDimensionNumbers dimension_numbers;
541  dimension_numbers.add_lhs_contracting_dimensions(
542      lhs_shape->dimensions_size() == 1 ? 0 : 1);
543  dimension_numbers.add_rhs_contracting_dimensions(0);
544  return DotGeneral(lhs, rhs, dimension_numbers);
545}
546
547ComputationDataHandle ComputationBuilder::DotGeneral(
548    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
549    const DotDimensionNumbers& dimension_numbers) {
550  OpRequest op_request;
551  DotRequest* request = op_request.mutable_dot_request();
552  *request->mutable_lhs() = lhs;
553  *request->mutable_rhs() = rhs;
554  *request->mutable_dimension_numbers() = dimension_numbers;
555  return RunOpAndParseResponse(&op_request);
556}
557
558ComputationDataHandle ComputationBuilder::Conv(
559    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
560    tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
561  return ConvWithGeneralDimensions(
562      lhs, rhs, window_strides, padding,
563      CreateDefaultConvDimensionNumbers(window_strides.size()));
564}
565
566ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding(
567    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
568    tensorflow::gtl::ArraySlice<int64> window_strides,
569    tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
570  return ConvGeneral(lhs, rhs, window_strides, padding,
571                     CreateDefaultConvDimensionNumbers(window_strides.size()));
572}
573
574bool ComputationBuilder::VerifyConvolution(
575    const Shape& lhs_shape, const Shape& rhs_shape,
576    const ConvolutionDimensionNumbers& dimension_numbers) {
577  if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
578    NoteError(
579        InvalidArgument("Convolution arguments must have same number of "
580                        "dimensions. Got: %s and %s",
581                        ShapeUtil::HumanString(lhs_shape).c_str(),
582                        ShapeUtil::HumanString(rhs_shape).c_str()));
583    return false;
584  }
585  int num_dims = ShapeUtil::Rank(lhs_shape);
586  if (num_dims < 2) {
587    NoteError(InvalidArgument(
588        "Convolution expects argument arrays with >= 3 dimensions. "
589        "Got: %s and %s",
590        ShapeUtil::HumanString(lhs_shape).c_str(),
591        ShapeUtil::HumanString(rhs_shape).c_str()));
592    return false;
593  }
594  int num_spatial_dims = num_dims - 2;
595
596  const auto check_spatial_dimensions =
597      [&](const char* const field_name,
598          const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
599              numbers) {
600        if (numbers.size() != num_spatial_dims) {
601          NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
602                                    num_spatial_dims, field_name,
603                                    numbers.size()));
604          return false;
605        }
606        for (int i = 0; i < numbers.size(); ++i) {
607          if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
608            NoteError(
609                InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
610                                field_name, i, numbers.Get(i)));
611            return false;
612          }
613        }
614        return true;
615      };
616  return check_spatial_dimensions(
617             "input_spatial_dimensions",
618             dimension_numbers.input_spatial_dimensions()) &&
619         check_spatial_dimensions(
620             "kernel_spatial_dimensions",
621             dimension_numbers.kernel_spatial_dimensions()) &&
622         check_spatial_dimensions(
623             "output_spatial_dimensions",
624             dimension_numbers.output_spatial_dimensions());
625}
626
627ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions(
628    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
629    tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
630    const ConvolutionDimensionNumbers& dimension_numbers) {
631  if (!first_error_.ok() || !PrepareComputation().ok()) {
632    return ComputationDataHandle();
633  }
634
635  StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
636  if (!lhs_shape_or_status.ok()) {
637    return ComputationDataHandle();
638  }
639
640  StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
641  if (!rhs_shape_or_status.ok()) {
642    return ComputationDataHandle();
643  }
644
645  std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
646  std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
647
648  if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
649    NoteError(InternalError("failed to verify convolution"));
650    return ComputationDataHandle();
651  }
652
653  std::vector<int64> base_area_dimensions(
654      dimension_numbers.input_spatial_dimensions_size());
655  for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
656       ++i) {
657    base_area_dimensions[i] =
658        lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i));
659  }
660
661  std::vector<int64> window_dimensions(
662      dimension_numbers.kernel_spatial_dimensions_size());
663  for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
664    window_dimensions[i] =
665        rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
666  }
667
668  return ConvGeneral(lhs, rhs, window_strides,
669                     MakePadding(base_area_dimensions, window_dimensions,
670                                 window_strides, padding),
671                     dimension_numbers);
672}
673
674ComputationDataHandle ComputationBuilder::ConvGeneral(
675    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
676    tensorflow::gtl::ArraySlice<int64> window_strides,
677    tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
678    const ConvolutionDimensionNumbers& dimension_numbers) {
679  return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
680                            dimension_numbers);
681}
682
683ComputationDataHandle ComputationBuilder::ConvGeneralDilated(
684    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
685    tensorflow::gtl::ArraySlice<int64> window_strides,
686    tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
687    tensorflow::gtl::ArraySlice<int64> lhs_dilation,
688    tensorflow::gtl::ArraySlice<int64> rhs_dilation,
689    const ConvolutionDimensionNumbers& dimension_numbers) {
690  if (!first_error_.ok() || !PrepareComputation().ok()) {
691    return ComputationDataHandle();
692  }
693
694  StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
695  if (!lhs_shape_or_status.ok()) {
696    return ComputationDataHandle();
697  }
698
699  StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs);
700  if (!rhs_shape_or_status.ok()) {
701    return ComputationDataHandle();
702  }
703
704  std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
705  std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie();
706  if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) {
707    // Error is recorded in VerifyConvolution.
708    return ComputationDataHandle();
709  }
710
711  std::vector<int64> window_dimensions(
712      dimension_numbers.kernel_spatial_dimensions_size());
713  for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) {
714    window_dimensions[i] =
715        rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i));
716  }
717
718  OpRequest op_request;
719  ConvolveRequest* request = op_request.mutable_convolve_request();
720  *request->mutable_lhs() = lhs;
721  *request->mutable_rhs() = rhs;
722  *request->mutable_dimension_numbers() = dimension_numbers;
723
724  if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation,
725                  rhs_dilation, request->mutable_window())) {
726    // Error is recorded in MakeWindow.
727    return ComputationDataHandle();
728  }
729
730  return RunOpAndParseResponse(&op_request);
731}
732
733ComputationDataHandle ComputationBuilder::Fft(
734    const ComputationDataHandle& operand, const FftType fft_type,
735    const tensorflow::gtl::ArraySlice<int64> fft_length) {
736  OpRequest op_request;
737  FftRequest* request = op_request.mutable_fft_request();
738  *request->mutable_operand() = operand;
739  request->set_fft_type(fft_type);
740  for (int64 dim_len : fft_length) {
741    request->add_fft_length(dim_len);
742  }
743  return RunOpAndParseResponse(&op_request);
744}
745
746ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape,
747                                                 const string& config) {
748  OpRequest op_request;
749  InfeedRequest* request = op_request.mutable_infeed_request();
750  *request->mutable_shape() = shape;
751  *request->mutable_config() = config;
752  return RunOpAndParseResponse(&op_request);
753}
754
755void ComputationBuilder::Outfeed(const ComputationDataHandle& operand,
756                                 const Shape& shape,
757                                 const string& outfeed_config) {
758  OpRequest op_request;
759  OutfeedRequest* request = op_request.mutable_outfeed_request();
760  request->set_outfeed_config(outfeed_config);
761  *request->mutable_operand() = operand;
762  *request->mutable_shape() = shape;
763  RunOpAndNoteError(&op_request);
764}
765
766ComputationDataHandle ComputationBuilder::Call(
767    const Computation& computation,
768    tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) {
769  OpRequest op_request;
770  CallRequest* request = op_request.mutable_call_request();
771  *request->mutable_to_apply() = computation.handle();
772  for (const ComputationDataHandle& operand : operands) {
773    *request->add_operands() = operand;
774  }
775  return RunOpAndParseResponse(&op_request);
776}
777
778ComputationDataHandle ComputationBuilder::CustomCall(
779    const string& call_target_name,
780    tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
781    const Shape& shape) {
782  OpRequest op_request;
783  CustomCallRequest* request = op_request.mutable_custom_call_request();
784  request->set_call_target_name(call_target_name);
785  for (const ComputationDataHandle& operand : operands) {
786    *request->add_operands() = operand;
787  }
788  *request->mutable_shape() = shape;
789  return RunOpAndParseResponse(&op_request);
790}
791
792ComputationDataHandle ComputationBuilder::HostCompute(
793    tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
794    const string& channel_name, int64 cost_estimate_ns, const Shape& shape) {
795  OpRequest op_request;
796  HostComputeRequest* request = op_request.mutable_host_compute_request();
797  for (const ComputationDataHandle& operand : operands) {
798    *request->add_operands() = operand;
799  }
800  *request->mutable_shape() = shape;
801  request->set_channel_name(channel_name);
802  request->set_cost_estimate_ns(cost_estimate_ns);
803  return RunOpAndParseResponse(&op_request);
804}
805
806ComputationDataHandle ComputationBuilder::Complex(
807    const ComputationDataHandle& real, const ComputationDataHandle& imag,
808    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
809  return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions);
810}
811
812ComputationDataHandle ComputationBuilder::Conj(
813    const ComputationDataHandle& operand) {
814  return Complex(Real(operand), Neg(Imag(operand)));
815}
816
817ComputationDataHandle ComputationBuilder::Add(
818    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
819    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
820  return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions);
821}
822
823ComputationDataHandle ComputationBuilder::Sub(
824    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
825    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
826  return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions);
827}
828
829ComputationDataHandle ComputationBuilder::Mul(
830    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
831    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
832  return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions);
833}
834
835ComputationDataHandle ComputationBuilder::Div(
836    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
837    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
838  return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions);
839}
840
841ComputationDataHandle ComputationBuilder::Rem(
842    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
843    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
844  return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions);
845}
846
847ComputationDataHandle ComputationBuilder::Max(
848    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
849    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
850  return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions);
851}
852
853ComputationDataHandle ComputationBuilder::Min(
854    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
855    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
856  return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions);
857}
858
859ComputationDataHandle ComputationBuilder::And(
860    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
861    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
862  return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions);
863}
864
865ComputationDataHandle ComputationBuilder::Or(
866    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
867    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
868  return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions);
869}
870
871ComputationDataHandle ComputationBuilder::Not(
872    const ComputationDataHandle& operand) {
873  return UnaryOp(UNOP_NOT, operand);
874}
875
876ComputationDataHandle ComputationBuilder::ShiftLeft(
877    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
878    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
879  return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions);
880}
881
882ComputationDataHandle ComputationBuilder::ShiftRightArithmetic(
883    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
884    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
885  return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions);
886}
887
888ComputationDataHandle ComputationBuilder::ShiftRightLogical(
889    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
890    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
891  return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions);
892}
893
894ComputationDataHandle ComputationBuilder::Abs(
895    const ComputationDataHandle& operand) {
896  return UnaryOp(UNOP_ABS, operand);
897}
898
899ComputationDataHandle ComputationBuilder::Atan2(
900    const ComputationDataHandle& y, const ComputationDataHandle& x,
901    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
902  return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions);
903}
904
905ComputationDataHandle ComputationBuilder::Exp(
906    const ComputationDataHandle& operand) {
907  return UnaryOp(UNOP_EXP, operand);
908}
909
910ComputationDataHandle ComputationBuilder::Floor(
911    const ComputationDataHandle& operand) {
912  return UnaryOp(UNOP_FLOOR, operand);
913}
914
915ComputationDataHandle ComputationBuilder::Ceil(
916    const ComputationDataHandle& operand) {
917  return UnaryOp(UNOP_CEIL, operand);
918}
919
920ComputationDataHandle ComputationBuilder::Round(
921    const ComputationDataHandle& operand) {
922  return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand);
923}
924
925ComputationDataHandle ComputationBuilder::Log(
926    const ComputationDataHandle& operand) {
927  return UnaryOp(UNOP_LOG, operand);
928}
929
930ComputationDataHandle ComputationBuilder::Sign(
931    const ComputationDataHandle& operand) {
932  return UnaryOp(UNOP_SIGN, operand);
933}
934
935ComputationDataHandle ComputationBuilder::Cos(
936    const ComputationDataHandle& operand) {
937  return UnaryOp(UNOP_COS, operand);
938}
939
940ComputationDataHandle ComputationBuilder::Sin(
941    const ComputationDataHandle& operand) {
942  return UnaryOp(UNOP_SIN, operand);
943}
944
945ComputationDataHandle ComputationBuilder::Tanh(
946    const ComputationDataHandle& operand) {
947  return UnaryOp(UNOP_TANH, operand);
948}
949
950ComputationDataHandle ComputationBuilder::Real(
951    const ComputationDataHandle& operand) {
952  return UnaryOp(UNOP_REAL, operand);
953}
954
955ComputationDataHandle ComputationBuilder::Imag(
956    const ComputationDataHandle& operand) {
957  return UnaryOp(UNOP_IMAG, operand);
958}
959
960ComputationDataHandle ComputationBuilder::IsFinite(
961    const ComputationDataHandle& operand) {
962  return UnaryOp(UNOP_IS_FINITE, operand);
963}
964
965ComputationDataHandle ComputationBuilder::Transpose(
966    const ComputationDataHandle& operand,
967    tensorflow::gtl::ArraySlice<int64> permutation) {
968  OpRequest op_request;
969  TransposeRequest* request = op_request.mutable_transpose_request();
970  *request->mutable_operand() = operand;
971  for (int64 dimension : permutation) {
972    request->add_dimensions(dimension);
973  }
974  return RunOpAndParseResponse(&op_request);
975}
976
977ComputationDataHandle ComputationBuilder::Rev(
978    const ComputationDataHandle& operand,
979    tensorflow::gtl::ArraySlice<int64> dimensions) {
980  OpRequest op_request;
981  ReverseRequest* request = op_request.mutable_reverse_request();
982  *request->mutable_operand() = operand;
983  for (int64 dimension : dimensions) {
984    request->add_dimensions(dimension);
985  }
986  return RunOpAndParseResponse(&op_request);
987}
988
989ComputationDataHandle ComputationBuilder::Sort(
990    const ComputationDataHandle& operand) {
991  return UnaryOp(UNOP_SORT, operand);
992}
993
994ComputationDataHandle ComputationBuilder::SqrtF32(
995    const ComputationDataHandle& operand) {
996  return BinaryOp(BINOP_POW, operand, ConstantR0<float>(0.5),
997                  /*broadcast_dimensions=*/{});
998}
999
1000ComputationDataHandle ComputationBuilder::Pow(
1001    const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
1002    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
1003  return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions);
1004}
1005
1006ComputationDataHandle ComputationBuilder::ConvertElementType(
1007    const ComputationDataHandle& operand, PrimitiveType new_element_type) {
1008  if (!first_error_.ok() || !PrepareComputation().ok()) {
1009    return ComputationDataHandle();
1010  }
1011
1012  StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
1013  if (!shape_status.ok()) {
1014    return ComputationDataHandle();
1015  }
1016  std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
1017
1018  OpRequest op_request;
1019  ConvertRequest* request = op_request.mutable_convert_request();
1020  *request->mutable_operand() = operand;
1021  request->set_new_element_type(new_element_type);
1022  return RunOpAndParseResponse(&op_request);
1023}
1024
1025ComputationDataHandle ComputationBuilder::BitcastConvertType(
1026    const ComputationDataHandle& operand, PrimitiveType new_element_type) {
1027  if (!first_error_.ok() || !PrepareComputation().ok()) {
1028    return ComputationDataHandle();
1029  }
1030
1031  StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand);
1032  if (!shape_status.ok()) {
1033    return ComputationDataHandle();
1034  }
1035  std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie();
1036
1037  OpRequest op_request;
1038  ConvertRequest* request = op_request.mutable_bitcast_convert_request();
1039  *request->mutable_operand() = operand;
1040  request->set_new_element_type(new_element_type);
1041  return RunOpAndParseResponse(&op_request);
1042}
1043
1044ComputationDataHandle ComputationBuilder::SquareF32(
1045    const ComputationDataHandle& operand) {
1046  return BinaryOp(BINOP_POW, operand, ConstantR0<float>(2.0),
1047                  /*broadcast_dimensions=*/{});
1048}
1049
1050ComputationDataHandle ComputationBuilder::ReciprocalF32(
1051    const ComputationDataHandle& operand) {
1052  return BinaryOp(BINOP_POW, operand, ConstantR0<float>(-1.0),
1053                  /*broadcast_dimensions=*/{});
1054}
1055
1056ComputationDataHandle ComputationBuilder::Neg(
1057    const ComputationDataHandle& operand) {
1058  return UnaryOp(UNOP_NEGATE, operand);
1059}
1060
1061ComputationDataHandle ComputationBuilder::Clamp(
1062    const ComputationDataHandle& min, const ComputationDataHandle& operand,
1063    const ComputationDataHandle& max) {
1064  return TernaryOp(TRIOP_CLAMP, min, operand, max);
1065}
1066
1067ComputationDataHandle ComputationBuilder::UnaryOp(
1068    UnaryOperation unop, const ComputationDataHandle& operand) {
1069  OpRequest op_request;
1070  UnaryOpRequest* request = op_request.mutable_unary_op_request();
1071  request->set_unop(unop);
1072  *request->mutable_operand() = operand;
1073  return RunOpAndParseResponse(&op_request);
1074}
1075
1076ComputationDataHandle ComputationBuilder::BinaryOp(
1077    BinaryOperation binop, const ComputationDataHandle& lhs,
1078    const ComputationDataHandle& rhs,
1079    tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
1080  OpRequest op_request;
1081  BinaryOpRequest* request = op_request.mutable_binary_op_request();
1082  request->set_binop(binop);
1083  *request->mutable_lhs() = lhs;
1084  *request->mutable_rhs() = rhs;
1085  for (int64 dimension : broadcast_dimensions) {
1086    request->add_broadcast_dimensions(dimension);
1087  }
1088  return RunOpAndParseResponse(&op_request);
1089}
1090
1091ComputationDataHandle ComputationBuilder::RngOp(
1092    RandomDistribution distribution,
1093    tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters,
1094    const Shape& shape) {
1095  OpRequest op_request;
1096  RngRequest* request = op_request.mutable_rng_request();
1097  request->set_distribution(distribution);
1098  for (const ComputationDataHandle& param : parameters) {
1099    *request->add_parameter() = param;
1100  }
1101  *request->mutable_shape() = shape;
1102  return RunOpAndParseResponse(&op_request);
1103}
1104
1105ComputationDataHandle ComputationBuilder::TernaryOp(
1106    TernaryOperation triop, const ComputationDataHandle& lhs,
1107    const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) {
1108  OpRequest op_request;
1109  TernaryOpRequest* request = op_request.mutable_ternary_op_request();
1110  request->set_triop(triop);
1111  *request->mutable_lhs() = lhs;
1112  *request->mutable_rhs() = rhs;
1113  *request->mutable_ehs() = ehs;
1114  return RunOpAndParseResponse(&op_request);
1115}
1116
1117Status ComputationBuilder::SetReturnValue(
1118    const ComputationDataHandle& operand) {
1119  TF_RETURN_IF_ERROR(first_error_);
1120
1121  SetReturnValueRequest request;
1122  *request.mutable_computation() = computation_.handle();
1123  *request.mutable_operand() = operand;
1124
1125  SetReturnValueResponse response;
1126
1127  VLOG(2) << "making set-handle-to-execute request";
1128  Status s = client_->stub()->SetReturnValue(&request, &response);
1129  VLOG(2) << "done with request";
1130
1131  if (!s.ok()) {
1132    NoteError(s);
1133    return first_error_;
1134  }
1135
1136  return Status::OK();
1137}
1138
1139StatusOr<bool> ComputationBuilder::IsConstant(
1140    const ComputationDataHandle& operand, int64 num_parameters) {
1141  TF_RETURN_IF_ERROR(first_error_);
1142
1143  IsConstantRequest request;
1144  *request.mutable_computation() = computation_.handle();
1145  *request.mutable_operand() = operand;
1146  request.set_num_parameters(num_parameters);
1147  IsConstantResponse response;
1148
1149  VLOG(2) << "making IsConstant request";
1150  Status s = client_->stub()->IsConstant(&request, &response);
1151  VLOG(2) << "done with request";
1152
1153  if (!s.ok()) {
1154    return s;
1155  }
1156  return response.is_constant();
1157}
1158
1159StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
1160    const ComputationDataHandle& operand, const Layout* output_layout,
1161    tensorflow::gtl::ArraySlice<Literal> parameters) {
1162  TF_RETURN_IF_ERROR(first_error_);
1163
1164  ComputeConstantRequest request;
1165  *request.mutable_computation() = computation_.handle();
1166  *request.mutable_operand() = operand;
1167  if (output_layout != nullptr) {
1168    *request.mutable_output_layout() = *output_layout;
1169  }
1170  for (const auto& param : parameters) {
1171    *request.add_parameters() = param.ToProto();
1172  }
1173
1174  ComputeConstantResponse response;
1175
1176  VLOG(2) << "making compute-constant request";
1177  Status s = client_->stub()->ComputeConstant(&request, &response);
1178  VLOG(2) << "done with request";
1179
1180  if (!s.ok()) {
1181    return s;
1182  }
1183
1184  VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
1185
1186  if (!response.has_literal()) {
1187    return InternalError(
1188        "no computed literal in the provided response in ComputeConstant "
1189        "request");
1190  }
1191  return Literal::CreateFromProto(response.literal());
1192}
1193
1194ComputationDataHandle ComputationBuilder::Map(
1195    tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
1196    const Computation& computation,
1197    tensorflow::gtl::ArraySlice<int64> dimensions,
1198    tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) {
1199  OpRequest op_request;
1200  MapRequest* request = op_request.mutable_map_request();
1201  for (const ComputationDataHandle& operand : operands) {
1202    *request->add_operands() = operand;
1203  }
1204  *request->mutable_to_apply() = computation.handle();
1205  for (int64 dimension : dimensions) {
1206    request->add_dimensions(dimension);
1207  }
1208  for (const ComputationDataHandle& sop : static_operands) {
1209    *request->add_static_operands() = sop;
1210  }
1211  return RunOpAndParseResponse(&op_request);
1212}
1213
1214ComputationDataHandle ComputationBuilder::RngNormal(
1215    const ComputationDataHandle& mu, const ComputationDataHandle& sigma,
1216    const Shape& shape) {
1217  return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape);
1218}
1219
1220ComputationDataHandle ComputationBuilder::RngUniform(
1221    const ComputationDataHandle& a, const ComputationDataHandle& b,
1222    const Shape& shape) {
1223  return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape);
1224}
1225
1226ComputationDataHandle ComputationBuilder::While(
1227    const Computation& condition, const Computation& body,
1228    const ComputationDataHandle& init) {
1229  OpRequest op_request;
1230  WhileRequest* request = op_request.mutable_while_request();
1231  *request->mutable_condition() = condition.handle();
1232  *request->mutable_body() = body.handle();
1233  *request->mutable_init() = init;
1234  return RunOpAndParseResponse(&op_request);
1235}
1236
1237ComputationDataHandle ComputationBuilder::Gather(
1238    const ComputationDataHandle& input,
1239    const ComputationDataHandle& gather_indices,
1240    const GatherDimensionNumbers& dimension_numbers,
1241    tensorflow::gtl::ArraySlice<int64> window_bounds) {
1242  OpRequest op_request;
1243  GatherRequest* gather_request = op_request.mutable_gather_request();
1244  *gather_request->mutable_input() = input;
1245  *gather_request->mutable_gather_indices() = gather_indices;
1246  *gather_request->mutable_dimension_numbers() = dimension_numbers;
1247  for (int64 window_bound : window_bounds) {
1248    gather_request->add_window_bounds(window_bound);
1249  }
1250  return RunOpAndParseResponse(&op_request);
1251}
1252
1253ComputationDataHandle ComputationBuilder::Conditional(
1254    const ComputationDataHandle& predicate,
1255    const ComputationDataHandle& true_operand,
1256    const Computation& true_computation,
1257    const ComputationDataHandle& false_operand,
1258    const Computation& false_computation) {
1259  OpRequest op_request;
1260  ConditionalRequest* request = op_request.mutable_conditional_request();
1261  *request->mutable_predicate() = predicate;
1262  *request->mutable_true_operand() = true_operand;
1263  *request->mutable_true_computation() = true_computation.handle();
1264  *request->mutable_false_operand() = false_operand;
1265  *request->mutable_false_computation() = false_computation.handle();
1266  return RunOpAndParseResponse(&op_request);
1267}
1268
1269ComputationDataHandle ComputationBuilder::Reduce(
1270    const ComputationDataHandle& operand,
1271    const ComputationDataHandle& init_value, const Computation& computation,
1272    tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) {
1273  OpRequest op_request;
1274  ReduceRequest* request = op_request.mutable_reduce_request();
1275  *request->mutable_operand() = operand;
1276  *request->mutable_init_value() = init_value;
1277  for (int64 dimension : dimensions_to_reduce) {
1278    request->add_dimensions(dimension);
1279  }
1280  *request->mutable_to_apply() = computation.handle();
1281  return RunOpAndParseResponse(&op_request);
1282}
1283
1284ComputationDataHandle ComputationBuilder::ReduceAll(
1285    const ComputationDataHandle& operand,
1286    const ComputationDataHandle& init_value, const Computation& computation) {
1287  if (!first_error_.ok() || !PrepareComputation().ok()) {
1288    return ComputationDataHandle();
1289  }
1290
1291  StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
1292  if (!shape.ok()) {
1293    return ComputationDataHandle();
1294  }
1295
1296  std::vector<int64> all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie()));
1297  std::iota(all_dimnos.begin(), all_dimnos.end(), 0);
1298  return Reduce(operand, init_value, computation, all_dimnos);
1299}
1300
1301ComputationDataHandle ComputationBuilder::ReduceWindow(
1302    const ComputationDataHandle& operand,
1303    const ComputationDataHandle& init_value, const Computation& computation,
1304    tensorflow::gtl::ArraySlice<int64> window_dimensions,
1305    tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) {
1306  if (!first_error_.ok()) {
1307    return ComputationDataHandle();
1308  }
1309
1310  StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
1311  if (!shape.ok()) {
1312    return ComputationDataHandle();
1313  }
1314
1315  Status padding_valid =
1316      ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()),
1317                            window_dimensions, window_strides);
1318  if (!padding_valid.ok()) {
1319    first_error_ = padding_valid;
1320    return ComputationDataHandle();
1321  }
1322
1323  std::vector<std::pair<int64, int64>> padding_values =
1324      MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
1325                  window_dimensions, window_strides, padding);
1326  return ReduceWindowWithGeneralPadding(operand, init_value, computation,
1327                                        window_dimensions, window_strides,
1328                                        padding_values);
1329}
1330
1331ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding(
1332    const ComputationDataHandle& operand,
1333    const ComputationDataHandle& init_value, const Computation& computation,
1334    tensorflow::gtl::ArraySlice<int64> window_dimensions,
1335    tensorflow::gtl::ArraySlice<int64> window_strides,
1336    tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
1337  OpRequest op_request;
1338  ReduceWindowRequest* request = op_request.mutable_reduce_window_request();
1339  *request->mutable_operand() = operand;
1340  *request->mutable_to_apply() = computation.handle();
1341  *request->mutable_init_value() = init_value;
1342
1343  if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
1344                  request->mutable_window())) {
1345    NoteError(InternalError("failed to make window"));
1346    return ComputationDataHandle();
1347  }
1348
1349  return RunOpAndParseResponse(&op_request);
1350}
1351
1352ComputationDataHandle ComputationBuilder::BatchNormTraining(
1353    const ComputationDataHandle& operand, const ComputationDataHandle& scale,
1354    const ComputationDataHandle& offset, float epsilon, int64 feature_index) {
1355  OpRequest op_request;
1356  BatchNormTrainingRequest* request =
1357      op_request.mutable_batch_norm_training_request();
1358  *request->mutable_operand() = operand;
1359  *request->mutable_scale() = scale;
1360  *request->mutable_offset() = offset;
1361  request->set_epsilon(epsilon);
1362  request->set_feature_index(feature_index);
1363  return RunOpAndParseResponse(&op_request);
1364}
1365
1366ComputationDataHandle ComputationBuilder::BatchNormInference(
1367    const ComputationDataHandle& operand, const ComputationDataHandle& scale,
1368    const ComputationDataHandle& offset, const ComputationDataHandle& mean,
1369    const ComputationDataHandle& variance, float epsilon, int64 feature_index) {
1370  OpRequest op_request;
1371  BatchNormInferenceRequest* request =
1372      op_request.mutable_batch_norm_inference_request();
1373  *request->mutable_operand() = operand;
1374  *request->mutable_scale() = scale;
1375  *request->mutable_offset() = offset;
1376  *request->mutable_mean() = mean;
1377  *request->mutable_variance() = variance;
1378  request->set_epsilon(epsilon);
1379  request->set_feature_index(feature_index);
1380  return RunOpAndParseResponse(&op_request);
1381}
1382
1383ComputationDataHandle ComputationBuilder::BatchNormGrad(
1384    const ComputationDataHandle& operand, const ComputationDataHandle& scale,
1385    const ComputationDataHandle& mean, const ComputationDataHandle& var,
1386    const ComputationDataHandle& grad_output, float epsilon,
1387    int64 feature_index) {
1388  OpRequest op_request;
1389  BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request();
1390  *request->mutable_operand() = operand;
1391  *request->mutable_scale() = scale;
1392  *request->mutable_mean() = mean;
1393  *request->mutable_variance() = var;
1394  *request->mutable_grad_output() = grad_output;
1395  request->set_epsilon(epsilon);
1396  request->set_feature_index(feature_index);
1397  return RunOpAndParseResponse(&op_request);
1398}
1399
1400ComputationDataHandle ComputationBuilder::CrossReplicaSum(
1401    const ComputationDataHandle& operand) {
1402  OpRequest op_request;
1403  CrossReplicaSumRequest* request =
1404      op_request.mutable_cross_replica_sum_request();
1405  *request->mutable_operand() = operand;
1406  return RunOpAndParseResponse(&op_request);
1407}
1408
1409ComputationDataHandle ComputationBuilder::SelectAndScatter(
1410    const ComputationDataHandle& operand, const Computation& select,
1411    tensorflow::gtl::ArraySlice<int64> window_dimensions,
1412    tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
1413    const ComputationDataHandle& source,
1414    const ComputationDataHandle& init_value, const Computation& scatter) {
1415  if (!first_error_.ok()) {
1416    return ComputationDataHandle();
1417  }
1418
1419  StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand);
1420  if (!shape.ok()) {
1421    return ComputationDataHandle();
1422  }
1423  return SelectAndScatterWithGeneralPadding(
1424      operand, select, window_dimensions, window_strides,
1425      MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()),
1426                  window_dimensions, window_strides, padding),
1427      source, init_value, scatter);
1428}
1429
1430ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding(
1431    const ComputationDataHandle& operand, const Computation& select,
1432    tensorflow::gtl::ArraySlice<int64> window_dimensions,
1433    tensorflow::gtl::ArraySlice<int64> window_strides,
1434    tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
1435    const ComputationDataHandle& source,
1436    const ComputationDataHandle& init_value, const Computation& scatter) {
1437  OpRequest op_request;
1438  SelectAndScatterRequest* request =
1439      op_request.mutable_select_and_scatter_request();
1440  *request->mutable_operand() = operand;
1441  *request->mutable_select() = select.handle();
1442  *request->mutable_source() = source;
1443  *request->mutable_init_value() = init_value;
1444  *request->mutable_scatter() = scatter.handle();
1445
1446  if (!MakeWindow(window_dimensions, window_strides, padding, {}, {},
1447                  request->mutable_window())) {
1448    NoteError(InternalError("failed to make window"));
1449    return ComputationDataHandle();
1450  }
1451
1452  return RunOpAndParseResponse(&op_request);
1453}
1454
1455ComputationDataHandle ComputationBuilder::ReducePrecision(
1456    const ComputationDataHandle& operand, const int exponent_bits,
1457    const int mantissa_bits) {
1458  OpRequest op_request;
1459  ReducePrecisionRequest* request =
1460      op_request.mutable_reduce_precision_request();
1461  *request->mutable_operand() = operand;
1462  request->set_exponent_bits(exponent_bits);
1463  request->set_mantissa_bits(mantissa_bits);
1464  return RunOpAndParseResponse(&op_request);
1465}
1466
1467void ComputationBuilder::Send(const ComputationDataHandle& operand,
1468                              const ChannelHandle& handle) {
1469  OpRequest op_request;
1470  SendRequest* request = op_request.mutable_send_request();
1471  *request->mutable_operand() = operand;
1472  *request->mutable_channel_handle() = handle;
1473  *op_request.mutable_computation() = computation_.handle();
1474  RunOpAndNoteError(&op_request);
1475}
1476
1477ComputationDataHandle ComputationBuilder::Recv(const Shape& shape,
1478                                               const ChannelHandle& handle) {
1479  OpRequest op_request;
1480  RecvRequest* request = op_request.mutable_recv_request();
1481  *request->mutable_shape() = shape;
1482  *request->mutable_channel_handle() = handle;
1483  return RunOpAndParseResponse(&op_request);
1484}
1485
1486Computation ComputationBuilder::BuildAndNoteError() {
1487  DCHECK(parent_builder_ != nullptr);
1488  auto build_status = Build();
1489  if (!build_status.ok()) {
1490    parent_builder_->NoteError(
1491        AddStatus(build_status.status(),
1492                  tensorflow::strings::StrCat("error from: ", name_)));
1493    return Computation();
1494  }
1495  return build_status.ConsumeValueOrDie();
1496}
1497
1498StatusOr<Computation> ComputationBuilder::Build() {
1499  if (!first_error_.ok()) {
1500    string backtrace;
1501    first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
1502    return AppendStatus(first_error_, backtrace);
1503  }
1504
1505  if (computation_.IsNull()) {
1506    return FailedPrecondition("no computation was built");
1507  }
1508
1509  return {std::move(computation_)};
1510}
1511
1512/* static */ ConvolutionDimensionNumbers
1513ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) {
1514  ConvolutionDimensionNumbers dimension_numbers;
1515  dimension_numbers.set_input_batch_dimension(kConvBatchDimension);
1516  dimension_numbers.set_input_feature_dimension(kConvFeatureDimension);
1517  dimension_numbers.set_output_batch_dimension(kConvBatchDimension);
1518  dimension_numbers.set_output_feature_dimension(kConvFeatureDimension);
1519  dimension_numbers.set_kernel_output_feature_dimension(
1520      kConvKernelOutputDimension);
1521  dimension_numbers.set_kernel_input_feature_dimension(
1522      kConvKernelInputDimension);
1523  for (int i = 0; i < num_spatial_dims; ++i) {
1524    dimension_numbers.add_input_spatial_dimensions(i + 2);
1525    dimension_numbers.add_kernel_spatial_dimensions(i + 2);
1526    dimension_numbers.add_output_spatial_dimensions(i + 2);
1527  }
1528  return dimension_numbers;
1529}
1530
1531/* static */ StatusOr<ConvolutionDimensionNumbers>
1532ComputationBuilder::CreateConvDimensionNumbers(
1533    int64 input_batch, int64 input_feature, int64 input_first_spatial,
1534    int64 input_second_spatial, int64 output_batch, int64 output_feature,
1535    int64 output_first_spatial, int64 output_second_spatial,
1536    int64 kernel_output_feature, int64 kernel_input_feature,
1537    int64 kernel_first_spatial, int64 kernel_second_spatial) {
1538  if (std::set<int64>({input_batch, input_feature, input_first_spatial,
1539                       input_second_spatial})
1540          .size() != 4) {
1541    return FailedPrecondition(
1542        "dimension numbers for the input are not unique: (%lld, %lld, %lld, "
1543        "%lld)",
1544        input_batch, input_feature, input_first_spatial, input_second_spatial);
1545  }
1546  if (std::set<int64>({kernel_output_feature, kernel_input_feature,
1547                       kernel_first_spatial, kernel_second_spatial})
1548          .size() != 4) {
1549    return FailedPrecondition(
1550        "dimension numbers for the weight are not unique: (%lld, %lld, %lld, "
1551        "%lld)",
1552        kernel_output_feature, kernel_input_feature, kernel_first_spatial,
1553        kernel_second_spatial);
1554  }
1555  if (std::set<int64>({output_batch, output_feature, output_first_spatial,
1556                       output_second_spatial})
1557          .size() != 4) {
1558    return FailedPrecondition(
1559        "dimension numbers for the output are not unique: (%lld, %lld, %lld, "
1560        "%lld)",
1561        output_batch, output_feature, output_first_spatial,
1562        output_second_spatial);
1563  }
1564  ConvolutionDimensionNumbers dimension_numbers;
1565  dimension_numbers.set_input_batch_dimension(input_batch);
1566  dimension_numbers.set_input_feature_dimension(input_feature);
1567  dimension_numbers.add_input_spatial_dimensions(input_first_spatial);
1568  dimension_numbers.add_input_spatial_dimensions(input_second_spatial);
1569  dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature);
1570  dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature);
1571  dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial);
1572  dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial);
1573  dimension_numbers.set_output_batch_dimension(output_batch);
1574  dimension_numbers.set_output_feature_dimension(output_feature);
1575  dimension_numbers.add_output_spatial_dimensions(output_first_spatial);
1576  dimension_numbers.add_output_spatial_dimensions(output_second_spatial);
1577  return dimension_numbers;
1578}
1579
1580}  // namespace xla
1581