1//===- SPIRVLowerOCLBlocks.cpp - Lower OpenCL blocks ------------*- C++ -*-===//
2//
3//                     The LLVM/SPIR-V Translator
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8// Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
9//
10// Permission is hereby granted, free of charge, to any person obtaining a
11// copy of this software and associated documentation files (the "Software"),
12// to deal with the Software without restriction, including without limitation
13// the rights to use, copy, modify, merge, publish, distribute, sublicense,
14// and/or sell copies of the Software, and to permit persons to whom the
15// Software is furnished to do so, subject to the following conditions:
16//
17// Redistributions of source code must retain the above copyright notice,
18// this list of conditions and the following disclaimers.
19// Redistributions in binary form must reproduce the above copyright notice,
20// this list of conditions and the following disclaimers in the documentation
21// and/or other materials provided with the distribution.
22// Neither the names of Advanced Micro Devices, Inc., nor the names of its
23// contributors may be used to endorse or promote products derived from this
24// Software without specific prior written permission.
25// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28// CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
31// THE SOFTWARE.
32//
33//===----------------------------------------------------------------------===//
34/// \file
35///
36/// This file implements lowering of OpenCL blocks to functions.
37///
38//===----------------------------------------------------------------------===//
39
40#ifndef OCLLOWERBLOCKS_H_
41#define OCLLOWERBLOCKS_H_
42
43#include "SPIRVInternal.h"
44#include "OCLUtil.h"
45
46#include "llvm/ADT/DenseMap.h"
47#include "llvm/ADT/SetVector.h"
48#include "llvm/ADT/StringSwitch.h"
49#include "llvm/ADT/Triple.h"
50#include "llvm/Analysis/AliasAnalysis.h"
51#include "llvm/Analysis/AssumptionCache.h"
52#include "llvm/Analysis/CallGraph.h"
53#include "llvm/IR/Verifier.h"
54#include "llvm/Bitcode/ReaderWriter.h"
55#include "llvm/IR/Constants.h"
56#include "llvm/IR/DerivedTypes.h"
57#include "llvm/IR/Function.h"
58#include "llvm/IR/InstrTypes.h"
59#include "llvm/IR/Instructions.h"
60#include "llvm/IR/Module.h"
61#include "llvm/IR/Operator.h"
62#include "llvm/Pass.h"
63#include "llvm/PassSupport.h"
64#include "llvm/Support/Casting.h"
65#include "llvm/Support/Debug.h"
66#include "llvm/Support/raw_ostream.h"
67#include "llvm/Support/ToolOutputFile.h"
68#include "llvm/Transforms/Utils/Cloning.h"
69
70#include <iostream>
71#include <list>
72#include <memory>
73#include <set>
74#include <sstream>
75#include <vector>
76
77#define DEBUG_TYPE "spvblocks"
78
79using namespace llvm;
80using namespace SPIRV;
81using namespace OCLUtil;
82
83namespace SPIRV{
84
85/// Lower SPIR2 blocks to function calls.
86///
87/// SPIR2 representation of blocks:
88///
89/// block = spir_block_bind(bitcast(block_func), context_len, context_align,
90///   context)
91/// block_func_ptr = bitcast(spir_get_block_invoke(block))
92/// context_ptr = spir_get_block_context(block)
93/// ret = block_func_ptr(context_ptr, args)
94///
95/// Propagates block_func to each spir_get_block_invoke through def-use chain of
96/// spir_block_bind, so that
97/// ret = block_func(context, args)
98class SPIRVLowerOCLBlocks: public ModulePass {
99public:
100  SPIRVLowerOCLBlocks():ModulePass(ID), M(nullptr){
101    initializeSPIRVLowerOCLBlocksPass(*PassRegistry::getPassRegistry());
102  }
103
104  virtual void getAnalysisUsage(AnalysisUsage &AU) const {
105    AU.addRequired<CallGraphWrapperPass>();
106    //AU.addRequired<AliasAnalysis>();
107    AU.addRequired<AssumptionCacheTracker>();
108  }
109
110  virtual bool runOnModule(Module &Module) {
111    M = &Module;
112    lowerBlockBind();
113    lowerGetBlockInvoke();
114    lowerGetBlockContext();
115    erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE));
116    erase(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT));
117    erase(M->getFunction(SPIR_INTRINSIC_BLOCK_BIND));
118    DEBUG(dbgs() << "------- After OCLLowerBlocks ------------\n" <<
119                    *M << '\n');
120    return true;
121  }
122
123  static char ID;
124private:
125  const static int MaxIter = 1000;
126  Module *M;
127
128  bool
129  lowerBlockBind() {
130    auto F = M->getFunction(SPIR_INTRINSIC_BLOCK_BIND);
131    if (!F)
132      return false;
133    int Iter = MaxIter;
134    while(lowerBlockBind(F) && Iter > 0){
135      Iter--;
136      DEBUG(dbgs() << "-------------- after iteration " << MaxIter - Iter <<
137          " --------------\n" << *M << '\n');
138    }
139    assert(Iter > 0 && "Too many iterations");
140    return true;
141  }
142
143  bool
144  eraseUselessFunctions() {
145    bool changed = false;
146    for (auto I = M->begin(), E = M->end(); I != E;) {
147      Function *F = static_cast<Function*>(I++);
148      if (!GlobalValue::isInternalLinkage(F->getLinkage()) &&
149          !F->isDeclaration())
150        continue;
151
152      dumpUsers(F, "[eraseUselessFunctions] ");
153      for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
154        auto U = *UI++;
155        if (auto CE = dyn_cast<ConstantExpr>(U)){
156          if (CE->use_empty()) {
157            CE->dropAllReferences();
158            changed = true;
159          }
160        }
161      }
162      if (F->use_empty()) {
163        erase(F);
164        changed = true;
165      }
166    }
167    return changed;
168  }
169
170  void
171  lowerGetBlockInvoke() {
172    if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE)) {
173      for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
174        auto CI = dyn_cast<CallInst>(*UI++);
175        assert(CI && "Invalid usage of spir_get_block_invoke");
176        lowerGetBlockInvoke(CI);
177      }
178    }
179  }
180
181  void
182  lowerGetBlockContext() {
183    if (auto F = M->getFunction(SPIR_INTRINSIC_GET_BLOCK_CONTEXT)) {
184      for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
185        auto CI = dyn_cast<CallInst>(*UI++);
186        assert(CI && "Invalid usage of spir_get_block_context");
187        lowerGetBlockContext(CI);
188      }
189    }
190  }
191  /// Lower calls of spir_block_bind.
192  /// Return true if the Module is changed.
193  bool
194  lowerBlockBind(Function *BlockBindFunc) {
195    bool changed = false;
196    for (auto I = BlockBindFunc->user_begin(), E = BlockBindFunc->user_end();
197        I != E;) {
198      DEBUG(dbgs() << "[lowerBlockBind] " << **I << '\n');
199      // Handle spir_block_bind(bitcast(block_func), context_len,
200      // context_align, context)
201      auto CallBlkBind = cast<CallInst>(*I++);
202      Function *InvF = nullptr;
203      Value *Ctx = nullptr;
204      Value *CtxLen = nullptr;
205      Value *CtxAlign = nullptr;
206      getBlockInvokeFuncAndContext(CallBlkBind, &InvF, &Ctx, &CtxLen,
207          &CtxAlign);
208      for (auto II = CallBlkBind->user_begin(), EE = CallBlkBind->user_end();
209          II != EE;) {
210        auto BlkUser = *II++;
211        SPIRVDBG(dbgs() << "  Block user: " << *BlkUser << '\n');
212        if (auto Ret = dyn_cast<ReturnInst>(BlkUser)) {
213          bool Inlined = false;
214          changed |= lowerReturnBlock(Ret, CallBlkBind, Inlined);
215          if (Inlined)
216            return true;
217        } else if (auto CI = dyn_cast<CallInst>(BlkUser)){
218          auto CallBindF = CI->getCalledFunction();
219          auto Name = CallBindF->getName();
220          std::string DemangledName;
221          if (Name == SPIR_INTRINSIC_GET_BLOCK_INVOKE) {
222            assert(CI->getArgOperand(0) == CallBlkBind);
223            changed |= lowerGetBlockInvoke(CI, cast<Function>(InvF));
224          } else if (Name == SPIR_INTRINSIC_GET_BLOCK_CONTEXT) {
225            assert(CI->getArgOperand(0) == CallBlkBind);
226            // Handle context_ptr = spir_get_block_context(block)
227            lowerGetBlockContext(CI, Ctx);
228            changed = true;
229          } else if (oclIsBuiltin(Name, &DemangledName)) {
230            lowerBlockBuiltin(CI, InvF, Ctx, CtxLen, CtxAlign, DemangledName);
231            changed = true;
232          } else
233            llvm_unreachable("Invalid block user");
234        }
235      }
236      erase(CallBlkBind);
237    }
238    changed |= eraseUselessFunctions();
239    return changed;
240  }
241
242  void
243  lowerGetBlockContext(CallInst *CallGetBlkCtx, Value *Ctx = nullptr) {
244    if (!Ctx)
245      getBlockInvokeFuncAndContext(CallGetBlkCtx->getArgOperand(0), nullptr,
246          &Ctx);
247    CallGetBlkCtx->replaceAllUsesWith(Ctx);
248    DEBUG(dbgs() << "  [lowerGetBlockContext] " << *CallGetBlkCtx << " => " <<
249        *Ctx << "\n\n");
250    erase(CallGetBlkCtx);
251  }
252
253  bool
254  lowerGetBlockInvoke(CallInst *CallGetBlkInvoke,
255      Function *InvokeF = nullptr) {
256    bool changed = false;
257    for (auto UI = CallGetBlkInvoke->user_begin(),
258        UE = CallGetBlkInvoke->user_end();
259        UI != UE;) {
260      // Handle block_func_ptr = bitcast(spir_get_block_invoke(block))
261      auto CallInv = cast<Instruction>(*UI++);
262      auto Cast = dyn_cast<BitCastInst>(CallInv);
263      if (Cast)
264        CallInv = dyn_cast<Instruction>(*CallInv->user_begin());
265      DEBUG(dbgs() << "[lowerGetBlockInvoke]  " << *CallInv);
266      // Handle ret = block_func_ptr(context_ptr, args)
267      auto CI = cast<CallInst>(CallInv);
268      auto F = CI->getCalledValue();
269      if (InvokeF == nullptr) {
270        getBlockInvokeFuncAndContext(CallGetBlkInvoke->getArgOperand(0),
271            &InvokeF, nullptr);
272        assert(InvokeF);
273      }
274      assert(F->getType() == InvokeF->getType());
275      CI->replaceUsesOfWith(F, InvokeF);
276      DEBUG(dbgs() << " => " << *CI << "\n\n");
277      erase(Cast);
278      changed = true;
279    }
280    erase(CallGetBlkInvoke);
281    return changed;
282  }
283
284  void
285  lowerBlockBuiltin(CallInst *CI, Function *InvF, Value *Ctx, Value *CtxLen,
286      Value *CtxAlign, const std::string& DemangledName) {
287    mutateCallInstSPIRV (M, CI, [=](CallInst *CI, std::vector<Value *> &Args) {
288      size_t I = 0;
289      size_t E = Args.size();
290      for (; I != E; ++I) {
291        if (isPointerToOpaqueStructType(Args[I]->getType(),
292            SPIR_TYPE_NAME_BLOCK_T)) {
293          break;
294        }
295      }
296      assert (I < E);
297      Args[I] = castToVoidFuncPtr(InvF);
298      if (I + 1 == E) {
299        Args.push_back(Ctx);
300        Args.push_back(CtxLen);
301        Args.push_back(CtxAlign);
302      } else {
303        Args.insert(Args.begin() + I + 1, CtxAlign);
304        Args.insert(Args.begin() + I + 1, CtxLen);
305        Args.insert(Args.begin() + I + 1, Ctx);
306      }
307      if (DemangledName == kOCLBuiltinName::EnqueueKernel) {
308        // Insert event arguments if there are not.
309        if (!isa<IntegerType>(Args[3]->getType())) {
310          Args.insert(Args.begin() + 3, getInt32(M, 0));
311          Args.insert(Args.begin() + 4, getOCLNullClkEventPtr());
312        }
313        if (!isOCLClkEventPtrType(Args[5]->getType()))
314          Args.insert(Args.begin() + 5, getOCLNullClkEventPtr());
315      }
316      return getSPIRVFuncName(OCLSPIRVBuiltinMap::map(DemangledName));
317    });
318  }
319  /// Transform return of a block.
320  /// The function returning a block is inlined since the context cannot be
321  /// passed to another function.
322  /// Returns true of module is changed.
323  bool
324  lowerReturnBlock(ReturnInst *Ret, Value *CallBlkBind, bool &Inlined) {
325    auto F = Ret->getParent()->getParent();
326    auto changed = false;
327    for (auto UI = F->user_begin(), UE = F->user_end(); UI != UE;) {
328      auto U = *UI++;
329      dumpUsers(U);
330      auto Inst = dyn_cast<Instruction>(U);
331      if (Inst && Inst->use_empty()) {
332        erase(Inst);
333        changed = true;
334        continue;
335      }
336      auto CI = dyn_cast<CallInst>(U);
337      if(!CI || CI->getCalledFunction() != F)
338        continue;
339
340      DEBUG(dbgs() << "[lowerReturnBlock] inline " << F->getName() << '\n');
341      auto CG = &getAnalysis<CallGraphWrapperPass>().getCallGraph();
342      auto ACT = &getAnalysis<AssumptionCacheTracker>();
343      //auto AA = &getAnalysis<AliasAnalysis>();
344      //InlineFunctionInfo IFI(CG, M->getDataLayout(), AA, ACT);
345      InlineFunctionInfo IFI(CG, ACT);
346      InlineFunction(CI, IFI);
347      Inlined = true;
348    }
349    return changed || Inlined;
350  }
351
352  void
353  getBlockInvokeFuncAndContext(Value *Blk, Function **PInvF, Value **PCtx,
354      Value **PCtxLen = nullptr, Value **PCtxAlign = nullptr){
355    Function *InvF = nullptr;
356    Value *Ctx = nullptr;
357    Value *CtxLen = nullptr;
358    Value *CtxAlign = nullptr;
359    if (auto CallBlkBind = dyn_cast<CallInst>(Blk)) {
360      assert(CallBlkBind->getCalledFunction()->getName() ==
361          SPIR_INTRINSIC_BLOCK_BIND && "Invalid block");
362      InvF = dyn_cast<Function>(
363          CallBlkBind->getArgOperand(0)->stripPointerCasts());
364      CtxLen = CallBlkBind->getArgOperand(1);
365      CtxAlign = CallBlkBind->getArgOperand(2);
366      Ctx = CallBlkBind->getArgOperand(3);
367    } else if (auto F = dyn_cast<Function>(Blk->stripPointerCasts())) {
368      InvF = F;
369      Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext()));
370    } else if (auto Load = dyn_cast<LoadInst>(Blk)) {
371      auto Op = Load->getPointerOperand();
372      if (auto GV = dyn_cast<GlobalVariable>(Op)) {
373        if (GV->isConstant()) {
374          InvF = cast<Function>(GV->getInitializer()->stripPointerCasts());
375          Ctx = Constant::getNullValue(IntegerType::getInt8PtrTy(M->getContext()));
376        } else {
377          llvm_unreachable("load non-constant block?");
378        }
379      } else {
380        llvm_unreachable("Loading block from non global?");
381      }
382    } else {
383      llvm_unreachable("Invalid block");
384    }
385    DEBUG(dbgs() << "  Block invocation func: " << InvF->getName() << '\n' <<
386        "  Block context: " << *Ctx << '\n');
387    assert(InvF && Ctx && "Invalid block");
388    if (PInvF)
389      *PInvF = InvF;
390    if (PCtx)
391      *PCtx = Ctx;
392    if (PCtxLen)
393      *PCtxLen = CtxLen;
394    if (PCtxAlign)
395      *PCtxAlign = CtxAlign;
396  }
397  void
398  erase(Instruction *I) {
399    if (!I)
400      return;
401    if (I->use_empty()) {
402      I->dropAllReferences();
403      I->eraseFromParent();
404    }
405    else
406      dumpUsers(I);
407  }
408  void
409  erase(ConstantExpr *I) {
410    if (!I)
411      return;
412    if (I->use_empty()) {
413      I->dropAllReferences();
414      I->destroyConstant();
415    } else
416      dumpUsers(I);
417  }
418  void
419  erase(Function *F) {
420    if (!F)
421      return;
422    if (!F->use_empty()) {
423      dumpUsers(F);
424      return;
425    }
426    F->dropAllReferences();
427    auto &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
428    CG.removeFunctionFromModule(new CallGraphNode(F));
429  }
430
431  llvm::PointerType* getOCLClkEventType() {
432    return getOrCreateOpaquePtrType(M, SPIR_TYPE_NAME_CLK_EVENT_T,
433        SPIRAS_Global);
434  }
435
436  llvm::PointerType* getOCLClkEventPtrType() {
437    return PointerType::get(getOCLClkEventType(), SPIRAS_Generic);
438  }
439
440  bool isOCLClkEventPtrType(Type *T) {
441    if (auto PT = dyn_cast<PointerType>(T))
442      return isPointerToOpaqueStructType(
443        PT->getElementType(), SPIR_TYPE_NAME_CLK_EVENT_T);
444    return false;
445  }
446
447  llvm::Constant* getOCLNullClkEventPtr() {
448    return Constant::getNullValue(getOCLClkEventPtrType());
449  }
450
451  void dumpGetBlockInvokeUsers(StringRef Prompt) {
452    DEBUG(dbgs() << Prompt);
453    dumpUsers(M->getFunction(SPIR_INTRINSIC_GET_BLOCK_INVOKE));
454  }
455};
456
457char SPIRVLowerOCLBlocks::ID = 0;
458}
459
460INITIALIZE_PASS_BEGIN(SPIRVLowerOCLBlocks, "spvblocks",
461    "SPIR-V lower OCL blocks", false, false)
462INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
463INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
464//INITIALIZE_AG_DEPENDENCY(AliasAnalysis)
465INITIALIZE_PASS_END(SPIRVLowerOCLBlocks, "spvblocks",
466    "SPIR-V lower OCL blocks", false, false)
467
468ModulePass *llvm::createSPIRVLowerOCLBlocks() {
469  return new SPIRVLowerOCLBlocks();
470}
471
472#endif /* OCLLOWERBLOCKS_H_ */
473