1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/compiler/xla/service/local_service.h" 17 18#include <string> 19#include <utility> 20#include <vector> 21 22#include "tensorflow/compiler/xla/client/executable_build_options.h" 23#include "tensorflow/compiler/xla/execution_options_util.h" 24#include "tensorflow/compiler/xla/ptr_util.h" 25#include "tensorflow/compiler/xla/service/backend.h" 26#include "tensorflow/compiler/xla/service/computation_layout.h" 27#include "tensorflow/compiler/xla/service/computation_tracker.h" 28#include "tensorflow/compiler/xla/service/executable.h" 29#include "tensorflow/compiler/xla/service/hlo_computation.h" 30#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" 31#include "tensorflow/compiler/xla/service/hlo_module.h" 32#include "tensorflow/compiler/xla/service/hlo_module_config.h" 33#include "tensorflow/compiler/xla/service/platform_util.h" 34#include "tensorflow/compiler/xla/service/user_computation.h" 35#include "tensorflow/compiler/xla/service/versioned_computation_handle.h" 36#include "tensorflow/compiler/xla/shape_layout.h" 37#include "tensorflow/compiler/xla/shape_util.h" 38#include "tensorflow/compiler/xla/status_macros.h" 39#include "tensorflow/compiler/xla/types.h" 40#include "tensorflow/compiler/xla/util.h" 41#include "tensorflow/core/lib/gtl/cleanup.h" 42#include "tensorflow/core/lib/strings/strcat.h" 43#include "tensorflow/core/platform/logging.h" 44#include "tensorflow/core/platform/stream_executor_no_cuda.h" 45 46namespace se = ::perftools::gputools; 47 48namespace xla { 49 50/* static */ StatusOr<std::unique_ptr<LocalService>> LocalService::NewService( 51 const ServiceOptions& options) { 52 perftools::gputools::Platform* platform = options.platform(); 53 if (platform == nullptr) { 54 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform()); 55 } 56 57 BackendOptions backend_options; 58 backend_options.set_platform(platform).set_intra_op_parallelism_threads( 59 options.intra_op_parallelism_threads()); 60 TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend, 61 Backend::CreateBackend(backend_options)); 62 63 std::unique_ptr<LocalService> service( 64 new LocalService(options, std::move(backend))); 65 return std::move(service); 66} 67 68LocalService::LocalService(const ServiceOptions& options, 69 std::unique_ptr<Backend> execute_backend) 70 : Service(options, std::move(execute_backend)) {} 71 72StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( 73 const ComputationHandle& computation, 74 const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, 75 const ExecutableBuildOptions& build_options) { 76 TF_ASSIGN_OR_RETURN(UserComputation * user_computation, 77 computation_tracker_.Resolve(computation)); 78 VersionedComputationHandle versioned_handle = 79 user_computation->GetVersionedHandle(); 80 81 TF_ASSIGN_OR_RETURN( 82 std::shared_ptr<const ProgramShape> program_shape, 83 user_computation->ComputeProgramShape(versioned_handle.version)); 84 85 // Validate incoming layouts. 86 if (argument_layouts.size() != program_shape->parameters_size()) { 87 return InvalidArgument( 88 "Invalid number of arguments for computation: expected %d, got %zu.", 89 program_shape->parameters_size(), argument_layouts.size()); 90 } 91 for (int i = 0; i < argument_layouts.size(); ++i) { 92 const Shape& argument_shape = *argument_layouts[i]; 93 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); 94 if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { 95 tensorflow::gtl::optional<const OpMetadata*> metadata = 96 user_computation->ParameterMetadata(i); 97 auto metadata_string = [&metadata]() -> string { 98 if (!metadata.has_value()) { 99 return ""; 100 } 101 CHECK(metadata.value() != nullptr); 102 const OpMetadata& m = *metadata.value(); 103 if (!m.source_file().empty()) { 104 return tensorflow::strings::Printf( 105 " (%s:%d)", m.source_file().c_str(), m.source_line()); 106 } 107 return ""; 108 }; 109 return InvalidArgument( 110 "Invalid argument shape for argument %d%s, expected %s, got %s.", i, 111 metadata_string().c_str(), 112 ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), 113 ShapeUtil::HumanString(argument_shape).c_str()); 114 } 115 } 116 if (build_options.result_layout() != nullptr) { 117 TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( 118 *build_options.result_layout(), program_shape->result())); 119 } 120 121 ExecutionOptions execution_options = CreateDefaultExecutionOptions(); 122 if (build_options.generate_hlo_graph().has_value()) { 123 execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( 124 build_options.generate_hlo_graph().value()); 125 } 126 if (build_options.result_layout() != nullptr) { 127 *execution_options.mutable_shape_with_output_layout() = 128 *build_options.result_layout(); 129 } else { 130 *execution_options.mutable_shape_with_output_layout() = 131 program_shape->result(); 132 LayoutUtil::SetToDefaultLayout( 133 execution_options.mutable_shape_with_output_layout()); 134 } 135 TF_ASSIGN_OR_RETURN( 136 std::unique_ptr<HloModuleConfig> module_config, 137 CreateModuleConfig(*program_shape, argument_layouts, &execution_options, 138 *user_computation)); 139 140 TF_ASSIGN_OR_RETURN( 141 se::StreamExecutor * executor, 142 execute_backend_->stream_executor(build_options.device_ordinal())); 143 144 return BuildExecutable(versioned_handle, std::move(module_config), 145 execute_backend_.get(), executor, 146 build_options.device_allocator()); 147} 148 149StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { 150 return backend().computation_placer()->DeviceId( 151 replica_number, /*computation=*/0, options_.number_of_replicas(), 152 /*computation_count=*/1); 153} 154 155} // namespace xla 156