1e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai#
3e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# Licensed under the Apache License, Version 2.0 (the "License");
4e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# you may not use this file except in compliance with the License.
5e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# You may obtain a copy of the License at
6e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai#
7e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai#     http://www.apache.org/licenses/LICENSE-2.0
8e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai#
9e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# Unless required by applicable law or agreed to in writing, software
10e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# distributed under the License is distributed on an "AS IS" BASIS,
11e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# See the License for the specific language governing permissions and
13e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# limitations under the License.
14e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai# ==============================================================================
15e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
16e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caifrom __future__ import absolute_import
17e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caifrom __future__ import division
18e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caifrom __future__ import print_function
19e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
20e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport os
21e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport numpy as np
22e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport tensorflow as tf
23e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiimport tensorflow.contrib.mpi_collectives as mpi
24e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caifrom tensorflow.python.platform import test
25e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
26e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
27e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiaverage_allreduce = False
28e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caimax_wrong_count = -1
29e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
30e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
31e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiclass AllreduceTest(test.TestCase):
32e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai  def dumpFailure(self, my_rank, out_loc_red, my_correct, out_all_red,
33e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                  our_correct):
34e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    # Find reduced/allreduced indices that are wrong and print all the
35e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    # values from output, slices, reduced, allreduced, so we can debug
36e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    # which is incorrect:
37e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    wrong_count = 0
38e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    red_dims = out_loc_red.shape
39e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    assert(len(red_dims) == 2)
40e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    for i in range(red_dims[0]):
41e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      for j in range(red_dims[1]):
42e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        suffix = ""
43e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        if out_loc_red[i][j] != my_correct[i][j] or \
44e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai           out_all_red[i][j] != our_correct[i][j]:
45e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          suffix = "WRONG"
46e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          wrong_count += 1
47e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        print("{}\t{}\t{}\t{}\t{}\t{}"
48e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai              .format(my_rank, i, j, out_loc_red[i][j],
49e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                      out_all_red[i][j], suffix), flush=True)
50e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        if max_wrong_count > 0 and wrong_count >= max_wrong_count:
51e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          return
52e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
53e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai  def test_mpi_allreduce(self):
54e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    # Get MPI rank
55e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    my_rank = int(os.environ['PMI_RANK'])
56e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    num_ranks = int(os.environ['PMI_SIZE'])
57e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
58e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    stages = 13
59e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    batch_size = 1331
60e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    hidden_size = batch_size
61e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    out_size = batch_size
62e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
63e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    # Input placeholder (batch_size x hidden) - init to 1s
64e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    inputs = tf.placeholder(tf.float32, shape=(batch_size, hidden_size),
65e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                            name="Input")
66e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
67e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    # Large matrices (hidden x out_dim) - init random
68e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    weights = []
69e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    for i in range(stages):
70e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      initer = tf.constant_initializer(pow(2.0, i + 1.0))
71e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      weights.append(tf.get_variable("weights_{}".format(i),
72e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                                     shape=(hidden_size, out_size),
73e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                                     dtype=tf.float32,
74e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                                     initializer=initer))
75e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
76e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    # Calculate output through dependent allreduces
77e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    stage_input = inputs
78e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    for i in range(stages):
79e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      inter_output = tf.add(stage_input, weights[i],
80e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                            name="add_red_{}".format(i))
81e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      stage_input = mpi.allreduce(inter_output,
82e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                                  average=average_allreduce)
83e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
84e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    all_reduced = stage_input
85e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
86e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    # Local reduced output for verification
87e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    local_input = inputs
88e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    for i in range(stages):
89e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      inter_output = tf.add(local_input, weights[i],
90e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                            name="addin_loc_{}".format(i))
91e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      my_reducer = tf.Variable(initial_value=np.ones((hidden_size, out_size)),
92e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                               dtype=tf.float32, name="loc_redr_{}".format(i))
93e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      for r in range(num_ranks):
94e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        my_reducer = tf.add(my_reducer, inter_output,
95e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                            name="add_loc_{}_{}".format(i, r))
96e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      if average_allreduce:
97e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        local_input = tf.div(my_reducer, num_ranks,
98e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                             name="div_loc_{}".format(i))
99e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      else:
100e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        local_input = my_reducer
101e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
102e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    local_reduced = local_input
103e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
104e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    # NOTE: This assumes that device IDs are numbered the same as ranks
105e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    gpu_options = tf.GPUOptions(visible_device_list=str(my_rank))
106e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    config = tf.ConfigProto(gpu_options=gpu_options)
107e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
108e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    # MPI Session to test allreduce
109e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai    with mpi.Session(config=config) as sess:
110e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      sess.run(tf.global_variables_initializer())
111e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
112e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      input_feed = np.ones((batch_size, hidden_size), dtype=np.float32)
113e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      our_output = input_feed[0][0]
114e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      spread_var = 100
115e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      input_feed = input_feed + my_rank * spread_var
116e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      my_output = input_feed[0][0]
117e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      for i in range(stages):
118e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        curr_feed = my_output + pow(2.0, i + 1.0)
119e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        my_output = curr_feed * num_ranks + 1
120e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        curr_our_feed = our_output + pow(2.0, i + 1.0)
121e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        if i == 0:
122e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          sum_ranks = num_ranks * (num_ranks - 1) / 2
123e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          our_output = curr_our_feed * num_ranks + \
124e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai            spread_var * sum_ranks
125e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        else:
126e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          our_output = curr_our_feed * num_ranks
127e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
128e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      print("rank {}: My output is {}".format(my_rank, my_output))
129e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      my_correct = np.zeros((batch_size, hidden_size), dtype=np.float32)
130e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      my_correct = my_correct + my_output
131e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      print("rank {}: Our output is {}".format(my_rank, our_output))
132e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      our_correct = np.zeros((batch_size, hidden_size), dtype=np.float32)
133e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      our_correct = our_correct + our_output
134e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
135e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai      for i in range(1000):
136e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        if i % 100 == 0:
137e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          print("{}: iter {}".format(my_rank, i), flush=True)
138e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        feed_dict = {inputs: input_feed}
139e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        out_all_red, out_loc_red \
140e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          = sess.run([all_reduced, local_reduced],
141e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                     feed_dict=feed_dict)
142e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
143e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai        if not np.allclose(out_loc_red, my_correct) or \
144e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai           not np.allclose(out_all_red, our_correct):
145e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          print("Test incorrect on iter {}".format(i), flush=True)
146e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          self.dumpFailure(my_rank, out_loc_red, my_correct, out_all_red,
147e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                           our_correct)
148e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai          assert(np.allclose(out_loc_red, my_correct) and
149e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai                 np.allclose(out_all_red, our_correct))
150e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
151e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai
152e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Caiif __name__ == '__main__':
153e2e3a943c0a28b7656325acb3fcd035743d55ea0Shanqing Cai  test.main()
154