11e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
21e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
31e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsLicensed under the Apache License, Version 2.0 (the "License");
41e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsyou may not use this file except in compliance with the License.
51e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsYou may obtain a copy of the License at
61e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
71e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins    http://www.apache.org/licenses/LICENSE-2.0
81e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
91e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsUnless required by applicable law or agreed to in writing, software
101e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsdistributed under the License is distributed on an "AS IS" BASIS,
111e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
121e67c90e2caceeff82d09793d1ef5fa0300d219bPeter HawkinsSee the License for the specific language governing permissions and
131e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinslimitations under the License.
141e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins==============================================================================*/
151e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
161e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_THUNK_H_
171e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_THUNK_H_
181e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
191e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/buffer_assignment.h"
201e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
211e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
221e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/gpu/thunk.h"
231e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/service/hlo_instruction.h"
241e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/compiler/xla/xla_data.pb.h"
251e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/lib/core/status.h"
261e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#include "tensorflow/core/platform/stream_executor_no_cuda.h"
271e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
281e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace xla {
291e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsnamespace gpu {
301e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
311e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// This class stores everything that StreamExecutor needs to launch a BLAS gemm.
321e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// It is generated by IrEmitter.
331e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins//
341e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins// This is thread-compatible.
351e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkinsclass GemmThunk : public Thunk {
361e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins public:
371e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Constructs a thunk that computes "output = lhs <dot> rhs" using BLAS gemm.
381e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // transpose_lhs and transpose_rhs indicate whether gemm should transpose the
391e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // lhs and rhs operand. hlo_instruction is as in Thunk.
408ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower  GemmThunk(const BufferAllocation::Slice& lhs_buffer,
418ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower            const BufferAllocation::Slice& rhs_buffer,
428ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower            const BufferAllocation::Slice& output_buffer,
438ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower            const Shape& lhs_shape, const Shape& rhs_shape,
448ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower            const Shape& output_shape, bool transpose_lhs, bool transpose_rhs,
451e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins            const HloInstruction* hlo_instruction);
461e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
471e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  GemmThunk(const GemmThunk&) = delete;
481e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  GemmThunk& operator=(const GemmThunk&) = delete;
491e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
501e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  // Does the gemm operation for the thunk on "stream", which must be non-null.
511e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins  tensorflow::Status ExecuteOnStream(
521e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      const BufferAllocations& buffer_allocations,
531e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins      perftools::gputools::Stream* stream) override;
541e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
551a786ab335aabe9020cff4f0ab69a5844de70fbcJustin Lebar  // Returns true if we'll perform autotuning if run on the given stream.  If
561a786ab335aabe9020cff4f0ab69a5844de70fbcJustin Lebar  // so, we want the GPU to be quiescent during autotuning, so as not to
571a786ab335aabe9020cff4f0ab69a5844de70fbcJustin Lebar  // introduce noise in our results.
581a786ab335aabe9020cff4f0ab69a5844de70fbcJustin Lebar  bool ShouldHaltAllActivityBeforeRunning(
591a786ab335aabe9020cff4f0ab69a5844de70fbcJustin Lebar      perftools::gputools::Stream* stream) override {
601a786ab335aabe9020cff4f0ab69a5844de70fbcJustin Lebar    return autotune_results_.count(
611a786ab335aabe9020cff4f0ab69a5844de70fbcJustin Lebar               stream->parent()->GetDeviceDescription().name()) != 0;
621a786ab335aabe9020cff4f0ab69a5844de70fbcJustin Lebar  }
631a786ab335aabe9020cff4f0ab69a5844de70fbcJustin Lebar
641e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins private:
658ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower  const BufferAllocation::Slice lhs_buffer_;
668ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower  const BufferAllocation::Slice rhs_buffer_;
678ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower  const BufferAllocation::Slice output_buffer_;
681e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
698ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower  const Shape lhs_shape_;
708ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower  const Shape rhs_shape_;
718ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower  const Shape output_shape_;
721e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
738ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower  const bool transpose_lhs_;
748ff1c465c87fc3967c9d480646fac6d6205f856cA. Unique TensorFlower  const bool transpose_rhs_;
7501194694948eb883e99af597d9dbbf3fc9f5c9e2Justin Lebar
7601194694948eb883e99af597d9dbbf3fc9f5c9e2Justin Lebar  // Maps device names (StreamExecutor::DeviceDescription::name()) to autotune
7701194694948eb883e99af597d9dbbf3fc9f5c9e2Justin Lebar  // results.  The map's value is the best algorithm we've found for this thunk
7801194694948eb883e99af597d9dbbf3fc9f5c9e2Justin Lebar  // on this device, or an error if none of the algorithms worked and we should
7901194694948eb883e99af597d9dbbf3fc9f5c9e2Justin Lebar  // use the regular gemm without an algorithm.
8001194694948eb883e99af597d9dbbf3fc9f5c9e2Justin Lebar  std::unordered_map<string,
8101194694948eb883e99af597d9dbbf3fc9f5c9e2Justin Lebar                     StatusOr<::perftools::gputools::blas::AlgorithmType>>
8201194694948eb883e99af597d9dbbf3fc9f5c9e2Justin Lebar      autotune_results_;
831e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins};
841e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
851e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace gpu
861e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins}  // namespace xla
871e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins
881e67c90e2caceeff82d09793d1ef5fa0300d219bPeter Hawkins#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GEMM_THUNK_H_
89