1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3Licensed under the Apache License, Version 2.0 (the "License"); 4you may not use this file except in compliance with the License. 5You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9Unless required by applicable law or agreed to in writing, software 10distributed under the License is distributed on an "AS IS" BASIS, 11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12See the License for the specific language governing permissions and 13limitations under the License. 14==============================================================================*/ 15 16#include "tensorflow/core/framework/common_shape_fns.h" 17#include "tensorflow/core/framework/op.h" 18 19namespace tensorflow { 20 21using shape_inference::InferenceContext; 22using shape_inference::ShapeHandle; 23 24REGISTER_OP("NcclAllReduce") 25 .Input("input: T") 26 .Output("data: T") 27 .Attr("reduction: {'min', 'max', 'prod', 'sum'}") 28 .Attr("T: {float, float64, int32, int64}") 29 .Attr("num_devices: int") 30 .Attr("shared_name: string") 31 .SetIsStateful() 32 .SetShapeFn(shape_inference::UnchangedShape) 33 .Doc(R"doc( 34Outputs a tensor containing the reduction across all input tensors passed to ops 35within the same `shared_name. 36 37The graph should be constructed so if one op runs with shared_name value `c`, 38then `num_devices` ops will run with shared_name value `c`. Failure to do so 39will cause the graph execution to fail to complete. 40 41input: the input to the reduction 42data: the value of the reduction across all `num_devices` devices. 43reduction: the reduction operation to perform. 44num_devices: The number of devices participating in this reduction. 45shared_name: Identifier that shared between ops of the same reduction. 46)doc"); 47 48// Note: This op has no kernel implementation, but is replaced by 49// _NcclReduceSend and _NcclReduceRecv during graph optimization stage. 50REGISTER_OP("NcclReduce") 51 .Input("input: num_devices * T") 52 .Output("data: T") 53 .Attr("reduction: {'min', 'max', 'prod', 'sum'}") 54 .Attr("T: {float, float64, int32, int64}") 55 .Attr("num_devices: int") 56 .SetIsStateful() 57 .SetShapeFn(shape_inference::UnchangedShape) 58 .Doc(R"doc( 59Reduces `input` from `num_devices` using `reduction` to a single device. 60 61The graph should be constructed so that all inputs have a valid device 62assignment, and the op itself is assigned one of these devices. 63 64input: The input to the reduction. 65data: the value of the reduction across all `num_devices` devices. 66reduction: the reduction operation to perform. 67 )doc"); 68 69REGISTER_OP("_NcclReduceSend") 70 .Input("input: T") 71 .Attr("reduction: {'min', 'max', 'prod', 'sum'}") 72 .Attr("T: {float, float64, int32, int64}") 73 .Attr("num_devices: int") 74 .Attr("shared_name: string") 75 .SetIsStateful() 76 .SetShapeFn(shape_inference::NoOutputs) 77 .Doc(R"doc( 78Replacement node for NcclReduce. 79 80Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. 81The graph should be constructed so that 'num_devices-1' devices run 82`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value 83`c`. Failure to do so will cause the graph execution to fail to complete. 84 85input: The input to the reduction. 86reduction: the reduction operation to perform. 87num_devices: The number of devices participating in this reduction. 88shared_name: Identifier that is shared between ops of the same reduce. 89 )doc"); 90 91REGISTER_OP("_NcclReduceRecv") 92 .Input("input: T") 93 .Output("data: T") 94 .Attr("reduction: {'min', 'max', 'prod', 'sum'}") 95 .Attr("T: {float, float64, int32, int64}") 96 .Attr("num_devices: int") 97 .Attr("shared_name: string") 98 .SetIsStateful() 99 .SetShapeFn(shape_inference::UnchangedShape) 100 .Doc(R"doc( 101Replacement node for NcclReduce. 102 103Reduces 'input' from this op and the NcclReduceSend ops registered in the same 104`shared_name`. 105The graph should be constructed so that 'num_devices-1' devices run 106`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value 107`c`. Failure to do so will cause the graph execution to fail to complete. 108 109input: The input to the reduction. 110data: The reduced data received from this op and the NcclReduceSend op. 111reduction: the reduction operation to perform. 112num_devices: The number of devices participating in this reduction. 113shared_name: Identifier that is shared between ops of the same reduce. 114 )doc"); 115 116// Note: This op has no kernel implementation, but is replaced by 117// _NcclBroadcastSend and _NcclBroadcastRecv during graph optimization stage. 118REGISTER_OP("NcclBroadcast") 119 .Input("input: T") 120 .Output("output: T") 121 .Attr("T: {float, float64, int32, int64}") 122 .Attr("shape: shape") 123 .SetIsStateful() 124 .SetShapeFn(shape_inference::UnchangedShape) 125 .Doc(R"doc( 126Sends `input` to all devices that are connected to the output. 127 128The graph should be constructed so that all ops connected to the output have a 129valid device assignment, and the op itself is assigned one of these devices. 130 131input: The input to the broadcast. 132output: The same as input. 133shape: The shape of the input tensor. 134 )doc"); 135 136REGISTER_OP("_NcclBroadcastSend") 137 .Input("input: T") 138 .Attr("T: {float, float64, int32, int64}") 139 .Attr("num_devices: int") 140 .Attr("shared_name: string") 141 .SetIsStateful() 142 .SetShapeFn(shape_inference::NoOutputs) 143 .Doc(R"doc( 144Replacement node for NcclBroadcast. 145 146Sends `input` to the _NcclBroadcastRecv ops registered in the same 147`shared_name`. 148The graph should be constructed so that one device runs `_NcclBroadcastSend` and 149`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. 150Failure to do so will cause the graph execution to fail to complete. 151 152input: The input to the broadcast. 153num_devices: The number of devices participating in this reduction. 154shared_name: Identifier that is shared between ops of the same broadcast. 155 )doc"); 156 157REGISTER_OP("_NcclBroadcastRecv") 158 .Input("shape: int32") 159 .Output("output: T") 160 .Attr("T: {float, float64, int32, int64}") 161 .Attr("num_devices: int") 162 .Attr("shared_name: string") 163 .SetIsStateful() 164 .SetShapeFn([](InferenceContext* c) { 165 ShapeHandle out; 166 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 167 c->set_output(0, out); 168 return Status::OK(); 169 }) 170 .Doc(R"doc( 171Replacement node for NcclBroadcast. 172 173Sends data of shape `shape` from the _NcclBroadcastSend op registered in the 174same `shared_name`. 175The graph should be constructed so that one device runs `_NcclBroadcastSend` and 176`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. 177Failure to do so will cause the graph execution to fail to complete. 178 179shape: The shape of the output. 180output: The broadcast data received from the NcclBroadcastSend op. 181num_devices: The number of devices participating in this reduction. 182shared_name: Identifier that is shared between ops of the same broadcast. 183 )doc"); 184 185} // namespace tensorflow 186