SkSLSPIRVCodeGenerator.cpp revision 8a83ca4e9afc9e3c08b4e8c33a74392f9b3154d7
1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "SkSLSPIRVCodeGenerator.h"
9
10#include "GLSL.std.450.h"
11
12#include "ir/SkSLExpressionStatement.h"
13#include "ir/SkSLExtension.h"
14#include "ir/SkSLIndexExpression.h"
15#include "ir/SkSLVariableReference.h"
16#include "SkSLCompiler.h"
17
18namespace SkSL {
19
20static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
21
22void SPIRVCodeGenerator::setupIntrinsics() {
23#define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
24                                    GLSLstd450 ## x, GLSLstd450 ## x)
25#define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
26                                                             GLSLstd450 ## ifFloat, \
27                                                             GLSLstd450 ## ifInt, \
28                                                             GLSLstd450 ## ifUInt, \
29                                                             SpvOpUndef)
30#define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \
31                                                           SpvOp ## x)
32#define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
33                                   k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
34                                   k ## x ## _SpecialIntrinsic)
35    fIntrinsicMap[String("round")]         = ALL_GLSL(Round);
36    fIntrinsicMap[String("roundEven")]     = ALL_GLSL(RoundEven);
37    fIntrinsicMap[String("trunc")]         = ALL_GLSL(Trunc);
38    fIntrinsicMap[String("abs")]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
39    fIntrinsicMap[String("sign")]          = BY_TYPE_GLSL(FSign, SSign, SSign);
40    fIntrinsicMap[String("floor")]         = ALL_GLSL(Floor);
41    fIntrinsicMap[String("ceil")]          = ALL_GLSL(Ceil);
42    fIntrinsicMap[String("fract")]         = ALL_GLSL(Fract);
43    fIntrinsicMap[String("radians")]       = ALL_GLSL(Radians);
44    fIntrinsicMap[String("degrees")]       = ALL_GLSL(Degrees);
45    fIntrinsicMap[String("sin")]           = ALL_GLSL(Sin);
46    fIntrinsicMap[String("cos")]           = ALL_GLSL(Cos);
47    fIntrinsicMap[String("tan")]           = ALL_GLSL(Tan);
48    fIntrinsicMap[String("asin")]          = ALL_GLSL(Asin);
49    fIntrinsicMap[String("acos")]          = ALL_GLSL(Acos);
50    fIntrinsicMap[String("atan")]          = SPECIAL(Atan);
51    fIntrinsicMap[String("sinh")]          = ALL_GLSL(Sinh);
52    fIntrinsicMap[String("cosh")]          = ALL_GLSL(Cosh);
53    fIntrinsicMap[String("tanh")]          = ALL_GLSL(Tanh);
54    fIntrinsicMap[String("asinh")]         = ALL_GLSL(Asinh);
55    fIntrinsicMap[String("acosh")]         = ALL_GLSL(Acosh);
56    fIntrinsicMap[String("atanh")]         = ALL_GLSL(Atanh);
57    fIntrinsicMap[String("pow")]           = ALL_GLSL(Pow);
58    fIntrinsicMap[String("exp")]           = ALL_GLSL(Exp);
59    fIntrinsicMap[String("log")]           = ALL_GLSL(Log);
60    fIntrinsicMap[String("exp2")]          = ALL_GLSL(Exp2);
61    fIntrinsicMap[String("log2")]          = ALL_GLSL(Log2);
62    fIntrinsicMap[String("sqrt")]          = ALL_GLSL(Sqrt);
63    fIntrinsicMap[String("inverse")]       = ALL_GLSL(MatrixInverse);
64    fIntrinsicMap[String("transpose")]     = ALL_SPIRV(Transpose);
65    fIntrinsicMap[String("inversesqrt")]   = ALL_GLSL(InverseSqrt);
66    fIntrinsicMap[String("determinant")]   = ALL_GLSL(Determinant);
67    fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
68    fIntrinsicMap[String("mod")]           = SPECIAL(Mod);
69    fIntrinsicMap[String("min")]           = BY_TYPE_GLSL(FMin, SMin, UMin);
70    fIntrinsicMap[String("max")]           = BY_TYPE_GLSL(FMax, SMax, UMax);
71    fIntrinsicMap[String("clamp")]         = BY_TYPE_GLSL(FClamp, SClamp, UClamp);
72    fIntrinsicMap[String("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
73                                                             SpvOpUndef, SpvOpUndef, SpvOpUndef);
74    fIntrinsicMap[String("mix")]           = ALL_GLSL(FMix);
75    fIntrinsicMap[String("step")]          = ALL_GLSL(Step);
76    fIntrinsicMap[String("smoothstep")]    = ALL_GLSL(SmoothStep);
77    fIntrinsicMap[String("fma")]           = ALL_GLSL(Fma);
78    fIntrinsicMap[String("frexp")]         = ALL_GLSL(Frexp);
79    fIntrinsicMap[String("ldexp")]         = ALL_GLSL(Ldexp);
80
81#define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
82                   fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
83    PACK(Snorm4x8);
84    PACK(Unorm4x8);
85    PACK(Snorm2x16);
86    PACK(Unorm2x16);
87    PACK(Half2x16);
88    PACK(Double2x32);
89    fIntrinsicMap[String("length")]      = ALL_GLSL(Length);
90    fIntrinsicMap[String("distance")]    = ALL_GLSL(Distance);
91    fIntrinsicMap[String("cross")]       = ALL_GLSL(Cross);
92    fIntrinsicMap[String("normalize")]   = ALL_GLSL(Normalize);
93    fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
94    fIntrinsicMap[String("reflect")]     = ALL_GLSL(Reflect);
95    fIntrinsicMap[String("refract")]     = ALL_GLSL(Refract);
96    fIntrinsicMap[String("findLSB")]     = ALL_GLSL(FindILsb);
97    fIntrinsicMap[String("findMSB")]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
98    fIntrinsicMap[String("dFdx")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
99                                                           SpvOpUndef, SpvOpUndef, SpvOpUndef);
100    fIntrinsicMap[String("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
101                                                           SpvOpUndef, SpvOpUndef, SpvOpUndef);
102    fIntrinsicMap[String("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
103                                                           SpvOpUndef, SpvOpUndef, SpvOpUndef);
104    fIntrinsicMap[String("texture")]     = SPECIAL(Texture);
105    fIntrinsicMap[String("texelFetch")]  = SPECIAL(TexelFetch);
106    fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
107
108    fIntrinsicMap[String("any")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
109                                                                SpvOpUndef, SpvOpUndef, SpvOpAny);
110    fIntrinsicMap[String("all")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
111                                                                SpvOpUndef, SpvOpUndef, SpvOpAll);
112    fIntrinsicMap[String("equal")]            = std::make_tuple(kSPIRV_IntrinsicKind,
113                                                                SpvOpFOrdEqual, SpvOpIEqual,
114                                                                SpvOpIEqual, SpvOpLogicalEqual);
115    fIntrinsicMap[String("notEqual")]         = std::make_tuple(kSPIRV_IntrinsicKind,
116                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
117                                                                SpvOpINotEqual,
118                                                                SpvOpLogicalNotEqual);
119    fIntrinsicMap[String("lessThan")]         = std::make_tuple(kSPIRV_IntrinsicKind,
120                                                                SpvOpFOrdLessThan, SpvOpSLessThan,
121                                                                SpvOpULessThan, SpvOpUndef);
122    fIntrinsicMap[String("lessThanEqual")]    = std::make_tuple(kSPIRV_IntrinsicKind,
123                                                                SpvOpFOrdLessThanEqual,
124                                                                SpvOpSLessThanEqual,
125                                                                SpvOpULessThanEqual,
126                                                                SpvOpUndef);
127    fIntrinsicMap[String("greaterThan")]      = std::make_tuple(kSPIRV_IntrinsicKind,
128                                                                SpvOpFOrdGreaterThan,
129                                                                SpvOpSGreaterThan,
130                                                                SpvOpUGreaterThan,
131                                                                SpvOpUndef);
132    fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
133                                                                SpvOpFOrdGreaterThanEqual,
134                                                                SpvOpSGreaterThanEqual,
135                                                                SpvOpUGreaterThanEqual,
136                                                                SpvOpUndef);
137    fIntrinsicMap[String("EmitVertex")]       = ALL_SPIRV(EmitVertex);
138    fIntrinsicMap[String("EndPrimitive")]     = ALL_SPIRV(EndPrimitive);
139// interpolateAt* not yet supported...
140}
141
142void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
143    out.write((const char*) &word, sizeof(word));
144}
145
146static bool is_float(const Context& context, const Type& type) {
147    if (type.kind() == Type::kVector_Kind) {
148        return is_float(context, type.componentType());
149    }
150    return type == *context.fFloat_Type || type == *context.fHalf_Type ||
151           type == *context.fDouble_Type;
152}
153
154static bool is_signed(const Context& context, const Type& type) {
155    if (type.kind() == Type::kVector_Kind) {
156        return is_signed(context, type.componentType());
157    }
158    return type == *context.fInt_Type || type == *context.fShort_Type;
159}
160
161static bool is_unsigned(const Context& context, const Type& type) {
162    if (type.kind() == Type::kVector_Kind) {
163        return is_unsigned(context, type.componentType());
164    }
165    return type == *context.fUInt_Type || type == *context.fUShort_Type;
166}
167
168static bool is_bool(const Context& context, const Type& type) {
169    if (type.kind() == Type::kVector_Kind) {
170        return is_bool(context, type.componentType());
171    }
172    return type == *context.fBool_Type;
173}
174
175static bool is_out(const Variable& var) {
176    return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
177}
178
179void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
180    ASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
181    ASSERT(opCode != SpvOpUndef);
182    switch (opCode) {
183        case SpvOpReturn:      // fall through
184        case SpvOpReturnValue: // fall through
185        case SpvOpKill:        // fall through
186        case SpvOpBranch:      // fall through
187        case SpvOpBranchConditional:
188            ASSERT(fCurrentBlock);
189            fCurrentBlock = 0;
190            break;
191        case SpvOpConstant:          // fall through
192        case SpvOpConstantTrue:      // fall through
193        case SpvOpConstantFalse:     // fall through
194        case SpvOpConstantComposite: // fall through
195        case SpvOpTypeVoid:          // fall through
196        case SpvOpTypeInt:           // fall through
197        case SpvOpTypeFloat:         // fall through
198        case SpvOpTypeBool:          // fall through
199        case SpvOpTypeVector:        // fall through
200        case SpvOpTypeMatrix:        // fall through
201        case SpvOpTypeArray:         // fall through
202        case SpvOpTypePointer:       // fall through
203        case SpvOpTypeFunction:      // fall through
204        case SpvOpTypeRuntimeArray:  // fall through
205        case SpvOpTypeStruct:        // fall through
206        case SpvOpTypeImage:         // fall through
207        case SpvOpTypeSampledImage:  // fall through
208        case SpvOpVariable:          // fall through
209        case SpvOpFunction:          // fall through
210        case SpvOpFunctionParameter: // fall through
211        case SpvOpFunctionEnd:       // fall through
212        case SpvOpExecutionMode:     // fall through
213        case SpvOpMemoryModel:       // fall through
214        case SpvOpCapability:        // fall through
215        case SpvOpExtInstImport:     // fall through
216        case SpvOpEntryPoint:        // fall through
217        case SpvOpSource:            // fall through
218        case SpvOpSourceExtension:   // fall through
219        case SpvOpName:              // fall through
220        case SpvOpMemberName:        // fall through
221        case SpvOpDecorate:          // fall through
222        case SpvOpMemberDecorate:
223            break;
224        default:
225            ASSERT(fCurrentBlock);
226    }
227    this->writeWord((length << 16) | opCode, out);
228}
229
230void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
231    fCurrentBlock = label;
232    this->writeInstruction(SpvOpLabel, label, out);
233}
234
235void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
236    this->writeOpCode(opCode, 1, out);
237}
238
239void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
240    this->writeOpCode(opCode, 2, out);
241    this->writeWord(word1, out);
242}
243
244void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) {
245    out.write(string, length);
246    switch (length % 4) {
247        case 1:
248            out.write8(0);
249            // fall through
250        case 2:
251            out.write8(0);
252            // fall through
253        case 3:
254            out.write8(0);
255            break;
256        default:
257            this->writeWord(0, out);
258    }
259}
260
261void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) {
262    this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out);
263    this->writeString(string.fChars, string.fLength, out);
264}
265
266
267void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string,
268                                          OutputStream& out) {
269    this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out);
270    this->writeWord(word1, out);
271    this->writeString(string.fChars, string.fLength, out);
272}
273
274void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
275                                          StringFragment string, OutputStream& out) {
276    this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out);
277    this->writeWord(word1, out);
278    this->writeWord(word2, out);
279    this->writeString(string.fChars, string.fLength, out);
280}
281
282void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
283                                          OutputStream& out) {
284    this->writeOpCode(opCode, 3, out);
285    this->writeWord(word1, out);
286    this->writeWord(word2, out);
287}
288
289void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
290                                          int32_t word3, OutputStream& out) {
291    this->writeOpCode(opCode, 4, out);
292    this->writeWord(word1, out);
293    this->writeWord(word2, out);
294    this->writeWord(word3, out);
295}
296
297void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
298                                          int32_t word3, int32_t word4, OutputStream& out) {
299    this->writeOpCode(opCode, 5, out);
300    this->writeWord(word1, out);
301    this->writeWord(word2, out);
302    this->writeWord(word3, out);
303    this->writeWord(word4, out);
304}
305
306void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
307                                          int32_t word3, int32_t word4, int32_t word5,
308                                          OutputStream& out) {
309    this->writeOpCode(opCode, 6, out);
310    this->writeWord(word1, out);
311    this->writeWord(word2, out);
312    this->writeWord(word3, out);
313    this->writeWord(word4, out);
314    this->writeWord(word5, out);
315}
316
317void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
318                                          int32_t word3, int32_t word4, int32_t word5,
319                                          int32_t word6, OutputStream& out) {
320    this->writeOpCode(opCode, 7, out);
321    this->writeWord(word1, out);
322    this->writeWord(word2, out);
323    this->writeWord(word3, out);
324    this->writeWord(word4, out);
325    this->writeWord(word5, out);
326    this->writeWord(word6, out);
327}
328
329void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
330                                          int32_t word3, int32_t word4, int32_t word5,
331                                          int32_t word6, int32_t word7, OutputStream& out) {
332    this->writeOpCode(opCode, 8, out);
333    this->writeWord(word1, out);
334    this->writeWord(word2, out);
335    this->writeWord(word3, out);
336    this->writeWord(word4, out);
337    this->writeWord(word5, out);
338    this->writeWord(word6, out);
339    this->writeWord(word7, out);
340}
341
342void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
343                                          int32_t word3, int32_t word4, int32_t word5,
344                                          int32_t word6, int32_t word7, int32_t word8,
345                                          OutputStream& out) {
346    this->writeOpCode(opCode, 9, out);
347    this->writeWord(word1, out);
348    this->writeWord(word2, out);
349    this->writeWord(word3, out);
350    this->writeWord(word4, out);
351    this->writeWord(word5, out);
352    this->writeWord(word6, out);
353    this->writeWord(word7, out);
354    this->writeWord(word8, out);
355}
356
357void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
358    for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
359        if (fCapabilities & bit) {
360            this->writeInstruction(SpvOpCapability, (SpvId) i, out);
361        }
362    }
363    if (fProgram.fKind == Program::kGeometry_Kind) {
364        this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out);
365    }
366}
367
368SpvId SPIRVCodeGenerator::nextId() {
369    return fIdCount++;
370}
371
372void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
373                                     SpvId resultId) {
374    this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
375    // go ahead and write all of the field types, so we don't inadvertently write them while we're
376    // in the middle of writing the struct instruction
377    std::vector<SpvId> types;
378    for (const auto& f : type.fields()) {
379        types.push_back(this->getType(*f.fType, memoryLayout));
380    }
381    this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
382    this->writeWord(resultId, fConstantBuffer);
383    for (SpvId id : types) {
384        this->writeWord(id, fConstantBuffer);
385    }
386    size_t offset = 0;
387    for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
388        size_t size = memoryLayout.size(*type.fields()[i].fType);
389        size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
390        const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
391        if (fieldLayout.fOffset >= 0) {
392            if (fieldLayout.fOffset < (int) offset) {
393                fErrors.error(type.fOffset,
394                              "offset of field '" + type.fields()[i].fName + "' must be at "
395                              "least " + to_string((int) offset));
396            }
397            if (fieldLayout.fOffset % alignment) {
398                fErrors.error(type.fOffset,
399                              "offset of field '" + type.fields()[i].fName + "' must be a multiple"
400                              " of " + to_string((int) alignment));
401            }
402            offset = fieldLayout.fOffset;
403        } else {
404            size_t mod = offset % alignment;
405            if (mod) {
406                offset += alignment - mod;
407            }
408        }
409        this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName, fNameBuffer);
410        this->writeLayout(fieldLayout, resultId, i);
411        if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
412            this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
413                                   (SpvId) offset, fDecorationBuffer);
414        }
415        if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
416            this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
417                                   fDecorationBuffer);
418            this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
419                                   (SpvId) memoryLayout.stride(*type.fields()[i].fType),
420                                   fDecorationBuffer);
421        }
422        offset += size;
423        Type::Kind kind = type.fields()[i].fType->kind();
424        if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
425            offset += alignment - offset % alignment;
426        }
427    }
428}
429
430Type SPIRVCodeGenerator::getActualType(const Type& type) {
431    if (type == *fContext.fHalf_Type) {
432        return *fContext.fFloat_Type;
433    }
434    if (type == *fContext.fShort_Type) {
435        return *fContext.fInt_Type;
436    }
437    if (type == *fContext.fUShort_Type) {
438        return *fContext.fUInt_Type;
439    }
440    if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) {
441        if (type.componentType() == *fContext.fHalf_Type) {
442            return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows());
443        }
444        if (type.componentType() == *fContext.fShort_Type) {
445            return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows());
446        }
447        if (type.componentType() == *fContext.fUShort_Type) {
448            return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows());
449        }
450    }
451    return type;
452}
453
454SpvId SPIRVCodeGenerator::getType(const Type& type) {
455    return this->getType(type, fDefaultLayout);
456}
457
458SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
459    Type type = this->getActualType(rawType);
460    String key = type.name() + to_string((int) layout.fStd);
461    auto entry = fTypeMap.find(key);
462    if (entry == fTypeMap.end()) {
463        SpvId result = this->nextId();
464        switch (type.kind()) {
465            case Type::kScalar_Kind:
466                if (type == *fContext.fBool_Type) {
467                    this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
468                } else if (type == *fContext.fInt_Type) {
469                    this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
470                } else if (type == *fContext.fUInt_Type) {
471                    this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
472                } else if (type == *fContext.fFloat_Type) {
473                    this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
474                } else if (type == *fContext.fDouble_Type) {
475                    this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
476                } else {
477                    ASSERT(false);
478                }
479                break;
480            case Type::kVector_Kind:
481                this->writeInstruction(SpvOpTypeVector, result,
482                                       this->getType(type.componentType(), layout),
483                                       type.columns(), fConstantBuffer);
484                break;
485            case Type::kMatrix_Kind:
486                this->writeInstruction(SpvOpTypeMatrix, result,
487                                       this->getType(index_type(fContext, type), layout),
488                                       type.columns(), fConstantBuffer);
489                break;
490            case Type::kStruct_Kind:
491                this->writeStruct(type, layout, result);
492                break;
493            case Type::kArray_Kind: {
494                if (type.columns() > 0) {
495                    IntLiteral count(fContext, -1, type.columns());
496                    this->writeInstruction(SpvOpTypeArray, result,
497                                           this->getType(type.componentType(), layout),
498                                           this->writeIntLiteral(count), fConstantBuffer);
499                    this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
500                                           (int32_t) layout.stride(type),
501                                           fDecorationBuffer);
502                } else {
503                    this->writeInstruction(SpvOpTypeRuntimeArray, result,
504                                           this->getType(type.componentType(), layout),
505                                           fConstantBuffer);
506                    this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
507                                           (int32_t) layout.stride(type),
508                                           fDecorationBuffer);
509                }
510                break;
511            }
512            case Type::kSampler_Kind: {
513                SpvId image = result;
514                if (SpvDimSubpassData != type.dimensions()) {
515                    image = this->nextId();
516                }
517                if (SpvDimBuffer == type.dimensions()) {
518                    fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
519                }
520                this->writeInstruction(SpvOpTypeImage, image,
521                                       this->getType(*fContext.fFloat_Type, layout),
522                                       type.dimensions(), type.isDepth(), type.isArrayed(),
523                                       type.isMultisampled(), type.isSampled() ? 1 : 2,
524                                       SpvImageFormatUnknown, fConstantBuffer);
525                fImageTypeMap[key] = image;
526                if (SpvDimSubpassData != type.dimensions()) {
527                    this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
528                }
529                break;
530            }
531            default:
532                if (type == *fContext.fVoid_Type) {
533                    this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
534                } else {
535                    ABORT("invalid type: %s", type.description().c_str());
536                }
537        }
538        fTypeMap[key] = result;
539        return result;
540    }
541    return entry->second;
542}
543
544SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
545    ASSERT(type.kind() == Type::kSampler_Kind);
546    this->getType(type);
547    String key = type.name() + to_string((int) fDefaultLayout.fStd);
548    ASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
549    return fImageTypeMap[key];
550}
551
552SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
553    String key = function.fReturnType.description() + "(";
554    String separator;
555    for (size_t i = 0; i < function.fParameters.size(); i++) {
556        key += separator;
557        separator = ", ";
558        key += function.fParameters[i]->fType.description();
559    }
560    key += ")";
561    auto entry = fTypeMap.find(key);
562    if (entry == fTypeMap.end()) {
563        SpvId result = this->nextId();
564        int32_t length = 3 + (int32_t) function.fParameters.size();
565        SpvId returnType = this->getType(function.fReturnType);
566        std::vector<SpvId> parameterTypes;
567        for (size_t i = 0; i < function.fParameters.size(); i++) {
568            // glslang seems to treat all function arguments as pointers whether they need to be or
569            // not. I  was initially puzzled by this until I ran bizarre failures with certain
570            // patterns of function calls and control constructs, as exemplified by this minimal
571            // failure case:
572            //
573            // void sphere(float x) {
574            // }
575            //
576            // void map() {
577            //     sphere(1.0);
578            // }
579            //
580            // void main() {
581            //     for (int i = 0; i < 1; i++) {
582            //         map();
583            //     }
584            // }
585            //
586            // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
587            // crashes. Making it take a float* and storing the argument in a temporary variable,
588            // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
589            // the spec makes this make sense.
590//            if (is_out(function->fParameters[i])) {
591                parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
592                                                              SpvStorageClassFunction));
593//            } else {
594//                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
595//            }
596        }
597        this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
598        this->writeWord(result, fConstantBuffer);
599        this->writeWord(returnType, fConstantBuffer);
600        for (SpvId id : parameterTypes) {
601            this->writeWord(id, fConstantBuffer);
602        }
603        fTypeMap[key] = result;
604        return result;
605    }
606    return entry->second;
607}
608
609SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
610    return this->getPointerType(type, fDefaultLayout, storageClass);
611}
612
613SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
614                                         SpvStorageClass_ storageClass) {
615    Type type = this->getActualType(rawType);
616    String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
617    auto entry = fTypeMap.find(key);
618    if (entry == fTypeMap.end()) {
619        SpvId result = this->nextId();
620        this->writeInstruction(SpvOpTypePointer, result, storageClass,
621                               this->getType(type), fConstantBuffer);
622        fTypeMap[key] = result;
623        return result;
624    }
625    return entry->second;
626}
627
628SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
629    switch (expr.fKind) {
630        case Expression::kBinary_Kind:
631            return this->writeBinaryExpression((BinaryExpression&) expr, out);
632        case Expression::kBoolLiteral_Kind:
633            return this->writeBoolLiteral((BoolLiteral&) expr);
634        case Expression::kConstructor_Kind:
635            return this->writeConstructor((Constructor&) expr, out);
636        case Expression::kIntLiteral_Kind:
637            return this->writeIntLiteral((IntLiteral&) expr);
638        case Expression::kFieldAccess_Kind:
639            return this->writeFieldAccess(((FieldAccess&) expr), out);
640        case Expression::kFloatLiteral_Kind:
641            return this->writeFloatLiteral(((FloatLiteral&) expr));
642        case Expression::kFunctionCall_Kind:
643            return this->writeFunctionCall((FunctionCall&) expr, out);
644        case Expression::kPrefix_Kind:
645            return this->writePrefixExpression((PrefixExpression&) expr, out);
646        case Expression::kPostfix_Kind:
647            return this->writePostfixExpression((PostfixExpression&) expr, out);
648        case Expression::kSwizzle_Kind:
649            return this->writeSwizzle((Swizzle&) expr, out);
650        case Expression::kVariableReference_Kind:
651            return this->writeVariableReference((VariableReference&) expr, out);
652        case Expression::kTernary_Kind:
653            return this->writeTernaryExpression((TernaryExpression&) expr, out);
654        case Expression::kIndex_Kind:
655            return this->writeIndexExpression((IndexExpression&) expr, out);
656        default:
657            ABORT("unsupported expression: %s", expr.description().c_str());
658    }
659    return -1;
660}
661
662SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
663    auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
664    ASSERT(intrinsic != fIntrinsicMap.end());
665    int32_t intrinsicId;
666    if (c.fArguments.size() > 0) {
667        const Type& type = c.fArguments[0]->fType;
668        if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
669            intrinsicId = std::get<1>(intrinsic->second);
670        } else if (is_signed(fContext, type)) {
671            intrinsicId = std::get<2>(intrinsic->second);
672        } else if (is_unsigned(fContext, type)) {
673            intrinsicId = std::get<3>(intrinsic->second);
674        } else if (is_bool(fContext, type)) {
675            intrinsicId = std::get<4>(intrinsic->second);
676        } else {
677            intrinsicId = std::get<1>(intrinsic->second);
678        }
679    } else {
680        intrinsicId = std::get<1>(intrinsic->second);
681    }
682    switch (std::get<0>(intrinsic->second)) {
683        case kGLSL_STD_450_IntrinsicKind: {
684            SpvId result = this->nextId();
685            std::vector<SpvId> arguments;
686            for (size_t i = 0; i < c.fArguments.size(); i++) {
687                arguments.push_back(this->writeExpression(*c.fArguments[i], out));
688            }
689            this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
690            this->writeWord(this->getType(c.fType), out);
691            this->writeWord(result, out);
692            this->writeWord(fGLSLExtendedInstructions, out);
693            this->writeWord(intrinsicId, out);
694            for (SpvId id : arguments) {
695                this->writeWord(id, out);
696            }
697            return result;
698        }
699        case kSPIRV_IntrinsicKind: {
700            SpvId result = this->nextId();
701            std::vector<SpvId> arguments;
702            for (size_t i = 0; i < c.fArguments.size(); i++) {
703                arguments.push_back(this->writeExpression(*c.fArguments[i], out));
704            }
705            if (c.fType != *fContext.fVoid_Type) {
706                this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
707                this->writeWord(this->getType(c.fType), out);
708                this->writeWord(result, out);
709            } else {
710                this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
711            }
712            for (SpvId id : arguments) {
713                this->writeWord(id, out);
714            }
715            return result;
716        }
717        case kSpecial_IntrinsicKind:
718            return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
719        default:
720            ABORT("unsupported intrinsic kind");
721    }
722}
723
724SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
725                                                OutputStream& out) {
726    SpvId result = this->nextId();
727    switch (kind) {
728        case kAtan_SpecialIntrinsic: {
729            std::vector<SpvId> arguments;
730            for (size_t i = 0; i < c.fArguments.size(); i++) {
731                arguments.push_back(this->writeExpression(*c.fArguments[i], out));
732            }
733            this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
734            this->writeWord(this->getType(c.fType), out);
735            this->writeWord(result, out);
736            this->writeWord(fGLSLExtendedInstructions, out);
737            this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
738            for (SpvId id : arguments) {
739                this->writeWord(id, out);
740            }
741            break;
742        }
743        case kSubpassLoad_SpecialIntrinsic: {
744            SpvId img = this->writeExpression(*c.fArguments[0], out);
745            std::vector<std::unique_ptr<Expression>> args;
746            args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
747            args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
748            Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args));
749            SpvId coords = this->writeConstantVector(ctor);
750            if (1 == c.fArguments.size()) {
751                this->writeInstruction(SpvOpImageRead,
752                                       this->getType(c.fType),
753                                       result,
754                                       img,
755                                       coords,
756                                       out);
757            } else {
758                ASSERT(2 == c.fArguments.size());
759                SpvId sample = this->writeExpression(*c.fArguments[1], out);
760                this->writeInstruction(SpvOpImageRead,
761                                       this->getType(c.fType),
762                                       result,
763                                       img,
764                                       coords,
765                                       SpvImageOperandsSampleMask,
766                                       sample,
767                                       out);
768            }
769            break;
770        }
771        case kTexelFetch_SpecialIntrinsic: {
772            ASSERT(c.fArguments.size() == 2);
773            SpvId image = this->nextId();
774            this->writeInstruction(SpvOpImage,
775                                   this->getImageType(c.fArguments[0]->fType),
776                                   image,
777                                   this->writeExpression(*c.fArguments[0], out),
778                                   out);
779            this->writeInstruction(SpvOpImageFetch,
780                                   this->getType(c.fType),
781                                   result,
782                                   image,
783                                   this->writeExpression(*c.fArguments[1], out),
784                                   out);
785            break;
786        }
787        case kTexture_SpecialIntrinsic: {
788            SpvOp_ op = SpvOpImageSampleImplicitLod;
789            switch (c.fArguments[0]->fType.dimensions()) {
790                case SpvDim1D:
791                    if (c.fArguments[1]->fType == *fContext.fFloat2_Type) {
792                        op = SpvOpImageSampleProjImplicitLod;
793                    } else {
794                        ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
795                    }
796                    break;
797                case SpvDim2D:
798                    if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
799                        op = SpvOpImageSampleProjImplicitLod;
800                    } else {
801                        ASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
802                    }
803                    break;
804                case SpvDim3D:
805                    if (c.fArguments[1]->fType == *fContext.fFloat4_Type) {
806                        op = SpvOpImageSampleProjImplicitLod;
807                    } else {
808                        ASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type);
809                    }
810                    break;
811                case SpvDimCube:   // fall through
812                case SpvDimRect:   // fall through
813                case SpvDimBuffer: // fall through
814                case SpvDimSubpassData:
815                    break;
816            }
817            SpvId type = this->getType(c.fType);
818            SpvId sampler = this->writeExpression(*c.fArguments[0], out);
819            SpvId uv = this->writeExpression(*c.fArguments[1], out);
820            if (c.fArguments.size() == 3) {
821                this->writeInstruction(op, type, result, sampler, uv,
822                                       SpvImageOperandsBiasMask,
823                                       this->writeExpression(*c.fArguments[2], out),
824                                       out);
825            } else {
826                ASSERT(c.fArguments.size() == 2);
827                if (fProgram.fSettings.fSharpenTextures) {
828                    FloatLiteral lodBias(fContext, -1, -0.5);
829                    this->writeInstruction(op, type, result, sampler, uv,
830                                           SpvImageOperandsBiasMask,
831                                           this->writeFloatLiteral(lodBias),
832                                           out);
833                } else {
834                    this->writeInstruction(op, type, result, sampler, uv,
835                                           out);
836                }
837            }
838            break;
839        }
840        case kMod_SpecialIntrinsic: {
841            ASSERT(c.fArguments.size() == 2);
842            SpvId arg1 = this->writeExpression(*c.fArguments[0], out);
843            SpvId arg2 = this->writeExpression(*c.fArguments[1], out);
844            if (c.fArguments[0]->fType != c.fArguments[1]->fType) {
845                // we have mod(vector, scalar), but SPIR-V wants mod(vector, vector)
846                ASSERT(c.fArguments[0]->fType.componentType() == c.fArguments[1]->fType);
847                SpvId scalar = arg2;
848                const Type& type = c.fArguments[0]->fType;
849                arg2 = this->nextId();
850                this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(), out);
851                this->writeWord(this->getType(type), out);
852                this->writeWord(arg2, out);
853                for (int i = 0; i < type.columns(); i++) {
854                    this->writeWord(scalar, out);
855                }
856            }
857            const Type& operandType = c.fArguments[0]->fType;
858            SpvOp_ op;
859            if (is_float(fContext, operandType)) {
860                op = SpvOpFMod;
861            } else if (is_signed(fContext, operandType)) {
862                op = SpvOpSMod;
863            } else if (is_unsigned(fContext, operandType)) {
864                op = SpvOpUMod;
865            } else {
866                ASSERT(false);
867                return 0;
868            }
869            this->writeOpCode(op, 5, out);
870            this->writeWord(this->getType(operandType), out);
871            this->writeWord(result, out);
872            this->writeWord(arg1, out);
873            this->writeWord(arg2, out);
874        }
875    }
876    return result;
877}
878
879SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
880    const auto& entry = fFunctionMap.find(&c.fFunction);
881    if (entry == fFunctionMap.end()) {
882        return this->writeIntrinsicCall(c, out);
883    }
884    // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
885    std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
886    std::vector<SpvId> arguments;
887    for (size_t i = 0; i < c.fArguments.size(); i++) {
888        // id of temporary variable that we will use to hold this argument, or 0 if it is being
889        // passed directly
890        SpvId tmpVar;
891        // if we need a temporary var to store this argument, this is the value to store in the var
892        SpvId tmpValueId;
893        if (is_out(*c.fFunction.fParameters[i])) {
894            std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
895            SpvId ptr = lv->getPointer();
896            if (ptr) {
897                arguments.push_back(ptr);
898                continue;
899            } else {
900                // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
901                // copy it into a temp, call the function, read the value out of the temp, and then
902                // update the lvalue.
903                tmpValueId = lv->load(out);
904                tmpVar = this->nextId();
905                lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
906                                  std::move(lv)));
907            }
908        } else {
909            // see getFunctionType for an explanation of why we're always using pointer parameters
910            tmpValueId = this->writeExpression(*c.fArguments[i], out);
911            tmpVar = this->nextId();
912        }
913        this->writeInstruction(SpvOpVariable,
914                               this->getPointerType(c.fArguments[i]->fType,
915                                                    SpvStorageClassFunction),
916                               tmpVar,
917                               SpvStorageClassFunction,
918                               fVariableBuffer);
919        this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
920        arguments.push_back(tmpVar);
921    }
922    SpvId result = this->nextId();
923    this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
924    this->writeWord(this->getType(c.fType), out);
925    this->writeWord(result, out);
926    this->writeWord(entry->second, out);
927    for (SpvId id : arguments) {
928        this->writeWord(id, out);
929    }
930    // now that the call is complete, we may need to update some lvalues with the new values of out
931    // arguments
932    for (const auto& tuple : lvalues) {
933        SpvId load = this->nextId();
934        this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
935        std::get<2>(tuple)->store(load, out);
936    }
937    return result;
938}
939
940SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
941    ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
942    SpvId result = this->nextId();
943    std::vector<SpvId> arguments;
944    for (size_t i = 0; i < c.fArguments.size(); i++) {
945        arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
946    }
947    SpvId type = this->getType(c.fType);
948    if (c.fArguments.size() == 1) {
949        // with a single argument, a vector will have all of its entries equal to the argument
950        this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
951        this->writeWord(type, fConstantBuffer);
952        this->writeWord(result, fConstantBuffer);
953        for (int i = 0; i < c.fType.columns(); i++) {
954            this->writeWord(arguments[0], fConstantBuffer);
955        }
956    } else {
957        this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
958                          fConstantBuffer);
959        this->writeWord(type, fConstantBuffer);
960        this->writeWord(result, fConstantBuffer);
961        for (SpvId id : arguments) {
962            this->writeWord(id, fConstantBuffer);
963        }
964    }
965    return result;
966}
967
968SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
969    ASSERT(c.fType.isFloat());
970    ASSERT(c.fArguments.size() == 1);
971    ASSERT(c.fArguments[0]->fType.isNumber());
972    SpvId result = this->nextId();
973    SpvId parameter = this->writeExpression(*c.fArguments[0], out);
974    if (c.fArguments[0]->fType.isSigned()) {
975        this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
976                               out);
977    } else {
978        ASSERT(c.fArguments[0]->fType.isUnsigned());
979        this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
980                               out);
981    }
982    return result;
983}
984
985SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
986    ASSERT(c.fType.isSigned());
987    ASSERT(c.fArguments.size() == 1);
988    ASSERT(c.fArguments[0]->fType.isNumber());
989    SpvId result = this->nextId();
990    SpvId parameter = this->writeExpression(*c.fArguments[0], out);
991    if (c.fArguments[0]->fType.isFloat()) {
992        this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
993                               out);
994    }
995    else {
996        ASSERT(c.fArguments[0]->fType.isUnsigned());
997        this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
998                               out);
999    }
1000    return result;
1001}
1002
1003SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
1004    ASSERT(c.fType.isUnsigned());
1005    ASSERT(c.fArguments.size() == 1);
1006    ASSERT(c.fArguments[0]->fType.isNumber());
1007    SpvId result = this->nextId();
1008    SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1009    if (c.fArguments[0]->fType.isFloat()) {
1010        this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
1011                               out);
1012    } else {
1013        ASSERT(c.fArguments[0]->fType.isSigned());
1014        this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1015                               out);
1016    }
1017    return result;
1018}
1019
1020void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1021                                                 OutputStream& out) {
1022    FloatLiteral zero(fContext, -1, 0);
1023    SpvId zeroId = this->writeFloatLiteral(zero);
1024    std::vector<SpvId> columnIds;
1025    for (int column = 0; column < type.columns(); column++) {
1026        this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1027                          out);
1028        this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1029                        out);
1030        SpvId columnId = this->nextId();
1031        this->writeWord(columnId, out);
1032        columnIds.push_back(columnId);
1033        for (int row = 0; row < type.columns(); row++) {
1034            this->writeWord(row == column ? diagonal : zeroId, out);
1035        }
1036    }
1037    this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1038                      out);
1039    this->writeWord(this->getType(type), out);
1040    this->writeWord(id, out);
1041    for (SpvId id : columnIds) {
1042        this->writeWord(id, out);
1043    }
1044}
1045
1046void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1047                                         const Type& dstType, OutputStream& out) {
1048    ASSERT(srcType.kind() == Type::kMatrix_Kind);
1049    ASSERT(dstType.kind() == Type::kMatrix_Kind);
1050    ASSERT(srcType.componentType() == dstType.componentType());
1051    SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1052                                                                           srcType.rows(),
1053                                                                           1));
1054    SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1055                                                                           dstType.rows(),
1056                                                                           1));
1057    SpvId zeroId;
1058    if (dstType.componentType() == *fContext.fFloat_Type) {
1059        FloatLiteral zero(fContext, -1, 0.0);
1060        zeroId = this->writeFloatLiteral(zero);
1061    } else if (dstType.componentType() == *fContext.fInt_Type) {
1062        IntLiteral zero(fContext, -1, 0);
1063        zeroId = this->writeIntLiteral(zero);
1064    } else {
1065        ABORT("unsupported matrix component type");
1066    }
1067    SpvId zeroColumn = 0;
1068    SpvId columns[4];
1069    for (int i = 0; i < dstType.columns(); i++) {
1070        if (i < srcType.columns()) {
1071            // we're still inside the src matrix, copy the column
1072            SpvId srcColumn = this->nextId();
1073            this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1074            SpvId dstColumn;
1075            if (srcType.rows() == dstType.rows()) {
1076                // columns are equal size, don't need to do anything
1077                dstColumn = srcColumn;
1078            }
1079            else if (dstType.rows() > srcType.rows()) {
1080                // dst column is bigger, need to zero-pad it
1081                dstColumn = this->nextId();
1082                int delta = dstType.rows() - srcType.rows();
1083                this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1084                this->writeWord(dstColumnType, out);
1085                this->writeWord(dstColumn, out);
1086                this->writeWord(srcColumn, out);
1087                for (int i = 0; i < delta; ++i) {
1088                    this->writeWord(zeroId, out);
1089                }
1090            }
1091            else {
1092                // dst column is smaller, need to swizzle the src column
1093                dstColumn = this->nextId();
1094                int count = dstType.rows();
1095                this->writeOpCode(SpvOpVectorShuffle, 5 + count, out);
1096                this->writeWord(dstColumnType, out);
1097                this->writeWord(dstColumn, out);
1098                this->writeWord(srcColumn, out);
1099                this->writeWord(srcColumn, out);
1100                for (int i = 0; i < count; i++) {
1101                    this->writeWord(i, out);
1102                }
1103            }
1104            columns[i] = dstColumn;
1105        } else {
1106            // we're past the end of the src matrix, need a vector of zeroes
1107            if (!zeroColumn) {
1108                zeroColumn = this->nextId();
1109                this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1110                this->writeWord(dstColumnType, out);
1111                this->writeWord(zeroColumn, out);
1112                for (int i = 0; i < dstType.rows(); ++i) {
1113                    this->writeWord(zeroId, out);
1114                }
1115            }
1116            columns[i] = zeroColumn;
1117        }
1118    }
1119    this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1120    this->writeWord(this->getType(dstType), out);
1121    this->writeWord(id, out);
1122    for (int i = 0; i < dstType.columns(); i++) {
1123        this->writeWord(columns[i], out);
1124    }
1125}
1126
1127SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
1128    ASSERT(c.fType.kind() == Type::kMatrix_Kind);
1129    // go ahead and write the arguments so we don't try to write new instructions in the middle of
1130    // an instruction
1131    std::vector<SpvId> arguments;
1132    for (size_t i = 0; i < c.fArguments.size(); i++) {
1133        arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1134    }
1135    SpvId result = this->nextId();
1136    int rows = c.fType.rows();
1137    int columns = c.fType.columns();
1138    if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1139        this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1140    } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1141        this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1142    } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) {
1143        ASSERT(c.fType.rows() == 2 && c.fType.columns() == 2);
1144        ASSERT(c.fArguments[0]->fType.columns() == 4);
1145        SpvId componentType = this->getType(c.fType.componentType());
1146        SpvId v[4];
1147        for (int i = 0; i < 4; ++i) {
1148            v[i] = this->nextId();
1149            this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out);
1150        }
1151        SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1));
1152        SpvId column1 = this->nextId();
1153        this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out);
1154        SpvId column2 = this->nextId();
1155        this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out);
1156        this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1,
1157                               column2, out);
1158    } else {
1159        std::vector<SpvId> columnIds;
1160        // ids of vectors and scalars we have written to the current column so far
1161        std::vector<SpvId> currentColumn;
1162        // the total number of scalars represented by currentColumn's entries
1163        int currentCount = 0;
1164        for (size_t i = 0; i < arguments.size(); i++) {
1165            if (c.fArguments[i]->fType.kind() == Type::kVector_Kind &&
1166                    c.fArguments[i]->fType.columns() == c.fType.rows()) {
1167                // this is a complete column by itself
1168                ASSERT(currentCount == 0);
1169                columnIds.push_back(arguments[i]);
1170            } else {
1171                currentColumn.push_back(arguments[i]);
1172                currentCount += c.fArguments[i]->fType.columns();
1173                if (currentCount == rows) {
1174                    currentCount = 0;
1175                    this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn.size(), out);
1176                    this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
1177                                                                                     1)),
1178                                    out);
1179                    SpvId columnId = this->nextId();
1180                    this->writeWord(columnId, out);
1181                    columnIds.push_back(columnId);
1182                    for (SpvId id : currentColumn) {
1183                        this->writeWord(id, out);
1184                    }
1185                    currentColumn.clear();
1186                }
1187                ASSERT(currentCount < rows);
1188            }
1189        }
1190        ASSERT(columnIds.size() == (size_t) columns);
1191        this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1192        this->writeWord(this->getType(c.fType), out);
1193        this->writeWord(result, out);
1194        for (SpvId id : columnIds) {
1195            this->writeWord(id, out);
1196        }
1197    }
1198    return result;
1199}
1200
1201SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
1202    ASSERT(c.fType.kind() == Type::kVector_Kind);
1203    if (c.isConstant()) {
1204        return this->writeConstantVector(c);
1205    }
1206    // go ahead and write the arguments so we don't try to write new instructions in the middle of
1207    // an instruction
1208    std::vector<SpvId> arguments;
1209    for (size_t i = 0; i < c.fArguments.size(); i++) {
1210        if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1211            // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to
1212            // extract the components and convert them in that case manually. On top of that,
1213            // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite
1214            // doesn't handle vector arguments at all, so we always extract vector components and
1215            // pass them into OpCreateComposite individually.
1216            SpvId vec = this->writeExpression(*c.fArguments[i], out);
1217            SpvOp_ op = SpvOpUndef;
1218            const Type& src = c.fArguments[i]->fType.componentType();
1219            const Type& dst = c.fType.componentType();
1220            if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) {
1221                if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1222                    if (c.fArguments.size() == 1) {
1223                        return vec;
1224                    }
1225                } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1226                    op = SpvOpConvertSToF;
1227                } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1228                    op = SpvOpConvertUToF;
1229                } else {
1230                    ASSERT(false);
1231                }
1232            } else if (dst == *fContext.fInt_Type || dst == *fContext.fShort_Type) {
1233                if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1234                    op = SpvOpConvertFToS;
1235                } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1236                    if (c.fArguments.size() == 1) {
1237                        return vec;
1238                    }
1239                } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1240                    op = SpvOpBitcast;
1241                } else {
1242                    ASSERT(false);
1243                }
1244            } else if (dst == *fContext.fUInt_Type || dst == *fContext.fUShort_Type) {
1245                if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1246                    op = SpvOpConvertFToS;
1247                } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1248                    op = SpvOpBitcast;
1249                } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1250                    if (c.fArguments.size() == 1) {
1251                        return vec;
1252                    }
1253                } else {
1254                    ASSERT(false);
1255                }
1256            }
1257            for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) {
1258                SpvId swizzle = this->nextId();
1259                this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j,
1260                                       out);
1261                if (op != SpvOpUndef) {
1262                    SpvId cast = this->nextId();
1263                    this->writeInstruction(op, this->getType(dst), cast, swizzle, out);
1264                    arguments.push_back(cast);
1265                } else {
1266                    arguments.push_back(swizzle);
1267                }
1268            }
1269        } else {
1270            arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1271        }
1272    }
1273    SpvId result = this->nextId();
1274    if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1275        this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1276        this->writeWord(this->getType(c.fType), out);
1277        this->writeWord(result, out);
1278        for (int i = 0; i < c.fType.columns(); i++) {
1279            this->writeWord(arguments[0], out);
1280        }
1281    } else {
1282        ASSERT(arguments.size() > 1);
1283        this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1284        this->writeWord(this->getType(c.fType), out);
1285        this->writeWord(result, out);
1286        for (SpvId id : arguments) {
1287            this->writeWord(id, out);
1288        }
1289    }
1290    return result;
1291}
1292
1293SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) {
1294    ASSERT(c.fType.kind() == Type::kArray_Kind);
1295    // go ahead and write the arguments so we don't try to write new instructions in the middle of
1296    // an instruction
1297    std::vector<SpvId> arguments;
1298    for (size_t i = 0; i < c.fArguments.size(); i++) {
1299        arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1300    }
1301    SpvId result = this->nextId();
1302    this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1303    this->writeWord(this->getType(c.fType), out);
1304    this->writeWord(result, out);
1305    for (SpvId id : arguments) {
1306        this->writeWord(id, out);
1307    }
1308    return result;
1309}
1310
1311SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
1312    if (c.fArguments.size() == 1 &&
1313        this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) {
1314        return this->writeExpression(*c.fArguments[0], out);
1315    }
1316    if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) {
1317        return this->writeFloatConstructor(c, out);
1318    } else if (c.fType == *fContext.fInt_Type || c.fType == *fContext.fShort_Type) {
1319        return this->writeIntConstructor(c, out);
1320    } else if (c.fType == *fContext.fUInt_Type || c.fType == *fContext.fUShort_Type) {
1321        return this->writeUIntConstructor(c, out);
1322    }
1323    switch (c.fType.kind()) {
1324        case Type::kVector_Kind:
1325            return this->writeVectorConstructor(c, out);
1326        case Type::kMatrix_Kind:
1327            return this->writeMatrixConstructor(c, out);
1328        case Type::kArray_Kind:
1329            return this->writeArrayConstructor(c, out);
1330        default:
1331            ABORT("unsupported constructor: %s", c.description().c_str());
1332    }
1333}
1334
1335SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1336    if (modifiers.fFlags & Modifiers::kIn_Flag) {
1337        ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1338        return SpvStorageClassInput;
1339    } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1340        ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1341        return SpvStorageClassOutput;
1342    } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1343        if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1344            return SpvStorageClassPushConstant;
1345        }
1346        return SpvStorageClassUniform;
1347    } else {
1348        return SpvStorageClassFunction;
1349    }
1350}
1351
1352SpvStorageClass_ get_storage_class(const Expression& expr) {
1353    switch (expr.fKind) {
1354        case Expression::kVariableReference_Kind: {
1355            const Variable& var = ((VariableReference&) expr).fVariable;
1356            if (var.fStorage != Variable::kGlobal_Storage) {
1357                return SpvStorageClassFunction;
1358            }
1359            SpvStorageClass_ result = get_storage_class(var.fModifiers);
1360            if (result == SpvStorageClassFunction) {
1361                result = SpvStorageClassPrivate;
1362            }
1363            return result;
1364        }
1365        case Expression::kFieldAccess_Kind:
1366            return get_storage_class(*((FieldAccess&) expr).fBase);
1367        case Expression::kIndex_Kind:
1368            return get_storage_class(*((IndexExpression&) expr).fBase);
1369        default:
1370            return SpvStorageClassFunction;
1371    }
1372}
1373
1374std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1375    std::vector<SpvId> chain;
1376    switch (expr.fKind) {
1377        case Expression::kIndex_Kind: {
1378            IndexExpression& indexExpr = (IndexExpression&) expr;
1379            chain = this->getAccessChain(*indexExpr.fBase, out);
1380            chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1381            break;
1382        }
1383        case Expression::kFieldAccess_Kind: {
1384            FieldAccess& fieldExpr = (FieldAccess&) expr;
1385            chain = this->getAccessChain(*fieldExpr.fBase, out);
1386            IntLiteral index(fContext, -1, fieldExpr.fFieldIndex);
1387            chain.push_back(this->writeIntLiteral(index));
1388            break;
1389        }
1390        default:
1391            chain.push_back(this->getLValue(expr, out)->getPointer());
1392    }
1393    return chain;
1394}
1395
1396class PointerLValue : public SPIRVCodeGenerator::LValue {
1397public:
1398    PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1399    : fGen(gen)
1400    , fPointer(pointer)
1401    , fType(type) {}
1402
1403    virtual SpvId getPointer() override {
1404        return fPointer;
1405    }
1406
1407    virtual SpvId load(OutputStream& out) override {
1408        SpvId result = fGen.nextId();
1409        fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1410        return result;
1411    }
1412
1413    virtual void store(SpvId value, OutputStream& out) override {
1414        fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1415    }
1416
1417private:
1418    SPIRVCodeGenerator& fGen;
1419    const SpvId fPointer;
1420    const SpvId fType;
1421};
1422
1423class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1424public:
1425    SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1426                  const Type& baseType, const Type& swizzleType)
1427    : fGen(gen)
1428    , fVecPointer(vecPointer)
1429    , fComponents(components)
1430    , fBaseType(baseType)
1431    , fSwizzleType(swizzleType) {}
1432
1433    virtual SpvId getPointer() override {
1434        return 0;
1435    }
1436
1437    virtual SpvId load(OutputStream& out) override {
1438        SpvId base = fGen.nextId();
1439        fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1440        SpvId result = fGen.nextId();
1441        fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1442        fGen.writeWord(fGen.getType(fSwizzleType), out);
1443        fGen.writeWord(result, out);
1444        fGen.writeWord(base, out);
1445        fGen.writeWord(base, out);
1446        for (int component : fComponents) {
1447            fGen.writeWord(component, out);
1448        }
1449        return result;
1450    }
1451
1452    virtual void store(SpvId value, OutputStream& out) override {
1453        // use OpVectorShuffle to mix and match the vector components. We effectively create
1454        // a virtual vector out of the concatenation of the left and right vectors, and then
1455        // select components from this virtual vector to make the result vector. For
1456        // instance, given:
1457        // float3L = ...;
1458        // float3R = ...;
1459        // L.xz = R.xy;
1460        // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1461        // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1462        // (3, 1, 4).
1463        SpvId base = fGen.nextId();
1464        fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1465        SpvId shuffle = fGen.nextId();
1466        fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1467        fGen.writeWord(fGen.getType(fBaseType), out);
1468        fGen.writeWord(shuffle, out);
1469        fGen.writeWord(base, out);
1470        fGen.writeWord(value, out);
1471        for (int i = 0; i < fBaseType.columns(); i++) {
1472            // current offset into the virtual vector, defaults to pulling the unmodified
1473            // value from the left side
1474            int offset = i;
1475            // check to see if we are writing this component
1476            for (size_t j = 0; j < fComponents.size(); j++) {
1477                if (fComponents[j] == i) {
1478                    // we're writing to this component, so adjust the offset to pull from
1479                    // the correct component of the right side instead of preserving the
1480                    // value from the left
1481                    offset = (int) (j + fBaseType.columns());
1482                    break;
1483                }
1484            }
1485            fGen.writeWord(offset, out);
1486        }
1487        fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1488    }
1489
1490private:
1491    SPIRVCodeGenerator& fGen;
1492    const SpvId fVecPointer;
1493    const std::vector<int>& fComponents;
1494    const Type& fBaseType;
1495    const Type& fSwizzleType;
1496};
1497
1498std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1499                                                                          OutputStream& out) {
1500    switch (expr.fKind) {
1501        case Expression::kVariableReference_Kind: {
1502            const Variable& var = ((VariableReference&) expr).fVariable;
1503            auto entry = fVariableMap.find(&var);
1504            ASSERT(entry != fVariableMap.end());
1505            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1506                                                                       *this,
1507                                                                       entry->second,
1508                                                                       this->getType(expr.fType)));
1509        }
1510        case Expression::kIndex_Kind: // fall through
1511        case Expression::kFieldAccess_Kind: {
1512            std::vector<SpvId> chain = this->getAccessChain(expr, out);
1513            SpvId member = this->nextId();
1514            this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1515            this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1516            this->writeWord(member, out);
1517            for (SpvId idx : chain) {
1518                this->writeWord(idx, out);
1519            }
1520            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1521                                                                       *this,
1522                                                                       member,
1523                                                                       this->getType(expr.fType)));
1524        }
1525        case Expression::kSwizzle_Kind: {
1526            Swizzle& swizzle = (Swizzle&) expr;
1527            size_t count = swizzle.fComponents.size();
1528            SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1529            ASSERT(base);
1530            if (count == 1) {
1531                IntLiteral index(fContext, -1, swizzle.fComponents[0]);
1532                SpvId member = this->nextId();
1533                this->writeInstruction(SpvOpAccessChain,
1534                                       this->getPointerType(swizzle.fType,
1535                                                            get_storage_class(*swizzle.fBase)),
1536                                       member,
1537                                       base,
1538                                       this->writeIntLiteral(index),
1539                                       out);
1540                return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1541                                                                       *this,
1542                                                                       member,
1543                                                                       this->getType(expr.fType)));
1544            } else {
1545                return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1546                                                                              *this,
1547                                                                              base,
1548                                                                              swizzle.fComponents,
1549                                                                              swizzle.fBase->fType,
1550                                                                              expr.fType));
1551            }
1552        }
1553        case Expression::kTernary_Kind: {
1554            TernaryExpression& t = (TernaryExpression&) expr;
1555            SpvId test = this->writeExpression(*t.fTest, out);
1556            SpvId end = this->nextId();
1557            SpvId ifTrueLabel = this->nextId();
1558            SpvId ifFalseLabel = this->nextId();
1559            this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
1560            this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
1561            this->writeLabel(ifTrueLabel, out);
1562            SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
1563            ASSERT(ifTrue);
1564            this->writeInstruction(SpvOpBranch, end, out);
1565            ifTrueLabel = fCurrentBlock;
1566            SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
1567            ASSERT(ifFalse);
1568            ifFalseLabel = fCurrentBlock;
1569            this->writeInstruction(SpvOpBranch, end, out);
1570            SpvId result = this->nextId();
1571            this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue,
1572                       ifTrueLabel, ifFalse, ifFalseLabel, out);
1573            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1574                                                                       *this,
1575                                                                       result,
1576                                                                       this->getType(expr.fType)));
1577        }
1578        default:
1579            // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1580            // to the need to store values in temporary variables during function calls (see
1581            // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1582            // caught by IRGenerator
1583            SpvId result = this->nextId();
1584            SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1585            this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1586                                   fVariableBuffer);
1587            this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1588            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1589                                                                       *this,
1590                                                                       result,
1591                                                                       this->getType(expr.fType)));
1592    }
1593}
1594
1595SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
1596    SpvId result = this->nextId();
1597    auto entry = fVariableMap.find(&ref.fVariable);
1598    ASSERT(entry != fVariableMap.end());
1599    SpvId var = entry->second;
1600    this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1601    if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1602        fProgram.fSettings.fFlipY) {
1603        // need to remap to a top-left coordinate system
1604        if (fRTHeightStructId == (SpvId) -1) {
1605            // height variable hasn't been written yet
1606            std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1607            ASSERT(fRTHeightFieldIndex == (SpvId) -1);
1608            std::vector<Type::Field> fields;
1609            fields.emplace_back(Modifiers(), SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get());
1610            StringFragment name("sksl_synthetic_uniforms");
1611            Type intfStruct(-1, name, fields);
1612            Layout layout(0, -1, -1, 1, -1, -1, -1, -1, Layout::Format::kUnspecified,
1613                          Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key,
1614                          StringFragment());
1615            Variable* intfVar = new Variable(-1,
1616                                             Modifiers(layout, Modifiers::kUniform_Flag),
1617                                             name,
1618                                             intfStruct,
1619                                             Variable::kGlobal_Storage);
1620            fSynthetics.takeOwnership(intfVar);
1621            InterfaceBlock intf(-1, intfVar, name, String(""),
1622                                std::vector<std::unique_ptr<Expression>>(), st);
1623            fRTHeightStructId = this->writeInterfaceBlock(intf);
1624            fRTHeightFieldIndex = 0;
1625        }
1626        ASSERT(fRTHeightFieldIndex != (SpvId) -1);
1627        // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0)
1628        SpvId xId = this->nextId();
1629        this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1630                               result, 0, out);
1631        IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex);
1632        SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1633        SpvId heightPtr = this->nextId();
1634        this->writeOpCode(SpvOpAccessChain, 5, out);
1635        this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1636        this->writeWord(heightPtr, out);
1637        this->writeWord(fRTHeightStructId, out);
1638        this->writeWord(fieldIndexId, out);
1639        SpvId heightRead = this->nextId();
1640        this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1641                               heightPtr, out);
1642        SpvId rawYId = this->nextId();
1643        this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1644                               result, 1, out);
1645        SpvId flippedYId = this->nextId();
1646        this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1647                               heightRead, rawYId, out);
1648        FloatLiteral zero(fContext, -1, 0.0);
1649        SpvId zeroId = writeFloatLiteral(zero);
1650        FloatLiteral one(fContext, -1, 1.0);
1651        SpvId oneId = writeFloatLiteral(one);
1652        SpvId flipped = this->nextId();
1653        this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1654        this->writeWord(this->getType(*fContext.fFloat4_Type), out);
1655        this->writeWord(flipped, out);
1656        this->writeWord(xId, out);
1657        this->writeWord(flippedYId, out);
1658        this->writeWord(zeroId, out);
1659        this->writeWord(oneId, out);
1660        return flipped;
1661    }
1662    return result;
1663}
1664
1665SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
1666    return getLValue(expr, out)->load(out);
1667}
1668
1669SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
1670    return getLValue(f, out)->load(out);
1671}
1672
1673SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
1674    SpvId base = this->writeExpression(*swizzle.fBase, out);
1675    SpvId result = this->nextId();
1676    size_t count = swizzle.fComponents.size();
1677    if (count == 1) {
1678        this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1679                               swizzle.fComponents[0], out);
1680    } else {
1681        this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1682        this->writeWord(this->getType(swizzle.fType), out);
1683        this->writeWord(result, out);
1684        this->writeWord(base, out);
1685        this->writeWord(base, out);
1686        for (int component : swizzle.fComponents) {
1687            this->writeWord(component, out);
1688        }
1689    }
1690    return result;
1691}
1692
1693SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1694                                               const Type& operandType, SpvId lhs,
1695                                               SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1696                                               SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
1697    SpvId result = this->nextId();
1698    if (is_float(fContext, operandType)) {
1699        this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1700    } else if (is_signed(fContext, operandType)) {
1701        this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1702    } else if (is_unsigned(fContext, operandType)) {
1703        this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1704    } else if (operandType == *fContext.fBool_Type) {
1705        this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1706    } else {
1707        ABORT("invalid operandType: %s", operandType.description().c_str());
1708    }
1709    return result;
1710}
1711
1712bool is_assignment(Token::Kind op) {
1713    switch (op) {
1714        case Token::EQ:           // fall through
1715        case Token::PLUSEQ:       // fall through
1716        case Token::MINUSEQ:      // fall through
1717        case Token::STAREQ:       // fall through
1718        case Token::SLASHEQ:      // fall through
1719        case Token::PERCENTEQ:    // fall through
1720        case Token::SHLEQ:        // fall through
1721        case Token::SHREQ:        // fall through
1722        case Token::BITWISEOREQ:  // fall through
1723        case Token::BITWISEXOREQ: // fall through
1724        case Token::BITWISEANDEQ: // fall through
1725        case Token::LOGICALOREQ:  // fall through
1726        case Token::LOGICALXOREQ: // fall through
1727        case Token::LOGICALANDEQ:
1728            return true;
1729        default:
1730            return false;
1731    }
1732}
1733
1734SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, OutputStream& out) {
1735    if (operandType.kind() == Type::kVector_Kind) {
1736        SpvId result = this->nextId();
1737        this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out);
1738        return result;
1739    }
1740    return id;
1741}
1742
1743SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
1744                                                SpvOp_ floatOperator, SpvOp_ intOperator,
1745                                                OutputStream& out) {
1746    SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
1747    ASSERT(operandType.kind() == Type::kMatrix_Kind);
1748    SpvId rowType = this->getType(operandType.componentType().toCompound(fContext,
1749                                                                         operandType.columns(),
1750                                                                         1));
1751    SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
1752                                                                    operandType.columns(),
1753                                                                    1));
1754    SpvId boolType = this->getType(*fContext.fBool_Type);
1755    SpvId result = 0;
1756    for (int i = 0; i < operandType.rows(); i++) {
1757        SpvId rowL = this->nextId();
1758        this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out);
1759        SpvId rowR = this->nextId();
1760        this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out);
1761        SpvId compare = this->nextId();
1762        this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out);
1763        SpvId all = this->nextId();
1764        this->writeInstruction(SpvOpAll, boolType, all, compare, out);
1765        if (result != 0) {
1766            SpvId next = this->nextId();
1767            this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out);
1768            result = next;
1769        }
1770        else {
1771            result = all;
1772        }
1773    }
1774    return result;
1775}
1776
1777SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
1778    // handle cases where we don't necessarily evaluate both LHS and RHS
1779    switch (b.fOperator) {
1780        case Token::EQ: {
1781            SpvId rhs = this->writeExpression(*b.fRight, out);
1782            this->getLValue(*b.fLeft, out)->store(rhs, out);
1783            return rhs;
1784        }
1785        case Token::LOGICALAND:
1786            return this->writeLogicalAnd(b, out);
1787        case Token::LOGICALOR:
1788            return this->writeLogicalOr(b, out);
1789        default:
1790            break;
1791    }
1792
1793    // "normal" operators
1794    const Type& resultType = b.fType;
1795    std::unique_ptr<LValue> lvalue;
1796    SpvId lhs;
1797    if (is_assignment(b.fOperator)) {
1798        lvalue = this->getLValue(*b.fLeft, out);
1799        lhs = lvalue->load(out);
1800    } else {
1801        lvalue = nullptr;
1802        lhs = this->writeExpression(*b.fLeft, out);
1803    }
1804    SpvId rhs = this->writeExpression(*b.fRight, out);
1805    if (b.fOperator == Token::COMMA) {
1806        return rhs;
1807    }
1808    Type tmp("<invalid>");
1809    // component type we are operating on: float, int, uint
1810    const Type* operandType;
1811    // IR allows mismatched types in expressions (e.g. float2* float), but they need special handling
1812    // in SPIR-V
1813    if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) {
1814        if (b.fLeft->fType.kind() == Type::kVector_Kind &&
1815            b.fRight->fType.isNumber()) {
1816            // promote number to vector
1817            SpvId vec = this->nextId();
1818            this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
1819            this->writeWord(this->getType(resultType), out);
1820            this->writeWord(vec, out);
1821            for (int i = 0; i < resultType.columns(); i++) {
1822                this->writeWord(rhs, out);
1823            }
1824            rhs = vec;
1825            operandType = &b.fRight->fType;
1826        } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
1827                   b.fLeft->fType.isNumber()) {
1828            // promote number to vector
1829            SpvId vec = this->nextId();
1830            this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
1831            this->writeWord(this->getType(resultType), out);
1832            this->writeWord(vec, out);
1833            for (int i = 0; i < resultType.columns(); i++) {
1834                this->writeWord(lhs, out);
1835            }
1836            lhs = vec;
1837            ASSERT(!lvalue);
1838            operandType = &b.fLeft->fType;
1839        } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
1840            SpvOp_ op;
1841            if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
1842                op = SpvOpMatrixTimesMatrix;
1843            } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
1844                op = SpvOpMatrixTimesVector;
1845            } else {
1846                ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
1847                op = SpvOpMatrixTimesScalar;
1848            }
1849            SpvId result = this->nextId();
1850            this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
1851            if (b.fOperator == Token::STAREQ) {
1852                lvalue->store(result, out);
1853            } else {
1854                ASSERT(b.fOperator == Token::STAR);
1855            }
1856            return result;
1857        } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
1858            SpvId result = this->nextId();
1859            if (b.fLeft->fType.kind() == Type::kVector_Kind) {
1860                this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
1861                                       lhs, rhs, out);
1862            } else {
1863                ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
1864                this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
1865                                       lhs, out);
1866            }
1867            if (b.fOperator == Token::STAREQ) {
1868                lvalue->store(result, out);
1869            } else {
1870                ASSERT(b.fOperator == Token::STAR);
1871            }
1872            return result;
1873        } else {
1874            ABORT("unsupported binary expression: %s", b.description().c_str());
1875        }
1876    } else {
1877        tmp = this->getActualType(b.fLeft->fType);
1878        operandType = &tmp;
1879        ASSERT(*operandType == this->getActualType(b.fRight->fType));
1880    }
1881    switch (b.fOperator) {
1882        case Token::EQEQ: {
1883            if (operandType->kind() == Type::kMatrix_Kind) {
1884                return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
1885                                                   SpvOpIEqual, out);
1886            }
1887            ASSERT(resultType == *fContext.fBool_Type);
1888            return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1889                                                               SpvOpFOrdEqual, SpvOpIEqual,
1890                                                               SpvOpIEqual, SpvOpLogicalEqual, out),
1891                                    *operandType, out);
1892        }
1893        case Token::NEQ:
1894            if (operandType->kind() == Type::kMatrix_Kind) {
1895                return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
1896                                                   SpvOpINotEqual, out);
1897            }
1898            ASSERT(resultType == *fContext.fBool_Type);
1899            return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1900                                                               SpvOpFOrdNotEqual, SpvOpINotEqual,
1901                                                               SpvOpINotEqual, SpvOpLogicalNotEqual,
1902                                                               out),
1903                                    *operandType, out);
1904        case Token::GT:
1905            ASSERT(resultType == *fContext.fBool_Type);
1906            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1907                                              SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
1908                                              SpvOpUGreaterThan, SpvOpUndef, out);
1909        case Token::LT:
1910            ASSERT(resultType == *fContext.fBool_Type);
1911            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
1912                                              SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
1913        case Token::GTEQ:
1914            ASSERT(resultType == *fContext.fBool_Type);
1915            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1916                                              SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
1917                                              SpvOpUGreaterThanEqual, SpvOpUndef, out);
1918        case Token::LTEQ:
1919            ASSERT(resultType == *fContext.fBool_Type);
1920            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1921                                              SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
1922                                              SpvOpULessThanEqual, SpvOpUndef, out);
1923        case Token::PLUS:
1924            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
1925                                              SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
1926        case Token::MINUS:
1927            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
1928                                              SpvOpISub, SpvOpISub, SpvOpUndef, out);
1929        case Token::STAR:
1930            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
1931                b.fRight->fType.kind() == Type::kMatrix_Kind) {
1932                // matrix multiply
1933                SpvId result = this->nextId();
1934                this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
1935                                       lhs, rhs, out);
1936                return result;
1937            }
1938            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
1939                                              SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
1940        case Token::SLASH:
1941            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
1942                                              SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
1943        case Token::PERCENT:
1944            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
1945                                              SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
1946        case Token::SHL:
1947            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
1948                                              SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
1949                                              SpvOpUndef, out);
1950        case Token::SHR:
1951            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
1952                                              SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
1953                                              SpvOpUndef, out);
1954        case Token::BITWISEAND:
1955            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
1956                                              SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
1957        case Token::BITWISEOR:
1958            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
1959                                              SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
1960        case Token::BITWISEXOR:
1961            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
1962                                              SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
1963        case Token::PLUSEQ: {
1964            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
1965                                                      SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
1966            ASSERT(lvalue);
1967            lvalue->store(result, out);
1968            return result;
1969        }
1970        case Token::MINUSEQ: {
1971            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
1972                                                      SpvOpISub, SpvOpISub, SpvOpUndef, out);
1973            ASSERT(lvalue);
1974            lvalue->store(result, out);
1975            return result;
1976        }
1977        case Token::STAREQ: {
1978            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
1979                b.fRight->fType.kind() == Type::kMatrix_Kind) {
1980                // matrix multiply
1981                SpvId result = this->nextId();
1982                this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
1983                                       lhs, rhs, out);
1984                ASSERT(lvalue);
1985                lvalue->store(result, out);
1986                return result;
1987            }
1988            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
1989                                                      SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
1990            ASSERT(lvalue);
1991            lvalue->store(result, out);
1992            return result;
1993        }
1994        case Token::SLASHEQ: {
1995            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
1996                                                      SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
1997            ASSERT(lvalue);
1998            lvalue->store(result, out);
1999            return result;
2000        }
2001        case Token::PERCENTEQ: {
2002            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2003                                                      SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2004            ASSERT(lvalue);
2005            lvalue->store(result, out);
2006            return result;
2007        }
2008        case Token::SHLEQ: {
2009            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2010                                                      SpvOpUndef, SpvOpShiftLeftLogical,
2011                                                      SpvOpShiftLeftLogical, SpvOpUndef, out);
2012            ASSERT(lvalue);
2013            lvalue->store(result, out);
2014            return result;
2015        }
2016        case Token::SHREQ: {
2017            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2018                                                      SpvOpUndef, SpvOpShiftRightArithmetic,
2019                                                      SpvOpShiftRightLogical, SpvOpUndef, out);
2020            ASSERT(lvalue);
2021            lvalue->store(result, out);
2022            return result;
2023        }
2024        case Token::BITWISEANDEQ: {
2025            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2026                                                      SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
2027                                                      SpvOpUndef, out);
2028            ASSERT(lvalue);
2029            lvalue->store(result, out);
2030            return result;
2031        }
2032        case Token::BITWISEOREQ: {
2033            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2034                                                      SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
2035                                                      SpvOpUndef, out);
2036            ASSERT(lvalue);
2037            lvalue->store(result, out);
2038            return result;
2039        }
2040        case Token::BITWISEXOREQ: {
2041            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2042                                                      SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
2043                                                      SpvOpUndef, out);
2044            ASSERT(lvalue);
2045            lvalue->store(result, out);
2046            return result;
2047        }
2048        default:
2049            ABORT("unsupported binary expression: %s", b.description().c_str());
2050    }
2051}
2052
2053SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
2054    ASSERT(a.fOperator == Token::LOGICALAND);
2055    BoolLiteral falseLiteral(fContext, -1, false);
2056    SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2057    SpvId lhs = this->writeExpression(*a.fLeft, out);
2058    SpvId rhsLabel = this->nextId();
2059    SpvId end = this->nextId();
2060    SpvId lhsBlock = fCurrentBlock;
2061    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2062    this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2063    this->writeLabel(rhsLabel, out);
2064    SpvId rhs = this->writeExpression(*a.fRight, out);
2065    SpvId rhsBlock = fCurrentBlock;
2066    this->writeInstruction(SpvOpBranch, end, out);
2067    this->writeLabel(end, out);
2068    SpvId result = this->nextId();
2069    this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2070                           lhsBlock, rhs, rhsBlock, out);
2071    return result;
2072}
2073
2074SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
2075    ASSERT(o.fOperator == Token::LOGICALOR);
2076    BoolLiteral trueLiteral(fContext, -1, true);
2077    SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2078    SpvId lhs = this->writeExpression(*o.fLeft, out);
2079    SpvId rhsLabel = this->nextId();
2080    SpvId end = this->nextId();
2081    SpvId lhsBlock = fCurrentBlock;
2082    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2083    this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2084    this->writeLabel(rhsLabel, out);
2085    SpvId rhs = this->writeExpression(*o.fRight, out);
2086    SpvId rhsBlock = fCurrentBlock;
2087    this->writeInstruction(SpvOpBranch, end, out);
2088    this->writeLabel(end, out);
2089    SpvId result = this->nextId();
2090    this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2091                           lhsBlock, rhs, rhsBlock, out);
2092    return result;
2093}
2094
2095SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2096    SpvId test = this->writeExpression(*t.fTest, out);
2097    if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2098        // both true and false are constants, can just use OpSelect
2099        SpvId result = this->nextId();
2100        SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2101        SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2102        this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2103                               out);
2104        return result;
2105    }
2106    // was originally using OpPhi to choose the result, but for some reason that is crashing on
2107    // Adreno. Switched to storing the result in a temp variable as glslang does.
2108    SpvId var = this->nextId();
2109    this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2110                           var, SpvStorageClassFunction, fVariableBuffer);
2111    SpvId trueLabel = this->nextId();
2112    SpvId falseLabel = this->nextId();
2113    SpvId end = this->nextId();
2114    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2115    this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2116    this->writeLabel(trueLabel, out);
2117    this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2118    this->writeInstruction(SpvOpBranch, end, out);
2119    this->writeLabel(falseLabel, out);
2120    this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2121    this->writeInstruction(SpvOpBranch, end, out);
2122    this->writeLabel(end, out);
2123    SpvId result = this->nextId();
2124    this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2125    return result;
2126}
2127
2128std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2129    if (type.isInteger()) {
2130        return std::unique_ptr<Expression>(new IntLiteral(context, -1, 1, &type));
2131    }
2132    else if (type.isFloat()) {
2133        return std::unique_ptr<Expression>(new FloatLiteral(context, -1, 1.0, &type));
2134    } else {
2135        ABORT("math is unsupported on type '%s'", type.name().c_str());
2136    }
2137}
2138
2139SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2140    if (p.fOperator == Token::MINUS) {
2141        SpvId result = this->nextId();
2142        SpvId typeId = this->getType(p.fType);
2143        SpvId expr = this->writeExpression(*p.fOperand, out);
2144        if (is_float(fContext, p.fType)) {
2145            this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2146        } else if (is_signed(fContext, p.fType)) {
2147            this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2148        } else {
2149            ABORT("unsupported prefix expression %s", p.description().c_str());
2150        };
2151        return result;
2152    }
2153    switch (p.fOperator) {
2154        case Token::PLUS:
2155            return this->writeExpression(*p.fOperand, out);
2156        case Token::PLUSPLUS: {
2157            std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2158            SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2159            SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2160                                                      SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2161                                                      out);
2162            lv->store(result, out);
2163            return result;
2164        }
2165        case Token::MINUSMINUS: {
2166            std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2167            SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2168            SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2169                                                      SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2170                                                      out);
2171            lv->store(result, out);
2172            return result;
2173        }
2174        case Token::LOGICALNOT: {
2175            ASSERT(p.fOperand->fType == *fContext.fBool_Type);
2176            SpvId result = this->nextId();
2177            this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2178                                   this->writeExpression(*p.fOperand, out), out);
2179            return result;
2180        }
2181        case Token::BITWISENOT: {
2182            SpvId result = this->nextId();
2183            this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2184                                   this->writeExpression(*p.fOperand, out), out);
2185            return result;
2186        }
2187        default:
2188            ABORT("unsupported prefix expression: %s", p.description().c_str());
2189    }
2190}
2191
2192SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2193    std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2194    SpvId result = lv->load(out);
2195    SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2196    switch (p.fOperator) {
2197        case Token::PLUSPLUS: {
2198            SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2199                                                    SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2200            lv->store(temp, out);
2201            return result;
2202        }
2203        case Token::MINUSMINUS: {
2204            SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2205                                                    SpvOpISub, SpvOpISub, SpvOpUndef, out);
2206            lv->store(temp, out);
2207            return result;
2208        }
2209        default:
2210            ABORT("unsupported postfix expression %s", p.description().c_str());
2211    }
2212}
2213
2214SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2215    if (b.fValue) {
2216        if (fBoolTrue == 0) {
2217            fBoolTrue = this->nextId();
2218            this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2219                                   fConstantBuffer);
2220        }
2221        return fBoolTrue;
2222    } else {
2223        if (fBoolFalse == 0) {
2224            fBoolFalse = this->nextId();
2225            this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2226                                   fConstantBuffer);
2227        }
2228        return fBoolFalse;
2229    }
2230}
2231
2232SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2233    if (i.fType == *fContext.fInt_Type) {
2234        auto entry = fIntConstants.find(i.fValue);
2235        if (entry == fIntConstants.end()) {
2236            SpvId result = this->nextId();
2237            this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2238                                   fConstantBuffer);
2239            fIntConstants[i.fValue] = result;
2240            return result;
2241        }
2242        return entry->second;
2243    } else {
2244        ASSERT(i.fType == *fContext.fUInt_Type);
2245        auto entry = fUIntConstants.find(i.fValue);
2246        if (entry == fUIntConstants.end()) {
2247            SpvId result = this->nextId();
2248            this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2249                                   fConstantBuffer);
2250            fUIntConstants[i.fValue] = result;
2251            return result;
2252        }
2253        return entry->second;
2254    }
2255}
2256
2257SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2258    if (f.fType == *fContext.fFloat_Type || f.fType == *fContext.fHalf_Type) {
2259        float value = (float) f.fValue;
2260        auto entry = fFloatConstants.find(value);
2261        if (entry == fFloatConstants.end()) {
2262            SpvId result = this->nextId();
2263            uint32_t bits;
2264            ASSERT(sizeof(bits) == sizeof(value));
2265            memcpy(&bits, &value, sizeof(bits));
2266            this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2267                                   fConstantBuffer);
2268            fFloatConstants[value] = result;
2269            return result;
2270        }
2271        return entry->second;
2272    } else {
2273        ASSERT(f.fType == *fContext.fDouble_Type);
2274        auto entry = fDoubleConstants.find(f.fValue);
2275        if (entry == fDoubleConstants.end()) {
2276            SpvId result = this->nextId();
2277            uint64_t bits;
2278            ASSERT(sizeof(bits) == sizeof(f.fValue));
2279            memcpy(&bits, &f.fValue, sizeof(bits));
2280            this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2281                                   bits & 0xffffffff, bits >> 32, fConstantBuffer);
2282            fDoubleConstants[f.fValue] = result;
2283            return result;
2284        }
2285        return entry->second;
2286    }
2287}
2288
2289SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2290    SpvId result = fFunctionMap[&f];
2291    this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2292                           SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2293    this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer);
2294    for (size_t i = 0; i < f.fParameters.size(); i++) {
2295        SpvId id = this->nextId();
2296        fVariableMap[f.fParameters[i]] = id;
2297        SpvId type;
2298        type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2299        this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2300    }
2301    return result;
2302}
2303
2304SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2305    fVariableBuffer.reset();
2306    SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2307    this->writeLabel(this->nextId(), out);
2308    if (f.fDeclaration.fName == "main") {
2309        write_stringstream(fGlobalInitializersBuffer, out);
2310    }
2311    StringStream bodyBuffer;
2312    this->writeBlock((Block&) *f.fBody, bodyBuffer);
2313    write_stringstream(fVariableBuffer, out);
2314    write_stringstream(bodyBuffer, out);
2315    if (fCurrentBlock) {
2316        if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) {
2317            this->writeInstruction(SpvOpReturn, out);
2318        } else {
2319            this->writeInstruction(SpvOpUnreachable, out);
2320        }
2321    }
2322    this->writeInstruction(SpvOpFunctionEnd, out);
2323    return result;
2324}
2325
2326void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2327    if (layout.fLocation >= 0) {
2328        this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2329                               fDecorationBuffer);
2330    }
2331    if (layout.fBinding >= 0) {
2332        this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2333                               fDecorationBuffer);
2334    }
2335    if (layout.fIndex >= 0) {
2336        this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2337                               fDecorationBuffer);
2338    }
2339    if (layout.fSet >= 0) {
2340        this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2341                               fDecorationBuffer);
2342    }
2343    if (layout.fInputAttachmentIndex >= 0) {
2344        this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2345                               layout.fInputAttachmentIndex, fDecorationBuffer);
2346        fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2347    }
2348    if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
2349        layout.fBuiltin != SK_IN_BUILTIN) {
2350        this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2351                               fDecorationBuffer);
2352    }
2353}
2354
2355void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2356    if (layout.fLocation >= 0) {
2357        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2358                               layout.fLocation, fDecorationBuffer);
2359    }
2360    if (layout.fBinding >= 0) {
2361        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2362                               layout.fBinding, fDecorationBuffer);
2363    }
2364    if (layout.fIndex >= 0) {
2365        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2366                               layout.fIndex, fDecorationBuffer);
2367    }
2368    if (layout.fSet >= 0) {
2369        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2370                               layout.fSet, fDecorationBuffer);
2371    }
2372    if (layout.fInputAttachmentIndex >= 0) {
2373        this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2374                               layout.fInputAttachmentIndex, fDecorationBuffer);
2375    }
2376    if (layout.fBuiltin >= 0) {
2377        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2378                               layout.fBuiltin, fDecorationBuffer);
2379    }
2380}
2381
2382SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2383    bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
2384    bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags &
2385                               Layout::kPushConstant_Flag));
2386    MemoryLayout layout = (pushConstant || isBuffer) ?
2387                          MemoryLayout(MemoryLayout::k430_Standard) :
2388                          fDefaultLayout;
2389    SpvId result = this->nextId();
2390    const Type* type = &intf.fVariable.fType;
2391    if (fProgram.fInputs.fRTHeight) {
2392        ASSERT(fRTHeightStructId == (SpvId) -1);
2393        ASSERT(fRTHeightFieldIndex == (SpvId) -1);
2394        std::vector<Type::Field> fields = type->fields();
2395        fRTHeightStructId = result;
2396        fRTHeightFieldIndex = fields.size();
2397        fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2398        type = new Type(type->fOffset, type->name(), fields);
2399    }
2400    SpvId typeId = this->getType(*type, layout);
2401    if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
2402        this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
2403    } else {
2404        this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2405    }
2406    SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2407    SpvId ptrType = this->nextId();
2408    this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2409    this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2410    this->writeLayout(intf.fVariable.fModifiers.fLayout, result);
2411    fVariableMap[&intf.fVariable] = result;
2412    if (fProgram.fInputs.fRTHeight) {
2413        delete type;
2414    }
2415    return result;
2416}
2417
2418void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) {
2419    if ((modifiers.fFlags & Modifiers::kLowp_Flag) |
2420        (modifiers.fFlags & Modifiers::kMediump_Flag)) {
2421        this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
2422    }
2423}
2424
2425#define BUILTIN_IGNORE 9999
2426void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2427                                         OutputStream& out) {
2428    for (size_t i = 0; i < decl.fVars.size(); i++) {
2429        if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
2430            continue;
2431        }
2432        const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
2433        const Variable* var = varDecl.fVar;
2434        // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2435        // in the OpenGL backend.
2436        ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2437                                           Modifiers::kWriteOnly_Flag |
2438                                           Modifiers::kCoherent_Flag |
2439                                           Modifiers::kVolatile_Flag |
2440                                           Modifiers::kRestrict_Flag)));
2441        if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2442            continue;
2443        }
2444        if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2445            kind != Program::kFragment_Kind) {
2446            continue;
2447        }
2448        if (!var->fReadCount && !var->fWriteCount &&
2449                !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2450                                            Modifiers::kOut_Flag |
2451                                            Modifiers::kUniform_Flag |
2452                                            Modifiers::kBuffer_Flag))) {
2453            // variable is dead and not an input / output var (the Vulkan debug layers complain if
2454            // we elide an interface var, even if it's dead)
2455            continue;
2456        }
2457        SpvStorageClass_ storageClass;
2458        if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2459            storageClass = SpvStorageClassInput;
2460        } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2461            storageClass = SpvStorageClassOutput;
2462        } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2463            if (var->fType.kind() == Type::kSampler_Kind) {
2464                storageClass = SpvStorageClassUniformConstant;
2465            } else {
2466                storageClass = SpvStorageClassUniform;
2467            }
2468        } else {
2469            storageClass = SpvStorageClassPrivate;
2470        }
2471        SpvId id = this->nextId();
2472        fVariableMap[var] = id;
2473        SpvId type = this->getPointerType(var->fType, storageClass);
2474        this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2475        this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2476        this->writePrecisionModifier(var->fModifiers, id);
2477        if (varDecl.fValue) {
2478            ASSERT(!fCurrentBlock);
2479            fCurrentBlock = -1;
2480            SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2481            this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2482            fCurrentBlock = 0;
2483        }
2484        this->writeLayout(var->fModifiers.fLayout, id);
2485        if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) {
2486            this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
2487        }
2488        if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) {
2489            this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
2490                                   fDecorationBuffer);
2491        }
2492    }
2493}
2494
2495void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
2496    for (const auto& stmt : decl.fVars) {
2497        ASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
2498        VarDeclaration& varDecl = (VarDeclaration&) *stmt;
2499        const Variable* var = varDecl.fVar;
2500        // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2501        // in the OpenGL backend.
2502        ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2503                                           Modifiers::kWriteOnly_Flag |
2504                                           Modifiers::kCoherent_Flag |
2505                                           Modifiers::kVolatile_Flag |
2506                                           Modifiers::kRestrict_Flag)));
2507        SpvId id = this->nextId();
2508        fVariableMap[var] = id;
2509        SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2510        this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2511        this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2512        if (varDecl.fValue) {
2513            SpvId value = this->writeExpression(*varDecl.fValue, out);
2514            this->writeInstruction(SpvOpStore, id, value, out);
2515        }
2516    }
2517}
2518
2519void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
2520    switch (s.fKind) {
2521        case Statement::kNop_Kind:
2522            break;
2523        case Statement::kBlock_Kind:
2524            this->writeBlock((Block&) s, out);
2525            break;
2526        case Statement::kExpression_Kind:
2527            this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2528            break;
2529        case Statement::kReturn_Kind:
2530            this->writeReturnStatement((ReturnStatement&) s, out);
2531            break;
2532        case Statement::kVarDeclarations_Kind:
2533            this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2534            break;
2535        case Statement::kIf_Kind:
2536            this->writeIfStatement((IfStatement&) s, out);
2537            break;
2538        case Statement::kFor_Kind:
2539            this->writeForStatement((ForStatement&) s, out);
2540            break;
2541        case Statement::kWhile_Kind:
2542            this->writeWhileStatement((WhileStatement&) s, out);
2543            break;
2544        case Statement::kDo_Kind:
2545            this->writeDoStatement((DoStatement&) s, out);
2546            break;
2547        case Statement::kSwitch_Kind:
2548            this->writeSwitchStatement((SwitchStatement&) s, out);
2549            break;
2550        case Statement::kBreak_Kind:
2551            this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2552            break;
2553        case Statement::kContinue_Kind:
2554            this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2555            break;
2556        case Statement::kDiscard_Kind:
2557            this->writeInstruction(SpvOpKill, out);
2558            break;
2559        default:
2560            ABORT("unsupported statement: %s", s.description().c_str());
2561    }
2562}
2563
2564void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
2565    for (size_t i = 0; i < b.fStatements.size(); i++) {
2566        this->writeStatement(*b.fStatements[i], out);
2567    }
2568}
2569
2570void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
2571    SpvId test = this->writeExpression(*stmt.fTest, out);
2572    SpvId ifTrue = this->nextId();
2573    SpvId ifFalse = this->nextId();
2574    if (stmt.fIfFalse) {
2575        SpvId end = this->nextId();
2576        this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2577        this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2578        this->writeLabel(ifTrue, out);
2579        this->writeStatement(*stmt.fIfTrue, out);
2580        if (fCurrentBlock) {
2581            this->writeInstruction(SpvOpBranch, end, out);
2582        }
2583        this->writeLabel(ifFalse, out);
2584        this->writeStatement(*stmt.fIfFalse, out);
2585        if (fCurrentBlock) {
2586            this->writeInstruction(SpvOpBranch, end, out);
2587        }
2588        this->writeLabel(end, out);
2589    } else {
2590        this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2591        this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2592        this->writeLabel(ifTrue, out);
2593        this->writeStatement(*stmt.fIfTrue, out);
2594        if (fCurrentBlock) {
2595            this->writeInstruction(SpvOpBranch, ifFalse, out);
2596        }
2597        this->writeLabel(ifFalse, out);
2598    }
2599}
2600
2601void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
2602    if (f.fInitializer) {
2603        this->writeStatement(*f.fInitializer, out);
2604    }
2605    SpvId header = this->nextId();
2606    SpvId start = this->nextId();
2607    SpvId body = this->nextId();
2608    SpvId next = this->nextId();
2609    fContinueTarget.push(next);
2610    SpvId end = this->nextId();
2611    fBreakTarget.push(end);
2612    this->writeInstruction(SpvOpBranch, header, out);
2613    this->writeLabel(header, out);
2614    this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2615    this->writeInstruction(SpvOpBranch, start, out);
2616    this->writeLabel(start, out);
2617    if (f.fTest) {
2618        SpvId test = this->writeExpression(*f.fTest, out);
2619        this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2620    }
2621    this->writeLabel(body, out);
2622    this->writeStatement(*f.fStatement, out);
2623    if (fCurrentBlock) {
2624        this->writeInstruction(SpvOpBranch, next, out);
2625    }
2626    this->writeLabel(next, out);
2627    if (f.fNext) {
2628        this->writeExpression(*f.fNext, out);
2629    }
2630    this->writeInstruction(SpvOpBranch, header, out);
2631    this->writeLabel(end, out);
2632    fBreakTarget.pop();
2633    fContinueTarget.pop();
2634}
2635
2636void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
2637    // We believe the while loop code below will work, but Skia doesn't actually use them and
2638    // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2639    // the time being, we just fail with an error due to the lack of testing. If you encounter this
2640    // message, simply remove the error call below to see whether our while loop support actually
2641    // works.
2642    fErrors.error(w.fOffset, "internal error: while loop support has been disabled in SPIR-V, "
2643                  "see SkSLSPIRVCodeGenerator.cpp for details");
2644
2645    SpvId header = this->nextId();
2646    SpvId start = this->nextId();
2647    SpvId body = this->nextId();
2648    fContinueTarget.push(start);
2649    SpvId end = this->nextId();
2650    fBreakTarget.push(end);
2651    this->writeInstruction(SpvOpBranch, header, out);
2652    this->writeLabel(header, out);
2653    this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2654    this->writeInstruction(SpvOpBranch, start, out);
2655    this->writeLabel(start, out);
2656    SpvId test = this->writeExpression(*w.fTest, out);
2657    this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2658    this->writeLabel(body, out);
2659    this->writeStatement(*w.fStatement, out);
2660    if (fCurrentBlock) {
2661        this->writeInstruction(SpvOpBranch, start, out);
2662    }
2663    this->writeLabel(end, out);
2664    fBreakTarget.pop();
2665    fContinueTarget.pop();
2666}
2667
2668void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
2669    // We believe the do loop code below will work, but Skia doesn't actually use them and
2670    // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2671    // the time being, we just fail with an error due to the lack of testing. If you encounter this
2672    // message, simply remove the error call below to see whether our do loop support actually
2673    // works.
2674    fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see "
2675                  "SkSLSPIRVCodeGenerator.cpp for details");
2676
2677    SpvId header = this->nextId();
2678    SpvId start = this->nextId();
2679    SpvId next = this->nextId();
2680    fContinueTarget.push(next);
2681    SpvId end = this->nextId();
2682    fBreakTarget.push(end);
2683    this->writeInstruction(SpvOpBranch, header, out);
2684    this->writeLabel(header, out);
2685    this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2686    this->writeInstruction(SpvOpBranch, start, out);
2687    this->writeLabel(start, out);
2688    this->writeStatement(*d.fStatement, out);
2689    if (fCurrentBlock) {
2690        this->writeInstruction(SpvOpBranch, next, out);
2691    }
2692    this->writeLabel(next, out);
2693    SpvId test = this->writeExpression(*d.fTest, out);
2694    this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
2695    this->writeLabel(end, out);
2696    fBreakTarget.pop();
2697    fContinueTarget.pop();
2698}
2699
2700void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
2701    SpvId value = this->writeExpression(*s.fValue, out);
2702    std::vector<SpvId> labels;
2703    SpvId end = this->nextId();
2704    SpvId defaultLabel = end;
2705    fBreakTarget.push(end);
2706    int size = 3;
2707    for (const auto& c : s.fCases) {
2708        SpvId label = this->nextId();
2709        labels.push_back(label);
2710        if (c->fValue) {
2711            size += 2;
2712        } else {
2713            defaultLabel = label;
2714        }
2715    }
2716    labels.push_back(end);
2717    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2718    this->writeOpCode(SpvOpSwitch, size, out);
2719    this->writeWord(value, out);
2720    this->writeWord(defaultLabel, out);
2721    for (size_t i = 0; i < s.fCases.size(); ++i) {
2722        if (!s.fCases[i]->fValue) {
2723            continue;
2724        }
2725        ASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind);
2726        this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out);
2727        this->writeWord(labels[i], out);
2728    }
2729    for (size_t i = 0; i < s.fCases.size(); ++i) {
2730        this->writeLabel(labels[i], out);
2731        for (const auto& stmt : s.fCases[i]->fStatements) {
2732            this->writeStatement(*stmt, out);
2733        }
2734        if (fCurrentBlock) {
2735            this->writeInstruction(SpvOpBranch, labels[i + 1], out);
2736        }
2737    }
2738    this->writeLabel(end, out);
2739    fBreakTarget.pop();
2740}
2741
2742void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
2743    if (r.fExpression) {
2744        this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
2745                               out);
2746    } else {
2747        this->writeInstruction(SpvOpReturn, out);
2748    }
2749}
2750
2751void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
2752    ASSERT(fProgram.fKind == Program::kGeometry_Kind);
2753    int invocations = 1;
2754    for (size_t i = 0; i < fProgram.fElements.size(); i++) {
2755        if (fProgram.fElements[i]->fKind == ProgramElement::kModifiers_Kind) {
2756            const Modifiers& m = ((ModifiersDeclaration&) *fProgram.fElements[i]).fModifiers;
2757            if (m.fFlags & Modifiers::kIn_Flag) {
2758                if (m.fLayout.fInvocations != -1) {
2759                    invocations = m.fLayout.fInvocations;
2760                }
2761                SpvId input;
2762                switch (m.fLayout.fPrimitive) {
2763                    case Layout::kPoints_Primitive:
2764                        input = SpvExecutionModeInputPoints;
2765                        break;
2766                    case Layout::kLines_Primitive:
2767                        input = SpvExecutionModeInputLines;
2768                        break;
2769                    case Layout::kLinesAdjacency_Primitive:
2770                        input = SpvExecutionModeInputLinesAdjacency;
2771                        break;
2772                    case Layout::kTriangles_Primitive:
2773                        input = SpvExecutionModeTriangles;
2774                        break;
2775                    case Layout::kTrianglesAdjacency_Primitive:
2776                        input = SpvExecutionModeInputTrianglesAdjacency;
2777                        break;
2778                    default:
2779                        input = 0;
2780                        break;
2781                }
2782                if (input) {
2783                    this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out);
2784                }
2785            } else if (m.fFlags & Modifiers::kOut_Flag) {
2786                SpvId output;
2787                switch (m.fLayout.fPrimitive) {
2788                    case Layout::kPoints_Primitive:
2789                        output = SpvExecutionModeOutputPoints;
2790                        break;
2791                    case Layout::kLineStrip_Primitive:
2792                        output = SpvExecutionModeOutputLineStrip;
2793                        break;
2794                    case Layout::kTriangleStrip_Primitive:
2795                        output = SpvExecutionModeOutputTriangleStrip;
2796                        break;
2797                    default:
2798                        output = 0;
2799                        break;
2800                }
2801                if (output) {
2802                    this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out);
2803                }
2804                if (m.fLayout.fMaxVertices != -1) {
2805                    this->writeInstruction(SpvOpExecutionMode, entryPoint,
2806                                           SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices,
2807                                           out);
2808                }
2809            }
2810        }
2811    }
2812    this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations,
2813                           invocations, out);
2814}
2815
2816void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
2817    fGLSLExtendedInstructions = this->nextId();
2818    StringStream body;
2819    std::set<SpvId> interfaceVars;
2820    // assign IDs to functions, determine sk_in size
2821    int skInSize = -1;
2822    for (size_t i = 0; i < program.fElements.size(); i++) {
2823        switch (program.fElements[i]->fKind) {
2824            case ProgramElement::kFunction_Kind: {
2825                FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
2826                fFunctionMap[&f.fDeclaration] = this->nextId();
2827                break;
2828            }
2829            case ProgramElement::kModifiers_Kind: {
2830                Modifiers& m = ((ModifiersDeclaration&) *program.fElements[i]).fModifiers;
2831                if (m.fFlags & Modifiers::kIn_Flag) {
2832                    switch (m.fLayout.fPrimitive) {
2833                        case Layout::kPoints_Primitive: // break
2834                        case Layout::kLines_Primitive:
2835                            skInSize = 1;
2836                            break;
2837                        case Layout::kLinesAdjacency_Primitive: // break
2838                            skInSize = 2;
2839                            break;
2840                        case Layout::kTriangles_Primitive: // break
2841                        case Layout::kTrianglesAdjacency_Primitive:
2842                            skInSize = 3;
2843                            break;
2844                        default:
2845                            break;
2846                    }
2847                }
2848                break;
2849            }
2850            default:
2851                break;
2852        }
2853    }
2854    for (size_t i = 0; i < program.fElements.size(); i++) {
2855        if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
2856            InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
2857            if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) {
2858                ASSERT(skInSize != -1);
2859                intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize));
2860            }
2861            SpvId id = this->writeInterfaceBlock(intf);
2862            if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
2863                (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
2864                interfaceVars.insert(id);
2865            }
2866        }
2867    }
2868    for (size_t i = 0; i < program.fElements.size(); i++) {
2869        if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
2870            this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]),
2871                                  body);
2872        }
2873    }
2874    for (size_t i = 0; i < program.fElements.size(); i++) {
2875        if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
2876            this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
2877        }
2878    }
2879    const FunctionDeclaration* main = nullptr;
2880    for (auto entry : fFunctionMap) {
2881        if (entry.first->fName == "main") {
2882            main = entry.first;
2883        }
2884    }
2885    ASSERT(main);
2886    for (auto entry : fVariableMap) {
2887        const Variable* var = entry.first;
2888        if (var->fStorage == Variable::kGlobal_Storage &&
2889                ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
2890                 (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
2891            interfaceVars.insert(entry.second);
2892        }
2893    }
2894    this->writeCapabilities(out);
2895    this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
2896    this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
2897    this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) +
2898                      (int32_t) interfaceVars.size(), out);
2899    switch (program.fKind) {
2900        case Program::kVertex_Kind:
2901            this->writeWord(SpvExecutionModelVertex, out);
2902            break;
2903        case Program::kFragment_Kind:
2904            this->writeWord(SpvExecutionModelFragment, out);
2905            break;
2906        case Program::kGeometry_Kind:
2907            this->writeWord(SpvExecutionModelGeometry, out);
2908            break;
2909        default:
2910            ABORT("cannot write this kind of program to SPIR-V\n");
2911    }
2912    SpvId entryPoint = fFunctionMap[main];
2913    this->writeWord(entryPoint, out);
2914    this->writeString(main->fName.fChars, main->fName.fLength, out);
2915    for (int var : interfaceVars) {
2916        this->writeWord(var, out);
2917    }
2918    if (program.fKind == Program::kGeometry_Kind) {
2919        this->writeGeometryShaderExecutionMode(entryPoint, out);
2920    }
2921    if (program.fKind == Program::kFragment_Kind) {
2922        this->writeInstruction(SpvOpExecutionMode,
2923                               fFunctionMap[main],
2924                               SpvExecutionModeOriginUpperLeft,
2925                               out);
2926    }
2927    for (size_t i = 0; i < program.fElements.size(); i++) {
2928        if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
2929            this->writeInstruction(SpvOpSourceExtension,
2930                                   ((Extension&) *program.fElements[i]).fName.c_str(),
2931                                   out);
2932        }
2933    }
2934
2935    write_stringstream(fExtraGlobalsBuffer, out);
2936    write_stringstream(fNameBuffer, out);
2937    write_stringstream(fDecorationBuffer, out);
2938    write_stringstream(fConstantBuffer, out);
2939    write_stringstream(fExternalFunctionsBuffer, out);
2940    write_stringstream(body, out);
2941}
2942
2943bool SPIRVCodeGenerator::generateCode() {
2944    ASSERT(!fErrors.errorCount());
2945    this->writeWord(SpvMagicNumber, *fOut);
2946    this->writeWord(SpvVersion, *fOut);
2947    this->writeWord(SKSL_MAGIC, *fOut);
2948    StringStream buffer;
2949    this->writeInstructions(fProgram, buffer);
2950    this->writeWord(fIdCount, *fOut);
2951    this->writeWord(0, *fOut); // reserved, always zero
2952    write_stringstream(buffer, *fOut);
2953    return 0 == fErrors.errorCount();
2954}
2955
2956}
2957