cancellation.h revision 12629a0a754d274cf1262f09db0290c6782e0adb
1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_
17#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_
18
19#include <atomic>
20#include <functional>
21
22#include "tensorflow/core/lib/core/notification.h"
23#include "tensorflow/core/lib/core/status.h"
24#include "tensorflow/core/lib/gtl/flatmap.h"
25#include "tensorflow/core/lib/hash/hash.h"
26#include "tensorflow/core/platform/mutex.h"
27#include "tensorflow/core/platform/thread_annotations.h"
28#include "tensorflow/core/platform/types.h"
29
30namespace tensorflow {
31
32// A token that can be used to register and deregister a
33// CancelCallback with a CancellationManager.
34//
35// CancellationToken values must be created by a call to
36// CancellationManager::get_cancellation_token.
37typedef int64 CancellationToken;
38
39// A callback that is invoked when a step is cancelled.
40//
41// NOTE(mrry): See caveats about CancelCallback implementations in the
42// comment for CancellationManager::RegisterCallback.
43typedef std::function<void()> CancelCallback;
44
45class CancellationManager {
46 public:
47  // A value that won't be returned by get_cancellation_token().
48  static const CancellationToken kInvalidToken;
49
50  CancellationManager();
51  ~CancellationManager();
52
53  // Run all callbacks associated with this manager.
54  void StartCancel();
55
56  // Returns true iff StartCancel() has been called.
57  bool IsCancelled() { return is_cancelled_.load(std::memory_order_acquire); }
58
59  // Returns a token that must be used in calls to RegisterCallback
60  // and DeregisterCallback.
61  CancellationToken get_cancellation_token();
62
63  // Attempts to register the given callback to be invoked when this
64  // manager is cancelled. Returns true if the callback was
65  // registered; returns false if this manager was already cancelled,
66  // and the callback was not registered.
67  //
68  // If this method returns false, it is the caller's responsibility
69  // to perform any cancellation cleanup.
70  //
71  // This method is tricky to use correctly. The following usage pattern
72  // is recommended:
73  //
74  // class ObjectWithCancellableOperation {
75  //   mutex mu_;
76  //   void CancellableOperation(CancellationManager* cm,
77  //                             std::function<void(Status)> callback) {
78  //     bool already_cancelled;
79  //     CancellationToken token = cm->get_cancellation_token();
80  //     {
81  //       mutex_lock(mu_);
82  //       already_cancelled = cm->RegisterCallback(
83  //           [this, token]() { Cancel(token); });
84  //       if (!already_cancelled) {
85  //         // Issue asynchronous operation. Associate the pending operation
86  //         // with `token` in some object state, or provide another way for
87  //         // the Cancel method to look up the operation for cancellation.
88  //         // Ensure that `cm->DeregisterCallback(token)` is called without
89  //         // holding `mu_`, before `callback` is invoked.
90  //         // ...
91  //       }
92  //     }
93  //     if (already_cancelled) {
94  //       callback(errors::Cancelled("Operation was cancelled"));
95  //     }
96  //   }
97  //
98  //   void Cancel(CancellationToken token) {
99  //     mutex_lock(mu_);
100  //     // Take action to cancel the operation with the given cancellation
101  //     // token.
102  //   }
103  //
104  // NOTE(mrry): The caller should take care that (i) the calling code
105  // is robust to `callback` being invoked asynchronously (e.g. from
106  // another thread), (ii) `callback` is deregistered by a call to
107  // this->DeregisterCallback(token) when the operation completes
108  // successfully, and (iii) `callback` does not invoke any method
109  // on this cancellation manager. Furthermore, it is important that
110  // the eventual caller of the complementary DeregisterCallback does not
111  // hold any mutexes that are required by `callback`.
112  bool RegisterCallback(CancellationToken token, CancelCallback callback);
113
114  // Deregister the callback that, when registered, was associated
115  // with the given cancellation token. Returns true iff the callback
116  // was deregistered and will not be invoked; otherwise returns false
117  // after the callback has been invoked, blocking if necessary.
118  //
119  // NOTE(mrry): This method may block if cancellation is in progress.
120  // The caller of this method must not hold any mutexes that are required
121  // to invoke any cancellation callback that has been registered with this
122  // cancellation manager.
123  bool DeregisterCallback(CancellationToken token);
124
125 private:
126  bool is_cancelling_;
127  std::atomic_bool is_cancelled_;
128
129  mutex mu_;
130  Notification cancelled_notification_;
131  CancellationToken next_cancellation_token_ GUARDED_BY(mu_);
132  gtl::FlatMap<CancellationToken, CancelCallback> callbacks_ GUARDED_BY(mu_);
133};
134
135}  // namespace tensorflow
136
137#endif  // TENSORFLOW_FRAMEWORK_CANCELLATION_H_
138