1e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# 3e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# Licensed under the Apache License, Version 2.0 (the "License"); 4e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# you may not use this file except in compliance with the License. 5e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# You may obtain a copy of the License at 6e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# 7e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# http://www.apache.org/licenses/LICENSE-2.0 8e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# 9e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# Unless required by applicable law or agreed to in writing, software 10e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# distributed under the License is distributed on an "AS IS" BASIS, 11e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# See the License for the specific language governing permissions and 13e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# limitations under the License. 14e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# ============================================================================== 15e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 16e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caifrom __future__ import absolute_import 17e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caifrom __future__ import division 18e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caifrom __future__ import print_function 19e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 20e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport os 21e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport numpy as np 22e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport tensorflow as tf 23e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport tensorflow.contrib.mpi_collectives as mpi 24e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caifrom tensorflow.python.platform import test 25e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 26e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 27e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiaverage_allreduce = False 28e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caimax_wrong_count = -1 29e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 30e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 31e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiclass AllreduceTest(test.TestCase): 32e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai def dumpFailure(self, my_rank, out_loc_red, my_correct, out_all_red, 33e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai our_correct): 34e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai # Find reduced/allreduced indices that are wrong and print all the 35e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai # values from output, slices, reduced, allreduced, so we can debug 36e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai # which is incorrect: 37e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai wrong_count = 0 38e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai red_dims = out_loc_red.shape 39e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai assert(len(red_dims) == 2) 40e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai for i in range(red_dims[0]): 41e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai for j in range(red_dims[1]): 42e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai suffix = "" 43e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai if out_loc_red[i][j] != my_correct[i][j] or \ 44e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai out_all_red[i][j] != our_correct[i][j]: 45e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai suffix = "WRONG" 46e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai wrong_count += 1 47e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai print("{}\t{}\t{}\t{}\t{}\t{}" 48e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai .format(my_rank, i, j, out_loc_red[i][j], 49e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai out_all_red[i][j], suffix), flush=True) 50e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai if max_wrong_count > 0 and wrong_count >= max_wrong_count: 51e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai return 52e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 53e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai def test_mpi_allreduce(self): 54e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai # Get MPI rank 55e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai my_rank = int(os.environ['PMI_RANK']) 56e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai num_ranks = int(os.environ['PMI_SIZE']) 57e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 58e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai stages = 13 59e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai batch_size = 1331 60e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai hidden_size = batch_size 61e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai out_size = batch_size 62e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 63e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai # Input placeholder (batch_size x hidden) - init to 1s 64e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai inputs = tf.placeholder(tf.float32, shape=(batch_size, hidden_size), 65e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai name="Input") 66e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 67e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai # Large matrices (hidden x out_dim) - init random 68e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai weights = [] 69e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai for i in range(stages): 70e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai initer = tf.constant_initializer(pow(2.0, i + 1.0)) 71e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai weights.append(tf.get_variable("weights_{}".format(i), 72e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai shape=(hidden_size, out_size), 73e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai dtype=tf.float32, 74e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai initializer=initer)) 75e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 76e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai # Calculate output through dependent allreduces 77e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai stage_input = inputs 78e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai for i in range(stages): 79e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai inter_output = tf.add(stage_input, weights[i], 80e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai name="add_red_{}".format(i)) 81e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai stage_input = mpi.allreduce(inter_output, 82e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai average=average_allreduce) 83e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 84e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai all_reduced = stage_input 85e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 86e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai # Local reduced output for verification 87e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai local_input = inputs 88e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai for i in range(stages): 89e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai inter_output = tf.add(local_input, weights[i], 90e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai name="addin_loc_{}".format(i)) 91e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai my_reducer = tf.Variable(initial_value=np.ones((hidden_size, out_size)), 92e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai dtype=tf.float32, name="loc_redr_{}".format(i)) 93e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai for r in range(num_ranks): 94e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai my_reducer = tf.add(my_reducer, inter_output, 95e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai name="add_loc_{}_{}".format(i, r)) 96e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai if average_allreduce: 97e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai local_input = tf.div(my_reducer, num_ranks, 98e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai name="div_loc_{}".format(i)) 99e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai else: 100e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai local_input = my_reducer 101e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 102e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai local_reduced = local_input 103e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 104e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai # NOTE: This assumes that device IDs are numbered the same as ranks 105e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai gpu_options = tf.GPUOptions(visible_device_list=str(my_rank)) 106e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai config = tf.ConfigProto(gpu_options=gpu_options) 107e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 108e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai # MPI Session to test allreduce 109e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai with mpi.Session(config=config) as sess: 110e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai sess.run(tf.global_variables_initializer()) 111e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 112e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai input_feed = np.ones((batch_size, hidden_size), dtype=np.float32) 113e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai our_output = input_feed[0][0] 114e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai spread_var = 100 115e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai input_feed = input_feed + my_rank * spread_var 116e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai my_output = input_feed[0][0] 117e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai for i in range(stages): 118e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai curr_feed = my_output + pow(2.0, i + 1.0) 119e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai my_output = curr_feed * num_ranks + 1 120e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai curr_our_feed = our_output + pow(2.0, i + 1.0) 121e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai if i == 0: 122e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai sum_ranks = num_ranks * (num_ranks - 1) / 2 123e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai our_output = curr_our_feed * num_ranks + \ 124e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai spread_var * sum_ranks 125e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai else: 126e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai our_output = curr_our_feed * num_ranks 127e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 128e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai print("rank {}: My output is {}".format(my_rank, my_output)) 129e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai my_correct = np.zeros((batch_size, hidden_size), dtype=np.float32) 130e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai my_correct = my_correct + my_output 131e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai print("rank {}: Our output is {}".format(my_rank, our_output)) 132e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai our_correct = np.zeros((batch_size, hidden_size), dtype=np.float32) 133e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai our_correct = our_correct + our_output 134e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 135e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai for i in range(1000): 136e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai if i % 100 == 0: 137e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai print("{}: iter {}".format(my_rank, i), flush=True) 138e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai feed_dict = {inputs: input_feed} 139e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai out_all_red, out_loc_red \ 140e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai = sess.run([all_reduced, local_reduced], 141e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai feed_dict=feed_dict) 142e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 143e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai if not np.allclose(out_loc_red, my_correct) or \ 144e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai not np.allclose(out_all_red, our_correct): 145e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai print("Test incorrect on iter {}".format(i), flush=True) 146e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai self.dumpFailure(my_rank, out_loc_red, my_correct, out_all_red, 147e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai our_correct) 148e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai assert(np.allclose(out_loc_red, my_correct) and 149e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai np.allclose(out_all_red, our_correct)) 150e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 151e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai 152e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiif __name__ == '__main__': 153e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai test.main() 154