1//
2//Copyright (C) 2015 LunarG, Inc.
3//
4//All rights reserved.
5//
6//Redistribution and use in source and binary forms, with or without
7//modification, are permitted provided that the following conditions
8//are met:
9//
10//    Redistributions of source code must retain the above copyright
11//    notice, this list of conditions and the following disclaimer.
12//
13//    Redistributions in binary form must reproduce the above
14//    copyright notice, this list of conditions and the following
15//    disclaimer in the documentation and/or other materials provided
16//    with the distribution.
17//
18//    Neither the name of 3Dlabs Inc. Ltd. nor the names of its
19//    contributors may be used to endorse or promote products derived
20//    from this software without specific prior written permission.
21//
22//THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23//"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24//LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
25//FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
26//COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
27//INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
28//BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
29//LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
30//CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31//LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
32//ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33//POSSIBILITY OF SUCH DAMAGE.
34//
35
36#include "SPVRemapper.h"
37#include "doc.h"
38
39#if !defined (use_cpp11)
40// ... not supported before C++11
41#else // defined (use_cpp11)
42
43#include <algorithm>
44#include <cassert>
45#include "../glslang/Include/Common.h"
46
47namespace spv {
48
49    // By default, just abort on error.  Can be overridden via RegisterErrorHandler
50    spirvbin_t::errorfn_t spirvbin_t::errorHandler = [](const std::string&) { exit(5); };
51    // By default, eat log messages.  Can be overridden via RegisterLogHandler
52    spirvbin_t::logfn_t   spirvbin_t::logHandler   = [](const std::string&) { };
53
54    // This can be overridden to provide other message behavior if needed
55    void spirvbin_t::msg(int minVerbosity, int indent, const std::string& txt) const
56    {
57        if (verbose >= minVerbosity)
58            logHandler(std::string(indent, ' ') + txt);
59    }
60
61    // hash opcode, with special handling for OpExtInst
62    std::uint32_t spirvbin_t::asOpCodeHash(unsigned word)
63    {
64        const spv::Op opCode = asOpCode(word);
65
66        std::uint32_t offset = 0;
67
68        switch (opCode) {
69        case spv::OpExtInst:
70            offset += asId(word + 4); break;
71        default:
72            break;
73        }
74
75        return opCode * 19 + offset; // 19 = small prime
76    }
77
78    spirvbin_t::range_t spirvbin_t::literalRange(spv::Op opCode) const
79    {
80        static const int maxCount = 1<<30;
81
82        switch (opCode) {
83        case spv::OpTypeFloat:        // fall through...
84        case spv::OpTypePointer:      return range_t(2, 3);
85        case spv::OpTypeInt:          return range_t(2, 4);
86        // TODO: case spv::OpTypeImage:
87        // TODO: case spv::OpTypeSampledImage:
88        case spv::OpTypeSampler:      return range_t(3, 8);
89        case spv::OpTypeVector:       // fall through
90        case spv::OpTypeMatrix:       // ...
91        case spv::OpTypePipe:         return range_t(3, 4);
92        case spv::OpConstant:         return range_t(3, maxCount);
93        default:                      return range_t(0, 0);
94        }
95    }
96
97    spirvbin_t::range_t spirvbin_t::typeRange(spv::Op opCode) const
98    {
99        static const int maxCount = 1<<30;
100
101        if (isConstOp(opCode))
102            return range_t(1, 2);
103
104        switch (opCode) {
105        case spv::OpTypeVector:       // fall through
106        case spv::OpTypeMatrix:       // ...
107        case spv::OpTypeSampler:      // ...
108        case spv::OpTypeArray:        // ...
109        case spv::OpTypeRuntimeArray: // ...
110        case spv::OpTypePipe:         return range_t(2, 3);
111        case spv::OpTypeStruct:       // fall through
112        case spv::OpTypeFunction:     return range_t(2, maxCount);
113        case spv::OpTypePointer:      return range_t(3, 4);
114        default:                      return range_t(0, 0);
115        }
116    }
117
118    spirvbin_t::range_t spirvbin_t::constRange(spv::Op opCode) const
119    {
120        static const int maxCount = 1<<30;
121
122        switch (opCode) {
123        case spv::OpTypeArray:         // fall through...
124        case spv::OpTypeRuntimeArray:  return range_t(3, 4);
125        case spv::OpConstantComposite: return range_t(3, maxCount);
126        default:                       return range_t(0, 0);
127        }
128    }
129
130    // Is this an opcode we should remove when using --strip?
131    bool spirvbin_t::isStripOp(spv::Op opCode) const
132    {
133        switch (opCode) {
134        case spv::OpSource:
135        case spv::OpSourceExtension:
136        case spv::OpName:
137        case spv::OpMemberName:
138        case spv::OpLine:           return true;
139        default:                    return false;
140        }
141    }
142
143    bool spirvbin_t::isFlowCtrl(spv::Op opCode) const
144    {
145        switch (opCode) {
146        case spv::OpBranchConditional:
147        case spv::OpBranch:
148        case spv::OpSwitch:
149        case spv::OpLoopMerge:
150        case spv::OpSelectionMerge:
151        case spv::OpLabel:
152        case spv::OpFunction:
153        case spv::OpFunctionEnd:    return true;
154        default:                    return false;
155        }
156    }
157
158    bool spirvbin_t::isTypeOp(spv::Op opCode) const
159    {
160        switch (opCode) {
161        case spv::OpTypeVoid:
162        case spv::OpTypeBool:
163        case spv::OpTypeInt:
164        case spv::OpTypeFloat:
165        case spv::OpTypeVector:
166        case spv::OpTypeMatrix:
167        case spv::OpTypeImage:
168        case spv::OpTypeSampler:
169        case spv::OpTypeArray:
170        case spv::OpTypeRuntimeArray:
171        case spv::OpTypeStruct:
172        case spv::OpTypeOpaque:
173        case spv::OpTypePointer:
174        case spv::OpTypeFunction:
175        case spv::OpTypeEvent:
176        case spv::OpTypeDeviceEvent:
177        case spv::OpTypeReserveId:
178        case spv::OpTypeQueue:
179        case spv::OpTypeSampledImage:
180        case spv::OpTypePipe:         return true;
181        default:                      return false;
182        }
183    }
184
185    bool spirvbin_t::isConstOp(spv::Op opCode) const
186    {
187        switch (opCode) {
188        case spv::OpConstantNull:       error("unimplemented constant type");
189        case spv::OpConstantSampler:    error("unimplemented constant type");
190
191        case spv::OpConstantTrue:
192        case spv::OpConstantFalse:
193        case spv::OpConstantComposite:
194        case spv::OpConstant:         return true;
195        default:                      return false;
196        }
197    }
198
199    const auto inst_fn_nop = [](spv::Op, unsigned) { return false; };
200    const auto op_fn_nop   = [](spv::Id&)          { };
201
202    // g++ doesn't like these defined in the class proper in an anonymous namespace.
203    // Dunno why.  Also MSVC doesn't like the constexpr keyword.  Also dunno why.
204    // Defining them externally seems to please both compilers, so, here they are.
205    const spv::Id spirvbin_t::unmapped    = spv::Id(-10000);
206    const spv::Id spirvbin_t::unused      = spv::Id(-10001);
207    const int     spirvbin_t::header_size = 5;
208
209    spv::Id spirvbin_t::nextUnusedId(spv::Id id)
210    {
211        while (isNewIdMapped(id))  // search for an unused ID
212            ++id;
213
214        return id;
215    }
216
217    spv::Id spirvbin_t::localId(spv::Id id, spv::Id newId)
218    {
219        assert(id != spv::NoResult && newId != spv::NoResult);
220
221        if (id >= idMapL.size())
222            idMapL.resize(id+1, unused);
223
224        if (newId != unmapped && newId != unused) {
225            if (isOldIdUnused(id))
226                error(std::string("ID unused in module: ") + std::to_string(id));
227
228            if (!isOldIdUnmapped(id))
229                error(std::string("ID already mapped: ") + std::to_string(id) + " -> "
230                + std::to_string(localId(id)));
231
232            if (isNewIdMapped(newId))
233                error(std::string("ID already used in module: ") + std::to_string(newId));
234
235            msg(4, 4, std::string("map: ") + std::to_string(id) + " -> " + std::to_string(newId));
236            setMapped(newId);
237            largestNewId = std::max(largestNewId, newId);
238        }
239
240        return idMapL[id] = newId;
241    }
242
243    // Parse a literal string from the SPIR binary and return it as an std::string
244    // Due to C++11 RValue references, this doesn't copy the result string.
245    std::string spirvbin_t::literalString(unsigned word) const
246    {
247        std::string literal;
248
249        literal.reserve(16);
250
251        const char* bytes = reinterpret_cast<const char*>(spv.data() + word);
252
253        while (bytes && *bytes)
254            literal += *bytes++;
255
256        return literal;
257    }
258
259
260    void spirvbin_t::applyMap()
261    {
262        msg(3, 2, std::string("Applying map: "));
263
264        // Map local IDs through the ID map
265        process(inst_fn_nop, // ignore instructions
266            [this](spv::Id& id) {
267                id = localId(id);
268                assert(id != unused && id != unmapped);
269            }
270        );
271    }
272
273
274    // Find free IDs for anything we haven't mapped
275    void spirvbin_t::mapRemainder()
276    {
277        msg(3, 2, std::string("Remapping remainder: "));
278
279        spv::Id     unusedId  = 1;  // can't use 0: that's NoResult
280        spirword_t  maxBound  = 0;
281
282        for (spv::Id id = 0; id < idMapL.size(); ++id) {
283            if (isOldIdUnused(id))
284                continue;
285
286            // Find a new mapping for any used but unmapped IDs
287            if (isOldIdUnmapped(id))
288                localId(id, unusedId = nextUnusedId(unusedId));
289
290            if (isOldIdUnmapped(id))
291                error(std::string("old ID not mapped: ") + std::to_string(id));
292
293            // Track max bound
294            maxBound = std::max(maxBound, localId(id) + 1);
295        }
296
297        bound(maxBound); // reset header ID bound to as big as it now needs to be
298    }
299
300    void spirvbin_t::stripDebug()
301    {
302        if ((options & STRIP) == 0)
303            return;
304
305        // build local Id and name maps
306        process(
307            [&](spv::Op opCode, unsigned start) {
308                // remember opcodes we want to strip later
309                if (isStripOp(opCode))
310                    stripInst(start);
311                return true;
312            },
313            op_fn_nop);
314    }
315
316    void spirvbin_t::buildLocalMaps()
317    {
318        msg(2, 2, std::string("build local maps: "));
319
320        mapped.clear();
321        idMapL.clear();
322//      preserve nameMap, so we don't clear that.
323        fnPos.clear();
324        fnPosDCE.clear();
325        fnCalls.clear();
326        typeConstPos.clear();
327        typeConstPosR.clear();
328        entryPoint = spv::NoResult;
329        largestNewId = 0;
330
331        idMapL.resize(bound(), unused);
332
333        int         fnStart = 0;
334        spv::Id     fnRes   = spv::NoResult;
335
336        // build local Id and name maps
337        process(
338            [&](spv::Op opCode, unsigned start) {
339                // remember opcodes we want to strip later
340                if ((options & STRIP) && isStripOp(opCode))
341                    stripInst(start);
342
343                if (opCode == spv::Op::OpName) {
344                    const spv::Id    target = asId(start+1);
345                    const std::string  name = literalString(start+2);
346                    nameMap[name] = target;
347
348                } else if (opCode == spv::Op::OpFunctionCall) {
349                    ++fnCalls[asId(start + 3)];
350                } else if (opCode == spv::Op::OpEntryPoint) {
351                    entryPoint = asId(start + 2);
352                } else if (opCode == spv::Op::OpFunction) {
353                    if (fnStart != 0)
354                        error("nested function found");
355                    fnStart = start;
356                    fnRes   = asId(start + 2);
357                } else if (opCode == spv::Op::OpFunctionEnd) {
358                    assert(fnRes != spv::NoResult);
359                    if (fnStart == 0)
360                        error("function end without function start");
361                    fnPos[fnRes] = range_t(fnStart, start + asWordCount(start));
362                    fnStart = 0;
363                } else if (isConstOp(opCode)) {
364                    assert(asId(start + 2) != spv::NoResult);
365                    typeConstPos.insert(start);
366                    typeConstPosR[asId(start + 2)] = start;
367                } else if (isTypeOp(opCode)) {
368                    assert(asId(start + 1) != spv::NoResult);
369                    typeConstPos.insert(start);
370                    typeConstPosR[asId(start + 1)] = start;
371                }
372
373                return false;
374            },
375
376            [this](spv::Id& id) { localId(id, unmapped); }
377        );
378    }
379
380    // Validate the SPIR header
381    void spirvbin_t::validate() const
382    {
383        msg(2, 2, std::string("validating: "));
384
385        if (spv.size() < header_size)
386            error("file too short: ");
387
388        if (magic() != spv::MagicNumber)
389            error("bad magic number");
390
391        // field 1 = version
392        // field 2 = generator magic
393        // field 3 = result <id> bound
394
395        if (schemaNum() != 0)
396            error("bad schema, must be 0");
397    }
398
399
400    int spirvbin_t::processInstruction(unsigned word, instfn_t instFn, idfn_t idFn)
401    {
402        const auto     instructionStart = word;
403        const unsigned wordCount = asWordCount(instructionStart);
404        const spv::Op  opCode    = asOpCode(instructionStart);
405        const int      nextInst  = word++ + wordCount;
406
407        if (nextInst > int(spv.size()))
408            error("spir instruction terminated too early");
409
410        // Base for computing number of operands; will be updated as more is learned
411        unsigned numOperands = wordCount - 1;
412
413        if (instFn(opCode, instructionStart))
414            return nextInst;
415
416        // Read type and result ID from instruction desc table
417        if (spv::InstructionDesc[opCode].hasType()) {
418            idFn(asId(word++));
419            --numOperands;
420        }
421
422        if (spv::InstructionDesc[opCode].hasResult()) {
423            idFn(asId(word++));
424            --numOperands;
425        }
426
427        // Extended instructions: currently, assume everything is an ID.
428        // TODO: add whatever data we need for exceptions to that
429        if (opCode == spv::OpExtInst) {
430            word        += 2; // instruction set, and instruction from set
431            numOperands -= 2;
432
433            for (unsigned op=0; op < numOperands; ++op)
434                idFn(asId(word++)); // ID
435
436            return nextInst;
437        }
438
439        // Store IDs from instruction in our map
440        for (int op = 0; numOperands > 0; ++op, --numOperands) {
441            switch (spv::InstructionDesc[opCode].operands.getClass(op)) {
442            case spv::OperandId:
443                idFn(asId(word++));
444                break;
445
446            case spv::OperandVariableIds:
447                for (unsigned i = 0; i < numOperands; ++i)
448                    idFn(asId(word++));
449                return nextInst;
450
451            case spv::OperandVariableLiterals:
452                // for clarity
453                // if (opCode == spv::OpDecorate && asDecoration(word - 1) == spv::DecorationBuiltIn) {
454                //     ++word;
455                //     --numOperands;
456                // }
457                // word += numOperands;
458                return nextInst;
459
460            case spv::OperandVariableLiteralId:
461                while (numOperands > 0) {
462                    ++word;             // immediate
463                    idFn(asId(word++)); // ID
464                    numOperands -= 2;
465                }
466                return nextInst;
467
468            case spv::OperandLiteralString: {
469                const int stringWordCount = literalStringWords(literalString(word));
470                word += stringWordCount;
471                numOperands -= (stringWordCount-1); // -1 because for() header post-decrements
472                break;
473            }
474
475            // Execution mode might have extra literal operands.  Skip them.
476            case spv::OperandExecutionMode:
477                return nextInst;
478
479            // Single word operands we simply ignore, as they hold no IDs
480            case spv::OperandLiteralNumber:
481            case spv::OperandSource:
482            case spv::OperandExecutionModel:
483            case spv::OperandAddressing:
484            case spv::OperandMemory:
485            case spv::OperandStorage:
486            case spv::OperandDimensionality:
487            case spv::OperandSamplerAddressingMode:
488            case spv::OperandSamplerFilterMode:
489            case spv::OperandSamplerImageFormat:
490            case spv::OperandImageChannelOrder:
491            case spv::OperandImageChannelDataType:
492            case spv::OperandImageOperands:
493            case spv::OperandFPFastMath:
494            case spv::OperandFPRoundingMode:
495            case spv::OperandLinkageType:
496            case spv::OperandAccessQualifier:
497            case spv::OperandFuncParamAttr:
498            case spv::OperandDecoration:
499            case spv::OperandBuiltIn:
500            case spv::OperandSelect:
501            case spv::OperandLoop:
502            case spv::OperandFunction:
503            case spv::OperandMemorySemantics:
504            case spv::OperandMemoryAccess:
505            case spv::OperandScope:
506            case spv::OperandGroupOperation:
507            case spv::OperandKernelEnqueueFlags:
508            case spv::OperandKernelProfilingInfo:
509            case spv::OperandCapability:
510                ++word;
511                break;
512
513            default:
514                assert(0 && "Unhandled Operand Class");
515                break;
516            }
517        }
518
519        return nextInst;
520    }
521
522    // Make a pass over all the instructions and process them given appropriate functions
523    spirvbin_t& spirvbin_t::process(instfn_t instFn, idfn_t idFn, unsigned begin, unsigned end)
524    {
525        // For efficiency, reserve name map space.  It can grow if needed.
526        nameMap.reserve(32);
527
528        // If begin or end == 0, use defaults
529        begin = (begin == 0 ? header_size          : begin);
530        end   = (end   == 0 ? unsigned(spv.size()) : end);
531
532        // basic parsing and InstructionDesc table borrowed from SpvDisassemble.cpp...
533        unsigned nextInst = unsigned(spv.size());
534
535        for (unsigned word = begin; word < end; word = nextInst)
536            nextInst = processInstruction(word, instFn, idFn);
537
538        return *this;
539    }
540
541    // Apply global name mapping to a single module
542    void spirvbin_t::mapNames()
543    {
544        static const std::uint32_t softTypeIdLimit = 3011;  // small prime.  TODO: get from options
545        static const std::uint32_t firstMappedID   = 3019;  // offset into ID space
546
547        for (const auto& name : nameMap) {
548            std::uint32_t hashval = 1911;
549            for (const char c : name.first)
550                hashval = hashval * 1009 + c;
551
552            if (isOldIdUnmapped(name.second))
553                localId(name.second, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
554        }
555    }
556
557    // Map fn contents to IDs of similar functions in other modules
558    void spirvbin_t::mapFnBodies()
559    {
560        static const std::uint32_t softTypeIdLimit = 19071;  // small prime.  TODO: get from options
561        static const std::uint32_t firstMappedID   =  6203;  // offset into ID space
562
563        // Initial approach: go through some high priority opcodes first and assign them
564        // hash values.
565
566        spv::Id               fnId       = spv::NoResult;
567        std::vector<unsigned> instPos;
568        instPos.reserve(unsigned(spv.size()) / 16); // initial estimate; can grow if needed.
569
570        // Build local table of instruction start positions
571        process(
572            [&](spv::Op, unsigned start) { instPos.push_back(start); return true; },
573            op_fn_nop);
574
575        // Window size for context-sensitive canonicalization values
576        // Empirical best size from a single data set.  TODO: Would be a good tunable.
577        // We essentially perform a little convolution around each instruction,
578        // to capture the flavor of nearby code, to hopefully match to similar
579        // code in other modules.
580        static const unsigned windowSize = 2;
581
582        for (unsigned entry = 0; entry < unsigned(instPos.size()); ++entry) {
583            const unsigned start  = instPos[entry];
584            const spv::Op  opCode = asOpCode(start);
585
586            if (opCode == spv::OpFunction)
587                fnId   = asId(start + 2);
588
589            if (opCode == spv::OpFunctionEnd)
590                fnId = spv::NoResult;
591
592            if (fnId != spv::NoResult) { // if inside a function
593                if (spv::InstructionDesc[opCode].hasResult()) {
594                    const unsigned word    = start + (spv::InstructionDesc[opCode].hasType() ? 2 : 1);
595                    const spv::Id  resId   = asId(word);
596                    std::uint32_t  hashval = fnId * 17; // small prime
597
598                    for (unsigned i = entry-1; i >= entry-windowSize; --i) {
599                        if (asOpCode(instPos[i]) == spv::OpFunction)
600                            break;
601                        hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime
602                    }
603
604                    for (unsigned i = entry; i <= entry + windowSize; ++i) {
605                        if (asOpCode(instPos[i]) == spv::OpFunctionEnd)
606                            break;
607                        hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime
608                    }
609
610                    if (isOldIdUnmapped(resId))
611                        localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
612                }
613            }
614        }
615
616        spv::Op          thisOpCode(spv::OpNop);
617        std::unordered_map<int, int> opCounter;
618        int              idCounter(0);
619        fnId = spv::NoResult;
620
621        process(
622            [&](spv::Op opCode, unsigned start) {
623                switch (opCode) {
624                case spv::OpFunction:
625                    // Reset counters at each function
626                    idCounter = 0;
627                    opCounter.clear();
628                    fnId = asId(start + 2);
629                    break;
630
631                case spv::OpImageSampleImplicitLod:
632                case spv::OpImageSampleExplicitLod:
633                case spv::OpImageSampleDrefImplicitLod:
634                case spv::OpImageSampleDrefExplicitLod:
635                case spv::OpImageSampleProjImplicitLod:
636                case spv::OpImageSampleProjExplicitLod:
637                case spv::OpImageSampleProjDrefImplicitLod:
638                case spv::OpImageSampleProjDrefExplicitLod:
639                case spv::OpDot:
640                case spv::OpCompositeExtract:
641                case spv::OpCompositeInsert:
642                case spv::OpVectorShuffle:
643                case spv::OpLabel:
644                case spv::OpVariable:
645
646                case spv::OpAccessChain:
647                case spv::OpLoad:
648                case spv::OpStore:
649                case spv::OpCompositeConstruct:
650                case spv::OpFunctionCall:
651                    ++opCounter[opCode];
652                    idCounter = 0;
653                    thisOpCode = opCode;
654                    break;
655                default:
656                    thisOpCode = spv::OpNop;
657                }
658
659                return false;
660            },
661
662            [&](spv::Id& id) {
663                if (thisOpCode != spv::OpNop) {
664                    ++idCounter;
665                    const std::uint32_t hashval = opCounter[thisOpCode] * thisOpCode * 50047 + idCounter + fnId * 117;
666
667                    if (isOldIdUnmapped(id))
668                        localId(id, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
669                }
670            });
671    }
672
673    // EXPERIMENTAL: forward IO and uniform load/stores into operands
674    // This produces invalid Schema-0 SPIRV
675    void spirvbin_t::forwardLoadStores()
676    {
677        idset_t fnLocalVars; // set of function local vars
678        idmap_t idMap;       // Map of load result IDs to what they load
679
680        // EXPERIMENTAL: Forward input and access chain loads into consumptions
681        process(
682            [&](spv::Op opCode, unsigned start) {
683                // Add inputs and uniforms to the map
684                if ((opCode == spv::OpVariable && asWordCount(start) == 4) &&
685                    (spv[start+3] == spv::StorageClassUniform ||
686                    spv[start+3] == spv::StorageClassUniformConstant ||
687                    spv[start+3] == spv::StorageClassInput))
688                    fnLocalVars.insert(asId(start+2));
689
690                if (opCode == spv::OpAccessChain && fnLocalVars.count(asId(start+3)) > 0)
691                    fnLocalVars.insert(asId(start+2));
692
693                if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {
694                    idMap[asId(start+2)] = asId(start+3);
695                    stripInst(start);
696                }
697
698                return false;
699            },
700
701            [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
702        );
703
704        // EXPERIMENTAL: Implicit output stores
705        fnLocalVars.clear();
706        idMap.clear();
707
708        process(
709            [&](spv::Op opCode, unsigned start) {
710                // Add inputs and uniforms to the map
711                if ((opCode == spv::OpVariable && asWordCount(start) == 4) &&
712                    (spv[start+3] == spv::StorageClassOutput))
713                    fnLocalVars.insert(asId(start+2));
714
715                if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {
716                    idMap[asId(start+2)] = asId(start+1);
717                    stripInst(start);
718                }
719
720                return false;
721            },
722            op_fn_nop);
723
724        process(
725            inst_fn_nop,
726            [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
727        );
728
729        strip();          // strip out data we decided to eliminate
730    }
731
732    // optimize loads and stores
733    void spirvbin_t::optLoadStore()
734    {
735        idset_t    fnLocalVars;  // candidates for removal (only locals)
736        idmap_t    idMap;        // Map of load result IDs to what they load
737        blockmap_t blockMap;     // Map of IDs to blocks they first appear in
738        int        blockNum = 0; // block count, to avoid crossing flow control
739
740        // Find all the function local pointers stored at most once, and not via access chains
741        process(
742            [&](spv::Op opCode, unsigned start) {
743                const int wordCount = asWordCount(start);
744
745                // Count blocks, so we can avoid crossing flow control
746                if (isFlowCtrl(opCode))
747                    ++blockNum;
748
749                // Add local variables to the map
750                if ((opCode == spv::OpVariable && spv[start+3] == spv::StorageClassFunction && asWordCount(start) == 4)) {
751                    fnLocalVars.insert(asId(start+2));
752                    return true;
753                }
754
755                // Ignore process vars referenced via access chain
756                if ((opCode == spv::OpAccessChain || opCode == spv::OpInBoundsAccessChain) && fnLocalVars.count(asId(start+3)) > 0) {
757                    fnLocalVars.erase(asId(start+3));
758                    idMap.erase(asId(start+3));
759                    return true;
760                }
761
762                if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {
763                    const spv::Id varId = asId(start+3);
764
765                    // Avoid loads before stores
766                    if (idMap.find(varId) == idMap.end()) {
767                        fnLocalVars.erase(varId);
768                        idMap.erase(varId);
769                    }
770
771                    // don't do for volatile references
772                    if (wordCount > 4 && (spv[start+4] & spv::MemoryAccessVolatileMask)) {
773                        fnLocalVars.erase(varId);
774                        idMap.erase(varId);
775                    }
776
777                    // Handle flow control
778                    if (blockMap.find(varId) == blockMap.end()) {
779                        blockMap[varId] = blockNum;  // track block we found it in.
780                    } else if (blockMap[varId] != blockNum) {
781                        fnLocalVars.erase(varId);  // Ignore if crosses flow control
782                        idMap.erase(varId);
783                    }
784
785                    return true;
786                }
787
788                if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {
789                    const spv::Id varId = asId(start+1);
790
791                    if (idMap.find(varId) == idMap.end()) {
792                        idMap[varId] = asId(start+2);
793                    } else {
794                        // Remove if it has more than one store to the same pointer
795                        fnLocalVars.erase(varId);
796                        idMap.erase(varId);
797                    }
798
799                    // don't do for volatile references
800                    if (wordCount > 3 && (spv[start+3] & spv::MemoryAccessVolatileMask)) {
801                        fnLocalVars.erase(asId(start+3));
802                        idMap.erase(asId(start+3));
803                    }
804
805                    // Handle flow control
806                    if (blockMap.find(varId) == blockMap.end()) {
807                        blockMap[varId] = blockNum;  // track block we found it in.
808                    } else if (blockMap[varId] != blockNum) {
809                        fnLocalVars.erase(varId);  // Ignore if crosses flow control
810                        idMap.erase(varId);
811                    }
812
813                    return true;
814                }
815
816                return false;
817            },
818
819            // If local var id used anywhere else, don't eliminate
820            [&](spv::Id& id) {
821                if (fnLocalVars.count(id) > 0) {
822                    fnLocalVars.erase(id);
823                    idMap.erase(id);
824                }
825            }
826        );
827
828        process(
829            [&](spv::Op opCode, unsigned start) {
830                if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0)
831                    idMap[asId(start+2)] = idMap[asId(start+3)];
832                return false;
833            },
834            op_fn_nop);
835
836        // Chase replacements to their origins, in case there is a chain such as:
837        //   2 = store 1
838        //   3 = load 2
839        //   4 = store 3
840        //   5 = load 4
841        // We want to replace uses of 5 with 1.
842        for (const auto& idPair : idMap) {
843            spv::Id id = idPair.first;
844            while (idMap.find(id) != idMap.end())  // Chase to end of chain
845                id = idMap[id];
846
847            idMap[idPair.first] = id;              // replace with final result
848        }
849
850        // Remove the load/store/variables for the ones we've discovered
851        process(
852            [&](spv::Op opCode, unsigned start) {
853                if ((opCode == spv::OpLoad  && fnLocalVars.count(asId(start+3)) > 0) ||
854                    (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) ||
855                    (opCode == spv::OpVariable && fnLocalVars.count(asId(start+2)) > 0)) {
856
857                    stripInst(start);
858                    return true;
859                }
860
861                return false;
862            },
863
864            [&](spv::Id& id) {
865                if (idMap.find(id) != idMap.end()) id = idMap[id];
866            }
867        );
868
869        strip();          // strip out data we decided to eliminate
870    }
871
872    // remove bodies of uncalled functions
873    void spirvbin_t::dceFuncs()
874    {
875        msg(3, 2, std::string("Removing Dead Functions: "));
876
877        // TODO: There are more efficient ways to do this.
878        bool changed = true;
879
880        while (changed) {
881            changed = false;
882
883            for (auto fn = fnPos.begin(); fn != fnPos.end(); ) {
884                if (fn->first == entryPoint) { // don't DCE away the entry point!
885                    ++fn;
886                    continue;
887                }
888
889                const auto call_it = fnCalls.find(fn->first);
890
891                if (call_it == fnCalls.end() || call_it->second == 0) {
892                    changed = true;
893                    stripRange.push_back(fn->second);
894                    fnPosDCE.insert(*fn);
895
896                    // decrease counts of called functions
897                    process(
898                        [&](spv::Op opCode, unsigned start) {
899                            if (opCode == spv::Op::OpFunctionCall) {
900                                const auto call_it = fnCalls.find(asId(start + 3));
901                                if (call_it != fnCalls.end()) {
902                                    if (--call_it->second <= 0)
903                                        fnCalls.erase(call_it);
904                                }
905                            }
906
907                            return true;
908                        },
909                        op_fn_nop,
910                        fn->second.first,
911                        fn->second.second);
912
913                    fn = fnPos.erase(fn);
914                } else ++fn;
915            }
916        }
917    }
918
919    // remove unused function variables + decorations
920    void spirvbin_t::dceVars()
921    {
922        msg(3, 2, std::string("DCE Vars: "));
923
924        std::unordered_map<spv::Id, int> varUseCount;
925
926        // Count function variable use
927        process(
928            [&](spv::Op opCode, unsigned start) {
929                if (opCode == spv::OpVariable) {
930                    ++varUseCount[asId(start+2)];
931                    return true;
932                } else if (opCode == spv::OpEntryPoint) {
933                    const int wordCount = asWordCount(start);
934                    for (int i = 4; i < wordCount; i++) {
935                        ++varUseCount[asId(start+i)];
936                    }
937                    return true;
938                } else
939                    return false;
940            },
941
942            [&](spv::Id& id) { if (varUseCount[id]) ++varUseCount[id]; }
943        );
944
945        // Remove single-use function variables + associated decorations and names
946        process(
947            [&](spv::Op opCode, unsigned start) {
948                if ((opCode == spv::OpVariable && varUseCount[asId(start+2)] == 1)  ||
949                    (opCode == spv::OpDecorate && varUseCount[asId(start+1)] == 1)  ||
950                    (opCode == spv::OpName     && varUseCount[asId(start+1)] == 1)) {
951                        stripInst(start);
952                }
953                return true;
954            },
955            op_fn_nop);
956    }
957
958    // remove unused types
959    void spirvbin_t::dceTypes()
960    {
961        std::vector<bool> isType(bound(), false);
962
963        // for speed, make O(1) way to get to type query (map is log(n))
964        for (const auto typeStart : typeConstPos)
965            isType[asTypeConstId(typeStart)] = true;
966
967        std::unordered_map<spv::Id, int> typeUseCount;
968
969        // Count total type usage
970        process(inst_fn_nop,
971            [&](spv::Id& id) { if (isType[id]) ++typeUseCount[id]; }
972        );
973
974        // Remove types from deleted code
975        for (const auto& fn : fnPosDCE)
976            process(inst_fn_nop,
977            [&](spv::Id& id) { if (isType[id]) --typeUseCount[id]; },
978            fn.second.first, fn.second.second);
979
980        // Remove single reference types
981        for (const auto typeStart : typeConstPos) {
982            const spv::Id typeId = asTypeConstId(typeStart);
983            if (typeUseCount[typeId] == 1) {
984                --typeUseCount[typeId];
985                stripInst(typeStart);
986            }
987        }
988    }
989
990
991#ifdef NOTDEF
992    bool spirvbin_t::matchType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt, spv::Id gt) const
993    {
994        // Find the local type id "lt" and global type id "gt"
995        const auto lt_it = typeConstPosR.find(lt);
996        if (lt_it == typeConstPosR.end())
997            return false;
998
999        const auto typeStart = lt_it->second;
1000
1001        // Search for entry in global table
1002        const auto gtype = globalTypes.find(gt);
1003        if (gtype == globalTypes.end())
1004            return false;
1005
1006        const auto& gdata = gtype->second;
1007
1008        // local wordcount and opcode
1009        const int     wordCount   = asWordCount(typeStart);
1010        const spv::Op opCode      = asOpCode(typeStart);
1011
1012        // no type match if opcodes don't match, or operand count doesn't match
1013        if (opCode != opOpCode(gdata[0]) || wordCount != opWordCount(gdata[0]))
1014            return false;
1015
1016        const unsigned numOperands = wordCount - 2; // all types have a result
1017
1018        const auto cmpIdRange = [&](range_t range) {
1019            for (int x=range.first; x<std::min(range.second, wordCount); ++x)
1020                if (!matchType(globalTypes, asId(typeStart+x), gdata[x]))
1021                    return false;
1022            return true;
1023        };
1024
1025        const auto cmpConst   = [&]() { return cmpIdRange(constRange(opCode)); };
1026        const auto cmpSubType = [&]() { return cmpIdRange(typeRange(opCode));  };
1027
1028        // Compare literals in range [start,end)
1029        const auto cmpLiteral = [&]() {
1030            const auto range = literalRange(opCode);
1031            return std::equal(spir.begin() + typeStart + range.first,
1032                spir.begin() + typeStart + std::min(range.second, wordCount),
1033                gdata.begin() + range.first);
1034        };
1035
1036        assert(isTypeOp(opCode) || isConstOp(opCode));
1037
1038        switch (opCode) {
1039        case spv::OpTypeOpaque:       // TODO: disable until we compare the literal strings.
1040        case spv::OpTypeQueue:        return false;
1041        case spv::OpTypeEvent:        // fall through...
1042        case spv::OpTypeDeviceEvent:  // ...
1043        case spv::OpTypeReserveId:    return false;
1044            // for samplers, we don't handle the optional parameters yet
1045        case spv::OpTypeSampler:      return cmpLiteral() && cmpConst() && cmpSubType() && wordCount == 8;
1046        default:                      return cmpLiteral() && cmpConst() && cmpSubType();
1047        }
1048    }
1049
1050
1051    // Look for an equivalent type in the globalTypes map
1052    spv::Id spirvbin_t::findType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt) const
1053    {
1054        // Try a recursive type match on each in turn, and return a match if we find one
1055        for (const auto& gt : globalTypes)
1056            if (matchType(globalTypes, lt, gt.first))
1057                return gt.first;
1058
1059        return spv::NoType;
1060    }
1061#endif // NOTDEF
1062
1063    // Return start position in SPV of given type.  error if not found.
1064    unsigned spirvbin_t::typePos(spv::Id id) const
1065    {
1066        const auto tid_it = typeConstPosR.find(id);
1067        if (tid_it == typeConstPosR.end())
1068            error("type ID not found");
1069
1070        return tid_it->second;
1071    }
1072
1073    // Hash types to canonical values.  This can return ID collisions (it's a bit
1074    // inevitable): it's up to the caller to handle that gracefully.
1075    std::uint32_t spirvbin_t::hashType(unsigned typeStart) const
1076    {
1077        const unsigned wordCount   = asWordCount(typeStart);
1078        const spv::Op  opCode      = asOpCode(typeStart);
1079
1080        switch (opCode) {
1081        case spv::OpTypeVoid:         return 0;
1082        case spv::OpTypeBool:         return 1;
1083        case spv::OpTypeInt:          return 3 + (spv[typeStart+3]);
1084        case spv::OpTypeFloat:        return 5;
1085        case spv::OpTypeVector:
1086            return 6 + hashType(typePos(spv[typeStart+2])) * (spv[typeStart+3] - 1);
1087        case spv::OpTypeMatrix:
1088            return 30 + hashType(typePos(spv[typeStart+2])) * (spv[typeStart+3] - 1);
1089        case spv::OpTypeImage:
1090            return 120 + hashType(typePos(spv[typeStart+2])) +
1091                spv[typeStart+3] +            // dimensionality
1092                spv[typeStart+4] * 8 * 16 +   // depth
1093                spv[typeStart+5] * 4 * 16 +   // arrayed
1094                spv[typeStart+6] * 2 * 16 +   // multisampled
1095                spv[typeStart+7] * 1 * 16;    // format
1096        case spv::OpTypeSampler:
1097            return 500;
1098        case spv::OpTypeSampledImage:
1099            return 502;
1100        case spv::OpTypeArray:
1101            return 501 + hashType(typePos(spv[typeStart+2])) * spv[typeStart+3];
1102        case spv::OpTypeRuntimeArray:
1103            return 5000  + hashType(typePos(spv[typeStart+2]));
1104        case spv::OpTypeStruct:
1105            {
1106                std::uint32_t hash = 10000;
1107                for (unsigned w=2; w < wordCount; ++w)
1108                    hash += w * hashType(typePos(spv[typeStart+w]));
1109                return hash;
1110            }
1111
1112        case spv::OpTypeOpaque:         return 6000 + spv[typeStart+2];
1113        case spv::OpTypePointer:        return 100000  + hashType(typePos(spv[typeStart+3]));
1114        case spv::OpTypeFunction:
1115            {
1116                std::uint32_t hash = 200000;
1117                for (unsigned w=2; w < wordCount; ++w)
1118                    hash += w * hashType(typePos(spv[typeStart+w]));
1119                return hash;
1120            }
1121
1122        case spv::OpTypeEvent:           return 300000;
1123        case spv::OpTypeDeviceEvent:     return 300001;
1124        case spv::OpTypeReserveId:       return 300002;
1125        case spv::OpTypeQueue:           return 300003;
1126        case spv::OpTypePipe:            return 300004;
1127
1128        case spv::OpConstantNull:        return 300005;
1129        case spv::OpConstantSampler:     return 300006;
1130
1131        case spv::OpConstantTrue:        return 300007;
1132        case spv::OpConstantFalse:       return 300008;
1133        case spv::OpConstantComposite:
1134            {
1135                std::uint32_t hash = 300011 + hashType(typePos(spv[typeStart+1]));
1136                for (unsigned w=3; w < wordCount; ++w)
1137                    hash += w * hashType(typePos(spv[typeStart+w]));
1138                return hash;
1139            }
1140        case spv::OpConstant:
1141            {
1142                std::uint32_t hash = 400011 + hashType(typePos(spv[typeStart+1]));
1143                for (unsigned w=3; w < wordCount; ++w)
1144                    hash += w * spv[typeStart+w];
1145                return hash;
1146            }
1147
1148        default:
1149            error("unknown type opcode");
1150            return 0;
1151        }
1152    }
1153
1154    void spirvbin_t::mapTypeConst()
1155    {
1156        globaltypes_t globalTypeMap;
1157
1158        msg(3, 2, std::string("Remapping Consts & Types: "));
1159
1160        static const std::uint32_t softTypeIdLimit = 3011; // small prime.  TODO: get from options
1161        static const std::uint32_t firstMappedID   = 8;    // offset into ID space
1162
1163        for (auto& typeStart : typeConstPos) {
1164            const spv::Id       resId     = asTypeConstId(typeStart);
1165            const std::uint32_t hashval   = hashType(typeStart);
1166
1167            if (isOldIdUnmapped(resId))
1168                localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
1169        }
1170    }
1171
1172
1173    // Strip a single binary by removing ranges given in stripRange
1174    void spirvbin_t::strip()
1175    {
1176        if (stripRange.empty()) // nothing to do
1177            return;
1178
1179        // Sort strip ranges in order of traversal
1180        std::sort(stripRange.begin(), stripRange.end());
1181
1182        // Allocate a new binary big enough to hold old binary
1183        // We'll step this iterator through the strip ranges as we go through the binary
1184        auto strip_it = stripRange.begin();
1185
1186        int strippedPos = 0;
1187        for (unsigned word = 0; word < unsigned(spv.size()); ++word) {
1188            if (strip_it != stripRange.end() && word >= strip_it->second)
1189                ++strip_it;
1190
1191            if (strip_it == stripRange.end() || word < strip_it->first || word >= strip_it->second)
1192                spv[strippedPos++] = spv[word];
1193        }
1194
1195        spv.resize(strippedPos);
1196        stripRange.clear();
1197
1198        buildLocalMaps();
1199    }
1200
1201    // Strip a single binary by removing ranges given in stripRange
1202    void spirvbin_t::remap(std::uint32_t opts)
1203    {
1204        options = opts;
1205
1206        // Set up opcode tables from SpvDoc
1207        spv::Parameterize();
1208
1209        validate();  // validate header
1210        buildLocalMaps();
1211
1212        msg(3, 4, std::string("ID bound: ") + std::to_string(bound()));
1213
1214        strip();        // strip out data we decided to eliminate
1215
1216        if (options & OPT_LOADSTORE) optLoadStore();
1217        if (options & OPT_FWD_LS)    forwardLoadStores();
1218        if (options & DCE_FUNCS)     dceFuncs();
1219        if (options & DCE_VARS)      dceVars();
1220        if (options & DCE_TYPES)     dceTypes();
1221        if (options & MAP_TYPES)     mapTypeConst();
1222        if (options & MAP_NAMES)     mapNames();
1223        if (options & MAP_FUNCS)     mapFnBodies();
1224
1225        mapRemainder(); // map any unmapped IDs
1226        applyMap();     // Now remap each shader to the new IDs we've come up with
1227        strip();        // strip out data we decided to eliminate
1228    }
1229
1230    // remap from a memory image
1231    void spirvbin_t::remap(std::vector<std::uint32_t>& in_spv, std::uint32_t opts)
1232    {
1233        spv.swap(in_spv);
1234        remap(opts);
1235        spv.swap(in_spv);
1236    }
1237
1238} // namespace SPV
1239
1240#endif // defined (use_cpp11)
1241
1242