1/*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Wrapper.h"
18
19#include "llvm/IR/Module.h"
20
21#include "Builtin.h"
22#include "Context.h"
23#include "GlobalAllocSPIRITPass.h"
24#include "RSAllocationUtils.h"
25#include "bcinfo/MetadataExtractor.h"
26#include "builder.h"
27#include "instructions.h"
28#include "module.h"
29#include "pass.h"
30
31#include <sstream>
32#include <vector>
33
34using bcinfo::MetadataExtractor;
35
36namespace android {
37namespace spirit {
38
39VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b,
40                        Module *m) {
41  auto ArrTy = m->getRuntimeArrayType(elementType);
42  const size_t stride = m->getSize(elementType);
43  ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride);
44  auto StructTy = m->getStructType(ArrTy);
45  StructTy->decorate(Decoration::BufferBlock);
46  StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0);
47
48  auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy);
49
50  VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform);
51  bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
52  bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding);
53  m->addVariable(bufferVar);
54
55  return bufferVar;
56}
57
58bool AddWrapper(const char *name, const uint32_t signature,
59                const uint32_t numInput, Builder &b, Module *m) {
60  FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name);
61  if (kernel == nullptr) {
62    // In the metadata for RenderScript LLVM bitcode, the first foreach kernel
63    // is always reserved for the root kernel, even though in the most recent RS
64    // apps it does not exist. Simply bypass wrapper generation here, and return
65    // true for this case.
66    // Otherwise, if a non-root kernel function cannot be found, it is a
67    // fatal internal error which is really unexpected.
68    return (strncmp(name, "root", 4) == 0);
69  }
70
71  // The following three cases are not supported
72  if (!MetadataExtractor::hasForEachSignatureKernel(signature)) {
73    // Not handling old-style kernel
74    return false;
75  }
76
77  if (MetadataExtractor::hasForEachSignatureUsrData(signature)) {
78    // Not handling the user argument
79    return false;
80  }
81
82  if (MetadataExtractor::hasForEachSignatureCtxt(signature)) {
83    // Not handling the context argument
84    return false;
85  }
86
87  TypeVoidInst *VoidTy = m->getVoidType();
88  TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0);
89  FunctionDefinition *Func =
90      b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy);
91  m->addFunctionDefinition(Func);
92
93  Block *Blk = b.MakeBlock();
94  Func->addBlock(Blk);
95
96  Blk->addInstruction(b.MakeLabel());
97
98  TypeIntInst *UIntTy = m->getUnsignedIntType(32);
99
100  Instruction *XValue = nullptr;
101  Instruction *YValue = nullptr;
102  Instruction *ZValue = nullptr;
103  Instruction *Index = nullptr;
104  VariableInst *InvocationId = nullptr;
105  VariableInst *NumWorkgroups = nullptr;
106
107  if (MetadataExtractor::hasForEachSignatureIn(signature) ||
108      MetadataExtractor::hasForEachSignatureOut(signature) ||
109      MetadataExtractor::hasForEachSignatureX(signature) ||
110      MetadataExtractor::hasForEachSignatureY(signature) ||
111      MetadataExtractor::hasForEachSignatureZ(signature)) {
112    TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3);
113    InvocationId = m->getInvocationId();
114    auto IID = b.MakeLoad(V3UIntTy, InvocationId);
115    Blk->addInstruction(IID);
116
117    XValue = b.MakeCompositeExtract(UIntTy, IID, {0});
118    Blk->addInstruction(XValue);
119
120    YValue = b.MakeCompositeExtract(UIntTy, IID, {1});
121    Blk->addInstruction(YValue);
122
123    ZValue = b.MakeCompositeExtract(UIntTy, IID, {2});
124    Blk->addInstruction(ZValue);
125
126    // TODO: Use SpecConstant for workgroup size
127    auto ConstOne = m->getConstant(UIntTy, 1U);
128    auto GroupSize =
129        m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne);
130
131    auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0});
132    Blk->addInstruction(GroupSizeX);
133
134    auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1});
135    Blk->addInstruction(GroupSizeY);
136
137    NumWorkgroups = m->getNumWorkgroups();
138    auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups);
139    Blk->addInstruction(NumGroup);
140
141    auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0});
142    Blk->addInstruction(NumGroupX);
143
144    auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1});
145    Blk->addInstruction(NumGroupY);
146
147    auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX);
148    Blk->addInstruction(GlobalSizeX);
149
150    auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY);
151    Blk->addInstruction(GlobalSizeY);
152
153    auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue);
154    Blk->addInstruction(RowsAlongZ);
155
156    auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ);
157    Blk->addInstruction(NumRows);
158
159    auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows);
160    Blk->addInstruction(NumCellsFromYZ);
161
162    Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue);
163    Blk->addInstruction(Index);
164  }
165
166  std::vector<IdRef> inputs;
167
168  ConstantInst *ConstZero = m->getConstant(UIntTy, 0);
169
170  for (uint32_t i = 0; i < numInput; i++) {
171    FunctionParameterInst *param = kernel->getParameter(i);
172    Instruction *elementType = param->mResultType.mInstruction;
173    VariableInst *inputBuffer = AddBuffer(elementType, i + 3, b, m);
174
175    TypePointerInst *PtrTy =
176        m->getPointerType(StorageClass::Function, elementType);
177    AccessChainInst *Ptr =
178        b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index});
179    Blk->addInstruction(Ptr);
180
181    Instruction *input = b.MakeLoad(elementType, Ptr);
182    Blk->addInstruction(input);
183
184    inputs.push_back(IdRef(input));
185  }
186
187  // TODO: Convert from unsigned int to signed int if that is what the kernel
188  // function takes for the coordinate parameters
189  if (MetadataExtractor::hasForEachSignatureX(signature)) {
190    inputs.push_back(XValue);
191    if (MetadataExtractor::hasForEachSignatureY(signature)) {
192      inputs.push_back(YValue);
193      if (MetadataExtractor::hasForEachSignatureZ(signature)) {
194        inputs.push_back(ZValue);
195      }
196    }
197  }
198
199  auto resultType = kernel->getReturnType();
200  auto kernelCall =
201      b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs);
202  Blk->addInstruction(kernelCall);
203
204  if (MetadataExtractor::hasForEachSignatureOut(signature)) {
205    VariableInst *OutputBuffer = AddBuffer(resultType, 2, b, m);
206    auto resultPtrType = m->getPointerType(StorageClass::Function, resultType);
207    AccessChainInst *OutPtr =
208        b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index});
209    Blk->addInstruction(OutPtr);
210    Blk->addInstruction(b.MakeStore(OutPtr, kernelCall));
211  }
212
213  Blk->addInstruction(b.MakeReturn());
214
215  std::string wrapperName("entry_");
216  wrapperName.append(name);
217
218  EntryPointDefinition *entry = b.MakeEntryPointDefinition(
219      ExecutionModel::GLCompute, Func, wrapperName.c_str());
220
221  entry->setLocalSize(1, 1, 1);
222
223  if (Index != nullptr) {
224    entry->addToInterface(InvocationId);
225    entry->addToInterface(NumWorkgroups);
226  }
227
228  m->addEntryPoint(entry);
229
230  return true;
231}
232
233bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) {
234  Instruction *inst = m->lookupByName("__GPUBlock");
235  if (inst == nullptr) {
236    return true;
237  }
238
239  VariableInst *bufferVar = static_cast<VariableInst *>(inst);
240  bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
241  bufferVar->decorate(Decoration::Binding)->addExtraOperand(0);
242
243  TypePointerInst *StructPtrTy =
244      static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction);
245  TypeStructInst *StructTy =
246      static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction);
247  StructTy->decorate(Decoration::BufferBlock);
248
249  // Decorate each member with proper offsets
250
251  const auto GlobalsB = LM.globals().begin();
252  const auto GlobalsE = LM.globals().end();
253  const auto Found =
254      std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) {
255        return GV.getName() == "__GPUBlock";
256      });
257
258  if (Found == GlobalsE) {
259    return true; // GPUBlock not found - not an error by itself.
260  }
261
262  const llvm::GlobalVariable &G = *Found;
263
264  rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
265  bool IsCorrectTy = false;
266  if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) {
267    if (auto *LStructTy =
268            llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) {
269      IsCorrectTy = true;
270
271      const auto &DLayout = LM.getDataLayout();
272      const auto *SLayout = DLayout.getStructLayout(LStructTy);
273      assert(SLayout);
274      if (SLayout == nullptr) {
275        std::cerr << "struct layout is null" << std::endl;
276        return false;
277      }
278      std::vector<uint32_t> offsets;
279      for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) {
280        auto decor = StructTy->memberDecorate(i, Decoration::Offset);
281        if (!decor) {
282          std::cerr << "failed creating member decoration for field " << i
283                    << std::endl;
284          return false;
285        }
286        const uint32_t offset = (uint32_t)SLayout->getElementOffset(i);
287        decor->addExtraOperand(offset);
288        offsets.push_back(offset);
289      }
290      std::stringstream ssOffsets;
291      // TODO: define this string in a central place
292      ssOffsets << ".rsov.ExportedVars:";
293      for(uint32_t slot = 0; slot < Ctxt.getNumExportVar(); slot++) {
294        const uint32_t index = Ctxt.getExportVarIndex(slot);
295        const uint32_t offset = offsets[index];
296        ssOffsets << offset << ';';
297      }
298      m->addString(ssOffsets.str().c_str());
299
300      std::stringstream ssGlobalSize;
301      ssGlobalSize << ".rsov.GlobalSize:" << Ctxt.getGlobalSize();
302      m->addString(ssGlobalSize.str().c_str());
303    }
304  }
305
306  if (!IsCorrectTy) {
307    return false;
308  }
309
310  llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs;
311  if (!getRSAllocationInfo(LM, RSAllocs)) {
312    // llvm::errs() << "Extracting rs_allocation info failed\n";
313    return true;
314  }
315
316  // TODO: clean up the binding number assignment
317  size_t BindingNum = 3;
318  for (const auto &A : RSAllocs) {
319    Instruction *inst = m->lookupByName(A.VarName.c_str());
320    if (inst == nullptr) {
321      return false;
322    }
323    VariableInst *bufferVar = static_cast<VariableInst *>(inst);
324    bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
325    bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++);
326  }
327
328  return true;
329}
330
331void AddHeader(Module *m) {
332  m->addCapability(Capability::Shader);
333  m->setMemoryModel(AddressingModel::Logical, MemoryModel::GLSL450);
334
335  m->addSource(SourceLanguage::GLSL, 450);
336  m->addSourceExtension("GL_ARB_separate_shader_objects");
337  m->addSourceExtension("GL_ARB_shading_language_420pack");
338  m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive");
339  m->addSourceExtension("GL_GOOGLE_include_directive");
340}
341
342namespace {
343
344class StorageClassVisitor : public DoNothingVisitor {
345public:
346  void visit(TypePointerInst *inst) override {
347    matchAndReplace(inst->mOperand1);
348  }
349
350  void visit(TypeForwardPointerInst *inst) override {
351    matchAndReplace(inst->mOperand2);
352  }
353
354  void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); }
355
356private:
357  void matchAndReplace(StorageClass &storage) {
358    if (storage == StorageClass::Function) {
359      storage = StorageClass::Uniform;
360    }
361  }
362};
363
364void FixGlobalStorageClass(Module *m) {
365  StorageClassVisitor v;
366  m->getGlobalSection()->accept(&v);
367}
368
369} // anonymous namespace
370
371bool AddWrappers(llvm::Module &LM,
372                 android::spirit::Module *m) {
373  rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
374  const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata();
375  android::spirit::Builder b;
376
377  m->setBuilder(&b);
378
379  FixGlobalStorageClass(m);
380
381  AddHeader(m);
382
383  DecorateGlobalBuffer(LM, b, m);
384
385  const size_t numKernel = metadata.getExportForEachSignatureCount();
386  const char **kernelName = metadata.getExportForEachNameList();
387  const uint32_t *kernelSigature = metadata.getExportForEachSignatureList();
388  const uint32_t *inputCount = metadata.getExportForEachInputCountList();
389
390  for (size_t i = 0; i < numKernel; i++) {
391    bool success =
392        AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m);
393    if (!success) {
394      return false;
395    }
396  }
397
398  m->consolidateAnnotations();
399  return true;
400}
401
402class WrapperPass : public Pass {
403public:
404  WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {}
405
406  Module *run(Module *m, int *error) override {
407    bool success = AddWrappers(mLLVMModule, m);
408    if (error) {
409      *error = success ? 0 : -1;
410    }
411    return m;
412  }
413
414private:
415  llvm::Module &mLLVMModule;
416};
417
418} // namespace spirit
419} // namespace android
420
421namespace rs2spirv {
422
423android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) {
424  return new android::spirit::WrapperPass(LLVMModule);
425}
426
427} // namespace rs2spirv
428