1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "SkSLCompiler.h"
9
10#include "SkSLCFGGenerator.h"
11#include "SkSLCPPCodeGenerator.h"
12#include "SkSLGLSLCodeGenerator.h"
13#include "SkSLHCodeGenerator.h"
14#include "SkSLIRGenerator.h"
15#include "SkSLMetalCodeGenerator.h"
16#include "SkSLSPIRVCodeGenerator.h"
17#include "ir/SkSLEnum.h"
18#include "ir/SkSLExpression.h"
19#include "ir/SkSLExpressionStatement.h"
20#include "ir/SkSLIntLiteral.h"
21#include "ir/SkSLModifiersDeclaration.h"
22#include "ir/SkSLNop.h"
23#include "ir/SkSLSymbolTable.h"
24#include "ir/SkSLTernaryExpression.h"
25#include "ir/SkSLUnresolvedFunction.h"
26#include "ir/SkSLVarDeclarations.h"
27
28#ifdef SK_ENABLE_SPIRV_VALIDATION
29#include "spirv-tools/libspirv.hpp"
30#endif
31
32// include the built-in shader symbols as static strings
33
34#define STRINGIFY(x) #x
35
36static const char* SKSL_INCLUDE =
37#include "sksl.inc"
38;
39
40static const char* SKSL_VERT_INCLUDE =
41#include "sksl_vert.inc"
42;
43
44static const char* SKSL_FRAG_INCLUDE =
45#include "sksl_frag.inc"
46;
47
48static const char* SKSL_GEOM_INCLUDE =
49#include "sksl_geom.inc"
50;
51
52static const char* SKSL_FP_INCLUDE =
53#include "sksl_enums.inc"
54#include "sksl_fp.inc"
55;
56
57namespace SkSL {
58
59Compiler::Compiler(Flags flags)
60: fFlags(flags)
61, fErrorCount(0) {
62    auto types = std::shared_ptr<SymbolTable>(new SymbolTable(this));
63    auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, this));
64    fIRGenerator = new IRGenerator(&fContext, symbols, *this);
65    fTypes = types;
66    #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \
67                                                   fContext.f ## t ## _Type.get())
68    ADD_TYPE(Void);
69    ADD_TYPE(Float);
70    ADD_TYPE(Float2);
71    ADD_TYPE(Float3);
72    ADD_TYPE(Float4);
73    ADD_TYPE(Half);
74    ADD_TYPE(Half2);
75    ADD_TYPE(Half3);
76    ADD_TYPE(Half4);
77    ADD_TYPE(Double);
78    ADD_TYPE(Double2);
79    ADD_TYPE(Double3);
80    ADD_TYPE(Double4);
81    ADD_TYPE(Int);
82    ADD_TYPE(Int2);
83    ADD_TYPE(Int3);
84    ADD_TYPE(Int4);
85    ADD_TYPE(UInt);
86    ADD_TYPE(UInt2);
87    ADD_TYPE(UInt3);
88    ADD_TYPE(UInt4);
89    ADD_TYPE(Short);
90    ADD_TYPE(Short2);
91    ADD_TYPE(Short3);
92    ADD_TYPE(Short4);
93    ADD_TYPE(UShort);
94    ADD_TYPE(UShort2);
95    ADD_TYPE(UShort3);
96    ADD_TYPE(UShort4);
97    ADD_TYPE(Bool);
98    ADD_TYPE(Bool2);
99    ADD_TYPE(Bool3);
100    ADD_TYPE(Bool4);
101    ADD_TYPE(Float2x2);
102    ADD_TYPE(Float2x3);
103    ADD_TYPE(Float2x4);
104    ADD_TYPE(Float3x2);
105    ADD_TYPE(Float3x3);
106    ADD_TYPE(Float3x4);
107    ADD_TYPE(Float4x2);
108    ADD_TYPE(Float4x3);
109    ADD_TYPE(Float4x4);
110    ADD_TYPE(Half2x2);
111    ADD_TYPE(Half2x3);
112    ADD_TYPE(Half2x4);
113    ADD_TYPE(Half3x2);
114    ADD_TYPE(Half3x3);
115    ADD_TYPE(Half3x4);
116    ADD_TYPE(Half4x2);
117    ADD_TYPE(Half4x3);
118    ADD_TYPE(Half4x4);
119    ADD_TYPE(Double2x2);
120    ADD_TYPE(Double2x3);
121    ADD_TYPE(Double2x4);
122    ADD_TYPE(Double3x2);
123    ADD_TYPE(Double3x3);
124    ADD_TYPE(Double3x4);
125    ADD_TYPE(Double4x2);
126    ADD_TYPE(Double4x3);
127    ADD_TYPE(Double4x4);
128    ADD_TYPE(GenType);
129    ADD_TYPE(GenHType);
130    ADD_TYPE(GenDType);
131    ADD_TYPE(GenIType);
132    ADD_TYPE(GenUType);
133    ADD_TYPE(GenBType);
134    ADD_TYPE(Mat);
135    ADD_TYPE(Vec);
136    ADD_TYPE(GVec);
137    ADD_TYPE(GVec2);
138    ADD_TYPE(GVec3);
139    ADD_TYPE(GVec4);
140    ADD_TYPE(HVec);
141    ADD_TYPE(DVec);
142    ADD_TYPE(IVec);
143    ADD_TYPE(UVec);
144    ADD_TYPE(SVec);
145    ADD_TYPE(USVec);
146    ADD_TYPE(BVec);
147
148    ADD_TYPE(Sampler1D);
149    ADD_TYPE(Sampler2D);
150    ADD_TYPE(Sampler3D);
151    ADD_TYPE(SamplerExternalOES);
152    ADD_TYPE(SamplerCube);
153    ADD_TYPE(Sampler2DRect);
154    ADD_TYPE(Sampler1DArray);
155    ADD_TYPE(Sampler2DArray);
156    ADD_TYPE(SamplerCubeArray);
157    ADD_TYPE(SamplerBuffer);
158    ADD_TYPE(Sampler2DMS);
159    ADD_TYPE(Sampler2DMSArray);
160
161    ADD_TYPE(ISampler2D);
162
163    ADD_TYPE(Image2D);
164    ADD_TYPE(IImage2D);
165
166    ADD_TYPE(SubpassInput);
167    ADD_TYPE(SubpassInputMS);
168
169    ADD_TYPE(GSampler1D);
170    ADD_TYPE(GSampler2D);
171    ADD_TYPE(GSampler3D);
172    ADD_TYPE(GSamplerCube);
173    ADD_TYPE(GSampler2DRect);
174    ADD_TYPE(GSampler1DArray);
175    ADD_TYPE(GSampler2DArray);
176    ADD_TYPE(GSamplerCubeArray);
177    ADD_TYPE(GSamplerBuffer);
178    ADD_TYPE(GSampler2DMS);
179    ADD_TYPE(GSampler2DMSArray);
180
181    ADD_TYPE(Sampler1DShadow);
182    ADD_TYPE(Sampler2DShadow);
183    ADD_TYPE(SamplerCubeShadow);
184    ADD_TYPE(Sampler2DRectShadow);
185    ADD_TYPE(Sampler1DArrayShadow);
186    ADD_TYPE(Sampler2DArrayShadow);
187    ADD_TYPE(SamplerCubeArrayShadow);
188    ADD_TYPE(GSampler2DArrayShadow);
189    ADD_TYPE(GSamplerCubeArrayShadow);
190    ADD_TYPE(FragmentProcessor);
191
192    StringFragment skCapsName("sk_Caps");
193    Variable* skCaps = new Variable(-1, Modifiers(), skCapsName,
194                                    *fContext.fSkCaps_Type, Variable::kGlobal_Storage);
195    fIRGenerator->fSymbolTable->add(skCapsName, std::unique_ptr<Symbol>(skCaps));
196
197    StringFragment skArgsName("sk_Args");
198    Variable* skArgs = new Variable(-1, Modifiers(), skArgsName,
199                                    *fContext.fSkArgs_Type, Variable::kGlobal_Storage);
200    fIRGenerator->fSymbolTable->add(skArgsName, std::unique_ptr<Symbol>(skArgs));
201
202    std::vector<std::unique_ptr<ProgramElement>> ignored;
203    fIRGenerator->convertProgram(Program::kFragment_Kind, SKSL_INCLUDE, strlen(SKSL_INCLUDE),
204                                 *fTypes, &ignored);
205    fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
206    if (fErrorCount) {
207        printf("Unexpected errors: %s\n", fErrorText.c_str());
208    }
209    ASSERT(!fErrorCount);
210}
211
212Compiler::~Compiler() {
213    delete fIRGenerator;
214}
215
216// add the definition created by assigning to the lvalue to the definition set
217void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expression>* expr,
218                             DefinitionMap* definitions) {
219    switch (lvalue->fKind) {
220        case Expression::kVariableReference_Kind: {
221            const Variable& var = ((VariableReference*) lvalue)->fVariable;
222            if (var.fStorage == Variable::kLocal_Storage) {
223                (*definitions)[&var] = expr;
224            }
225            break;
226        }
227        case Expression::kSwizzle_Kind:
228            // We consider the variable written to as long as at least some of its components have
229            // been written to. This will lead to some false negatives (we won't catch it if you
230            // write to foo.x and then read foo.y), but being stricter could lead to false positives
231            // (we write to foo.x, and then pass foo to a function which happens to only read foo.x,
232            // but since we pass foo as a whole it is flagged as an error) unless we perform a much
233            // more complicated whole-program analysis. This is probably good enough.
234            this->addDefinition(((Swizzle*) lvalue)->fBase.get(),
235                                (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
236                                definitions);
237            break;
238        case Expression::kIndex_Kind:
239            // see comments in Swizzle
240            this->addDefinition(((IndexExpression*) lvalue)->fBase.get(),
241                                (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
242                                definitions);
243            break;
244        case Expression::kFieldAccess_Kind:
245            // see comments in Swizzle
246            this->addDefinition(((FieldAccess*) lvalue)->fBase.get(),
247                                (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
248                                definitions);
249            break;
250        case Expression::kTernary_Kind:
251            // To simplify analysis, we just pretend that we write to both sides of the ternary.
252            // This allows for false positives (meaning we fail to detect that a variable might not
253            // have been assigned), but is preferable to false negatives.
254            this->addDefinition(((TernaryExpression*) lvalue)->fIfTrue.get(),
255                                (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
256                                definitions);
257            this->addDefinition(((TernaryExpression*) lvalue)->fIfFalse.get(),
258                                (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
259                                definitions);
260            break;
261        default:
262            // not an lvalue, can't happen
263            ASSERT(false);
264    }
265}
266
267// add local variables defined by this node to the set
268void Compiler::addDefinitions(const BasicBlock::Node& node,
269                              DefinitionMap* definitions) {
270    switch (node.fKind) {
271        case BasicBlock::Node::kExpression_Kind: {
272            ASSERT(node.expression());
273            const Expression* expr = (Expression*) node.expression()->get();
274            switch (expr->fKind) {
275                case Expression::kBinary_Kind: {
276                    BinaryExpression* b = (BinaryExpression*) expr;
277                    if (b->fOperator == Token::EQ) {
278                        this->addDefinition(b->fLeft.get(), &b->fRight, definitions);
279                    } else if (Compiler::IsAssignment(b->fOperator)) {
280                        this->addDefinition(
281                                       b->fLeft.get(),
282                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
283                                       definitions);
284
285                    }
286                    break;
287                }
288                case Expression::kPrefix_Kind: {
289                    const PrefixExpression* p = (PrefixExpression*) expr;
290                    if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
291                        this->addDefinition(
292                                       p->fOperand.get(),
293                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
294                                       definitions);
295                    }
296                    break;
297                }
298                case Expression::kPostfix_Kind: {
299                    const PostfixExpression* p = (PostfixExpression*) expr;
300                    if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
301                        this->addDefinition(
302                                       p->fOperand.get(),
303                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
304                                       definitions);
305                    }
306                    break;
307                }
308                case Expression::kVariableReference_Kind: {
309                    const VariableReference* v = (VariableReference*) expr;
310                    if (v->fRefKind != VariableReference::kRead_RefKind) {
311                        this->addDefinition(
312                                       v,
313                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
314                                       definitions);
315                    }
316                }
317                default:
318                    break;
319            }
320            break;
321        }
322        case BasicBlock::Node::kStatement_Kind: {
323            const Statement* stmt = (Statement*) node.statement()->get();
324            if (stmt->fKind == Statement::kVarDeclaration_Kind) {
325                VarDeclaration& vd = (VarDeclaration&) *stmt;
326                if (vd.fValue) {
327                    (*definitions)[vd.fVar] = &vd.fValue;
328                }
329            }
330            break;
331        }
332    }
333}
334
335void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
336    BasicBlock& block = cfg->fBlocks[blockId];
337
338    // compute definitions after this block
339    DefinitionMap after = block.fBefore;
340    for (const BasicBlock::Node& n : block.fNodes) {
341        this->addDefinitions(n, &after);
342    }
343
344    // propagate definitions to exits
345    for (BlockId exitId : block.fExits) {
346        BasicBlock& exit = cfg->fBlocks[exitId];
347        for (const auto& pair : after) {
348            std::unique_ptr<Expression>* e1 = pair.second;
349            auto found = exit.fBefore.find(pair.first);
350            if (found == exit.fBefore.end()) {
351                // exit has no definition for it, just copy it
352                workList->insert(exitId);
353                exit.fBefore[pair.first] = e1;
354            } else {
355                // exit has a (possibly different) value already defined
356                std::unique_ptr<Expression>* e2 = exit.fBefore[pair.first];
357                if (e1 != e2) {
358                    // definition has changed, merge and add exit block to worklist
359                    workList->insert(exitId);
360                    if (e1 && e2) {
361                        exit.fBefore[pair.first] =
362                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression;
363                    } else {
364                        exit.fBefore[pair.first] = nullptr;
365                    }
366                }
367            }
368        }
369    }
370}
371
372// returns a map which maps all local variables in the function to null, indicating that their value
373// is initially unknown
374static DefinitionMap compute_start_state(const CFG& cfg) {
375    DefinitionMap result;
376    for (const auto& block : cfg.fBlocks) {
377        for (const auto& node : block.fNodes) {
378            if (node.fKind == BasicBlock::Node::kStatement_Kind) {
379                ASSERT(node.statement());
380                const Statement* s = node.statement()->get();
381                if (s->fKind == Statement::kVarDeclarations_Kind) {
382                    const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s;
383                    for (const auto& decl : vd->fDeclaration->fVars) {
384                        if (decl->fKind == Statement::kVarDeclaration_Kind) {
385                            result[((VarDeclaration&) *decl).fVar] = nullptr;
386                        }
387                    }
388                }
389            }
390        }
391    }
392    return result;
393}
394
395/**
396 * Returns true if assigning to this lvalue has no effect.
397 */
398static bool is_dead(const Expression& lvalue) {
399    switch (lvalue.fKind) {
400        case Expression::kVariableReference_Kind:
401            return ((VariableReference&) lvalue).fVariable.dead();
402        case Expression::kSwizzle_Kind:
403            return is_dead(*((Swizzle&) lvalue).fBase);
404        case Expression::kFieldAccess_Kind:
405            return is_dead(*((FieldAccess&) lvalue).fBase);
406        case Expression::kIndex_Kind: {
407            const IndexExpression& idx = (IndexExpression&) lvalue;
408            return is_dead(*idx.fBase) && !idx.fIndex->hasSideEffects();
409        }
410        case Expression::kTernary_Kind: {
411            const TernaryExpression& t = (TernaryExpression&) lvalue;
412            return !t.fTest->hasSideEffects() && is_dead(*t.fIfTrue) && is_dead(*t.fIfFalse);
413        }
414        default:
415            ABORT("invalid lvalue: %s\n", lvalue.description().c_str());
416    }
417}
418
419/**
420 * Returns true if this is an assignment which can be collapsed down to just the right hand side due
421 * to a dead target and lack of side effects on the left hand side.
422 */
423static bool dead_assignment(const BinaryExpression& b) {
424    if (!Compiler::IsAssignment(b.fOperator)) {
425        return false;
426    }
427    return is_dead(*b.fLeft);
428}
429
430void Compiler::computeDataFlow(CFG* cfg) {
431    cfg->fBlocks[cfg->fStart].fBefore = compute_start_state(*cfg);
432    std::set<BlockId> workList;
433    for (BlockId i = 0; i < cfg->fBlocks.size(); i++) {
434        workList.insert(i);
435    }
436    while (workList.size()) {
437        BlockId next = *workList.begin();
438        workList.erase(workList.begin());
439        this->scanCFG(cfg, next, &workList);
440    }
441}
442
443/**
444 * Attempts to replace the expression pointed to by iter with a new one (in both the CFG and the
445 * IR). If the expression can be cleanly removed, returns true and updates the iterator to point to
446 * the newly-inserted element. Otherwise updates only the IR and returns false (and the CFG will
447 * need to be regenerated).
448 */
449bool try_replace_expression(BasicBlock* b,
450                            std::vector<BasicBlock::Node>::iterator* iter,
451                            std::unique_ptr<Expression>* newExpression) {
452    std::unique_ptr<Expression>* target = (*iter)->expression();
453    if (!b->tryRemoveExpression(iter)) {
454        *target = std::move(*newExpression);
455        return false;
456    }
457    *target = std::move(*newExpression);
458    return b->tryInsertExpression(iter, target);
459}
460
461/**
462 * Returns true if the expression is a constant numeric literal with the specified value, or a
463 * constant vector with all elements equal to the specified value.
464 */
465bool is_constant(const Expression& expr, double value) {
466    switch (expr.fKind) {
467        case Expression::kIntLiteral_Kind:
468            return ((IntLiteral&) expr).fValue == value;
469        case Expression::kFloatLiteral_Kind:
470            return ((FloatLiteral&) expr).fValue == value;
471        case Expression::kConstructor_Kind: {
472            Constructor& c = (Constructor&) expr;
473            if (c.fType.kind() == Type::kVector_Kind && c.isConstant()) {
474                for (int i = 0; i < c.fType.columns(); ++i) {
475                    if (!is_constant(c.getVecComponent(i), value)) {
476                        return false;
477                    }
478                }
479                return true;
480            }
481            return false;
482        }
483        default:
484            return false;
485    }
486}
487
488/**
489 * Collapses the binary expression pointed to by iter down to just the right side (in both the IR
490 * and CFG structures).
491 */
492void delete_left(BasicBlock* b,
493                 std::vector<BasicBlock::Node>::iterator* iter,
494                 bool* outUpdated,
495                 bool* outNeedsRescan) {
496    *outUpdated = true;
497    std::unique_ptr<Expression>* target = (*iter)->expression();
498    ASSERT((*target)->fKind == Expression::kBinary_Kind);
499    BinaryExpression& bin = (BinaryExpression&) **target;
500    ASSERT(!bin.fLeft->hasSideEffects());
501    bool result;
502    if (bin.fOperator == Token::EQ) {
503        result = b->tryRemoveLValueBefore(iter, bin.fLeft.get());
504    } else {
505        result = b->tryRemoveExpressionBefore(iter, bin.fLeft.get());
506    }
507    *target = std::move(bin.fRight);
508    if (!result) {
509        *outNeedsRescan = true;
510        return;
511    }
512    if (*iter == b->fNodes.begin()) {
513        *outNeedsRescan = true;
514        return;
515    }
516    --(*iter);
517    if ((*iter)->fKind != BasicBlock::Node::kExpression_Kind ||
518        (*iter)->expression() != &bin.fRight) {
519        *outNeedsRescan = true;
520        return;
521    }
522    *iter = b->fNodes.erase(*iter);
523    ASSERT((*iter)->expression() == target);
524}
525
526/**
527 * Collapses the binary expression pointed to by iter down to just the left side (in both the IR and
528 * CFG structures).
529 */
530void delete_right(BasicBlock* b,
531                  std::vector<BasicBlock::Node>::iterator* iter,
532                  bool* outUpdated,
533                  bool* outNeedsRescan) {
534    *outUpdated = true;
535    std::unique_ptr<Expression>* target = (*iter)->expression();
536    ASSERT((*target)->fKind == Expression::kBinary_Kind);
537    BinaryExpression& bin = (BinaryExpression&) **target;
538    ASSERT(!bin.fRight->hasSideEffects());
539    if (!b->tryRemoveExpressionBefore(iter, bin.fRight.get())) {
540        *target = std::move(bin.fLeft);
541        *outNeedsRescan = true;
542        return;
543    }
544    *target = std::move(bin.fLeft);
545    if (*iter == b->fNodes.begin()) {
546        *outNeedsRescan = true;
547        return;
548    }
549    --(*iter);
550    if (((*iter)->fKind != BasicBlock::Node::kExpression_Kind ||
551        (*iter)->expression() != &bin.fLeft)) {
552        *outNeedsRescan = true;
553        return;
554    }
555    *iter = b->fNodes.erase(*iter);
556    ASSERT((*iter)->expression() == target);
557}
558
559/**
560 * Constructs the specified type using a single argument.
561 */
562static std::unique_ptr<Expression> construct(const Type& type, std::unique_ptr<Expression> v) {
563    std::vector<std::unique_ptr<Expression>> args;
564    args.push_back(std::move(v));
565    auto result = std::unique_ptr<Expression>(new Constructor(-1, type, std::move(args)));
566    return result;
567}
568
569/**
570 * Used in the implementations of vectorize_left and vectorize_right. Given a vector type and an
571 * expression x, deletes the expression pointed to by iter and replaces it with <type>(x).
572 */
573static void vectorize(BasicBlock* b,
574                      std::vector<BasicBlock::Node>::iterator* iter,
575                      const Type& type,
576                      std::unique_ptr<Expression>* otherExpression,
577                      bool* outUpdated,
578                      bool* outNeedsRescan) {
579    ASSERT((*(*iter)->expression())->fKind == Expression::kBinary_Kind);
580    ASSERT(type.kind() == Type::kVector_Kind);
581    ASSERT((*otherExpression)->fType.kind() == Type::kScalar_Kind);
582    *outUpdated = true;
583    std::unique_ptr<Expression>* target = (*iter)->expression();
584    if (!b->tryRemoveExpression(iter)) {
585        *target = construct(type, std::move(*otherExpression));
586        *outNeedsRescan = true;
587    } else {
588        *target = construct(type, std::move(*otherExpression));
589        if (!b->tryInsertExpression(iter, target)) {
590            *outNeedsRescan = true;
591        }
592    }
593}
594
595/**
596 * Given a binary expression of the form x <op> vec<n>(y), deletes the right side and vectorizes the
597 * left to yield vec<n>(x).
598 */
599static void vectorize_left(BasicBlock* b,
600                           std::vector<BasicBlock::Node>::iterator* iter,
601                           bool* outUpdated,
602                           bool* outNeedsRescan) {
603    BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression();
604    vectorize(b, iter, bin.fRight->fType, &bin.fLeft, outUpdated, outNeedsRescan);
605}
606
607/**
608 * Given a binary expression of the form vec<n>(x) <op> y, deletes the left side and vectorizes the
609 * right to yield vec<n>(y).
610 */
611static void vectorize_right(BasicBlock* b,
612                            std::vector<BasicBlock::Node>::iterator* iter,
613                            bool* outUpdated,
614                            bool* outNeedsRescan) {
615    BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression();
616    vectorize(b, iter, bin.fLeft->fType, &bin.fRight, outUpdated, outNeedsRescan);
617}
618
619// Mark that an expression which we were writing to is no longer being written to
620void clear_write(const Expression& expr) {
621    switch (expr.fKind) {
622        case Expression::kVariableReference_Kind: {
623            ((VariableReference&) expr).setRefKind(VariableReference::kRead_RefKind);
624            break;
625        }
626        case Expression::kFieldAccess_Kind:
627            clear_write(*((FieldAccess&) expr).fBase);
628            break;
629        case Expression::kSwizzle_Kind:
630            clear_write(*((Swizzle&) expr).fBase);
631            break;
632        case Expression::kIndex_Kind:
633            clear_write(*((IndexExpression&) expr).fBase);
634            break;
635        default:
636            ABORT("shouldn't be writing to this kind of expression\n");
637            break;
638    }
639}
640
641void Compiler::simplifyExpression(DefinitionMap& definitions,
642                                  BasicBlock& b,
643                                  std::vector<BasicBlock::Node>::iterator* iter,
644                                  std::unordered_set<const Variable*>* undefinedVariables,
645                                  bool* outUpdated,
646                                  bool* outNeedsRescan) {
647    Expression* expr = (*iter)->expression()->get();
648    ASSERT(expr);
649    if ((*iter)->fConstantPropagation) {
650        std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator, definitions);
651        if (optimized) {
652            *outUpdated = true;
653            if (!try_replace_expression(&b, iter, &optimized)) {
654                *outNeedsRescan = true;
655                return;
656            }
657            ASSERT((*iter)->fKind == BasicBlock::Node::kExpression_Kind);
658            expr = (*iter)->expression()->get();
659        }
660    }
661    switch (expr->fKind) {
662        case Expression::kVariableReference_Kind: {
663            const Variable& var = ((VariableReference*) expr)->fVariable;
664            if (var.fStorage == Variable::kLocal_Storage && !definitions[&var] &&
665                (*undefinedVariables).find(&var) == (*undefinedVariables).end()) {
666                (*undefinedVariables).insert(&var);
667                this->error(expr->fOffset,
668                            "'" + var.fName + "' has not been assigned");
669            }
670            break;
671        }
672        case Expression::kTernary_Kind: {
673            TernaryExpression* t = (TernaryExpression*) expr;
674            if (t->fTest->fKind == Expression::kBoolLiteral_Kind) {
675                // ternary has a constant test, replace it with either the true or
676                // false branch
677                if (((BoolLiteral&) *t->fTest).fValue) {
678                    (*iter)->setExpression(std::move(t->fIfTrue));
679                } else {
680                    (*iter)->setExpression(std::move(t->fIfFalse));
681                }
682                *outUpdated = true;
683                *outNeedsRescan = true;
684            }
685            break;
686        }
687        case Expression::kBinary_Kind: {
688            BinaryExpression* bin = (BinaryExpression*) expr;
689            if (dead_assignment(*bin)) {
690                delete_left(&b, iter, outUpdated, outNeedsRescan);
691                break;
692            }
693            // collapse useless expressions like x * 1 or x + 0
694            if (((bin->fLeft->fType.kind()  != Type::kScalar_Kind) &&
695                 (bin->fLeft->fType.kind()  != Type::kVector_Kind)) ||
696                ((bin->fRight->fType.kind() != Type::kScalar_Kind) &&
697                 (bin->fRight->fType.kind() != Type::kVector_Kind))) {
698                break;
699            }
700            switch (bin->fOperator) {
701                case Token::STAR:
702                    if (is_constant(*bin->fLeft, 1)) {
703                        if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
704                            bin->fRight->fType.kind() == Type::kScalar_Kind) {
705                            // float4(1) * x -> float4(x)
706                            vectorize_right(&b, iter, outUpdated, outNeedsRescan);
707                        } else {
708                            // 1 * x -> x
709                            // 1 * float4(x) -> float4(x)
710                            // float4(1) * float4(x) -> float4(x)
711                            delete_left(&b, iter, outUpdated, outNeedsRescan);
712                        }
713                    }
714                    else if (is_constant(*bin->fLeft, 0)) {
715                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
716                            bin->fRight->fType.kind() == Type::kVector_Kind &&
717                            !bin->fRight->hasSideEffects()) {
718                            // 0 * float4(x) -> float4(0)
719                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
720                        } else {
721                            // 0 * x -> 0
722                            // float4(0) * x -> float4(0)
723                            // float4(0) * float4(x) -> float4(0)
724                            if (!bin->fRight->hasSideEffects()) {
725                                delete_right(&b, iter, outUpdated, outNeedsRescan);
726                            }
727                        }
728                    }
729                    else if (is_constant(*bin->fRight, 1)) {
730                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
731                            bin->fRight->fType.kind() == Type::kVector_Kind) {
732                            // x * float4(1) -> float4(x)
733                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
734                        } else {
735                            // x * 1 -> x
736                            // float4(x) * 1 -> float4(x)
737                            // float4(x) * float4(1) -> float4(x)
738                            delete_right(&b, iter, outUpdated, outNeedsRescan);
739                        }
740                    }
741                    else if (is_constant(*bin->fRight, 0)) {
742                        if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
743                            bin->fRight->fType.kind() == Type::kScalar_Kind &&
744                            !bin->fLeft->hasSideEffects()) {
745                            // float4(x) * 0 -> float4(0)
746                            vectorize_right(&b, iter, outUpdated, outNeedsRescan);
747                        } else {
748                            // x * 0 -> 0
749                            // x * float4(0) -> float4(0)
750                            // float4(x) * float4(0) -> float4(0)
751                            if (!bin->fLeft->hasSideEffects()) {
752                                delete_left(&b, iter, outUpdated, outNeedsRescan);
753                            }
754                        }
755                    }
756                    break;
757                case Token::PLUS:
758                    if (is_constant(*bin->fLeft, 0)) {
759                        if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
760                            bin->fRight->fType.kind() == Type::kScalar_Kind) {
761                            // float4(0) + x -> float4(x)
762                            vectorize_right(&b, iter, outUpdated, outNeedsRescan);
763                        } else {
764                            // 0 + x -> x
765                            // 0 + float4(x) -> float4(x)
766                            // float4(0) + float4(x) -> float4(x)
767                            delete_left(&b, iter, outUpdated, outNeedsRescan);
768                        }
769                    } else if (is_constant(*bin->fRight, 0)) {
770                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
771                            bin->fRight->fType.kind() == Type::kVector_Kind) {
772                            // x + float4(0) -> float4(x)
773                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
774                        } else {
775                            // x + 0 -> x
776                            // float4(x) + 0 -> float4(x)
777                            // float4(x) + float4(0) -> float4(x)
778                            delete_right(&b, iter, outUpdated, outNeedsRescan);
779                        }
780                    }
781                    break;
782                case Token::MINUS:
783                    if (is_constant(*bin->fRight, 0)) {
784                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
785                            bin->fRight->fType.kind() == Type::kVector_Kind) {
786                            // x - float4(0) -> float4(x)
787                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
788                        } else {
789                            // x - 0 -> x
790                            // float4(x) - 0 -> float4(x)
791                            // float4(x) - float4(0) -> float4(x)
792                            delete_right(&b, iter, outUpdated, outNeedsRescan);
793                        }
794                    }
795                    break;
796                case Token::SLASH:
797                    if (is_constant(*bin->fRight, 1)) {
798                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
799                            bin->fRight->fType.kind() == Type::kVector_Kind) {
800                            // x / float4(1) -> float4(x)
801                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
802                        } else {
803                            // x / 1 -> x
804                            // float4(x) / 1 -> float4(x)
805                            // float4(x) / float4(1) -> float4(x)
806                            delete_right(&b, iter, outUpdated, outNeedsRescan);
807                        }
808                    } else if (is_constant(*bin->fLeft, 0)) {
809                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
810                            bin->fRight->fType.kind() == Type::kVector_Kind &&
811                            !bin->fRight->hasSideEffects()) {
812                            // 0 / float4(x) -> float4(0)
813                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
814                        } else {
815                            // 0 / x -> 0
816                            // float4(0) / x -> float4(0)
817                            // float4(0) / float4(x) -> float4(0)
818                            if (!bin->fRight->hasSideEffects()) {
819                                delete_right(&b, iter, outUpdated, outNeedsRescan);
820                            }
821                        }
822                    }
823                    break;
824                case Token::PLUSEQ:
825                    if (is_constant(*bin->fRight, 0)) {
826                        clear_write(*bin->fLeft);
827                        delete_right(&b, iter, outUpdated, outNeedsRescan);
828                    }
829                    break;
830                case Token::MINUSEQ:
831                    if (is_constant(*bin->fRight, 0)) {
832                        clear_write(*bin->fLeft);
833                        delete_right(&b, iter, outUpdated, outNeedsRescan);
834                    }
835                    break;
836                case Token::STAREQ:
837                    if (is_constant(*bin->fRight, 1)) {
838                        clear_write(*bin->fLeft);
839                        delete_right(&b, iter, outUpdated, outNeedsRescan);
840                    }
841                    break;
842                case Token::SLASHEQ:
843                    if (is_constant(*bin->fRight, 1)) {
844                        clear_write(*bin->fLeft);
845                        delete_right(&b, iter, outUpdated, outNeedsRescan);
846                    }
847                    break;
848                default:
849                    break;
850            }
851        }
852        default:
853            break;
854    }
855}
856
857// returns true if this statement could potentially execute a break at the current level (we ignore
858// nested loops and switches, since any breaks inside of them will merely break the loop / switch)
859static bool contains_break(Statement& s) {
860    switch (s.fKind) {
861        case Statement::kBlock_Kind:
862            for (const auto& sub : ((Block&) s).fStatements) {
863                if (contains_break(*sub)) {
864                    return true;
865                }
866            }
867            return false;
868        case Statement::kBreak_Kind:
869            return true;
870        case Statement::kIf_Kind: {
871            const IfStatement& i = (IfStatement&) s;
872            return contains_break(*i.fIfTrue) || (i.fIfFalse && contains_break(*i.fIfFalse));
873        }
874        default:
875            return false;
876    }
877}
878
879// Returns a block containing all of the statements that will be run if the given case matches
880// (which, owing to the statements being owned by unique_ptrs, means the switch itself will be
881// broken by this call and must then be discarded).
882// Returns null (and leaves the switch unmodified) if no such simple reduction is possible, such as
883// when break statements appear inside conditionals.
884static std::unique_ptr<Statement> block_for_case(SwitchStatement* s, SwitchCase* c) {
885    bool capturing = false;
886    std::vector<std::unique_ptr<Statement>*> statementPtrs;
887    for (const auto& current : s->fCases) {
888        if (current.get() == c) {
889            capturing = true;
890        }
891        if (capturing) {
892            for (auto& stmt : current->fStatements) {
893                if (stmt->fKind == Statement::kBreak_Kind) {
894                    capturing = false;
895                    break;
896                }
897                if (contains_break(*stmt)) {
898                    return nullptr;
899                }
900                statementPtrs.push_back(&stmt);
901            }
902            if (!capturing) {
903                break;
904            }
905        }
906    }
907    std::vector<std::unique_ptr<Statement>> statements;
908    for (const auto& s : statementPtrs) {
909        statements.push_back(std::move(*s));
910    }
911    return std::unique_ptr<Statement>(new Block(-1, std::move(statements), s->fSymbols));
912}
913
914void Compiler::simplifyStatement(DefinitionMap& definitions,
915                                 BasicBlock& b,
916                                 std::vector<BasicBlock::Node>::iterator* iter,
917                                 std::unordered_set<const Variable*>* undefinedVariables,
918                                 bool* outUpdated,
919                                 bool* outNeedsRescan) {
920    Statement* stmt = (*iter)->statement()->get();
921    switch (stmt->fKind) {
922        case Statement::kVarDeclaration_Kind: {
923            const auto& varDecl = (VarDeclaration&) *stmt;
924            if (varDecl.fVar->dead() &&
925                (!varDecl.fValue ||
926                 !varDecl.fValue->hasSideEffects())) {
927                if (varDecl.fValue) {
928                    ASSERT((*iter)->statement()->get() == stmt);
929                    if (!b.tryRemoveExpressionBefore(iter, varDecl.fValue.get())) {
930                        *outNeedsRescan = true;
931                    }
932                }
933                (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
934                *outUpdated = true;
935            }
936            break;
937        }
938        case Statement::kIf_Kind: {
939            IfStatement& i = (IfStatement&) *stmt;
940            if (i.fTest->fKind == Expression::kBoolLiteral_Kind) {
941                // constant if, collapse down to a single branch
942                if (((BoolLiteral&) *i.fTest).fValue) {
943                    ASSERT(i.fIfTrue);
944                    (*iter)->setStatement(std::move(i.fIfTrue));
945                } else {
946                    if (i.fIfFalse) {
947                        (*iter)->setStatement(std::move(i.fIfFalse));
948                    } else {
949                        (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
950                    }
951                }
952                *outUpdated = true;
953                *outNeedsRescan = true;
954                break;
955            }
956            if (i.fIfFalse && i.fIfFalse->isEmpty()) {
957                // else block doesn't do anything, remove it
958                i.fIfFalse.reset();
959                *outUpdated = true;
960                *outNeedsRescan = true;
961            }
962            if (!i.fIfFalse && i.fIfTrue->isEmpty()) {
963                // if block doesn't do anything, no else block
964                if (i.fTest->hasSideEffects()) {
965                    // test has side effects, keep it
966                    (*iter)->setStatement(std::unique_ptr<Statement>(
967                                                      new ExpressionStatement(std::move(i.fTest))));
968                } else {
969                    // no if, no else, no test side effects, kill the whole if
970                    // statement
971                    (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
972                }
973                *outUpdated = true;
974                *outNeedsRescan = true;
975            }
976            break;
977        }
978        case Statement::kSwitch_Kind: {
979            SwitchStatement& s = (SwitchStatement&) *stmt;
980            if (s.fValue->isConstant()) {
981                // switch is constant, replace it with the case that matches
982                bool found = false;
983                SwitchCase* defaultCase = nullptr;
984                for (const auto& c : s.fCases) {
985                    if (!c->fValue) {
986                        defaultCase = c.get();
987                        continue;
988                    }
989                    ASSERT(c->fValue->fKind == s.fValue->fKind);
990                    found = c->fValue->compareConstant(fContext, *s.fValue);
991                    if (found) {
992                        std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get());
993                        if (newBlock) {
994                            (*iter)->setStatement(std::move(newBlock));
995                            break;
996                        } else {
997                            if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
998                                this->error(s.fOffset,
999                                            "static switch contains non-static conditional break");
1000                                s.fIsStatic = false;
1001                            }
1002                            return; // can't simplify
1003                        }
1004                    }
1005                }
1006                if (!found) {
1007                    // no matching case. use default if it exists, or kill the whole thing
1008                    if (defaultCase) {
1009                        std::unique_ptr<Statement> newBlock = block_for_case(&s, defaultCase);
1010                        if (newBlock) {
1011                            (*iter)->setStatement(std::move(newBlock));
1012                        } else {
1013                            if (s.fIsStatic && !(fFlags & kPermitInvalidStaticTests_Flag)) {
1014                                this->error(s.fOffset,
1015                                            "static switch contains non-static conditional break");
1016                                s.fIsStatic = false;
1017                            }
1018                            return; // can't simplify
1019                        }
1020                    } else {
1021                        (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1022                    }
1023                }
1024                *outUpdated = true;
1025                *outNeedsRescan = true;
1026            }
1027            break;
1028        }
1029        case Statement::kExpression_Kind: {
1030            ExpressionStatement& e = (ExpressionStatement&) *stmt;
1031            ASSERT((*iter)->statement()->get() == &e);
1032            if (!e.fExpression->hasSideEffects()) {
1033                // Expression statement with no side effects, kill it
1034                if (!b.tryRemoveExpressionBefore(iter, e.fExpression.get())) {
1035                    *outNeedsRescan = true;
1036                }
1037                ASSERT((*iter)->statement()->get() == stmt);
1038                (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
1039                *outUpdated = true;
1040            }
1041            break;
1042        }
1043        default:
1044            break;
1045    }
1046}
1047
1048void Compiler::scanCFG(FunctionDefinition& f) {
1049    CFG cfg = CFGGenerator().getCFG(f);
1050    this->computeDataFlow(&cfg);
1051
1052    // check for unreachable code
1053    for (size_t i = 0; i < cfg.fBlocks.size(); i++) {
1054        if (i != cfg.fStart && !cfg.fBlocks[i].fEntrances.size() &&
1055            cfg.fBlocks[i].fNodes.size()) {
1056            int offset;
1057            switch (cfg.fBlocks[i].fNodes[0].fKind) {
1058                case BasicBlock::Node::kStatement_Kind:
1059                    offset = (*cfg.fBlocks[i].fNodes[0].statement())->fOffset;
1060                    break;
1061                case BasicBlock::Node::kExpression_Kind:
1062                    offset = (*cfg.fBlocks[i].fNodes[0].expression())->fOffset;
1063                    break;
1064            }
1065            this->error(offset, String("unreachable"));
1066        }
1067    }
1068    if (fErrorCount) {
1069        return;
1070    }
1071
1072    // check for dead code & undefined variables, perform constant propagation
1073    std::unordered_set<const Variable*> undefinedVariables;
1074    bool updated;
1075    bool needsRescan = false;
1076    do {
1077        if (needsRescan) {
1078            cfg = CFGGenerator().getCFG(f);
1079            this->computeDataFlow(&cfg);
1080            needsRescan = false;
1081        }
1082
1083        updated = false;
1084        for (BasicBlock& b : cfg.fBlocks) {
1085            DefinitionMap definitions = b.fBefore;
1086
1087            for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) {
1088                if (iter->fKind == BasicBlock::Node::kExpression_Kind) {
1089                    this->simplifyExpression(definitions, b, &iter, &undefinedVariables, &updated,
1090                                             &needsRescan);
1091                } else {
1092                    this->simplifyStatement(definitions, b, &iter, &undefinedVariables, &updated,
1093                                             &needsRescan);
1094                }
1095                if (needsRescan) {
1096                    break;
1097                }
1098                this->addDefinitions(*iter, &definitions);
1099            }
1100        }
1101    } while (updated);
1102    ASSERT(!needsRescan);
1103
1104    // verify static ifs & switches, clean up dead variable decls
1105    for (BasicBlock& b : cfg.fBlocks) {
1106        DefinitionMap definitions = b.fBefore;
1107
1108        for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan;) {
1109            if (iter->fKind == BasicBlock::Node::kStatement_Kind) {
1110                const Statement& s = **iter->statement();
1111                switch (s.fKind) {
1112                    case Statement::kIf_Kind:
1113                        if (((const IfStatement&) s).fIsStatic &&
1114                            !(fFlags & kPermitInvalidStaticTests_Flag)) {
1115                            this->error(s.fOffset, "static if has non-static test");
1116                        }
1117                        ++iter;
1118                        break;
1119                    case Statement::kSwitch_Kind:
1120                        if (((const SwitchStatement&) s).fIsStatic &&
1121                             !(fFlags & kPermitInvalidStaticTests_Flag)) {
1122                            this->error(s.fOffset, "static switch has non-static test");
1123                        }
1124                        ++iter;
1125                        break;
1126                    case Statement::kVarDeclarations_Kind: {
1127                        VarDeclarations& decls = *((VarDeclarationsStatement&) s).fDeclaration;
1128                        for (auto varIter = decls.fVars.begin(); varIter != decls.fVars.end();) {
1129                            if ((*varIter)->fKind == Statement::kNop_Kind) {
1130                                varIter = decls.fVars.erase(varIter);
1131                            } else {
1132                                ++varIter;
1133                            }
1134                        }
1135                        if (!decls.fVars.size()) {
1136                            iter = b.fNodes.erase(iter);
1137                        } else {
1138                            ++iter;
1139                        }
1140                        break;
1141                    }
1142                    default:
1143                        ++iter;
1144                        break;
1145                }
1146            } else {
1147                ++iter;
1148            }
1149        }
1150    }
1151
1152    // check for missing return
1153    if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) {
1154        if (cfg.fBlocks[cfg.fExit].fEntrances.size()) {
1155            this->error(f.fOffset, String("function can exit without returning a value"));
1156        }
1157    }
1158}
1159
1160std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, String text,
1161                                                  const Program::Settings& settings) {
1162    fErrorText = "";
1163    fErrorCount = 0;
1164    fIRGenerator->start(&settings);
1165    std::vector<std::unique_ptr<ProgramElement>> elements;
1166    switch (kind) {
1167        case Program::kVertex_Kind:
1168            fIRGenerator->convertProgram(kind, SKSL_VERT_INCLUDE, strlen(SKSL_VERT_INCLUDE),
1169                                         *fTypes, &elements);
1170            break;
1171        case Program::kFragment_Kind:
1172            fIRGenerator->convertProgram(kind, SKSL_FRAG_INCLUDE, strlen(SKSL_FRAG_INCLUDE),
1173                                         *fTypes, &elements);
1174            break;
1175        case Program::kGeometry_Kind:
1176            fIRGenerator->convertProgram(kind, SKSL_GEOM_INCLUDE, strlen(SKSL_GEOM_INCLUDE),
1177                                         *fTypes, &elements);
1178            break;
1179        case Program::kFragmentProcessor_Kind:
1180            fIRGenerator->convertProgram(kind, SKSL_FP_INCLUDE, strlen(SKSL_FP_INCLUDE), *fTypes,
1181                                         &elements);
1182            break;
1183    }
1184    fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
1185    for (auto& element : elements) {
1186        if (element->fKind == ProgramElement::kEnum_Kind) {
1187            ((Enum&) *element).fBuiltin = true;
1188        }
1189    }
1190    std::unique_ptr<String> textPtr(new String(std::move(text)));
1191    fSource = textPtr.get();
1192    fIRGenerator->convertProgram(kind, textPtr->c_str(), textPtr->size(), *fTypes, &elements);
1193    if (!fErrorCount) {
1194        for (auto& element : elements) {
1195            if (element->fKind == ProgramElement::kFunction_Kind) {
1196                this->scanCFG((FunctionDefinition&) *element);
1197            }
1198        }
1199    }
1200    auto result = std::unique_ptr<Program>(new Program(kind,
1201                                                       std::move(textPtr),
1202                                                       settings,
1203                                                       &fContext,
1204                                                       std::move(elements),
1205                                                       fIRGenerator->fSymbolTable,
1206                                                       fIRGenerator->fInputs));
1207    fIRGenerator->finish();
1208    fSource = nullptr;
1209    this->writeErrorCount();
1210    if (fErrorCount) {
1211        return nullptr;
1212    }
1213    return result;
1214}
1215
1216bool Compiler::toSPIRV(const Program& program, OutputStream& out) {
1217#ifdef SK_ENABLE_SPIRV_VALIDATION
1218    StringStream buffer;
1219    fSource = program.fSource.get();
1220    SPIRVCodeGenerator cg(&fContext, &program, this, &buffer);
1221    bool result = cg.generateCode();
1222    fSource = nullptr;
1223    if (result) {
1224        spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_0);
1225        const String& data = buffer.str();
1226        ASSERT(0 == data.size() % 4);
1227        auto dumpmsg = [](spv_message_level_t, const char*, const spv_position_t&, const char* m) {
1228            SkDebugf("SPIR-V validation error: %s\n", m);
1229        };
1230        tools.SetMessageConsumer(dumpmsg);
1231        // Verify that the SPIR-V we produced is valid. If this assert fails, check the logs prior
1232        // to the failure to see the validation errors.
1233        ASSERT_RESULT(tools.Validate((const uint32_t*) data.c_str(), data.size() / 4));
1234        out.write(data.c_str(), data.size());
1235    }
1236#else
1237    fSource = program.fSource.get();
1238    SPIRVCodeGenerator cg(&fContext, &program, this, &out);
1239    bool result = cg.generateCode();
1240    fSource = nullptr;
1241#endif
1242    this->writeErrorCount();
1243    return result;
1244}
1245
1246bool Compiler::toSPIRV(const Program& program, String* out) {
1247    StringStream buffer;
1248    bool result = this->toSPIRV(program, buffer);
1249    if (result) {
1250        *out = buffer.str();
1251    }
1252    return result;
1253}
1254
1255bool Compiler::toGLSL(const Program& program, OutputStream& out) {
1256    fSource = program.fSource.get();
1257    GLSLCodeGenerator cg(&fContext, &program, this, &out);
1258    bool result = cg.generateCode();
1259    fSource = nullptr;
1260    this->writeErrorCount();
1261    return result;
1262}
1263
1264bool Compiler::toGLSL(const Program& program, String* out) {
1265    StringStream buffer;
1266    bool result = this->toGLSL(program, buffer);
1267    if (result) {
1268        *out = buffer.str();
1269    }
1270    return result;
1271}
1272
1273bool Compiler::toMetal(const Program& program, OutputStream& out) {
1274    MetalCodeGenerator cg(&fContext, &program, this, &out);
1275    bool result = cg.generateCode();
1276    this->writeErrorCount();
1277    return result;
1278}
1279
1280bool Compiler::toCPP(const Program& program, String name, OutputStream& out) {
1281    fSource = program.fSource.get();
1282    CPPCodeGenerator cg(&fContext, &program, this, name, &out);
1283    bool result = cg.generateCode();
1284    fSource = nullptr;
1285    this->writeErrorCount();
1286    return result;
1287}
1288
1289bool Compiler::toH(const Program& program, String name, OutputStream& out) {
1290    fSource = program.fSource.get();
1291    HCodeGenerator cg(&fContext, &program, this, name, &out);
1292    bool result = cg.generateCode();
1293    fSource = nullptr;
1294    this->writeErrorCount();
1295    return result;
1296}
1297
1298const char* Compiler::OperatorName(Token::Kind kind) {
1299    switch (kind) {
1300        case Token::PLUS:         return "+";
1301        case Token::MINUS:        return "-";
1302        case Token::STAR:         return "*";
1303        case Token::SLASH:        return "/";
1304        case Token::PERCENT:      return "%";
1305        case Token::SHL:          return "<<";
1306        case Token::SHR:          return ">>";
1307        case Token::LOGICALNOT:   return "!";
1308        case Token::LOGICALAND:   return "&&";
1309        case Token::LOGICALOR:    return "||";
1310        case Token::LOGICALXOR:   return "^^";
1311        case Token::BITWISENOT:   return "~";
1312        case Token::BITWISEAND:   return "&";
1313        case Token::BITWISEOR:    return "|";
1314        case Token::BITWISEXOR:   return "^";
1315        case Token::EQ:           return "=";
1316        case Token::EQEQ:         return "==";
1317        case Token::NEQ:          return "!=";
1318        case Token::LT:           return "<";
1319        case Token::GT:           return ">";
1320        case Token::LTEQ:         return "<=";
1321        case Token::GTEQ:         return ">=";
1322        case Token::PLUSEQ:       return "+=";
1323        case Token::MINUSEQ:      return "-=";
1324        case Token::STAREQ:       return "*=";
1325        case Token::SLASHEQ:      return "/=";
1326        case Token::PERCENTEQ:    return "%=";
1327        case Token::SHLEQ:        return "<<=";
1328        case Token::SHREQ:        return ">>=";
1329        case Token::LOGICALANDEQ: return "&&=";
1330        case Token::LOGICALOREQ:  return "||=";
1331        case Token::LOGICALXOREQ: return "^^=";
1332        case Token::BITWISEANDEQ: return "&=";
1333        case Token::BITWISEOREQ:  return "|=";
1334        case Token::BITWISEXOREQ: return "^=";
1335        case Token::PLUSPLUS:     return "++";
1336        case Token::MINUSMINUS:   return "--";
1337        case Token::COMMA:        return ",";
1338        default:
1339            ABORT("unsupported operator: %d\n", kind);
1340    }
1341}
1342
1343
1344bool Compiler::IsAssignment(Token::Kind op) {
1345    switch (op) {
1346        case Token::EQ:           // fall through
1347        case Token::PLUSEQ:       // fall through
1348        case Token::MINUSEQ:      // fall through
1349        case Token::STAREQ:       // fall through
1350        case Token::SLASHEQ:      // fall through
1351        case Token::PERCENTEQ:    // fall through
1352        case Token::SHLEQ:        // fall through
1353        case Token::SHREQ:        // fall through
1354        case Token::BITWISEOREQ:  // fall through
1355        case Token::BITWISEXOREQ: // fall through
1356        case Token::BITWISEANDEQ: // fall through
1357        case Token::LOGICALOREQ:  // fall through
1358        case Token::LOGICALXOREQ: // fall through
1359        case Token::LOGICALANDEQ:
1360            return true;
1361        default:
1362            return false;
1363    }
1364}
1365
1366Position Compiler::position(int offset) {
1367    ASSERT(fSource);
1368    int line = 1;
1369    int column = 1;
1370    for (int i = 0; i < offset; i++) {
1371        if ((*fSource)[i] == '\n') {
1372            ++line;
1373            column = 1;
1374        }
1375        else {
1376            ++column;
1377        }
1378    }
1379    return Position(line, column);
1380}
1381
1382void Compiler::error(int offset, String msg) {
1383    fErrorCount++;
1384    Position pos = this->position(offset);
1385    fErrorText += "error: " + to_string(pos.fLine) + ": " + msg.c_str() + "\n";
1386}
1387
1388String Compiler::errorText() {
1389    String result = fErrorText;
1390    return result;
1391}
1392
1393void Compiler::writeErrorCount() {
1394    if (fErrorCount) {
1395        fErrorText += to_string(fErrorCount) + " error";
1396        if (fErrorCount > 1) {
1397            fErrorText += "s";
1398        }
1399        fErrorText += "\n";
1400    }
1401}
1402
1403} // namespace
1404