1/*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "Builtin.h"
18
19#include "cxxabi.h"
20#include "spirit.h"
21#include "transformer.h"
22
23#include <stdint.h>
24
25#include <map>
26#include <string>
27#include <vector>
28
29namespace android {
30namespace spirit {
31
32namespace {
33
34Instruction *translateClampVector(const char *name,
35                                  const FunctionCallInst *call, Transformer *tr,
36                                  Builder *b, Module *m) {
37  int width = name[10] - '0';
38  if (width < 2 || width > 4) {
39    return nullptr;
40  }
41
42  uint32_t extOpCode = 0;
43  switch (name[strlen(name) - 1]) {
44  case 'f':
45    extOpCode = 43;
46    break; // FClamp
47  // TODO: Do we get _Z5clampDV_uuu at all? Does LLVM convert u into i?
48  case 'u':
49    extOpCode = 44;
50    break; // UClamp
51  case 'i':
52    extOpCode = 45;
53    break; // SClamp
54  default:
55    return nullptr;
56  }
57
58  std::vector<IdRef> minConstituents(width, call->mOperand2[1]);
59  std::unique_ptr<Instruction> min(
60      b->MakeCompositeConstruct(call->mResultType, minConstituents));
61  tr->insert(min.get());
62
63  std::vector<IdRef> maxConstituents(width, call->mOperand2[2]);
64  std::unique_ptr<Instruction> max(
65      b->MakeCompositeConstruct(call->mResultType, maxConstituents));
66  tr->insert(max.get());
67
68  std::vector<IdRef> extOpnds = {call->mOperand2[0], min.get(), max.get()};
69  return b->MakeExtInst(call->mResultType, m->getGLExt(), extOpCode, extOpnds);
70}
71
72Instruction *translateExtInst(const uint32_t extOpCode,
73                              const FunctionCallInst *call, Builder *b,
74                              Module *m) {
75  return b->MakeExtInst(call->mResultType, m->getGLExt(), extOpCode,
76                        {call->mOperand2[0]});
77}
78
79} // anonymous namespace
80
81typedef std::function<Instruction *(const char *, const FunctionCallInst *,
82                                    Transformer *, Builder *, Module *)>
83    InstTrTy;
84
85class BuiltinLookupTable {
86public:
87  BuiltinLookupTable() {
88    for (sNameCode const *p = &mFPMathFuncOpCode[0]; p->name; p++) {
89      const char *name = p->name;
90      const uint32_t extOpCode = p->code;
91      addMapping(name, {"*"}, {{"float+"}}, {1, 2, 3, 4},
92                 [extOpCode](const char *, const FunctionCallInst *call,
93                             Transformer *, Builder *b, Module *m) {
94                   return translateExtInst(extOpCode, call, b, m);
95                 });
96    }
97
98    addMapping("abs", {"*"}, {{"int+"}, {"char+"}}, {1, 2, 3, 4},
99               [](const char *, const FunctionCallInst *call, Transformer *,
100                  Builder *b, Module *m) {
101                 return translateExtInst(5, call, b, m); // SAbs
102               });
103
104    addMapping("clamp", {"*"},
105               {{"int+", "int", "int"}, {"float+", "float", "float"}},
106               {1, 2, 3, 4}, [](const char *name, const FunctionCallInst *call,
107                                Transformer *tr, Builder *b, Module *m) {
108                 return translateClampVector(name, call, tr, b, m);
109               });
110
111    addMapping("convert", {"char+", "int+", "uchar+", "uint+"},
112               {{"char+"}, {"int+"}, {"uchar+"}, {"uint+"}}, {1, 2, 3, 4},
113               [](const char *, const FunctionCallInst *call, Transformer *,
114                  Builder *b, Module *) -> Instruction * {
115                 return b->MakeUConvert(call->mResultType, call->mOperand2[0]);
116               });
117
118    addMapping(
119        "convert", {"char+", "int+", "uchar+", "uint+"}, {{"float+"}},
120        {1, 2, 3, 4}, [](const char *, const FunctionCallInst *call,
121                         Transformer *, Builder *b, Module *) -> Instruction * {
122          return b->MakeConvertFToU(call->mResultType, call->mOperand2[0]);
123        });
124
125    addMapping(
126        "convert", {"float+"}, {{"char+"}, {"int+"}, {"uchar+"}, {"uint+"}},
127        {1, 2, 3, 4}, [](const char *, const FunctionCallInst *call,
128                         Transformer *, Builder *b, Module *) {
129          return b->MakeConvertUToF(call->mResultType, call->mOperand2[0]);
130        });
131
132    addMapping("dot", {"*"}, {{"float+"}}, {1, 2, 3, 4},
133               [](const char *, const FunctionCallInst *call, Transformer *,
134                  Builder *b, Module *) {
135                 return b->MakeDot(call->mResultType, call->mOperand2[0],
136                                   call->mOperand2[1]);
137               });
138
139    addMapping("min", {"*"}, {{"uint+"}, {"uchar+"}}, {1, 2, 3, 4},
140               [](const char *, const FunctionCallInst *call, Transformer *,
141                  Builder *b, Module *m) {
142                 return translateExtInst(38, call, b, m); // UMin
143               });
144
145    addMapping("min", {"*"}, {{"int+"}, {"char+"}}, {1, 2, 3, 4},
146               [](const char *, const FunctionCallInst *call, Transformer *,
147                  Builder *b, Module *m) {
148                 return translateExtInst(39, call, b, m); // SMin
149               });
150
151    addMapping("max", {"*"}, {{"uint+"}, {"uchar+"}}, {1, 2, 3, 4},
152               [](const char *, const FunctionCallInst *call, Transformer *,
153                  Builder *b, Module *m) {
154                 return translateExtInst(41, call, b, m); // UMax
155               });
156
157    addMapping("max", {"*"}, {{"int+"}, {"char+"}}, {1, 2, 3, 4},
158               [](const char *, const FunctionCallInst *call, Transformer *,
159                  Builder *b, Module *m) {
160                 return translateExtInst(42, call, b, m); // SMax
161               });
162
163    addMapping("rsUnpackColor8888", {"*"}, {{"uchar+"}}, {4},
164               [](const char *, const FunctionCallInst *call, Transformer *,
165                  Builder *b, Module *m) {
166                 auto cast = b->MakeBitcast(m->getUnsignedIntType(32),
167                                            call->mOperand2[0]);
168                 return b->MakeExtInst(call->mResultType, m->getGLExt(), 64,
169                                       {cast}); // UnpackUnorm4x8
170               });
171
172    addMapping("rsPackColorTo8888", {"*"}, {{"float+"}}, {4},
173               [](const char *, const FunctionCallInst *call, Transformer *,
174                  Builder *b, Module *m) {
175                 // PackUnorm4x8
176                 auto packed = b->MakeExtInst(call->mResultType, m->getGLExt(),
177                                              55, {call->mOperand2[0]});
178                 return b->MakeBitcast(
179                     m->getVectorType(m->getUnsignedIntType(8), 4), packed);
180               });
181  }
182
183  static const BuiltinLookupTable &getInstance() {
184    static BuiltinLookupTable table;
185    return table;
186  }
187
188  void addMapping(const char *funcName,
189                  const std::vector<std::string> &retTypes,
190                  const std::vector<std::vector<std::string>> &argTypes,
191                  const std::vector<uint8_t> &vecWidths, InstTrTy fp) {
192    for (auto width : vecWidths) {
193      for (auto retType : retTypes) {
194        std::string suffixed(funcName);
195        if (retType != "*") {
196          if (retType.back() == '+') {
197            retType.pop_back();
198            if (width > 1) {
199              retType.append(1, '0' + width);
200            }
201          }
202          suffixed.append("_").append(retType);
203        }
204
205        for (auto argList : argTypes) {
206          std::string args("(");
207          bool first = true;
208          for (auto argType : argList) {
209            if (first) {
210              first = false;
211            } else {
212              args.append(", ");
213            }
214            if (argType.front() == 'u') {
215              argType.replace(0, 1, "unsigned ");
216            }
217            if (argType.back() == '+') {
218              argType.pop_back();
219              if (width > 1) {
220                argType.append(" vector[");
221                argType.append(1, '0' + width);
222                argType.append("]");
223              }
224            }
225            args.append(argType);
226          }
227          args.append(")");
228          mFuncNameMap[suffixed + args] = fp;
229        }
230      }
231    }
232  }
233
234  InstTrTy lookupTranslation(const char *mangled) const {
235    const char *demangled =
236        __cxxabiv1::__cxa_demangle(mangled, nullptr, nullptr, nullptr);
237
238    if (!demangled) {
239      // All RS runtime/builtin functions are overloaded, therefore
240      // name-mangled.
241      return nullptr;
242    }
243
244    std::string strDemangled(demangled);
245
246    auto it = mFuncNameMap.find(strDemangled);
247    if (it == mFuncNameMap.end()) {
248      return nullptr;
249    }
250    return it->second;
251  }
252
253private:
254  std::map<std::string, InstTrTy> mFuncNameMap;
255
256  struct sNameCode {
257    const char *name;
258    uint32_t code;
259  };
260
261  static sNameCode constexpr mFPMathFuncOpCode[] = {
262      {"abs", 4},        {"sin", 13},   {"cos", 14},   {"tan", 15},
263      {"asin", 16},      {"acos", 17},  {"atan", 18},  {"sinh", 19},
264      {"cosh", 20},      {"tanh", 21},  {"asinh", 22}, {"acosh", 23},
265      {"atanh", 24},     {"atan2", 25}, {"pow", 26},   {"exp", 27},
266      {"log", 28},       {"exp2", 29},  {"log2", 30},  {"sqrt", 31},
267      {"modf", 35},      {"min", 37},   {"max", 40},   {"length", 66},
268      {"normalize", 69}, {nullptr, 0},
269  };
270
271}; // BuiltinLookupTable
272
273BuiltinLookupTable::sNameCode constexpr BuiltinLookupTable::mFPMathFuncOpCode[];
274
275class BuiltinTransformer : public Transformer {
276public:
277  // BEGIN: cleanup unrelated to builtin functions, but necessary for LLVM-SPIRV
278  // converter generated code.
279
280  // TODO: Move these in its own pass
281
282  std::vector<uint32_t> runAndSerialize(Module *module, int *error) override {
283    module->addExtInstImport("GLSL.std.450");
284    return Transformer::runAndSerialize(module, error);
285  }
286
287  Instruction *transform(CapabilityInst *inst) override {
288    if (inst->mOperand1 == Capability::Linkage ||
289        inst->mOperand1 == Capability::Kernel) {
290      return nullptr;
291    }
292    return inst;
293  }
294
295  Instruction *transform(ExtInstImportInst *inst) override {
296    if (inst->mOperand1.compare("OpenCL.std") == 0) {
297      return nullptr;
298    }
299    return inst;
300  }
301
302  Instruction *transform(SourceInst *inst) override {
303    if (inst->mOperand1 == SourceLanguage::Unknown) {
304      return nullptr;
305    }
306    return inst;
307  }
308
309  Instruction *transform(DecorateInst *inst) override {
310    if (inst->mOperand2 == Decoration::LinkageAttributes ||
311        inst->mOperand2 == Decoration::Alignment) {
312      return nullptr;
313    }
314    return inst;
315  }
316
317  // END: cleanup unrelated to builtin functions
318
319  Instruction *transform(FunctionCallInst *call) {
320    FunctionInst *func =
321        static_cast<FunctionInst *>(call->mOperand1.mInstruction);
322    // TODO: attach name to the instruction to avoid linear search in the debug
323    // section, i.e.,
324    // const char *name = func->getName();
325    const char *name = getModule()->lookupNameByInstruction(func);
326    if (!name) {
327      return call;
328    }
329
330    // Maps name into a SPIR-V instruction
331    auto fpTranslate =
332        BuiltinLookupTable::getInstance().lookupTranslation(name);
333    if (!fpTranslate) {
334      return call;
335    }
336    Instruction *inst = fpTranslate(name, call, this, &mBuilder, getModule());
337
338    if (inst) {
339      inst->setId(call->getId());
340    }
341
342    return inst;
343  }
344
345private:
346  Builder mBuilder;
347};
348
349} // namespace spirit
350} // namespace android
351
352namespace rs2spirv {
353
354android::spirit::Pass *CreateBuiltinPass() {
355  return new android::spirit::BuiltinTransformer();
356}
357
358} // namespace rs2spirv
359
360