1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
17#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
18
19#include <string>
20#include <unordered_set>
21
22#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
23#include "tensorflow/core/distributed_runtime/worker_env.h"
24#include "tensorflow/core/distributed_runtime/worker_session.h"
25#include "tensorflow/core/framework/control_flow.h"
26#include "tensorflow/core/framework/rendezvous.h"
27#include "tensorflow/core/lib/core/status.h"
28#include "tensorflow/core/lib/gtl/flatmap.h"
29#include "tensorflow/core/lib/gtl/flatset.h"
30#include "tensorflow/core/lib/hash/hash.h"
31#include "tensorflow/core/platform/macros.h"
32#include "tensorflow/core/platform/mutex.h"
33#include "tensorflow/core/platform/thread_annotations.h"
34#include "tensorflow/core/platform/types.h"
35#include "tensorflow/core/util/device_name_utils.h"
36
37namespace tensorflow {
38
39class BaseRemoteRendezvous;
40class BaseRecvTensorCall;
41
42// RendezvousMgr keeps track of a set of local rendezvous instances.
43// All tensors sent by this worker are buffered in a RendezvousMgr
44// until the tensor is received.  Each global unique "step_id"
45// corresponds to one local rendezvous instance managed by a
46// RendezvousMgr.
47//
48// E.g.,
49//   Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
50//   fork execution of a graph executor using "rendez" on thread 1;
51//   fork execution of another graph executor using "rendez" on thread 2;
52//   ...
53//   join threads 1 and 2;
54//
55// In the example above, execution in thread 1 and 2 communicates with
56// each other by send/recv operations through `rendez`.
57//
58// Tensors sent and received through a rendezvous managed by this
59// RendezvousMgr must have keys generated by Rendezvous::CreateKey().
60class BaseRendezvousMgr : public RendezvousMgrInterface {
61 public:
62  explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
63
64  ~BaseRendezvousMgr() override;
65
66  // Returns Rendezvous supporting send and recv among workers in the
67  // "step_id".  The caller takes ownership of one reference on the
68  // returned Rendezvous instance.
69  //
70  // Note: the caller must guarantee to eventually call Initialize on the
71  // returned RemoteRendezvous
72  RemoteRendezvous* Find(int64 step_id) override;
73
74  // Finds the local rendezvous instance for the "step_id".  Runs
75  // "done" when the tensor for "key" is produced or an error occurs.
76  //
77  // This method is used by the rpc handler of RecvTensor.
78  void RecvLocalAsync(int64 step_id, const Rendezvous::ParsedKey& parsed,
79                      Rendezvous::DoneCallback done) override;
80
81  // Synchronous wrapper for RecvLocalAsync.
82  Status RecvLocal(int64 step_id, const Rendezvous::ParsedKey& parsed,
83                   Tensor* val, bool* is_dead) override;
84
85  // Removes rendezvous for "step_id".
86  //
87  // TODO(zhifengc): Have a background thread in worker that
88  // periodically calls CleanupAll().
89  void Cleanup(int64 step_id) override;
90
91  // Removed all rendezvous.
92  void CleanupAll() override;
93
94 protected:
95  virtual BaseRemoteRendezvous* Create(int64 step_id,
96                                       const WorkerEnv* worker_env) = 0;
97
98 private:
99  // Maps step_id to rendezvous.
100  typedef gtl::FlatMap<int64, BaseRemoteRendezvous*> Table;
101
102  // Not owned.
103  const WorkerEnv* const worker_env_;
104
105  mutex mu_;
106  Table table_ GUARDED_BY(mu_);
107
108  BaseRemoteRendezvous* FindOrCreate(int64 step_id);
109
110  TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr);
111};
112
113// RemoteRendezvous is a Rendezvous which can handle either
114// the producer or consumer being in a remote process.
115//
116// Buffering of Tensor values is delegated to a "local" Rendezvous
117// obtained from NewLocalRendezvous().  This class just adds
118// functionality to coordinate with remote workers.
119class BaseRemoteRendezvous : public RemoteRendezvous {
120 public:
121  BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id);
122
123  // Upgrades the BaseRemoteRendezvous to full initialization.
124  Status Initialize(WorkerSession* session) override;
125
126  // Forwards to local_, where the Tensor "val" will be buffered and
127  // any waiting callback stored.
128  Status Send(const ParsedKey& key, const Rendezvous::Args& args,
129              const Tensor& val, const bool is_dead) override;
130
131  // This method is called only by the RecvOp.  It tests to see
132  // whether the value will be produced by a local or remote device
133  // and handles accordingly.  In the local case it forwards to
134  // local_, in the remote case it initiates an RPC request.
135  void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
136                 DoneCallback done) override;
137
138  void StartAbort(const Status& status) override;
139
140  // This method is called only by the local Worker, forwarded through
141  // the same method on RendezvousMgr.  This occurs when the Worker
142  // has received a RecvTensor request, either locally or over the
143  // network.  In either case it needs to retrieve a locally buffered
144  // value from local_, and give it to its caller.
145  //
146  // Runs "done" as soon as the tensor for "parsed" is available or an error
147  // is detected.
148  //
149  // REQUIRES: "parsed" is one that will be Saved into the local rendezvous.
150  void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done);
151
152 protected:
153  virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
154                                   const Rendezvous::Args& args,
155                                   DoneCallback done) = 0;
156
157  // Returns true if "src" and "dst" are located in the same worker,
158  // and hence may use a local rendezvous.
159  virtual bool IsSameWorker(DeviceNameUtils::ParsedName src,
160                            DeviceNameUtils::ParsedName dst);
161
162  // If aborted, aborts "call". Otherwise, adds "call" into active_.
163  void RegisterCall(BaseRecvTensorCall* call);
164
165  // Removes "call" from active_ if "call" is in active_.
166  void DeregisterCall(BaseRecvTensorCall* call);
167
168  WorkerSession* session();
169
170  bool is_initialized();
171
172  ~BaseRemoteRendezvous() override;
173
174  const WorkerEnv* const env_;  // Not owned.
175  const int64 step_id_;
176
177 private:
178  Rendezvous* local_;  // Owns a Ref on this object.
179
180  mutable mutex mu_;
181
182  // Status given by StartAbort() if any.
183  Status status_ GUARDED_BY(mu_);
184  WorkerSession* session_ GUARDED_BY(mu_);  // Not owned.
185
186  // Data structures to handle calls when partially initialized.
187  struct DeferredCall {
188    const ParsedKey parsed;
189    DoneCallback done;
190
191    DeferredCall(const ParsedKey& parsed, DoneCallback done);
192  };
193  std::vector<DeferredCall> deferred_calls_ GUARDED_BY(mu_);
194
195  // Active outstanding RecvTensor calls.
196  gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
197
198  bool is_initialized_locked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
199    return session_ != nullptr;
200  }
201
202  // If "is_src" is true, checks that the rendezvous key "parsed"'s
203  // source is in this process. If "is_src" is false, checks that the
204  // rendezvous key "parsed"'s destination is in this process.
205  Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src);
206
207  // Callback handling the case when a rendezvous has been
208  // accomplished in local_ and the consumer is local to this process.
209  // Tensor "in" will be copied into "out". The key "parsed" encodes
210  // the src and dst devices.
211  void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
212                          const Rendezvous::Args& in_args,
213                          const Rendezvous::Args& out_args, const Tensor& in,
214                          Tensor* out, StatusCallback done);
215
216  // Must be called only if fully initialized.
217  void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
218
219  TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
220};
221
222class BaseRecvTensorCall {
223 public:
224  BaseRecvTensorCall() {}
225  virtual ~BaseRecvTensorCall() {}
226
227  virtual void Start(std::function<void()> recv_done) = 0;
228
229  virtual void StartAbort(const Status& s) = 0;
230
231  virtual Status status() const = 0;
232
233 private:
234  TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall);
235};
236
237}  // end namespace tensorflow
238
239#endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
240