1//===- OCL21ToSPIRV.cpp - Transform OCL21 to SPIR-V builtins -----*- C++ -*-===//
2//
3//                     The LLVM/SPIRV Translator
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8// Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
9//
10// Permission is hereby granted, free of charge, to any person obtaining a
11// copy of this software and associated documentation files (the "Software"),
12// to deal with the Software without restriction, including without limitation
13// the rights to use, copy, modify, merge, publish, distribute, sublicense,
14// and/or sell copies of the Software, and to permit persons to whom the
15// Software is furnished to do so, subject to the following conditions:
16//
17// Redistributions of source code must retain the above copyright notice,
18// this list of conditions and the following disclaimers.
19// Redistributions in binary form must reproduce the above copyright notice,
20// this list of conditions and the following disclaimers in the documentation
21// and/or other materials provided with the distribution.
22// Neither the names of Advanced Micro Devices, Inc., nor the names of its
23// contributors may be used to endorse or promote products derived from this
24// Software without specific prior written permission.
25// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28// CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
31// THE SOFTWARE.
32//
33//===----------------------------------------------------------------------===//
34//
35// This file implements translation of OCL21 builtin functions.
36//
37//===----------------------------------------------------------------------===//
38#define DEBUG_TYPE "cl21tospv"
39
40#include "SPIRVInternal.h"
41#include "OCLUtil.h"
42#include "llvm/ADT/StringSwitch.h"
43#include "llvm/IR/InstVisitor.h"
44#include "llvm/IR/Instructions.h"
45#include "llvm/IR/IRBuilder.h"
46#include "llvm/IR/Verifier.h"
47#include "llvm/Pass.h"
48#include "llvm/PassSupport.h"
49#include "llvm/Support/Debug.h"
50#include "llvm/Support/raw_ostream.h"
51
52#include <set>
53
54using namespace llvm;
55using namespace SPIRV;
56using namespace OCLUtil;
57
58namespace SPIRV {
59
60class OCL21ToSPIRV: public ModulePass,
61  public InstVisitor<OCL21ToSPIRV> {
62public:
63  OCL21ToSPIRV():ModulePass(ID), M(nullptr), Ctx(nullptr), CLVer(0) {
64    initializeOCL21ToSPIRVPass(*PassRegistry::getPassRegistry());
65  }
66  virtual bool runOnModule(Module &M);
67  virtual void visitCallInst(CallInst &CI);
68
69  /// Transform SPIR-V convert function
70  //    __spirv{N}Op{ConvertOpName}(src, dummy)
71  ///   =>
72  ///   __spirv_{ConvertOpName}_R{TargeTyName}
73  void visitCallConvert(CallInst *CI, StringRef MangledName, Op OC);
74
75  /// Transform SPIR-V decoration
76  ///   x = __spirv_{OpName};
77  ///   y = __spirv{N}Op{Decorate}(x, type, value, dummy)
78  ///   =>
79  ///   y = __spirv_{OpName}{Postfix(type,value)}
80  void visitCallDecorate(CallInst *CI, StringRef MangledName);
81
82  /// Transform sub_group_barrier to __spirv_ControlBarrier.
83  /// sub_group_barrier(scope, flag) =>
84  ///   __spirv_ControlBarrier(subgroup, map(scope), map(flag))
85  void visitCallSubGroupBarrier(CallInst *CI);
86
87  /// Transform OCL C++ builtin function to SPIR-V builtin function.
88  /// Assuming there is no argument changes.
89  /// Should be called at last.
90  void transBuiltin(CallInst *CI, Op OC);
91
92  static char ID;
93private:
94  ConstantInt *addInt32(int I) {
95    return getInt32(M, I);
96  }
97
98  Module *M;
99  LLVMContext *Ctx;
100  unsigned CLVer;                   /// OpenCL version as major*10+minor
101  std::set<Value *> ValuesToDelete;
102};
103
104char OCL21ToSPIRV::ID = 0;
105
106bool
107OCL21ToSPIRV::runOnModule(Module& Module) {
108  M = &Module;
109  Ctx = &M->getContext();
110
111  auto Src = getSPIRVSource(&Module);
112  if (std::get<0>(Src) != spv::SourceLanguageOpenCL_CPP)
113    return false;
114
115  CLVer = std::get<1>(Src);
116  if (CLVer < kOCLVer::CL21)
117    return false;
118
119  DEBUG(dbgs() << "Enter OCL21ToSPIRV:\n");
120  visit(*M);
121
122  for (auto &I:ValuesToDelete)
123    if (auto Inst = dyn_cast<Instruction>(I))
124      Inst->eraseFromParent();
125  for (auto &I:ValuesToDelete)
126    if (auto GV = dyn_cast<GlobalValue>(I))
127      GV->eraseFromParent();
128
129  DEBUG(dbgs() << "After OCL21ToSPIRV:\n" << *M);
130  std::string Err;
131  raw_string_ostream ErrorOS(Err);
132  if (verifyModule(*M, &ErrorOS)){
133    DEBUG(errs() << "Fails to verify module: " << ErrorOS.str());
134  }
135  return true;
136}
137
138// The order of handling OCL builtin functions is important.
139// Workgroup functions need to be handled before pipe functions since
140// there are functions fall into both categories.
141void
142OCL21ToSPIRV::visitCallInst(CallInst& CI) {
143  DEBUG(dbgs() << "[visistCallInst] " << CI << '\n');
144  auto F = CI.getCalledFunction();
145  if (!F)
146    return;
147
148  auto MangledName = F->getName();
149  std::string DemangledName;
150
151  if (oclIsBuiltin(MangledName, &DemangledName)) {
152    if (DemangledName == kOCLBuiltinName::SubGroupBarrier) {
153      visitCallSubGroupBarrier(&CI);
154      return;
155    }
156  }
157
158  if (!oclIsBuiltin(MangledName, &DemangledName, true))
159    return;
160  DEBUG(dbgs() << "DemangledName:" << DemangledName << '\n');
161  StringRef Ref(DemangledName);
162
163  Op OC = OpNop;
164  if (!OpCodeNameMap::rfind(Ref.str(), &OC))
165    return;
166  DEBUG(dbgs() << "maps to opcode " << OC << '\n');
167
168  if (isCvtOpCode(OC)) {
169    visitCallConvert(&CI, MangledName, OC);
170    return;
171  }
172  if (OC == OpDecorate) {
173    visitCallDecorate(&CI, MangledName);
174    return;
175  }
176  transBuiltin(&CI, OC);
177}
178
179void OCL21ToSPIRV::visitCallConvert(CallInst* CI,
180    StringRef MangledName, Op OC) {
181  AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
182  mutateCallInstSPIRV(M, CI, [=](CallInst *, std::vector<Value *> &Args){
183    Args.pop_back();
184    return getSPIRVFuncName(OC, kSPIRVPostfix::Divider +
185      getPostfixForReturnType(CI,
186      OC == OpSConvert || OC == OpConvertFToS || OC == OpSatConvertUToS));
187  }, &Attrs);
188  ValuesToDelete.insert(CI);
189  ValuesToDelete.insert(CI->getCalledFunction());
190}
191
192void OCL21ToSPIRV::visitCallDecorate(CallInst* CI,
193    StringRef MangledName) {
194  auto Target = cast<CallInst>(CI->getArgOperand(0));
195  auto F = Target->getCalledFunction();
196  auto Name = F->getName().str();
197  std::string DemangledName;
198  oclIsBuiltin(Name, &DemangledName);
199  BuiltinFuncMangleInfo Info;
200  F->setName(mangleBuiltin(DemangledName + kSPIRVPostfix::Divider +
201      getPostfix(getArgAsDecoration(CI, 1), getArgAsInt(CI, 2)),
202      getTypes(getArguments(CI)), &Info));
203  CI->replaceAllUsesWith(Target);
204  ValuesToDelete.insert(CI);
205  ValuesToDelete.insert(CI->getCalledFunction());
206}
207
208void
209OCL21ToSPIRV::visitCallSubGroupBarrier(CallInst *CI) {
210  DEBUG(dbgs() << "[visitCallSubGroupBarrier] "<< *CI << '\n');
211  auto Lit = getBarrierLiterals(CI);
212  AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
213  mutateCallInstSPIRV(M, CI, [=](CallInst *, std::vector<Value *> &Args){
214      Args.resize(3);
215      Args[0] = addInt32(map<Scope>(std::get<2>(Lit)));
216      Args[1] = addInt32(map<Scope>(std::get<1>(Lit)));
217      Args[2] = addInt32(mapOCLMemFenceFlagToSPIRV(std::get<0>(Lit)));
218      return getSPIRVFuncName(OpControlBarrier);
219    }, &Attrs);
220}
221
222void
223OCL21ToSPIRV::transBuiltin(CallInst* CI, Op OC) {
224  AttributeSet Attrs = CI->getCalledFunction()->getAttributes();
225  assert(OC != OpExtInst && "not supported");
226  mutateCallInstSPIRV(M, CI, [=](CallInst *, std::vector<Value *> &Args){
227    return getSPIRVFuncName(OC);
228  }, &Attrs);
229  ValuesToDelete.insert(CI);
230  ValuesToDelete.insert(CI->getCalledFunction());
231}
232
233}
234
235INITIALIZE_PASS(OCL21ToSPIRV, "cl21tospv", "Transform OCL 2.1 to SPIR-V",
236    false, false)
237
238ModulePass *llvm::createOCL21ToSPIRV() {
239  return new OCL21ToSPIRV();
240}
241