1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include <iostream> 17#include <vector> 18 19#include "grpc++/grpc++.h" 20#include "grpc++/security/credentials.h" 21#include "grpc++/server_builder.h" 22 23#include "tensorflow/core/distributed_runtime/server_lib.h" 24 25#include "tensorflow/core/lib/core/errors.h" 26#include "tensorflow/core/lib/core/status.h" 27#include "tensorflow/core/lib/strings/str_util.h" 28#include "tensorflow/core/lib/strings/strcat.h" 29#include "tensorflow/core/platform/init_main.h" 30#include "tensorflow/core/protobuf/cluster.pb.h" 31#include "tensorflow/core/protobuf/tensorflow_server.pb.h" 32#include "tensorflow/core/public/session_options.h" 33#include "tensorflow/core/util/command_line_flags.h" 34 35// This binary starts a TensorFlow server (master and worker). 36// 37// TODO(mrry): Replace with a py_binary that uses `tf.GrpcServer()`. 38namespace tensorflow { 39namespace { 40 41Status FillServerDef(const string& cluster_spec, const string& job_name, 42 int task_index, ServerDef* options) { 43 options->set_protocol("grpc"); 44 options->set_job_name(job_name); 45 options->set_task_index(task_index); 46 47 size_t my_num_tasks = 0; 48 49 ClusterDef* const cluster = options->mutable_cluster(); 50 51 for (const string& job_str : str_util::Split(cluster_spec, ',')) { 52 JobDef* const job_def = cluster->add_job(); 53 // Split each entry in the flag into 2 pieces, separated by "|". 54 const std::vector<string> job_pieces = str_util::Split(job_str, '|'); 55 CHECK_EQ(2, job_pieces.size()) << job_str; 56 const string& job_name = job_pieces[0]; 57 job_def->set_name(job_name); 58 // Does a bit more validation of the tasks_per_replica. 59 const StringPiece spec = job_pieces[1]; 60 // job_str is of form <job_name>|<host_ports>. 61 const std::vector<string> host_ports = str_util::Split(spec, ';'); 62 for (size_t i = 0; i < host_ports.size(); ++i) { 63 (*job_def->mutable_tasks())[i] = host_ports[i]; 64 } 65 size_t num_tasks = host_ports.size(); 66 if (job_name == options->job_name()) { 67 my_num_tasks = host_ports.size(); 68 } 69 LOG(INFO) << "Peer " << job_name << " " << num_tasks << " {" 70 << str_util::Join(host_ports, ", ") << "}"; 71 } 72 if (my_num_tasks == 0) { 73 return errors::InvalidArgument("Job name \"", options->job_name(), 74 "\" does not appear in the cluster spec"); 75 } 76 if (options->task_index() >= my_num_tasks) { 77 return errors::InvalidArgument("Task index ", options->task_index(), 78 " is invalid (job \"", options->job_name(), 79 "\" contains ", my_num_tasks, " tasks"); 80 } 81 return Status::OK(); 82} 83 84} // namespace 85} // namespace tensorflow 86 87void Usage(char* const argv_0) { 88 std::cerr << "Usage: " << argv_0 89 << " --cluster_spec=SPEC --job_name=NAME --task_id=ID" << std::endl; 90 std::cerr << "Where:" << std::endl; 91 std::cerr << " SPEC is <JOB>(,<JOB>)*" << std::endl; 92 std::cerr << " JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*" << std::endl; 93 std::cerr << " NAME is a valid job name ([a-z][0-9a-z]*)" << std::endl; 94 std::cerr << " HOST is a hostname or IP address" << std::endl; 95 std::cerr << " PORT is a port number" << std::endl; 96} 97 98int main(int argc, char* argv[]) { 99 tensorflow::string cluster_spec; 100 tensorflow::string job_name; 101 int task_index = 0; 102 std::vector<tensorflow::Flag> flag_list = { 103 tensorflow::Flag("cluster_spec", &cluster_spec, "cluster spec"), 104 tensorflow::Flag("job_name", &job_name, "job name"), 105 tensorflow::Flag("task_id", &task_index, "task id"), 106 }; 107 tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); 108 const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); 109 tensorflow::port::InitMain(argv[0], &argc, &argv); 110 if (!parse_result || argc != 1) { 111 std::cerr << usage << std::endl; 112 Usage(argv[0]); 113 return -1; 114 } 115 tensorflow::ServerDef server_def; 116 tensorflow::Status s = tensorflow::FillServerDef(cluster_spec, job_name, 117 task_index, &server_def); 118 if (!s.ok()) { 119 std::cerr << "ERROR: " << s.error_message() << std::endl; 120 Usage(argv[0]); 121 return -1; 122 } 123 std::unique_ptr<tensorflow::ServerInterface> server; 124 TF_QCHECK_OK(tensorflow::NewServer(server_def, &server)); 125 TF_QCHECK_OK(server->Start()); 126 TF_QCHECK_OK(server->Join()); 127} 128