1/*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "module.h"
18
19#include <set>
20
21#include "builder.h"
22#include "core_defs.h"
23#include "instructions.h"
24#include "types_generated.h"
25#include "word_stream.h"
26
27namespace android {
28namespace spirit {
29
30Module *Module::mInstance = nullptr;
31
32Module *Module::getCurrentModule() {
33  if (mInstance == nullptr) {
34    return mInstance = new Module();
35  }
36  return mInstance;
37}
38
39Module::Module()
40    : mNextId(1), mCapabilitiesDeleter(mCapabilities),
41      mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
42      mEntryPointInstsDeleter(mEntryPointInsts),
43      mExecutionModesDeleter(mExecutionModes),
44      mEntryPointsDeleter(mEntryPoints),
45      mFunctionDefinitionsDeleter(mFunctionDefinitions) {
46  mInstance = this;
47}
48
49Module::Module(Builder *b)
50    : Entity(b), mNextId(1), mCapabilitiesDeleter(mCapabilities),
51      mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
52      mEntryPointInstsDeleter(mEntryPointInsts),
53      mExecutionModesDeleter(mExecutionModes),
54      mEntryPointsDeleter(mEntryPoints),
55      mFunctionDefinitionsDeleter(mFunctionDefinitions) {
56  mInstance = this;
57}
58
59bool Module::resolveIds() {
60  auto &table = mIdTable;
61
62  std::unique_ptr<IVisitor> v0(
63      CreateInstructionVisitor([&table](Instruction *inst) {
64        if (inst->hasResult()) {
65          table.insert(std::make_pair(inst->getId(), inst));
66        }
67      }));
68  v0->visit(this);
69
70  mNextId = mIdTable.rbegin()->first + 1;
71
72  int err = 0;
73  std::unique_ptr<IVisitor> v(
74      CreateInstructionVisitor([&table, &err](Instruction *inst) {
75        for (auto ref : inst->getAllIdRefs()) {
76          if (ref) {
77            auto it = table.find(ref->mId);
78            if (it != table.end()) {
79              ref->mInstruction = it->second;
80            } else {
81              std::cout << "Found no instruction for id " << ref->mId
82                        << std::endl;
83              err++;
84            }
85          }
86        }
87      }));
88  v->visit(this);
89  return err == 0;
90}
91
92bool Module::DeserializeInternal(InputWordStream &IS) {
93  if (IS.empty()) {
94    return false;
95  }
96
97  IS >> &mMagicNumber;
98  if (mMagicNumber != 0x07230203) {
99    errs() << "Wrong Magic Number: " << mMagicNumber;
100    return false;
101  }
102
103  if (IS.empty()) {
104    return false;
105  }
106
107  IS >> &mVersion.mWord;
108  if (mVersion.mBytes[0] != 0 || mVersion.mBytes[3] != 0) {
109    return false;
110  }
111
112  if (IS.empty()) {
113    return false;
114  }
115
116  IS >> &mGeneratorMagicNumber >> &mBound >> &mReserved;
117
118  DeserializeZeroOrMore<CapabilityInst>(IS, mCapabilities);
119  DeserializeZeroOrMore<ExtensionInst>(IS, mExtensions);
120  DeserializeZeroOrMore<ExtInstImportInst>(IS, mExtInstImports);
121
122  mMemoryModel.reset(Deserialize<MemoryModelInst>(IS));
123  if (!mMemoryModel) {
124    errs() << "Missing memory model specification.\n";
125    return false;
126  }
127
128  DeserializeZeroOrMore<EntryPointDefinition>(IS, mEntryPoints);
129  DeserializeZeroOrMore<ExecutionModeInst>(IS, mExecutionModes);
130  for (auto entry : mEntryPoints) {
131    mEntryPointInsts.push_back(entry->getInstruction());
132    for (auto mode : mExecutionModes) {
133      entry->applyExecutionMode(mode);
134    }
135  }
136
137  mDebugInfo.reset(Deserialize<DebugInfoSection>(IS));
138  mAnnotations.reset(Deserialize<AnnotationSection>(IS));
139  mGlobals.reset(Deserialize<GlobalSection>(IS));
140
141  DeserializeZeroOrMore<FunctionDefinition>(IS, mFunctionDefinitions);
142
143  if (mFunctionDefinitions.empty()) {
144    errs() << "Missing function definitions.\n";
145    for (int i = 0; i < 4; i++) {
146      uint32_t w;
147      IS >> &w;
148      std::cout << std::hex << w << " ";
149    }
150    std::cout << std::endl;
151    return false;
152  }
153
154  return true;
155}
156
157void Module::initialize() {
158  mMagicNumber = 0x07230203;
159  mVersion.mMajorMinor = {.mMinorNumber = 1, .mMajorNumber = 1};
160  mGeneratorMagicNumber = 0x00070000;
161  mBound = 0;
162  mReserved = 0;
163  mAnnotations.reset(new AnnotationSection());
164}
165
166void Module::SerializeHeader(OutputWordStream &OS) const {
167  OS << mMagicNumber;
168  OS << mVersion.mWord << mGeneratorMagicNumber;
169  if (mBound == 0) {
170    OS << mIdTable.end()->first + 1;
171  } else {
172    OS << std::max(mBound, mNextId);
173  }
174  OS << mReserved;
175}
176
177void Module::Serialize(OutputWordStream &OS) const {
178  SerializeHeader(OS);
179  Entity::Serialize(OS);
180}
181
182Module *Module::addCapability(Capability cap) {
183  mCapabilities.push_back(mBuilder->MakeCapability(cap));
184  return this;
185}
186
187Module *Module::setMemoryModel(AddressingModel am, MemoryModel mm) {
188  mMemoryModel.reset(mBuilder->MakeMemoryModel(am, mm));
189  return this;
190}
191
192Module *Module::addExtInstImport(const char *extName) {
193  ExtInstImportInst *extInst = mBuilder->MakeExtInstImport(extName);
194  mExtInstImports.push_back(extInst);
195  if (strcmp(extName, "GLSL.std.450") == 0) {
196    mGLExt = extInst;
197  }
198  return this;
199}
200
201Module *Module::addSource(SourceLanguage lang, int version) {
202  if (!mDebugInfo) {
203    mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
204  }
205  mDebugInfo->addSource(lang, version);
206  return this;
207}
208
209Module *Module::addSourceExtension(const char *ext) {
210  if (!mDebugInfo) {
211    mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
212  }
213  mDebugInfo->addSourceExtension(ext);
214  return this;
215}
216
217Module *Module::addString(const char *str) {
218  if (!mDebugInfo) {
219    mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
220  }
221  mDebugInfo->addString(str);
222  return this;
223}
224
225Module *Module::addEntryPoint(EntryPointDefinition *entry) {
226  mEntryPoints.push_back(entry);
227  auto newModes = entry->getExecutionModes();
228  mExecutionModes.insert(mExecutionModes.end(), newModes.begin(),
229                         newModes.end());
230  return this;
231}
232
233GlobalSection *Module::getGlobalSection() {
234  if (!mGlobals) {
235    mGlobals.reset(new GlobalSection());
236  }
237  return mGlobals.get();
238}
239
240ConstantInst *Module::getConstant(TypeIntInst *type, int32_t value) {
241  return getGlobalSection()->getConstant(type, value);
242}
243
244ConstantInst *Module::getConstant(TypeIntInst *type, uint32_t value) {
245  return getGlobalSection()->getConstant(type, value);
246}
247
248ConstantInst *Module::getConstant(TypeFloatInst *type, float value) {
249  return getGlobalSection()->getConstant(type, value);
250}
251
252ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
253                                                    ConstantInst *components[],
254                                                    size_t width) {
255  return getGlobalSection()->getConstantComposite(type, components, width);
256}
257
258ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
259                                                    ConstantInst *comp0,
260                                                    ConstantInst *comp1,
261                                                    ConstantInst *comp2) {
262  // TODO: verify that component types are the same and consistent with the
263  // resulting vector type
264  ConstantInst *comps[] = {comp0, comp1, comp2};
265  return getConstantComposite(type, comps, 3);
266}
267
268ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
269                                                    ConstantInst *comp0,
270                                                    ConstantInst *comp1,
271                                                    ConstantInst *comp2,
272                                                    ConstantInst *comp3) {
273  // TODO: verify that component types are the same and consistent with the
274  // resulting vector type
275  ConstantInst *comps[] = {comp0, comp1, comp2, comp3};
276  return getConstantComposite(type, comps, 4);
277}
278
279TypeVoidInst *Module::getVoidType() {
280  return getGlobalSection()->getVoidType();
281}
282
283TypeIntInst *Module::getIntType(int bits, bool isSigned) {
284  return getGlobalSection()->getIntType(bits, isSigned);
285}
286
287TypeIntInst *Module::getUnsignedIntType(int bits) {
288  return getIntType(bits, false);
289}
290
291TypeFloatInst *Module::getFloatType(int bits) {
292  return getGlobalSection()->getFloatType(bits);
293}
294
295TypeVectorInst *Module::getVectorType(Instruction *componentType, int width) {
296  return getGlobalSection()->getVectorType(componentType, width);
297}
298
299TypePointerInst *Module::getPointerType(StorageClass storage,
300                                        Instruction *pointeeType) {
301  return getGlobalSection()->getPointerType(storage, pointeeType);
302}
303
304TypeRuntimeArrayInst *Module::getRuntimeArrayType(Instruction *elementType) {
305  return getGlobalSection()->getRuntimeArrayType(elementType);
306}
307
308TypeStructInst *Module::getStructType(Instruction *fieldType[], int numField) {
309  return getGlobalSection()->getStructType(fieldType, numField);
310}
311
312TypeStructInst *Module::getStructType(Instruction *fieldType) {
313  return getStructType(&fieldType, 1);
314}
315
316TypeFunctionInst *Module::getFunctionType(Instruction *retType,
317                                          Instruction *const argType[],
318                                          size_t numArg) {
319  return getGlobalSection()->getFunctionType(retType, argType, numArg);
320}
321
322TypeFunctionInst *
323Module::getFunctionType(Instruction *retType,
324                        const std::vector<Instruction *> &argTypes) {
325  return getGlobalSection()->getFunctionType(retType, argTypes.data(),
326                                             argTypes.size());
327}
328
329size_t Module::getSize(TypeVoidInst *) { return 0; }
330
331size_t Module::getSize(TypeIntInst *intTy) { return intTy->mOperand1 / 8; }
332
333size_t Module::getSize(TypeFloatInst *fpTy) { return fpTy->mOperand1 / 8; }
334
335size_t Module::getSize(TypeVectorInst *vTy) {
336  return getSize(vTy->mOperand1.mInstruction) * vTy->mOperand2;
337}
338
339size_t Module::getSize(TypePointerInst *) {
340  return 4; // TODO: or 8?
341}
342
343size_t Module::getSize(TypeStructInst *structTy) {
344  size_t sz = 0;
345  for (auto ty : structTy->mOperand1) {
346    sz += getSize(ty.mInstruction);
347  }
348  return sz;
349}
350
351size_t Module::getSize(TypeFunctionInst *) {
352  return 4; // TODO: or 8? Is this just the size of a pointer?
353}
354
355size_t Module::getSize(Instruction *inst) {
356  switch (inst->getOpCode()) {
357  case OpTypeVoid:
358    return getSize(static_cast<TypeVoidInst *>(inst));
359  case OpTypeInt:
360    return getSize(static_cast<TypeIntInst *>(inst));
361  case OpTypeFloat:
362    return getSize(static_cast<TypeFloatInst *>(inst));
363  case OpTypeVector:
364    return getSize(static_cast<TypeVectorInst *>(inst));
365  case OpTypeStruct:
366    return getSize(static_cast<TypeStructInst *>(inst));
367  case OpTypeFunction:
368    return getSize(static_cast<TypeFunctionInst *>(inst));
369  default:
370    return 0;
371  }
372}
373
374Module *Module::addFunctionDefinition(FunctionDefinition *func) {
375  mFunctionDefinitions.push_back(func);
376  return this;
377}
378
379Instruction *Module::lookupByName(const char *name) const {
380  return mDebugInfo->lookupByName(name);
381}
382
383FunctionDefinition *
384Module::getFunctionDefinitionFromInstruction(FunctionInst *inst) const {
385  for (auto fdef : mFunctionDefinitions) {
386    if (fdef->getInstruction() == inst) {
387      return fdef;
388    }
389  }
390  return nullptr;
391}
392
393FunctionDefinition *
394Module::lookupFunctionDefinitionByName(const char *name) const {
395  FunctionInst *inst = static_cast<FunctionInst *>(lookupByName(name));
396  return getFunctionDefinitionFromInstruction(inst);
397}
398
399const char *Module::lookupNameByInstruction(const Instruction *inst) const {
400  return mDebugInfo->lookupNameByInstruction(inst);
401}
402
403VariableInst *Module::getInvocationId() {
404  return getGlobalSection()->getInvocationId();
405}
406
407VariableInst *Module::getNumWorkgroups() {
408  return getGlobalSection()->getNumWorkgroups();
409}
410
411Module *Module::addStructType(TypeStructInst *structType) {
412  getGlobalSection()->addStructType(structType);
413  return this;
414}
415
416Module *Module::addVariable(VariableInst *var) {
417  getGlobalSection()->addVariable(var);
418  return this;
419}
420
421void Module::consolidateAnnotations() {
422  std::vector<Instruction *> annotations(mAnnotations->begin(),
423                                      mAnnotations->end());
424  std::unique_ptr<IVisitor> v(
425      CreateInstructionVisitor([&annotations](Instruction *inst) -> void {
426        const auto &ann = inst->getAnnotations();
427        annotations.insert(annotations.end(), ann.begin(), ann.end());
428      }));
429  v->visit(this);
430  mAnnotations->clear();
431  mAnnotations->addAnnotations(annotations.begin(), annotations.end());
432}
433
434EntryPointDefinition::EntryPointDefinition(Builder *builder,
435                                           ExecutionModel execModel,
436                                           FunctionDefinition *func,
437                                           const char *name)
438    : Entity(builder), mFunction(func->getInstruction()),
439      mExecutionModel(execModel) {
440  mName = strndup(name, strlen(name));
441  mEntryPointInst = mBuilder->MakeEntryPoint(execModel, mFunction, mName);
442}
443
444bool EntryPointDefinition::DeserializeInternal(InputWordStream &IS) {
445  if (IS.empty()) {
446    return false;
447  }
448
449  if ((mEntryPointInst = Deserialize<EntryPointInst>(IS))) {
450    return true;
451  }
452
453  return false;
454}
455
456EntryPointDefinition *
457EntryPointDefinition::applyExecutionMode(ExecutionModeInst *mode) {
458  if (mode->mOperand1.mInstruction == mFunction) {
459    addExecutionMode(mode);
460  }
461  return this;
462}
463
464EntryPointDefinition *EntryPointDefinition::addToInterface(VariableInst *var) {
465  mInterface.push_back(var);
466  mEntryPointInst->mOperand4.push_back(var);
467  return this;
468}
469
470EntryPointDefinition *EntryPointDefinition::setLocalSize(uint32_t width,
471                                                         uint32_t height,
472                                                         uint32_t depth) {
473  mLocalSize.mWidth = width;
474  mLocalSize.mHeight = height;
475  mLocalSize.mDepth = depth;
476
477  auto mode = mBuilder->MakeExecutionMode(mFunction, ExecutionMode::LocalSize);
478  mode->addExtraOperand(width)->addExtraOperand(height)->addExtraOperand(depth);
479
480  addExecutionMode(mode);
481
482  return this;
483}
484
485bool DebugInfoSection::DeserializeInternal(InputWordStream &IS) {
486  while (true) {
487    if (auto str = Deserialize<StringInst>(IS)) {
488      mSources.push_back(str);
489    } else if (auto src = Deserialize<SourceInst>(IS)) {
490      mSources.push_back(src);
491    } else if (auto srcExt = Deserialize<SourceExtensionInst>(IS)) {
492      mSources.push_back(srcExt);
493    } else if (auto srcCont = Deserialize<SourceContinuedInst>(IS)) {
494      mSources.push_back(srcCont);
495    } else {
496      break;
497    }
498  }
499
500  while (true) {
501    if (auto name = Deserialize<NameInst>(IS)) {
502      mNames.push_back(name);
503    } else if (auto memName = Deserialize<MemberNameInst>(IS)) {
504      mNames.push_back(memName);
505    } else {
506      break;
507    }
508  }
509
510  return true;
511}
512
513DebugInfoSection *DebugInfoSection::addSource(SourceLanguage lang,
514                                              int version) {
515  SourceInst *source = mBuilder->MakeSource(lang, version);
516  mSources.push_back(source);
517  return this;
518}
519
520DebugInfoSection *DebugInfoSection::addSourceExtension(const char *ext) {
521  SourceExtensionInst *inst = mBuilder->MakeSourceExtension(ext);
522  mSources.push_back(inst);
523  return this;
524}
525
526DebugInfoSection *DebugInfoSection::addString(const char *str) {
527  StringInst *source = mBuilder->MakeString(str);
528  mSources.push_back(source);
529  return this;
530}
531
532Instruction *DebugInfoSection::lookupByName(const char *name) const {
533  for (auto inst : mNames) {
534    if (inst->getOpCode() == OpName) {
535      NameInst *nameInst = static_cast<NameInst *>(inst);
536      if (nameInst->mOperand2.compare(name) == 0) {
537        return nameInst->mOperand1.mInstruction;
538      }
539    }
540    // Ignore member names
541  }
542  return nullptr;
543}
544
545const char *
546DebugInfoSection::lookupNameByInstruction(const Instruction *target) const {
547  for (auto inst : mNames) {
548    if (inst->getOpCode() == OpName) {
549      NameInst *nameInst = static_cast<NameInst *>(inst);
550      if (nameInst->mOperand1.mInstruction == target) {
551        return nameInst->mOperand2.c_str();
552      }
553    }
554    // Ignore member names
555  }
556  return nullptr;
557}
558
559AnnotationSection::AnnotationSection() : mAnnotationsDeleter(mAnnotations) {}
560
561AnnotationSection::AnnotationSection(Builder *b)
562    : Entity(b), mAnnotationsDeleter(mAnnotations) {}
563
564bool AnnotationSection::DeserializeInternal(InputWordStream &IS) {
565  while (true) {
566    if (auto decor = Deserialize<DecorateInst>(IS)) {
567      mAnnotations.push_back(decor);
568    } else if (auto decor = Deserialize<MemberDecorateInst>(IS)) {
569      mAnnotations.push_back(decor);
570    } else if (auto decor = Deserialize<GroupDecorateInst>(IS)) {
571      mAnnotations.push_back(decor);
572    } else if (auto decor = Deserialize<GroupMemberDecorateInst>(IS)) {
573      mAnnotations.push_back(decor);
574    } else if (auto decor = Deserialize<DecorationGroupInst>(IS)) {
575      mAnnotations.push_back(decor);
576    } else {
577      break;
578    }
579  }
580  return true;
581}
582
583GlobalSection::GlobalSection() : mGlobalDefsDeleter(mGlobalDefs) {}
584
585GlobalSection::GlobalSection(Builder *builder)
586    : Entity(builder), mGlobalDefsDeleter(mGlobalDefs) {}
587
588namespace {
589
590template <typename T>
591T *findOrCreate(std::function<bool(T *)> criteria, std::function<T *()> factory,
592                std::vector<Instruction *> *globals) {
593  T *derived;
594  for (auto inst : *globals) {
595    if (inst->getOpCode() == T::mOpCode) {
596      T *derived = static_cast<T *>(inst);
597      if (criteria(derived)) {
598        return derived;
599      }
600    }
601  }
602  derived = factory();
603  globals->push_back(derived);
604  return derived;
605}
606
607} // anonymous namespace
608
609bool GlobalSection::DeserializeInternal(InputWordStream &IS) {
610  while (true) {
611#define HANDLE_INSTRUCTION(OPCODE, INST_CLASS)                                 \
612  if (auto typeInst = Deserialize<INST_CLASS>(IS)) {                           \
613    mGlobalDefs.push_back(typeInst);                                           \
614    continue;                                                                  \
615  }
616#include "const_inst_dispatches_generated.h"
617#include "type_inst_dispatches_generated.h"
618#undef HANDLE_INSTRUCTION
619
620    if (auto globalInst = Deserialize<VariableInst>(IS)) {
621      // Check if this is function scoped
622      if (globalInst->mOperand1 == StorageClass::Function) {
623        Module::errs() << "warning: Variable (id = " << globalInst->mResult;
624        Module::errs() << ") has function scope in global section.\n";
625        // Khronos LLVM-SPIRV convertor emits "Function" storage-class globals.
626        // As a workaround, accept such SPIR-V code here, and fix it up later
627        // in the rs2spirv compiler by correcting the storage class.
628        // In a stricter deserializer, such code should be rejected, and we
629        // should return false here.
630      }
631      mGlobalDefs.push_back(globalInst);
632      continue;
633    }
634
635    if (auto globalInst = Deserialize<UndefInst>(IS)) {
636      mGlobalDefs.push_back(globalInst);
637      continue;
638    }
639    break;
640  }
641  return true;
642}
643
644ConstantInst *GlobalSection::getConstant(TypeIntInst *type, int32_t value) {
645  return findOrCreate<ConstantInst>(
646      [=](ConstantInst *c) { return c->mOperand1.intValue == value; },
647      [=]() -> ConstantInst * {
648        LiteralContextDependentNumber cdn = {.intValue = value};
649        return mBuilder->MakeConstant(type, cdn);
650      },
651      &mGlobalDefs);
652}
653
654ConstantInst *GlobalSection::getConstant(TypeIntInst *type, uint32_t value) {
655  return findOrCreate<ConstantInst>(
656      [=](ConstantInst *c) { return c->mOperand1.intValue == (int)value; },
657      [=]() -> ConstantInst * {
658        LiteralContextDependentNumber cdn = {.intValue = (int)value};
659        return mBuilder->MakeConstant(type, cdn);
660      },
661      &mGlobalDefs);
662}
663
664ConstantInst *GlobalSection::getConstant(TypeFloatInst *type, float value) {
665  return findOrCreate<ConstantInst>(
666      [=](ConstantInst *c) { return c->mOperand1.floatValue == value; },
667      [=]() -> ConstantInst * {
668        LiteralContextDependentNumber cdn = {.floatValue = value};
669        return mBuilder->MakeConstant(type, cdn);
670      },
671      &mGlobalDefs);
672}
673
674ConstantCompositeInst *
675GlobalSection::getConstantComposite(TypeVectorInst *type,
676                                    ConstantInst *components[], size_t width) {
677  return findOrCreate<ConstantCompositeInst>(
678      [=](ConstantCompositeInst *c) {
679        if (c->mOperand1.size() != width) {
680          return false;
681        }
682        for (size_t i = 0; i < width; i++) {
683          if (c->mOperand1[i].mInstruction != components[i]) {
684            return false;
685          }
686        }
687        return true;
688      },
689      [=]() -> ConstantCompositeInst * {
690        ConstantCompositeInst *c = mBuilder->MakeConstantComposite(type);
691        for (size_t i = 0; i < width; i++) {
692          c->mOperand1.push_back(components[i]);
693        }
694        return c;
695      },
696      &mGlobalDefs);
697}
698
699TypeVoidInst *GlobalSection::getVoidType() {
700  return findOrCreate<TypeVoidInst>(
701      [=](TypeVoidInst *) -> bool { return true; },
702      [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); },
703      &mGlobalDefs);
704}
705
706TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) {
707  if (isSigned) {
708    switch (bits) {
709#define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED)                                \
710  case BITS: {                                                                 \
711    return findOrCreate<TypeIntInst>(                                          \
712        [=](TypeIntInst *intTy) -> bool {                                      \
713          return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED;       \
714        },                                                                     \
715        [=]() -> TypeIntInst * {                                               \
716          return mBuilder->MakeTypeInt(BITS, SIGNED);                          \
717        },                                                                     \
718        &mGlobalDefs);                                                         \
719  }
720      HANDLE_INT_SIZE(Int, 8, 1);
721      HANDLE_INT_SIZE(Int, 16, 1);
722      HANDLE_INT_SIZE(Int, 32, 1);
723      HANDLE_INT_SIZE(Int, 64, 1);
724    default:
725      Module::errs() << "unexpected int type";
726    }
727  } else {
728    switch (bits) {
729      HANDLE_INT_SIZE(UInt, 8, 0);
730      HANDLE_INT_SIZE(UInt, 16, 0);
731      HANDLE_INT_SIZE(UInt, 32, 0);
732      HANDLE_INT_SIZE(UInt, 64, 0);
733    default:
734      Module::errs() << "unexpected int type";
735    }
736  }
737#undef HANDLE_INT_SIZE
738  return nullptr;
739}
740
741TypeFloatInst *GlobalSection::getFloatType(int bits) {
742  switch (bits) {
743#define HANDLE_FLOAT_SIZE(BITS)                                                \
744  case BITS: {                                                                 \
745    return findOrCreate<TypeFloatInst>(                                        \
746        [=](TypeFloatInst *floatTy) -> bool {                                  \
747          return floatTy->mOperand1 == BITS;                                   \
748        },                                                                     \
749        [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); },    \
750        &mGlobalDefs);                                                         \
751  }
752    HANDLE_FLOAT_SIZE(16);
753    HANDLE_FLOAT_SIZE(32);
754    HANDLE_FLOAT_SIZE(64);
755  default:
756    Module::errs() << "unexpeced floating point type";
757  }
758#undef HANDLE_FLOAT_SIZE
759  return nullptr;
760}
761
762TypeVectorInst *GlobalSection::getVectorType(Instruction *componentType,
763                                             int width) {
764  // TODO: verify that componentType is basic numeric types
765
766  return findOrCreate<TypeVectorInst>(
767      [=](TypeVectorInst *vecTy) -> bool {
768        return vecTy->mOperand1.mInstruction == componentType &&
769               vecTy->mOperand2 == width;
770      },
771      [=]() -> TypeVectorInst * {
772        return mBuilder->MakeTypeVector(componentType, width);
773      },
774      &mGlobalDefs);
775}
776
777TypePointerInst *GlobalSection::getPointerType(StorageClass storage,
778                                               Instruction *pointeeType) {
779  return findOrCreate<TypePointerInst>(
780      [=](TypePointerInst *type) -> bool {
781        return type->mOperand1 == storage &&
782               type->mOperand2.mInstruction == pointeeType;
783      },
784      [=]() -> TypePointerInst * {
785        return mBuilder->MakeTypePointer(storage, pointeeType);
786      },
787      &mGlobalDefs);
788}
789
790TypeRuntimeArrayInst *
791GlobalSection::getRuntimeArrayType(Instruction *elemType) {
792  return findOrCreate<TypeRuntimeArrayInst>(
793      [=](TypeRuntimeArrayInst * /*type*/) -> bool {
794        // return type->mOperand1.mInstruction == elemType;
795        return false;
796      },
797      [=]() -> TypeRuntimeArrayInst * {
798        return mBuilder->MakeTypeRuntimeArray(elemType);
799      },
800      &mGlobalDefs);
801}
802
803TypeStructInst *GlobalSection::getStructType(Instruction *fieldType[],
804                                             int numField) {
805  TypeStructInst *structTy = mBuilder->MakeTypeStruct();
806  for (int i = 0; i < numField; i++) {
807    structTy->mOperand1.push_back(fieldType[i]);
808  }
809  mGlobalDefs.push_back(structTy);
810  return structTy;
811}
812
813TypeFunctionInst *GlobalSection::getFunctionType(Instruction *retType,
814                                                 Instruction *const argType[],
815                                                 size_t numArg) {
816  return findOrCreate<TypeFunctionInst>(
817      [=](TypeFunctionInst *type) -> bool {
818        if (type->mOperand1.mInstruction != retType ||
819            type->mOperand2.size() != numArg) {
820          return false;
821        }
822        for (size_t i = 0; i < numArg; i++) {
823          if (type->mOperand2[i].mInstruction != argType[i]) {
824            return false;
825          }
826        }
827        return true;
828      },
829      [=]() -> TypeFunctionInst * {
830        TypeFunctionInst *funcTy = mBuilder->MakeTypeFunction(retType);
831        for (size_t i = 0; i < numArg; i++) {
832          funcTy->mOperand2.push_back(argType[i]);
833        }
834        return funcTy;
835      },
836      &mGlobalDefs);
837}
838
839GlobalSection *GlobalSection::addStructType(TypeStructInst *structType) {
840  mGlobalDefs.push_back(structType);
841  return this;
842}
843
844GlobalSection *GlobalSection::addVariable(VariableInst *var) {
845  mGlobalDefs.push_back(var);
846  return this;
847}
848
849VariableInst *GlobalSection::getInvocationId() {
850  if (mInvocationId) {
851    return mInvocationId.get();
852  }
853
854  TypeIntInst *UIntTy = getIntType(32, false);
855  TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
856  TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
857
858  VariableInst *InvocationId =
859      mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
860  InvocationId->decorate(Decoration::BuiltIn)
861      ->addExtraOperand(static_cast<uint32_t>(BuiltIn::GlobalInvocationId));
862
863  mInvocationId.reset(InvocationId);
864
865  return InvocationId;
866}
867
868VariableInst *GlobalSection::getNumWorkgroups() {
869  if (mNumWorkgroups) {
870    return mNumWorkgroups.get();
871  }
872
873  TypeIntInst *UIntTy = getIntType(32, false);
874  TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
875  TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
876
877  VariableInst *GNum = mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
878  GNum->decorate(Decoration::BuiltIn)
879      ->addExtraOperand(static_cast<uint32_t>(BuiltIn::NumWorkgroups));
880
881  mNumWorkgroups.reset(GNum);
882
883  return GNum;
884}
885
886bool FunctionDeclaration::DeserializeInternal(InputWordStream &IS) {
887  if (!Deserialize<FunctionInst>(IS)) {
888    return false;
889  }
890
891  DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
892
893  if (!Deserialize<FunctionEndInst>(IS)) {
894    return false;
895  }
896
897  return true;
898}
899
900template <> Instruction *Deserialize(InputWordStream &IS) {
901  Instruction *inst;
902
903  switch ((*IS) & 0xFFFF) {
904#define HANDLE_INSTRUCTION(OPCODE, INST_CLASS)                                 \
905  case OPCODE:                                                                 \
906    inst = Deserialize<INST_CLASS>(IS);                                        \
907    break;
908#include "instruction_dispatches_generated.h"
909#undef HANDLE_INSTRUCTION
910  default:
911    Module::errs() << "unrecognized instruction";
912    inst = nullptr;
913  }
914
915  return inst;
916}
917
918bool Block::DeserializeInternal(InputWordStream &IS) {
919  Instruction *inst;
920  while (((*IS) & 0xFFFF) != OpFunctionEnd &&
921         (inst = Deserialize<Instruction>(IS))) {
922    mInsts.push_back(inst);
923    if (inst->getOpCode() == OpBranch ||
924        inst->getOpCode() == OpBranchConditional ||
925        inst->getOpCode() == OpSwitch || inst->getOpCode() == OpKill ||
926        inst->getOpCode() == OpReturn || inst->getOpCode() == OpReturnValue ||
927        inst->getOpCode() == OpUnreachable) {
928      break;
929    }
930  }
931  return !mInsts.empty();
932}
933
934FunctionDefinition::FunctionDefinition()
935    : mParamsDeleter(mParams), mBlocksDeleter(mBlocks) {}
936
937FunctionDefinition::FunctionDefinition(Builder *builder, FunctionInst *func,
938                                       FunctionEndInst *end)
939    : Entity(builder), mFunc(func), mFuncEnd(end), mParamsDeleter(mParams),
940      mBlocksDeleter(mBlocks) {}
941
942bool FunctionDefinition::DeserializeInternal(InputWordStream &IS) {
943  mFunc.reset(Deserialize<FunctionInst>(IS));
944  if (!mFunc) {
945    return false;
946  }
947
948  DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
949  DeserializeZeroOrMore<Block>(IS, mBlocks);
950
951  mFuncEnd.reset(Deserialize<FunctionEndInst>(IS));
952  if (!mFuncEnd) {
953    return false;
954  }
955
956  return true;
957}
958
959Instruction *FunctionDefinition::getReturnType() const {
960  return mFunc->mResultType.mInstruction;
961}
962
963} // namespace spirit
964} // namespace android
965