1//===- SPIRVLowerBool.cpp � Lower instructions with bool operands ----------===//
2//
3//                     The LLVM/SPIRV Translator
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8// Copyright (c) 2014 Advanced Micro Devices, Inc. All rights reserved.
9//
10// Permission is hereby granted, free of charge, to any person obtaining a
11// copy of this software and associated documentation files (the "Software"),
12// to deal with the Software without restriction, including without limitation
13// the rights to use, copy, modify, merge, publish, distribute, sublicense,
14// and/or sell copies of the Software, and to permit persons to whom the
15// Software is furnished to do so, subject to the following conditions:
16//
17// Redistributions of source code must retain the above copyright notice,
18// this list of conditions and the following disclaimers.
19// Redistributions in binary form must reproduce the above copyright notice,
20// this list of conditions and the following disclaimers in the documentation
21// and/or other materials provided with the distribution.
22// Neither the names of Advanced Micro Devices, Inc., nor the names of its
23// contributors may be used to endorse or promote products derived from this
24// Software without specific prior written permission.
25// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
26// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
27// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
28// CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
29// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
30// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH
31// THE SOFTWARE.
32//
33//===----------------------------------------------------------------------===//
34//
35// This file implements lowering instructions with bool operands.
36//
37//===----------------------------------------------------------------------===//
38#define DEBUG_TYPE "spvbool"
39
40#include "SPIRVInternal.h"
41#include "llvm/IR/InstVisitor.h"
42#include "llvm/IR/Instructions.h"
43#include "llvm/IR/IRBuilder.h"
44#include "llvm/IR/Verifier.h"
45#include "llvm/Pass.h"
46#include "llvm/PassSupport.h"
47#include "llvm/Support/CommandLine.h"
48#include "llvm/Support/Debug.h"
49#include "llvm/Support/raw_ostream.h"
50
51using namespace llvm;
52using namespace SPIRV;
53
54namespace SPIRV {
55cl::opt<bool> SPIRVLowerBoolValidate("spvbool-validate",
56    cl::desc("Validate module after lowering boolean instructions for SPIR-V"));
57
58class SPIRVLowerBool: public ModulePass,
59  public InstVisitor<SPIRVLowerBool> {
60public:
61  SPIRVLowerBool():ModulePass(ID), Context(nullptr) {
62    initializeSPIRVLowerBoolPass(*PassRegistry::getPassRegistry());
63  }
64  void replace(Instruction *I, Instruction *NewI) {
65    NewI->takeName(I);
66    I->replaceAllUsesWith(NewI);
67    I->dropAllReferences();
68    I->eraseFromParent();
69  }
70  bool isBoolType(Type *Ty) {
71    if (Ty->isIntegerTy(1))
72      return true;
73    if (auto VT = dyn_cast<VectorType>(Ty))
74      return isBoolType(VT->getElementType());
75    return false;
76  }
77  virtual void visitTruncInst(TruncInst &I) {
78    if (isBoolType(I.getType())) {
79      auto Op = I.getOperand(0);
80      auto Zero = getScalarOrVectorConstantInt(Op->getType(), 0, false);
81      auto Cmp = new ICmpInst(&I, CmpInst::ICMP_NE, Op, Zero);
82      replace(&I, Cmp);
83    }
84  }
85  virtual void visitZExtInst(ZExtInst &I) {
86    auto Op = I.getOperand(0);
87    if (isBoolType(Op->getType())) {
88      auto Ty = I.getType();
89      auto Zero = getScalarOrVectorConstantInt(Ty, 0, false);
90      auto One = getScalarOrVectorConstantInt(Ty, 1, false);
91      auto Sel = SelectInst::Create(Op, One, Zero, "", &I);
92      replace(&I, Sel);
93    }
94  }
95  virtual void visitSExtInst(SExtInst &I) {
96    auto Op = I.getOperand(0);
97    if (isBoolType(Op->getType())) {
98      auto Ty = I.getType();
99      auto Zero = getScalarOrVectorConstantInt(Ty, 0, false);
100      auto One = getScalarOrVectorConstantInt(Ty, ~0, false);
101      auto Sel = SelectInst::Create(Op, One, Zero, "", &I);
102      replace(&I, Sel);
103    }
104  }
105  virtual bool runOnModule(Module &M) {
106    Context = &M.getContext();
107    visit(M);
108
109    if (SPIRVLowerBoolValidate) {
110      DEBUG(dbgs() << "After SPIRVLowerBool:\n" << M);
111      std::string Err;
112      raw_string_ostream ErrorOS(Err);
113      if (verifyModule(M, &ErrorOS)){
114        Err = std::string("Fails to verify module: ") + Err;
115        report_fatal_error(Err.c_str(), false);
116      }
117    }
118    return true;
119  }
120
121  static char ID;
122private:
123  LLVMContext *Context;
124};
125
126char SPIRVLowerBool::ID = 0;
127}
128
129INITIALIZE_PASS(SPIRVLowerBool, "spvbool",
130    "Lower instructions with bool operands", false, false)
131
132ModulePass *llvm::createSPIRVLowerBool() {
133  return new SPIRVLowerBool();
134}
135