1/*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Wrapper.h"
18
19#include "llvm/IR/Module.h"
20
21#include "Builtin.h"
22#include "Context.h"
23#include "GlobalAllocSPIRITPass.h"
24#include "RSAllocationUtils.h"
25#include "bcinfo/MetadataExtractor.h"
26#include "builder.h"
27#include "instructions.h"
28#include "module.h"
29#include "pass.h"
30
31#include <sstream>
32#include <vector>
33
34using bcinfo::MetadataExtractor;
35
36namespace android {
37namespace spirit {
38
39VariableInst *AddBuffer(Instruction *elementType, uint32_t binding, Builder &b,
40                        Module *m) {
41  auto ArrTy = m->getRuntimeArrayType(elementType);
42  const size_t stride = m->getSize(elementType);
43  ArrTy->decorate(Decoration::ArrayStride)->addExtraOperand(stride);
44  auto StructTy = m->getStructType(ArrTy);
45  StructTy->decorate(Decoration::BufferBlock);
46  StructTy->memberDecorate(0, Decoration::Offset)->addExtraOperand(0);
47
48  auto StructPtrTy = m->getPointerType(StorageClass::Uniform, StructTy);
49
50  VariableInst *bufferVar = b.MakeVariable(StructPtrTy, StorageClass::Uniform);
51  bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
52  bufferVar->decorate(Decoration::Binding)->addExtraOperand(binding);
53  m->addVariable(bufferVar);
54
55  return bufferVar;
56}
57
58bool AddWrapper(const char *name, const uint32_t signature,
59                const uint32_t numInput, Builder &b, Module *m) {
60  FunctionDefinition *kernel = m->lookupFunctionDefinitionByName(name);
61  if (kernel == nullptr) {
62    // In the metadata for RenderScript LLVM bitcode, the first foreach kernel
63    // is always reserved for the root kernel, even though in the most recent RS
64    // apps it does not exist. Simply bypass wrapper generation here, and return
65    // true for this case.
66    // Otherwise, if a non-root kernel function cannot be found, it is a
67    // fatal internal error which is really unexpected.
68    return (strncmp(name, "root", 4) == 0);
69  }
70
71  // The following three cases are not supported
72  if (!MetadataExtractor::hasForEachSignatureKernel(signature)) {
73    // Not handling old-style kernel
74    return false;
75  }
76
77  if (MetadataExtractor::hasForEachSignatureUsrData(signature)) {
78    // Not handling the user argument
79    return false;
80  }
81
82  if (MetadataExtractor::hasForEachSignatureCtxt(signature)) {
83    // Not handling the context argument
84    return false;
85  }
86
87  TypeVoidInst *VoidTy = m->getVoidType();
88  TypeFunctionInst *FuncTy = m->getFunctionType(VoidTy, nullptr, 0);
89  FunctionDefinition *Func =
90      b.MakeFunctionDefinition(VoidTy, FunctionControl::None, FuncTy);
91  m->addFunctionDefinition(Func);
92
93  Block *Blk = b.MakeBlock();
94  Func->addBlock(Blk);
95
96  Blk->addInstruction(b.MakeLabel());
97
98  TypeIntInst *UIntTy = m->getUnsignedIntType(32);
99
100  Instruction *XValue = nullptr;
101  Instruction *YValue = nullptr;
102  Instruction *ZValue = nullptr;
103  Instruction *Index = nullptr;
104  VariableInst *InvocationId = nullptr;
105  VariableInst *NumWorkgroups = nullptr;
106
107  if (MetadataExtractor::hasForEachSignatureIn(signature) ||
108      MetadataExtractor::hasForEachSignatureOut(signature) ||
109      MetadataExtractor::hasForEachSignatureX(signature) ||
110      MetadataExtractor::hasForEachSignatureY(signature) ||
111      MetadataExtractor::hasForEachSignatureZ(signature)) {
112    TypeVectorInst *V3UIntTy = m->getVectorType(UIntTy, 3);
113    InvocationId = m->getInvocationId();
114    auto IID = b.MakeLoad(V3UIntTy, InvocationId);
115    Blk->addInstruction(IID);
116
117    XValue = b.MakeCompositeExtract(UIntTy, IID, {0});
118    Blk->addInstruction(XValue);
119
120    YValue = b.MakeCompositeExtract(UIntTy, IID, {1});
121    Blk->addInstruction(YValue);
122
123    ZValue = b.MakeCompositeExtract(UIntTy, IID, {2});
124    Blk->addInstruction(ZValue);
125
126    // TODO: Use SpecConstant for workgroup size
127    auto ConstOne = m->getConstant(UIntTy, 1U);
128    auto GroupSize =
129        m->getConstantComposite(V3UIntTy, ConstOne, ConstOne, ConstOne);
130
131    auto GroupSizeX = b.MakeCompositeExtract(UIntTy, GroupSize, {0});
132    Blk->addInstruction(GroupSizeX);
133
134    auto GroupSizeY = b.MakeCompositeExtract(UIntTy, GroupSize, {1});
135    Blk->addInstruction(GroupSizeY);
136
137    NumWorkgroups = m->getNumWorkgroups();
138    auto NumGroup = b.MakeLoad(V3UIntTy, NumWorkgroups);
139    Blk->addInstruction(NumGroup);
140
141    auto NumGroupX = b.MakeCompositeExtract(UIntTy, NumGroup, {0});
142    Blk->addInstruction(NumGroupX);
143
144    auto NumGroupY = b.MakeCompositeExtract(UIntTy, NumGroup, {1});
145    Blk->addInstruction(NumGroupY);
146
147    auto GlobalSizeX = b.MakeIMul(UIntTy, GroupSizeX, NumGroupX);
148    Blk->addInstruction(GlobalSizeX);
149
150    auto GlobalSizeY = b.MakeIMul(UIntTy, GroupSizeY, NumGroupY);
151    Blk->addInstruction(GlobalSizeY);
152
153    auto RowsAlongZ = b.MakeIMul(UIntTy, GlobalSizeY, ZValue);
154    Blk->addInstruction(RowsAlongZ);
155
156    auto NumRows = b.MakeIAdd(UIntTy, YValue, RowsAlongZ);
157    Blk->addInstruction(NumRows);
158
159    auto NumCellsFromYZ = b.MakeIMul(UIntTy, GlobalSizeX, NumRows);
160    Blk->addInstruction(NumCellsFromYZ);
161
162    Index = b.MakeIAdd(UIntTy, NumCellsFromYZ, XValue);
163    Blk->addInstruction(Index);
164  }
165
166  std::vector<IdRef> inputs;
167
168  ConstantInst *ConstZero = m->getConstant(UIntTy, 0);
169
170  for (uint32_t i = 0; i < numInput; i++) {
171    FunctionParameterInst *param = kernel->getParameter(i);
172    Instruction *elementType = param->mResultType.mInstruction;
173    VariableInst *inputBuffer = AddBuffer(elementType, i + 2, b, m);
174
175    TypePointerInst *PtrTy =
176        m->getPointerType(StorageClass::Function, elementType);
177    AccessChainInst *Ptr =
178        b.MakeAccessChain(PtrTy, inputBuffer, {ConstZero, Index});
179    Blk->addInstruction(Ptr);
180
181    Instruction *input = b.MakeLoad(elementType, Ptr);
182    Blk->addInstruction(input);
183
184    inputs.push_back(IdRef(input));
185  }
186
187  // TODO: Convert from unsigned int to signed int if that is what the kernel
188  // function takes for the coordinate parameters
189  if (MetadataExtractor::hasForEachSignatureX(signature)) {
190    inputs.push_back(XValue);
191    if (MetadataExtractor::hasForEachSignatureY(signature)) {
192      inputs.push_back(YValue);
193      if (MetadataExtractor::hasForEachSignatureZ(signature)) {
194        inputs.push_back(ZValue);
195      }
196    }
197  }
198
199  auto resultType = kernel->getReturnType();
200  auto kernelCall =
201      b.MakeFunctionCall(resultType, kernel->getInstruction(), inputs);
202  Blk->addInstruction(kernelCall);
203
204  if (MetadataExtractor::hasForEachSignatureOut(signature)) {
205    VariableInst *OutputBuffer = AddBuffer(resultType, 1, b, m);
206    auto resultPtrType = m->getPointerType(StorageClass::Function, resultType);
207    AccessChainInst *OutPtr =
208        b.MakeAccessChain(resultPtrType, OutputBuffer, {ConstZero, Index});
209    Blk->addInstruction(OutPtr);
210    Blk->addInstruction(b.MakeStore(OutPtr, kernelCall));
211  }
212
213  Blk->addInstruction(b.MakeReturn());
214
215  std::string wrapperName("entry_");
216  wrapperName.append(name);
217
218  EntryPointDefinition *entry = b.MakeEntryPointDefinition(
219      ExecutionModel::GLCompute, Func, wrapperName.c_str());
220
221  entry->setLocalSize(1, 1, 1);
222
223  if (Index != nullptr) {
224    entry->addToInterface(InvocationId);
225    entry->addToInterface(NumWorkgroups);
226  }
227
228  m->addEntryPoint(entry);
229
230  return true;
231}
232
233bool DecorateGlobalBuffer(llvm::Module &LM, Builder &b, Module *m) {
234  Instruction *inst = m->lookupByName("__GPUBlock");
235  if (inst == nullptr) {
236    return true;
237  }
238
239  VariableInst *bufferVar = static_cast<VariableInst *>(inst);
240  bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
241  bufferVar->decorate(Decoration::Binding)->addExtraOperand(0);
242
243  TypePointerInst *StructPtrTy =
244      static_cast<TypePointerInst *>(bufferVar->mResultType.mInstruction);
245  TypeStructInst *StructTy =
246      static_cast<TypeStructInst *>(StructPtrTy->mOperand2.mInstruction);
247  StructTy->decorate(Decoration::BufferBlock);
248
249  // Decorate each member with proper offsets
250
251  const auto GlobalsB = LM.globals().begin();
252  const auto GlobalsE = LM.globals().end();
253  const auto Found =
254      std::find_if(GlobalsB, GlobalsE, [](const llvm::GlobalVariable &GV) {
255        return GV.getName() == "__GPUBlock";
256      });
257
258  if (Found == GlobalsE) {
259    return true; // GPUBlock not found - not an error by itself.
260  }
261
262  const llvm::GlobalVariable &G = *Found;
263
264  bool IsCorrectTy = false;
265  if (const auto *LPtrTy = llvm::dyn_cast<llvm::PointerType>(G.getType())) {
266    if (auto *LStructTy =
267            llvm::dyn_cast<llvm::StructType>(LPtrTy->getElementType())) {
268      IsCorrectTy = true;
269
270      const auto &DLayout = LM.getDataLayout();
271      const auto *SLayout = DLayout.getStructLayout(LStructTy);
272      assert(SLayout);
273      if (SLayout == nullptr) {
274        std::cerr << "struct layout is null" << std::endl;
275        return false;
276      }
277      for (uint32_t i = 0, e = LStructTy->getNumElements(); i != e; ++i) {
278        auto decor = StructTy->memberDecorate(i, Decoration::Offset);
279        if (!decor) {
280          std::cerr << "failed creating member decoration for field " << i
281                    << std::endl;
282          return false;
283        }
284        const uint32_t offset = (uint32_t)SLayout->getElementOffset(i);
285        decor->addExtraOperand(offset);
286      }
287    }
288  }
289
290  if (!IsCorrectTy) {
291    return false;
292  }
293
294  llvm::SmallVector<rs2spirv::RSAllocationInfo, 2> RSAllocs;
295  if (!getRSAllocationInfo(LM, RSAllocs)) {
296    // llvm::errs() << "Extracting rs_allocation info failed\n";
297    return true;
298  }
299
300  // TODO: clean up the binding number assignment
301  size_t BindingNum = 3;
302  for (const auto &A : RSAllocs) {
303    Instruction *inst = m->lookupByName(A.VarName.c_str());
304    if (inst == nullptr) {
305      return false;
306    }
307    VariableInst *bufferVar = static_cast<VariableInst *>(inst);
308    bufferVar->decorate(Decoration::DescriptorSet)->addExtraOperand(0);
309    bufferVar->decorate(Decoration::Binding)->addExtraOperand(BindingNum++);
310  }
311
312  return true;
313}
314
315void AddHeader(Module *m) {
316  m->addCapability(Capability::Shader);
317  // TODO: avoid duplicated capability
318  // m->addCapability(Capability::Addresses);
319  m->setMemoryModel(AddressingModel::Physical32, MemoryModel::GLSL450);
320
321  m->addSource(SourceLanguage::GLSL, 450);
322  m->addSourceExtension("GL_ARB_separate_shader_objects");
323  m->addSourceExtension("GL_ARB_shading_language_420pack");
324  m->addSourceExtension("GL_GOOGLE_cpp_style_line_directive");
325  m->addSourceExtension("GL_GOOGLE_include_directive");
326}
327
328namespace {
329
330class StorageClassVisitor : public DoNothingVisitor {
331public:
332  void visit(TypePointerInst *inst) override {
333    matchAndReplace(inst->mOperand1);
334  }
335
336  void visit(TypeForwardPointerInst *inst) override {
337    matchAndReplace(inst->mOperand2);
338  }
339
340  void visit(VariableInst *inst) override { matchAndReplace(inst->mOperand1); }
341
342private:
343  void matchAndReplace(StorageClass &storage) {
344    if (storage == StorageClass::Function) {
345      storage = StorageClass::Uniform;
346    }
347  }
348};
349
350void FixGlobalStorageClass(Module *m) {
351  StorageClassVisitor v;
352  m->getGlobalSection()->accept(&v);
353}
354
355} // anonymous namespace
356
357bool AddWrappers(llvm::Module &LM,
358                 android::spirit::Module *m) {
359  rs2spirv::Context &Ctxt = rs2spirv::Context::getInstance();
360  const bcinfo::MetadataExtractor &metadata = Ctxt.getMetadata();
361  android::spirit::Builder b;
362
363  m->setBuilder(&b);
364
365  FixGlobalStorageClass(m);
366
367  AddHeader(m);
368
369  DecorateGlobalBuffer(LM, b, m);
370
371  const size_t numKernel = metadata.getExportForEachSignatureCount();
372  const char **kernelName = metadata.getExportForEachNameList();
373  const uint32_t *kernelSigature = metadata.getExportForEachSignatureList();
374  const uint32_t *inputCount = metadata.getExportForEachInputCountList();
375
376  for (size_t i = 0; i < numKernel; i++) {
377    bool success =
378        AddWrapper(kernelName[i], kernelSigature[i], inputCount[i], b, m);
379    if (!success) {
380      return false;
381    }
382  }
383
384  m->consolidateAnnotations();
385  return true;
386}
387
388class WrapperPass : public Pass {
389public:
390  WrapperPass(const llvm::Module &LM) : mLLVMModule(const_cast<llvm::Module&>(LM)) {}
391
392  Module *run(Module *m, int *error) override {
393    bool success = AddWrappers(mLLVMModule, m);
394    if (error) {
395      *error = success ? 0 : -1;
396    }
397    return m;
398  }
399
400private:
401  llvm::Module &mLLVMModule;
402};
403
404} // namespace spirit
405} // namespace android
406
407namespace rs2spirv {
408
409android::spirit::Pass* CreateWrapperPass(const llvm::Module &LLVMModule) {
410  return new android::spirit::WrapperPass(LLVMModule);
411}
412
413} // namespace rs2spirv
414