1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/xla/service/hlo_sharding.h"
17
18#include "tensorflow/core/lib/core/errors.h"
19#include "tensorflow/core/lib/strings/str_util.h"
20
21namespace xla {
22
23using ::tensorflow::strings::StrCat;
24
25HloSharding HloSharding::AssignDevice(int64 device_id) {
26  return HloSharding(device_id);
27}
28
29HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) {
30  CHECK_EQ(1, ShapeUtil::Rank(input_shape));
31  CHECK_GT(num_tiles, 1);
32  std::vector<int64> dimensions(1, num_tiles);
33  Shape tile_shape = input_shape;
34  auto& tile_dimension = (*tile_shape.mutable_dimensions())[0];
35  tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);
36  Array<int64> assignment(dimensions);
37  std::iota(assignment.begin(), assignment.end(), 0);
38  return HloSharding(tile_shape, assignment);
39}
40
41string HloSharding::ToString() const {
42  if (IsTuple()) {
43    std::vector<string> parts;
44    parts.reserve(tuple_elements_.size());
45    for (const HloSharding& element : tuple_elements_) {
46      parts.push_back(element.ToString());
47    }
48    return StrCat("{", tensorflow::str_util::Join(parts, ", "), "}");
49  }
50
51  string result = StrCat("{", (replicated_ ? " replicated" : ""),
52                         (maximal_ ? " maximal" : ""));
53
54  if (replicated_) {
55    return "{replicated}";
56  } else if (maximal_) {
57    return StrCat(
58        "{maximal device=", static_cast<int64>(*tile_assignment_.begin()), "}");
59  } else {
60    return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ",
61                  "devices=", VectorString(tile_assignment_), "}");
62  }
63}
64
65bool HloSharding::UsesDevice(int64 device) const {
66  if (IsTuple()) {
67    return std::any_of(
68        tuple_elements_.begin(), tuple_elements_.end(),
69        [&](const HloSharding& s) { return s.UsesDevice(device); });
70  }
71  const auto& devices = tile_assignment_;
72  return replicated_ ||
73         std::find(devices.begin(), devices.end(), device) != devices.end();
74}
75
76std::vector<int64> HloSharding::TileIndexForDevice(int64 device) const {
77  CHECK(!ShapeUtil::IsTuple(tile_shape_));
78  CHECK(!maximal_);
79  CHECK(!IsTuple());
80  std::vector<int64> ret_index;
81  tile_assignment_.Each([&](tensorflow::gtl::ArraySlice<int64> index, int64 d) {
82    if (d == device) {
83      ret_index = {index.begin(), index.end()};
84    }
85  });
86  CHECK(!ret_index.empty());
87  return ret_index;
88}
89
90int64 HloSharding::DeviceForTileIndex(
91    tensorflow::gtl::ArraySlice<int64> index) const {
92  CHECK(!replicated_);
93  CHECK(!IsTuple());
94  if (maximal_) {
95    return *tile_assignment_.begin();
96  }
97  CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size());
98  return tile_assignment_(index);
99}
100
101std::vector<int64> HloSharding::TileOffsetForDevice(int64 device) const {
102  CHECK(!IsTuple());
103
104  std::vector<int64> index = TileIndexForDevice(device);
105  if (maximal_) {
106    // Index will always be all zeroes if we're maximal, and tile_shape_ is not
107    // valid.
108    return index;
109  }
110  for (int64 i = 0; i < index.size(); ++i) {
111    index[i] *= tile_shape_.dimensions(i);
112  }
113  return index;
114}
115
116std::vector<int64> HloSharding::TileLimitForDevice(int64 device) const {
117  CHECK(!IsTuple());
118  CHECK(!maximal_);  // Maximal shardings do not have a valid tile shape.
119
120  std::vector<int64> index = TileIndexForDevice(device);
121  for (int64 i = 0; i < index.size(); ++i) {
122    index[i] = (index[i] + 1) * tile_shape_.dimensions(i);
123  }
124  return index;
125}
126
127StatusOr<int64> HloSharding::UniqueDevice() const {
128  if (IsTuple()) {
129    if (tuple_elements_.empty()) {
130      return tensorflow::errors::InvalidArgument(
131          "UniqueDevice() called on empty tuple");
132    }
133    std::vector<StatusOr<int64>> results;
134    std::transform(tuple_elements_.begin(), tuple_elements_.end(),
135                   std::back_inserter(results),
136                   [](const HloSharding& s) { return s.UniqueDevice(); });
137    if (std::all_of(results.begin(), results.end(),
138                    [&](const StatusOr<int64>& s) {
139                      return s.ok() && results[0].ok() &&
140                             s.ValueOrDie() == results[0].ValueOrDie();
141                    })) {
142      return results[0];
143    } else {
144      return tensorflow::errors::InvalidArgument(
145          "Tuple did not contain a unique device");
146    }
147  }
148  if (!replicated_ && maximal_ && !IsTuple()) {
149    return static_cast<int64>(*tile_assignment_.begin());
150  }
151  return tensorflow::errors::InvalidArgument(
152      "UniqueDevice() called on sharding that executes on multiple devices");
153}
154
155bool HloSharding::HasUniqueDevice() const {
156  if (IsTuple()) {
157    return UniqueDevice().status().ok();
158  } else {
159    return !IsReplicated() && IsTileMaximal();
160  }
161}
162
163Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const {
164  if (!ShapeUtil::IsTuple(shape)) {
165    return tensorflow::errors::InvalidArgument(
166        StrCat("Sharding is tuple-shaped but validation shape is not."));
167  }
168  // The easiest way to get the number of elements in a nested tuple is just to
169  // create a shape tree. We could call GetAsShapeTree, but that will try and
170  // apply our tuple_shardings_ to the shape tree, and that might cause a crash
171  // at this point as we haven't validated them.
172  ShapeTree<bool> bool_shape_tree(shape, false);
173  int64 num_leaves =
174      std::distance(bool_shape_tree.leaf_begin(), bool_shape_tree.leaf_end());
175  if (num_leaves != tuple_elements_.size()) {
176    return tensorflow::errors::InvalidArgument(
177        StrCat("Validation tuple shape has ", num_leaves,
178               " leaf elements, but this sharding contains ",
179               tuple_elements_.size(), " elements."));
180  }
181
182  // Now we've validated the number of tuple elements, it's safe to request a
183  // shape tree.
184  ShapeTree<HloSharding> shape_tree = GetAsShapeTree(shape);
185  for (const auto& index_to_sharding : shape_tree.leaves()) {
186    Status status = index_to_sharding.second.ValidateNonTuple(
187        ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices);
188    if (!status.ok()) {
189      tensorflow::errors::AppendToMessage(
190          &status, StrCat("Note: While validating sharding tuple element ",
191                          index_to_sharding.first.ToString(), " which is ",
192                          index_to_sharding.second.ToString()));
193      return status;
194    }
195  }
196  return Status::OK();
197}
198
199Status HloSharding::Validate(const Shape& shape, int64 num_devices) const {
200  Status status = IsTuple() ? ValidateTuple(shape, num_devices)
201                            : ValidateNonTuple(shape, num_devices);
202  if (!status.ok()) {
203    tensorflow::errors::AppendToMessage(
204        &status, StrCat("Note: While validating sharding ", ToString(),
205                        " against shape ", ShapeUtil::HumanString(shape)));
206  }
207  return status;
208}
209
210Status HloSharding::ValidateNonTuple(const Shape& shape,
211                                     int64 num_devices) const {
212  if (ShapeUtil::IsTuple(shape)) {
213    return tensorflow::errors::InvalidArgument(
214        StrCat("Validation shape is a tuple but sharding is not."));
215  }
216  if (replicated_) {
217    return Status::OK();
218  }
219
220  // All tile assignments must be less than the number of available cores and
221  // unique.
222  Status status = Status::OK();
223  std::set<int64> seen_cores;
224  tile_assignment_.Each(
225      [&](tensorflow::gtl::ArraySlice<int64> indices, uint32 core) {
226        // Don't overwrite a bad status, so we report the first error.
227        if (status.ok()) {
228          if (core >= num_devices) {
229            status = tensorflow::errors::InvalidArgument(StrCat(
230                "core ", core, " > ", num_devices, " in tile assignment"));
231          } else if (seen_cores.count(core) != 0) {
232            status = tensorflow::errors::InvalidArgument(
233                StrCat("core ", core, " is not unique in tile assignment"));
234          }
235        }
236        seen_cores.insert(core);
237      });
238  if (!status.ok()) {
239    return status;
240  }
241
242  if (IsTileMaximal()) {
243    return Status::OK();
244  }
245
246  // The tile rank must be the same as the input rank.
247  if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) {
248    return tensorflow::errors::InvalidArgument(
249        "Tile rank is different to the input rank. sharding=", ToString(),
250        ", input_shape=", ShapeUtil::HumanString(shape));
251  }
252
253  // The tile shape must not be the same as the input shape without maximal_
254  // also set. If this is the case, we're not actually sharded and the correct
255  // constructor should have been used.
256  if (ShapeUtil::Equal(shape, tile_shape_)) {
257    return tensorflow::errors::InvalidArgument(
258        "Tile shape is the same as the input shape. If a replicated sharding "
259        "was intended, use HloSharding::Replicated(). If a device placement "
260        "was intended, use HloSharding::AssignDevice()");
261  }
262
263  // The tile shape must not be greater than the input shape in any dimension.
264  for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) {
265    auto tile_dim = tile_shape_.dimensions(i);
266    auto shape_dim = shape.dimensions(i);
267    if (tile_dim > shape_dim) {
268      return tensorflow::errors::InvalidArgument(
269          StrCat("Tile is larger than input shape (dimension ", i, ", ",
270                 tile_dim, " > ", shape_dim));
271    }
272  }
273
274  // The tile assignment tensor must be exactly dimensioned to ceil(shape[dim]
275  // tile[dim]) for every dimension contained within tile.
276  for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) {
277    int64 expected_dim =
278        CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i));
279    if (tile_assignment_.dimensions()[i] != expected_dim) {
280      return tensorflow::errors::InvalidArgument(
281          StrCat("Tile assignment tensor has incorrect shape. Dimension ", i,
282                 " expected ", expected_dim, " but got ",
283                 tile_assignment_.dimensions()[i]));
284    }
285  }
286
287  return Status::OK();
288}
289
290/*static*/ StatusOr<HloSharding> HloSharding::FromProto(
291    const OpSharding& proto) {
292  if (proto.type() == OpSharding::Type::OpSharding_Type_TUPLE) {
293    std::vector<HloSharding> tuple_shardings;
294    tuple_shardings.reserve(proto.tuple_shardings().size());
295    for (const OpSharding& tuple_sharding_proto : proto.tuple_shardings()) {
296      TF_ASSIGN_OR_RETURN(HloSharding sharding,
297                          HloSharding::FromProto(tuple_sharding_proto));
298      tuple_shardings.push_back(sharding);
299    }
300    return HloSharding(tuple_shardings);
301  } else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
302    return Replicate();
303  } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL ||
304             proto.tile_assignment_devices().size() == 1) {
305    return HloSharding(proto.tile_assignment_devices(0));
306  }
307  // Some versions of gcc cannot infer the TileAssignment constructor from a
308  // braced initializer-list, so create one manually.
309  std::vector<int64> devices(proto.tile_assignment_devices().begin(),
310                             proto.tile_assignment_devices().end());
311  Array<int64> tile_assignment(
312      std::vector<int64>(proto.tile_assignment_dimensions().begin(),
313                         proto.tile_assignment_dimensions().end()));
314  std::copy(proto.tile_assignment_devices().begin(),
315            proto.tile_assignment_devices().end(), tile_assignment.begin());
316  return HloSharding(proto.tile_shape(), tile_assignment);
317}
318
319OpSharding HloSharding::ToProto() const {
320  OpSharding result;
321
322  if (IsTuple()) {
323    for (const HloSharding& element : tuple_elements_) {
324      *result.add_tuple_shardings() = element.ToProto();
325    }
326    result.set_type(OpSharding::Type::OpSharding_Type_TUPLE);
327    return result;
328  }
329
330  *result.mutable_tile_shape() = tile_shape_;
331  for (int64 dim : tile_assignment_.dimensions()) {
332    result.add_tile_assignment_dimensions(dim);
333  }
334  for (auto device : tile_assignment_) {
335    result.add_tile_assignment_devices(device);
336  }
337  if (IsReplicated()) {
338    result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
339  } else if (IsTileMaximal()) {
340    result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
341  } else {
342    result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
343  }
344  return result;
345}
346
347}  // namespace xla
348