SkSLCompiler.cpp revision 2a51de82ceb6790f329b9f4cc85e61f34fc2d0d4
1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "SkSLCompiler.h"
9
10#include <fstream>
11#include <streambuf>
12
13#include "ast/SkSLASTPrecision.h"
14#include "SkSLCFGGenerator.h"
15#include "SkSLIRGenerator.h"
16#include "SkSLParser.h"
17#include "SkSLSPIRVCodeGenerator.h"
18#include "ir/SkSLExpression.h"
19#include "ir/SkSLIntLiteral.h"
20#include "ir/SkSLModifiersDeclaration.h"
21#include "ir/SkSLSymbolTable.h"
22#include "ir/SkSLUnresolvedFunction.h"
23#include "ir/SkSLVarDeclarations.h"
24#include "SkMutex.h"
25
26#define STRINGIFY(x) #x
27
28// include the built-in shader symbols as static strings
29
30static const char* SKSL_INCLUDE =
31#include "sksl.include"
32;
33
34static const char* SKSL_VERT_INCLUDE =
35#include "sksl_vert.include"
36;
37
38static const char* SKSL_FRAG_INCLUDE =
39#include "sksl_frag.include"
40;
41
42namespace SkSL {
43
44Compiler::Compiler()
45: fErrorCount(0) {
46    auto types = std::shared_ptr<SymbolTable>(new SymbolTable(*this));
47    auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, *this));
48    fIRGenerator = new IRGenerator(&fContext, symbols, *this);
49    fTypes = types;
50    #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \
51                                                   fContext.f ## t ## _Type.get())
52    ADD_TYPE(Void);
53    ADD_TYPE(Float);
54    ADD_TYPE(Vec2);
55    ADD_TYPE(Vec3);
56    ADD_TYPE(Vec4);
57    ADD_TYPE(Double);
58    ADD_TYPE(DVec2);
59    ADD_TYPE(DVec3);
60    ADD_TYPE(DVec4);
61    ADD_TYPE(Int);
62    ADD_TYPE(IVec2);
63    ADD_TYPE(IVec3);
64    ADD_TYPE(IVec4);
65    ADD_TYPE(UInt);
66    ADD_TYPE(UVec2);
67    ADD_TYPE(UVec3);
68    ADD_TYPE(UVec4);
69    ADD_TYPE(Bool);
70    ADD_TYPE(BVec2);
71    ADD_TYPE(BVec3);
72    ADD_TYPE(BVec4);
73    ADD_TYPE(Mat2x2);
74    types->addWithoutOwnership("mat2x2", fContext.fMat2x2_Type.get());
75    ADD_TYPE(Mat2x3);
76    ADD_TYPE(Mat2x4);
77    ADD_TYPE(Mat3x2);
78    ADD_TYPE(Mat3x3);
79    types->addWithoutOwnership("mat3x3", fContext.fMat3x3_Type.get());
80    ADD_TYPE(Mat3x4);
81    ADD_TYPE(Mat4x2);
82    ADD_TYPE(Mat4x3);
83    ADD_TYPE(Mat4x4);
84    types->addWithoutOwnership("mat4x4", fContext.fMat4x4_Type.get());
85    ADD_TYPE(GenType);
86    ADD_TYPE(GenDType);
87    ADD_TYPE(GenIType);
88    ADD_TYPE(GenUType);
89    ADD_TYPE(GenBType);
90    ADD_TYPE(Mat);
91    ADD_TYPE(Vec);
92    ADD_TYPE(GVec);
93    ADD_TYPE(GVec2);
94    ADD_TYPE(GVec3);
95    ADD_TYPE(GVec4);
96    ADD_TYPE(DVec);
97    ADD_TYPE(IVec);
98    ADD_TYPE(UVec);
99    ADD_TYPE(BVec);
100
101    ADD_TYPE(Sampler1D);
102    ADD_TYPE(Sampler2D);
103    ADD_TYPE(Sampler3D);
104    ADD_TYPE(SamplerExternalOES);
105    ADD_TYPE(SamplerCube);
106    ADD_TYPE(Sampler2DRect);
107    ADD_TYPE(Sampler1DArray);
108    ADD_TYPE(Sampler2DArray);
109    ADD_TYPE(SamplerCubeArray);
110    ADD_TYPE(SamplerBuffer);
111    ADD_TYPE(Sampler2DMS);
112    ADD_TYPE(Sampler2DMSArray);
113
114    ADD_TYPE(ISampler2D);
115
116    ADD_TYPE(Image2D);
117    ADD_TYPE(IImage2D);
118
119    ADD_TYPE(GSampler1D);
120    ADD_TYPE(GSampler2D);
121    ADD_TYPE(GSampler3D);
122    ADD_TYPE(GSamplerCube);
123    ADD_TYPE(GSampler2DRect);
124    ADD_TYPE(GSampler1DArray);
125    ADD_TYPE(GSampler2DArray);
126    ADD_TYPE(GSamplerCubeArray);
127    ADD_TYPE(GSamplerBuffer);
128    ADD_TYPE(GSampler2DMS);
129    ADD_TYPE(GSampler2DMSArray);
130
131    ADD_TYPE(Sampler1DShadow);
132    ADD_TYPE(Sampler2DShadow);
133    ADD_TYPE(SamplerCubeShadow);
134    ADD_TYPE(Sampler2DRectShadow);
135    ADD_TYPE(Sampler1DArrayShadow);
136    ADD_TYPE(Sampler2DArrayShadow);
137    ADD_TYPE(SamplerCubeArrayShadow);
138    ADD_TYPE(GSampler2DArrayShadow);
139    ADD_TYPE(GSamplerCubeArrayShadow);
140
141    Modifiers::Flag ignored1;
142    std::vector<std::unique_ptr<ProgramElement>> ignored2;
143    this->internalConvertProgram(SKSL_INCLUDE, &ignored1, &ignored2);
144    fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
145    ASSERT(!fErrorCount);
146}
147
148Compiler::~Compiler() {
149    delete fIRGenerator;
150}
151
152// add the definition created by assigning to the lvalue to the definition set
153void Compiler::addDefinition(const Expression* lvalue, const Expression* expr,
154                           std::unordered_map<const Variable*, const Expression*>* definitions) {
155    switch (lvalue->fKind) {
156        case Expression::kVariableReference_Kind: {
157            const Variable& var = ((VariableReference*) lvalue)->fVariable;
158            if (var.fStorage == Variable::kLocal_Storage) {
159                (*definitions)[&var] = expr;
160            }
161            break;
162        }
163        case Expression::kSwizzle_Kind:
164            // We consider the variable written to as long as at least some of its components have
165            // been written to. This will lead to some false negatives (we won't catch it if you
166            // write to foo.x and then read foo.y), but being stricter could lead to false positives
167            // (we write to foo.x, and then pass foo to a function which happens to only read foo.x,
168            // but since we pass foo as a whole it is flagged as an error) unless we perform a much
169            // more complicated whole-program analysis. This is probably good enough.
170            this->addDefinition(((Swizzle*) lvalue)->fBase.get(),
171                                fContext.fDefined_Expression.get(),
172                                definitions);
173            break;
174        case Expression::kIndex_Kind:
175            // see comments in Swizzle
176            this->addDefinition(((IndexExpression*) lvalue)->fBase.get(),
177                                fContext.fDefined_Expression.get(),
178                                definitions);
179            break;
180        case Expression::kFieldAccess_Kind:
181            // see comments in Swizzle
182            this->addDefinition(((FieldAccess*) lvalue)->fBase.get(),
183                                fContext.fDefined_Expression.get(),
184                                definitions);
185            break;
186        default:
187            // not an lvalue, can't happen
188            ASSERT(false);
189    }
190}
191
192// add local variables defined by this node to the set
193void Compiler::addDefinitions(const BasicBlock::Node& node,
194                              std::unordered_map<const Variable*, const Expression*>* definitions) {
195    switch (node.fKind) {
196        case BasicBlock::Node::kExpression_Kind: {
197            const Expression* expr = (Expression*) node.fNode;
198            if (expr->fKind == Expression::kBinary_Kind) {
199                const BinaryExpression* b = (BinaryExpression*) expr;
200                if (b->fOperator == Token::EQ) {
201                    this->addDefinition(b->fLeft.get(), b->fRight.get(), definitions);
202                }
203            }
204            break;
205        }
206        case BasicBlock::Node::kStatement_Kind: {
207            const Statement* stmt = (Statement*) node.fNode;
208            if (stmt->fKind == Statement::kVarDeclarations_Kind) {
209                const VarDeclarationsStatement* vd = (VarDeclarationsStatement*) stmt;
210                for (const VarDeclaration& decl : vd->fDeclaration->fVars) {
211                    if (decl.fValue) {
212                        (*definitions)[decl.fVar] = decl.fValue.get();
213                    }
214                }
215            }
216            break;
217        }
218    }
219}
220
221void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) {
222    BasicBlock& block = cfg->fBlocks[blockId];
223
224    // compute definitions after this block
225    std::unordered_map<const Variable*, const Expression*> after = block.fBefore;
226    for (const BasicBlock::Node& n : block.fNodes) {
227        this->addDefinitions(n, &after);
228    }
229
230    // propagate definitions to exits
231    for (BlockId exitId : block.fExits) {
232        BasicBlock& exit = cfg->fBlocks[exitId];
233        for (const auto& pair : after) {
234            const Expression* e1 = pair.second;
235            if (exit.fBefore.find(pair.first) == exit.fBefore.end()) {
236                exit.fBefore[pair.first] = e1;
237            } else {
238                const Expression* e2 = exit.fBefore[pair.first];
239                if (e1 != e2) {
240                    // definition has changed, merge and add exit block to worklist
241                    workList->insert(exitId);
242                    if (!e1 || !e2) {
243                        exit.fBefore[pair.first] = nullptr;
244                    } else {
245                        exit.fBefore[pair.first] = fContext.fDefined_Expression.get();
246                    }
247                }
248            }
249        }
250    }
251}
252
253// returns a map which maps all local variables in the function to null, indicating that their value
254// is initially unknown
255static std::unordered_map<const Variable*, const Expression*> compute_start_state(const CFG& cfg) {
256    std::unordered_map<const Variable*, const Expression*> result;
257    for (const auto& block : cfg.fBlocks) {
258        for (const auto& node : block.fNodes) {
259            if (node.fKind == BasicBlock::Node::kStatement_Kind) {
260                const Statement* s = (Statement*) node.fNode;
261                if (s->fKind == Statement::kVarDeclarations_Kind) {
262                    const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s;
263                    for (const VarDeclaration& decl : vd->fDeclaration->fVars) {
264                        result[decl.fVar] = nullptr;
265                    }
266                }
267            }
268        }
269    }
270    return result;
271}
272
273void Compiler::scanCFG(const FunctionDefinition& f) {
274    CFG cfg = CFGGenerator().getCFG(f);
275
276    // compute the data flow
277    cfg.fBlocks[cfg.fStart].fBefore = compute_start_state(cfg);
278    std::set<BlockId> workList;
279    for (BlockId i = 0; i < cfg.fBlocks.size(); i++) {
280        workList.insert(i);
281    }
282    while (workList.size()) {
283        BlockId next = *workList.begin();
284        workList.erase(workList.begin());
285        this->scanCFG(&cfg, next, &workList);
286    }
287
288    // check for unreachable code
289    for (size_t i = 0; i < cfg.fBlocks.size(); i++) {
290        if (i != cfg.fStart && !cfg.fBlocks[i].fEntrances.size() &&
291            cfg.fBlocks[i].fNodes.size()) {
292            this->error(cfg.fBlocks[i].fNodes[0].fNode->fPosition, "unreachable");
293        }
294    }
295    if (fErrorCount) {
296        return;
297    }
298
299    // check for undefined variables
300    for (const BasicBlock& b : cfg.fBlocks) {
301        std::unordered_map<const Variable*, const Expression*> definitions = b.fBefore;
302        for (const BasicBlock::Node& n : b.fNodes) {
303            if (n.fKind == BasicBlock::Node::kExpression_Kind) {
304                const Expression* expr = (const Expression*) n.fNode;
305                if (expr->fKind == Expression::kVariableReference_Kind) {
306                    const Variable& var = ((VariableReference*) expr)->fVariable;
307                    if (var.fStorage == Variable::kLocal_Storage &&
308                        !definitions[&var]) {
309                        this->error(expr->fPosition,
310                                    "'" + var.fName + "' has not been assigned");
311                    }
312                }
313            }
314            this->addDefinitions(n, &definitions);
315        }
316    }
317
318    // check for missing return
319    if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) {
320        if (cfg.fBlocks[cfg.fExit].fEntrances.size()) {
321            this->error(f.fPosition, "function can exit without returning a value");
322        }
323    }
324}
325
326void Compiler::internalConvertProgram(std::string text,
327                                      Modifiers::Flag* defaultPrecision,
328                                      std::vector<std::unique_ptr<ProgramElement>>* result) {
329    Parser parser(text, *fTypes, *this);
330    std::vector<std::unique_ptr<ASTDeclaration>> parsed = parser.file();
331    if (fErrorCount) {
332        return;
333    }
334    *defaultPrecision = Modifiers::kHighp_Flag;
335    for (size_t i = 0; i < parsed.size(); i++) {
336        ASTDeclaration& decl = *parsed[i];
337        switch (decl.fKind) {
338            case ASTDeclaration::kVar_Kind: {
339                std::unique_ptr<VarDeclarations> s = fIRGenerator->convertVarDeclarations(
340                                                                         (ASTVarDeclarations&) decl,
341                                                                         Variable::kGlobal_Storage);
342                if (s) {
343                    result->push_back(std::move(s));
344                }
345                break;
346            }
347            case ASTDeclaration::kFunction_Kind: {
348                std::unique_ptr<FunctionDefinition> f = fIRGenerator->convertFunction(
349                                                                               (ASTFunction&) decl);
350                if (!fErrorCount && f) {
351                    this->scanCFG(*f);
352                    result->push_back(std::move(f));
353                }
354                break;
355            }
356            case ASTDeclaration::kModifiers_Kind: {
357                std::unique_ptr<ModifiersDeclaration> f = fIRGenerator->convertModifiersDeclaration(
358                                                                   (ASTModifiersDeclaration&) decl);
359                if (f) {
360                    result->push_back(std::move(f));
361                }
362                break;
363            }
364            case ASTDeclaration::kInterfaceBlock_Kind: {
365                std::unique_ptr<InterfaceBlock> i = fIRGenerator->convertInterfaceBlock(
366                                                                         (ASTInterfaceBlock&) decl);
367                if (i) {
368                    result->push_back(std::move(i));
369                }
370                break;
371            }
372            case ASTDeclaration::kExtension_Kind: {
373                std::unique_ptr<Extension> e = fIRGenerator->convertExtension((ASTExtension&) decl);
374                if (e) {
375                    result->push_back(std::move(e));
376                }
377                break;
378            }
379            case ASTDeclaration::kPrecision_Kind: {
380                *defaultPrecision = ((ASTPrecision&) decl).fPrecision;
381                break;
382            }
383            default:
384                ABORT("unsupported declaration: %s\n", decl.description().c_str());
385        }
386    }
387}
388
389std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, std::string text) {
390    fErrorText = "";
391    fErrorCount = 0;
392    fIRGenerator->pushSymbolTable();
393    std::vector<std::unique_ptr<ProgramElement>> elements;
394    Modifiers::Flag ignored;
395    switch (kind) {
396        case Program::kVertex_Kind:
397            this->internalConvertProgram(SKSL_VERT_INCLUDE, &ignored, &elements);
398            break;
399        case Program::kFragment_Kind:
400            this->internalConvertProgram(SKSL_FRAG_INCLUDE, &ignored, &elements);
401            break;
402    }
403    fIRGenerator->fSymbolTable->markAllFunctionsBuiltin();
404    Modifiers::Flag defaultPrecision;
405    this->internalConvertProgram(text, &defaultPrecision, &elements);
406    auto result = std::unique_ptr<Program>(new Program(kind, defaultPrecision, std::move(elements),
407                                                       fIRGenerator->fSymbolTable));
408    fIRGenerator->popSymbolTable();
409    this->writeErrorCount();
410    return result;
411}
412
413void Compiler::error(Position position, std::string msg) {
414    fErrorCount++;
415    fErrorText += "error: " + position.description() + ": " + msg.c_str() + "\n";
416}
417
418std::string Compiler::errorText() {
419    std::string result = fErrorText;
420    return result;
421}
422
423void Compiler::writeErrorCount() {
424    if (fErrorCount) {
425        fErrorText += to_string(fErrorCount) + " error";
426        if (fErrorCount > 1) {
427            fErrorText += "s";
428        }
429        fErrorText += "\n";
430    }
431}
432
433bool Compiler::toSPIRV(Program::Kind kind, const std::string& text, std::ostream& out) {
434    auto program = this->convertProgram(kind, text);
435    if (fErrorCount == 0) {
436        SkSL::SPIRVCodeGenerator cg(&fContext);
437        cg.generateCode(*program.get(), out);
438        ASSERT(!out.rdstate());
439    }
440    return fErrorCount == 0;
441}
442
443bool Compiler::toSPIRV(Program::Kind kind, const std::string& text, std::string* out) {
444    std::stringstream buffer;
445    bool result = this->toSPIRV(kind, text, buffer);
446    if (result) {
447        *out = buffer.str();
448    }
449    return result;
450}
451
452bool Compiler::toGLSL(Program::Kind kind, const std::string& text, const GrGLSLCaps& caps,
453                      std::ostream& out) {
454    auto program = this->convertProgram(kind, text);
455    if (fErrorCount == 0) {
456        SkSL::GLSLCodeGenerator cg(&fContext, &caps);
457        cg.generateCode(*program.get(), out);
458        ASSERT(!out.rdstate());
459    }
460    return fErrorCount == 0;
461}
462
463bool Compiler::toGLSL(Program::Kind kind, const std::string& text, const GrGLSLCaps& caps,
464                      std::string* out) {
465    std::stringstream buffer;
466    bool result = this->toGLSL(kind, text, caps, buffer);
467    if (result) {
468        *out = buffer.str();
469    }
470    return result;
471}
472
473} // namespace
474