1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18
19namespace tensorflow {
20
21using shape_inference::InferenceContext;
22using shape_inference::ShapeHandle;
23
24REGISTER_OP("NcclAllReduce")
25    .Input("input: T")
26    .Output("data: T")
27    .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
28    .Attr("T: {float, float64, int32, int64}")
29    .Attr("num_devices: int")
30    .Attr("shared_name: string")
31    .SetIsStateful()
32    .SetShapeFn(shape_inference::UnchangedShape)
33    .Doc(R"doc(
34Outputs a tensor containing the reduction across all input tensors passed to ops
35within the same `shared_name.
36
37The graph should be constructed so if one op runs with shared_name value `c`,
38then `num_devices` ops will run with shared_name value `c`.  Failure to do so
39will cause the graph execution to fail to complete.
40
41input: the input to the reduction
42data: the value of the reduction across all `num_devices` devices.
43reduction: the reduction operation to perform.
44num_devices: The number of devices participating in this reduction.
45shared_name: Identifier that shared between ops of the same reduction.
46)doc");
47
48// Note: This op has no kernel implementation, but is replaced by
49// _NcclReduceSend and _NcclReduceRecv during graph optimization stage.
50REGISTER_OP("NcclReduce")
51    .Input("input: num_devices * T")
52    .Output("data: T")
53    .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
54    .Attr("T: {float, float64, int32, int64}")
55    .Attr("num_devices: int")
56    .SetIsStateful()
57    .SetShapeFn(shape_inference::UnchangedShape)
58    .Doc(R"doc(
59Reduces `input` from `num_devices` using `reduction` to a single device.
60
61The graph should be constructed so that all inputs have a valid device
62assignment, and the op itself is assigned one of these devices.
63
64input: The input to the reduction.
65data: the value of the reduction across all `num_devices` devices.
66reduction: the reduction operation to perform.
67    )doc");
68
69REGISTER_OP("_NcclReduceSend")
70    .Input("input: T")
71    .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
72    .Attr("T: {float, float64, int32, int64}")
73    .Attr("num_devices: int")
74    .Attr("shared_name: string")
75    .SetIsStateful()
76    .SetShapeFn(shape_inference::NoOutputs)
77    .Doc(R"doc(
78Replacement node for NcclReduce.
79
80Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`.
81The graph should be constructed so that 'num_devices-1' devices run
82`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value
83`c`. Failure to do so will cause the graph execution to fail to complete.
84
85input: The input to the reduction.
86reduction: the reduction operation to perform.
87num_devices: The number of devices participating in this reduction.
88shared_name: Identifier that is shared between ops of the same reduce.
89    )doc");
90
91REGISTER_OP("_NcclReduceRecv")
92    .Input("input: T")
93    .Output("data: T")
94    .Attr("reduction: {'min', 'max', 'prod', 'sum'}")
95    .Attr("T: {float, float64, int32, int64}")
96    .Attr("num_devices: int")
97    .Attr("shared_name: string")
98    .SetIsStateful()
99    .SetShapeFn(shape_inference::UnchangedShape)
100    .Doc(R"doc(
101Replacement node for NcclReduce.
102
103Reduces 'input' from this op and the NcclReduceSend ops registered in the same
104`shared_name`.
105The graph should be constructed so that 'num_devices-1' devices run
106`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value
107`c`. Failure to do so will cause the graph execution to fail to complete.
108
109input: The input to the reduction.
110data: The reduced data received from this op and the NcclReduceSend op.
111reduction: the reduction operation to perform.
112num_devices: The number of devices participating in this reduction.
113shared_name: Identifier that is shared between ops of the same reduce.
114    )doc");
115
116// Note: This op has no kernel implementation, but is replaced by
117// _NcclBroadcastSend and _NcclBroadcastRecv during graph optimization stage.
118REGISTER_OP("NcclBroadcast")
119    .Input("input: T")
120    .Output("output: T")
121    .Attr("T: {float, float64, int32, int64}")
122    .Attr("shape: shape")
123    .SetIsStateful()
124    .SetShapeFn(shape_inference::UnchangedShape)
125    .Doc(R"doc(
126Sends `input` to all devices that are connected to the output.
127
128The graph should be constructed so that all ops connected to the output have a
129valid device assignment, and the op itself is assigned one of these devices.
130
131input: The input to the broadcast.
132output: The same as input.
133shape: The shape of the input tensor.
134    )doc");
135
136REGISTER_OP("_NcclBroadcastSend")
137    .Input("input: T")
138    .Attr("T: {float, float64, int32, int64}")
139    .Attr("num_devices: int")
140    .Attr("shared_name: string")
141    .SetIsStateful()
142    .SetShapeFn(shape_inference::NoOutputs)
143    .Doc(R"doc(
144Replacement node for NcclBroadcast.
145
146Sends `input` to the _NcclBroadcastRecv ops registered in the same
147`shared_name`.
148The graph should be constructed so that one device runs `_NcclBroadcastSend` and
149`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`.
150Failure to do so will cause the graph execution to fail to complete.
151
152input: The input to the broadcast.
153num_devices: The number of devices participating in this reduction.
154shared_name: Identifier that is shared between ops of the same broadcast.
155    )doc");
156
157REGISTER_OP("_NcclBroadcastRecv")
158    .Input("shape: int32")
159    .Output("output: T")
160    .Attr("T: {float, float64, int32, int64}")
161    .Attr("num_devices: int")
162    .Attr("shared_name: string")
163    .SetIsStateful()
164    .SetShapeFn([](InferenceContext* c) {
165      ShapeHandle out;
166      TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
167      c->set_output(0, out);
168      return Status::OK();
169    })
170    .Doc(R"doc(
171Replacement node for NcclBroadcast.
172
173Sends data of shape `shape` from the _NcclBroadcastSend op registered in the
174same `shared_name`.
175The graph should be constructed so that one device runs `_NcclBroadcastSend` and
176`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`.
177Failure to do so will cause the graph execution to fail to complete.
178
179shape: The shape of the output.
180output: The broadcast data received from the NcclBroadcastSend op.
181num_devices: The number of devices participating in this reduction.
182shared_name: Identifier that is shared between ops of the same broadcast.
183    )doc");
184
185}  // namespace tensorflow
186