1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <iostream>
17#include <vector>
18
19#include "grpc++/grpc++.h"
20#include "grpc++/security/credentials.h"
21#include "grpc++/server_builder.h"
22
23#include "tensorflow/core/distributed_runtime/server_lib.h"
24
25#include "tensorflow/core/lib/core/errors.h"
26#include "tensorflow/core/lib/core/status.h"
27#include "tensorflow/core/lib/strings/str_util.h"
28#include "tensorflow/core/lib/strings/strcat.h"
29#include "tensorflow/core/platform/init_main.h"
30#include "tensorflow/core/protobuf/cluster.pb.h"
31#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
32#include "tensorflow/core/public/session_options.h"
33#include "tensorflow/core/util/command_line_flags.h"
34
35// This binary starts a TensorFlow server (master and worker).
36//
37// TODO(mrry): Replace with a py_binary that uses `tf.GrpcServer()`.
38namespace tensorflow {
39namespace {
40
41Status FillServerDef(const string& cluster_spec, const string& job_name,
42                     int task_index, ServerDef* options) {
43  options->set_protocol("grpc");
44  options->set_job_name(job_name);
45  options->set_task_index(task_index);
46
47  size_t my_num_tasks = 0;
48
49  ClusterDef* const cluster = options->mutable_cluster();
50
51  for (const string& job_str : str_util::Split(cluster_spec, ',')) {
52    JobDef* const job_def = cluster->add_job();
53    // Split each entry in the flag into 2 pieces, separated by "|".
54    const std::vector<string> job_pieces = str_util::Split(job_str, '|');
55    CHECK_EQ(2, job_pieces.size()) << job_str;
56    const string& job_name = job_pieces[0];
57    job_def->set_name(job_name);
58    // Does a bit more validation of the tasks_per_replica.
59    const StringPiece spec = job_pieces[1];
60    // job_str is of form <job_name>|<host_ports>.
61    const std::vector<string> host_ports = str_util::Split(spec, ';');
62    for (size_t i = 0; i < host_ports.size(); ++i) {
63      (*job_def->mutable_tasks())[i] = host_ports[i];
64    }
65    size_t num_tasks = host_ports.size();
66    if (job_name == options->job_name()) {
67      my_num_tasks = host_ports.size();
68    }
69    LOG(INFO) << "Peer " << job_name << " " << num_tasks << " {"
70              << str_util::Join(host_ports, ", ") << "}";
71  }
72  if (my_num_tasks == 0) {
73    return errors::InvalidArgument("Job name \"", options->job_name(),
74                                   "\" does not appear in the cluster spec");
75  }
76  if (options->task_index() >= my_num_tasks) {
77    return errors::InvalidArgument("Task index ", options->task_index(),
78                                   " is invalid (job \"", options->job_name(),
79                                   "\" contains ", my_num_tasks, " tasks");
80  }
81  return Status::OK();
82}
83
84}  // namespace
85}  // namespace tensorflow
86
87void Usage(char* const argv_0) {
88  std::cerr << "Usage: " << argv_0
89            << " --cluster_spec=SPEC --job_name=NAME --task_id=ID" << std::endl;
90  std::cerr << "Where:" << std::endl;
91  std::cerr << "    SPEC is <JOB>(,<JOB>)*" << std::endl;
92  std::cerr << "    JOB  is <NAME>|<HOST:PORT>(;<HOST:PORT>)*" << std::endl;
93  std::cerr << "    NAME is a valid job name ([a-z][0-9a-z]*)" << std::endl;
94  std::cerr << "    HOST is a hostname or IP address" << std::endl;
95  std::cerr << "    PORT is a port number" << std::endl;
96}
97
98int main(int argc, char* argv[]) {
99  tensorflow::string cluster_spec;
100  tensorflow::string job_name;
101  int task_index = 0;
102  std::vector<tensorflow::Flag> flag_list = {
103      tensorflow::Flag("cluster_spec", &cluster_spec, "cluster spec"),
104      tensorflow::Flag("job_name", &job_name, "job name"),
105      tensorflow::Flag("task_id", &task_index, "task id"),
106  };
107  tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
108  const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
109  tensorflow::port::InitMain(argv[0], &argc, &argv);
110  if (!parse_result || argc != 1) {
111    std::cerr << usage << std::endl;
112    Usage(argv[0]);
113    return -1;
114  }
115  tensorflow::ServerDef server_def;
116  tensorflow::Status s = tensorflow::FillServerDef(cluster_spec, job_name,
117                                                   task_index, &server_def);
118  if (!s.ok()) {
119    std::cerr << "ERROR: " << s.error_message() << std::endl;
120    Usage(argv[0]);
121    return -1;
122  }
123  std::unique_ptr<tensorflow::ServerInterface> server;
124  TF_QCHECK_OK(tensorflow::NewServer(server_def, &server));
125  TF_QCHECK_OK(server->Start());
126  TF_QCHECK_OK(server->Join());
127}
128