SkSLCompiler.cpp revision 84cda40bd7e98f4e19574c6e946395e244901408
1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "SkSLCompiler.h"
9
10#include "ast/SkSLASTPrecision.h"
11#include "SkSLCFGGenerator.h"
12#include "SkSLGLSLCodeGenerator.h"
13#include "SkSLIRGenerator.h"
14#include "SkSLParser.h"
15#include "SkSLSPIRVCodeGenerator.h"
16#include "ir/SkSLExpression.h"
17#include "ir/SkSLExpressionStatement.h"
18#include "ir/SkSLIntLiteral.h"
19#include "ir/SkSLModifiersDeclaration.h"
20#include "ir/SkSLNop.h"
21#include "ir/SkSLSymbolTable.h"
22#include "ir/SkSLTernaryExpression.h"
23#include "ir/SkSLUnresolvedFunction.h"
24#include "ir/SkSLVarDeclarations.h"
25
26#ifdef SK_ENABLE_SPIRV_VALIDATION
27#include "spirv-tools/libspirv.hpp"
28#endif
29
30#define STRINGIFY(x) #x
31
32// include the built-in shader symbols as static strings
33
34static const char* SKSL_INCLUDE =
35#include "sksl.include"
36;
37
38static const char* SKSL_VERT_INCLUDE =
39#include "sksl_vert.include"
40;
41
42static const char* SKSL_FRAG_INCLUDE =
43#include "sksl_frag.include"
44;
45
46static const char* SKSL_GEOM_INCLUDE =
47#include "sksl_geom.include"
48;
49
50namespace SkSL {
51
52Compiler::Compiler()
53: fErrorCount(0) {
54    auto types = std::shared_ptr<SymbolTable>(new SymbolTable(this));
55    auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, this));
56    fIRGenerator = new IRGenerator(&fContext, symbols, *this);
57    fTypes = types;
58    #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \
59                                                   fContext.f ## t ## _Type.get())
60    ADD_TYPE(Void);
61    ADD_TYPE(Float);
62    ADD_TYPE(Vec2);
63    ADD_TYPE(Vec3);
64    ADD_TYPE(Vec4);
65    ADD_TYPE(Double);
66    ADD_TYPE(DVec2);
67    ADD_TYPE(DVec3);
68    ADD_TYPE(DVec4);
69    ADD_TYPE(Int);
70    ADD_TYPE(IVec2);
71    ADD_TYPE(IVec3);
72    ADD_TYPE(IVec4);
73    ADD_TYPE(UInt);
74    ADD_TYPE(UVec2);
75    ADD_TYPE(UVec3);
76    ADD_TYPE(UVec4);
77    ADD_TYPE(Bool);
78    ADD_TYPE(BVec2);
79    ADD_TYPE(BVec3);
80    ADD_TYPE(BVec4);
81    ADD_TYPE(Mat2x2);
82    types->addWithoutOwnership(String("mat2x2"), fContext.fMat2x2_Type.get());
83    ADD_TYPE(Mat2x3);
84    ADD_TYPE(Mat2x4);
85    ADD_TYPE(Mat3x2);
86    ADD_TYPE(Mat3x3);
87    types->addWithoutOwnership(String("mat3x3"), fContext.fMat3x3_Type.get());
88    ADD_TYPE(Mat3x4);
89    ADD_TYPE(Mat4x2);
90    ADD_TYPE(Mat4x3);
91    ADD_TYPE(Mat4x4);
92    types->addWithoutOwnership(String("mat4x4"), fContext.fMat4x4_Type.get());
93    ADD_TYPE(GenType);
94    ADD_TYPE(GenDType);
95    ADD_TYPE(GenIType);
96    ADD_TYPE(GenUType);
97    ADD_TYPE(GenBType);
98    ADD_TYPE(Mat);
99    ADD_TYPE(Vec);
100    ADD_TYPE(GVec);
101    ADD_TYPE(GVec2);
102    ADD_TYPE(GVec3);
103    ADD_TYPE(GVec4);
104    ADD_TYPE(DVec);
105    ADD_TYPE(IVec);
106    ADD_TYPE(UVec);
107    ADD_TYPE(BVec);
108
109    ADD_TYPE(Sampler1D);
110    ADD_TYPE(Sampler2D);
111    ADD_TYPE(Sampler3D);
112    ADD_TYPE(SamplerExternalOES);
113    ADD_TYPE(SamplerCube);
114    ADD_TYPE(Sampler2DRect);
115    ADD_TYPE(Sampler1DArray);
116    ADD_TYPE(Sampler2DArray);
117    ADD_TYPE(SamplerCubeArray);
118    ADD_TYPE(SamplerBuffer);
119    ADD_TYPE(Sampler2DMS);
120    ADD_TYPE(Sampler2DMSArray);
121
122    ADD_TYPE(ISampler2D);
123
124    ADD_TYPE(Image2D);
125    ADD_TYPE(IImage2D);
126
127    ADD_TYPE(SubpassInput);
128    ADD_TYPE(SubpassInputMS);
129
130    ADD_TYPE(GSampler1D);
131    ADD_TYPE(GSampler2D);
132    ADD_TYPE(GSampler3D);
133    ADD_TYPE(GSamplerCube);
134    ADD_TYPE(GSampler2DRect);
135    ADD_TYPE(GSampler1DArray);
136    ADD_TYPE(GSampler2DArray);
137    ADD_TYPE(GSamplerCubeArray);
138    ADD_TYPE(GSamplerBuffer);
139    ADD_TYPE(GSampler2DMS);
140    ADD_TYPE(GSampler2DMSArray);
141
142    ADD_TYPE(Sampler1DShadow);
143    ADD_TYPE(Sampler2DShadow);
144    ADD_TYPE(SamplerCubeShadow);
145    ADD_TYPE(Sampler2DRectShadow);
146    ADD_TYPE(Sampler1DArrayShadow);
147    ADD_TYPE(Sampler2DArrayShadow);
148    ADD_TYPE(SamplerCubeArrayShadow);
149    ADD_TYPE(GSampler2DArrayShadow);
150    ADD_TYPE(GSamplerCubeArrayShadow);
151
152    String skCapsName("sk_Caps");
153    Variable* skCaps = new Variable(Position(), Modifiers(), skCapsName,
154                                    *fContext.fSkCaps_Type, Variable::kGlobal_Storage);
155    fIRGenerator->fSymbolTable->add(skCapsName, std::unique_ptr<Symbol>(skCaps));
156
157    Modifiers::Flag ignored1;
158    std::vector<std::unique_ptr<ProgramElement>> ignored2;
159    this->internalConvertProgram(String(SKSL_INCLUDE), &ignored1, &ignored2);
160    fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
161    ASSERT(!fErrorCount);
162}
163
164Compiler::~Compiler() {
165    delete fIRGenerator;
166}
167
168// add the definition created by assigning to the lvalue to the definition set
169void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expression>* expr,
170                             DefinitionMap* definitions) {
171    switch (lvalue->fKind) {
172        case Expression::kVariableReference_Kind: {
173            const Variable& var = ((VariableReference*) lvalue)->fVariable;
174            if (var.fStorage == Variable::kLocal_Storage) {
175                (*definitions)[&var] = expr;
176            }
177            break;
178        }
179        case Expression::kSwizzle_Kind:
180            // We consider the variable written to as long as at least some of its components have
181            // been written to. This will lead to some false negatives (we won't catch it if you
182            // write to foo.x and then read foo.y), but being stricter could lead to false positives
183            // (we write to foo.x, and then pass foo to a function which happens to only read foo.x,
184            // but since we pass foo as a whole it is flagged as an error) unless we perform a much
185            // more complicated whole-program analysis. This is probably good enough.
186            this->addDefinition(((Swizzle*) lvalue)->fBase.get(),
187                                (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
188                                definitions);
189            break;
190        case Expression::kIndex_Kind:
191            // see comments in Swizzle
192            this->addDefinition(((IndexExpression*) lvalue)->fBase.get(),
193                                (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
194                                definitions);
195            break;
196        case Expression::kFieldAccess_Kind:
197            // see comments in Swizzle
198            this->addDefinition(((FieldAccess*) lvalue)->fBase.get(),
199                                (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
200                                definitions);
201            break;
202        default:
203            // not an lvalue, can't happen
204            ASSERT(false);
205    }
206}
207
208// add local variables defined by this node to the set
209void Compiler::addDefinitions(const BasicBlock::Node& node,
210                              DefinitionMap* definitions) {
211    switch (node.fKind) {
212        case BasicBlock::Node::kExpression_Kind: {
213            ASSERT(node.expression());
214            const Expression* expr = (Expression*) node.expression()->get();
215            switch (expr->fKind) {
216                case Expression::kBinary_Kind: {
217                    BinaryExpression* b = (BinaryExpression*) expr;
218                    if (b->fOperator == Token::EQ) {
219                        this->addDefinition(b->fLeft.get(), &b->fRight, definitions);
220                    } else if (Token::IsAssignment(b->fOperator)) {
221                        this->addDefinition(
222                                       b->fLeft.get(),
223                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
224                                       definitions);
225
226                    }
227                    break;
228                }
229                case Expression::kPrefix_Kind: {
230                    const PrefixExpression* p = (PrefixExpression*) expr;
231                    if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
232                        this->addDefinition(
233                                       p->fOperand.get(),
234                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
235                                       definitions);
236                    }
237                    break;
238                }
239                case Expression::kPostfix_Kind: {
240                    const PostfixExpression* p = (PostfixExpression*) expr;
241                    if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) {
242                        this->addDefinition(
243                                       p->fOperand.get(),
244                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
245                                       definitions);
246                    }
247                    break;
248                }
249                case Expression::kVariableReference_Kind: {
250                    const VariableReference* v = (VariableReference*) expr;
251                    if (v->fRefKind != VariableReference::kRead_RefKind) {
252                        this->addDefinition(
253                                       v,
254                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression,
255                                       definitions);
256                    }
257                }
258                default:
259                    break;
260            }
261            break;
262        }
263        case BasicBlock::Node::kStatement_Kind: {
264            const Statement* stmt = (Statement*) node.statement()->get();
265            if (stmt->fKind == Statement::kVarDeclaration_Kind) {
266                VarDeclaration& vd = (VarDeclaration&) *stmt;
267                if (vd.fValue) {
268                    (*definitions)[vd.fVar] = &vd.fValue;
269                }
270            }
271            break;
272        }
273    }
274}
275
276void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
277    BasicBlock& block = cfg->fBlocks[blockId];
278
279    // compute definitions after this block
280    DefinitionMap after = block.fBefore;
281    for (const BasicBlock::Node& n : block.fNodes) {
282        this->addDefinitions(n, &after);
283    }
284
285    // propagate definitions to exits
286    for (BlockId exitId : block.fExits) {
287        BasicBlock& exit = cfg->fBlocks[exitId];
288        for (const auto& pair : after) {
289            std::unique_ptr<Expression>* e1 = pair.second;
290            auto found = exit.fBefore.find(pair.first);
291            if (found == exit.fBefore.end()) {
292                // exit has no definition for it, just copy it
293                workList->insert(exitId);
294                exit.fBefore[pair.first] = e1;
295            } else {
296                // exit has a (possibly different) value already defined
297                std::unique_ptr<Expression>* e2 = exit.fBefore[pair.first];
298                if (e1 != e2) {
299                    // definition has changed, merge and add exit block to worklist
300                    workList->insert(exitId);
301                    if (e1 && e2) {
302                        exit.fBefore[pair.first] =
303                                       (std::unique_ptr<Expression>*) &fContext.fDefined_Expression;
304                    } else {
305                        exit.fBefore[pair.first] = nullptr;
306                    }
307                }
308            }
309        }
310    }
311}
312
313// returns a map which maps all local variables in the function to null, indicating that their value
314// is initially unknown
315static DefinitionMap compute_start_state(const CFG& cfg) {
316    DefinitionMap result;
317    for (const auto& block : cfg.fBlocks) {
318        for (const auto& node : block.fNodes) {
319            if (node.fKind == BasicBlock::Node::kStatement_Kind) {
320                ASSERT(node.statement());
321                const Statement* s = node.statement()->get();
322                if (s->fKind == Statement::kVarDeclarations_Kind) {
323                    const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s;
324                    for (const auto& decl : vd->fDeclaration->fVars) {
325                        result[((VarDeclaration&) *decl).fVar] = nullptr;
326                    }
327                }
328            }
329        }
330    }
331    return result;
332}
333
334/**
335 * Returns true if assigning to this lvalue has no effect.
336 */
337static bool is_dead(const Expression& lvalue) {
338    switch (lvalue.fKind) {
339        case Expression::kVariableReference_Kind:
340            return ((VariableReference&) lvalue).fVariable.dead();
341        case Expression::kSwizzle_Kind:
342            return is_dead(*((Swizzle&) lvalue).fBase);
343        case Expression::kFieldAccess_Kind:
344            return is_dead(*((FieldAccess&) lvalue).fBase);
345        case Expression::kIndex_Kind: {
346            const IndexExpression& idx = (IndexExpression&) lvalue;
347            return is_dead(*idx.fBase) && !idx.fIndex->hasSideEffects();
348        }
349        default:
350            ABORT("invalid lvalue: %s\n", lvalue.description().c_str());
351    }
352}
353
354/**
355 * Returns true if this is an assignment which can be collapsed down to just the right hand side due
356 * to a dead target and lack of side effects on the left hand side.
357 */
358static bool dead_assignment(const BinaryExpression& b) {
359    if (!Token::IsAssignment(b.fOperator)) {
360        return false;
361    }
362    return is_dead(*b.fLeft);
363}
364
365void Compiler::computeDataFlow(CFG* cfg) {
366    cfg->fBlocks[cfg->fStart].fBefore = compute_start_state(*cfg);
367    std::set<BlockId> workList;
368    for (BlockId i = 0; i < cfg->fBlocks.size(); i++) {
369        workList.insert(i);
370    }
371    while (workList.size()) {
372        BlockId next = *workList.begin();
373        workList.erase(workList.begin());
374        this->scanCFG(cfg, next, &workList);
375    }
376}
377
378/**
379 * Attempts to replace the expression pointed to by iter with a new one (in both the CFG and the
380 * IR). If the expression can be cleanly removed, returns true and updates the iterator to point to
381 * the newly-inserted element. Otherwise updates only the IR and returns false (and the CFG will
382 * need to be regenerated).
383 */
384bool try_replace_expression(BasicBlock* b,
385                            std::vector<BasicBlock::Node>::iterator* iter,
386                            std::unique_ptr<Expression>* newExpression) {
387    std::unique_ptr<Expression>* target = (*iter)->expression();
388    if (!b->tryRemoveExpression(iter)) {
389        *target = std::move(*newExpression);
390        return false;
391    }
392    *target = std::move(*newExpression);
393    return b->tryInsertExpression(iter, target);
394}
395
396/**
397 * Returns true if the expression is a constant numeric literal with the specified value, or a
398 * constant vector with all elements equal to the specified value.
399 */
400bool is_constant(const Expression& expr, double value) {
401    switch (expr.fKind) {
402        case Expression::kIntLiteral_Kind:
403            return ((IntLiteral&) expr).fValue == value;
404        case Expression::kFloatLiteral_Kind:
405            return ((FloatLiteral&) expr).fValue == value;
406        case Expression::kConstructor_Kind: {
407            Constructor& c = (Constructor&) expr;
408            if (c.fType.kind() == Type::kVector_Kind && c.isConstant()) {
409                for (int i = 0; i < c.fType.columns(); ++i) {
410                    if (!is_constant(c.getVecComponent(i), value)) {
411                        return false;
412                    }
413                }
414                return true;
415            }
416            return false;
417        }
418        default:
419            return false;
420    }
421}
422
423/**
424 * Collapses the binary expression pointed to by iter down to just the right side (in both the IR
425 * and CFG structures).
426 */
427void delete_left(BasicBlock* b,
428                 std::vector<BasicBlock::Node>::iterator* iter,
429                 bool* outUpdated,
430                 bool* outNeedsRescan) {
431    *outUpdated = true;
432    std::unique_ptr<Expression>* target = (*iter)->expression();
433    ASSERT((*target)->fKind == Expression::kBinary_Kind);
434    BinaryExpression& bin = (BinaryExpression&) **target;
435    bool result;
436    if (bin.fOperator == Token::EQ) {
437        result = b->tryRemoveLValueBefore(iter, bin.fLeft.get());
438    } else {
439        result = b->tryRemoveExpressionBefore(iter, bin.fLeft.get());
440    }
441    *target = std::move(bin.fRight);
442    if (!result) {
443        *outNeedsRescan = true;
444        return;
445    }
446    if (*iter == b->fNodes.begin()) {
447        *outNeedsRescan = true;
448        return;
449    }
450    --(*iter);
451    if ((*iter)->fKind != BasicBlock::Node::kExpression_Kind ||
452        (*iter)->expression() != &bin.fRight) {
453        *outNeedsRescan = true;
454        return;
455    }
456    *iter = b->fNodes.erase(*iter);
457    ASSERT((*iter)->expression() == target);
458}
459
460/**
461 * Collapses the binary expression pointed to by iter down to just the left side (in both the IR and
462 * CFG structures).
463 */
464void delete_right(BasicBlock* b,
465                  std::vector<BasicBlock::Node>::iterator* iter,
466                  bool* outUpdated,
467                  bool* outNeedsRescan) {
468    *outUpdated = true;
469    std::unique_ptr<Expression>* target = (*iter)->expression();
470    ASSERT((*target)->fKind == Expression::kBinary_Kind);
471    BinaryExpression& bin = (BinaryExpression&) **target;
472    if (!b->tryRemoveExpressionBefore(iter, bin.fRight.get())) {
473        *target = std::move(bin.fLeft);
474        *outNeedsRescan = true;
475        return;
476    }
477    *target = std::move(bin.fLeft);
478    if (*iter == b->fNodes.begin()) {
479        *outNeedsRescan = true;
480        return;
481    }
482    --(*iter);
483    if (((*iter)->fKind != BasicBlock::Node::kExpression_Kind ||
484        (*iter)->expression() != &bin.fLeft)) {
485        *outNeedsRescan = true;
486        return;
487    }
488    *iter = b->fNodes.erase(*iter);
489    ASSERT((*iter)->expression() == target);
490}
491
492/**
493 * Constructs the specified type using a single argument.
494 */
495static std::unique_ptr<Expression> construct(const Type& type, std::unique_ptr<Expression> v) {
496    std::vector<std::unique_ptr<Expression>> args;
497    args.push_back(std::move(v));
498    auto result = std::unique_ptr<Expression>(new Constructor(Position(), type, std::move(args)));
499    return result;
500}
501
502/**
503 * Used in the implementations of vectorize_left and vectorize_right. Given a vector type and an
504 * expression x, deletes the expression pointed to by iter and replaces it with <type>(x).
505 */
506static void vectorize(BasicBlock* b,
507                      std::vector<BasicBlock::Node>::iterator* iter,
508                      const Type& type,
509                      std::unique_ptr<Expression>* otherExpression,
510                      bool* outUpdated,
511                      bool* outNeedsRescan) {
512    ASSERT((*(*iter)->expression())->fKind == Expression::kBinary_Kind);
513    ASSERT(type.kind() == Type::kVector_Kind);
514    ASSERT((*otherExpression)->fType.kind() == Type::kScalar_Kind);
515    *outUpdated = true;
516    std::unique_ptr<Expression>* target = (*iter)->expression();
517    if (!b->tryRemoveExpression(iter)) {
518        *target = construct(type, std::move(*otherExpression));
519        *outNeedsRescan = true;
520    } else {
521        *target = construct(type, std::move(*otherExpression));
522        if (!b->tryInsertExpression(iter, target)) {
523            *outNeedsRescan = true;
524        }
525    }
526}
527
528/**
529 * Given a binary expression of the form x <op> vec<n>(y), deletes the right side and vectorizes the
530 * left to yield vec<n>(x).
531 */
532static void vectorize_left(BasicBlock* b,
533                           std::vector<BasicBlock::Node>::iterator* iter,
534                           bool* outUpdated,
535                           bool* outNeedsRescan) {
536    BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression();
537    vectorize(b, iter, bin.fRight->fType, &bin.fLeft, outUpdated, outNeedsRescan);
538}
539
540/**
541 * Given a binary expression of the form vec<n>(x) <op> y, deletes the left side and vectorizes the
542 * right to yield vec<n>(y).
543 */
544static void vectorize_right(BasicBlock* b,
545                            std::vector<BasicBlock::Node>::iterator* iter,
546                            bool* outUpdated,
547                            bool* outNeedsRescan) {
548    BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression();
549    vectorize(b, iter, bin.fLeft->fType, &bin.fRight, outUpdated, outNeedsRescan);
550}
551
552// Mark that an expression which we were writing to is no longer being written to
553void clear_write(const Expression& expr) {
554    switch (expr.fKind) {
555        case Expression::kVariableReference_Kind: {
556            ((VariableReference&) expr).setRefKind(VariableReference::kRead_RefKind);
557            break;
558        }
559        case Expression::kFieldAccess_Kind:
560            clear_write(*((FieldAccess&) expr).fBase);
561            break;
562        case Expression::kSwizzle_Kind:
563            clear_write(*((Swizzle&) expr).fBase);
564            break;
565        case Expression::kIndex_Kind:
566            clear_write(*((IndexExpression&) expr).fBase);
567            break;
568        default:
569            ABORT("shouldn't be writing to this kind of expression\n");
570            break;
571    }
572}
573
574void Compiler::simplifyExpression(DefinitionMap& definitions,
575                                  BasicBlock& b,
576                                  std::vector<BasicBlock::Node>::iterator* iter,
577                                  std::unordered_set<const Variable*>* undefinedVariables,
578                                  bool* outUpdated,
579                                  bool* outNeedsRescan) {
580    Expression* expr = (*iter)->expression()->get();
581    ASSERT(expr);
582    if ((*iter)->fConstantPropagation) {
583        std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator, definitions);
584        if (optimized) {
585            *outUpdated = true;
586            if (!try_replace_expression(&b, iter, &optimized)) {
587                *outNeedsRescan = true;
588                return;
589            }
590            ASSERT((*iter)->fKind == BasicBlock::Node::kExpression_Kind);
591            expr = (*iter)->expression()->get();
592        }
593    }
594    switch (expr->fKind) {
595        case Expression::kVariableReference_Kind: {
596            const Variable& var = ((VariableReference*) expr)->fVariable;
597            if (var.fStorage == Variable::kLocal_Storage && !definitions[&var] &&
598                (*undefinedVariables).find(&var) == (*undefinedVariables).end()) {
599                (*undefinedVariables).insert(&var);
600                this->error(expr->fPosition,
601                            "'" + var.fName + "' has not been assigned");
602            }
603            break;
604        }
605        case Expression::kTernary_Kind: {
606            TernaryExpression* t = (TernaryExpression*) expr;
607            if (t->fTest->fKind == Expression::kBoolLiteral_Kind) {
608                // ternary has a constant test, replace it with either the true or
609                // false branch
610                if (((BoolLiteral&) *t->fTest).fValue) {
611                    (*iter)->setExpression(std::move(t->fIfTrue));
612                } else {
613                    (*iter)->setExpression(std::move(t->fIfFalse));
614                }
615                *outUpdated = true;
616                *outNeedsRescan = true;
617            }
618            break;
619        }
620        case Expression::kBinary_Kind: {
621            BinaryExpression* bin = (BinaryExpression*) expr;
622            if (dead_assignment(*bin)) {
623                delete_left(&b, iter, outUpdated, outNeedsRescan);
624                break;
625            }
626            // collapse useless expressions like x * 1 or x + 0
627            if (((bin->fLeft->fType.kind()  != Type::kScalar_Kind) &&
628                 (bin->fLeft->fType.kind()  != Type::kVector_Kind)) ||
629                ((bin->fRight->fType.kind() != Type::kScalar_Kind) &&
630                 (bin->fRight->fType.kind() != Type::kVector_Kind))) {
631                break;
632            }
633            switch (bin->fOperator) {
634                case Token::STAR:
635                    if (is_constant(*bin->fLeft, 1)) {
636                        if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
637                            bin->fRight->fType.kind() == Type::kScalar_Kind) {
638                            // vec4(1) * x -> vec4(x)
639                            vectorize_right(&b, iter, outUpdated, outNeedsRescan);
640                        } else {
641                            // 1 * x -> x
642                            // 1 * vec4(x) -> vec4(x)
643                            // vec4(1) * vec4(x) -> vec4(x)
644                            delete_left(&b, iter, outUpdated, outNeedsRescan);
645                        }
646                    }
647                    else if (is_constant(*bin->fLeft, 0)) {
648                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
649                            bin->fRight->fType.kind() == Type::kVector_Kind) {
650                            // 0 * vec4(x) -> vec4(0)
651                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
652                        } else {
653                            // 0 * x -> 0
654                            // vec4(0) * x -> vec4(0)
655                            // vec4(0) * vec4(x) -> vec4(0)
656                            delete_right(&b, iter, outUpdated, outNeedsRescan);
657                        }
658                    }
659                    else if (is_constant(*bin->fRight, 1)) {
660                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
661                            bin->fRight->fType.kind() == Type::kVector_Kind) {
662                            // x * vec4(1) -> vec4(x)
663                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
664                        } else {
665                            // x * 1 -> x
666                            // vec4(x) * 1 -> vec4(x)
667                            // vec4(x) * vec4(1) -> vec4(x)
668                            delete_right(&b, iter, outUpdated, outNeedsRescan);
669                        }
670                    }
671                    else if (is_constant(*bin->fRight, 0)) {
672                        if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
673                            bin->fRight->fType.kind() == Type::kScalar_Kind) {
674                            // vec4(x) * 0 -> vec4(0)
675                            vectorize_right(&b, iter, outUpdated, outNeedsRescan);
676                        } else {
677                            // x * 0 -> 0
678                            // x * vec4(0) -> vec4(0)
679                            // vec4(x) * vec4(0) -> vec4(0)
680                            delete_left(&b, iter, outUpdated, outNeedsRescan);
681                        }
682                    }
683                    break;
684                case Token::PLUS:
685                    if (is_constant(*bin->fLeft, 0)) {
686                        if (bin->fLeft->fType.kind() == Type::kVector_Kind &&
687                            bin->fRight->fType.kind() == Type::kScalar_Kind) {
688                            // vec4(0) + x -> vec4(x)
689                            vectorize_right(&b, iter, outUpdated, outNeedsRescan);
690                        } else {
691                            // 0 + x -> x
692                            // 0 + vec4(x) -> vec4(x)
693                            // vec4(0) + vec4(x) -> vec4(x)
694                            delete_left(&b, iter, outUpdated, outNeedsRescan);
695                        }
696                    } else if (is_constant(*bin->fRight, 0)) {
697                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
698                            bin->fRight->fType.kind() == Type::kVector_Kind) {
699                            // x + vec4(0) -> vec4(x)
700                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
701                        } else {
702                            // x + 0 -> x
703                            // vec4(x) + 0 -> vec4(x)
704                            // vec4(x) + vec4(0) -> vec4(x)
705                            delete_right(&b, iter, outUpdated, outNeedsRescan);
706                        }
707                    }
708                    break;
709                case Token::MINUS:
710                    if (is_constant(*bin->fRight, 0)) {
711                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
712                            bin->fRight->fType.kind() == Type::kVector_Kind) {
713                            // x - vec4(0) -> vec4(x)
714                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
715                        } else {
716                            // x - 0 -> x
717                            // vec4(x) - 0 -> vec4(x)
718                            // vec4(x) - vec4(0) -> vec4(x)
719                            delete_right(&b, iter, outUpdated, outNeedsRescan);
720                        }
721                    }
722                    break;
723                case Token::SLASH:
724                    if (is_constant(*bin->fRight, 1)) {
725                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
726                            bin->fRight->fType.kind() == Type::kVector_Kind) {
727                            // x / vec4(1) -> vec4(x)
728                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
729                        } else {
730                            // x / 1 -> x
731                            // vec4(x) / 1 -> vec4(x)
732                            // vec4(x) / vec4(1) -> vec4(x)
733                            delete_right(&b, iter, outUpdated, outNeedsRescan);
734                        }
735                    } else if (is_constant(*bin->fLeft, 0)) {
736                        if (bin->fLeft->fType.kind() == Type::kScalar_Kind &&
737                            bin->fRight->fType.kind() == Type::kVector_Kind) {
738                            // 0 / vec4(x) -> vec4(0)
739                            vectorize_left(&b, iter, outUpdated, outNeedsRescan);
740                        } else {
741                            // 0 / x -> 0
742                            // vec4(0) / x -> vec4(0)
743                            // vec4(0) / vec4(x) -> vec4(0)
744                            delete_right(&b, iter, outUpdated, outNeedsRescan);
745                        }
746                    }
747                    break;
748                case Token::PLUSEQ:
749                    if (is_constant(*bin->fRight, 0)) {
750                        clear_write(*bin->fLeft);
751                        delete_right(&b, iter, outUpdated, outNeedsRescan);
752                    }
753                    break;
754                case Token::MINUSEQ:
755                    if (is_constant(*bin->fRight, 0)) {
756                        clear_write(*bin->fLeft);
757                        delete_right(&b, iter, outUpdated, outNeedsRescan);
758                    }
759                    break;
760                case Token::STAREQ:
761                    if (is_constant(*bin->fRight, 1)) {
762                        clear_write(*bin->fLeft);
763                        delete_right(&b, iter, outUpdated, outNeedsRescan);
764                    }
765                    break;
766                case Token::SLASHEQ:
767                    if (is_constant(*bin->fRight, 1)) {
768                        clear_write(*bin->fLeft);
769                        delete_right(&b, iter, outUpdated, outNeedsRescan);
770                    }
771                    break;
772                default:
773                    break;
774            }
775        }
776        default:
777            break;
778    }
779}
780
781
782// returns true if this statement could potentially execute a break at the current level (we ignore
783// nested loops and switches, since any breaks inside of them will merely break the loop / switch)
784static bool contains_break(Statement& s) {
785    switch (s.fKind) {
786        case Statement::kBlock_Kind:
787            for (const auto& sub : ((Block&) s).fStatements) {
788                if (contains_break(*sub)) {
789                    return true;
790                }
791            }
792            return false;
793        case Statement::kBreak_Kind:
794            return true;
795        case Statement::kIf_Kind: {
796            const IfStatement& i = (IfStatement&) s;
797            return contains_break(*i.fIfTrue) || (i.fIfFalse && contains_break(*i.fIfFalse));
798        }
799        default:
800            return false;
801    }
802}
803
804// Returns a block containing all of the statements that will be run if the given case matches
805// (which, owing to the statements being owned by unique_ptrs, means the switch itself will be
806// broken by this call and must then be discarded).
807// Returns null (and leaves the switch unmodified) if no such simple reduction is possible, such as
808// when break statements appear inside conditionals.
809static std::unique_ptr<Statement> block_for_case(SwitchStatement* s, SwitchCase* c) {
810    bool capturing = false;
811    std::vector<std::unique_ptr<Statement>*> statementPtrs;
812    for (const auto& current : s->fCases) {
813        if (current.get() == c) {
814            capturing = true;
815        }
816        if (capturing) {
817            for (auto& stmt : current->fStatements) {
818                if (stmt->fKind == Statement::kBreak_Kind) {
819                    capturing = false;
820                    break;
821                }
822                if (contains_break(*stmt)) {
823                    return nullptr;
824                }
825                statementPtrs.push_back(&stmt);
826            }
827            if (!capturing) {
828                break;
829            }
830        }
831    }
832    std::vector<std::unique_ptr<Statement>> statements;
833    for (const auto& s : statementPtrs) {
834        statements.push_back(std::move(*s));
835    }
836    return std::unique_ptr<Statement>(new Block(Position(), std::move(statements)));
837}
838
839void Compiler::simplifyStatement(DefinitionMap& definitions,
840                                 BasicBlock& b,
841                                 std::vector<BasicBlock::Node>::iterator* iter,
842                                 std::unordered_set<const Variable*>* undefinedVariables,
843                                 bool* outUpdated,
844                                 bool* outNeedsRescan) {
845    Statement* stmt = (*iter)->statement()->get();
846    switch (stmt->fKind) {
847        case Statement::kVarDeclaration_Kind: {
848            const auto& varDecl = (VarDeclaration&) *stmt;
849            if (varDecl.fVar->dead() &&
850                (!varDecl.fValue ||
851                 !varDecl.fValue->hasSideEffects())) {
852                if (varDecl.fValue) {
853                    ASSERT((*iter)->statement()->get() == stmt);
854                    if (!b.tryRemoveExpressionBefore(iter, varDecl.fValue.get())) {
855                        *outNeedsRescan = true;
856                    }
857                }
858                (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
859                *outUpdated = true;
860            }
861            break;
862        }
863        case Statement::kIf_Kind: {
864            IfStatement& i = (IfStatement&) *stmt;
865            if (i.fTest->fKind == Expression::kBoolLiteral_Kind) {
866                // constant if, collapse down to a single branch
867                if (((BoolLiteral&) *i.fTest).fValue) {
868                    ASSERT(i.fIfTrue);
869                    (*iter)->setStatement(std::move(i.fIfTrue));
870                } else {
871                    if (i.fIfFalse) {
872                        (*iter)->setStatement(std::move(i.fIfFalse));
873                    } else {
874                        (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
875                    }
876                }
877                *outUpdated = true;
878                *outNeedsRescan = true;
879                break;
880            }
881            if (i.fIfFalse && i.fIfFalse->isEmpty()) {
882                // else block doesn't do anything, remove it
883                i.fIfFalse.reset();
884                *outUpdated = true;
885                *outNeedsRescan = true;
886            }
887            if (!i.fIfFalse && i.fIfTrue->isEmpty()) {
888                // if block doesn't do anything, no else block
889                if (i.fTest->hasSideEffects()) {
890                    // test has side effects, keep it
891                    (*iter)->setStatement(std::unique_ptr<Statement>(
892                                                      new ExpressionStatement(std::move(i.fTest))));
893                } else {
894                    // no if, no else, no test side effects, kill the whole if
895                    // statement
896                    (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
897                }
898                *outUpdated = true;
899                *outNeedsRescan = true;
900            }
901            break;
902        }
903        case Statement::kSwitch_Kind: {
904            SwitchStatement& s = (SwitchStatement&) *stmt;
905            if (s.fValue->isConstant()) {
906                // switch is constant, replace it with the case that matches
907                bool found = false;
908                SwitchCase* defaultCase = nullptr;
909                for (const auto& c : s.fCases) {
910                    if (!c->fValue) {
911                        defaultCase = c.get();
912                        continue;
913                    }
914                    ASSERT(c->fValue->fKind == s.fValue->fKind);
915                    found = c->fValue->compareConstant(fContext, *s.fValue);
916                    if (found) {
917                        std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get());
918                        if (newBlock) {
919                            (*iter)->setStatement(std::move(newBlock));
920                            break;
921                        } else {
922                            if (s.fIsStatic) {
923                                this->error(s.fPosition,
924                                            "static switch contains non-static conditional break");
925                                s.fIsStatic = false;
926                            }
927                            return; // can't simplify
928                        }
929                    }
930                }
931                if (!found) {
932                    // no matching case. use default if it exists, or kill the whole thing
933                    if (defaultCase) {
934                        std::unique_ptr<Statement> newBlock = block_for_case(&s, defaultCase);
935                        if (newBlock) {
936                            (*iter)->setStatement(std::move(newBlock));
937                        } else {
938                            if (s.fIsStatic) {
939                                this->error(s.fPosition,
940                                            "static switch contains non-static conditional break");
941                                s.fIsStatic = false;
942                            }
943                            return; // can't simplify
944                        }
945                    } else {
946                        (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
947                    }
948                }
949                *outUpdated = true;
950                *outNeedsRescan = true;
951            }
952            break;
953        }
954        case Statement::kExpression_Kind: {
955            ExpressionStatement& e = (ExpressionStatement&) *stmt;
956            ASSERT((*iter)->statement()->get() == &e);
957            if (!e.fExpression->hasSideEffects()) {
958                // Expression statement with no side effects, kill it
959                if (!b.tryRemoveExpressionBefore(iter, e.fExpression.get())) {
960                    *outNeedsRescan = true;
961                }
962                ASSERT((*iter)->statement()->get() == stmt);
963                (*iter)->setStatement(std::unique_ptr<Statement>(new Nop()));
964                *outUpdated = true;
965            }
966            break;
967        }
968        default:
969            break;
970    }
971}
972
973void Compiler::scanCFG(FunctionDefinition& f) {
974    CFG cfg = CFGGenerator().getCFG(f);
975    this->computeDataFlow(&cfg);
976
977    // check for unreachable code
978    for (size_t i = 0; i < cfg.fBlocks.size(); i++) {
979        if (i != cfg.fStart && !cfg.fBlocks[i].fEntrances.size() &&
980            cfg.fBlocks[i].fNodes.size()) {
981            Position p;
982            switch (cfg.fBlocks[i].fNodes[0].fKind) {
983                case BasicBlock::Node::kStatement_Kind:
984                    p = (*cfg.fBlocks[i].fNodes[0].statement())->fPosition;
985                    break;
986                case BasicBlock::Node::kExpression_Kind:
987                    p = (*cfg.fBlocks[i].fNodes[0].expression())->fPosition;
988                    break;
989            }
990            this->error(p, String("unreachable"));
991        }
992    }
993    if (fErrorCount) {
994        return;
995    }
996
997    // check for dead code & undefined variables, perform constant propagation
998    std::unordered_set<const Variable*> undefinedVariables;
999    bool updated;
1000    bool needsRescan = false;
1001    do {
1002        if (needsRescan) {
1003            cfg = CFGGenerator().getCFG(f);
1004            this->computeDataFlow(&cfg);
1005            needsRescan = false;
1006        }
1007
1008        updated = false;
1009        for (BasicBlock& b : cfg.fBlocks) {
1010            DefinitionMap definitions = b.fBefore;
1011
1012            for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) {
1013                if (iter->fKind == BasicBlock::Node::kExpression_Kind) {
1014                    this->simplifyExpression(definitions, b, &iter, &undefinedVariables, &updated,
1015                                             &needsRescan);
1016                } else {
1017                    this->simplifyStatement(definitions, b, &iter, &undefinedVariables, &updated,
1018                                             &needsRescan);
1019                }
1020                if (needsRescan) {
1021                    break;
1022                }
1023                this->addDefinitions(*iter, &definitions);
1024            }
1025        }
1026    } while (updated);
1027    ASSERT(!needsRescan);
1028
1029    // verify static ifs & switches
1030    for (BasicBlock& b : cfg.fBlocks) {
1031        DefinitionMap definitions = b.fBefore;
1032
1033        for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) {
1034            if (iter->fKind == BasicBlock::Node::kStatement_Kind) {
1035                const Statement& s = **iter->statement();
1036                switch (s.fKind) {
1037                    case Statement::kIf_Kind:
1038                        if (((const IfStatement&) s).fIsStatic) {
1039                            this->error(s.fPosition, "static if has non-static test");
1040                        }
1041                        break;
1042                    case Statement::kSwitch_Kind:
1043                        if (((const SwitchStatement&) s).fIsStatic) {
1044                            this->error(s.fPosition, "static switch has non-static test");
1045                        }
1046                        break;
1047                    default:
1048                        break;
1049                }
1050            }
1051        }
1052    }
1053
1054    // check for missing return
1055    if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) {
1056        if (cfg.fBlocks[cfg.fExit].fEntrances.size()) {
1057            this->error(f.fPosition, String("function can exit without returning a value"));
1058        }
1059    }
1060}
1061
1062void Compiler::internalConvertProgram(String text,
1063                                      Modifiers::Flag* defaultPrecision,
1064                                      std::vector<std::unique_ptr<ProgramElement>>* result) {
1065    Parser parser(text, *fTypes, *this);
1066    std::vector<std::unique_ptr<ASTDeclaration>> parsed = parser.file();
1067    if (fErrorCount) {
1068        return;
1069    }
1070    *defaultPrecision = Modifiers::kHighp_Flag;
1071    for (size_t i = 0; i < parsed.size(); i++) {
1072        ASTDeclaration& decl = *parsed[i];
1073        switch (decl.fKind) {
1074            case ASTDeclaration::kVar_Kind: {
1075                std::unique_ptr<VarDeclarations> s = fIRGenerator->convertVarDeclarations(
1076                                                                         (ASTVarDeclarations&) decl,
1077                                                                         Variable::kGlobal_Storage);
1078                if (s) {
1079                    result->push_back(std::move(s));
1080                }
1081                break;
1082            }
1083            case ASTDeclaration::kFunction_Kind: {
1084                std::unique_ptr<FunctionDefinition> f = fIRGenerator->convertFunction(
1085                                                                               (ASTFunction&) decl);
1086                if (!fErrorCount && f) {
1087                    this->scanCFG(*f);
1088                    result->push_back(std::move(f));
1089                }
1090                break;
1091            }
1092            case ASTDeclaration::kModifiers_Kind: {
1093                std::unique_ptr<ModifiersDeclaration> f = fIRGenerator->convertModifiersDeclaration(
1094                                                                   (ASTModifiersDeclaration&) decl);
1095                if (f) {
1096                    result->push_back(std::move(f));
1097                }
1098                break;
1099            }
1100            case ASTDeclaration::kInterfaceBlock_Kind: {
1101                std::unique_ptr<InterfaceBlock> i = fIRGenerator->convertInterfaceBlock(
1102                                                                         (ASTInterfaceBlock&) decl);
1103                if (i) {
1104                    result->push_back(std::move(i));
1105                }
1106                break;
1107            }
1108            case ASTDeclaration::kExtension_Kind: {
1109                std::unique_ptr<Extension> e = fIRGenerator->convertExtension((ASTExtension&) decl);
1110                if (e) {
1111                    result->push_back(std::move(e));
1112                }
1113                break;
1114            }
1115            case ASTDeclaration::kPrecision_Kind: {
1116                *defaultPrecision = ((ASTPrecision&) decl).fPrecision;
1117                break;
1118            }
1119            default:
1120                ABORT("unsupported declaration: %s\n", decl.description().c_str());
1121        }
1122    }
1123}
1124
1125std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, String text,
1126                                                  const Program::Settings& settings) {
1127    fErrorText = "";
1128    fErrorCount = 0;
1129    fIRGenerator->start(&settings);
1130    std::vector<std::unique_ptr<ProgramElement>> elements;
1131    Modifiers::Flag ignored;
1132    switch (kind) {
1133        case Program::kVertex_Kind:
1134            this->internalConvertProgram(String(SKSL_VERT_INCLUDE), &ignored, &elements);
1135            break;
1136        case Program::kFragment_Kind:
1137            this->internalConvertProgram(String(SKSL_FRAG_INCLUDE), &ignored, &elements);
1138            break;
1139        case Program::kGeometry_Kind:
1140            this->internalConvertProgram(String(SKSL_GEOM_INCLUDE), &ignored, &elements);
1141            break;
1142    }
1143    fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
1144    Modifiers::Flag defaultPrecision;
1145    this->internalConvertProgram(text, &defaultPrecision, &elements);
1146    auto result = std::unique_ptr<Program>(new Program(kind, settings, defaultPrecision, &fContext,
1147                                                       std::move(elements),
1148                                                       fIRGenerator->fSymbolTable,
1149                                                       fIRGenerator->fInputs));
1150    fIRGenerator->finish();
1151    this->writeErrorCount();
1152    if (fErrorCount) {
1153        return nullptr;
1154    }
1155    return result;
1156}
1157
1158bool Compiler::toSPIRV(const Program& program, OutputStream& out) {
1159#ifdef SK_ENABLE_SPIRV_VALIDATION
1160    StringStream buffer;
1161    SPIRVCodeGenerator cg(&fContext, &program, this, &buffer);
1162    bool result = cg.generateCode();
1163    if (result) {
1164        spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_0);
1165        ASSERT(0 == buffer.size() % 4);
1166        auto dumpmsg = [](spv_message_level_t, const char*, const spv_position_t&, const char* m) {
1167            SkDebugf("SPIR-V validation error: %s\n", m);
1168        };
1169        tools.SetMessageConsumer(dumpmsg);
1170        // Verify that the SPIR-V we produced is valid. If this assert fails, check the logs prior
1171        // to the failure to see the validation errors.
1172        ASSERT_RESULT(tools.Validate((const uint32_t*) buffer.data(), buffer.size() / 4));
1173        out.write(buffer.data(), buffer.size());
1174    }
1175#else
1176    SPIRVCodeGenerator cg(&fContext, &program, this, &out);
1177    bool result = cg.generateCode();
1178#endif
1179    this->writeErrorCount();
1180    return result;
1181}
1182
1183bool Compiler::toSPIRV(const Program& program, String* out) {
1184    StringStream buffer;
1185    bool result = this->toSPIRV(program, buffer);
1186    if (result) {
1187        *out = String(buffer.data(), buffer.size());
1188    }
1189    return result;
1190}
1191
1192bool Compiler::toGLSL(const Program& program, OutputStream& out) {
1193    GLSLCodeGenerator cg(&fContext, &program, this, &out);
1194    bool result = cg.generateCode();
1195    this->writeErrorCount();
1196    return result;
1197}
1198
1199bool Compiler::toGLSL(const Program& program, String* out) {
1200    StringStream buffer;
1201    bool result = this->toGLSL(program, buffer);
1202    if (result) {
1203        *out = String(buffer.data(), buffer.size());
1204    }
1205    return result;
1206}
1207
1208
1209void Compiler::error(Position position, String msg) {
1210    fErrorCount++;
1211    fErrorText += "error: " + position.description() + ": " + msg.c_str() + "\n";
1212}
1213
1214String Compiler::errorText() {
1215    String result = fErrorText;
1216    return result;
1217}
1218
1219void Compiler::writeErrorCount() {
1220    if (fErrorCount) {
1221        fErrorText += to_string(fErrorCount) + " error";
1222        if (fErrorCount > 1) {
1223            fErrorText += "s";
1224        }
1225        fErrorText += "\n";
1226    }
1227}
1228
1229} // namespace
1230