1#!/usr/bin/env bash
2# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16#
17# This is the entry-point script to testing TensorFlow's distributed runtime.
18# It builds a docker image with the necessary gcloud and Kubernetes (k8s) tools
19# installed, and then execute k8s cluster preparation and distributed TensorFlow
20# runs from within a container based on the image.
21#
22# Usage:
23#   remote_test.sh <whl_url>
24#                  [--setup_cluster_only]
25#                  [--num_workers <NUM_WORKERS>]
26#                  [--num_parameter_servers <NUM_PARAMETER_SERVERS>]
27#                  [--sync_replicas]
28#
29# Arguments:
30# <whl_url>
31#   Specify custom TensorFlow whl file URL to install in the test Docker image.
32#
33# --setup_cluster_only:
34#       Setup the TensorFlow k8s cluster only, and do not perform testing of
35#       the distributed runtime.
36#
37# --num_workers <NUM_WORKERS>:
38#   Specifies the number of worker pods to start
39#
40# --num_parameter_server <NUM_PARAMETER_SERVERS>:
41#   Specifies the number of parameter servers to start
42#
43# --sync_replicas
44#   Use the synchronized-replica mode. The parameter updates from the replicas
45#   (workers) will be aggregated before applied, which avoids stale parameter
46#   updates.
47#
48#
49#
50# If any of the following environment variable has non-empty values, it will
51# be mapped into the docker container to override the default values (see
52# dist_test.sh)
53#   TF_DIST_GRPC_SERVER_URL:      URL to an existing TensorFlow GRPC server.
54#                                 If set to any non-empty and valid value (e.g.,
55#                                 grpc://1.2.3.4:2222), it will cause the test
56#                                 to bypass the k8s cluster setup and
57#                                 teardown process, and just use the this URL
58#                                 as the master session.
59#   TF_DIST_GCLOUD_PROJECT:       gcloud project in which the GKE cluster
60#                                 will be created (takes effect only if
61#                                 TF_DIST_GRPC_SERVER_URL is empty, same below)
62#   TF_DIST_GCLOUD_COMPUTE_ZONE:  gcloud compute zone.
63#   TF_DIST_CONTAINER_CLUSTER:    name of the GKE cluster
64#   TF_DIST_GCLOUD_KEY_FILE:      path to the gloud service JSON key file
65#   TF_DIST_GRPC_PORT:            port on which to create the TensorFlow GRPC
66#                                 servers
67#   TF_DIST_DOCKER_NO_CACHE:      do not use cache when building docker images
68
69die() {
70  echo $@
71  exit 1
72}
73
74DOCKER_IMG_NAME="tensorflow/tf-dist-test-client"
75
76# Get current script directory
77DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
78
79# Prepare environment variables for the docker container
80DOCKER_ENV_FLAGS=""
81if [[ ! -z "$TF_DIST_GRPC_SERVER_URL" ]]; then
82  DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
83"-e TF_DIST_GRPC_SERVER_URL=${TF_DIST_GRPC_SERVER_URL}"
84fi
85if [[ ! -z "$TF_DIST_GCLOUD_PROJECT" ]]; then
86  DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
87"-e TF_DIST_GCLOUD_PROJECT=${TF_DIST_GCLOUD_PROJECT}"
88fi
89if [[ ! -z "$TF_DIST_GCLOUD_COMPUTE_ZONE" ]]; then
90  DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
91"-e TF_DIST_GCLOUD_COMPUTE_ZONE=${TF_DIST_GCLOUD_COMPUTE_ZONE}"
92fi
93if [[ ! -z "$TF_DIST_CONTAINER_CLUSTER" ]]; then
94  DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
95"-e TF_DIST_CONTAINER_CLUSTER=${TF_DIST_CONTAINER_CLUSTER}"
96fi
97if [[ ! -z "$TF_DIST_GRPC_PORT" ]]; then
98  DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
99"-e TF_DIST_GRPC_PORT=${TF_DIST_GRPC_PORT}"
100fi
101
102NO_CACHE_FLAG=""
103if [[ ! -z "${TF_DIST_DOCKER_NO_CACHE}" ]] &&
104   [[ "${TF_DIST_DOCKER_NO_CACHE}" != "0" ]]; then
105  NO_CACHE_FLAG="--no-cache"
106fi
107
108# Parse command-line arguments.
109WHL_URL=${1}
110if [[ -z "${WHL_URL}" ]]; then
111  die "whl URL is not specified"
112fi
113
114# Create docker build context directory.
115BUILD_DIR=$(mktemp -d)
116echo ""
117echo "Using custom whl file URL: ${WHL_URL}"
118echo "Building in temporary directory: ${BUILD_DIR}"
119
120cp -r ${DIR}/* ${BUILD_DIR}/ || \
121  die "Failed to copy files to ${BUILD_DIR}"
122
123# Download whl file into the build context directory.
124wget -P "${BUILD_DIR}" ${WHL_URL} || \
125  die "Failed to download tensorflow whl file from URL: ${WHL_URL}"
126
127# Build docker image for test.
128docker build ${NO_CACHE_FLAG} \
129    -t ${DOCKER_IMG_NAME} -f "${BUILD_DIR}/Dockerfile" "${BUILD_DIR}" || \
130    die "Failed to build docker image: ${DOCKER_IMG_NAME}"
131
132# Clean up docker build context directory.
133rm -rf "${BUILD_DIR}"
134
135# Run docker image for test.
136KEY_FILE=${TF_DIST_GCLOUD_KEY_FILE:-"${HOME}/gcloud-secrets/tensorflow-testing.json"}
137
138docker run --rm -v ${KEY_FILE}:/var/gcloud/secrets/tensorflow-testing.json \
139  ${DOCKER_ENV_FLAGS} \
140  ${DOCKER_IMG_NAME} \
141  /var/tf-dist-test/scripts/dist_test.sh $@
142