1#!/usr/bin/python
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16
17"""Generates YAML configuration files for distributed TensorFlow workers.
18
19The workers will be run in a Kubernetes (k8s) container cluster.
20"""
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import argparse
26import sys
27
28import k8s_tensorflow_lib
29
30# Note: It is intentional that we do not import tensorflow in this script. The
31# machine that launches a TensorFlow k8s cluster does not have to have the
32# Python package of TensorFlow installed on it.
33
34
35DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server'
36DEFAULT_PORT = 2222
37
38
39def main():
40  """Do arg parsing."""
41  parser = argparse.ArgumentParser()
42  parser.register(
43      'type', 'bool', lambda v: v.lower() in ('true', 't', 'y', 'yes'))
44  parser.add_argument('--num_workers',
45                      type=int,
46                      default=2,
47                      help='How many worker pods to run')
48  parser.add_argument('--num_parameter_servers',
49                      type=int,
50                      default=1,
51                      help='How many paramater server pods to run')
52  parser.add_argument('--grpc_port',
53                      type=int,
54                      default=DEFAULT_PORT,
55                      help='GRPC server port (Default: %d)' % DEFAULT_PORT)
56  parser.add_argument('--request_load_balancer',
57                      type='bool',
58                      default=False,
59                      help='To request worker0 to be exposed on a public IP '
60                      'address via an external load balancer, enabling you to '
61                      'run client processes from outside the cluster')
62  parser.add_argument('--docker_image',
63                      type=str,
64                      default=DEFAULT_DOCKER_IMAGE,
65                      help='Override default docker image for the TensorFlow '
66                      'GRPC server')
67  parser.add_argument('--name_prefix',
68                      type=str,
69                      default='tf',
70                      help='Prefix for job names. Jobs will be named as '
71                      '<name_prefix>_worker|ps<task_id>')
72  parser.add_argument('--use_shared_volume',
73                      type='bool',
74                      default=True,
75                      help='Whether to mount /shared directory from host to '
76                      'the pod')
77  args = parser.parse_args()
78
79  if args.num_workers <= 0:
80    sys.stderr.write('--num_workers must be greater than 0; received %d\n'
81                     % args.num_workers)
82    sys.exit(1)
83  if args.num_parameter_servers <= 0:
84    sys.stderr.write(
85        '--num_parameter_servers must be greater than 0; received %d\n'
86        % args.num_parameter_servers)
87    sys.exit(1)
88
89  # Generate contents of yaml config
90  yaml_config = k8s_tensorflow_lib.GenerateConfig(
91      args.num_workers,
92      args.num_parameter_servers,
93      args.grpc_port,
94      args.request_load_balancer,
95      args.docker_image,
96      args.name_prefix,
97      env_vars=None,
98      use_shared_volume=args.use_shared_volume)
99  print(yaml_config)  # pylint: disable=superfluous-parens
100
101
102if __name__ == '__main__':
103  main()
104