1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "SkSLSPIRVCodeGenerator.h"
9
10#include "GLSL.std.450.h"
11
12#include "ir/SkSLExpressionStatement.h"
13#include "ir/SkSLExtension.h"
14#include "ir/SkSLIndexExpression.h"
15#include "ir/SkSLVariableReference.h"
16#include "SkSLCompiler.h"
17
18namespace SkSL {
19
20static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
21
22void SPIRVCodeGenerator::setupIntrinsics() {
23#define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
24                                    GLSLstd450 ## x, GLSLstd450 ## x)
25#define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
26                                                             GLSLstd450 ## ifFloat, \
27                                                             GLSLstd450 ## ifInt, \
28                                                             GLSLstd450 ## ifUInt, \
29                                                             SpvOpUndef)
30#define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \
31                                                           SpvOp ## x)
32#define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
33                                   k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
34                                   k ## x ## _SpecialIntrinsic)
35    fIntrinsicMap[String("round")]         = ALL_GLSL(Round);
36    fIntrinsicMap[String("roundEven")]     = ALL_GLSL(RoundEven);
37    fIntrinsicMap[String("trunc")]         = ALL_GLSL(Trunc);
38    fIntrinsicMap[String("abs")]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
39    fIntrinsicMap[String("sign")]          = BY_TYPE_GLSL(FSign, SSign, SSign);
40    fIntrinsicMap[String("floor")]         = ALL_GLSL(Floor);
41    fIntrinsicMap[String("ceil")]          = ALL_GLSL(Ceil);
42    fIntrinsicMap[String("fract")]         = ALL_GLSL(Fract);
43    fIntrinsicMap[String("radians")]       = ALL_GLSL(Radians);
44    fIntrinsicMap[String("degrees")]       = ALL_GLSL(Degrees);
45    fIntrinsicMap[String("sin")]           = ALL_GLSL(Sin);
46    fIntrinsicMap[String("cos")]           = ALL_GLSL(Cos);
47    fIntrinsicMap[String("tan")]           = ALL_GLSL(Tan);
48    fIntrinsicMap[String("asin")]          = ALL_GLSL(Asin);
49    fIntrinsicMap[String("acos")]          = ALL_GLSL(Acos);
50    fIntrinsicMap[String("atan")]          = SPECIAL(Atan);
51    fIntrinsicMap[String("sinh")]          = ALL_GLSL(Sinh);
52    fIntrinsicMap[String("cosh")]          = ALL_GLSL(Cosh);
53    fIntrinsicMap[String("tanh")]          = ALL_GLSL(Tanh);
54    fIntrinsicMap[String("asinh")]         = ALL_GLSL(Asinh);
55    fIntrinsicMap[String("acosh")]         = ALL_GLSL(Acosh);
56    fIntrinsicMap[String("atanh")]         = ALL_GLSL(Atanh);
57    fIntrinsicMap[String("pow")]           = ALL_GLSL(Pow);
58    fIntrinsicMap[String("exp")]           = ALL_GLSL(Exp);
59    fIntrinsicMap[String("log")]           = ALL_GLSL(Log);
60    fIntrinsicMap[String("exp2")]          = ALL_GLSL(Exp2);
61    fIntrinsicMap[String("log2")]          = ALL_GLSL(Log2);
62    fIntrinsicMap[String("sqrt")]          = ALL_GLSL(Sqrt);
63    fIntrinsicMap[String("inverse")]       = ALL_GLSL(MatrixInverse);
64    fIntrinsicMap[String("transpose")]     = ALL_SPIRV(Transpose);
65    fIntrinsicMap[String("inversesqrt")]   = ALL_GLSL(InverseSqrt);
66    fIntrinsicMap[String("determinant")]   = ALL_GLSL(Determinant);
67    fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse);
68    fIntrinsicMap[String("mod")]           = SPECIAL(Mod);
69    fIntrinsicMap[String("min")]           = SPECIAL(Min);
70    fIntrinsicMap[String("max")]           = SPECIAL(Max);
71    fIntrinsicMap[String("clamp")]         = SPECIAL(Clamp);
72    fIntrinsicMap[String("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
73                                                             SpvOpUndef, SpvOpUndef, SpvOpUndef);
74    fIntrinsicMap[String("mix")]           = SPECIAL(Mix);
75    fIntrinsicMap[String("step")]          = ALL_GLSL(Step);
76    fIntrinsicMap[String("smoothstep")]    = ALL_GLSL(SmoothStep);
77    fIntrinsicMap[String("fma")]           = ALL_GLSL(Fma);
78    fIntrinsicMap[String("frexp")]         = ALL_GLSL(Frexp);
79    fIntrinsicMap[String("ldexp")]         = ALL_GLSL(Ldexp);
80
81#define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \
82                   fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type)
83    PACK(Snorm4x8);
84    PACK(Unorm4x8);
85    PACK(Snorm2x16);
86    PACK(Unorm2x16);
87    PACK(Half2x16);
88    PACK(Double2x32);
89    fIntrinsicMap[String("length")]      = ALL_GLSL(Length);
90    fIntrinsicMap[String("distance")]    = ALL_GLSL(Distance);
91    fIntrinsicMap[String("cross")]       = ALL_GLSL(Cross);
92    fIntrinsicMap[String("normalize")]   = ALL_GLSL(Normalize);
93    fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward);
94    fIntrinsicMap[String("reflect")]     = ALL_GLSL(Reflect);
95    fIntrinsicMap[String("refract")]     = ALL_GLSL(Refract);
96    fIntrinsicMap[String("findLSB")]     = ALL_GLSL(FindILsb);
97    fIntrinsicMap[String("findMSB")]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
98    fIntrinsicMap[String("dFdx")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
99                                                           SpvOpUndef, SpvOpUndef, SpvOpUndef);
100    fIntrinsicMap[String("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
101                                                           SpvOpUndef, SpvOpUndef, SpvOpUndef);
102    fIntrinsicMap[String("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
103                                                           SpvOpUndef, SpvOpUndef, SpvOpUndef);
104    fIntrinsicMap[String("texture")]     = SPECIAL(Texture);
105    fIntrinsicMap[String("texelFetch")]  = SPECIAL(TexelFetch);
106    fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad);
107
108    fIntrinsicMap[String("any")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
109                                                                SpvOpUndef, SpvOpUndef, SpvOpAny);
110    fIntrinsicMap[String("all")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
111                                                                SpvOpUndef, SpvOpUndef, SpvOpAll);
112    fIntrinsicMap[String("equal")]            = std::make_tuple(kSPIRV_IntrinsicKind,
113                                                                SpvOpFOrdEqual, SpvOpIEqual,
114                                                                SpvOpIEqual, SpvOpLogicalEqual);
115    fIntrinsicMap[String("notEqual")]         = std::make_tuple(kSPIRV_IntrinsicKind,
116                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
117                                                                SpvOpINotEqual,
118                                                                SpvOpLogicalNotEqual);
119    fIntrinsicMap[String("lessThan")]         = std::make_tuple(kSPIRV_IntrinsicKind,
120                                                                SpvOpFOrdLessThan, SpvOpSLessThan,
121                                                                SpvOpULessThan, SpvOpUndef);
122    fIntrinsicMap[String("lessThanEqual")]    = std::make_tuple(kSPIRV_IntrinsicKind,
123                                                                SpvOpFOrdLessThanEqual,
124                                                                SpvOpSLessThanEqual,
125                                                                SpvOpULessThanEqual,
126                                                                SpvOpUndef);
127    fIntrinsicMap[String("greaterThan")]      = std::make_tuple(kSPIRV_IntrinsicKind,
128                                                                SpvOpFOrdGreaterThan,
129                                                                SpvOpSGreaterThan,
130                                                                SpvOpUGreaterThan,
131                                                                SpvOpUndef);
132    fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
133                                                                SpvOpFOrdGreaterThanEqual,
134                                                                SpvOpSGreaterThanEqual,
135                                                                SpvOpUGreaterThanEqual,
136                                                                SpvOpUndef);
137    fIntrinsicMap[String("EmitVertex")]       = ALL_SPIRV(EmitVertex);
138    fIntrinsicMap[String("EndPrimitive")]     = ALL_SPIRV(EndPrimitive);
139// interpolateAt* not yet supported...
140}
141
142void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) {
143    out.write((const char*) &word, sizeof(word));
144}
145
146static bool is_float(const Context& context, const Type& type) {
147    if (type.kind() == Type::kVector_Kind) {
148        return is_float(context, type.componentType());
149    }
150    return type == *context.fFloat_Type || type == *context.fHalf_Type ||
151           type == *context.fDouble_Type;
152}
153
154static bool is_signed(const Context& context, const Type& type) {
155    if (type.kind() == Type::kVector_Kind) {
156        return is_signed(context, type.componentType());
157    }
158    return type == *context.fInt_Type || type == *context.fShort_Type;
159}
160
161static bool is_unsigned(const Context& context, const Type& type) {
162    if (type.kind() == Type::kVector_Kind) {
163        return is_unsigned(context, type.componentType());
164    }
165    return type == *context.fUInt_Type || type == *context.fUShort_Type;
166}
167
168static bool is_bool(const Context& context, const Type& type) {
169    if (type.kind() == Type::kVector_Kind) {
170        return is_bool(context, type.componentType());
171    }
172    return type == *context.fBool_Type;
173}
174
175static bool is_out(const Variable& var) {
176    return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
177}
178
179void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) {
180    ASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer);
181    ASSERT(opCode != SpvOpUndef);
182    switch (opCode) {
183        case SpvOpReturn:      // fall through
184        case SpvOpReturnValue: // fall through
185        case SpvOpKill:        // fall through
186        case SpvOpBranch:      // fall through
187        case SpvOpBranchConditional:
188            ASSERT(fCurrentBlock);
189            fCurrentBlock = 0;
190            break;
191        case SpvOpConstant:          // fall through
192        case SpvOpConstantTrue:      // fall through
193        case SpvOpConstantFalse:     // fall through
194        case SpvOpConstantComposite: // fall through
195        case SpvOpTypeVoid:          // fall through
196        case SpvOpTypeInt:           // fall through
197        case SpvOpTypeFloat:         // fall through
198        case SpvOpTypeBool:          // fall through
199        case SpvOpTypeVector:        // fall through
200        case SpvOpTypeMatrix:        // fall through
201        case SpvOpTypeArray:         // fall through
202        case SpvOpTypePointer:       // fall through
203        case SpvOpTypeFunction:      // fall through
204        case SpvOpTypeRuntimeArray:  // fall through
205        case SpvOpTypeStruct:        // fall through
206        case SpvOpTypeImage:         // fall through
207        case SpvOpTypeSampledImage:  // fall through
208        case SpvOpVariable:          // fall through
209        case SpvOpFunction:          // fall through
210        case SpvOpFunctionParameter: // fall through
211        case SpvOpFunctionEnd:       // fall through
212        case SpvOpExecutionMode:     // fall through
213        case SpvOpMemoryModel:       // fall through
214        case SpvOpCapability:        // fall through
215        case SpvOpExtInstImport:     // fall through
216        case SpvOpEntryPoint:        // fall through
217        case SpvOpSource:            // fall through
218        case SpvOpSourceExtension:   // fall through
219        case SpvOpName:              // fall through
220        case SpvOpMemberName:        // fall through
221        case SpvOpDecorate:          // fall through
222        case SpvOpMemberDecorate:
223            break;
224        default:
225            ASSERT(fCurrentBlock);
226    }
227    this->writeWord((length << 16) | opCode, out);
228}
229
230void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) {
231    fCurrentBlock = label;
232    this->writeInstruction(SpvOpLabel, label, out);
233}
234
235void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) {
236    this->writeOpCode(opCode, 1, out);
237}
238
239void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) {
240    this->writeOpCode(opCode, 2, out);
241    this->writeWord(word1, out);
242}
243
244void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) {
245    out.write(string, length);
246    switch (length % 4) {
247        case 1:
248            out.write8(0);
249            // fall through
250        case 2:
251            out.write8(0);
252            // fall through
253        case 3:
254            out.write8(0);
255            break;
256        default:
257            this->writeWord(0, out);
258    }
259}
260
261void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) {
262    this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out);
263    this->writeString(string.fChars, string.fLength, out);
264}
265
266
267void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string,
268                                          OutputStream& out) {
269    this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out);
270    this->writeWord(word1, out);
271    this->writeString(string.fChars, string.fLength, out);
272}
273
274void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
275                                          StringFragment string, OutputStream& out) {
276    this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out);
277    this->writeWord(word1, out);
278    this->writeWord(word2, out);
279    this->writeString(string.fChars, string.fLength, out);
280}
281
282void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
283                                          OutputStream& out) {
284    this->writeOpCode(opCode, 3, out);
285    this->writeWord(word1, out);
286    this->writeWord(word2, out);
287}
288
289void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
290                                          int32_t word3, OutputStream& out) {
291    this->writeOpCode(opCode, 4, out);
292    this->writeWord(word1, out);
293    this->writeWord(word2, out);
294    this->writeWord(word3, out);
295}
296
297void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
298                                          int32_t word3, int32_t word4, OutputStream& out) {
299    this->writeOpCode(opCode, 5, out);
300    this->writeWord(word1, out);
301    this->writeWord(word2, out);
302    this->writeWord(word3, out);
303    this->writeWord(word4, out);
304}
305
306void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
307                                          int32_t word3, int32_t word4, int32_t word5,
308                                          OutputStream& out) {
309    this->writeOpCode(opCode, 6, out);
310    this->writeWord(word1, out);
311    this->writeWord(word2, out);
312    this->writeWord(word3, out);
313    this->writeWord(word4, out);
314    this->writeWord(word5, out);
315}
316
317void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
318                                          int32_t word3, int32_t word4, int32_t word5,
319                                          int32_t word6, OutputStream& out) {
320    this->writeOpCode(opCode, 7, out);
321    this->writeWord(word1, out);
322    this->writeWord(word2, out);
323    this->writeWord(word3, out);
324    this->writeWord(word4, out);
325    this->writeWord(word5, out);
326    this->writeWord(word6, out);
327}
328
329void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
330                                          int32_t word3, int32_t word4, int32_t word5,
331                                          int32_t word6, int32_t word7, OutputStream& out) {
332    this->writeOpCode(opCode, 8, out);
333    this->writeWord(word1, out);
334    this->writeWord(word2, out);
335    this->writeWord(word3, out);
336    this->writeWord(word4, out);
337    this->writeWord(word5, out);
338    this->writeWord(word6, out);
339    this->writeWord(word7, out);
340}
341
342void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
343                                          int32_t word3, int32_t word4, int32_t word5,
344                                          int32_t word6, int32_t word7, int32_t word8,
345                                          OutputStream& out) {
346    this->writeOpCode(opCode, 9, out);
347    this->writeWord(word1, out);
348    this->writeWord(word2, out);
349    this->writeWord(word3, out);
350    this->writeWord(word4, out);
351    this->writeWord(word5, out);
352    this->writeWord(word6, out);
353    this->writeWord(word7, out);
354    this->writeWord(word8, out);
355}
356
357void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) {
358    for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
359        if (fCapabilities & bit) {
360            this->writeInstruction(SpvOpCapability, (SpvId) i, out);
361        }
362    }
363    if (fProgram.fKind == Program::kGeometry_Kind) {
364        this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out);
365    }
366}
367
368SpvId SPIRVCodeGenerator::nextId() {
369    return fIdCount++;
370}
371
372void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
373                                     SpvId resultId) {
374    this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
375    // go ahead and write all of the field types, so we don't inadvertently write them while we're
376    // in the middle of writing the struct instruction
377    std::vector<SpvId> types;
378    for (const auto& f : type.fields()) {
379        types.push_back(this->getType(*f.fType, memoryLayout));
380    }
381    this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
382    this->writeWord(resultId, fConstantBuffer);
383    for (SpvId id : types) {
384        this->writeWord(id, fConstantBuffer);
385    }
386    size_t offset = 0;
387    for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
388        size_t size = memoryLayout.size(*type.fields()[i].fType);
389        size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
390        const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
391        if (fieldLayout.fOffset >= 0) {
392            if (fieldLayout.fOffset < (int) offset) {
393                fErrors.error(type.fOffset,
394                              "offset of field '" + type.fields()[i].fName + "' must be at "
395                              "least " + to_string((int) offset));
396            }
397            if (fieldLayout.fOffset % alignment) {
398                fErrors.error(type.fOffset,
399                              "offset of field '" + type.fields()[i].fName + "' must be a multiple"
400                              " of " + to_string((int) alignment));
401            }
402            offset = fieldLayout.fOffset;
403        } else {
404            size_t mod = offset % alignment;
405            if (mod) {
406                offset += alignment - mod;
407            }
408        }
409        this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName, fNameBuffer);
410        this->writeLayout(fieldLayout, resultId, i);
411        if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
412            this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
413                                   (SpvId) offset, fDecorationBuffer);
414        }
415        if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
416            this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
417                                   fDecorationBuffer);
418            this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
419                                   (SpvId) memoryLayout.stride(*type.fields()[i].fType),
420                                   fDecorationBuffer);
421        }
422        offset += size;
423        Type::Kind kind = type.fields()[i].fType->kind();
424        if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
425            offset += alignment - offset % alignment;
426        }
427    }
428}
429
430Type SPIRVCodeGenerator::getActualType(const Type& type) {
431    if (type == *fContext.fHalf_Type) {
432        return *fContext.fFloat_Type;
433    }
434    if (type == *fContext.fShort_Type) {
435        return *fContext.fInt_Type;
436    }
437    if (type == *fContext.fUShort_Type) {
438        return *fContext.fUInt_Type;
439    }
440    if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) {
441        if (type.componentType() == *fContext.fHalf_Type) {
442            return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows());
443        }
444        if (type.componentType() == *fContext.fShort_Type) {
445            return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows());
446        }
447        if (type.componentType() == *fContext.fUShort_Type) {
448            return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows());
449        }
450    }
451    return type;
452}
453
454SpvId SPIRVCodeGenerator::getType(const Type& type) {
455    return this->getType(type, fDefaultLayout);
456}
457
458SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) {
459    Type type = this->getActualType(rawType);
460    String key = type.name() + to_string((int) layout.fStd);
461    auto entry = fTypeMap.find(key);
462    if (entry == fTypeMap.end()) {
463        SpvId result = this->nextId();
464        switch (type.kind()) {
465            case Type::kScalar_Kind:
466                if (type == *fContext.fBool_Type) {
467                    this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
468                } else if (type == *fContext.fInt_Type) {
469                    this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
470                } else if (type == *fContext.fUInt_Type) {
471                    this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
472                } else if (type == *fContext.fFloat_Type) {
473                    this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
474                } else if (type == *fContext.fDouble_Type) {
475                    this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
476                } else {
477                    ASSERT(false);
478                }
479                break;
480            case Type::kVector_Kind:
481                this->writeInstruction(SpvOpTypeVector, result,
482                                       this->getType(type.componentType(), layout),
483                                       type.columns(), fConstantBuffer);
484                break;
485            case Type::kMatrix_Kind:
486                this->writeInstruction(SpvOpTypeMatrix, result,
487                                       this->getType(index_type(fContext, type), layout),
488                                       type.columns(), fConstantBuffer);
489                break;
490            case Type::kStruct_Kind:
491                this->writeStruct(type, layout, result);
492                break;
493            case Type::kArray_Kind: {
494                if (type.columns() > 0) {
495                    IntLiteral count(fContext, -1, type.columns());
496                    this->writeInstruction(SpvOpTypeArray, result,
497                                           this->getType(type.componentType(), layout),
498                                           this->writeIntLiteral(count), fConstantBuffer);
499                    this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
500                                           (int32_t) layout.stride(type),
501                                           fDecorationBuffer);
502                } else {
503                    this->writeInstruction(SpvOpTypeRuntimeArray, result,
504                                           this->getType(type.componentType(), layout),
505                                           fConstantBuffer);
506                    this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
507                                           (int32_t) layout.stride(type),
508                                           fDecorationBuffer);
509                }
510                break;
511            }
512            case Type::kSampler_Kind: {
513                SpvId image = result;
514                if (SpvDimSubpassData != type.dimensions()) {
515                    image = this->nextId();
516                }
517                if (SpvDimBuffer == type.dimensions()) {
518                    fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer);
519                }
520                this->writeInstruction(SpvOpTypeImage, image,
521                                       this->getType(*fContext.fFloat_Type, layout),
522                                       type.dimensions(), type.isDepth(), type.isArrayed(),
523                                       type.isMultisampled(), type.isSampled() ? 1 : 2,
524                                       SpvImageFormatUnknown, fConstantBuffer);
525                fImageTypeMap[key] = image;
526                if (SpvDimSubpassData != type.dimensions()) {
527                    this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
528                }
529                break;
530            }
531            default:
532                if (type == *fContext.fVoid_Type) {
533                    this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
534                } else {
535                    ABORT("invalid type: %s", type.description().c_str());
536                }
537        }
538        fTypeMap[key] = result;
539        return result;
540    }
541    return entry->second;
542}
543
544SpvId SPIRVCodeGenerator::getImageType(const Type& type) {
545    ASSERT(type.kind() == Type::kSampler_Kind);
546    this->getType(type);
547    String key = type.name() + to_string((int) fDefaultLayout.fStd);
548    ASSERT(fImageTypeMap.find(key) != fImageTypeMap.end());
549    return fImageTypeMap[key];
550}
551
552SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
553    String key = function.fReturnType.description() + "(";
554    String separator;
555    for (size_t i = 0; i < function.fParameters.size(); i++) {
556        key += separator;
557        separator = ", ";
558        key += function.fParameters[i]->fType.description();
559    }
560    key += ")";
561    auto entry = fTypeMap.find(key);
562    if (entry == fTypeMap.end()) {
563        SpvId result = this->nextId();
564        int32_t length = 3 + (int32_t) function.fParameters.size();
565        SpvId returnType = this->getType(function.fReturnType);
566        std::vector<SpvId> parameterTypes;
567        for (size_t i = 0; i < function.fParameters.size(); i++) {
568            // glslang seems to treat all function arguments as pointers whether they need to be or
569            // not. I  was initially puzzled by this until I ran bizarre failures with certain
570            // patterns of function calls and control constructs, as exemplified by this minimal
571            // failure case:
572            //
573            // void sphere(float x) {
574            // }
575            //
576            // void map() {
577            //     sphere(1.0);
578            // }
579            //
580            // void main() {
581            //     for (int i = 0; i < 1; i++) {
582            //         map();
583            //     }
584            // }
585            //
586            // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
587            // crashes. Making it take a float* and storing the argument in a temporary variable,
588            // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
589            // the spec makes this make sense.
590//            if (is_out(function->fParameters[i])) {
591                parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
592                                                              SpvStorageClassFunction));
593//            } else {
594//                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
595//            }
596        }
597        this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
598        this->writeWord(result, fConstantBuffer);
599        this->writeWord(returnType, fConstantBuffer);
600        for (SpvId id : parameterTypes) {
601            this->writeWord(id, fConstantBuffer);
602        }
603        fTypeMap[key] = result;
604        return result;
605    }
606    return entry->second;
607}
608
609SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
610    return this->getPointerType(type, fDefaultLayout, storageClass);
611}
612
613SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout,
614                                         SpvStorageClass_ storageClass) {
615    Type type = this->getActualType(rawType);
616    String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
617    auto entry = fTypeMap.find(key);
618    if (entry == fTypeMap.end()) {
619        SpvId result = this->nextId();
620        this->writeInstruction(SpvOpTypePointer, result, storageClass,
621                               this->getType(type), fConstantBuffer);
622        fTypeMap[key] = result;
623        return result;
624    }
625    return entry->second;
626}
627
628SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) {
629    switch (expr.fKind) {
630        case Expression::kBinary_Kind:
631            return this->writeBinaryExpression((BinaryExpression&) expr, out);
632        case Expression::kBoolLiteral_Kind:
633            return this->writeBoolLiteral((BoolLiteral&) expr);
634        case Expression::kConstructor_Kind:
635            return this->writeConstructor((Constructor&) expr, out);
636        case Expression::kIntLiteral_Kind:
637            return this->writeIntLiteral((IntLiteral&) expr);
638        case Expression::kFieldAccess_Kind:
639            return this->writeFieldAccess(((FieldAccess&) expr), out);
640        case Expression::kFloatLiteral_Kind:
641            return this->writeFloatLiteral(((FloatLiteral&) expr));
642        case Expression::kFunctionCall_Kind:
643            return this->writeFunctionCall((FunctionCall&) expr, out);
644        case Expression::kPrefix_Kind:
645            return this->writePrefixExpression((PrefixExpression&) expr, out);
646        case Expression::kPostfix_Kind:
647            return this->writePostfixExpression((PostfixExpression&) expr, out);
648        case Expression::kSwizzle_Kind:
649            return this->writeSwizzle((Swizzle&) expr, out);
650        case Expression::kVariableReference_Kind:
651            return this->writeVariableReference((VariableReference&) expr, out);
652        case Expression::kTernary_Kind:
653            return this->writeTernaryExpression((TernaryExpression&) expr, out);
654        case Expression::kIndex_Kind:
655            return this->writeIndexExpression((IndexExpression&) expr, out);
656        default:
657            ABORT("unsupported expression: %s", expr.description().c_str());
658    }
659    return -1;
660}
661
662SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) {
663    auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
664    ASSERT(intrinsic != fIntrinsicMap.end());
665    int32_t intrinsicId;
666    if (c.fArguments.size() > 0) {
667        const Type& type = c.fArguments[0]->fType;
668        if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
669            intrinsicId = std::get<1>(intrinsic->second);
670        } else if (is_signed(fContext, type)) {
671            intrinsicId = std::get<2>(intrinsic->second);
672        } else if (is_unsigned(fContext, type)) {
673            intrinsicId = std::get<3>(intrinsic->second);
674        } else if (is_bool(fContext, type)) {
675            intrinsicId = std::get<4>(intrinsic->second);
676        } else {
677            intrinsicId = std::get<1>(intrinsic->second);
678        }
679    } else {
680        intrinsicId = std::get<1>(intrinsic->second);
681    }
682    switch (std::get<0>(intrinsic->second)) {
683        case kGLSL_STD_450_IntrinsicKind: {
684            SpvId result = this->nextId();
685            std::vector<SpvId> arguments;
686            for (size_t i = 0; i < c.fArguments.size(); i++) {
687                arguments.push_back(this->writeExpression(*c.fArguments[i], out));
688            }
689            this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
690            this->writeWord(this->getType(c.fType), out);
691            this->writeWord(result, out);
692            this->writeWord(fGLSLExtendedInstructions, out);
693            this->writeWord(intrinsicId, out);
694            for (SpvId id : arguments) {
695                this->writeWord(id, out);
696            }
697            return result;
698        }
699        case kSPIRV_IntrinsicKind: {
700            SpvId result = this->nextId();
701            std::vector<SpvId> arguments;
702            for (size_t i = 0; i < c.fArguments.size(); i++) {
703                arguments.push_back(this->writeExpression(*c.fArguments[i], out));
704            }
705            if (c.fType != *fContext.fVoid_Type) {
706                this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
707                this->writeWord(this->getType(c.fType), out);
708                this->writeWord(result, out);
709            } else {
710                this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out);
711            }
712            for (SpvId id : arguments) {
713                this->writeWord(id, out);
714            }
715            return result;
716        }
717        case kSpecial_IntrinsicKind:
718            return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
719        default:
720            ABORT("unsupported intrinsic kind");
721    }
722}
723
724std::vector<SpvId> SPIRVCodeGenerator::vectorize(
725                                               const std::vector<std::unique_ptr<Expression>>& args,
726                                               OutputStream& out) {
727    int vectorSize = 0;
728    for (const auto& a : args) {
729        if (a->fType.kind() == Type::kVector_Kind) {
730            if (vectorSize) {
731                ASSERT(a->fType.columns() == vectorSize);
732            }
733            else {
734                vectorSize = a->fType.columns();
735            }
736        }
737    }
738    std::vector<SpvId> result;
739    for (const auto& a : args) {
740        SpvId raw = this->writeExpression(*a, out);
741        if (vectorSize && a->fType.kind() == Type::kScalar_Kind) {
742            SpvId vector = this->nextId();
743            this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out);
744            this->writeWord(this->getType(a->fType.toCompound(fContext, vectorSize, 1)), out);
745            this->writeWord(vector, out);
746            for (int i = 0; i < vectorSize; i++) {
747                this->writeWord(raw, out);
748            }
749            result.push_back(vector);
750        } else {
751            result.push_back(raw);
752        }
753    }
754    return result;
755}
756
757void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst,
758                                                      SpvId signedInst, SpvId unsignedInst,
759                                                      const std::vector<SpvId>& args,
760                                                      OutputStream& out) {
761    this->writeOpCode(SpvOpExtInst, 5 + args.size(), out);
762    this->writeWord(this->getType(type), out);
763    this->writeWord(id, out);
764    this->writeWord(fGLSLExtendedInstructions, out);
765
766    if (is_float(fContext, type)) {
767        this->writeWord(floatInst, out);
768    } else if (is_signed(fContext, type)) {
769        this->writeWord(signedInst, out);
770    } else if (is_unsigned(fContext, type)) {
771        this->writeWord(unsignedInst, out);
772    } else {
773        ASSERT(false);
774    }
775    for (SpvId a : args) {
776        this->writeWord(a, out);
777    }
778}
779
780SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
781                                                OutputStream& out) {
782    SpvId result = this->nextId();
783    switch (kind) {
784        case kAtan_SpecialIntrinsic: {
785            std::vector<SpvId> arguments;
786            for (size_t i = 0; i < c.fArguments.size(); i++) {
787                arguments.push_back(this->writeExpression(*c.fArguments[i], out));
788            }
789            this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
790            this->writeWord(this->getType(c.fType), out);
791            this->writeWord(result, out);
792            this->writeWord(fGLSLExtendedInstructions, out);
793            this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
794            for (SpvId id : arguments) {
795                this->writeWord(id, out);
796            }
797            break;
798        }
799        case kSubpassLoad_SpecialIntrinsic: {
800            SpvId img = this->writeExpression(*c.fArguments[0], out);
801            std::vector<std::unique_ptr<Expression>> args;
802            args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
803            args.emplace_back(new FloatLiteral(fContext, -1, 0.0));
804            Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args));
805            SpvId coords = this->writeConstantVector(ctor);
806            if (1 == c.fArguments.size()) {
807                this->writeInstruction(SpvOpImageRead,
808                                       this->getType(c.fType),
809                                       result,
810                                       img,
811                                       coords,
812                                       out);
813            } else {
814                ASSERT(2 == c.fArguments.size());
815                SpvId sample = this->writeExpression(*c.fArguments[1], out);
816                this->writeInstruction(SpvOpImageRead,
817                                       this->getType(c.fType),
818                                       result,
819                                       img,
820                                       coords,
821                                       SpvImageOperandsSampleMask,
822                                       sample,
823                                       out);
824            }
825            break;
826        }
827        case kTexelFetch_SpecialIntrinsic: {
828            ASSERT(c.fArguments.size() == 2);
829            SpvId image = this->nextId();
830            this->writeInstruction(SpvOpImage,
831                                   this->getImageType(c.fArguments[0]->fType),
832                                   image,
833                                   this->writeExpression(*c.fArguments[0], out),
834                                   out);
835            this->writeInstruction(SpvOpImageFetch,
836                                   this->getType(c.fType),
837                                   result,
838                                   image,
839                                   this->writeExpression(*c.fArguments[1], out),
840                                   out);
841            break;
842        }
843        case kTexture_SpecialIntrinsic: {
844            SpvOp_ op = SpvOpImageSampleImplicitLod;
845            switch (c.fArguments[0]->fType.dimensions()) {
846                case SpvDim1D:
847                    if (c.fArguments[1]->fType == *fContext.fFloat2_Type) {
848                        op = SpvOpImageSampleProjImplicitLod;
849                    } else {
850                        ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
851                    }
852                    break;
853                case SpvDim2D:
854                    if (c.fArguments[1]->fType == *fContext.fFloat3_Type) {
855                        op = SpvOpImageSampleProjImplicitLod;
856                    } else {
857                        ASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type);
858                    }
859                    break;
860                case SpvDim3D:
861                    if (c.fArguments[1]->fType == *fContext.fFloat4_Type) {
862                        op = SpvOpImageSampleProjImplicitLod;
863                    } else {
864                        ASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type);
865                    }
866                    break;
867                case SpvDimCube:   // fall through
868                case SpvDimRect:   // fall through
869                case SpvDimBuffer: // fall through
870                case SpvDimSubpassData:
871                    break;
872            }
873            SpvId type = this->getType(c.fType);
874            SpvId sampler = this->writeExpression(*c.fArguments[0], out);
875            SpvId uv = this->writeExpression(*c.fArguments[1], out);
876            if (c.fArguments.size() == 3) {
877                this->writeInstruction(op, type, result, sampler, uv,
878                                       SpvImageOperandsBiasMask,
879                                       this->writeExpression(*c.fArguments[2], out),
880                                       out);
881            } else {
882                ASSERT(c.fArguments.size() == 2);
883                if (fProgram.fSettings.fSharpenTextures) {
884                    FloatLiteral lodBias(fContext, -1, -0.5);
885                    this->writeInstruction(op, type, result, sampler, uv,
886                                           SpvImageOperandsBiasMask,
887                                           this->writeFloatLiteral(lodBias),
888                                           out);
889                } else {
890                    this->writeInstruction(op, type, result, sampler, uv,
891                                           out);
892                }
893            }
894            break;
895        }
896        case kMod_SpecialIntrinsic: {
897            std::vector<SpvId> args = this->vectorize(c.fArguments, out);
898            ASSERT(args.size() == 2);
899            const Type& operandType = c.fArguments[0]->fType;
900            SpvOp_ op;
901            if (is_float(fContext, operandType)) {
902                op = SpvOpFMod;
903            } else if (is_signed(fContext, operandType)) {
904                op = SpvOpSMod;
905            } else if (is_unsigned(fContext, operandType)) {
906                op = SpvOpUMod;
907            } else {
908                ASSERT(false);
909                return 0;
910            }
911            this->writeOpCode(op, 5, out);
912            this->writeWord(this->getType(operandType), out);
913            this->writeWord(result, out);
914            this->writeWord(args[0], out);
915            this->writeWord(args[1], out);
916            break;
917        }
918        case kClamp_SpecialIntrinsic: {
919            std::vector<SpvId> args = this->vectorize(c.fArguments, out);
920            ASSERT(args.size() == 3);
921            this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp,
922                                               GLSLstd450UClamp, args, out);
923            break;
924        }
925        case kMax_SpecialIntrinsic: {
926            std::vector<SpvId> args = this->vectorize(c.fArguments, out);
927            ASSERT(args.size() == 2);
928            this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMax, GLSLstd450SMax,
929                                               GLSLstd450UMax, args, out);
930            break;
931        }
932        case kMin_SpecialIntrinsic: {
933            std::vector<SpvId> args = this->vectorize(c.fArguments, out);
934            ASSERT(args.size() == 2);
935            this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMin, GLSLstd450SMin,
936                                               GLSLstd450UMin, args, out);
937            break;
938        }
939        case kMix_SpecialIntrinsic: {
940            std::vector<SpvId> args = this->vectorize(c.fArguments, out);
941            ASSERT(args.size() == 3);
942            this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMix, SpvOpUndef,
943                                               SpvOpUndef, args, out);
944            break;
945        }
946    }
947    return result;
948}
949
950SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) {
951    const auto& entry = fFunctionMap.find(&c.fFunction);
952    if (entry == fFunctionMap.end()) {
953        return this->writeIntrinsicCall(c, out);
954    }
955    // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
956    std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
957    std::vector<SpvId> arguments;
958    for (size_t i = 0; i < c.fArguments.size(); i++) {
959        // id of temporary variable that we will use to hold this argument, or 0 if it is being
960        // passed directly
961        SpvId tmpVar;
962        // if we need a temporary var to store this argument, this is the value to store in the var
963        SpvId tmpValueId;
964        if (is_out(*c.fFunction.fParameters[i])) {
965            std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
966            SpvId ptr = lv->getPointer();
967            if (ptr) {
968                arguments.push_back(ptr);
969                continue;
970            } else {
971                // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
972                // copy it into a temp, call the function, read the value out of the temp, and then
973                // update the lvalue.
974                tmpValueId = lv->load(out);
975                tmpVar = this->nextId();
976                lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
977                                  std::move(lv)));
978            }
979        } else {
980            // see getFunctionType for an explanation of why we're always using pointer parameters
981            tmpValueId = this->writeExpression(*c.fArguments[i], out);
982            tmpVar = this->nextId();
983        }
984        this->writeInstruction(SpvOpVariable,
985                               this->getPointerType(c.fArguments[i]->fType,
986                                                    SpvStorageClassFunction),
987                               tmpVar,
988                               SpvStorageClassFunction,
989                               fVariableBuffer);
990        this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
991        arguments.push_back(tmpVar);
992    }
993    SpvId result = this->nextId();
994    this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
995    this->writeWord(this->getType(c.fType), out);
996    this->writeWord(result, out);
997    this->writeWord(entry->second, out);
998    for (SpvId id : arguments) {
999        this->writeWord(id, out);
1000    }
1001    // now that the call is complete, we may need to update some lvalues with the new values of out
1002    // arguments
1003    for (const auto& tuple : lvalues) {
1004        SpvId load = this->nextId();
1005        this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
1006        std::get<2>(tuple)->store(load, out);
1007    }
1008    return result;
1009}
1010
1011SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1012    ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1013    SpvId result = this->nextId();
1014    std::vector<SpvId> arguments;
1015    for (size_t i = 0; i < c.fArguments.size(); i++) {
1016        arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1017    }
1018    SpvId type = this->getType(c.fType);
1019    if (c.fArguments.size() == 1) {
1020        // with a single argument, a vector will have all of its entries equal to the argument
1021        this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1022        this->writeWord(type, fConstantBuffer);
1023        this->writeWord(result, fConstantBuffer);
1024        for (int i = 0; i < c.fType.columns(); i++) {
1025            this->writeWord(arguments[0], fConstantBuffer);
1026        }
1027    } else {
1028        this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1029                          fConstantBuffer);
1030        this->writeWord(type, fConstantBuffer);
1031        this->writeWord(result, fConstantBuffer);
1032        for (SpvId id : arguments) {
1033            this->writeWord(id, fConstantBuffer);
1034        }
1035    }
1036    return result;
1037}
1038
1039SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) {
1040    ASSERT(c.fType.isFloat());
1041    ASSERT(c.fArguments.size() == 1);
1042    ASSERT(c.fArguments[0]->fType.isNumber());
1043    SpvId result = this->nextId();
1044    SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1045    if (c.fArguments[0]->fType.isSigned()) {
1046        this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1047                               out);
1048    } else {
1049        ASSERT(c.fArguments[0]->fType.isUnsigned());
1050        this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1051                               out);
1052    }
1053    return result;
1054}
1055
1056SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) {
1057    ASSERT(c.fType.isSigned());
1058    ASSERT(c.fArguments.size() == 1);
1059    ASSERT(c.fArguments[0]->fType.isNumber());
1060    SpvId result = this->nextId();
1061    SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1062    if (c.fArguments[0]->fType.isFloat()) {
1063        this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1064                               out);
1065    }
1066    else {
1067        ASSERT(c.fArguments[0]->fType.isUnsigned());
1068        this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1069                               out);
1070    }
1071    return result;
1072}
1073
1074SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) {
1075    ASSERT(c.fType.isUnsigned());
1076    ASSERT(c.fArguments.size() == 1);
1077    ASSERT(c.fArguments[0]->fType.isNumber());
1078    SpvId result = this->nextId();
1079    SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1080    if (c.fArguments[0]->fType.isFloat()) {
1081        this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter,
1082                               out);
1083    } else {
1084        ASSERT(c.fArguments[0]->fType.isSigned());
1085        this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter,
1086                               out);
1087    }
1088    return result;
1089}
1090
1091void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1092                                                 OutputStream& out) {
1093    FloatLiteral zero(fContext, -1, 0);
1094    SpvId zeroId = this->writeFloatLiteral(zero);
1095    std::vector<SpvId> columnIds;
1096    for (int column = 0; column < type.columns(); column++) {
1097        this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1098                          out);
1099        this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1100                        out);
1101        SpvId columnId = this->nextId();
1102        this->writeWord(columnId, out);
1103        columnIds.push_back(columnId);
1104        for (int row = 0; row < type.columns(); row++) {
1105            this->writeWord(row == column ? diagonal : zeroId, out);
1106        }
1107    }
1108    this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1109                      out);
1110    this->writeWord(this->getType(type), out);
1111    this->writeWord(id, out);
1112    for (SpvId id : columnIds) {
1113        this->writeWord(id, out);
1114    }
1115}
1116
1117void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1118                                         const Type& dstType, OutputStream& out) {
1119    ASSERT(srcType.kind() == Type::kMatrix_Kind);
1120    ASSERT(dstType.kind() == Type::kMatrix_Kind);
1121    ASSERT(srcType.componentType() == dstType.componentType());
1122    SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext,
1123                                                                           srcType.rows(),
1124                                                                           1));
1125    SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext,
1126                                                                           dstType.rows(),
1127                                                                           1));
1128    SpvId zeroId;
1129    if (dstType.componentType() == *fContext.fFloat_Type) {
1130        FloatLiteral zero(fContext, -1, 0.0);
1131        zeroId = this->writeFloatLiteral(zero);
1132    } else if (dstType.componentType() == *fContext.fInt_Type) {
1133        IntLiteral zero(fContext, -1, 0);
1134        zeroId = this->writeIntLiteral(zero);
1135    } else {
1136        ABORT("unsupported matrix component type");
1137    }
1138    SpvId zeroColumn = 0;
1139    SpvId columns[4];
1140    for (int i = 0; i < dstType.columns(); i++) {
1141        if (i < srcType.columns()) {
1142            // we're still inside the src matrix, copy the column
1143            SpvId srcColumn = this->nextId();
1144            this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out);
1145            SpvId dstColumn;
1146            if (srcType.rows() == dstType.rows()) {
1147                // columns are equal size, don't need to do anything
1148                dstColumn = srcColumn;
1149            }
1150            else if (dstType.rows() > srcType.rows()) {
1151                // dst column is bigger, need to zero-pad it
1152                dstColumn = this->nextId();
1153                int delta = dstType.rows() - srcType.rows();
1154                this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out);
1155                this->writeWord(dstColumnType, out);
1156                this->writeWord(dstColumn, out);
1157                this->writeWord(srcColumn, out);
1158                for (int i = 0; i < delta; ++i) {
1159                    this->writeWord(zeroId, out);
1160                }
1161            }
1162            else {
1163                // dst column is smaller, need to swizzle the src column
1164                dstColumn = this->nextId();
1165                int count = dstType.rows();
1166                this->writeOpCode(SpvOpVectorShuffle, 5 + count, out);
1167                this->writeWord(dstColumnType, out);
1168                this->writeWord(dstColumn, out);
1169                this->writeWord(srcColumn, out);
1170                this->writeWord(srcColumn, out);
1171                for (int i = 0; i < count; i++) {
1172                    this->writeWord(i, out);
1173                }
1174            }
1175            columns[i] = dstColumn;
1176        } else {
1177            // we're past the end of the src matrix, need a vector of zeroes
1178            if (!zeroColumn) {
1179                zeroColumn = this->nextId();
1180                this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out);
1181                this->writeWord(dstColumnType, out);
1182                this->writeWord(zeroColumn, out);
1183                for (int i = 0; i < dstType.rows(); ++i) {
1184                    this->writeWord(zeroId, out);
1185                }
1186            }
1187            columns[i] = zeroColumn;
1188        }
1189    }
1190    this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out);
1191    this->writeWord(this->getType(dstType), out);
1192    this->writeWord(id, out);
1193    for (int i = 0; i < dstType.columns(); i++) {
1194        this->writeWord(columns[i], out);
1195    }
1196}
1197
1198SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) {
1199    ASSERT(c.fType.kind() == Type::kMatrix_Kind);
1200    // go ahead and write the arguments so we don't try to write new instructions in the middle of
1201    // an instruction
1202    std::vector<SpvId> arguments;
1203    for (size_t i = 0; i < c.fArguments.size(); i++) {
1204        arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1205    }
1206    SpvId result = this->nextId();
1207    int rows = c.fType.rows();
1208    int columns = c.fType.columns();
1209    if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1210        this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1211    } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1212        this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1213    } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) {
1214        ASSERT(c.fType.rows() == 2 && c.fType.columns() == 2);
1215        ASSERT(c.fArguments[0]->fType.columns() == 4);
1216        SpvId componentType = this->getType(c.fType.componentType());
1217        SpvId v[4];
1218        for (int i = 0; i < 4; ++i) {
1219            v[i] = this->nextId();
1220            this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out);
1221        }
1222        SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1));
1223        SpvId column1 = this->nextId();
1224        this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out);
1225        SpvId column2 = this->nextId();
1226        this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out);
1227        this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1,
1228                               column2, out);
1229    } else {
1230        std::vector<SpvId> columnIds;
1231        // ids of vectors and scalars we have written to the current column so far
1232        std::vector<SpvId> currentColumn;
1233        // the total number of scalars represented by currentColumn's entries
1234        int currentCount = 0;
1235        for (size_t i = 0; i < arguments.size(); i++) {
1236            if (c.fArguments[i]->fType.kind() == Type::kVector_Kind &&
1237                    c.fArguments[i]->fType.columns() == c.fType.rows()) {
1238                // this is a complete column by itself
1239                ASSERT(currentCount == 0);
1240                columnIds.push_back(arguments[i]);
1241            } else {
1242                currentColumn.push_back(arguments[i]);
1243                currentCount += c.fArguments[i]->fType.columns();
1244                if (currentCount == rows) {
1245                    currentCount = 0;
1246                    this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn.size(), out);
1247                    this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
1248                                                                                     1)),
1249                                    out);
1250                    SpvId columnId = this->nextId();
1251                    this->writeWord(columnId, out);
1252                    columnIds.push_back(columnId);
1253                    for (SpvId id : currentColumn) {
1254                        this->writeWord(id, out);
1255                    }
1256                    currentColumn.clear();
1257                }
1258                ASSERT(currentCount < rows);
1259            }
1260        }
1261        ASSERT(columnIds.size() == (size_t) columns);
1262        this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1263        this->writeWord(this->getType(c.fType), out);
1264        this->writeWord(result, out);
1265        for (SpvId id : columnIds) {
1266            this->writeWord(id, out);
1267        }
1268    }
1269    return result;
1270}
1271
1272SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) {
1273    ASSERT(c.fType.kind() == Type::kVector_Kind);
1274    if (c.isConstant()) {
1275        return this->writeConstantVector(c);
1276    }
1277    // go ahead and write the arguments so we don't try to write new instructions in the middle of
1278    // an instruction
1279    std::vector<SpvId> arguments;
1280    for (size_t i = 0; i < c.fArguments.size(); i++) {
1281        if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1282            // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to
1283            // extract the components and convert them in that case manually. On top of that,
1284            // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite
1285            // doesn't handle vector arguments at all, so we always extract vector components and
1286            // pass them into OpCreateComposite individually.
1287            SpvId vec = this->writeExpression(*c.fArguments[i], out);
1288            SpvOp_ op = SpvOpUndef;
1289            const Type& src = c.fArguments[i]->fType.componentType();
1290            const Type& dst = c.fType.componentType();
1291            if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) {
1292                if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1293                    if (c.fArguments.size() == 1) {
1294                        return vec;
1295                    }
1296                } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1297                    op = SpvOpConvertSToF;
1298                } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1299                    op = SpvOpConvertUToF;
1300                } else {
1301                    ASSERT(false);
1302                }
1303            } else if (dst == *fContext.fInt_Type || dst == *fContext.fShort_Type) {
1304                if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1305                    op = SpvOpConvertFToS;
1306                } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1307                    if (c.fArguments.size() == 1) {
1308                        return vec;
1309                    }
1310                } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1311                    op = SpvOpBitcast;
1312                } else {
1313                    ASSERT(false);
1314                }
1315            } else if (dst == *fContext.fUInt_Type || dst == *fContext.fUShort_Type) {
1316                if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) {
1317                    op = SpvOpConvertFToS;
1318                } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) {
1319                    op = SpvOpBitcast;
1320                } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) {
1321                    if (c.fArguments.size() == 1) {
1322                        return vec;
1323                    }
1324                } else {
1325                    ASSERT(false);
1326                }
1327            }
1328            for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) {
1329                SpvId swizzle = this->nextId();
1330                this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j,
1331                                       out);
1332                if (op != SpvOpUndef) {
1333                    SpvId cast = this->nextId();
1334                    this->writeInstruction(op, this->getType(dst), cast, swizzle, out);
1335                    arguments.push_back(cast);
1336                } else {
1337                    arguments.push_back(swizzle);
1338                }
1339            }
1340        } else {
1341            arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1342        }
1343    }
1344    SpvId result = this->nextId();
1345    if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1346        this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1347        this->writeWord(this->getType(c.fType), out);
1348        this->writeWord(result, out);
1349        for (int i = 0; i < c.fType.columns(); i++) {
1350            this->writeWord(arguments[0], out);
1351        }
1352    } else {
1353        ASSERT(arguments.size() > 1);
1354        this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out);
1355        this->writeWord(this->getType(c.fType), out);
1356        this->writeWord(result, out);
1357        for (SpvId id : arguments) {
1358            this->writeWord(id, out);
1359        }
1360    }
1361    return result;
1362}
1363
1364SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) {
1365    ASSERT(c.fType.kind() == Type::kArray_Kind);
1366    // go ahead and write the arguments so we don't try to write new instructions in the middle of
1367    // an instruction
1368    std::vector<SpvId> arguments;
1369    for (size_t i = 0; i < c.fArguments.size(); i++) {
1370        arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1371    }
1372    SpvId result = this->nextId();
1373    this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1374    this->writeWord(this->getType(c.fType), out);
1375    this->writeWord(result, out);
1376    for (SpvId id : arguments) {
1377        this->writeWord(id, out);
1378    }
1379    return result;
1380}
1381
1382SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) {
1383    if (c.fArguments.size() == 1 &&
1384        this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) {
1385        return this->writeExpression(*c.fArguments[0], out);
1386    }
1387    if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) {
1388        return this->writeFloatConstructor(c, out);
1389    } else if (c.fType == *fContext.fInt_Type || c.fType == *fContext.fShort_Type) {
1390        return this->writeIntConstructor(c, out);
1391    } else if (c.fType == *fContext.fUInt_Type || c.fType == *fContext.fUShort_Type) {
1392        return this->writeUIntConstructor(c, out);
1393    }
1394    switch (c.fType.kind()) {
1395        case Type::kVector_Kind:
1396            return this->writeVectorConstructor(c, out);
1397        case Type::kMatrix_Kind:
1398            return this->writeMatrixConstructor(c, out);
1399        case Type::kArray_Kind:
1400            return this->writeArrayConstructor(c, out);
1401        default:
1402            ABORT("unsupported constructor: %s", c.description().c_str());
1403    }
1404}
1405
1406SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1407    if (modifiers.fFlags & Modifiers::kIn_Flag) {
1408        ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1409        return SpvStorageClassInput;
1410    } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1411        ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag));
1412        return SpvStorageClassOutput;
1413    } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1414        if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) {
1415            return SpvStorageClassPushConstant;
1416        }
1417        return SpvStorageClassUniform;
1418    } else {
1419        return SpvStorageClassFunction;
1420    }
1421}
1422
1423SpvStorageClass_ get_storage_class(const Expression& expr) {
1424    switch (expr.fKind) {
1425        case Expression::kVariableReference_Kind: {
1426            const Variable& var = ((VariableReference&) expr).fVariable;
1427            if (var.fStorage != Variable::kGlobal_Storage) {
1428                return SpvStorageClassFunction;
1429            }
1430            SpvStorageClass_ result = get_storage_class(var.fModifiers);
1431            if (result == SpvStorageClassFunction) {
1432                result = SpvStorageClassPrivate;
1433            }
1434            return result;
1435        }
1436        case Expression::kFieldAccess_Kind:
1437            return get_storage_class(*((FieldAccess&) expr).fBase);
1438        case Expression::kIndex_Kind:
1439            return get_storage_class(*((IndexExpression&) expr).fBase);
1440        default:
1441            return SpvStorageClassFunction;
1442    }
1443}
1444
1445std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) {
1446    std::vector<SpvId> chain;
1447    switch (expr.fKind) {
1448        case Expression::kIndex_Kind: {
1449            IndexExpression& indexExpr = (IndexExpression&) expr;
1450            chain = this->getAccessChain(*indexExpr.fBase, out);
1451            chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1452            break;
1453        }
1454        case Expression::kFieldAccess_Kind: {
1455            FieldAccess& fieldExpr = (FieldAccess&) expr;
1456            chain = this->getAccessChain(*fieldExpr.fBase, out);
1457            IntLiteral index(fContext, -1, fieldExpr.fFieldIndex);
1458            chain.push_back(this->writeIntLiteral(index));
1459            break;
1460        }
1461        default:
1462            chain.push_back(this->getLValue(expr, out)->getPointer());
1463    }
1464    return chain;
1465}
1466
1467class PointerLValue : public SPIRVCodeGenerator::LValue {
1468public:
1469    PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1470    : fGen(gen)
1471    , fPointer(pointer)
1472    , fType(type) {}
1473
1474    virtual SpvId getPointer() override {
1475        return fPointer;
1476    }
1477
1478    virtual SpvId load(OutputStream& out) override {
1479        SpvId result = fGen.nextId();
1480        fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1481        return result;
1482    }
1483
1484    virtual void store(SpvId value, OutputStream& out) override {
1485        fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1486    }
1487
1488private:
1489    SPIRVCodeGenerator& fGen;
1490    const SpvId fPointer;
1491    const SpvId fType;
1492};
1493
1494class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1495public:
1496    SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1497                  const Type& baseType, const Type& swizzleType)
1498    : fGen(gen)
1499    , fVecPointer(vecPointer)
1500    , fComponents(components)
1501    , fBaseType(baseType)
1502    , fSwizzleType(swizzleType) {}
1503
1504    virtual SpvId getPointer() override {
1505        return 0;
1506    }
1507
1508    virtual SpvId load(OutputStream& out) override {
1509        SpvId base = fGen.nextId();
1510        fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1511        SpvId result = fGen.nextId();
1512        fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1513        fGen.writeWord(fGen.getType(fSwizzleType), out);
1514        fGen.writeWord(result, out);
1515        fGen.writeWord(base, out);
1516        fGen.writeWord(base, out);
1517        for (int component : fComponents) {
1518            fGen.writeWord(component, out);
1519        }
1520        return result;
1521    }
1522
1523    virtual void store(SpvId value, OutputStream& out) override {
1524        // use OpVectorShuffle to mix and match the vector components. We effectively create
1525        // a virtual vector out of the concatenation of the left and right vectors, and then
1526        // select components from this virtual vector to make the result vector. For
1527        // instance, given:
1528        // float3L = ...;
1529        // float3R = ...;
1530        // L.xz = R.xy;
1531        // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1532        // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1533        // (3, 1, 4).
1534        SpvId base = fGen.nextId();
1535        fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1536        SpvId shuffle = fGen.nextId();
1537        fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1538        fGen.writeWord(fGen.getType(fBaseType), out);
1539        fGen.writeWord(shuffle, out);
1540        fGen.writeWord(base, out);
1541        fGen.writeWord(value, out);
1542        for (int i = 0; i < fBaseType.columns(); i++) {
1543            // current offset into the virtual vector, defaults to pulling the unmodified
1544            // value from the left side
1545            int offset = i;
1546            // check to see if we are writing this component
1547            for (size_t j = 0; j < fComponents.size(); j++) {
1548                if (fComponents[j] == i) {
1549                    // we're writing to this component, so adjust the offset to pull from
1550                    // the correct component of the right side instead of preserving the
1551                    // value from the left
1552                    offset = (int) (j + fBaseType.columns());
1553                    break;
1554                }
1555            }
1556            fGen.writeWord(offset, out);
1557        }
1558        fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1559    }
1560
1561private:
1562    SPIRVCodeGenerator& fGen;
1563    const SpvId fVecPointer;
1564    const std::vector<int>& fComponents;
1565    const Type& fBaseType;
1566    const Type& fSwizzleType;
1567};
1568
1569std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1570                                                                          OutputStream& out) {
1571    switch (expr.fKind) {
1572        case Expression::kVariableReference_Kind: {
1573            const Variable& var = ((VariableReference&) expr).fVariable;
1574            auto entry = fVariableMap.find(&var);
1575            ASSERT(entry != fVariableMap.end());
1576            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1577                                                                       *this,
1578                                                                       entry->second,
1579                                                                       this->getType(expr.fType)));
1580        }
1581        case Expression::kIndex_Kind: // fall through
1582        case Expression::kFieldAccess_Kind: {
1583            std::vector<SpvId> chain = this->getAccessChain(expr, out);
1584            SpvId member = this->nextId();
1585            this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1586            this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1587            this->writeWord(member, out);
1588            for (SpvId idx : chain) {
1589                this->writeWord(idx, out);
1590            }
1591            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1592                                                                       *this,
1593                                                                       member,
1594                                                                       this->getType(expr.fType)));
1595        }
1596        case Expression::kSwizzle_Kind: {
1597            Swizzle& swizzle = (Swizzle&) expr;
1598            size_t count = swizzle.fComponents.size();
1599            SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1600            ASSERT(base);
1601            if (count == 1) {
1602                IntLiteral index(fContext, -1, swizzle.fComponents[0]);
1603                SpvId member = this->nextId();
1604                this->writeInstruction(SpvOpAccessChain,
1605                                       this->getPointerType(swizzle.fType,
1606                                                            get_storage_class(*swizzle.fBase)),
1607                                       member,
1608                                       base,
1609                                       this->writeIntLiteral(index),
1610                                       out);
1611                return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1612                                                                       *this,
1613                                                                       member,
1614                                                                       this->getType(expr.fType)));
1615            } else {
1616                return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1617                                                                              *this,
1618                                                                              base,
1619                                                                              swizzle.fComponents,
1620                                                                              swizzle.fBase->fType,
1621                                                                              expr.fType));
1622            }
1623        }
1624        case Expression::kTernary_Kind: {
1625            TernaryExpression& t = (TernaryExpression&) expr;
1626            SpvId test = this->writeExpression(*t.fTest, out);
1627            SpvId end = this->nextId();
1628            SpvId ifTrueLabel = this->nextId();
1629            SpvId ifFalseLabel = this->nextId();
1630            this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
1631            this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out);
1632            this->writeLabel(ifTrueLabel, out);
1633            SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer();
1634            ASSERT(ifTrue);
1635            this->writeInstruction(SpvOpBranch, end, out);
1636            ifTrueLabel = fCurrentBlock;
1637            SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer();
1638            ASSERT(ifFalse);
1639            ifFalseLabel = fCurrentBlock;
1640            this->writeInstruction(SpvOpBranch, end, out);
1641            SpvId result = this->nextId();
1642            this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue,
1643                       ifTrueLabel, ifFalse, ifFalseLabel, out);
1644            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1645                                                                       *this,
1646                                                                       result,
1647                                                                       this->getType(expr.fType)));
1648        }
1649        default:
1650            // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1651            // to the need to store values in temporary variables during function calls (see
1652            // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1653            // caught by IRGenerator
1654            SpvId result = this->nextId();
1655            SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1656            this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1657                                   fVariableBuffer);
1658            this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1659            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1660                                                                       *this,
1661                                                                       result,
1662                                                                       this->getType(expr.fType)));
1663    }
1664}
1665
1666SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) {
1667    SpvId result = this->nextId();
1668    auto entry = fVariableMap.find(&ref.fVariable);
1669    ASSERT(entry != fVariableMap.end());
1670    SpvId var = entry->second;
1671    this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1672    if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1673        fProgram.fSettings.fFlipY) {
1674        // need to remap to a top-left coordinate system
1675        if (fRTHeightStructId == (SpvId) -1) {
1676            // height variable hasn't been written yet
1677            std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1678            ASSERT(fRTHeightFieldIndex == (SpvId) -1);
1679            std::vector<Type::Field> fields;
1680            fields.emplace_back(Modifiers(), SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get());
1681            StringFragment name("sksl_synthetic_uniforms");
1682            Type intfStruct(-1, name, fields);
1683            Layout layout(0, -1, -1, 1, -1, -1, -1, -1, Layout::Format::kUnspecified,
1684                          Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key,
1685                          StringFragment());
1686            Variable* intfVar = new Variable(-1,
1687                                             Modifiers(layout, Modifiers::kUniform_Flag),
1688                                             name,
1689                                             intfStruct,
1690                                             Variable::kGlobal_Storage);
1691            fSynthetics.takeOwnership(intfVar);
1692            InterfaceBlock intf(-1, intfVar, name, String(""),
1693                                std::vector<std::unique_ptr<Expression>>(), st);
1694            fRTHeightStructId = this->writeInterfaceBlock(intf);
1695            fRTHeightFieldIndex = 0;
1696        }
1697        ASSERT(fRTHeightFieldIndex != (SpvId) -1);
1698        // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0)
1699        SpvId xId = this->nextId();
1700        this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1701                               result, 0, out);
1702        IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex);
1703        SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1704        SpvId heightPtr = this->nextId();
1705        this->writeOpCode(SpvOpAccessChain, 5, out);
1706        this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1707        this->writeWord(heightPtr, out);
1708        this->writeWord(fRTHeightStructId, out);
1709        this->writeWord(fieldIndexId, out);
1710        SpvId heightRead = this->nextId();
1711        this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1712                               heightPtr, out);
1713        SpvId rawYId = this->nextId();
1714        this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1715                               result, 1, out);
1716        SpvId flippedYId = this->nextId();
1717        this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1718                               heightRead, rawYId, out);
1719        FloatLiteral zero(fContext, -1, 0.0);
1720        SpvId zeroId = writeFloatLiteral(zero);
1721        FloatLiteral one(fContext, -1, 1.0);
1722        SpvId oneId = writeFloatLiteral(one);
1723        SpvId flipped = this->nextId();
1724        this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1725        this->writeWord(this->getType(*fContext.fFloat4_Type), out);
1726        this->writeWord(flipped, out);
1727        this->writeWord(xId, out);
1728        this->writeWord(flippedYId, out);
1729        this->writeWord(zeroId, out);
1730        this->writeWord(oneId, out);
1731        return flipped;
1732    }
1733    return result;
1734}
1735
1736SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) {
1737    return getLValue(expr, out)->load(out);
1738}
1739
1740SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) {
1741    return getLValue(f, out)->load(out);
1742}
1743
1744SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) {
1745    SpvId base = this->writeExpression(*swizzle.fBase, out);
1746    SpvId result = this->nextId();
1747    size_t count = swizzle.fComponents.size();
1748    if (count == 1) {
1749        this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1750                               swizzle.fComponents[0], out);
1751    } else {
1752        this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1753        this->writeWord(this->getType(swizzle.fType), out);
1754        this->writeWord(result, out);
1755        this->writeWord(base, out);
1756        this->writeWord(base, out);
1757        for (int component : swizzle.fComponents) {
1758            this->writeWord(component, out);
1759        }
1760    }
1761    return result;
1762}
1763
1764SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1765                                               const Type& operandType, SpvId lhs,
1766                                               SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1767                                               SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) {
1768    SpvId result = this->nextId();
1769    if (is_float(fContext, operandType)) {
1770        this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1771    } else if (is_signed(fContext, operandType)) {
1772        this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1773    } else if (is_unsigned(fContext, operandType)) {
1774        this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1775    } else if (operandType == *fContext.fBool_Type) {
1776        this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1777    } else {
1778        ABORT("invalid operandType: %s", operandType.description().c_str());
1779    }
1780    return result;
1781}
1782
1783bool is_assignment(Token::Kind op) {
1784    switch (op) {
1785        case Token::EQ:           // fall through
1786        case Token::PLUSEQ:       // fall through
1787        case Token::MINUSEQ:      // fall through
1788        case Token::STAREQ:       // fall through
1789        case Token::SLASHEQ:      // fall through
1790        case Token::PERCENTEQ:    // fall through
1791        case Token::SHLEQ:        // fall through
1792        case Token::SHREQ:        // fall through
1793        case Token::BITWISEOREQ:  // fall through
1794        case Token::BITWISEXOREQ: // fall through
1795        case Token::BITWISEANDEQ: // fall through
1796        case Token::LOGICALOREQ:  // fall through
1797        case Token::LOGICALXOREQ: // fall through
1798        case Token::LOGICALANDEQ:
1799            return true;
1800        default:
1801            return false;
1802    }
1803}
1804
1805SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, OutputStream& out) {
1806    if (operandType.kind() == Type::kVector_Kind) {
1807        SpvId result = this->nextId();
1808        this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out);
1809        return result;
1810    }
1811    return id;
1812}
1813
1814SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs,
1815                                                SpvOp_ floatOperator, SpvOp_ intOperator,
1816                                                OutputStream& out) {
1817    SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator;
1818    ASSERT(operandType.kind() == Type::kMatrix_Kind);
1819    SpvId rowType = this->getType(operandType.componentType().toCompound(fContext,
1820                                                                         operandType.columns(),
1821                                                                         1));
1822    SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext,
1823                                                                    operandType.columns(),
1824                                                                    1));
1825    SpvId boolType = this->getType(*fContext.fBool_Type);
1826    SpvId result = 0;
1827    for (int i = 0; i < operandType.rows(); i++) {
1828        SpvId rowL = this->nextId();
1829        this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out);
1830        SpvId rowR = this->nextId();
1831        this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out);
1832        SpvId compare = this->nextId();
1833        this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out);
1834        SpvId all = this->nextId();
1835        this->writeInstruction(SpvOpAll, boolType, all, compare, out);
1836        if (result != 0) {
1837            SpvId next = this->nextId();
1838            this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out);
1839            result = next;
1840        }
1841        else {
1842            result = all;
1843        }
1844    }
1845    return result;
1846}
1847
1848SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) {
1849    // handle cases where we don't necessarily evaluate both LHS and RHS
1850    switch (b.fOperator) {
1851        case Token::EQ: {
1852            SpvId rhs = this->writeExpression(*b.fRight, out);
1853            this->getLValue(*b.fLeft, out)->store(rhs, out);
1854            return rhs;
1855        }
1856        case Token::LOGICALAND:
1857            return this->writeLogicalAnd(b, out);
1858        case Token::LOGICALOR:
1859            return this->writeLogicalOr(b, out);
1860        default:
1861            break;
1862    }
1863
1864    // "normal" operators
1865    const Type& resultType = b.fType;
1866    std::unique_ptr<LValue> lvalue;
1867    SpvId lhs;
1868    if (is_assignment(b.fOperator)) {
1869        lvalue = this->getLValue(*b.fLeft, out);
1870        lhs = lvalue->load(out);
1871    } else {
1872        lvalue = nullptr;
1873        lhs = this->writeExpression(*b.fLeft, out);
1874    }
1875    SpvId rhs = this->writeExpression(*b.fRight, out);
1876    if (b.fOperator == Token::COMMA) {
1877        return rhs;
1878    }
1879    Type tmp("<invalid>");
1880    // component type we are operating on: float, int, uint
1881    const Type* operandType;
1882    // IR allows mismatched types in expressions (e.g. float2* float), but they need special handling
1883    // in SPIR-V
1884    if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) {
1885        if (b.fLeft->fType.kind() == Type::kVector_Kind &&
1886            b.fRight->fType.isNumber()) {
1887            // promote number to vector
1888            SpvId vec = this->nextId();
1889            this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
1890            this->writeWord(this->getType(resultType), out);
1891            this->writeWord(vec, out);
1892            for (int i = 0; i < resultType.columns(); i++) {
1893                this->writeWord(rhs, out);
1894            }
1895            rhs = vec;
1896            operandType = &b.fRight->fType;
1897        } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
1898                   b.fLeft->fType.isNumber()) {
1899            // promote number to vector
1900            SpvId vec = this->nextId();
1901            this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
1902            this->writeWord(this->getType(resultType), out);
1903            this->writeWord(vec, out);
1904            for (int i = 0; i < resultType.columns(); i++) {
1905                this->writeWord(lhs, out);
1906            }
1907            lhs = vec;
1908            ASSERT(!lvalue);
1909            operandType = &b.fLeft->fType;
1910        } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
1911            SpvOp_ op;
1912            if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
1913                op = SpvOpMatrixTimesMatrix;
1914            } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
1915                op = SpvOpMatrixTimesVector;
1916            } else {
1917                ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
1918                op = SpvOpMatrixTimesScalar;
1919            }
1920            SpvId result = this->nextId();
1921            this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
1922            if (b.fOperator == Token::STAREQ) {
1923                lvalue->store(result, out);
1924            } else {
1925                ASSERT(b.fOperator == Token::STAR);
1926            }
1927            return result;
1928        } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
1929            SpvId result = this->nextId();
1930            if (b.fLeft->fType.kind() == Type::kVector_Kind) {
1931                this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
1932                                       lhs, rhs, out);
1933            } else {
1934                ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
1935                this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
1936                                       lhs, out);
1937            }
1938            if (b.fOperator == Token::STAREQ) {
1939                lvalue->store(result, out);
1940            } else {
1941                ASSERT(b.fOperator == Token::STAR);
1942            }
1943            return result;
1944        } else {
1945            ABORT("unsupported binary expression: %s", b.description().c_str());
1946        }
1947    } else {
1948        tmp = this->getActualType(b.fLeft->fType);
1949        operandType = &tmp;
1950        ASSERT(*operandType == this->getActualType(b.fRight->fType));
1951    }
1952    switch (b.fOperator) {
1953        case Token::EQEQ: {
1954            if (operandType->kind() == Type::kMatrix_Kind) {
1955                return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual,
1956                                                   SpvOpIEqual, out);
1957            }
1958            ASSERT(resultType == *fContext.fBool_Type);
1959            return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1960                                                               SpvOpFOrdEqual, SpvOpIEqual,
1961                                                               SpvOpIEqual, SpvOpLogicalEqual, out),
1962                                    *operandType, out);
1963        }
1964        case Token::NEQ:
1965            if (operandType->kind() == Type::kMatrix_Kind) {
1966                return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual,
1967                                                   SpvOpINotEqual, out);
1968            }
1969            ASSERT(resultType == *fContext.fBool_Type);
1970            return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1971                                                               SpvOpFOrdNotEqual, SpvOpINotEqual,
1972                                                               SpvOpINotEqual, SpvOpLogicalNotEqual,
1973                                                               out),
1974                                    *operandType, out);
1975        case Token::GT:
1976            ASSERT(resultType == *fContext.fBool_Type);
1977            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1978                                              SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
1979                                              SpvOpUGreaterThan, SpvOpUndef, out);
1980        case Token::LT:
1981            ASSERT(resultType == *fContext.fBool_Type);
1982            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
1983                                              SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
1984        case Token::GTEQ:
1985            ASSERT(resultType == *fContext.fBool_Type);
1986            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1987                                              SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
1988                                              SpvOpUGreaterThanEqual, SpvOpUndef, out);
1989        case Token::LTEQ:
1990            ASSERT(resultType == *fContext.fBool_Type);
1991            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
1992                                              SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
1993                                              SpvOpULessThanEqual, SpvOpUndef, out);
1994        case Token::PLUS:
1995            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
1996                                              SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
1997        case Token::MINUS:
1998            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
1999                                              SpvOpISub, SpvOpISub, SpvOpUndef, out);
2000        case Token::STAR:
2001            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2002                b.fRight->fType.kind() == Type::kMatrix_Kind) {
2003                // matrix multiply
2004                SpvId result = this->nextId();
2005                this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2006                                       lhs, rhs, out);
2007                return result;
2008            }
2009            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2010                                              SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2011        case Token::SLASH:
2012            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2013                                              SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2014        case Token::PERCENT:
2015            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2016                                              SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2017        case Token::SHL:
2018            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2019                                              SpvOpShiftLeftLogical, SpvOpShiftLeftLogical,
2020                                              SpvOpUndef, out);
2021        case Token::SHR:
2022            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2023                                              SpvOpShiftRightArithmetic, SpvOpShiftRightLogical,
2024                                              SpvOpUndef, out);
2025        case Token::BITWISEAND:
2026            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2027                                              SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out);
2028        case Token::BITWISEOR:
2029            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2030                                              SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out);
2031        case Token::BITWISEXOR:
2032            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef,
2033                                              SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out);
2034        case Token::PLUSEQ: {
2035            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2036                                                      SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2037            ASSERT(lvalue);
2038            lvalue->store(result, out);
2039            return result;
2040        }
2041        case Token::MINUSEQ: {
2042            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2043                                                      SpvOpISub, SpvOpISub, SpvOpUndef, out);
2044            ASSERT(lvalue);
2045            lvalue->store(result, out);
2046            return result;
2047        }
2048        case Token::STAREQ: {
2049            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2050                b.fRight->fType.kind() == Type::kMatrix_Kind) {
2051                // matrix multiply
2052                SpvId result = this->nextId();
2053                this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2054                                       lhs, rhs, out);
2055                ASSERT(lvalue);
2056                lvalue->store(result, out);
2057                return result;
2058            }
2059            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2060                                                      SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2061            ASSERT(lvalue);
2062            lvalue->store(result, out);
2063            return result;
2064        }
2065        case Token::SLASHEQ: {
2066            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2067                                                      SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2068            ASSERT(lvalue);
2069            lvalue->store(result, out);
2070            return result;
2071        }
2072        case Token::PERCENTEQ: {
2073            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod,
2074                                                      SpvOpSMod, SpvOpUMod, SpvOpUndef, out);
2075            ASSERT(lvalue);
2076            lvalue->store(result, out);
2077            return result;
2078        }
2079        case Token::SHLEQ: {
2080            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2081                                                      SpvOpUndef, SpvOpShiftLeftLogical,
2082                                                      SpvOpShiftLeftLogical, SpvOpUndef, out);
2083            ASSERT(lvalue);
2084            lvalue->store(result, out);
2085            return result;
2086        }
2087        case Token::SHREQ: {
2088            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2089                                                      SpvOpUndef, SpvOpShiftRightArithmetic,
2090                                                      SpvOpShiftRightLogical, SpvOpUndef, out);
2091            ASSERT(lvalue);
2092            lvalue->store(result, out);
2093            return result;
2094        }
2095        case Token::BITWISEANDEQ: {
2096            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2097                                                      SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd,
2098                                                      SpvOpUndef, out);
2099            ASSERT(lvalue);
2100            lvalue->store(result, out);
2101            return result;
2102        }
2103        case Token::BITWISEOREQ: {
2104            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2105                                                      SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr,
2106                                                      SpvOpUndef, out);
2107            ASSERT(lvalue);
2108            lvalue->store(result, out);
2109            return result;
2110        }
2111        case Token::BITWISEXOREQ: {
2112            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2113                                                      SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor,
2114                                                      SpvOpUndef, out);
2115            ASSERT(lvalue);
2116            lvalue->store(result, out);
2117            return result;
2118        }
2119        default:
2120            ABORT("unsupported binary expression: %s", b.description().c_str());
2121    }
2122}
2123
2124SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) {
2125    ASSERT(a.fOperator == Token::LOGICALAND);
2126    BoolLiteral falseLiteral(fContext, -1, false);
2127    SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2128    SpvId lhs = this->writeExpression(*a.fLeft, out);
2129    SpvId rhsLabel = this->nextId();
2130    SpvId end = this->nextId();
2131    SpvId lhsBlock = fCurrentBlock;
2132    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2133    this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2134    this->writeLabel(rhsLabel, out);
2135    SpvId rhs = this->writeExpression(*a.fRight, out);
2136    SpvId rhsBlock = fCurrentBlock;
2137    this->writeInstruction(SpvOpBranch, end, out);
2138    this->writeLabel(end, out);
2139    SpvId result = this->nextId();
2140    this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2141                           lhsBlock, rhs, rhsBlock, out);
2142    return result;
2143}
2144
2145SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) {
2146    ASSERT(o.fOperator == Token::LOGICALOR);
2147    BoolLiteral trueLiteral(fContext, -1, true);
2148    SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2149    SpvId lhs = this->writeExpression(*o.fLeft, out);
2150    SpvId rhsLabel = this->nextId();
2151    SpvId end = this->nextId();
2152    SpvId lhsBlock = fCurrentBlock;
2153    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2154    this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2155    this->writeLabel(rhsLabel, out);
2156    SpvId rhs = this->writeExpression(*o.fRight, out);
2157    SpvId rhsBlock = fCurrentBlock;
2158    this->writeInstruction(SpvOpBranch, end, out);
2159    this->writeLabel(end, out);
2160    SpvId result = this->nextId();
2161    this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2162                           lhsBlock, rhs, rhsBlock, out);
2163    return result;
2164}
2165
2166SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) {
2167    SpvId test = this->writeExpression(*t.fTest, out);
2168    if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2169        // both true and false are constants, can just use OpSelect
2170        SpvId result = this->nextId();
2171        SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2172        SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2173        this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2174                               out);
2175        return result;
2176    }
2177    // was originally using OpPhi to choose the result, but for some reason that is crashing on
2178    // Adreno. Switched to storing the result in a temp variable as glslang does.
2179    SpvId var = this->nextId();
2180    this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2181                           var, SpvStorageClassFunction, fVariableBuffer);
2182    SpvId trueLabel = this->nextId();
2183    SpvId falseLabel = this->nextId();
2184    SpvId end = this->nextId();
2185    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2186    this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2187    this->writeLabel(trueLabel, out);
2188    this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2189    this->writeInstruction(SpvOpBranch, end, out);
2190    this->writeLabel(falseLabel, out);
2191    this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2192    this->writeInstruction(SpvOpBranch, end, out);
2193    this->writeLabel(end, out);
2194    SpvId result = this->nextId();
2195    this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2196    return result;
2197}
2198
2199std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2200    if (type.isInteger()) {
2201        return std::unique_ptr<Expression>(new IntLiteral(context, -1, 1, &type));
2202    }
2203    else if (type.isFloat()) {
2204        return std::unique_ptr<Expression>(new FloatLiteral(context, -1, 1.0, &type));
2205    } else {
2206        ABORT("math is unsupported on type '%s'", type.name().c_str());
2207    }
2208}
2209
2210SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) {
2211    if (p.fOperator == Token::MINUS) {
2212        SpvId result = this->nextId();
2213        SpvId typeId = this->getType(p.fType);
2214        SpvId expr = this->writeExpression(*p.fOperand, out);
2215        if (is_float(fContext, p.fType)) {
2216            this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2217        } else if (is_signed(fContext, p.fType)) {
2218            this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2219        } else {
2220            ABORT("unsupported prefix expression %s", p.description().c_str());
2221        };
2222        return result;
2223    }
2224    switch (p.fOperator) {
2225        case Token::PLUS:
2226            return this->writeExpression(*p.fOperand, out);
2227        case Token::PLUSPLUS: {
2228            std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2229            SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2230            SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2231                                                      SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2232                                                      out);
2233            lv->store(result, out);
2234            return result;
2235        }
2236        case Token::MINUSMINUS: {
2237            std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2238            SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2239            SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2240                                                      SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2241                                                      out);
2242            lv->store(result, out);
2243            return result;
2244        }
2245        case Token::LOGICALNOT: {
2246            ASSERT(p.fOperand->fType == *fContext.fBool_Type);
2247            SpvId result = this->nextId();
2248            this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2249                                   this->writeExpression(*p.fOperand, out), out);
2250            return result;
2251        }
2252        case Token::BITWISENOT: {
2253            SpvId result = this->nextId();
2254            this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2255                                   this->writeExpression(*p.fOperand, out), out);
2256            return result;
2257        }
2258        default:
2259            ABORT("unsupported prefix expression: %s", p.description().c_str());
2260    }
2261}
2262
2263SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) {
2264    std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2265    SpvId result = lv->load(out);
2266    SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2267    switch (p.fOperator) {
2268        case Token::PLUSPLUS: {
2269            SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2270                                                    SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2271            lv->store(temp, out);
2272            return result;
2273        }
2274        case Token::MINUSMINUS: {
2275            SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2276                                                    SpvOpISub, SpvOpISub, SpvOpUndef, out);
2277            lv->store(temp, out);
2278            return result;
2279        }
2280        default:
2281            ABORT("unsupported postfix expression %s", p.description().c_str());
2282    }
2283}
2284
2285SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2286    if (b.fValue) {
2287        if (fBoolTrue == 0) {
2288            fBoolTrue = this->nextId();
2289            this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2290                                   fConstantBuffer);
2291        }
2292        return fBoolTrue;
2293    } else {
2294        if (fBoolFalse == 0) {
2295            fBoolFalse = this->nextId();
2296            this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2297                                   fConstantBuffer);
2298        }
2299        return fBoolFalse;
2300    }
2301}
2302
2303SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2304    if (i.fType == *fContext.fInt_Type) {
2305        auto entry = fIntConstants.find(i.fValue);
2306        if (entry == fIntConstants.end()) {
2307            SpvId result = this->nextId();
2308            this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2309                                   fConstantBuffer);
2310            fIntConstants[i.fValue] = result;
2311            return result;
2312        }
2313        return entry->second;
2314    } else {
2315        ASSERT(i.fType == *fContext.fUInt_Type);
2316        auto entry = fUIntConstants.find(i.fValue);
2317        if (entry == fUIntConstants.end()) {
2318            SpvId result = this->nextId();
2319            this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2320                                   fConstantBuffer);
2321            fUIntConstants[i.fValue] = result;
2322            return result;
2323        }
2324        return entry->second;
2325    }
2326}
2327
2328SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2329    if (f.fType == *fContext.fFloat_Type || f.fType == *fContext.fHalf_Type) {
2330        float value = (float) f.fValue;
2331        auto entry = fFloatConstants.find(value);
2332        if (entry == fFloatConstants.end()) {
2333            SpvId result = this->nextId();
2334            uint32_t bits;
2335            ASSERT(sizeof(bits) == sizeof(value));
2336            memcpy(&bits, &value, sizeof(bits));
2337            this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2338                                   fConstantBuffer);
2339            fFloatConstants[value] = result;
2340            return result;
2341        }
2342        return entry->second;
2343    } else {
2344        ASSERT(f.fType == *fContext.fDouble_Type);
2345        auto entry = fDoubleConstants.find(f.fValue);
2346        if (entry == fDoubleConstants.end()) {
2347            SpvId result = this->nextId();
2348            uint64_t bits;
2349            ASSERT(sizeof(bits) == sizeof(f.fValue));
2350            memcpy(&bits, &f.fValue, sizeof(bits));
2351            this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2352                                   bits & 0xffffffff, bits >> 32, fConstantBuffer);
2353            fDoubleConstants[f.fValue] = result;
2354            return result;
2355        }
2356        return entry->second;
2357    }
2358}
2359
2360SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) {
2361    SpvId result = fFunctionMap[&f];
2362    this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2363                           SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2364    this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer);
2365    for (size_t i = 0; i < f.fParameters.size(); i++) {
2366        SpvId id = this->nextId();
2367        fVariableMap[f.fParameters[i]] = id;
2368        SpvId type;
2369        type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2370        this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2371    }
2372    return result;
2373}
2374
2375SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) {
2376    fVariableBuffer.reset();
2377    SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2378    this->writeLabel(this->nextId(), out);
2379    if (f.fDeclaration.fName == "main") {
2380        write_stringstream(fGlobalInitializersBuffer, out);
2381    }
2382    StringStream bodyBuffer;
2383    this->writeBlock((Block&) *f.fBody, bodyBuffer);
2384    write_stringstream(fVariableBuffer, out);
2385    write_stringstream(bodyBuffer, out);
2386    if (fCurrentBlock) {
2387        if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) {
2388            this->writeInstruction(SpvOpReturn, out);
2389        } else {
2390            this->writeInstruction(SpvOpUnreachable, out);
2391        }
2392    }
2393    this->writeInstruction(SpvOpFunctionEnd, out);
2394    return result;
2395}
2396
2397void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2398    if (layout.fLocation >= 0) {
2399        this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2400                               fDecorationBuffer);
2401    }
2402    if (layout.fBinding >= 0) {
2403        this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2404                               fDecorationBuffer);
2405    }
2406    if (layout.fIndex >= 0) {
2407        this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2408                               fDecorationBuffer);
2409    }
2410    if (layout.fSet >= 0) {
2411        this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2412                               fDecorationBuffer);
2413    }
2414    if (layout.fInputAttachmentIndex >= 0) {
2415        this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2416                               layout.fInputAttachmentIndex, fDecorationBuffer);
2417        fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment);
2418    }
2419    if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN &&
2420        layout.fBuiltin != SK_IN_BUILTIN) {
2421        this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2422                               fDecorationBuffer);
2423    }
2424}
2425
2426void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2427    if (layout.fLocation >= 0) {
2428        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2429                               layout.fLocation, fDecorationBuffer);
2430    }
2431    if (layout.fBinding >= 0) {
2432        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2433                               layout.fBinding, fDecorationBuffer);
2434    }
2435    if (layout.fIndex >= 0) {
2436        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2437                               layout.fIndex, fDecorationBuffer);
2438    }
2439    if (layout.fSet >= 0) {
2440        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2441                               layout.fSet, fDecorationBuffer);
2442    }
2443    if (layout.fInputAttachmentIndex >= 0) {
2444        this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2445                               layout.fInputAttachmentIndex, fDecorationBuffer);
2446    }
2447    if (layout.fBuiltin >= 0) {
2448        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2449                               layout.fBuiltin, fDecorationBuffer);
2450    }
2451}
2452
2453SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2454    bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag));
2455    bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags &
2456                               Layout::kPushConstant_Flag));
2457    MemoryLayout layout = (pushConstant || isBuffer) ?
2458                          MemoryLayout(MemoryLayout::k430_Standard) :
2459                          fDefaultLayout;
2460    SpvId result = this->nextId();
2461    const Type* type = &intf.fVariable.fType;
2462    if (fProgram.fInputs.fRTHeight) {
2463        ASSERT(fRTHeightStructId == (SpvId) -1);
2464        ASSERT(fRTHeightFieldIndex == (SpvId) -1);
2465        std::vector<Type::Field> fields = type->fields();
2466        fRTHeightStructId = result;
2467        fRTHeightFieldIndex = fields.size();
2468        fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2469        type = new Type(type->fOffset, type->name(), fields);
2470    }
2471    SpvId typeId = this->getType(*type, layout);
2472    if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) {
2473        this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer);
2474    } else {
2475        this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2476    }
2477    SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2478    SpvId ptrType = this->nextId();
2479    this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2480    this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2481    this->writeLayout(intf.fVariable.fModifiers.fLayout, result);
2482    fVariableMap[&intf.fVariable] = result;
2483    if (fProgram.fInputs.fRTHeight) {
2484        delete type;
2485    }
2486    return result;
2487}
2488
2489void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) {
2490    if ((modifiers.fFlags & Modifiers::kLowp_Flag) |
2491        (modifiers.fFlags & Modifiers::kMediump_Flag)) {
2492        this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer);
2493    }
2494}
2495
2496#define BUILTIN_IGNORE 9999
2497void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2498                                         OutputStream& out) {
2499    for (size_t i = 0; i < decl.fVars.size(); i++) {
2500        if (decl.fVars[i]->fKind == Statement::kNop_Kind) {
2501            continue;
2502        }
2503        const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i];
2504        const Variable* var = varDecl.fVar;
2505        // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2506        // in the OpenGL backend.
2507        ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2508                                           Modifiers::kWriteOnly_Flag |
2509                                           Modifiers::kCoherent_Flag |
2510                                           Modifiers::kVolatile_Flag |
2511                                           Modifiers::kRestrict_Flag)));
2512        if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2513            continue;
2514        }
2515        if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2516            kind != Program::kFragment_Kind) {
2517            ASSERT(!fProgram.fSettings.fFragColorIsInOut);
2518            continue;
2519        }
2520        if (!var->fReadCount && !var->fWriteCount &&
2521                !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2522                                            Modifiers::kOut_Flag |
2523                                            Modifiers::kUniform_Flag |
2524                                            Modifiers::kBuffer_Flag))) {
2525            // variable is dead and not an input / output var (the Vulkan debug layers complain if
2526            // we elide an interface var, even if it's dead)
2527            continue;
2528        }
2529        SpvStorageClass_ storageClass;
2530        if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2531            storageClass = SpvStorageClassInput;
2532        } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2533            storageClass = SpvStorageClassOutput;
2534        } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2535            if (var->fType.kind() == Type::kSampler_Kind) {
2536                storageClass = SpvStorageClassUniformConstant;
2537            } else {
2538                storageClass = SpvStorageClassUniform;
2539            }
2540        } else {
2541            storageClass = SpvStorageClassPrivate;
2542        }
2543        SpvId id = this->nextId();
2544        fVariableMap[var] = id;
2545        SpvId type = this->getPointerType(var->fType, storageClass);
2546        this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2547        this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2548        this->writePrecisionModifier(var->fModifiers, id);
2549        if (varDecl.fValue) {
2550            ASSERT(!fCurrentBlock);
2551            fCurrentBlock = -1;
2552            SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2553            this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2554            fCurrentBlock = 0;
2555        }
2556        this->writeLayout(var->fModifiers.fLayout, id);
2557        if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) {
2558            this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer);
2559        }
2560        if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) {
2561            this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective,
2562                                   fDecorationBuffer);
2563        }
2564    }
2565}
2566
2567void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) {
2568    for (const auto& stmt : decl.fVars) {
2569        ASSERT(stmt->fKind == Statement::kVarDeclaration_Kind);
2570        VarDeclaration& varDecl = (VarDeclaration&) *stmt;
2571        const Variable* var = varDecl.fVar;
2572        // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2573        // in the OpenGL backend.
2574        ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2575                                           Modifiers::kWriteOnly_Flag |
2576                                           Modifiers::kCoherent_Flag |
2577                                           Modifiers::kVolatile_Flag |
2578                                           Modifiers::kRestrict_Flag)));
2579        SpvId id = this->nextId();
2580        fVariableMap[var] = id;
2581        SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2582        this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2583        this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer);
2584        if (varDecl.fValue) {
2585            SpvId value = this->writeExpression(*varDecl.fValue, out);
2586            this->writeInstruction(SpvOpStore, id, value, out);
2587        }
2588    }
2589}
2590
2591void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) {
2592    switch (s.fKind) {
2593        case Statement::kNop_Kind:
2594            break;
2595        case Statement::kBlock_Kind:
2596            this->writeBlock((Block&) s, out);
2597            break;
2598        case Statement::kExpression_Kind:
2599            this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2600            break;
2601        case Statement::kReturn_Kind:
2602            this->writeReturnStatement((ReturnStatement&) s, out);
2603            break;
2604        case Statement::kVarDeclarations_Kind:
2605            this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2606            break;
2607        case Statement::kIf_Kind:
2608            this->writeIfStatement((IfStatement&) s, out);
2609            break;
2610        case Statement::kFor_Kind:
2611            this->writeForStatement((ForStatement&) s, out);
2612            break;
2613        case Statement::kWhile_Kind:
2614            this->writeWhileStatement((WhileStatement&) s, out);
2615            break;
2616        case Statement::kDo_Kind:
2617            this->writeDoStatement((DoStatement&) s, out);
2618            break;
2619        case Statement::kSwitch_Kind:
2620            this->writeSwitchStatement((SwitchStatement&) s, out);
2621            break;
2622        case Statement::kBreak_Kind:
2623            this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2624            break;
2625        case Statement::kContinue_Kind:
2626            this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2627            break;
2628        case Statement::kDiscard_Kind:
2629            this->writeInstruction(SpvOpKill, out);
2630            break;
2631        default:
2632            ABORT("unsupported statement: %s", s.description().c_str());
2633    }
2634}
2635
2636void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) {
2637    for (size_t i = 0; i < b.fStatements.size(); i++) {
2638        this->writeStatement(*b.fStatements[i], out);
2639    }
2640}
2641
2642void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) {
2643    SpvId test = this->writeExpression(*stmt.fTest, out);
2644    SpvId ifTrue = this->nextId();
2645    SpvId ifFalse = this->nextId();
2646    if (stmt.fIfFalse) {
2647        SpvId end = this->nextId();
2648        this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2649        this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2650        this->writeLabel(ifTrue, out);
2651        this->writeStatement(*stmt.fIfTrue, out);
2652        if (fCurrentBlock) {
2653            this->writeInstruction(SpvOpBranch, end, out);
2654        }
2655        this->writeLabel(ifFalse, out);
2656        this->writeStatement(*stmt.fIfFalse, out);
2657        if (fCurrentBlock) {
2658            this->writeInstruction(SpvOpBranch, end, out);
2659        }
2660        this->writeLabel(end, out);
2661    } else {
2662        this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2663        this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2664        this->writeLabel(ifTrue, out);
2665        this->writeStatement(*stmt.fIfTrue, out);
2666        if (fCurrentBlock) {
2667            this->writeInstruction(SpvOpBranch, ifFalse, out);
2668        }
2669        this->writeLabel(ifFalse, out);
2670    }
2671}
2672
2673void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) {
2674    if (f.fInitializer) {
2675        this->writeStatement(*f.fInitializer, out);
2676    }
2677    SpvId header = this->nextId();
2678    SpvId start = this->nextId();
2679    SpvId body = this->nextId();
2680    SpvId next = this->nextId();
2681    fContinueTarget.push(next);
2682    SpvId end = this->nextId();
2683    fBreakTarget.push(end);
2684    this->writeInstruction(SpvOpBranch, header, out);
2685    this->writeLabel(header, out);
2686    this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2687    this->writeInstruction(SpvOpBranch, start, out);
2688    this->writeLabel(start, out);
2689    if (f.fTest) {
2690        SpvId test = this->writeExpression(*f.fTest, out);
2691        this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2692    }
2693    this->writeLabel(body, out);
2694    this->writeStatement(*f.fStatement, out);
2695    if (fCurrentBlock) {
2696        this->writeInstruction(SpvOpBranch, next, out);
2697    }
2698    this->writeLabel(next, out);
2699    if (f.fNext) {
2700        this->writeExpression(*f.fNext, out);
2701    }
2702    this->writeInstruction(SpvOpBranch, header, out);
2703    this->writeLabel(end, out);
2704    fBreakTarget.pop();
2705    fContinueTarget.pop();
2706}
2707
2708void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) {
2709    // We believe the while loop code below will work, but Skia doesn't actually use them and
2710    // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2711    // the time being, we just fail with an error due to the lack of testing. If you encounter this
2712    // message, simply remove the error call below to see whether our while loop support actually
2713    // works.
2714    fErrors.error(w.fOffset, "internal error: while loop support has been disabled in SPIR-V, "
2715                  "see SkSLSPIRVCodeGenerator.cpp for details");
2716
2717    SpvId header = this->nextId();
2718    SpvId start = this->nextId();
2719    SpvId body = this->nextId();
2720    fContinueTarget.push(start);
2721    SpvId end = this->nextId();
2722    fBreakTarget.push(end);
2723    this->writeInstruction(SpvOpBranch, header, out);
2724    this->writeLabel(header, out);
2725    this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2726    this->writeInstruction(SpvOpBranch, start, out);
2727    this->writeLabel(start, out);
2728    SpvId test = this->writeExpression(*w.fTest, out);
2729    this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2730    this->writeLabel(body, out);
2731    this->writeStatement(*w.fStatement, out);
2732    if (fCurrentBlock) {
2733        this->writeInstruction(SpvOpBranch, start, out);
2734    }
2735    this->writeLabel(end, out);
2736    fBreakTarget.pop();
2737    fContinueTarget.pop();
2738}
2739
2740void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) {
2741    // We believe the do loop code below will work, but Skia doesn't actually use them and
2742    // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2743    // the time being, we just fail with an error due to the lack of testing. If you encounter this
2744    // message, simply remove the error call below to see whether our do loop support actually
2745    // works.
2746    fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see "
2747                  "SkSLSPIRVCodeGenerator.cpp for details");
2748
2749    SpvId header = this->nextId();
2750    SpvId start = this->nextId();
2751    SpvId next = this->nextId();
2752    fContinueTarget.push(next);
2753    SpvId end = this->nextId();
2754    fBreakTarget.push(end);
2755    this->writeInstruction(SpvOpBranch, header, out);
2756    this->writeLabel(header, out);
2757    this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2758    this->writeInstruction(SpvOpBranch, start, out);
2759    this->writeLabel(start, out);
2760    this->writeStatement(*d.fStatement, out);
2761    if (fCurrentBlock) {
2762        this->writeInstruction(SpvOpBranch, next, out);
2763    }
2764    this->writeLabel(next, out);
2765    SpvId test = this->writeExpression(*d.fTest, out);
2766    this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
2767    this->writeLabel(end, out);
2768    fBreakTarget.pop();
2769    fContinueTarget.pop();
2770}
2771
2772void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) {
2773    SpvId value = this->writeExpression(*s.fValue, out);
2774    std::vector<SpvId> labels;
2775    SpvId end = this->nextId();
2776    SpvId defaultLabel = end;
2777    fBreakTarget.push(end);
2778    int size = 3;
2779    for (const auto& c : s.fCases) {
2780        SpvId label = this->nextId();
2781        labels.push_back(label);
2782        if (c->fValue) {
2783            size += 2;
2784        } else {
2785            defaultLabel = label;
2786        }
2787    }
2788    labels.push_back(end);
2789    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2790    this->writeOpCode(SpvOpSwitch, size, out);
2791    this->writeWord(value, out);
2792    this->writeWord(defaultLabel, out);
2793    for (size_t i = 0; i < s.fCases.size(); ++i) {
2794        if (!s.fCases[i]->fValue) {
2795            continue;
2796        }
2797        ASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind);
2798        this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out);
2799        this->writeWord(labels[i], out);
2800    }
2801    for (size_t i = 0; i < s.fCases.size(); ++i) {
2802        this->writeLabel(labels[i], out);
2803        for (const auto& stmt : s.fCases[i]->fStatements) {
2804            this->writeStatement(*stmt, out);
2805        }
2806        if (fCurrentBlock) {
2807            this->writeInstruction(SpvOpBranch, labels[i + 1], out);
2808        }
2809    }
2810    this->writeLabel(end, out);
2811    fBreakTarget.pop();
2812}
2813
2814void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) {
2815    if (r.fExpression) {
2816        this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
2817                               out);
2818    } else {
2819        this->writeInstruction(SpvOpReturn, out);
2820    }
2821}
2822
2823void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) {
2824    ASSERT(fProgram.fKind == Program::kGeometry_Kind);
2825    int invocations = 1;
2826    for (size_t i = 0; i < fProgram.fElements.size(); i++) {
2827        if (fProgram.fElements[i]->fKind == ProgramElement::kModifiers_Kind) {
2828            const Modifiers& m = ((ModifiersDeclaration&) *fProgram.fElements[i]).fModifiers;
2829            if (m.fFlags & Modifiers::kIn_Flag) {
2830                if (m.fLayout.fInvocations != -1) {
2831                    invocations = m.fLayout.fInvocations;
2832                }
2833                SpvId input;
2834                switch (m.fLayout.fPrimitive) {
2835                    case Layout::kPoints_Primitive:
2836                        input = SpvExecutionModeInputPoints;
2837                        break;
2838                    case Layout::kLines_Primitive:
2839                        input = SpvExecutionModeInputLines;
2840                        break;
2841                    case Layout::kLinesAdjacency_Primitive:
2842                        input = SpvExecutionModeInputLinesAdjacency;
2843                        break;
2844                    case Layout::kTriangles_Primitive:
2845                        input = SpvExecutionModeTriangles;
2846                        break;
2847                    case Layout::kTrianglesAdjacency_Primitive:
2848                        input = SpvExecutionModeInputTrianglesAdjacency;
2849                        break;
2850                    default:
2851                        input = 0;
2852                        break;
2853                }
2854                if (input) {
2855                    this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out);
2856                }
2857            } else if (m.fFlags & Modifiers::kOut_Flag) {
2858                SpvId output;
2859                switch (m.fLayout.fPrimitive) {
2860                    case Layout::kPoints_Primitive:
2861                        output = SpvExecutionModeOutputPoints;
2862                        break;
2863                    case Layout::kLineStrip_Primitive:
2864                        output = SpvExecutionModeOutputLineStrip;
2865                        break;
2866                    case Layout::kTriangleStrip_Primitive:
2867                        output = SpvExecutionModeOutputTriangleStrip;
2868                        break;
2869                    default:
2870                        output = 0;
2871                        break;
2872                }
2873                if (output) {
2874                    this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out);
2875                }
2876                if (m.fLayout.fMaxVertices != -1) {
2877                    this->writeInstruction(SpvOpExecutionMode, entryPoint,
2878                                           SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices,
2879                                           out);
2880                }
2881            }
2882        }
2883    }
2884    this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations,
2885                           invocations, out);
2886}
2887
2888void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) {
2889    fGLSLExtendedInstructions = this->nextId();
2890    StringStream body;
2891    std::set<SpvId> interfaceVars;
2892    // assign IDs to functions, determine sk_in size
2893    int skInSize = -1;
2894    for (size_t i = 0; i < program.fElements.size(); i++) {
2895        switch (program.fElements[i]->fKind) {
2896            case ProgramElement::kFunction_Kind: {
2897                FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
2898                fFunctionMap[&f.fDeclaration] = this->nextId();
2899                break;
2900            }
2901            case ProgramElement::kModifiers_Kind: {
2902                Modifiers& m = ((ModifiersDeclaration&) *program.fElements[i]).fModifiers;
2903                if (m.fFlags & Modifiers::kIn_Flag) {
2904                    switch (m.fLayout.fPrimitive) {
2905                        case Layout::kPoints_Primitive: // break
2906                        case Layout::kLines_Primitive:
2907                            skInSize = 1;
2908                            break;
2909                        case Layout::kLinesAdjacency_Primitive: // break
2910                            skInSize = 2;
2911                            break;
2912                        case Layout::kTriangles_Primitive: // break
2913                        case Layout::kTrianglesAdjacency_Primitive:
2914                            skInSize = 3;
2915                            break;
2916                        default:
2917                            break;
2918                    }
2919                }
2920                break;
2921            }
2922            default:
2923                break;
2924        }
2925    }
2926    for (size_t i = 0; i < program.fElements.size(); i++) {
2927        if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
2928            InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
2929            if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) {
2930                ASSERT(skInSize != -1);
2931                intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize));
2932            }
2933            SpvId id = this->writeInterfaceBlock(intf);
2934            if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
2935                (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
2936                interfaceVars.insert(id);
2937            }
2938        }
2939    }
2940    for (size_t i = 0; i < program.fElements.size(); i++) {
2941        if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
2942            this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]),
2943                                  body);
2944        }
2945    }
2946    for (size_t i = 0; i < program.fElements.size(); i++) {
2947        if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
2948            this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
2949        }
2950    }
2951    const FunctionDeclaration* main = nullptr;
2952    for (auto entry : fFunctionMap) {
2953        if (entry.first->fName == "main") {
2954            main = entry.first;
2955        }
2956    }
2957    ASSERT(main);
2958    for (auto entry : fVariableMap) {
2959        const Variable* var = entry.first;
2960        if (var->fStorage == Variable::kGlobal_Storage &&
2961                ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
2962                 (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
2963            interfaceVars.insert(entry.second);
2964        }
2965    }
2966    this->writeCapabilities(out);
2967    this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
2968    this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
2969    this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) +
2970                      (int32_t) interfaceVars.size(), out);
2971    switch (program.fKind) {
2972        case Program::kVertex_Kind:
2973            this->writeWord(SpvExecutionModelVertex, out);
2974            break;
2975        case Program::kFragment_Kind:
2976            this->writeWord(SpvExecutionModelFragment, out);
2977            break;
2978        case Program::kGeometry_Kind:
2979            this->writeWord(SpvExecutionModelGeometry, out);
2980            break;
2981        default:
2982            ABORT("cannot write this kind of program to SPIR-V\n");
2983    }
2984    SpvId entryPoint = fFunctionMap[main];
2985    this->writeWord(entryPoint, out);
2986    this->writeString(main->fName.fChars, main->fName.fLength, out);
2987    for (int var : interfaceVars) {
2988        this->writeWord(var, out);
2989    }
2990    if (program.fKind == Program::kGeometry_Kind) {
2991        this->writeGeometryShaderExecutionMode(entryPoint, out);
2992    }
2993    if (program.fKind == Program::kFragment_Kind) {
2994        this->writeInstruction(SpvOpExecutionMode,
2995                               fFunctionMap[main],
2996                               SpvExecutionModeOriginUpperLeft,
2997                               out);
2998    }
2999    for (size_t i = 0; i < program.fElements.size(); i++) {
3000        if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
3001            this->writeInstruction(SpvOpSourceExtension,
3002                                   ((Extension&) *program.fElements[i]).fName.c_str(),
3003                                   out);
3004        }
3005    }
3006
3007    write_stringstream(fExtraGlobalsBuffer, out);
3008    write_stringstream(fNameBuffer, out);
3009    write_stringstream(fDecorationBuffer, out);
3010    write_stringstream(fConstantBuffer, out);
3011    write_stringstream(fExternalFunctionsBuffer, out);
3012    write_stringstream(body, out);
3013}
3014
3015bool SPIRVCodeGenerator::generateCode() {
3016    ASSERT(!fErrors.errorCount());
3017    this->writeWord(SpvMagicNumber, *fOut);
3018    this->writeWord(SpvVersion, *fOut);
3019    this->writeWord(SKSL_MAGIC, *fOut);
3020    StringStream buffer;
3021    this->writeInstructions(fProgram, buffer);
3022    this->writeWord(fIdCount, *fOut);
3023    this->writeWord(0, *fOut); // reserved, always zero
3024    write_stringstream(buffer, *fOut);
3025    return 0 == fErrors.errorCount();
3026}
3027
3028}
3029