1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
17#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
18
19#include <memory>
20#include <vector>
21
22#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
23#include "tensorflow/compiler/xla/service/hlo_instruction.h"
24#include "tensorflow/core/lib/core/status.h"
25#include "tensorflow/core/platform/stream_executor_no_cuda.h"
26
27namespace xla {
28namespace gpu {
29
30class GpuExecutable;
31
32// Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
33// metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
34//
35// Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable
36// to initialize and execute the invocation respectively. Its subclasses are
37// supposed to override these interfaces to launch a generated kernel or call an
38// external library function (such as operations in cuBLAS).
39//
40// This is thread-compatible.
41class Thunk {
42 public:
43  enum class Kind {
44    kConditional,
45    kConvolution,
46    kCopy,
47    kCudnnBatchNormBackward,
48    kCudnnBatchNormForwardInference,
49    kCudnnBatchNormForwardTraining,
50    kFft,
51    kGemm,
52    kInfeed,
53    kKernel,
54    kSequential,
55    kTuple,
56    kWhile,
57  };
58
59  // The hlo_instruction argument is meant to be the instruction this thunk was
60  // generated from, but Thunk never uses this argument other than to save it
61  // to Thunk::hlo_instruction, so it can be null.
62  explicit Thunk(Kind kind, const HloInstruction* hlo_instruction)
63      : kind_(kind), hlo_instruction_(hlo_instruction) {}
64  virtual ~Thunk() {}
65  Thunk(const Thunk&) = delete;
66  Thunk& operator=(const Thunk&) = delete;
67
68  Kind kind() const { return kind_; }
69  const HloInstruction* hlo_instruction() const { return hlo_instruction_; }
70
71  // Prepares for executing the thunk. This method is called only once over
72  // Thunk's lifetime. For example, KernelThunk::Initialize loads the PTX of a
73  // kernel, which is the same in every execution.
74  virtual tensorflow::Status Initialize(const GpuExecutable& executable) {
75    return tensorflow::Status::OK();
76  }
77
78  // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream)
79  // before calling ExecuteOnStream(stream).  If it returns true, it's the
80  // user's responsibility to wait for all activity on the GPU to finish before
81  // calling ExecuteOnStream.
82  //
83  // This value is not required to be constant for a given Thunk.  For example,
84  // a Thunk that performs autotuning may return true for its first run and
85  // false thereafter.
86  virtual bool ShouldHaltAllActivityBeforeRunning(
87      perftools::gputools::Stream* /*stream*/) {
88    return false;
89  }
90
91  // Indicates whether thunks scheduled after this one should wait for this one
92  // to complete before running. For example, a convolution thunk creates a
93  // scratch allocator, then kicks off a convolution in cudnn via the stream
94  // executor. When the stream executor call returns, the scratch allocator goes
95  // out of scope, and the scratch memory is deallocated. In this case, the
96  // convolution thunk needs to return true so that future thunks wait for the
97  // convolution thunk to avoid reusing the deallocated memory until the
98  // convolution thunk is done with it.
99  virtual bool ShouldBlockFutureThunks() { return false; }
100
101  // Execute the kernel for the thunk on the given stream. This method must be
102  // called after Initialize and can be called multiple times over Thunk's
103  // lifetime. Stream argument must be non-null.
104  virtual tensorflow::Status ExecuteOnStream(
105      const BufferAllocations& buffer_allocations,
106      perftools::gputools::Stream* stream) = 0;
107
108 private:
109  Kind kind_;
110  const HloInstruction* hlo_instruction_;
111};
112
113// A sequence of thunks.
114using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
115
116}  // namespace gpu
117}  // namespace xla
118
119#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
120