1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/client/computation_builder.h" 17 18#include <stddef.h> 19#include <array> 20#include <numeric> 21#include <set> 22#include <vector> 23 24#include "tensorflow/compiler/xla/ptr_util.h" 25#include "tensorflow/compiler/xla/shape_util.h" 26#include "tensorflow/compiler/xla/status_macros.h" 27#include "tensorflow/compiler/xla/types.h" 28#include "tensorflow/compiler/xla/util.h" 29#include "tensorflow/compiler/xla/xla.pb.h" 30#include "tensorflow/core/lib/core/errors.h" 31#include "tensorflow/core/lib/strings/strcat.h" 32#include "tensorflow/core/platform/logging.h" 33#include "tensorflow/core/platform/protobuf.h" 34 35namespace xla { 36 37ComputationBuilder::ComputationBuilder(Client* client, 38 const string& computation_name) 39 : name_(computation_name), client_(client) {} 40 41ComputationBuilder::~ComputationBuilder() {} 42 43void ComputationBuilder::NoteError(const Status& error) { 44 if (die_immediately_on_error_) { 45 LOG(FATAL) << "error building computation: " << error; 46 } 47 48 if (first_error_.ok()) { 49 first_error_ = error; 50 first_error_backtrace_.CreateCurrent(/*skip_count=*/1); 51 } 52} 53 54std::unique_ptr<ComputationBuilder> ComputationBuilder::CreateSubBuilder( 55 const string& computation_name) { 56 auto sub_builder = MakeUnique<ComputationBuilder>(client_, computation_name); 57 sub_builder->parent_builder_ = this; 58 sub_builder->die_immediately_on_error_ = die_immediately_on_error_; 59 return sub_builder; 60} 61 62Status ComputationBuilder::PrepareComputation() { 63 TF_RETURN_IF_ERROR(first_error_); 64 65 if (!computation_.IsNull()) { 66 return Status::OK(); 67 } 68 69 ComputationRequest request; 70 request.set_name(name_); 71 ComputationResponse response; 72 73 VLOG(2) << "making computation request"; 74 Status s = client_->stub()->Computation(&request, &response); 75 VLOG(2) << "done with computation request"; 76 77 if (!s.ok()) { 78 NoteError(s); 79 return first_error_; 80 } 81 82 computation_ = Computation(client_->stub(), response.computation()); 83 return Status::OK(); 84} 85 86Status ComputationBuilder::RunOp(OpRequest* op_request, 87 OpResponse* op_response) { 88 TF_RETURN_IF_ERROR(first_error_); 89 TF_RETURN_IF_ERROR(PrepareComputation()); 90 91 // Fill in fields that are set on every OpRequest. 92 *op_request->mutable_computation() = computation_.handle(); 93 *op_request->mutable_metadata() = metadata_; 94 if (sharding_) { 95 *op_request->mutable_sharding() = *sharding_; 96 } 97 98 const string& op_name = 99 OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name(); 100 VLOG(2) << "running op request: " << op_name; 101 Status status = client_->stub()->Op(op_request, op_response); 102 VLOG(2) << "done with op request: " << op_name; 103 return status; 104} 105 106void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) { 107 OpResponse op_response; 108 Status status = RunOp(op_request, &op_response); 109 if (!status.ok()) { 110 NoteError(status); 111 } 112} 113 114ComputationDataHandle ComputationBuilder::RunOpAndParseResponse( 115 OpRequest* op_request) { 116 OpResponse op_response; 117 Status status = RunOp(op_request, &op_response); 118 if (!status.ok()) { 119 NoteError(status); 120 return ComputationDataHandle(); 121 } 122 if (op_response.output().handle() == 0) { 123 NoteError(InternalError("No output handle")); 124 return ComputationDataHandle(); 125 } 126 return op_response.output(); 127} 128 129bool ComputationBuilder::MakeWindow( 130 tensorflow::gtl::ArraySlice<int64> window_dimensions, 131 tensorflow::gtl::ArraySlice<int64> window_strides, 132 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, 133 tensorflow::gtl::ArraySlice<int64> lhs_dilation, 134 tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) { 135 const auto verify_size = [&](const size_t x, const char* x_name) { 136 if (x == 0 || x == window_dimensions.size()) { 137 return true; 138 } else { 139 NoteError(InvalidArgument( 140 "%s", tensorflow::strings::StrCat( 141 "Window has different number of window dimensions than of ", 142 x_name, "\nNumber of window dimensions: ", 143 window_dimensions.size(), "\nNumber of ", x_name, ": ", x, 144 "\n") 145 .c_str())); // 146 return false; 147 } 148 }; 149 if (!verify_size(window_strides.size(), "window strides") || 150 !verify_size(padding.size(), "padding entries") || 151 !verify_size(lhs_dilation.size(), "lhs dilation factors") || 152 !verify_size(rhs_dilation.size(), "rhs dilation factors")) { 153 return false; 154 } 155 156 window->Clear(); 157 for (size_t i = 0; i < window_dimensions.size(); i++) { 158 auto dim = window->add_dimensions(); 159 dim->set_size(window_dimensions[i]); 160 if (!window_strides.empty()) { 161 dim->set_stride(window_strides[i]); 162 } else { 163 dim->set_stride(1); 164 } 165 if (!padding.empty()) { 166 dim->set_padding_low(padding[i].first); 167 dim->set_padding_high(padding[i].second); 168 } else { 169 dim->set_padding_low(0); 170 dim->set_padding_high(0); 171 } 172 if (!lhs_dilation.empty()) { 173 dim->set_base_dilation(lhs_dilation[i]); 174 } else { 175 dim->set_base_dilation(1); 176 } 177 if (!rhs_dilation.empty()) { 178 dim->set_window_dilation(rhs_dilation[i]); 179 } else { 180 dim->set_window_dilation(1); 181 } 182 dim->set_window_reversal(false); 183 } 184 return true; 185} 186 187ComputationDataHandle ComputationBuilder::ConstantLiteral( 188 const Literal& literal) { 189 OpRequest op_request; 190 ConstantRequest* request = op_request.mutable_constant_request(); 191 *request->mutable_literal() = literal.ToProto(); 192 VLOG(3) << "created constant: " << request->literal().ShortDebugString(); 193 return RunOpAndParseResponse(&op_request); 194} 195 196ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number, 197 const Shape& shape, 198 const string& name) { 199 OpRequest op_request; 200 ParameterRequest* request = op_request.mutable_parameter_request(); 201 *request->mutable_shape() = shape; 202 request->set_parameter(parameter_number); 203 request->set_name(name); 204 return RunOpAndParseResponse(&op_request); 205} 206 207StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShapeWithoutNoteError( 208 const ComputationDataHandle& operand) { 209 GetLocalShapeRequest request; 210 *request.mutable_computation() = computation_.handle(); 211 *request.mutable_operand() = operand; 212 GetLocalShapeResponse response; 213 214 VLOG(2) << "making get-shape request"; 215 TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response)); 216 VLOG(2) << "done with request"; 217 218 TF_RET_CHECK(response.has_shape()); 219 std::unique_ptr<Shape> shape = WrapUnique(response.release_shape()); 220 TF_RET_CHECK(shape != nullptr); 221 return std::move(shape); 222} 223 224StatusOr<std::unique_ptr<Shape>> ComputationBuilder::GetShape( 225 const ComputationDataHandle& operand) { 226 TF_RETURN_IF_ERROR(first_error_); 227 228 auto status_or_shape = GetShapeWithoutNoteError(operand); 229 if (!status_or_shape.ok()) { 230 NoteError(status_or_shape.status()); 231 return first_error_; 232 } 233 return status_or_shape; 234} 235 236StatusOr<ProgramShape> ComputationBuilder::GetProgramShape() { 237 TF_RETURN_IF_ERROR(first_error_); 238 239 GetComputationShapeRequest request; 240 *request.mutable_computation() = computation_.handle(); 241 GetComputationShapeResponse response; 242 243 VLOG(2) << "making get-program-shape-request"; 244 Status status = client_->stub()->GetComputationShape(&request, &response); 245 VLOG(2) << "done with get-program-shape-request"; 246 247 if (!status.ok()) { 248 first_error_ = status; 249 return status; 250 } 251 252 TF_RET_CHECK(response.has_program_shape()); 253 return std::move(*response.mutable_program_shape()); 254} 255 256ComputationDataHandle ComputationBuilder::CheckShape( 257 const ComputationDataHandle& operand, const Shape& expected_shape) { 258 std::unique_ptr<Shape> actual_shape = GetShape(operand).ConsumeValueOrDie(); 259 CHECK(ShapeUtil::Equal(expected_shape, *actual_shape)) 260 << "want " << ShapeUtil::HumanString(expected_shape) << " got " 261 << ShapeUtil::HumanString(*actual_shape); 262 return operand; 263} 264 265void ComputationBuilder::CheckSameShape(const ComputationDataHandle& lhs, 266 const ComputationDataHandle& rhs) { 267 std::unique_ptr<Shape> lhs_shape = GetShape(lhs).ConsumeValueOrDie(); 268 std::unique_ptr<Shape> rhs_shape = GetShape(rhs).ConsumeValueOrDie(); 269 VLOG(2) << "checking " << ShapeUtil::HumanString(*lhs_shape) << " equals " 270 << ShapeUtil::HumanString(*rhs_shape); 271 CHECK(ShapeUtil::Equal(*lhs_shape, *rhs_shape)) 272 << "lhs " << ShapeUtil::HumanString(*lhs_shape) << " rhs " 273 << ShapeUtil::HumanString(*rhs_shape); 274} 275 276ComputationDataHandle ComputationBuilder::Slice( 277 const ComputationDataHandle& operand, 278 tensorflow::gtl::ArraySlice<int64> start_indices, 279 tensorflow::gtl::ArraySlice<int64> limit_indices, 280 tensorflow::gtl::ArraySlice<int64> strides) { 281 OpRequest op_request; 282 SliceRequest* request = op_request.mutable_slice_request(); 283 *request->mutable_operand() = operand; 284 for (int64 index : start_indices) { 285 request->add_start_indices(index); 286 } 287 for (int64 index : limit_indices) { 288 request->add_limit_indices(index); 289 } 290 for (int64 index : strides) { 291 request->add_strides(index); 292 } 293 return RunOpAndParseResponse(&op_request); 294} 295 296ComputationDataHandle ComputationBuilder::SliceInDim( 297 const ComputationDataHandle& operand, int64 start_index, int64 limit_index, 298 int64 stride, int64 dimno) { 299 StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand); 300 if (!shape_status.ok()) { 301 NoteError(shape_status.status()); 302 return ComputationDataHandle{}; 303 } 304 const Shape& shape = *shape_status.ValueOrDie(); 305 std::vector<int64> starts(ShapeUtil::Rank(shape), 0); 306 std::vector<int64> limits(shape.dimensions().begin(), 307 shape.dimensions().end()); 308 std::vector<int64> strides(ShapeUtil::Rank(shape), 1); 309 starts[dimno] = start_index; 310 limits[dimno] = limit_index; 311 strides[dimno] = stride; 312 return Slice(operand, starts, limits, strides); 313} 314 315ComputationDataHandle ComputationBuilder::DynamicSlice( 316 const ComputationDataHandle& operand, 317 const ComputationDataHandle& start_indices, 318 tensorflow::gtl::ArraySlice<int64> slice_sizes) { 319 OpRequest op_request; 320 DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request(); 321 *request->mutable_operand() = operand; 322 *request->mutable_start_indices() = start_indices; 323 for (int64 index : slice_sizes) { 324 request->add_slice_sizes(index); 325 } 326 return RunOpAndParseResponse(&op_request); 327} 328 329ComputationDataHandle ComputationBuilder::DynamicUpdateSlice( 330 const ComputationDataHandle& operand, const ComputationDataHandle& update, 331 const ComputationDataHandle& start_indices) { 332 OpRequest op_request; 333 DynamicUpdateSliceRequest* request = 334 op_request.mutable_dynamic_update_slice_request(); 335 *request->mutable_operand() = operand; 336 *request->mutable_update() = update; 337 *request->mutable_start_indices() = start_indices; 338 return RunOpAndParseResponse(&op_request); 339} 340 341ComputationDataHandle ComputationBuilder::ConcatInDim( 342 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, 343 int64 dimension) { 344 OpRequest op_request; 345 ConcatenateRequest* request = op_request.mutable_concatenate_request(); 346 for (const ComputationDataHandle& operand : operands) { 347 *request->add_operands() = operand; 348 } 349 request->set_dimension(dimension); 350 return RunOpAndParseResponse(&op_request); 351} 352 353ComputationDataHandle ComputationBuilder::Broadcast( 354 const ComputationDataHandle& operand, 355 tensorflow::gtl::ArraySlice<int64> broadcast_sizes) { 356 OpRequest op_request; 357 BroadcastRequest* request = op_request.mutable_broadcast_request(); 358 *request->mutable_operand() = operand; 359 for (int64 size : broadcast_sizes) { 360 request->add_broadcast_sizes(size); 361 } 362 return RunOpAndParseResponse(&op_request); 363} 364 365ComputationDataHandle ComputationBuilder::Pad( 366 const ComputationDataHandle& operand, 367 const ComputationDataHandle& padding_value, 368 const PaddingConfig& padding_config) { 369 OpRequest op_request; 370 PadRequest* request = op_request.mutable_pad_request(); 371 *request->mutable_operand() = operand; 372 *request->mutable_padding_value() = padding_value; 373 *request->mutable_padding_config() = padding_config; 374 return RunOpAndParseResponse(&op_request); 375} 376 377ComputationDataHandle ComputationBuilder::Reshape( 378 const ComputationDataHandle& operand, 379 tensorflow::gtl::ArraySlice<int64> dimensions, 380 tensorflow::gtl::ArraySlice<int64> new_sizes) { 381 OpRequest op_request; 382 ReshapeRequest* request = op_request.mutable_reshape_request(); 383 *request->mutable_operand() = operand; 384 for (int64 dimension : dimensions) { 385 request->add_dimensions(dimension); 386 } 387 for (int64 new_size : new_sizes) { 388 request->add_new_sizes(new_size); 389 } 390 return RunOpAndParseResponse(&op_request); 391} 392 393ComputationDataHandle ComputationBuilder::Reshape( 394 const ComputationDataHandle& operand, 395 tensorflow::gtl::ArraySlice<int64> new_sizes) { 396 if (!first_error_.ok()) { 397 return ComputationDataHandle(); 398 } 399 400 StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand); 401 if (!shape.ok()) { 402 return ComputationDataHandle(); 403 } 404 std::vector<int64> dimensions(shape.ValueOrDie()->dimensions().size()); 405 std::iota(dimensions.begin(), dimensions.end(), 0); 406 return Reshape(operand, dimensions, new_sizes); 407} 408 409ComputationDataHandle ComputationBuilder::Collapse( 410 const ComputationDataHandle& operand, 411 tensorflow::gtl::ArraySlice<int64> dims_to_collapse) { 412 if (!first_error_.ok()) { 413 return ComputationDataHandle(); 414 } 415 416 // Don't support out-of-order collapse here. 417 // Checks that the collapsed dimensions are in order and consecutive. 418 for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1; 419 i < dims_to_collapse.size(); ++i) { 420 if (dims_to_collapse[i] - 1 != dims_to_collapse[i - 1]) { 421 NoteError(InvalidArgument( 422 "Collapsed dimensions are not in order and consecutive.")); 423 return ComputationDataHandle(); 424 } 425 } 426 427 // Create a new sizes vector from the old shape, replacing the collapsed 428 // dimensions by the product of their sizes. 429 StatusOr<std::unique_ptr<Shape>> shape_or_status = GetShape(operand); 430 if (!shape_or_status.ok()) { 431 return ComputationDataHandle(); 432 } 433 std::unique_ptr<Shape> original_shape = shape_or_status.ConsumeValueOrDie(); 434 435 VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); 436 VLOG(3) << "dims to collapse: " 437 << tensorflow::str_util::Join(dims_to_collapse, ","); 438 439 if (dims_to_collapse.size() <= 1) { 440 // Not collapsing anything, trivially we can return the operand versus 441 // enqueueing a trivial reshape. 442 return operand; 443 } 444 445 std::vector<int64> new_sizes; 446 for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { 447 if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) { 448 new_sizes.push_back(original_shape->dimensions(i)); 449 } else { 450 new_sizes.back() *= original_shape->dimensions(i); 451 } 452 } 453 454 VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") 455 << "]"; 456 457 return Reshape(operand, new_sizes); 458} 459 460void ComputationBuilder::Trace(const string& tag, 461 const ComputationDataHandle& operand) { 462 OpRequest op_request; 463 TraceRequest* request = op_request.mutable_trace_request(); 464 request->set_tag(tag); 465 *request->mutable_operand() = operand; 466 RunOpAndNoteError(&op_request); 467} 468 469ComputationDataHandle ComputationBuilder::Select( 470 const ComputationDataHandle& pred, const ComputationDataHandle& on_true, 471 const ComputationDataHandle& on_false) { 472 return TernaryOp(TRIOP_SELECT, pred, on_true, on_false); 473} 474 475ComputationDataHandle ComputationBuilder::Tuple( 476 tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) { 477 OpRequest op_request; 478 VariadicOpRequest* request = op_request.mutable_variadic_op_request(); 479 request->set_varop(VAROP_TUPLE); 480 for (const ComputationDataHandle& operand : elements) { 481 *request->add_operands() = operand; 482 } 483 return RunOpAndParseResponse(&op_request); 484} 485 486ComputationDataHandle ComputationBuilder::GetTupleElement( 487 const ComputationDataHandle& tuple_data, int64 index) { 488 OpRequest op_request; 489 GetTupleElementRequest* request = 490 op_request.mutable_get_tuple_element_request(); 491 *request->mutable_operand() = tuple_data; 492 request->set_index(index); 493 return RunOpAndParseResponse(&op_request); 494} 495 496ComputationDataHandle ComputationBuilder::Eq( 497 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 498 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 499 return BinaryOp(BINOP_EQ, lhs, rhs, broadcast_dimensions); 500} 501 502ComputationDataHandle ComputationBuilder::Ne( 503 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 504 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 505 return BinaryOp(BINOP_NE, lhs, rhs, broadcast_dimensions); 506} 507 508ComputationDataHandle ComputationBuilder::Ge( 509 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 510 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 511 return BinaryOp(BINOP_GE, lhs, rhs, broadcast_dimensions); 512} 513 514ComputationDataHandle ComputationBuilder::Gt( 515 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 516 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 517 return BinaryOp(BINOP_GT, lhs, rhs, broadcast_dimensions); 518} 519 520ComputationDataHandle ComputationBuilder::Le( 521 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 522 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 523 return BinaryOp(BINOP_LE, lhs, rhs, broadcast_dimensions); 524} 525 526ComputationDataHandle ComputationBuilder::Lt( 527 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 528 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 529 return BinaryOp(BINOP_LT, lhs, rhs, broadcast_dimensions); 530} 531 532ComputationDataHandle ComputationBuilder::Dot( 533 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { 534 StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs); 535 if (!lhs_shape_or_status.ok()) { 536 return ComputationDataHandle(); 537 } 538 std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); 539 540 DotDimensionNumbers dimension_numbers; 541 dimension_numbers.add_lhs_contracting_dimensions( 542 lhs_shape->dimensions_size() == 1 ? 0 : 1); 543 dimension_numbers.add_rhs_contracting_dimensions(0); 544 return DotGeneral(lhs, rhs, dimension_numbers); 545} 546 547ComputationDataHandle ComputationBuilder::DotGeneral( 548 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 549 const DotDimensionNumbers& dimension_numbers) { 550 OpRequest op_request; 551 DotRequest* request = op_request.mutable_dot_request(); 552 *request->mutable_lhs() = lhs; 553 *request->mutable_rhs() = rhs; 554 *request->mutable_dimension_numbers() = dimension_numbers; 555 return RunOpAndParseResponse(&op_request); 556} 557 558ComputationDataHandle ComputationBuilder::Conv( 559 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 560 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) { 561 return ConvWithGeneralDimensions( 562 lhs, rhs, window_strides, padding, 563 CreateDefaultConvDimensionNumbers(window_strides.size())); 564} 565 566ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding( 567 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 568 tensorflow::gtl::ArraySlice<int64> window_strides, 569 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) { 570 return ConvGeneral(lhs, rhs, window_strides, padding, 571 CreateDefaultConvDimensionNumbers(window_strides.size())); 572} 573 574bool ComputationBuilder::VerifyConvolution( 575 const Shape& lhs_shape, const Shape& rhs_shape, 576 const ConvolutionDimensionNumbers& dimension_numbers) { 577 if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) { 578 NoteError( 579 InvalidArgument("Convolution arguments must have same number of " 580 "dimensions. Got: %s and %s", 581 ShapeUtil::HumanString(lhs_shape).c_str(), 582 ShapeUtil::HumanString(rhs_shape).c_str())); 583 return false; 584 } 585 int num_dims = ShapeUtil::Rank(lhs_shape); 586 if (num_dims < 2) { 587 NoteError(InvalidArgument( 588 "Convolution expects argument arrays with >= 3 dimensions. " 589 "Got: %s and %s", 590 ShapeUtil::HumanString(lhs_shape).c_str(), 591 ShapeUtil::HumanString(rhs_shape).c_str())); 592 return false; 593 } 594 int num_spatial_dims = num_dims - 2; 595 596 const auto check_spatial_dimensions = 597 [&](const char* const field_name, 598 const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>& 599 numbers) { 600 if (numbers.size() != num_spatial_dims) { 601 NoteError(InvalidArgument("Expected %d elements for %s, but got %d.", 602 num_spatial_dims, field_name, 603 numbers.size())); 604 return false; 605 } 606 for (int i = 0; i < numbers.size(); ++i) { 607 if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) { 608 NoteError( 609 InvalidArgument("Convolution %s[%d] is out of bounds: %lld", 610 field_name, i, numbers.Get(i))); 611 return false; 612 } 613 } 614 return true; 615 }; 616 return check_spatial_dimensions( 617 "input_spatial_dimensions", 618 dimension_numbers.input_spatial_dimensions()) && 619 check_spatial_dimensions( 620 "kernel_spatial_dimensions", 621 dimension_numbers.kernel_spatial_dimensions()) && 622 check_spatial_dimensions( 623 "output_spatial_dimensions", 624 dimension_numbers.output_spatial_dimensions()); 625} 626 627ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( 628 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 629 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, 630 const ConvolutionDimensionNumbers& dimension_numbers) { 631 if (!first_error_.ok() || !PrepareComputation().ok()) { 632 return ComputationDataHandle(); 633 } 634 635 StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs); 636 if (!lhs_shape_or_status.ok()) { 637 return ComputationDataHandle(); 638 } 639 640 StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs); 641 if (!rhs_shape_or_status.ok()) { 642 return ComputationDataHandle(); 643 } 644 645 std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); 646 std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); 647 648 if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { 649 NoteError(InternalError("failed to verify convolution")); 650 return ComputationDataHandle(); 651 } 652 653 std::vector<int64> base_area_dimensions( 654 dimension_numbers.input_spatial_dimensions_size()); 655 for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size(); 656 ++i) { 657 base_area_dimensions[i] = 658 lhs_shape->dimensions(dimension_numbers.input_spatial_dimensions(i)); 659 } 660 661 std::vector<int64> window_dimensions( 662 dimension_numbers.kernel_spatial_dimensions_size()); 663 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) { 664 window_dimensions[i] = 665 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); 666 } 667 668 return ConvGeneral(lhs, rhs, window_strides, 669 MakePadding(base_area_dimensions, window_dimensions, 670 window_strides, padding), 671 dimension_numbers); 672} 673 674ComputationDataHandle ComputationBuilder::ConvGeneral( 675 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 676 tensorflow::gtl::ArraySlice<int64> window_strides, 677 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, 678 const ConvolutionDimensionNumbers& dimension_numbers) { 679 return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {}, 680 dimension_numbers); 681} 682 683ComputationDataHandle ComputationBuilder::ConvGeneralDilated( 684 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 685 tensorflow::gtl::ArraySlice<int64> window_strides, 686 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, 687 tensorflow::gtl::ArraySlice<int64> lhs_dilation, 688 tensorflow::gtl::ArraySlice<int64> rhs_dilation, 689 const ConvolutionDimensionNumbers& dimension_numbers) { 690 if (!first_error_.ok() || !PrepareComputation().ok()) { 691 return ComputationDataHandle(); 692 } 693 694 StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs); 695 if (!lhs_shape_or_status.ok()) { 696 return ComputationDataHandle(); 697 } 698 699 StatusOr<std::unique_ptr<Shape>> rhs_shape_or_status = GetShape(rhs); 700 if (!rhs_shape_or_status.ok()) { 701 return ComputationDataHandle(); 702 } 703 704 std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); 705 std::unique_ptr<Shape> rhs_shape = rhs_shape_or_status.ConsumeValueOrDie(); 706 if (!VerifyConvolution(*lhs_shape, *rhs_shape, dimension_numbers)) { 707 // Error is recorded in VerifyConvolution. 708 return ComputationDataHandle(); 709 } 710 711 std::vector<int64> window_dimensions( 712 dimension_numbers.kernel_spatial_dimensions_size()); 713 for (std::vector<int64>::size_type i = 0; i < window_dimensions.size(); ++i) { 714 window_dimensions[i] = 715 rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); 716 } 717 718 OpRequest op_request; 719 ConvolveRequest* request = op_request.mutable_convolve_request(); 720 *request->mutable_lhs() = lhs; 721 *request->mutable_rhs() = rhs; 722 *request->mutable_dimension_numbers() = dimension_numbers; 723 724 if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, 725 rhs_dilation, request->mutable_window())) { 726 // Error is recorded in MakeWindow. 727 return ComputationDataHandle(); 728 } 729 730 return RunOpAndParseResponse(&op_request); 731} 732 733ComputationDataHandle ComputationBuilder::Fft( 734 const ComputationDataHandle& operand, const FftType fft_type, 735 const tensorflow::gtl::ArraySlice<int64> fft_length) { 736 OpRequest op_request; 737 FftRequest* request = op_request.mutable_fft_request(); 738 *request->mutable_operand() = operand; 739 request->set_fft_type(fft_type); 740 for (int64 dim_len : fft_length) { 741 request->add_fft_length(dim_len); 742 } 743 return RunOpAndParseResponse(&op_request); 744} 745 746ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, 747 const string& config) { 748 OpRequest op_request; 749 InfeedRequest* request = op_request.mutable_infeed_request(); 750 *request->mutable_shape() = shape; 751 *request->mutable_config() = config; 752 return RunOpAndParseResponse(&op_request); 753} 754 755void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, 756 const Shape& shape, 757 const string& outfeed_config) { 758 OpRequest op_request; 759 OutfeedRequest* request = op_request.mutable_outfeed_request(); 760 request->set_outfeed_config(outfeed_config); 761 *request->mutable_operand() = operand; 762 *request->mutable_shape() = shape; 763 RunOpAndNoteError(&op_request); 764} 765 766ComputationDataHandle ComputationBuilder::Call( 767 const Computation& computation, 768 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) { 769 OpRequest op_request; 770 CallRequest* request = op_request.mutable_call_request(); 771 *request->mutable_to_apply() = computation.handle(); 772 for (const ComputationDataHandle& operand : operands) { 773 *request->add_operands() = operand; 774 } 775 return RunOpAndParseResponse(&op_request); 776} 777 778ComputationDataHandle ComputationBuilder::CustomCall( 779 const string& call_target_name, 780 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, 781 const Shape& shape) { 782 OpRequest op_request; 783 CustomCallRequest* request = op_request.mutable_custom_call_request(); 784 request->set_call_target_name(call_target_name); 785 for (const ComputationDataHandle& operand : operands) { 786 *request->add_operands() = operand; 787 } 788 *request->mutable_shape() = shape; 789 return RunOpAndParseResponse(&op_request); 790} 791 792ComputationDataHandle ComputationBuilder::HostCompute( 793 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, 794 const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { 795 OpRequest op_request; 796 HostComputeRequest* request = op_request.mutable_host_compute_request(); 797 for (const ComputationDataHandle& operand : operands) { 798 *request->add_operands() = operand; 799 } 800 *request->mutable_shape() = shape; 801 request->set_channel_name(channel_name); 802 request->set_cost_estimate_ns(cost_estimate_ns); 803 return RunOpAndParseResponse(&op_request); 804} 805 806ComputationDataHandle ComputationBuilder::Complex( 807 const ComputationDataHandle& real, const ComputationDataHandle& imag, 808 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 809 return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions); 810} 811 812ComputationDataHandle ComputationBuilder::Conj( 813 const ComputationDataHandle& operand) { 814 return Complex(Real(operand), Neg(Imag(operand))); 815} 816 817ComputationDataHandle ComputationBuilder::Add( 818 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 819 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 820 return BinaryOp(BINOP_ADD, lhs, rhs, broadcast_dimensions); 821} 822 823ComputationDataHandle ComputationBuilder::Sub( 824 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 825 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 826 return BinaryOp(BINOP_SUB, lhs, rhs, broadcast_dimensions); 827} 828 829ComputationDataHandle ComputationBuilder::Mul( 830 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 831 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 832 return BinaryOp(BINOP_MUL, lhs, rhs, broadcast_dimensions); 833} 834 835ComputationDataHandle ComputationBuilder::Div( 836 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 837 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 838 return BinaryOp(BINOP_DIV, lhs, rhs, broadcast_dimensions); 839} 840 841ComputationDataHandle ComputationBuilder::Rem( 842 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 843 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 844 return BinaryOp(BINOP_REM, lhs, rhs, broadcast_dimensions); 845} 846 847ComputationDataHandle ComputationBuilder::Max( 848 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 849 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 850 return BinaryOp(BINOP_MAX, lhs, rhs, broadcast_dimensions); 851} 852 853ComputationDataHandle ComputationBuilder::Min( 854 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 855 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 856 return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions); 857} 858 859ComputationDataHandle ComputationBuilder::And( 860 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 861 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 862 return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions); 863} 864 865ComputationDataHandle ComputationBuilder::Or( 866 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 867 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 868 return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); 869} 870 871ComputationDataHandle ComputationBuilder::Not( 872 const ComputationDataHandle& operand) { 873 return UnaryOp(UNOP_NOT, operand); 874} 875 876ComputationDataHandle ComputationBuilder::ShiftLeft( 877 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 878 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 879 return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions); 880} 881 882ComputationDataHandle ComputationBuilder::ShiftRightArithmetic( 883 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 884 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 885 return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions); 886} 887 888ComputationDataHandle ComputationBuilder::ShiftRightLogical( 889 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 890 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 891 return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions); 892} 893 894ComputationDataHandle ComputationBuilder::Abs( 895 const ComputationDataHandle& operand) { 896 return UnaryOp(UNOP_ABS, operand); 897} 898 899ComputationDataHandle ComputationBuilder::Atan2( 900 const ComputationDataHandle& y, const ComputationDataHandle& x, 901 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 902 return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions); 903} 904 905ComputationDataHandle ComputationBuilder::Exp( 906 const ComputationDataHandle& operand) { 907 return UnaryOp(UNOP_EXP, operand); 908} 909 910ComputationDataHandle ComputationBuilder::Floor( 911 const ComputationDataHandle& operand) { 912 return UnaryOp(UNOP_FLOOR, operand); 913} 914 915ComputationDataHandle ComputationBuilder::Ceil( 916 const ComputationDataHandle& operand) { 917 return UnaryOp(UNOP_CEIL, operand); 918} 919 920ComputationDataHandle ComputationBuilder::Round( 921 const ComputationDataHandle& operand) { 922 return UnaryOp(UNOP_ROUND_NEAREST_AFZ, operand); 923} 924 925ComputationDataHandle ComputationBuilder::Log( 926 const ComputationDataHandle& operand) { 927 return UnaryOp(UNOP_LOG, operand); 928} 929 930ComputationDataHandle ComputationBuilder::Sign( 931 const ComputationDataHandle& operand) { 932 return UnaryOp(UNOP_SIGN, operand); 933} 934 935ComputationDataHandle ComputationBuilder::Cos( 936 const ComputationDataHandle& operand) { 937 return UnaryOp(UNOP_COS, operand); 938} 939 940ComputationDataHandle ComputationBuilder::Sin( 941 const ComputationDataHandle& operand) { 942 return UnaryOp(UNOP_SIN, operand); 943} 944 945ComputationDataHandle ComputationBuilder::Tanh( 946 const ComputationDataHandle& operand) { 947 return UnaryOp(UNOP_TANH, operand); 948} 949 950ComputationDataHandle ComputationBuilder::Real( 951 const ComputationDataHandle& operand) { 952 return UnaryOp(UNOP_REAL, operand); 953} 954 955ComputationDataHandle ComputationBuilder::Imag( 956 const ComputationDataHandle& operand) { 957 return UnaryOp(UNOP_IMAG, operand); 958} 959 960ComputationDataHandle ComputationBuilder::IsFinite( 961 const ComputationDataHandle& operand) { 962 return UnaryOp(UNOP_IS_FINITE, operand); 963} 964 965ComputationDataHandle ComputationBuilder::Transpose( 966 const ComputationDataHandle& operand, 967 tensorflow::gtl::ArraySlice<int64> permutation) { 968 OpRequest op_request; 969 TransposeRequest* request = op_request.mutable_transpose_request(); 970 *request->mutable_operand() = operand; 971 for (int64 dimension : permutation) { 972 request->add_dimensions(dimension); 973 } 974 return RunOpAndParseResponse(&op_request); 975} 976 977ComputationDataHandle ComputationBuilder::Rev( 978 const ComputationDataHandle& operand, 979 tensorflow::gtl::ArraySlice<int64> dimensions) { 980 OpRequest op_request; 981 ReverseRequest* request = op_request.mutable_reverse_request(); 982 *request->mutable_operand() = operand; 983 for (int64 dimension : dimensions) { 984 request->add_dimensions(dimension); 985 } 986 return RunOpAndParseResponse(&op_request); 987} 988 989ComputationDataHandle ComputationBuilder::Sort( 990 const ComputationDataHandle& operand) { 991 return UnaryOp(UNOP_SORT, operand); 992} 993 994ComputationDataHandle ComputationBuilder::SqrtF32( 995 const ComputationDataHandle& operand) { 996 return BinaryOp(BINOP_POW, operand, ConstantR0<float>(0.5), 997 /*broadcast_dimensions=*/{}); 998} 999 1000ComputationDataHandle ComputationBuilder::Pow( 1001 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 1002 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 1003 return BinaryOp(BINOP_POW, lhs, rhs, broadcast_dimensions); 1004} 1005 1006ComputationDataHandle ComputationBuilder::ConvertElementType( 1007 const ComputationDataHandle& operand, PrimitiveType new_element_type) { 1008 if (!first_error_.ok() || !PrepareComputation().ok()) { 1009 return ComputationDataHandle(); 1010 } 1011 1012 StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand); 1013 if (!shape_status.ok()) { 1014 return ComputationDataHandle(); 1015 } 1016 std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie(); 1017 1018 OpRequest op_request; 1019 ConvertRequest* request = op_request.mutable_convert_request(); 1020 *request->mutable_operand() = operand; 1021 request->set_new_element_type(new_element_type); 1022 return RunOpAndParseResponse(&op_request); 1023} 1024 1025ComputationDataHandle ComputationBuilder::BitcastConvertType( 1026 const ComputationDataHandle& operand, PrimitiveType new_element_type) { 1027 if (!first_error_.ok() || !PrepareComputation().ok()) { 1028 return ComputationDataHandle(); 1029 } 1030 1031 StatusOr<std::unique_ptr<Shape>> shape_status = GetShape(operand); 1032 if (!shape_status.ok()) { 1033 return ComputationDataHandle(); 1034 } 1035 std::unique_ptr<Shape> original = shape_status.ConsumeValueOrDie(); 1036 1037 OpRequest op_request; 1038 ConvertRequest* request = op_request.mutable_bitcast_convert_request(); 1039 *request->mutable_operand() = operand; 1040 request->set_new_element_type(new_element_type); 1041 return RunOpAndParseResponse(&op_request); 1042} 1043 1044ComputationDataHandle ComputationBuilder::SquareF32( 1045 const ComputationDataHandle& operand) { 1046 return BinaryOp(BINOP_POW, operand, ConstantR0<float>(2.0), 1047 /*broadcast_dimensions=*/{}); 1048} 1049 1050ComputationDataHandle ComputationBuilder::ReciprocalF32( 1051 const ComputationDataHandle& operand) { 1052 return BinaryOp(BINOP_POW, operand, ConstantR0<float>(-1.0), 1053 /*broadcast_dimensions=*/{}); 1054} 1055 1056ComputationDataHandle ComputationBuilder::Neg( 1057 const ComputationDataHandle& operand) { 1058 return UnaryOp(UNOP_NEGATE, operand); 1059} 1060 1061ComputationDataHandle ComputationBuilder::Clamp( 1062 const ComputationDataHandle& min, const ComputationDataHandle& operand, 1063 const ComputationDataHandle& max) { 1064 return TernaryOp(TRIOP_CLAMP, min, operand, max); 1065} 1066 1067ComputationDataHandle ComputationBuilder::UnaryOp( 1068 UnaryOperation unop, const ComputationDataHandle& operand) { 1069 OpRequest op_request; 1070 UnaryOpRequest* request = op_request.mutable_unary_op_request(); 1071 request->set_unop(unop); 1072 *request->mutable_operand() = operand; 1073 return RunOpAndParseResponse(&op_request); 1074} 1075 1076ComputationDataHandle ComputationBuilder::BinaryOp( 1077 BinaryOperation binop, const ComputationDataHandle& lhs, 1078 const ComputationDataHandle& rhs, 1079 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { 1080 OpRequest op_request; 1081 BinaryOpRequest* request = op_request.mutable_binary_op_request(); 1082 request->set_binop(binop); 1083 *request->mutable_lhs() = lhs; 1084 *request->mutable_rhs() = rhs; 1085 for (int64 dimension : broadcast_dimensions) { 1086 request->add_broadcast_dimensions(dimension); 1087 } 1088 return RunOpAndParseResponse(&op_request); 1089} 1090 1091ComputationDataHandle ComputationBuilder::RngOp( 1092 RandomDistribution distribution, 1093 tensorflow::gtl::ArraySlice<ComputationDataHandle> parameters, 1094 const Shape& shape) { 1095 OpRequest op_request; 1096 RngRequest* request = op_request.mutable_rng_request(); 1097 request->set_distribution(distribution); 1098 for (const ComputationDataHandle& param : parameters) { 1099 *request->add_parameter() = param; 1100 } 1101 *request->mutable_shape() = shape; 1102 return RunOpAndParseResponse(&op_request); 1103} 1104 1105ComputationDataHandle ComputationBuilder::TernaryOp( 1106 TernaryOperation triop, const ComputationDataHandle& lhs, 1107 const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) { 1108 OpRequest op_request; 1109 TernaryOpRequest* request = op_request.mutable_ternary_op_request(); 1110 request->set_triop(triop); 1111 *request->mutable_lhs() = lhs; 1112 *request->mutable_rhs() = rhs; 1113 *request->mutable_ehs() = ehs; 1114 return RunOpAndParseResponse(&op_request); 1115} 1116 1117Status ComputationBuilder::SetReturnValue( 1118 const ComputationDataHandle& operand) { 1119 TF_RETURN_IF_ERROR(first_error_); 1120 1121 SetReturnValueRequest request; 1122 *request.mutable_computation() = computation_.handle(); 1123 *request.mutable_operand() = operand; 1124 1125 SetReturnValueResponse response; 1126 1127 VLOG(2) << "making set-handle-to-execute request"; 1128 Status s = client_->stub()->SetReturnValue(&request, &response); 1129 VLOG(2) << "done with request"; 1130 1131 if (!s.ok()) { 1132 NoteError(s); 1133 return first_error_; 1134 } 1135 1136 return Status::OK(); 1137} 1138 1139StatusOr<bool> ComputationBuilder::IsConstant( 1140 const ComputationDataHandle& operand, int64 num_parameters) { 1141 TF_RETURN_IF_ERROR(first_error_); 1142 1143 IsConstantRequest request; 1144 *request.mutable_computation() = computation_.handle(); 1145 *request.mutable_operand() = operand; 1146 request.set_num_parameters(num_parameters); 1147 IsConstantResponse response; 1148 1149 VLOG(2) << "making IsConstant request"; 1150 Status s = client_->stub()->IsConstant(&request, &response); 1151 VLOG(2) << "done with request"; 1152 1153 if (!s.ok()) { 1154 return s; 1155 } 1156 return response.is_constant(); 1157} 1158 1159StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant( 1160 const ComputationDataHandle& operand, const Layout* output_layout, 1161 tensorflow::gtl::ArraySlice<Literal> parameters) { 1162 TF_RETURN_IF_ERROR(first_error_); 1163 1164 ComputeConstantRequest request; 1165 *request.mutable_computation() = computation_.handle(); 1166 *request.mutable_operand() = operand; 1167 if (output_layout != nullptr) { 1168 *request.mutable_output_layout() = *output_layout; 1169 } 1170 for (const auto& param : parameters) { 1171 *request.add_parameters() = param.ToProto(); 1172 } 1173 1174 ComputeConstantResponse response; 1175 1176 VLOG(2) << "making compute-constant request"; 1177 Status s = client_->stub()->ComputeConstant(&request, &response); 1178 VLOG(2) << "done with request"; 1179 1180 if (!s.ok()) { 1181 return s; 1182 } 1183 1184 VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}"; 1185 1186 if (!response.has_literal()) { 1187 return InternalError( 1188 "no computed literal in the provided response in ComputeConstant " 1189 "request"); 1190 } 1191 return Literal::CreateFromProto(response.literal()); 1192} 1193 1194ComputationDataHandle ComputationBuilder::Map( 1195 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, 1196 const Computation& computation, 1197 tensorflow::gtl::ArraySlice<int64> dimensions, 1198 tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) { 1199 OpRequest op_request; 1200 MapRequest* request = op_request.mutable_map_request(); 1201 for (const ComputationDataHandle& operand : operands) { 1202 *request->add_operands() = operand; 1203 } 1204 *request->mutable_to_apply() = computation.handle(); 1205 for (int64 dimension : dimensions) { 1206 request->add_dimensions(dimension); 1207 } 1208 for (const ComputationDataHandle& sop : static_operands) { 1209 *request->add_static_operands() = sop; 1210 } 1211 return RunOpAndParseResponse(&op_request); 1212} 1213 1214ComputationDataHandle ComputationBuilder::RngNormal( 1215 const ComputationDataHandle& mu, const ComputationDataHandle& sigma, 1216 const Shape& shape) { 1217 return RngOp(RandomDistribution::RNG_NORMAL, {mu, sigma}, shape); 1218} 1219 1220ComputationDataHandle ComputationBuilder::RngUniform( 1221 const ComputationDataHandle& a, const ComputationDataHandle& b, 1222 const Shape& shape) { 1223 return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); 1224} 1225 1226ComputationDataHandle ComputationBuilder::While( 1227 const Computation& condition, const Computation& body, 1228 const ComputationDataHandle& init) { 1229 OpRequest op_request; 1230 WhileRequest* request = op_request.mutable_while_request(); 1231 *request->mutable_condition() = condition.handle(); 1232 *request->mutable_body() = body.handle(); 1233 *request->mutable_init() = init; 1234 return RunOpAndParseResponse(&op_request); 1235} 1236 1237ComputationDataHandle ComputationBuilder::Gather( 1238 const ComputationDataHandle& input, 1239 const ComputationDataHandle& gather_indices, 1240 const GatherDimensionNumbers& dimension_numbers, 1241 tensorflow::gtl::ArraySlice<int64> window_bounds) { 1242 OpRequest op_request; 1243 GatherRequest* gather_request = op_request.mutable_gather_request(); 1244 *gather_request->mutable_input() = input; 1245 *gather_request->mutable_gather_indices() = gather_indices; 1246 *gather_request->mutable_dimension_numbers() = dimension_numbers; 1247 for (int64 window_bound : window_bounds) { 1248 gather_request->add_window_bounds(window_bound); 1249 } 1250 return RunOpAndParseResponse(&op_request); 1251} 1252 1253ComputationDataHandle ComputationBuilder::Conditional( 1254 const ComputationDataHandle& predicate, 1255 const ComputationDataHandle& true_operand, 1256 const Computation& true_computation, 1257 const ComputationDataHandle& false_operand, 1258 const Computation& false_computation) { 1259 OpRequest op_request; 1260 ConditionalRequest* request = op_request.mutable_conditional_request(); 1261 *request->mutable_predicate() = predicate; 1262 *request->mutable_true_operand() = true_operand; 1263 *request->mutable_true_computation() = true_computation.handle(); 1264 *request->mutable_false_operand() = false_operand; 1265 *request->mutable_false_computation() = false_computation.handle(); 1266 return RunOpAndParseResponse(&op_request); 1267} 1268 1269ComputationDataHandle ComputationBuilder::Reduce( 1270 const ComputationDataHandle& operand, 1271 const ComputationDataHandle& init_value, const Computation& computation, 1272 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { 1273 OpRequest op_request; 1274 ReduceRequest* request = op_request.mutable_reduce_request(); 1275 *request->mutable_operand() = operand; 1276 *request->mutable_init_value() = init_value; 1277 for (int64 dimension : dimensions_to_reduce) { 1278 request->add_dimensions(dimension); 1279 } 1280 *request->mutable_to_apply() = computation.handle(); 1281 return RunOpAndParseResponse(&op_request); 1282} 1283 1284ComputationDataHandle ComputationBuilder::ReduceAll( 1285 const ComputationDataHandle& operand, 1286 const ComputationDataHandle& init_value, const Computation& computation) { 1287 if (!first_error_.ok() || !PrepareComputation().ok()) { 1288 return ComputationDataHandle(); 1289 } 1290 1291 StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand); 1292 if (!shape.ok()) { 1293 return ComputationDataHandle(); 1294 } 1295 1296 std::vector<int64> all_dimnos(ShapeUtil::Rank(*shape.ValueOrDie())); 1297 std::iota(all_dimnos.begin(), all_dimnos.end(), 0); 1298 return Reduce(operand, init_value, computation, all_dimnos); 1299} 1300 1301ComputationDataHandle ComputationBuilder::ReduceWindow( 1302 const ComputationDataHandle& operand, 1303 const ComputationDataHandle& init_value, const Computation& computation, 1304 tensorflow::gtl::ArraySlice<int64> window_dimensions, 1305 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding) { 1306 if (!first_error_.ok()) { 1307 return ComputationDataHandle(); 1308 } 1309 1310 StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand); 1311 if (!shape.ok()) { 1312 return ComputationDataHandle(); 1313 } 1314 1315 Status padding_valid = 1316 ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()), 1317 window_dimensions, window_strides); 1318 if (!padding_valid.ok()) { 1319 first_error_ = padding_valid; 1320 return ComputationDataHandle(); 1321 } 1322 1323 std::vector<std::pair<int64, int64>> padding_values = 1324 MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), 1325 window_dimensions, window_strides, padding); 1326 return ReduceWindowWithGeneralPadding(operand, init_value, computation, 1327 window_dimensions, window_strides, 1328 padding_values); 1329} 1330 1331ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( 1332 const ComputationDataHandle& operand, 1333 const ComputationDataHandle& init_value, const Computation& computation, 1334 tensorflow::gtl::ArraySlice<int64> window_dimensions, 1335 tensorflow::gtl::ArraySlice<int64> window_strides, 1336 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) { 1337 OpRequest op_request; 1338 ReduceWindowRequest* request = op_request.mutable_reduce_window_request(); 1339 *request->mutable_operand() = operand; 1340 *request->mutable_to_apply() = computation.handle(); 1341 *request->mutable_init_value() = init_value; 1342 1343 if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, 1344 request->mutable_window())) { 1345 NoteError(InternalError("failed to make window")); 1346 return ComputationDataHandle(); 1347 } 1348 1349 return RunOpAndParseResponse(&op_request); 1350} 1351 1352ComputationDataHandle ComputationBuilder::BatchNormTraining( 1353 const ComputationDataHandle& operand, const ComputationDataHandle& scale, 1354 const ComputationDataHandle& offset, float epsilon, int64 feature_index) { 1355 OpRequest op_request; 1356 BatchNormTrainingRequest* request = 1357 op_request.mutable_batch_norm_training_request(); 1358 *request->mutable_operand() = operand; 1359 *request->mutable_scale() = scale; 1360 *request->mutable_offset() = offset; 1361 request->set_epsilon(epsilon); 1362 request->set_feature_index(feature_index); 1363 return RunOpAndParseResponse(&op_request); 1364} 1365 1366ComputationDataHandle ComputationBuilder::BatchNormInference( 1367 const ComputationDataHandle& operand, const ComputationDataHandle& scale, 1368 const ComputationDataHandle& offset, const ComputationDataHandle& mean, 1369 const ComputationDataHandle& variance, float epsilon, int64 feature_index) { 1370 OpRequest op_request; 1371 BatchNormInferenceRequest* request = 1372 op_request.mutable_batch_norm_inference_request(); 1373 *request->mutable_operand() = operand; 1374 *request->mutable_scale() = scale; 1375 *request->mutable_offset() = offset; 1376 *request->mutable_mean() = mean; 1377 *request->mutable_variance() = variance; 1378 request->set_epsilon(epsilon); 1379 request->set_feature_index(feature_index); 1380 return RunOpAndParseResponse(&op_request); 1381} 1382 1383ComputationDataHandle ComputationBuilder::BatchNormGrad( 1384 const ComputationDataHandle& operand, const ComputationDataHandle& scale, 1385 const ComputationDataHandle& mean, const ComputationDataHandle& var, 1386 const ComputationDataHandle& grad_output, float epsilon, 1387 int64 feature_index) { 1388 OpRequest op_request; 1389 BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request(); 1390 *request->mutable_operand() = operand; 1391 *request->mutable_scale() = scale; 1392 *request->mutable_mean() = mean; 1393 *request->mutable_variance() = var; 1394 *request->mutable_grad_output() = grad_output; 1395 request->set_epsilon(epsilon); 1396 request->set_feature_index(feature_index); 1397 return RunOpAndParseResponse(&op_request); 1398} 1399 1400ComputationDataHandle ComputationBuilder::CrossReplicaSum( 1401 const ComputationDataHandle& operand) { 1402 OpRequest op_request; 1403 CrossReplicaSumRequest* request = 1404 op_request.mutable_cross_replica_sum_request(); 1405 *request->mutable_operand() = operand; 1406 return RunOpAndParseResponse(&op_request); 1407} 1408 1409ComputationDataHandle ComputationBuilder::SelectAndScatter( 1410 const ComputationDataHandle& operand, const Computation& select, 1411 tensorflow::gtl::ArraySlice<int64> window_dimensions, 1412 tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding, 1413 const ComputationDataHandle& source, 1414 const ComputationDataHandle& init_value, const Computation& scatter) { 1415 if (!first_error_.ok()) { 1416 return ComputationDataHandle(); 1417 } 1418 1419 StatusOr<std::unique_ptr<Shape>> shape = GetShape(operand); 1420 if (!shape.ok()) { 1421 return ComputationDataHandle(); 1422 } 1423 return SelectAndScatterWithGeneralPadding( 1424 operand, select, window_dimensions, window_strides, 1425 MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), 1426 window_dimensions, window_strides, padding), 1427 source, init_value, scatter); 1428} 1429 1430ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( 1431 const ComputationDataHandle& operand, const Computation& select, 1432 tensorflow::gtl::ArraySlice<int64> window_dimensions, 1433 tensorflow::gtl::ArraySlice<int64> window_strides, 1434 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, 1435 const ComputationDataHandle& source, 1436 const ComputationDataHandle& init_value, const Computation& scatter) { 1437 OpRequest op_request; 1438 SelectAndScatterRequest* request = 1439 op_request.mutable_select_and_scatter_request(); 1440 *request->mutable_operand() = operand; 1441 *request->mutable_select() = select.handle(); 1442 *request->mutable_source() = source; 1443 *request->mutable_init_value() = init_value; 1444 *request->mutable_scatter() = scatter.handle(); 1445 1446 if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, 1447 request->mutable_window())) { 1448 NoteError(InternalError("failed to make window")); 1449 return ComputationDataHandle(); 1450 } 1451 1452 return RunOpAndParseResponse(&op_request); 1453} 1454 1455ComputationDataHandle ComputationBuilder::ReducePrecision( 1456 const ComputationDataHandle& operand, const int exponent_bits, 1457 const int mantissa_bits) { 1458 OpRequest op_request; 1459 ReducePrecisionRequest* request = 1460 op_request.mutable_reduce_precision_request(); 1461 *request->mutable_operand() = operand; 1462 request->set_exponent_bits(exponent_bits); 1463 request->set_mantissa_bits(mantissa_bits); 1464 return RunOpAndParseResponse(&op_request); 1465} 1466 1467void ComputationBuilder::Send(const ComputationDataHandle& operand, 1468 const ChannelHandle& handle) { 1469 OpRequest op_request; 1470 SendRequest* request = op_request.mutable_send_request(); 1471 *request->mutable_operand() = operand; 1472 *request->mutable_channel_handle() = handle; 1473 *op_request.mutable_computation() = computation_.handle(); 1474 RunOpAndNoteError(&op_request); 1475} 1476 1477ComputationDataHandle ComputationBuilder::Recv(const Shape& shape, 1478 const ChannelHandle& handle) { 1479 OpRequest op_request; 1480 RecvRequest* request = op_request.mutable_recv_request(); 1481 *request->mutable_shape() = shape; 1482 *request->mutable_channel_handle() = handle; 1483 return RunOpAndParseResponse(&op_request); 1484} 1485 1486Computation ComputationBuilder::BuildAndNoteError() { 1487 DCHECK(parent_builder_ != nullptr); 1488 auto build_status = Build(); 1489 if (!build_status.ok()) { 1490 parent_builder_->NoteError( 1491 AddStatus(build_status.status(), 1492 tensorflow::strings::StrCat("error from: ", name_))); 1493 return Computation(); 1494 } 1495 return build_status.ConsumeValueOrDie(); 1496} 1497 1498StatusOr<Computation> ComputationBuilder::Build() { 1499 if (!first_error_.ok()) { 1500 string backtrace; 1501 first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace); 1502 return AppendStatus(first_error_, backtrace); 1503 } 1504 1505 if (computation_.IsNull()) { 1506 return FailedPrecondition("no computation was built"); 1507 } 1508 1509 return {std::move(computation_)}; 1510} 1511 1512/* static */ ConvolutionDimensionNumbers 1513ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { 1514 ConvolutionDimensionNumbers dimension_numbers; 1515 dimension_numbers.set_input_batch_dimension(kConvBatchDimension); 1516 dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); 1517 dimension_numbers.set_output_batch_dimension(kConvBatchDimension); 1518 dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); 1519 dimension_numbers.set_kernel_output_feature_dimension( 1520 kConvKernelOutputDimension); 1521 dimension_numbers.set_kernel_input_feature_dimension( 1522 kConvKernelInputDimension); 1523 for (int i = 0; i < num_spatial_dims; ++i) { 1524 dimension_numbers.add_input_spatial_dimensions(i + 2); 1525 dimension_numbers.add_kernel_spatial_dimensions(i + 2); 1526 dimension_numbers.add_output_spatial_dimensions(i + 2); 1527 } 1528 return dimension_numbers; 1529} 1530 1531/* static */ StatusOr<ConvolutionDimensionNumbers> 1532ComputationBuilder::CreateConvDimensionNumbers( 1533 int64 input_batch, int64 input_feature, int64 input_first_spatial, 1534 int64 input_second_spatial, int64 output_batch, int64 output_feature, 1535 int64 output_first_spatial, int64 output_second_spatial, 1536 int64 kernel_output_feature, int64 kernel_input_feature, 1537 int64 kernel_first_spatial, int64 kernel_second_spatial) { 1538 if (std::set<int64>({input_batch, input_feature, input_first_spatial, 1539 input_second_spatial}) 1540 .size() != 4) { 1541 return FailedPrecondition( 1542 "dimension numbers for the input are not unique: (%lld, %lld, %lld, " 1543 "%lld)", 1544 input_batch, input_feature, input_first_spatial, input_second_spatial); 1545 } 1546 if (std::set<int64>({kernel_output_feature, kernel_input_feature, 1547 kernel_first_spatial, kernel_second_spatial}) 1548 .size() != 4) { 1549 return FailedPrecondition( 1550 "dimension numbers for the weight are not unique: (%lld, %lld, %lld, " 1551 "%lld)", 1552 kernel_output_feature, kernel_input_feature, kernel_first_spatial, 1553 kernel_second_spatial); 1554 } 1555 if (std::set<int64>({output_batch, output_feature, output_first_spatial, 1556 output_second_spatial}) 1557 .size() != 4) { 1558 return FailedPrecondition( 1559 "dimension numbers for the output are not unique: (%lld, %lld, %lld, " 1560 "%lld)", 1561 output_batch, output_feature, output_first_spatial, 1562 output_second_spatial); 1563 } 1564 ConvolutionDimensionNumbers dimension_numbers; 1565 dimension_numbers.set_input_batch_dimension(input_batch); 1566 dimension_numbers.set_input_feature_dimension(input_feature); 1567 dimension_numbers.add_input_spatial_dimensions(input_first_spatial); 1568 dimension_numbers.add_input_spatial_dimensions(input_second_spatial); 1569 dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); 1570 dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); 1571 dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); 1572 dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); 1573 dimension_numbers.set_output_batch_dimension(output_batch); 1574 dimension_numbers.set_output_feature_dimension(output_feature); 1575 dimension_numbers.add_output_spatial_dimensions(output_first_spatial); 1576 dimension_numbers.add_output_spatial_dimensions(output_second_spatial); 1577 return dimension_numbers; 1578} 1579 1580} // namespace xla 1581