SkSLSPIRVCodeGenerator.cpp revision 8feeff929e57ea63914213f3b14d8f00b287a0ad
1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "SkSLSPIRVCodeGenerator.h"
9
10#include "string.h"
11
12#include "GLSL.std.450.h"
13
14#include "ir/SkSLExpressionStatement.h"
15#include "ir/SkSLExtension.h"
16#include "ir/SkSLIndexExpression.h"
17#include "ir/SkSLVariableReference.h"
18#include "SkSLCompiler.h"
19
20namespace SkSL {
21
22#define SPIRV_DEBUG 0
23
24static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
25
26void SPIRVCodeGenerator::setupIntrinsics() {
27#define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
28                                    GLSLstd450 ## x, GLSLstd450 ## x)
29#define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
30                                                             GLSLstd450 ## ifFloat, \
31                                                             GLSLstd450 ## ifInt, \
32                                                             GLSLstd450 ## ifUInt, \
33                                                             SpvOpUndef)
34#define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
35                                   k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
36                                   k ## x ## _SpecialIntrinsic)
37    fIntrinsicMap[SkString("round")]         = ALL_GLSL(Round);
38    fIntrinsicMap[SkString("roundEven")]     = ALL_GLSL(RoundEven);
39    fIntrinsicMap[SkString("trunc")]         = ALL_GLSL(Trunc);
40    fIntrinsicMap[SkString("abs")]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
41    fIntrinsicMap[SkString("sign")]          = BY_TYPE_GLSL(FSign, SSign, SSign);
42    fIntrinsicMap[SkString("floor")]         = ALL_GLSL(Floor);
43    fIntrinsicMap[SkString("ceil")]          = ALL_GLSL(Ceil);
44    fIntrinsicMap[SkString("fract")]         = ALL_GLSL(Fract);
45    fIntrinsicMap[SkString("radians")]       = ALL_GLSL(Radians);
46    fIntrinsicMap[SkString("degrees")]       = ALL_GLSL(Degrees);
47    fIntrinsicMap[SkString("sin")]           = ALL_GLSL(Sin);
48    fIntrinsicMap[SkString("cos")]           = ALL_GLSL(Cos);
49    fIntrinsicMap[SkString("tan")]           = ALL_GLSL(Tan);
50    fIntrinsicMap[SkString("asin")]          = ALL_GLSL(Asin);
51    fIntrinsicMap[SkString("acos")]          = ALL_GLSL(Acos);
52    fIntrinsicMap[SkString("atan")]          = SPECIAL(Atan);
53    fIntrinsicMap[SkString("sinh")]          = ALL_GLSL(Sinh);
54    fIntrinsicMap[SkString("cosh")]          = ALL_GLSL(Cosh);
55    fIntrinsicMap[SkString("tanh")]          = ALL_GLSL(Tanh);
56    fIntrinsicMap[SkString("asinh")]         = ALL_GLSL(Asinh);
57    fIntrinsicMap[SkString("acosh")]         = ALL_GLSL(Acosh);
58    fIntrinsicMap[SkString("atanh")]         = ALL_GLSL(Atanh);
59    fIntrinsicMap[SkString("pow")]           = ALL_GLSL(Pow);
60    fIntrinsicMap[SkString("exp")]           = ALL_GLSL(Exp);
61    fIntrinsicMap[SkString("log")]           = ALL_GLSL(Log);
62    fIntrinsicMap[SkString("exp2")]          = ALL_GLSL(Exp2);
63    fIntrinsicMap[SkString("log2")]          = ALL_GLSL(Log2);
64    fIntrinsicMap[SkString("sqrt")]          = ALL_GLSL(Sqrt);
65    fIntrinsicMap[SkString("inversesqrt")]   = ALL_GLSL(InverseSqrt);
66    fIntrinsicMap[SkString("determinant")]   = ALL_GLSL(Determinant);
67    fIntrinsicMap[SkString("matrixInverse")] = ALL_GLSL(MatrixInverse);
68    fIntrinsicMap[SkString("mod")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFMod,
69                                                               SpvOpSMod, SpvOpUMod, SpvOpUndef);
70    fIntrinsicMap[SkString("min")]           = BY_TYPE_GLSL(FMin, SMin, UMin);
71    fIntrinsicMap[SkString("max")]           = BY_TYPE_GLSL(FMax, SMax, UMax);
72    fIntrinsicMap[SkString("clamp")]         = BY_TYPE_GLSL(FClamp, SClamp, UClamp);
73    fIntrinsicMap[SkString("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
74                                                               SpvOpUndef, SpvOpUndef, SpvOpUndef);
75    fIntrinsicMap[SkString("mix")]           = ALL_GLSL(FMix);
76    fIntrinsicMap[SkString("step")]          = ALL_GLSL(Step);
77    fIntrinsicMap[SkString("smoothstep")]    = ALL_GLSL(SmoothStep);
78    fIntrinsicMap[SkString("fma")]           = ALL_GLSL(Fma);
79    fIntrinsicMap[SkString("frexp")]         = ALL_GLSL(Frexp);
80    fIntrinsicMap[SkString("ldexp")]         = ALL_GLSL(Ldexp);
81
82#define PACK(type) fIntrinsicMap[SkString("pack" #type)] = ALL_GLSL(Pack ## type); \
83                   fIntrinsicMap[SkString("unpack" #type)] = ALL_GLSL(Unpack ## type)
84    PACK(Snorm4x8);
85    PACK(Unorm4x8);
86    PACK(Snorm2x16);
87    PACK(Unorm2x16);
88    PACK(Half2x16);
89    PACK(Double2x32);
90    fIntrinsicMap[SkString("length")]      = ALL_GLSL(Length);
91    fIntrinsicMap[SkString("distance")]    = ALL_GLSL(Distance);
92    fIntrinsicMap[SkString("cross")]       = ALL_GLSL(Cross);
93    fIntrinsicMap[SkString("normalize")]   = ALL_GLSL(Normalize);
94    fIntrinsicMap[SkString("faceForward")] = ALL_GLSL(FaceForward);
95    fIntrinsicMap[SkString("reflect")]     = ALL_GLSL(Reflect);
96    fIntrinsicMap[SkString("refract")]     = ALL_GLSL(Refract);
97    fIntrinsicMap[SkString("findLSB")]     = ALL_GLSL(FindILsb);
98    fIntrinsicMap[SkString("findMSB")]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
99    fIntrinsicMap[SkString("dFdx")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
100                                                             SpvOpUndef, SpvOpUndef, SpvOpUndef);
101    fIntrinsicMap[SkString("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
102                                                             SpvOpUndef, SpvOpUndef, SpvOpUndef);
103    fIntrinsicMap[SkString("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
104                                                             SpvOpUndef, SpvOpUndef, SpvOpUndef);
105    fIntrinsicMap[SkString("texture")]     = SPECIAL(Texture);
106
107    fIntrinsicMap[SkString("subpassLoad")] = SPECIAL(SubpassLoad);
108
109    fIntrinsicMap[SkString("any")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
110                                                                  SpvOpUndef, SpvOpUndef, SpvOpAny);
111    fIntrinsicMap[SkString("all")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
112                                                                  SpvOpUndef, SpvOpUndef, SpvOpAll);
113    fIntrinsicMap[SkString("equal")]            = std::make_tuple(kSPIRV_IntrinsicKind,
114                                                                  SpvOpFOrdEqual, SpvOpIEqual,
115                                                                  SpvOpIEqual, SpvOpLogicalEqual);
116    fIntrinsicMap[SkString("notEqual")]         = std::make_tuple(kSPIRV_IntrinsicKind,
117                                                                  SpvOpFOrdNotEqual, SpvOpINotEqual,
118                                                                  SpvOpINotEqual,
119                                                                  SpvOpLogicalNotEqual);
120    fIntrinsicMap[SkString("lessThan")]         = std::make_tuple(kSPIRV_IntrinsicKind,
121                                                                  SpvOpSLessThan, SpvOpULessThan,
122                                                                  SpvOpFOrdLessThan, SpvOpUndef);
123    fIntrinsicMap[SkString("lessThanEqual")]    = std::make_tuple(kSPIRV_IntrinsicKind,
124                                                                  SpvOpSLessThanEqual,
125                                                                  SpvOpULessThanEqual,
126                                                                  SpvOpFOrdLessThanEqual,
127                                                                  SpvOpUndef);
128    fIntrinsicMap[SkString("greaterThan")]      = std::make_tuple(kSPIRV_IntrinsicKind,
129                                                                  SpvOpSGreaterThan,
130                                                                  SpvOpUGreaterThan,
131                                                                  SpvOpFOrdGreaterThan,
132                                                                  SpvOpUndef);
133    fIntrinsicMap[SkString("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
134                                                                  SpvOpSGreaterThanEqual,
135                                                                  SpvOpUGreaterThanEqual,
136                                                                  SpvOpFOrdGreaterThanEqual,
137                                                                  SpvOpUndef);
138
139// interpolateAt* not yet supported...
140}
141
142void SPIRVCodeGenerator::writeWord(int32_t word, SkWStream& out) {
143#if SPIRV_DEBUG
144    out << "(" << word << ") ";
145#else
146    out.write((const char*) &word, sizeof(word));
147#endif
148}
149
150static bool is_float(const Context& context, const Type& type) {
151    if (type.kind() == Type::kVector_Kind) {
152        return is_float(context, type.componentType());
153    }
154    return type == *context.fFloat_Type || type == *context.fDouble_Type;
155}
156
157static bool is_signed(const Context& context, const Type& type) {
158    if (type.kind() == Type::kVector_Kind) {
159        return is_signed(context, type.componentType());
160    }
161    return type == *context.fInt_Type;
162}
163
164static bool is_unsigned(const Context& context, const Type& type) {
165    if (type.kind() == Type::kVector_Kind) {
166        return is_unsigned(context, type.componentType());
167    }
168    return type == *context.fUInt_Type;
169}
170
171static bool is_bool(const Context& context, const Type& type) {
172    if (type.kind() == Type::kVector_Kind) {
173        return is_bool(context, type.componentType());
174    }
175    return type == *context.fBool_Type;
176}
177
178static bool is_out(const Variable& var) {
179    return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
180}
181
182#if SPIRV_DEBUG
183static SkString opcode_text(SpvOp_ opCode) {
184    switch (opCode) {
185        case SpvOpNop:
186            return SkString("Nop");
187        case SpvOpUndef:
188            return SkString("Undef");
189        case SpvOpSourceContinued:
190            return SkString("SourceContinued");
191        case SpvOpSource:
192            return SkString("Source");
193        case SpvOpSourceExtension:
194            return SkString("SourceExtension");
195        case SpvOpName:
196            return SkString("Name");
197        case SpvOpMemberName:
198            return SkString("MemberName");
199        case SpvOpString:
200            return SkString("String");
201        case SpvOpLine:
202            return SkString("Line");
203        case SpvOpExtension:
204            return SkString("Extension");
205        case SpvOpExtInstImport:
206            return SkString("ExtInstImport");
207        case SpvOpExtInst:
208            return SkString("ExtInst");
209        case SpvOpMemoryModel:
210            return SkString("MemoryModel");
211        case SpvOpEntryPoint:
212            return SkString("EntryPoint");
213        case SpvOpExecutionMode:
214            return SkString("ExecutionMode");
215        case SpvOpCapability:
216            return SkString("Capability");
217        case SpvOpTypeVoid:
218            return SkString("TypeVoid");
219        case SpvOpTypeBool:
220            return SkString("TypeBool");
221        case SpvOpTypeInt:
222            return SkString("TypeInt");
223        case SpvOpTypeFloat:
224            return SkString("TypeFloat");
225        case SpvOpTypeVector:
226            return SkString("TypeVector");
227        case SpvOpTypeMatrix:
228            return SkString("TypeMatrix");
229        case SpvOpTypeImage:
230            return SkString("TypeImage");
231        case SpvOpTypeSampler:
232            return SkString("TypeSampler");
233        case SpvOpTypeSampledImage:
234            return SkString("TypeSampledImage");
235        case SpvOpTypeArray:
236            return SkString("TypeArray");
237        case SpvOpTypeRuntimeArray:
238            return SkString("TypeRuntimeArray");
239        case SpvOpTypeStruct:
240            return SkString("TypeStruct");
241        case SpvOpTypeOpaque:
242            return SkString("TypeOpaque");
243        case SpvOpTypePointer:
244            return SkString("TypePointer");
245        case SpvOpTypeFunction:
246            return SkString("TypeFunction");
247        case SpvOpTypeEvent:
248            return SkString("TypeEvent");
249        case SpvOpTypeDeviceEvent:
250            return SkString("TypeDeviceEvent");
251        case SpvOpTypeReserveId:
252            return SkString("TypeReserveId");
253        case SpvOpTypeQueue:
254            return SkString("TypeQueue");
255        case SpvOpTypePipe:
256            return SkString("TypePipe");
257        case SpvOpTypeForwardPointer:
258            return SkString("TypeForwardPointer");
259        case SpvOpConstantTrue:
260            return SkString("ConstantTrue");
261        case SpvOpConstantFalse:
262            return SkString("ConstantFalse");
263        case SpvOpConstant:
264            return SkString("Constant");
265        case SpvOpConstantComposite:
266            return SkString("ConstantComposite");
267        case SpvOpConstantSampler:
268            return SkString("ConstantSampler");
269        case SpvOpConstantNull:
270            return SkString("ConstantNull");
271        case SpvOpSpecConstantTrue:
272            return SkString("SpecConstantTrue");
273        case SpvOpSpecConstantFalse:
274            return SkString("SpecConstantFalse");
275        case SpvOpSpecConstant:
276            return SkString("SpecConstant");
277        case SpvOpSpecConstantComposite:
278            return SkString("SpecConstantComposite");
279        case SpvOpSpecConstantOp:
280            return SkString("SpecConstantOp");
281        case SpvOpFunction:
282            return SkString("Function");
283        case SpvOpFunctionParameter:
284            return SkString("FunctionParameter");
285        case SpvOpFunctionEnd:
286            return SkString("FunctionEnd");
287        case SpvOpFunctionCall:
288            return SkString("FunctionCall");
289        case SpvOpVariable:
290            return SkString("Variable");
291        case SpvOpImageTexelPointer:
292            return SkString("ImageTexelPointer");
293        case SpvOpLoad:
294            return SkString("Load");
295        case SpvOpStore:
296            return SkString("Store");
297        case SpvOpCopyMemory:
298            return SkString("CopyMemory");
299        case SpvOpCopyMemorySized:
300            return SkString("CopyMemorySized");
301        case SpvOpAccessChain:
302            return SkString("AccessChain");
303        case SpvOpInBoundsAccessChain:
304            return SkString("InBoundsAccessChain");
305        case SpvOpPtrAccessChain:
306            return SkString("PtrAccessChain");
307        case SpvOpArrayLength:
308            return SkString("ArrayLength");
309        case SpvOpGenericPtrMemSemantics:
310            return SkString("GenericPtrMemSemantics");
311        case SpvOpInBoundsPtrAccessChain:
312            return SkString("InBoundsPtrAccessChain");
313        case SpvOpDecorate:
314            return SkString("Decorate");
315        case SpvOpMemberDecorate:
316            return SkString("MemberDecorate");
317        case SpvOpDecorationGroup:
318            return SkString("DecorationGroup");
319        case SpvOpGroupDecorate:
320            return SkString("GroupDecorate");
321        case SpvOpGroupMemberDecorate:
322            return SkString("GroupMemberDecorate");
323        case SpvOpVectorExtractDynamic:
324            return SkString("VectorExtractDynamic");
325        case SpvOpVectorInsertDynamic:
326            return SkString("VectorInsertDynamic");
327        case SpvOpVectorShuffle:
328            return SkString("VectorShuffle");
329        case SpvOpCompositeConstruct:
330            return SkString("CompositeConstruct");
331        case SpvOpCompositeExtract:
332            return SkString("CompositeExtract");
333        case SpvOpCompositeInsert:
334            return SkString("CompositeInsert");
335        case SpvOpCopyObject:
336            return SkString("CopyObject");
337        case SpvOpTranspose:
338            return SkString("Transpose");
339        case SpvOpSampledImage:
340            return SkString("SampledImage");
341        case SpvOpImageSampleImplicitLod:
342            return SkString("ImageSampleImplicitLod");
343        case SpvOpImageSampleExplicitLod:
344            return SkString("ImageSampleExplicitLod");
345        case SpvOpImageSampleDrefImplicitLod:
346            return SkString("ImageSampleDrefImplicitLod");
347        case SpvOpImageSampleDrefExplicitLod:
348            return SkString("ImageSampleDrefExplicitLod");
349        case SpvOpImageSampleProjImplicitLod:
350            return SkString("ImageSampleProjImplicitLod");
351        case SpvOpImageSampleProjExplicitLod:
352            return SkString("ImageSampleProjExplicitLod");
353        case SpvOpImageSampleProjDrefImplicitLod:
354            return SkString("ImageSampleProjDrefImplicitLod");
355        case SpvOpImageSampleProjDrefExplicitLod:
356            return SkString("ImageSampleProjDrefExplicitLod");
357        case SpvOpImageFetch:
358            return SkString("ImageFetch");
359        case SpvOpImageGather:
360            return SkString("ImageGather");
361        case SpvOpImageDrefGather:
362            return SkString("ImageDrefGather");
363        case SpvOpImageRead:
364            return SkString("ImageRead");
365        case SpvOpImageWrite:
366            return SkString("ImageWrite");
367        case SpvOpImage:
368            return SkString("Image");
369        case SpvOpImageQueryFormat:
370            return SkString("ImageQueryFormat");
371        case SpvOpImageQueryOrder:
372            return SkString("ImageQueryOrder");
373        case SpvOpImageQuerySizeLod:
374            return SkString("ImageQuerySizeLod");
375        case SpvOpImageQuerySize:
376            return SkString("ImageQuerySize");
377        case SpvOpImageQueryLod:
378            return SkString("ImageQueryLod");
379        case SpvOpImageQueryLevels:
380            return SkString("ImageQueryLevels");
381        case SpvOpImageQuerySamples:
382            return SkString("ImageQuerySamples");
383        case SpvOpConvertFToU:
384            return SkString("ConvertFToU");
385        case SpvOpConvertFToS:
386            return SkString("ConvertFToS");
387        case SpvOpConvertSToF:
388            return SkString("ConvertSToF");
389        case SpvOpConvertUToF:
390            return SkString("ConvertUToF");
391        case SpvOpUConvert:
392            return SkString("UConvert");
393        case SpvOpSConvert:
394            return SkString("SConvert");
395        case SpvOpFConvert:
396            return SkString("FConvert");
397        case SpvOpQuantizeToF16:
398            return SkString("QuantizeToF16");
399        case SpvOpConvertPtrToU:
400            return SkString("ConvertPtrToU");
401        case SpvOpSatConvertSToU:
402            return SkString("SatConvertSToU");
403        case SpvOpSatConvertUToS:
404            return SkString("SatConvertUToS");
405        case SpvOpConvertUToPtr:
406            return SkString("ConvertUToPtr");
407        case SpvOpPtrCastToGeneric:
408            return SkString("PtrCastToGeneric");
409        case SpvOpGenericCastToPtr:
410            return SkString("GenericCastToPtr");
411        case SpvOpGenericCastToPtrExplicit:
412            return SkString("GenericCastToPtrExplicit");
413        case SpvOpBitcast:
414            return SkString("Bitcast");
415        case SpvOpSNegate:
416            return SkString("SNegate");
417        case SpvOpFNegate:
418            return SkString("FNegate");
419        case SpvOpIAdd:
420            return SkString("IAdd");
421        case SpvOpFAdd:
422            return SkString("FAdd");
423        case SpvOpISub:
424            return SkString("ISub");
425        case SpvOpFSub:
426            return SkString("FSub");
427        case SpvOpIMul:
428            return SkString("IMul");
429        case SpvOpFMul:
430            return SkString("FMul");
431        case SpvOpUDiv:
432            return SkString("UDiv");
433        case SpvOpSDiv:
434            return SkString("SDiv");
435        case SpvOpFDiv:
436            return SkString("FDiv");
437        case SpvOpUMod:
438            return SkString("UMod");
439        case SpvOpSRem:
440            return SkString("SRem");
441        case SpvOpSMod:
442            return SkString("SMod");
443        case SpvOpFRem:
444            return SkString("FRem");
445        case SpvOpFMod:
446            return SkString("FMod");
447        case SpvOpVectorTimesScalar:
448            return SkString("VectorTimesScalar");
449        case SpvOpMatrixTimesScalar:
450            return SkString("MatrixTimesScalar");
451        case SpvOpVectorTimesMatrix:
452            return SkString("VectorTimesMatrix");
453        case SpvOpMatrixTimesVector:
454            return SkString("MatrixTimesVector");
455        case SpvOpMatrixTimesMatrix:
456            return SkString("MatrixTimesMatrix");
457        case SpvOpOuterProduct:
458            return SkString("OuterProduct");
459        case SpvOpDot:
460            return SkString("Dot");
461        case SpvOpIAddCarry:
462            return SkString("IAddCarry");
463        case SpvOpISubBorrow:
464            return SkString("ISubBorrow");
465        case SpvOpUMulExtended:
466            return SkString("UMulExtended");
467        case SpvOpSMulExtended:
468            return SkString("SMulExtended");
469        case SpvOpAny:
470            return SkString("Any");
471        case SpvOpAll:
472            return SkString("All");
473        case SpvOpIsNan:
474            return SkString("IsNan");
475        case SpvOpIsInf:
476            return SkString("IsInf");
477        case SpvOpIsFinite:
478            return SkString("IsFinite");
479        case SpvOpIsNormal:
480            return SkString("IsNormal");
481        case SpvOpSignBitSet:
482            return SkString("SignBitSet");
483        case SpvOpLessOrGreater:
484            return SkString("LessOrGreater");
485        case SpvOpOrdered:
486            return SkString("Ordered");
487        case SpvOpUnordered:
488            return SkString("Unordered");
489        case SpvOpLogicalEqual:
490            return SkString("LogicalEqual");
491        case SpvOpLogicalNotEqual:
492            return SkString("LogicalNotEqual");
493        case SpvOpLogicalOr:
494            return SkString("LogicalOr");
495        case SpvOpLogicalAnd:
496            return SkString("LogicalAnd");
497        case SpvOpLogicalNot:
498            return SkString("LogicalNot");
499        case SpvOpSelect:
500            return SkString("Select");
501        case SpvOpIEqual:
502            return SkString("IEqual");
503        case SpvOpINotEqual:
504            return SkString("INotEqual");
505        case SpvOpUGreaterThan:
506            return SkString("UGreaterThan");
507        case SpvOpSGreaterThan:
508            return SkString("SGreaterThan");
509        case SpvOpUGreaterThanEqual:
510            return SkString("UGreaterThanEqual");
511        case SpvOpSGreaterThanEqual:
512            return SkString("SGreaterThanEqual");
513        case SpvOpULessThan:
514            return SkString("ULessThan");
515        case SpvOpSLessThan:
516            return SkString("SLessThan");
517        case SpvOpULessThanEqual:
518            return SkString("ULessThanEqual");
519        case SpvOpSLessThanEqual:
520            return SkString("SLessThanEqual");
521        case SpvOpFOrdEqual:
522            return SkString("FOrdEqual");
523        case SpvOpFUnordEqual:
524            return SkString("FUnordEqual");
525        case SpvOpFOrdNotEqual:
526            return SkString("FOrdNotEqual");
527        case SpvOpFUnordNotEqual:
528            return SkString("FUnordNotEqual");
529        case SpvOpFOrdLessThan:
530            return SkString("FOrdLessThan");
531        case SpvOpFUnordLessThan:
532            return SkString("FUnordLessThan");
533        case SpvOpFOrdGreaterThan:
534            return SkString("FOrdGreaterThan");
535        case SpvOpFUnordGreaterThan:
536            return SkString("FUnordGreaterThan");
537        case SpvOpFOrdLessThanEqual:
538            return SkString("FOrdLessThanEqual");
539        case SpvOpFUnordLessThanEqual:
540            return SkString("FUnordLessThanEqual");
541        case SpvOpFOrdGreaterThanEqual:
542            return SkString("FOrdGreaterThanEqual");
543        case SpvOpFUnordGreaterThanEqual:
544            return SkString("FUnordGreaterThanEqual");
545        case SpvOpShiftRightLogical:
546            return SkString("ShiftRightLogical");
547        case SpvOpShiftRightArithmetic:
548            return SkString("ShiftRightArithmetic");
549        case SpvOpShiftLeftLogical:
550            return SkString("ShiftLeftLogical");
551        case SpvOpBitwiseOr:
552            return SkString("BitwiseOr");
553        case SpvOpBitwiseXor:
554            return SkString("BitwiseXor");
555        case SpvOpBitwiseAnd:
556            return SkString("BitwiseAnd");
557        case SpvOpNot:
558            return SkString("Not");
559        case SpvOpBitFieldInsert:
560            return SkString("BitFieldInsert");
561        case SpvOpBitFieldSExtract:
562            return SkString("BitFieldSExtract");
563        case SpvOpBitFieldUExtract:
564            return SkString("BitFieldUExtract");
565        case SpvOpBitReverse:
566            return SkString("BitReverse");
567        case SpvOpBitCount:
568            return SkString("BitCount");
569        case SpvOpDPdx:
570            return SkString("DPdx");
571        case SpvOpDPdy:
572            return SkString("DPdy");
573        case SpvOpFwidth:
574            return SkString("Fwidth");
575        case SpvOpDPdxFine:
576            return SkString("DPdxFine");
577        case SpvOpDPdyFine:
578            return SkString("DPdyFine");
579        case SpvOpFwidthFine:
580            return SkString("FwidthFine");
581        case SpvOpDPdxCoarse:
582            return SkString("DPdxCoarse");
583        case SpvOpDPdyCoarse:
584            return SkString("DPdyCoarse");
585        case SpvOpFwidthCoarse:
586            return SkString("FwidthCoarse");
587        case SpvOpEmitVertex:
588            return SkString("EmitVertex");
589        case SpvOpEndPrimitive:
590            return SkString("EndPrimitive");
591        case SpvOpEmitStreamVertex:
592            return SkString("EmitStreamVertex");
593        case SpvOpEndStreamPrimitive:
594            return SkString("EndStreamPrimitive");
595        case SpvOpControlBarrier:
596            return SkString("ControlBarrier");
597        case SpvOpMemoryBarrier:
598            return SkString("MemoryBarrier");
599        case SpvOpAtomicLoad:
600            return SkString("AtomicLoad");
601        case SpvOpAtomicStore:
602            return SkString("AtomicStore");
603        case SpvOpAtomicExchange:
604            return SkString("AtomicExchange");
605        case SpvOpAtomicCompareExchange:
606            return SkString("AtomicCompareExchange");
607        case SpvOpAtomicCompareExchangeWeak:
608            return SkString("AtomicCompareExchangeWeak");
609        case SpvOpAtomicIIncrement:
610            return SkString("AtomicIIncrement");
611        case SpvOpAtomicIDecrement:
612            return SkString("AtomicIDecrement");
613        case SpvOpAtomicIAdd:
614            return SkString("AtomicIAdd");
615        case SpvOpAtomicISub:
616            return SkString("AtomicISub");
617        case SpvOpAtomicSMin:
618            return SkString("AtomicSMin");
619        case SpvOpAtomicUMin:
620            return SkString("AtomicUMin");
621        case SpvOpAtomicSMax:
622            return SkString("AtomicSMax");
623        case SpvOpAtomicUMax:
624            return SkString("AtomicUMax");
625        case SpvOpAtomicAnd:
626            return SkString("AtomicAnd");
627        case SpvOpAtomicOr:
628            return SkString("AtomicOr");
629        case SpvOpAtomicXor:
630            return SkString("AtomicXor");
631        case SpvOpPhi:
632            return SkString("Phi");
633        case SpvOpLoopMerge:
634            return SkString("LoopMerge");
635        case SpvOpSelectionMerge:
636            return SkString("SelectionMerge");
637        case SpvOpLabel:
638            return SkString("Label");
639        case SpvOpBranch:
640            return SkString("Branch");
641        case SpvOpBranchConditional:
642            return SkString("BranchConditional");
643        case SpvOpSwitch:
644            return SkString("Switch");
645        case SpvOpKill:
646            return SkString("Kill");
647        case SpvOpReturn:
648            return SkString("Return");
649        case SpvOpReturnValue:
650            return SkString("ReturnValue");
651        case SpvOpUnreachable:
652            return SkString("Unreachable");
653        case SpvOpLifetimeStart:
654            return SkString("LifetimeStart");
655        case SpvOpLifetimeStop:
656            return SkString("LifetimeStop");
657        case SpvOpGroupAsyncCopy:
658            return SkString("GroupAsyncCopy");
659        case SpvOpGroupWaitEvents:
660            return SkString("GroupWaitEvents");
661        case SpvOpGroupAll:
662            return SkString("GroupAll");
663        case SpvOpGroupAny:
664            return SkString("GroupAny");
665        case SpvOpGroupBroadcast:
666            return SkString("GroupBroadcast");
667        case SpvOpGroupIAdd:
668            return SkString("GroupIAdd");
669        case SpvOpGroupFAdd:
670            return SkString("GroupFAdd");
671        case SpvOpGroupFMin:
672            return SkString("GroupFMin");
673        case SpvOpGroupUMin:
674            return SkString("GroupUMin");
675        case SpvOpGroupSMin:
676            return SkString("GroupSMin");
677        case SpvOpGroupFMax:
678            return SkString("GroupFMax");
679        case SpvOpGroupUMax:
680            return SkString("GroupUMax");
681        case SpvOpGroupSMax:
682            return SkString("GroupSMax");
683        case SpvOpReadPipe:
684            return SkString("ReadPipe");
685        case SpvOpWritePipe:
686            return SkString("WritePipe");
687        case SpvOpReservedReadPipe:
688            return SkString("ReservedReadPipe");
689        case SpvOpReservedWritePipe:
690            return SkString("ReservedWritePipe");
691        case SpvOpReserveReadPipePackets:
692            return SkString("ReserveReadPipePackets");
693        case SpvOpReserveWritePipePackets:
694            return SkString("ReserveWritePipePackets");
695        case SpvOpCommitReadPipe:
696            return SkString("CommitReadPipe");
697        case SpvOpCommitWritePipe:
698            return SkString("CommitWritePipe");
699        case SpvOpIsValidReserveId:
700            return SkString("IsValidReserveId");
701        case SpvOpGetNumPipePackets:
702            return SkString("GetNumPipePackets");
703        case SpvOpGetMaxPipePackets:
704            return SkString("GetMaxPipePackets");
705        case SpvOpGroupReserveReadPipePackets:
706            return SkString("GroupReserveReadPipePackets");
707        case SpvOpGroupReserveWritePipePackets:
708            return SkString("GroupReserveWritePipePackets");
709        case SpvOpGroupCommitReadPipe:
710            return SkString("GroupCommitReadPipe");
711        case SpvOpGroupCommitWritePipe:
712            return SkString("GroupCommitWritePipe");
713        case SpvOpEnqueueMarker:
714            return SkString("EnqueueMarker");
715        case SpvOpEnqueueKernel:
716            return SkString("EnqueueKernel");
717        case SpvOpGetKernelNDrangeSubGroupCount:
718            return SkString("GetKernelNDrangeSubGroupCount");
719        case SpvOpGetKernelNDrangeMaxSubGroupSize:
720            return SkString("GetKernelNDrangeMaxSubGroupSize");
721        case SpvOpGetKernelWorkGroupSize:
722            return SkString("GetKernelWorkGroupSize");
723        case SpvOpGetKernelPreferredWorkGroupSizeMultiple:
724            return SkString("GetKernelPreferredWorkGroupSizeMultiple");
725        case SpvOpRetainEvent:
726            return SkString("RetainEvent");
727        case SpvOpReleaseEvent:
728            return SkString("ReleaseEvent");
729        case SpvOpCreateUserEvent:
730            return SkString("CreateUserEvent");
731        case SpvOpIsValidEvent:
732            return SkString("IsValidEvent");
733        case SpvOpSetUserEventStatus:
734            return SkString("SetUserEventStatus");
735        case SpvOpCaptureEventProfilingInfo:
736            return SkString("CaptureEventProfilingInfo");
737        case SpvOpGetDefaultQueue:
738            return SkString("GetDefaultQueue");
739        case SpvOpBuildNDRange:
740            return SkString("BuildNDRange");
741        case SpvOpImageSparseSampleImplicitLod:
742            return SkString("ImageSparseSampleImplicitLod");
743        case SpvOpImageSparseSampleExplicitLod:
744            return SkString("ImageSparseSampleExplicitLod");
745        case SpvOpImageSparseSampleDrefImplicitLod:
746            return SkString("ImageSparseSampleDrefImplicitLod");
747        case SpvOpImageSparseSampleDrefExplicitLod:
748            return SkString("ImageSparseSampleDrefExplicitLod");
749        case SpvOpImageSparseSampleProjImplicitLod:
750            return SkString("ImageSparseSampleProjImplicitLod");
751        case SpvOpImageSparseSampleProjExplicitLod:
752            return SkString("ImageSparseSampleProjExplicitLod");
753        case SpvOpImageSparseSampleProjDrefImplicitLod:
754            return SkString("ImageSparseSampleProjDrefImplicitLod");
755        case SpvOpImageSparseSampleProjDrefExplicitLod:
756            return SkString("ImageSparseSampleProjDrefExplicitLod");
757        case SpvOpImageSparseFetch:
758            return SkString("ImageSparseFetch");
759        case SpvOpImageSparseGather:
760            return SkString("ImageSparseGather");
761        case SpvOpImageSparseDrefGather:
762            return SkString("ImageSparseDrefGather");
763        case SpvOpImageSparseTexelsResident:
764            return SkString("ImageSparseTexelsResident");
765        case SpvOpNoLine:
766            return SkString("NoLine");
767        case SpvOpAtomicFlagTestAndSet:
768            return SkString("AtomicFlagTestAndSet");
769        case SpvOpAtomicFlagClear:
770            return SkString("AtomicFlagClear");
771        case SpvOpImageSparseRead:
772            return SkString("ImageSparseRead");
773        default:
774            ABORT("unsupported SPIR-V op");
775    }
776}
777#endif
778
779void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, SkWStream& out) {
780    ASSERT(opCode != SpvOpUndef);
781    switch (opCode) {
782        case SpvOpReturn:      // fall through
783        case SpvOpReturnValue: // fall through
784        case SpvOpKill:        // fall through
785        case SpvOpBranch:      // fall through
786        case SpvOpBranchConditional:
787            ASSERT(fCurrentBlock);
788            fCurrentBlock = 0;
789            break;
790        case SpvOpConstant:          // fall through
791        case SpvOpConstantTrue:      // fall through
792        case SpvOpConstantFalse:     // fall through
793        case SpvOpConstantComposite: // fall through
794        case SpvOpTypeVoid:          // fall through
795        case SpvOpTypeInt:           // fall through
796        case SpvOpTypeFloat:         // fall through
797        case SpvOpTypeBool:          // fall through
798        case SpvOpTypeVector:        // fall through
799        case SpvOpTypeMatrix:        // fall through
800        case SpvOpTypeArray:         // fall through
801        case SpvOpTypePointer:       // fall through
802        case SpvOpTypeFunction:      // fall through
803        case SpvOpTypeRuntimeArray:  // fall through
804        case SpvOpTypeStruct:        // fall through
805        case SpvOpTypeImage:         // fall through
806        case SpvOpTypeSampledImage:  // fall through
807        case SpvOpVariable:          // fall through
808        case SpvOpFunction:          // fall through
809        case SpvOpFunctionParameter: // fall through
810        case SpvOpFunctionEnd:       // fall through
811        case SpvOpExecutionMode:     // fall through
812        case SpvOpMemoryModel:       // fall through
813        case SpvOpCapability:        // fall through
814        case SpvOpExtInstImport:     // fall through
815        case SpvOpEntryPoint:        // fall through
816        case SpvOpSource:            // fall through
817        case SpvOpSourceExtension:   // fall through
818        case SpvOpName:              // fall through
819        case SpvOpMemberName:        // fall through
820        case SpvOpDecorate:          // fall through
821        case SpvOpMemberDecorate:
822            break;
823        default:
824            ASSERT(fCurrentBlock);
825    }
826#if SPIRV_DEBUG
827    out << std::endl << opcode_text(opCode) << " ";
828#else
829    this->writeWord((length << 16) | opCode, out);
830#endif
831}
832
833void SPIRVCodeGenerator::writeLabel(SpvId label, SkWStream& out) {
834    fCurrentBlock = label;
835    this->writeInstruction(SpvOpLabel, label, out);
836}
837
838void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, SkWStream& out) {
839    this->writeOpCode(opCode, 1, out);
840}
841
842void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, SkWStream& out) {
843    this->writeOpCode(opCode, 2, out);
844    this->writeWord(word1, out);
845}
846
847void SPIRVCodeGenerator::writeString(const char* string, SkWStream& out) {
848    size_t length = strlen(string);
849    out.writeText(string);
850    switch (length % 4) {
851        case 1:
852            out.write8(0);
853            // fall through
854        case 2:
855            out.write8(0);
856            // fall through
857        case 3:
858            out.write8(0);
859            break;
860        default:
861            this->writeWord(0, out);
862    }
863}
864
865void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, const char* string, SkWStream& out) {
866    int32_t length = (int32_t) strlen(string);
867    this->writeOpCode(opCode, 1 + (length + 4) / 4, out);
868    this->writeString(string, out);
869}
870
871
872void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, const char* string,
873                                          SkWStream& out) {
874    int32_t length = (int32_t) strlen(string);
875    this->writeOpCode(opCode, 2 + (length + 4) / 4, out);
876    this->writeWord(word1, out);
877    this->writeString(string, out);
878}
879
880void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
881                                          const char* string, SkWStream& out) {
882    int32_t length = (int32_t) strlen(string);
883    this->writeOpCode(opCode, 3 + (length + 4) / 4, out);
884    this->writeWord(word1, out);
885    this->writeWord(word2, out);
886    this->writeString(string, out);
887}
888
889void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
890                                          SkWStream& out) {
891    this->writeOpCode(opCode, 3, out);
892    this->writeWord(word1, out);
893    this->writeWord(word2, out);
894}
895
896void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
897                                          int32_t word3, SkWStream& out) {
898    this->writeOpCode(opCode, 4, out);
899    this->writeWord(word1, out);
900    this->writeWord(word2, out);
901    this->writeWord(word3, out);
902}
903
904void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
905                                          int32_t word3, int32_t word4, SkWStream& out) {
906    this->writeOpCode(opCode, 5, out);
907    this->writeWord(word1, out);
908    this->writeWord(word2, out);
909    this->writeWord(word3, out);
910    this->writeWord(word4, out);
911}
912
913void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
914                                          int32_t word3, int32_t word4, int32_t word5,
915                                          SkWStream& out) {
916    this->writeOpCode(opCode, 6, out);
917    this->writeWord(word1, out);
918    this->writeWord(word2, out);
919    this->writeWord(word3, out);
920    this->writeWord(word4, out);
921    this->writeWord(word5, out);
922}
923
924void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
925                                          int32_t word3, int32_t word4, int32_t word5,
926                                          int32_t word6, SkWStream& out) {
927    this->writeOpCode(opCode, 7, out);
928    this->writeWord(word1, out);
929    this->writeWord(word2, out);
930    this->writeWord(word3, out);
931    this->writeWord(word4, out);
932    this->writeWord(word5, out);
933    this->writeWord(word6, out);
934}
935
936void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
937                                          int32_t word3, int32_t word4, int32_t word5,
938                                          int32_t word6, int32_t word7, SkWStream& out) {
939    this->writeOpCode(opCode, 8, out);
940    this->writeWord(word1, out);
941    this->writeWord(word2, out);
942    this->writeWord(word3, out);
943    this->writeWord(word4, out);
944    this->writeWord(word5, out);
945    this->writeWord(word6, out);
946    this->writeWord(word7, out);
947}
948
949void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
950                                          int32_t word3, int32_t word4, int32_t word5,
951                                          int32_t word6, int32_t word7, int32_t word8,
952                                          SkWStream& out) {
953    this->writeOpCode(opCode, 9, out);
954    this->writeWord(word1, out);
955    this->writeWord(word2, out);
956    this->writeWord(word3, out);
957    this->writeWord(word4, out);
958    this->writeWord(word5, out);
959    this->writeWord(word6, out);
960    this->writeWord(word7, out);
961    this->writeWord(word8, out);
962}
963
964void SPIRVCodeGenerator::writeCapabilities(SkWStream& out) {
965    for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
966        if (fCapabilities & bit) {
967            this->writeInstruction(SpvOpCapability, (SpvId) i, out);
968        }
969    }
970}
971
972SpvId SPIRVCodeGenerator::nextId() {
973    return fIdCount++;
974}
975
976void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
977                                     SpvId resultId) {
978    this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
979    // go ahead and write all of the field types, so we don't inadvertently write them while we're
980    // in the middle of writing the struct instruction
981    std::vector<SpvId> types;
982    for (const auto& f : type.fields()) {
983        types.push_back(this->getType(*f.fType, memoryLayout));
984    }
985    this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
986    this->writeWord(resultId, fConstantBuffer);
987    for (SpvId id : types) {
988        this->writeWord(id, fConstantBuffer);
989    }
990    size_t offset = 0;
991    for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
992        size_t size = memoryLayout.size(*type.fields()[i].fType);
993        size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
994        const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
995        if (fieldLayout.fOffset >= 0) {
996            if (fieldLayout.fOffset < (int) offset) {
997                fErrors.error(type.fPosition,
998                              "offset of field '" + type.fields()[i].fName + "' must be at "
999                              "least " + to_string((int) offset));
1000            }
1001            if (fieldLayout.fOffset % alignment) {
1002                fErrors.error(type.fPosition,
1003                              "offset of field '" + type.fields()[i].fName + "' must be a multiple"
1004                              " of " + to_string((int) alignment));
1005            }
1006            offset = fieldLayout.fOffset;
1007        } else {
1008            size_t mod = offset % alignment;
1009            if (mod) {
1010                offset += alignment - mod;
1011            }
1012        }
1013        this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName.c_str(),
1014                               fNameBuffer);
1015        this->writeLayout(fieldLayout, resultId, i);
1016        if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
1017            this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
1018                                   (SpvId) offset, fDecorationBuffer);
1019        }
1020        if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
1021            this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
1022                                   fDecorationBuffer);
1023            this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
1024                                   (SpvId) memoryLayout.stride(*type.fields()[i].fType),
1025                                   fDecorationBuffer);
1026        }
1027        offset += size;
1028        Type::Kind kind = type.fields()[i].fType->kind();
1029        if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
1030            offset += alignment - offset % alignment;
1031        }
1032    }
1033}
1034
1035SpvId SPIRVCodeGenerator::getType(const Type& type) {
1036    return this->getType(type, fDefaultLayout);
1037}
1038
1039SpvId SPIRVCodeGenerator::getType(const Type& type, const MemoryLayout& layout) {
1040    SkString key = type.name() + to_string((int) layout.fStd);
1041    auto entry = fTypeMap.find(key);
1042    if (entry == fTypeMap.end()) {
1043        SpvId result = this->nextId();
1044        switch (type.kind()) {
1045            case Type::kScalar_Kind:
1046                if (type == *fContext.fBool_Type) {
1047                    this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
1048                } else if (type == *fContext.fInt_Type) {
1049                    this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
1050                } else if (type == *fContext.fUInt_Type) {
1051                    this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
1052                } else if (type == *fContext.fFloat_Type) {
1053                    this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
1054                } else if (type == *fContext.fDouble_Type) {
1055                    this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
1056                } else {
1057                    ASSERT(false);
1058                }
1059                break;
1060            case Type::kVector_Kind:
1061                this->writeInstruction(SpvOpTypeVector, result,
1062                                       this->getType(type.componentType(), layout),
1063                                       type.columns(), fConstantBuffer);
1064                break;
1065            case Type::kMatrix_Kind:
1066                this->writeInstruction(SpvOpTypeMatrix, result,
1067                                       this->getType(index_type(fContext, type), layout),
1068                                       type.columns(), fConstantBuffer);
1069                break;
1070            case Type::kStruct_Kind:
1071                this->writeStruct(type, layout, result);
1072                break;
1073            case Type::kArray_Kind: {
1074                if (type.columns() > 0) {
1075                    IntLiteral count(fContext, Position(), type.columns());
1076                    this->writeInstruction(SpvOpTypeArray, result,
1077                                           this->getType(type.componentType(), layout),
1078                                           this->writeIntLiteral(count), fConstantBuffer);
1079                    this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
1080                                           (int32_t) layout.stride(type),
1081                                           fDecorationBuffer);
1082                } else {
1083                    ABORT("runtime-sized arrays are not yet supported");
1084                    this->writeInstruction(SpvOpTypeRuntimeArray, result,
1085                                           this->getType(type.componentType(), layout),
1086                                           fConstantBuffer);
1087                }
1088                break;
1089            }
1090            case Type::kSampler_Kind: {
1091                SpvId image = result;
1092                if (SpvDimSubpassData != type.dimensions()) {
1093                    image = this->nextId();
1094                }
1095                this->writeInstruction(SpvOpTypeImage, image,
1096                                       this->getType(*fContext.fFloat_Type, layout),
1097                                       type.dimensions(), type.isDepth(), type.isArrayed(),
1098                                       type.isMultisampled(), type.isSampled() ? 1 : 2,
1099                                       SpvImageFormatUnknown, fConstantBuffer);
1100                if (SpvDimSubpassData != type.dimensions()) {
1101                    this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
1102                }
1103                break;
1104            }
1105            default:
1106                if (type == *fContext.fVoid_Type) {
1107                    this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
1108                } else {
1109                    ABORT("invalid type: %s", type.description().c_str());
1110                }
1111        }
1112        fTypeMap[key] = result;
1113        return result;
1114    }
1115    return entry->second;
1116}
1117
1118SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
1119    SkString key = function.fReturnType.description() + "(";
1120    SkString separator;
1121    for (size_t i = 0; i < function.fParameters.size(); i++) {
1122        key += separator;
1123        separator = ", ";
1124        key += function.fParameters[i]->fType.description();
1125    }
1126    key += ")";
1127    auto entry = fTypeMap.find(key);
1128    if (entry == fTypeMap.end()) {
1129        SpvId result = this->nextId();
1130        int32_t length = 3 + (int32_t) function.fParameters.size();
1131        SpvId returnType = this->getType(function.fReturnType);
1132        std::vector<SpvId> parameterTypes;
1133        for (size_t i = 0; i < function.fParameters.size(); i++) {
1134            // glslang seems to treat all function arguments as pointers whether they need to be or
1135            // not. I  was initially puzzled by this until I ran bizarre failures with certain
1136            // patterns of function calls and control constructs, as exemplified by this minimal
1137            // failure case:
1138            //
1139            // void sphere(float x) {
1140            // }
1141            //
1142            // void map() {
1143            //     sphere(1.0);
1144            // }
1145            //
1146            // void main() {
1147            //     for (int i = 0; i < 1; i++) {
1148            //         map();
1149            //     }
1150            // }
1151            //
1152            // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
1153            // crashes. Making it take a float* and storing the argument in a temporary variable,
1154            // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
1155            // the spec makes this make sense.
1156//            if (is_out(function->fParameters[i])) {
1157                parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
1158                                                              SpvStorageClassFunction));
1159//            } else {
1160//                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
1161//            }
1162        }
1163        this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
1164        this->writeWord(result, fConstantBuffer);
1165        this->writeWord(returnType, fConstantBuffer);
1166        for (SpvId id : parameterTypes) {
1167            this->writeWord(id, fConstantBuffer);
1168        }
1169        fTypeMap[key] = result;
1170        return result;
1171    }
1172    return entry->second;
1173}
1174
1175SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
1176    return this->getPointerType(type, fDefaultLayout, storageClass);
1177}
1178
1179SpvId SPIRVCodeGenerator::getPointerType(const Type& type, const MemoryLayout& layout,
1180                                         SpvStorageClass_ storageClass) {
1181    SkString key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
1182    auto entry = fTypeMap.find(key);
1183    if (entry == fTypeMap.end()) {
1184        SpvId result = this->nextId();
1185        this->writeInstruction(SpvOpTypePointer, result, storageClass,
1186                               this->getType(type), fConstantBuffer);
1187        fTypeMap[key] = result;
1188        return result;
1189    }
1190    return entry->second;
1191}
1192
1193SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, SkWStream& out) {
1194    switch (expr.fKind) {
1195        case Expression::kBinary_Kind:
1196            return this->writeBinaryExpression((BinaryExpression&) expr, out);
1197        case Expression::kBoolLiteral_Kind:
1198            return this->writeBoolLiteral((BoolLiteral&) expr);
1199        case Expression::kConstructor_Kind:
1200            return this->writeConstructor((Constructor&) expr, out);
1201        case Expression::kIntLiteral_Kind:
1202            return this->writeIntLiteral((IntLiteral&) expr);
1203        case Expression::kFieldAccess_Kind:
1204            return this->writeFieldAccess(((FieldAccess&) expr), out);
1205        case Expression::kFloatLiteral_Kind:
1206            return this->writeFloatLiteral(((FloatLiteral&) expr));
1207        case Expression::kFunctionCall_Kind:
1208            return this->writeFunctionCall((FunctionCall&) expr, out);
1209        case Expression::kPrefix_Kind:
1210            return this->writePrefixExpression((PrefixExpression&) expr, out);
1211        case Expression::kPostfix_Kind:
1212            return this->writePostfixExpression((PostfixExpression&) expr, out);
1213        case Expression::kSwizzle_Kind:
1214            return this->writeSwizzle((Swizzle&) expr, out);
1215        case Expression::kVariableReference_Kind:
1216            return this->writeVariableReference((VariableReference&) expr, out);
1217        case Expression::kTernary_Kind:
1218            return this->writeTernaryExpression((TernaryExpression&) expr, out);
1219        case Expression::kIndex_Kind:
1220            return this->writeIndexExpression((IndexExpression&) expr, out);
1221        default:
1222            ABORT("unsupported expression: %s", expr.description().c_str());
1223    }
1224    return -1;
1225}
1226
1227SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, SkWStream& out) {
1228    auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
1229    ASSERT(intrinsic != fIntrinsicMap.end());
1230    const Type& type = c.fArguments[0]->fType;
1231    int32_t intrinsicId;
1232    if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
1233        intrinsicId = std::get<1>(intrinsic->second);
1234    } else if (is_signed(fContext, type)) {
1235        intrinsicId = std::get<2>(intrinsic->second);
1236    } else if (is_unsigned(fContext, type)) {
1237        intrinsicId = std::get<3>(intrinsic->second);
1238    } else if (is_bool(fContext, type)) {
1239        intrinsicId = std::get<4>(intrinsic->second);
1240    } else {
1241        ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(),
1242              type.description().c_str());
1243    }
1244    switch (std::get<0>(intrinsic->second)) {
1245        case kGLSL_STD_450_IntrinsicKind: {
1246            SpvId result = this->nextId();
1247            std::vector<SpvId> arguments;
1248            for (size_t i = 0; i < c.fArguments.size(); i++) {
1249                arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1250            }
1251            this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
1252            this->writeWord(this->getType(c.fType), out);
1253            this->writeWord(result, out);
1254            this->writeWord(fGLSLExtendedInstructions, out);
1255            this->writeWord(intrinsicId, out);
1256            for (SpvId id : arguments) {
1257                this->writeWord(id, out);
1258            }
1259            return result;
1260        }
1261        case kSPIRV_IntrinsicKind: {
1262            SpvId result = this->nextId();
1263            std::vector<SpvId> arguments;
1264            for (size_t i = 0; i < c.fArguments.size(); i++) {
1265                arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1266            }
1267            this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
1268            this->writeWord(this->getType(c.fType), out);
1269            this->writeWord(result, out);
1270            for (SpvId id : arguments) {
1271                this->writeWord(id, out);
1272            }
1273            return result;
1274        }
1275        case kSpecial_IntrinsicKind:
1276            return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
1277        default:
1278            ABORT("unsupported intrinsic kind");
1279    }
1280}
1281
1282SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
1283                                                SkWStream& out) {
1284    SpvId result = this->nextId();
1285    switch (kind) {
1286        case kAtan_SpecialIntrinsic: {
1287            std::vector<SpvId> arguments;
1288            for (size_t i = 0; i < c.fArguments.size(); i++) {
1289                arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1290            }
1291            this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
1292            this->writeWord(this->getType(c.fType), out);
1293            this->writeWord(result, out);
1294            this->writeWord(fGLSLExtendedInstructions, out);
1295            this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
1296            for (SpvId id : arguments) {
1297                this->writeWord(id, out);
1298            }
1299            return result;
1300        }
1301        case kTexture_SpecialIntrinsic: {
1302            SpvOp_ op = SpvOpImageSampleImplicitLod;
1303            switch (c.fArguments[0]->fType.dimensions()) {
1304                case SpvDim1D:
1305                    if (c.fArguments[1]->fType == *fContext.fVec2_Type) {
1306                        op = SpvOpImageSampleProjImplicitLod;
1307                    } else {
1308                        ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
1309                    }
1310                    break;
1311                case SpvDim2D:
1312                    if (c.fArguments[1]->fType == *fContext.fVec3_Type) {
1313                        op = SpvOpImageSampleProjImplicitLod;
1314                    } else {
1315                        ASSERT(c.fArguments[1]->fType == *fContext.fVec2_Type);
1316                    }
1317                    break;
1318                case SpvDim3D:
1319                    if (c.fArguments[1]->fType == *fContext.fVec4_Type) {
1320                        op = SpvOpImageSampleProjImplicitLod;
1321                    } else {
1322                        ASSERT(c.fArguments[1]->fType == *fContext.fVec3_Type);
1323                    }
1324                    break;
1325                case SpvDimCube:   // fall through
1326                case SpvDimRect:   // fall through
1327                case SpvDimBuffer: // fall through
1328                case SpvDimSubpassData:
1329                    break;
1330            }
1331            SpvId type = this->getType(c.fType);
1332            SpvId sampler = this->writeExpression(*c.fArguments[0], out);
1333            SpvId uv = this->writeExpression(*c.fArguments[1], out);
1334            if (c.fArguments.size() == 3) {
1335                this->writeInstruction(op, type, result, sampler, uv,
1336                                       SpvImageOperandsBiasMask,
1337                                       this->writeExpression(*c.fArguments[2], out),
1338                                       out);
1339            } else {
1340                ASSERT(c.fArguments.size() == 2);
1341                this->writeInstruction(op, type, result, sampler, uv,
1342                                       out);
1343            }
1344            break;
1345        }
1346        case kSubpassLoad_SpecialIntrinsic: {
1347            SpvId img = this->writeExpression(*c.fArguments[0], out);
1348            std::vector<std::unique_ptr<Expression>> args;
1349            args.emplace_back(new FloatLiteral(fContext, Position(), 0.0));
1350            args.emplace_back(new FloatLiteral(fContext, Position(), 0.0));
1351            Constructor ctor(Position(), *fContext.fVec2_Type, std::move(args));
1352            SpvId coords = this->writeConstantVector(ctor);
1353            if (1 == c.fArguments.size()) {
1354                this->writeInstruction(SpvOpImageRead,
1355                                       this->getType(c.fType),
1356                                       result,
1357                                       img,
1358                                       coords,
1359                                       out);
1360            } else {
1361                SkASSERT(2 == c.fArguments.size());
1362                SpvId sample = this->writeExpression(*c.fArguments[1], out);
1363                this->writeInstruction(SpvOpImageRead,
1364                                       this->getType(c.fType),
1365                                       result,
1366                                       img,
1367                                       coords,
1368                                       SpvImageOperandsSampleMask,
1369                                       sample,
1370                                       out);
1371            }
1372            break;
1373        }
1374    }
1375    return result;
1376}
1377
1378SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, SkWStream& out) {
1379    const auto& entry = fFunctionMap.find(&c.fFunction);
1380    if (entry == fFunctionMap.end()) {
1381        return this->writeIntrinsicCall(c, out);
1382    }
1383    // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
1384    std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
1385    std::vector<SpvId> arguments;
1386    for (size_t i = 0; i < c.fArguments.size(); i++) {
1387        // id of temporary variable that we will use to hold this argument, or 0 if it is being
1388        // passed directly
1389        SpvId tmpVar;
1390        // if we need a temporary var to store this argument, this is the value to store in the var
1391        SpvId tmpValueId;
1392        if (is_out(*c.fFunction.fParameters[i])) {
1393            std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
1394            SpvId ptr = lv->getPointer();
1395            if (ptr) {
1396                arguments.push_back(ptr);
1397                continue;
1398            } else {
1399                // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
1400                // copy it into a temp, call the function, read the value out of the temp, and then
1401                // update the lvalue.
1402                tmpValueId = lv->load(out);
1403                tmpVar = this->nextId();
1404                lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
1405                                  std::move(lv)));
1406            }
1407        } else {
1408            // see getFunctionType for an explanation of why we're always using pointer parameters
1409            tmpValueId = this->writeExpression(*c.fArguments[i], out);
1410            tmpVar = this->nextId();
1411        }
1412        this->writeInstruction(SpvOpVariable,
1413                               this->getPointerType(c.fArguments[i]->fType,
1414                                                    SpvStorageClassFunction),
1415                               tmpVar,
1416                               SpvStorageClassFunction,
1417                               fVariableBuffer);
1418        this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
1419        arguments.push_back(tmpVar);
1420    }
1421    SpvId result = this->nextId();
1422    this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
1423    this->writeWord(this->getType(c.fType), out);
1424    this->writeWord(result, out);
1425    this->writeWord(entry->second, out);
1426    for (SpvId id : arguments) {
1427        this->writeWord(id, out);
1428    }
1429    // now that the call is complete, we may need to update some lvalues with the new values of out
1430    // arguments
1431    for (const auto& tuple : lvalues) {
1432        SpvId load = this->nextId();
1433        this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
1434        std::get<2>(tuple)->store(load, out);
1435    }
1436    return result;
1437}
1438
1439SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
1440    ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
1441    SpvId result = this->nextId();
1442    std::vector<SpvId> arguments;
1443    for (size_t i = 0; i < c.fArguments.size(); i++) {
1444        arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
1445    }
1446    SpvId type = this->getType(c.fType);
1447    if (c.fArguments.size() == 1) {
1448        // with a single argument, a vector will have all of its entries equal to the argument
1449        this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
1450        this->writeWord(type, fConstantBuffer);
1451        this->writeWord(result, fConstantBuffer);
1452        for (int i = 0; i < c.fType.columns(); i++) {
1453            this->writeWord(arguments[0], fConstantBuffer);
1454        }
1455    } else {
1456        this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
1457                          fConstantBuffer);
1458        this->writeWord(type, fConstantBuffer);
1459        this->writeWord(result, fConstantBuffer);
1460        for (SpvId id : arguments) {
1461            this->writeWord(id, fConstantBuffer);
1462        }
1463    }
1464    return result;
1465}
1466
1467SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, SkWStream& out) {
1468    ASSERT(c.fType == *fContext.fFloat_Type);
1469    ASSERT(c.fArguments.size() == 1);
1470    ASSERT(c.fArguments[0]->fType.isNumber());
1471    SpvId result = this->nextId();
1472    SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1473    if (c.fArguments[0]->fType == *fContext.fInt_Type) {
1474        this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
1475                               out);
1476    } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
1477        this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
1478                               out);
1479    } else if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
1480        return parameter;
1481    }
1482    return result;
1483}
1484
1485SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, SkWStream& out) {
1486    ASSERT(c.fType == *fContext.fInt_Type);
1487    ASSERT(c.fArguments.size() == 1);
1488    ASSERT(c.fArguments[0]->fType.isNumber());
1489    SpvId result = this->nextId();
1490    SpvId parameter = this->writeExpression(*c.fArguments[0], out);
1491    if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
1492        this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
1493                               out);
1494    } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
1495        this->writeInstruction(SpvOpSatConvertUToS, this->getType(c.fType), result, parameter,
1496                               out);
1497    } else if (c.fArguments[0]->fType == *fContext.fInt_Type) {
1498        return parameter;
1499    }
1500    return result;
1501}
1502
1503void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
1504                                                 SkWStream& out) {
1505    FloatLiteral zero(fContext, Position(), 0);
1506    SpvId zeroId = this->writeFloatLiteral(zero);
1507    std::vector<SpvId> columnIds;
1508    for (int column = 0; column < type.columns(); column++) {
1509        this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
1510                          out);
1511        this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
1512                        out);
1513        SpvId columnId = this->nextId();
1514        this->writeWord(columnId, out);
1515        columnIds.push_back(columnId);
1516        for (int row = 0; row < type.columns(); row++) {
1517            this->writeWord(row == column ? diagonal : zeroId, out);
1518        }
1519    }
1520    this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
1521                      out);
1522    this->writeWord(this->getType(type), out);
1523    this->writeWord(id, out);
1524    for (SpvId id : columnIds) {
1525        this->writeWord(id, out);
1526    }
1527}
1528
1529void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
1530                                         const Type& dstType, SkWStream& out) {
1531    ABORT("unimplemented");
1532}
1533
1534SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, SkWStream& out) {
1535    ASSERT(c.fType.kind() == Type::kMatrix_Kind);
1536    // go ahead and write the arguments so we don't try to write new instructions in the middle of
1537    // an instruction
1538    std::vector<SpvId> arguments;
1539    for (size_t i = 0; i < c.fArguments.size(); i++) {
1540        arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1541    }
1542    SpvId result = this->nextId();
1543    int rows = c.fType.rows();
1544    int columns = c.fType.columns();
1545    if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1546        this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
1547    } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
1548        this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
1549    } else {
1550        std::vector<SpvId> columnIds;
1551        int currentCount = 0;
1552        for (size_t i = 0; i < arguments.size(); i++) {
1553            if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
1554                ASSERT(currentCount == 0);
1555                columnIds.push_back(arguments[i]);
1556                currentCount = 0;
1557            } else {
1558                ASSERT(c.fArguments[i]->fType.kind() == Type::kScalar_Kind);
1559                if (currentCount == 0) {
1560                    this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out);
1561                    this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
1562                                                                                     1)),
1563                                    out);
1564                    SpvId id = this->nextId();
1565                    this->writeWord(id, out);
1566                    columnIds.push_back(id);
1567                }
1568                this->writeWord(arguments[i], out);
1569                currentCount = (currentCount + 1) % rows;
1570            }
1571        }
1572        ASSERT(columnIds.size() == (size_t) columns);
1573        this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
1574        this->writeWord(this->getType(c.fType), out);
1575        this->writeWord(result, out);
1576        for (SpvId id : columnIds) {
1577            this->writeWord(id, out);
1578        }
1579    }
1580    return result;
1581}
1582
1583SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, SkWStream& out) {
1584    ASSERT(c.fType.kind() == Type::kVector_Kind);
1585    if (c.isConstant()) {
1586        return this->writeConstantVector(c);
1587    }
1588    // go ahead and write the arguments so we don't try to write new instructions in the middle of
1589    // an instruction
1590    std::vector<SpvId> arguments;
1591    for (size_t i = 0; i < c.fArguments.size(); i++) {
1592        arguments.push_back(this->writeExpression(*c.fArguments[i], out));
1593    }
1594    SpvId result = this->nextId();
1595    if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
1596        this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
1597        this->writeWord(this->getType(c.fType), out);
1598        this->writeWord(result, out);
1599        for (int i = 0; i < c.fType.columns(); i++) {
1600            this->writeWord(arguments[0], out);
1601        }
1602    } else {
1603        this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
1604        this->writeWord(this->getType(c.fType), out);
1605        this->writeWord(result, out);
1606        for (SpvId id : arguments) {
1607            this->writeWord(id, out);
1608        }
1609    }
1610    return result;
1611}
1612
1613SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, SkWStream& out) {
1614    if (c.fType == *fContext.fFloat_Type) {
1615        return this->writeFloatConstructor(c, out);
1616    } else if (c.fType == *fContext.fInt_Type) {
1617        return this->writeIntConstructor(c, out);
1618    }
1619    switch (c.fType.kind()) {
1620        case Type::kVector_Kind:
1621            return this->writeVectorConstructor(c, out);
1622        case Type::kMatrix_Kind:
1623            return this->writeMatrixConstructor(c, out);
1624        default:
1625            ABORT("unsupported constructor: %s", c.description().c_str());
1626    }
1627}
1628
1629SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
1630    if (modifiers.fFlags & Modifiers::kIn_Flag) {
1631        ASSERT(!modifiers.fLayout.fPushConstant);
1632        return SpvStorageClassInput;
1633    } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
1634        ASSERT(!modifiers.fLayout.fPushConstant);
1635        return SpvStorageClassOutput;
1636    } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
1637        if (modifiers.fLayout.fPushConstant) {
1638            return SpvStorageClassPushConstant;
1639        }
1640        return SpvStorageClassUniform;
1641    } else {
1642        return SpvStorageClassFunction;
1643    }
1644}
1645
1646SpvStorageClass_ get_storage_class(const Expression& expr) {
1647    switch (expr.fKind) {
1648        case Expression::kVariableReference_Kind:
1649            return get_storage_class(((VariableReference&) expr).fVariable.fModifiers);
1650        case Expression::kFieldAccess_Kind:
1651            return get_storage_class(*((FieldAccess&) expr).fBase);
1652        case Expression::kIndex_Kind:
1653            return get_storage_class(*((IndexExpression&) expr).fBase);
1654        default:
1655            return SpvStorageClassFunction;
1656    }
1657}
1658
1659std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, SkWStream& out) {
1660    std::vector<SpvId> chain;
1661    switch (expr.fKind) {
1662        case Expression::kIndex_Kind: {
1663            IndexExpression& indexExpr = (IndexExpression&) expr;
1664            chain = this->getAccessChain(*indexExpr.fBase, out);
1665            chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
1666            break;
1667        }
1668        case Expression::kFieldAccess_Kind: {
1669            FieldAccess& fieldExpr = (FieldAccess&) expr;
1670            chain = this->getAccessChain(*fieldExpr.fBase, out);
1671            IntLiteral index(fContext, Position(), fieldExpr.fFieldIndex);
1672            chain.push_back(this->writeIntLiteral(index));
1673            break;
1674        }
1675        default:
1676            chain.push_back(this->getLValue(expr, out)->getPointer());
1677    }
1678    return chain;
1679}
1680
1681class PointerLValue : public SPIRVCodeGenerator::LValue {
1682public:
1683    PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
1684    : fGen(gen)
1685    , fPointer(pointer)
1686    , fType(type) {}
1687
1688    virtual SpvId getPointer() override {
1689        return fPointer;
1690    }
1691
1692    virtual SpvId load(SkWStream& out) override {
1693        SpvId result = fGen.nextId();
1694        fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
1695        return result;
1696    }
1697
1698    virtual void store(SpvId value, SkWStream& out) override {
1699        fGen.writeInstruction(SpvOpStore, fPointer, value, out);
1700    }
1701
1702private:
1703    SPIRVCodeGenerator& fGen;
1704    const SpvId fPointer;
1705    const SpvId fType;
1706};
1707
1708class SwizzleLValue : public SPIRVCodeGenerator::LValue {
1709public:
1710    SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
1711                  const Type& baseType, const Type& swizzleType)
1712    : fGen(gen)
1713    , fVecPointer(vecPointer)
1714    , fComponents(components)
1715    , fBaseType(baseType)
1716    , fSwizzleType(swizzleType) {}
1717
1718    virtual SpvId getPointer() override {
1719        return 0;
1720    }
1721
1722    virtual SpvId load(SkWStream& out) override {
1723        SpvId base = fGen.nextId();
1724        fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1725        SpvId result = fGen.nextId();
1726        fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
1727        fGen.writeWord(fGen.getType(fSwizzleType), out);
1728        fGen.writeWord(result, out);
1729        fGen.writeWord(base, out);
1730        fGen.writeWord(base, out);
1731        for (int component : fComponents) {
1732            fGen.writeWord(component, out);
1733        }
1734        return result;
1735    }
1736
1737    virtual void store(SpvId value, SkWStream& out) override {
1738        // use OpVectorShuffle to mix and match the vector components. We effectively create
1739        // a virtual vector out of the concatenation of the left and right vectors, and then
1740        // select components from this virtual vector to make the result vector. For
1741        // instance, given:
1742        // vec3 L = ...;
1743        // vec3 R = ...;
1744        // L.xz = R.xy;
1745        // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
1746        // our result vector to look like (R.x, L.y, R.y), so we need to select indices
1747        // (3, 1, 4).
1748        SpvId base = fGen.nextId();
1749        fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
1750        SpvId shuffle = fGen.nextId();
1751        fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
1752        fGen.writeWord(fGen.getType(fBaseType), out);
1753        fGen.writeWord(shuffle, out);
1754        fGen.writeWord(base, out);
1755        fGen.writeWord(value, out);
1756        for (int i = 0; i < fBaseType.columns(); i++) {
1757            // current offset into the virtual vector, defaults to pulling the unmodified
1758            // value from the left side
1759            int offset = i;
1760            // check to see if we are writing this component
1761            for (size_t j = 0; j < fComponents.size(); j++) {
1762                if (fComponents[j] == i) {
1763                    // we're writing to this component, so adjust the offset to pull from
1764                    // the correct component of the right side instead of preserving the
1765                    // value from the left
1766                    offset = (int) (j + fBaseType.columns());
1767                    break;
1768                }
1769            }
1770            fGen.writeWord(offset, out);
1771        }
1772        fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
1773    }
1774
1775private:
1776    SPIRVCodeGenerator& fGen;
1777    const SpvId fVecPointer;
1778    const std::vector<int>& fComponents;
1779    const Type& fBaseType;
1780    const Type& fSwizzleType;
1781};
1782
1783std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
1784                                                                          SkWStream& out) {
1785    switch (expr.fKind) {
1786        case Expression::kVariableReference_Kind: {
1787            const Variable& var = ((VariableReference&) expr).fVariable;
1788            auto entry = fVariableMap.find(&var);
1789            ASSERT(entry != fVariableMap.end());
1790            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1791                                                                       *this,
1792                                                                       entry->second,
1793                                                                       this->getType(expr.fType)));
1794        }
1795        case Expression::kIndex_Kind: // fall through
1796        case Expression::kFieldAccess_Kind: {
1797            std::vector<SpvId> chain = this->getAccessChain(expr, out);
1798            SpvId member = this->nextId();
1799            this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
1800            this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
1801            this->writeWord(member, out);
1802            for (SpvId idx : chain) {
1803                this->writeWord(idx, out);
1804            }
1805            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1806                                                                       *this,
1807                                                                       member,
1808                                                                       this->getType(expr.fType)));
1809        }
1810
1811        case Expression::kSwizzle_Kind: {
1812            Swizzle& swizzle = (Swizzle&) expr;
1813            size_t count = swizzle.fComponents.size();
1814            SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
1815            ASSERT(base);
1816            if (count == 1) {
1817                IntLiteral index(fContext, Position(), swizzle.fComponents[0]);
1818                SpvId member = this->nextId();
1819                this->writeInstruction(SpvOpAccessChain,
1820                                       this->getPointerType(swizzle.fType,
1821                                                            get_storage_class(*swizzle.fBase)),
1822                                       member,
1823                                       base,
1824                                       this->writeIntLiteral(index),
1825                                       out);
1826                return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1827                                                                       *this,
1828                                                                       member,
1829                                                                       this->getType(expr.fType)));
1830            } else {
1831                return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
1832                                                                              *this,
1833                                                                              base,
1834                                                                              swizzle.fComponents,
1835                                                                              swizzle.fBase->fType,
1836                                                                              expr.fType));
1837            }
1838        }
1839
1840        default:
1841            // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
1842            // to the need to store values in temporary variables during function calls (see
1843            // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
1844            // caught by IRGenerator
1845            SpvId result = this->nextId();
1846            SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
1847            this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
1848                                   fVariableBuffer);
1849            this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
1850            return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
1851                                                                       *this,
1852                                                                       result,
1853                                                                       this->getType(expr.fType)));
1854    }
1855}
1856
1857SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, SkWStream& out) {
1858    SpvId result = this->nextId();
1859    auto entry = fVariableMap.find(&ref.fVariable);
1860    ASSERT(entry != fVariableMap.end());
1861    SpvId var = entry->second;
1862    this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
1863    if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
1864        fProgram.fSettings.fFlipY) {
1865        // need to remap to a top-left coordinate system
1866        if (fRTHeightStructId == (SpvId) -1) {
1867            // height variable hasn't been written yet
1868            std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors));
1869            ASSERT(fRTHeightFieldIndex == (SpvId) -1);
1870            std::vector<Type::Field> fields;
1871            fields.emplace_back(Modifiers(), SkString(SKSL_RTHEIGHT_NAME),
1872                                fContext.fFloat_Type.get());
1873            SkString name("sksl_synthetic_uniforms");
1874            Type intfStruct(Position(), name, fields);
1875            Layout layout(-1, -1, 1, -1, -1, -1, -1, false, false, false, Layout::Format::kUnspecified,
1876                          false, Layout::kUnspecified_Primitive, -1, -1);
1877            Variable* intfVar = new Variable(Position(),
1878                                             Modifiers(layout, Modifiers::kUniform_Flag),
1879                                             name,
1880                                             intfStruct,
1881                                             Variable::kGlobal_Storage);
1882            fSynthetics.takeOwnership(intfVar);
1883            InterfaceBlock intf(Position(), intfVar, name, SkString(""),
1884                                std::vector<std::unique_ptr<Expression>>(), st);
1885            fRTHeightStructId = this->writeInterfaceBlock(intf);
1886            fRTHeightFieldIndex = 0;
1887        }
1888        ASSERT(fRTHeightFieldIndex != (SpvId) -1);
1889        // write vec4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0)
1890        SpvId xId = this->nextId();
1891        this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
1892                               result, 0, out);
1893        IntLiteral fieldIndex(fContext, Position(), fRTHeightFieldIndex);
1894        SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
1895        SpvId heightPtr = this->nextId();
1896        this->writeOpCode(SpvOpAccessChain, 5, out);
1897        this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
1898        this->writeWord(heightPtr, out);
1899        this->writeWord(fRTHeightStructId, out);
1900        this->writeWord(fieldIndexId, out);
1901        SpvId heightRead = this->nextId();
1902        this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
1903                               heightPtr, out);
1904        SpvId rawYId = this->nextId();
1905        this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
1906                               result, 1, out);
1907        SpvId flippedYId = this->nextId();
1908        this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
1909                               heightRead, rawYId, out);
1910        FloatLiteral zero(fContext, Position(), 0.0);
1911        SpvId zeroId = writeFloatLiteral(zero);
1912        FloatLiteral one(fContext, Position(), 1.0);
1913        SpvId oneId = writeFloatLiteral(one);
1914        SpvId flipped = this->nextId();
1915        this->writeOpCode(SpvOpCompositeConstruct, 7, out);
1916        this->writeWord(this->getType(*fContext.fVec4_Type), out);
1917        this->writeWord(flipped, out);
1918        this->writeWord(xId, out);
1919        this->writeWord(flippedYId, out);
1920        this->writeWord(zeroId, out);
1921        this->writeWord(oneId, out);
1922        return flipped;
1923    }
1924    return result;
1925}
1926
1927SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, SkWStream& out) {
1928    return getLValue(expr, out)->load(out);
1929}
1930
1931SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, SkWStream& out) {
1932    return getLValue(f, out)->load(out);
1933}
1934
1935SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, SkWStream& out) {
1936    SpvId base = this->writeExpression(*swizzle.fBase, out);
1937    SpvId result = this->nextId();
1938    size_t count = swizzle.fComponents.size();
1939    if (count == 1) {
1940        this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
1941                               swizzle.fComponents[0], out);
1942    } else {
1943        this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
1944        this->writeWord(this->getType(swizzle.fType), out);
1945        this->writeWord(result, out);
1946        this->writeWord(base, out);
1947        this->writeWord(base, out);
1948        for (int component : swizzle.fComponents) {
1949            this->writeWord(component, out);
1950        }
1951    }
1952    return result;
1953}
1954
1955SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
1956                                               const Type& operandType, SpvId lhs,
1957                                               SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
1958                                               SpvOp_ ifUInt, SpvOp_ ifBool, SkWStream& out) {
1959    SpvId result = this->nextId();
1960    if (is_float(fContext, operandType)) {
1961        this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
1962    } else if (is_signed(fContext, operandType)) {
1963        this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
1964    } else if (is_unsigned(fContext, operandType)) {
1965        this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
1966    } else if (operandType == *fContext.fBool_Type) {
1967        this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
1968    } else {
1969        ABORT("invalid operandType: %s", operandType.description().c_str());
1970    }
1971    return result;
1972}
1973
1974bool is_assignment(Token::Kind op) {
1975    switch (op) {
1976        case Token::EQ:           // fall through
1977        case Token::PLUSEQ:       // fall through
1978        case Token::MINUSEQ:      // fall through
1979        case Token::STAREQ:       // fall through
1980        case Token::SLASHEQ:      // fall through
1981        case Token::PERCENTEQ:    // fall through
1982        case Token::SHLEQ:        // fall through
1983        case Token::SHREQ:        // fall through
1984        case Token::BITWISEOREQ:  // fall through
1985        case Token::BITWISEXOREQ: // fall through
1986        case Token::BITWISEANDEQ: // fall through
1987        case Token::LOGICALOREQ:  // fall through
1988        case Token::LOGICALXOREQ: // fall through
1989        case Token::LOGICALANDEQ:
1990            return true;
1991        default:
1992            return false;
1993    }
1994}
1995
1996SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SkWStream& out) {
1997    if (operandType.kind() == Type::kVector_Kind) {
1998        SpvId result = this->nextId();
1999        this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out);
2000        return result;
2001    }
2002    return id;
2003}
2004
2005SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, SkWStream& out) {
2006    // handle cases where we don't necessarily evaluate both LHS and RHS
2007    switch (b.fOperator) {
2008        case Token::EQ: {
2009            SpvId rhs = this->writeExpression(*b.fRight, out);
2010            this->getLValue(*b.fLeft, out)->store(rhs, out);
2011            return rhs;
2012        }
2013        case Token::LOGICALAND:
2014            return this->writeLogicalAnd(b, out);
2015        case Token::LOGICALOR:
2016            return this->writeLogicalOr(b, out);
2017        default:
2018            break;
2019    }
2020
2021    // "normal" operators
2022    const Type& resultType = b.fType;
2023    std::unique_ptr<LValue> lvalue;
2024    SpvId lhs;
2025    if (is_assignment(b.fOperator)) {
2026        lvalue = this->getLValue(*b.fLeft, out);
2027        lhs = lvalue->load(out);
2028    } else {
2029        lvalue = nullptr;
2030        lhs = this->writeExpression(*b.fLeft, out);
2031    }
2032    SpvId rhs = this->writeExpression(*b.fRight, out);
2033    // component type we are operating on: float, int, uint
2034    const Type* operandType;
2035    // IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling
2036    // in SPIR-V
2037    if (b.fLeft->fType != b.fRight->fType) {
2038        if (b.fLeft->fType.kind() == Type::kVector_Kind &&
2039            b.fRight->fType.isNumber()) {
2040            // promote number to vector
2041            SpvId vec = this->nextId();
2042            this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
2043            this->writeWord(this->getType(resultType), out);
2044            this->writeWord(vec, out);
2045            for (int i = 0; i < resultType.columns(); i++) {
2046                this->writeWord(rhs, out);
2047            }
2048            rhs = vec;
2049            operandType = &b.fRight->fType;
2050        } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
2051                   b.fLeft->fType.isNumber()) {
2052            // promote number to vector
2053            SpvId vec = this->nextId();
2054            this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
2055            this->writeWord(this->getType(resultType), out);
2056            this->writeWord(vec, out);
2057            for (int i = 0; i < resultType.columns(); i++) {
2058                this->writeWord(lhs, out);
2059            }
2060            lhs = vec;
2061            ASSERT(!lvalue);
2062            operandType = &b.fLeft->fType;
2063        } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
2064            SpvOp_ op;
2065            if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2066                op = SpvOpMatrixTimesMatrix;
2067            } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
2068                op = SpvOpMatrixTimesVector;
2069            } else {
2070                ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
2071                op = SpvOpMatrixTimesScalar;
2072            }
2073            SpvId result = this->nextId();
2074            this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
2075            if (b.fOperator == Token::STAREQ) {
2076                lvalue->store(result, out);
2077            } else {
2078                ASSERT(b.fOperator == Token::STAR);
2079            }
2080            return result;
2081        } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
2082            SpvId result = this->nextId();
2083            if (b.fLeft->fType.kind() == Type::kVector_Kind) {
2084                this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
2085                                       lhs, rhs, out);
2086            } else {
2087                ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
2088                this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
2089                                       lhs, out);
2090            }
2091            if (b.fOperator == Token::STAREQ) {
2092                lvalue->store(result, out);
2093            } else {
2094                ASSERT(b.fOperator == Token::STAR);
2095            }
2096            return result;
2097        } else {
2098            ABORT("unsupported binary expression: %s", b.description().c_str());
2099        }
2100    } else {
2101        operandType = &b.fLeft->fType;
2102        ASSERT(*operandType == b.fRight->fType);
2103    }
2104    switch (b.fOperator) {
2105        case Token::EQEQ: {
2106            ASSERT(resultType == *fContext.fBool_Type);
2107            return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2108                                                               SpvOpFOrdEqual, SpvOpIEqual,
2109                                                               SpvOpIEqual, SpvOpLogicalEqual, out),
2110                                    *operandType, out);
2111        }
2112        case Token::NEQ:
2113            ASSERT(resultType == *fContext.fBool_Type);
2114            return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2115                                                               SpvOpFOrdNotEqual, SpvOpINotEqual,
2116                                                               SpvOpINotEqual, SpvOpLogicalNotEqual,
2117                                                               out),
2118                                    *operandType, out);
2119        case Token::GT:
2120            ASSERT(resultType == *fContext.fBool_Type);
2121            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2122                                              SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
2123                                              SpvOpUGreaterThan, SpvOpUndef, out);
2124        case Token::LT:
2125            ASSERT(resultType == *fContext.fBool_Type);
2126            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
2127                                              SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
2128        case Token::GTEQ:
2129            ASSERT(resultType == *fContext.fBool_Type);
2130            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2131                                              SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
2132                                              SpvOpUGreaterThanEqual, SpvOpUndef, out);
2133        case Token::LTEQ:
2134            ASSERT(resultType == *fContext.fBool_Type);
2135            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
2136                                              SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
2137                                              SpvOpULessThanEqual, SpvOpUndef, out);
2138        case Token::PLUS:
2139            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2140                                              SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2141        case Token::MINUS:
2142            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2143                                              SpvOpISub, SpvOpISub, SpvOpUndef, out);
2144        case Token::STAR:
2145            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2146                b.fRight->fType.kind() == Type::kMatrix_Kind) {
2147                // matrix multiply
2148                SpvId result = this->nextId();
2149                this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2150                                       lhs, rhs, out);
2151                return result;
2152            }
2153            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2154                                              SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2155        case Token::SLASH:
2156            return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2157                                              SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2158        case Token::PLUSEQ: {
2159            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
2160                                                      SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2161            ASSERT(lvalue);
2162            lvalue->store(result, out);
2163            return result;
2164        }
2165        case Token::MINUSEQ: {
2166            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
2167                                                      SpvOpISub, SpvOpISub, SpvOpUndef, out);
2168            ASSERT(lvalue);
2169            lvalue->store(result, out);
2170            return result;
2171        }
2172        case Token::STAREQ: {
2173            if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
2174                b.fRight->fType.kind() == Type::kMatrix_Kind) {
2175                // matrix multiply
2176                SpvId result = this->nextId();
2177                this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
2178                                       lhs, rhs, out);
2179                ASSERT(lvalue);
2180                lvalue->store(result, out);
2181                return result;
2182            }
2183            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
2184                                                      SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
2185            ASSERT(lvalue);
2186            lvalue->store(result, out);
2187            return result;
2188        }
2189        case Token::SLASHEQ: {
2190            SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
2191                                                      SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
2192            ASSERT(lvalue);
2193            lvalue->store(result, out);
2194            return result;
2195        }
2196        default:
2197            // FIXME: missing support for some operators (bitwise, &&=, ||=, shift...)
2198            ABORT("unsupported binary expression: %s", b.description().c_str());
2199    }
2200}
2201
2202SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, SkWStream& out) {
2203    ASSERT(a.fOperator == Token::LOGICALAND);
2204    BoolLiteral falseLiteral(fContext, Position(), false);
2205    SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
2206    SpvId lhs = this->writeExpression(*a.fLeft, out);
2207    SpvId rhsLabel = this->nextId();
2208    SpvId end = this->nextId();
2209    SpvId lhsBlock = fCurrentBlock;
2210    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2211    this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
2212    this->writeLabel(rhsLabel, out);
2213    SpvId rhs = this->writeExpression(*a.fRight, out);
2214    SpvId rhsBlock = fCurrentBlock;
2215    this->writeInstruction(SpvOpBranch, end, out);
2216    this->writeLabel(end, out);
2217    SpvId result = this->nextId();
2218    this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
2219                           lhsBlock, rhs, rhsBlock, out);
2220    return result;
2221}
2222
2223SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, SkWStream& out) {
2224    ASSERT(o.fOperator == Token::LOGICALOR);
2225    BoolLiteral trueLiteral(fContext, Position(), true);
2226    SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
2227    SpvId lhs = this->writeExpression(*o.fLeft, out);
2228    SpvId rhsLabel = this->nextId();
2229    SpvId end = this->nextId();
2230    SpvId lhsBlock = fCurrentBlock;
2231    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2232    this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
2233    this->writeLabel(rhsLabel, out);
2234    SpvId rhs = this->writeExpression(*o.fRight, out);
2235    SpvId rhsBlock = fCurrentBlock;
2236    this->writeInstruction(SpvOpBranch, end, out);
2237    this->writeLabel(end, out);
2238    SpvId result = this->nextId();
2239    this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
2240                           lhsBlock, rhs, rhsBlock, out);
2241    return result;
2242}
2243
2244SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, SkWStream& out) {
2245    SpvId test = this->writeExpression(*t.fTest, out);
2246    if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
2247        // both true and false are constants, can just use OpSelect
2248        SpvId result = this->nextId();
2249        SpvId trueId = this->writeExpression(*t.fIfTrue, out);
2250        SpvId falseId = this->writeExpression(*t.fIfFalse, out);
2251        this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
2252                               out);
2253        return result;
2254    }
2255    // was originally using OpPhi to choose the result, but for some reason that is crashing on
2256    // Adreno. Switched to storing the result in a temp variable as glslang does.
2257    SpvId var = this->nextId();
2258    this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
2259                           var, SpvStorageClassFunction, fVariableBuffer);
2260    SpvId trueLabel = this->nextId();
2261    SpvId falseLabel = this->nextId();
2262    SpvId end = this->nextId();
2263    this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2264    this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
2265    this->writeLabel(trueLabel, out);
2266    this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
2267    this->writeInstruction(SpvOpBranch, end, out);
2268    this->writeLabel(falseLabel, out);
2269    this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
2270    this->writeInstruction(SpvOpBranch, end, out);
2271    this->writeLabel(end, out);
2272    SpvId result = this->nextId();
2273    this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
2274    return result;
2275}
2276
2277std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
2278    if (type == *context.fInt_Type) {
2279        return std::unique_ptr<Expression>(new IntLiteral(context, Position(), 1));
2280    }
2281    else if (type == *context.fFloat_Type) {
2282        return std::unique_ptr<Expression>(new FloatLiteral(context, Position(), 1.0));
2283    } else {
2284        ABORT("math is unsupported on type '%s'")
2285    }
2286}
2287
2288SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, SkWStream& out) {
2289    if (p.fOperator == Token::MINUS) {
2290        SpvId result = this->nextId();
2291        SpvId typeId = this->getType(p.fType);
2292        SpvId expr = this->writeExpression(*p.fOperand, out);
2293        if (is_float(fContext, p.fType)) {
2294            this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
2295        } else if (is_signed(fContext, p.fType)) {
2296            this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
2297        } else {
2298            ABORT("unsupported prefix expression %s", p.description().c_str());
2299        };
2300        return result;
2301    }
2302    switch (p.fOperator) {
2303        case Token::PLUS:
2304            return this->writeExpression(*p.fOperand, out);
2305        case Token::PLUSPLUS: {
2306            std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2307            SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2308            SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2309                                                      SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
2310                                                      out);
2311            lv->store(result, out);
2312            return result;
2313        }
2314        case Token::MINUSMINUS: {
2315            std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2316            SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2317            SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
2318                                                      SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
2319                                                      out);
2320            lv->store(result, out);
2321            return result;
2322        }
2323        case Token::LOGICALNOT: {
2324            ASSERT(p.fOperand->fType == *fContext.fBool_Type);
2325            SpvId result = this->nextId();
2326            this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
2327                                   this->writeExpression(*p.fOperand, out), out);
2328            return result;
2329        }
2330        case Token::BITWISENOT: {
2331            SpvId result = this->nextId();
2332            this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
2333                                   this->writeExpression(*p.fOperand, out), out);
2334            return result;
2335        }
2336        default:
2337            ABORT("unsupported prefix expression: %s", p.description().c_str());
2338    }
2339}
2340
2341SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, SkWStream& out) {
2342    std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
2343    SpvId result = lv->load(out);
2344    SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
2345    switch (p.fOperator) {
2346        case Token::PLUSPLUS: {
2347            SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
2348                                                    SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
2349            lv->store(temp, out);
2350            return result;
2351        }
2352        case Token::MINUSMINUS: {
2353            SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
2354                                                    SpvOpISub, SpvOpISub, SpvOpUndef, out);
2355            lv->store(temp, out);
2356            return result;
2357        }
2358        default:
2359            ABORT("unsupported postfix expression %s", p.description().c_str());
2360    }
2361}
2362
2363SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
2364    if (b.fValue) {
2365        if (fBoolTrue == 0) {
2366            fBoolTrue = this->nextId();
2367            this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
2368                                   fConstantBuffer);
2369        }
2370        return fBoolTrue;
2371    } else {
2372        if (fBoolFalse == 0) {
2373            fBoolFalse = this->nextId();
2374            this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
2375                                   fConstantBuffer);
2376        }
2377        return fBoolFalse;
2378    }
2379}
2380
2381SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
2382    if (i.fType == *fContext.fInt_Type) {
2383        auto entry = fIntConstants.find(i.fValue);
2384        if (entry == fIntConstants.end()) {
2385            SpvId result = this->nextId();
2386            this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2387                                   fConstantBuffer);
2388            fIntConstants[i.fValue] = result;
2389            return result;
2390        }
2391        return entry->second;
2392    } else {
2393        ASSERT(i.fType == *fContext.fUInt_Type);
2394        auto entry = fUIntConstants.find(i.fValue);
2395        if (entry == fUIntConstants.end()) {
2396            SpvId result = this->nextId();
2397            this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
2398                                   fConstantBuffer);
2399            fUIntConstants[i.fValue] = result;
2400            return result;
2401        }
2402        return entry->second;
2403    }
2404}
2405
2406SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
2407    if (f.fType == *fContext.fFloat_Type) {
2408        float value = (float) f.fValue;
2409        auto entry = fFloatConstants.find(value);
2410        if (entry == fFloatConstants.end()) {
2411            SpvId result = this->nextId();
2412            uint32_t bits;
2413            ASSERT(sizeof(bits) == sizeof(value));
2414            memcpy(&bits, &value, sizeof(bits));
2415            this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
2416                                   fConstantBuffer);
2417            fFloatConstants[value] = result;
2418            return result;
2419        }
2420        return entry->second;
2421    } else {
2422        ASSERT(f.fType == *fContext.fDouble_Type);
2423        auto entry = fDoubleConstants.find(f.fValue);
2424        if (entry == fDoubleConstants.end()) {
2425            SpvId result = this->nextId();
2426            uint64_t bits;
2427            ASSERT(sizeof(bits) == sizeof(f.fValue));
2428            memcpy(&bits, &f.fValue, sizeof(bits));
2429            this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
2430                                   bits & 0xffffffff, bits >> 32, fConstantBuffer);
2431            fDoubleConstants[f.fValue] = result;
2432            return result;
2433        }
2434        return entry->second;
2435    }
2436}
2437
2438SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, SkWStream& out) {
2439    SpvId result = fFunctionMap[&f];
2440    this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
2441                           SpvFunctionControlMaskNone, this->getFunctionType(f), out);
2442    this->writeInstruction(SpvOpName, result, f.fName.c_str(), fNameBuffer);
2443    for (size_t i = 0; i < f.fParameters.size(); i++) {
2444        SpvId id = this->nextId();
2445        fVariableMap[f.fParameters[i]] = id;
2446        SpvId type;
2447        type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
2448        this->writeInstruction(SpvOpFunctionParameter, type, id, out);
2449    }
2450    return result;
2451}
2452
2453SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, SkWStream& out) {
2454    SpvId result = this->writeFunctionStart(f.fDeclaration, out);
2455    this->writeLabel(this->nextId(), out);
2456    if (f.fDeclaration.fName == "main") {
2457        write_data(*fGlobalInitializersBuffer.detachAsData(), out);
2458    }
2459    SkDynamicMemoryWStream bodyBuffer;
2460    this->writeBlock(*f.fBody, bodyBuffer);
2461    write_data(*fVariableBuffer.detachAsData(), out);
2462    write_data(*bodyBuffer.detachAsData(), out);
2463    if (fCurrentBlock) {
2464        this->writeInstruction(SpvOpReturn, out);
2465    }
2466    this->writeInstruction(SpvOpFunctionEnd, out);
2467    return result;
2468}
2469
2470void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
2471    if (layout.fLocation >= 0) {
2472        this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
2473                               fDecorationBuffer);
2474    }
2475    if (layout.fBinding >= 0) {
2476        this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
2477                               fDecorationBuffer);
2478    }
2479    if (layout.fIndex >= 0) {
2480        this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
2481                               fDecorationBuffer);
2482    }
2483    if (layout.fSet >= 0) {
2484        this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
2485                               fDecorationBuffer);
2486    }
2487    if (layout.fInputAttachmentIndex >= 0) {
2488        this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
2489                               layout.fInputAttachmentIndex, fDecorationBuffer);
2490    }
2491    if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
2492        this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
2493                               fDecorationBuffer);
2494    }
2495}
2496
2497void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
2498    if (layout.fLocation >= 0) {
2499        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
2500                               layout.fLocation, fDecorationBuffer);
2501    }
2502    if (layout.fBinding >= 0) {
2503        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
2504                               layout.fBinding, fDecorationBuffer);
2505    }
2506    if (layout.fIndex >= 0) {
2507        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
2508                               layout.fIndex, fDecorationBuffer);
2509    }
2510    if (layout.fSet >= 0) {
2511        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
2512                               layout.fSet, fDecorationBuffer);
2513    }
2514    if (layout.fInputAttachmentIndex >= 0) {
2515        this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
2516                               layout.fInputAttachmentIndex, fDecorationBuffer);
2517    }
2518    if (layout.fBuiltin >= 0) {
2519        this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
2520                               layout.fBuiltin, fDecorationBuffer);
2521    }
2522}
2523
2524SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
2525    MemoryLayout layout = intf.fVariable.fModifiers.fLayout.fPushConstant ?
2526                          MemoryLayout(MemoryLayout::k430_Standard) :
2527                          fDefaultLayout;
2528    SpvId result = this->nextId();
2529    const Type* type = &intf.fVariable.fType;
2530    if (fProgram.fInputs.fRTHeight) {
2531        ASSERT(fRTHeightStructId == (SpvId) -1);
2532        ASSERT(fRTHeightFieldIndex == (SpvId) -1);
2533        std::vector<Type::Field> fields = type->fields();
2534        fRTHeightStructId = result;
2535        fRTHeightFieldIndex = fields.size();
2536        fields.emplace_back(Modifiers(), SkString(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
2537        type = new Type(type->fPosition, type->name(), fields);
2538    }
2539    SpvId typeId = this->getType(*type, layout);
2540    this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
2541    SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
2542    SpvId ptrType = this->nextId();
2543    this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
2544    this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
2545    this->writeLayout(intf.fVariable.fModifiers.fLayout, result);
2546    fVariableMap[&intf.fVariable] = result;
2547    if (fProgram.fInputs.fRTHeight) {
2548        delete type;
2549    }
2550    return result;
2551}
2552
2553#define BUILTIN_IGNORE 9999
2554void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
2555                                         SkWStream& out) {
2556    for (size_t i = 0; i < decl.fVars.size(); i++) {
2557        const VarDeclaration& varDecl = decl.fVars[i];
2558        const Variable* var = varDecl.fVar;
2559        // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2560        // in the OpenGL backend.
2561        ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2562                                           Modifiers::kWriteOnly_Flag |
2563                                           Modifiers::kCoherent_Flag |
2564                                           Modifiers::kVolatile_Flag |
2565                                           Modifiers::kRestrict_Flag)));
2566        if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
2567            continue;
2568        }
2569        if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
2570            kind != Program::kFragment_Kind) {
2571            continue;
2572        }
2573        if (!var->fReadCount && !var->fWriteCount &&
2574                !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
2575                                            Modifiers::kOut_Flag |
2576                                            Modifiers::kUniform_Flag))) {
2577            // variable is dead and not an input / output var (the Vulkan debug layers complain if
2578            // we elide an interface var, even if it's dead)
2579            continue;
2580        }
2581        SpvStorageClass_ storageClass;
2582        if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
2583            storageClass = SpvStorageClassInput;
2584        } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
2585            storageClass = SpvStorageClassOutput;
2586        } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
2587            if (var->fType.kind() == Type::kSampler_Kind) {
2588                storageClass = SpvStorageClassUniformConstant;
2589            } else {
2590                storageClass = SpvStorageClassUniform;
2591            }
2592        } else {
2593            storageClass = SpvStorageClassPrivate;
2594        }
2595        SpvId id = this->nextId();
2596        fVariableMap[var] = id;
2597        SpvId type = this->getPointerType(var->fType, storageClass);
2598        this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
2599        this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer);
2600        if (var->fType.kind() == Type::kMatrix_Kind) {
2601            this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor,
2602                                   fDecorationBuffer);
2603            this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride,
2604                                   (SpvId) fDefaultLayout.stride(var->fType), fDecorationBuffer);
2605        }
2606        if (varDecl.fValue) {
2607            ASSERT(!fCurrentBlock);
2608            fCurrentBlock = -1;
2609            SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
2610            this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
2611            fCurrentBlock = 0;
2612        }
2613        this->writeLayout(var->fModifiers.fLayout, id);
2614    }
2615}
2616
2617void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, SkWStream& out) {
2618    for (const auto& varDecl : decl.fVars) {
2619        const Variable* var = varDecl.fVar;
2620        // These haven't been implemented in our SPIR-V generator yet and we only currently use them
2621        // in the OpenGL backend.
2622        ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
2623                                           Modifiers::kWriteOnly_Flag |
2624                                           Modifiers::kCoherent_Flag |
2625                                           Modifiers::kVolatile_Flag |
2626                                           Modifiers::kRestrict_Flag)));
2627        SpvId id = this->nextId();
2628        fVariableMap[var] = id;
2629        SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
2630        this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
2631        this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer);
2632        if (varDecl.fValue) {
2633            SpvId value = this->writeExpression(*varDecl.fValue, out);
2634            this->writeInstruction(SpvOpStore, id, value, out);
2635        }
2636    }
2637}
2638
2639void SPIRVCodeGenerator::writeStatement(const Statement& s, SkWStream& out) {
2640    switch (s.fKind) {
2641        case Statement::kBlock_Kind:
2642            this->writeBlock((Block&) s, out);
2643            break;
2644        case Statement::kExpression_Kind:
2645            this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
2646            break;
2647        case Statement::kReturn_Kind:
2648            this->writeReturnStatement((ReturnStatement&) s, out);
2649            break;
2650        case Statement::kVarDeclarations_Kind:
2651            this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
2652            break;
2653        case Statement::kIf_Kind:
2654            this->writeIfStatement((IfStatement&) s, out);
2655            break;
2656        case Statement::kFor_Kind:
2657            this->writeForStatement((ForStatement&) s, out);
2658            break;
2659        case Statement::kWhile_Kind:
2660            this->writeWhileStatement((WhileStatement&) s, out);
2661            break;
2662        case Statement::kDo_Kind:
2663            this->writeDoStatement((DoStatement&) s, out);
2664            break;
2665        case Statement::kBreak_Kind:
2666            this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
2667            break;
2668        case Statement::kContinue_Kind:
2669            this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
2670            break;
2671        case Statement::kDiscard_Kind:
2672            this->writeInstruction(SpvOpKill, out);
2673            break;
2674        default:
2675            ABORT("unsupported statement: %s", s.description().c_str());
2676    }
2677}
2678
2679void SPIRVCodeGenerator::writeBlock(const Block& b, SkWStream& out) {
2680    for (size_t i = 0; i < b.fStatements.size(); i++) {
2681        this->writeStatement(*b.fStatements[i], out);
2682    }
2683}
2684
2685void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, SkWStream& out) {
2686    SpvId test = this->writeExpression(*stmt.fTest, out);
2687    SpvId ifTrue = this->nextId();
2688    SpvId ifFalse = this->nextId();
2689    if (stmt.fIfFalse) {
2690        SpvId end = this->nextId();
2691        this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
2692        this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2693        this->writeLabel(ifTrue, out);
2694        this->writeStatement(*stmt.fIfTrue, out);
2695        if (fCurrentBlock) {
2696            this->writeInstruction(SpvOpBranch, end, out);
2697        }
2698        this->writeLabel(ifFalse, out);
2699        this->writeStatement(*stmt.fIfFalse, out);
2700        if (fCurrentBlock) {
2701            this->writeInstruction(SpvOpBranch, end, out);
2702        }
2703        this->writeLabel(end, out);
2704    } else {
2705        this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
2706        this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
2707        this->writeLabel(ifTrue, out);
2708        this->writeStatement(*stmt.fIfTrue, out);
2709        if (fCurrentBlock) {
2710            this->writeInstruction(SpvOpBranch, ifFalse, out);
2711        }
2712        this->writeLabel(ifFalse, out);
2713    }
2714}
2715
2716void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, SkWStream& out) {
2717    if (f.fInitializer) {
2718        this->writeStatement(*f.fInitializer, out);
2719    }
2720    SpvId header = this->nextId();
2721    SpvId start = this->nextId();
2722    SpvId body = this->nextId();
2723    SpvId next = this->nextId();
2724    fContinueTarget.push(next);
2725    SpvId end = this->nextId();
2726    fBreakTarget.push(end);
2727    this->writeInstruction(SpvOpBranch, header, out);
2728    this->writeLabel(header, out);
2729    this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
2730    this->writeInstruction(SpvOpBranch, start, out);
2731    this->writeLabel(start, out);
2732    if (f.fTest) {
2733        SpvId test = this->writeExpression(*f.fTest, out);
2734        this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2735    }
2736    this->writeLabel(body, out);
2737    this->writeStatement(*f.fStatement, out);
2738    if (fCurrentBlock) {
2739        this->writeInstruction(SpvOpBranch, next, out);
2740    }
2741    this->writeLabel(next, out);
2742    if (f.fNext) {
2743        this->writeExpression(*f.fNext, out);
2744    }
2745    this->writeInstruction(SpvOpBranch, header, out);
2746    this->writeLabel(end, out);
2747    fBreakTarget.pop();
2748    fContinueTarget.pop();
2749}
2750
2751void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, SkWStream& out) {
2752    // We believe the while loop code below will work, but Skia doesn't actually use them and
2753    // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2754    // the time being, we just fail with an error due to the lack of testing. If you encounter this
2755    // message, simply remove the error call below to see whether our while loop support actually
2756    // works.
2757    fErrors.error(w.fPosition, "internal error: while loop support has been disabled in SPIR-V, "
2758                  "see SkSLSPIRVCodeGenerator.cpp for details");
2759
2760    SpvId header = this->nextId();
2761    SpvId start = this->nextId();
2762    SpvId body = this->nextId();
2763    fContinueTarget.push(start);
2764    SpvId end = this->nextId();
2765    fBreakTarget.push(end);
2766    this->writeInstruction(SpvOpBranch, header, out);
2767    this->writeLabel(header, out);
2768    this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2769    this->writeInstruction(SpvOpBranch, start, out);
2770    this->writeLabel(start, out);
2771    SpvId test = this->writeExpression(*w.fTest, out);
2772    this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
2773    this->writeLabel(body, out);
2774    this->writeStatement(*w.fStatement, out);
2775    if (fCurrentBlock) {
2776        this->writeInstruction(SpvOpBranch, start, out);
2777    }
2778    this->writeLabel(end, out);
2779    fBreakTarget.pop();
2780    fContinueTarget.pop();
2781}
2782
2783void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, SkWStream& out) {
2784    // We believe the do loop code below will work, but Skia doesn't actually use them and
2785    // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
2786    // the time being, we just fail with an error due to the lack of testing. If you encounter this
2787    // message, simply remove the error call below to see whether our do loop support actually
2788    // works.
2789    fErrors.error(d.fPosition, "internal error: do loop support has been disabled in SPIR-V, see "
2790                  "SkSLSPIRVCodeGenerator.cpp for details");
2791
2792    SpvId header = this->nextId();
2793    SpvId start = this->nextId();
2794    SpvId next = this->nextId();
2795    fContinueTarget.push(next);
2796    SpvId end = this->nextId();
2797    fBreakTarget.push(end);
2798    this->writeInstruction(SpvOpBranch, header, out);
2799    this->writeLabel(header, out);
2800    this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
2801    this->writeInstruction(SpvOpBranch, start, out);
2802    this->writeLabel(start, out);
2803    this->writeStatement(*d.fStatement, out);
2804    if (fCurrentBlock) {
2805        this->writeInstruction(SpvOpBranch, next, out);
2806    }
2807    this->writeLabel(next, out);
2808    SpvId test = this->writeExpression(*d.fTest, out);
2809    this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
2810    this->writeLabel(end, out);
2811    fBreakTarget.pop();
2812    fContinueTarget.pop();
2813}
2814
2815void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, SkWStream& out) {
2816    if (r.fExpression) {
2817        this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
2818                               out);
2819    } else {
2820        this->writeInstruction(SpvOpReturn, out);
2821    }
2822}
2823
2824void SPIRVCodeGenerator::writeInstructions(const Program& program, SkWStream& out) {
2825    fGLSLExtendedInstructions = this->nextId();
2826    SkDynamicMemoryWStream body;
2827    std::set<SpvId> interfaceVars;
2828    // assign IDs to functions
2829    for (size_t i = 0; i < program.fElements.size(); i++) {
2830        if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
2831            FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
2832            fFunctionMap[&f.fDeclaration] = this->nextId();
2833        }
2834    }
2835    for (size_t i = 0; i < program.fElements.size(); i++) {
2836        if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
2837            InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
2838            SpvId id = this->writeInterfaceBlock(intf);
2839            if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
2840                (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
2841                interfaceVars.insert(id);
2842            }
2843        }
2844    }
2845    for (size_t i = 0; i < program.fElements.size(); i++) {
2846        if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
2847            this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]),
2848                                  body);
2849        }
2850    }
2851    for (size_t i = 0; i < program.fElements.size(); i++) {
2852        if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
2853            this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
2854        }
2855    }
2856    const FunctionDeclaration* main = nullptr;
2857    for (auto entry : fFunctionMap) {
2858        if (entry.first->fName == "main") {
2859            main = entry.first;
2860        }
2861    }
2862    ASSERT(main);
2863    for (auto entry : fVariableMap) {
2864        const Variable* var = entry.first;
2865        if (var->fStorage == Variable::kGlobal_Storage &&
2866                ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
2867                 (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
2868            interfaceVars.insert(entry.second);
2869        }
2870    }
2871    this->writeCapabilities(out);
2872    this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
2873    this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
2874    this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (strlen(main->fName.c_str()) + 4) / 4) +
2875                      (int32_t) interfaceVars.size(), out);
2876    switch (program.fKind) {
2877        case Program::kVertex_Kind:
2878            this->writeWord(SpvExecutionModelVertex, out);
2879            break;
2880        case Program::kFragment_Kind:
2881            this->writeWord(SpvExecutionModelFragment, out);
2882            break;
2883        case Program::kGeometry_Kind:
2884            this->writeWord(SpvExecutionModelGeometry, out);
2885            break;
2886    }
2887    this->writeWord(fFunctionMap[main], out);
2888    this->writeString(main->fName.c_str(), out);
2889    for (int var : interfaceVars) {
2890        this->writeWord(var, out);
2891    }
2892    if (program.fKind == Program::kFragment_Kind) {
2893        this->writeInstruction(SpvOpExecutionMode,
2894                               fFunctionMap[main],
2895                               SpvExecutionModeOriginUpperLeft,
2896                               out);
2897    }
2898    for (size_t i = 0; i < program.fElements.size(); i++) {
2899        if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
2900            this->writeInstruction(SpvOpSourceExtension,
2901                                   ((Extension&) *program.fElements[i]).fName.c_str(),
2902                                   out);
2903        }
2904    }
2905
2906    write_data(*fExtraGlobalsBuffer.detachAsData(), out);
2907    write_data(*fNameBuffer.detachAsData(), out);
2908    write_data(*fDecorationBuffer.detachAsData(), out);
2909    write_data(*fConstantBuffer.detachAsData(), out);
2910    write_data(*fExternalFunctionsBuffer.detachAsData(), out);
2911    write_data(*body.detachAsData(), out);
2912}
2913
2914bool SPIRVCodeGenerator::generateCode() {
2915    ASSERT(!fErrors.errorCount());
2916    this->writeWord(SpvMagicNumber, *fOut);
2917    this->writeWord(SpvVersion, *fOut);
2918    this->writeWord(SKSL_MAGIC, *fOut);
2919    SkDynamicMemoryWStream buffer;
2920    this->writeInstructions(fProgram, buffer);
2921    this->writeWord(fIdCount, *fOut);
2922    this->writeWord(0, *fOut); // reserved, always zero
2923    write_data(*buffer.detachAsData(), *fOut);
2924    return 0 == fErrors.errorCount();
2925}
2926
2927}
2928