SkSLCompiler.cpp revision 84cda40bd7e98f4e19574c6e946395e244901408
1/* 2 * Copyright 2016 Google Inc. 3 * 4 * Use of this source code is governed by a BSD-style license that can be 5 * found in the LICENSE file. 6 */ 7 8#include "SkSLCompiler.h" 9 10#include "ast/SkSLASTPrecision.h" 11#include "SkSLCFGGenerator.h" 12#include "SkSLGLSLCodeGenerator.h" 13#include "SkSLIRGenerator.h" 14#include "SkSLParser.h" 15#include "SkSLSPIRVCodeGenerator.h" 16#include "ir/SkSLExpression.h" 17#include "ir/SkSLExpressionStatement.h" 18#include "ir/SkSLIntLiteral.h" 19#include "ir/SkSLModifiersDeclaration.h" 20#include "ir/SkSLNop.h" 21#include "ir/SkSLSymbolTable.h" 22#include "ir/SkSLTernaryExpression.h" 23#include "ir/SkSLUnresolvedFunction.h" 24#include "ir/SkSLVarDeclarations.h" 25 26#ifdef SK_ENABLE_SPIRV_VALIDATION 27#include "spirv-tools/libspirv.hpp" 28#endif 29 30#define STRINGIFY(x) #x 31 32// include the built-in shader symbols as static strings 33 34static const char* SKSL_INCLUDE = 35#include "sksl.include" 36; 37 38static const char* SKSL_VERT_INCLUDE = 39#include "sksl_vert.include" 40; 41 42static const char* SKSL_FRAG_INCLUDE = 43#include "sksl_frag.include" 44; 45 46static const char* SKSL_GEOM_INCLUDE = 47#include "sksl_geom.include" 48; 49 50namespace SkSL { 51 52Compiler::Compiler() 53: fErrorCount(0) { 54 auto types = std::shared_ptr<SymbolTable>(new SymbolTable(this)); 55 auto symbols = std::shared_ptr<SymbolTable>(new SymbolTable(types, this)); 56 fIRGenerator = new IRGenerator(&fContext, symbols, *this); 57 fTypes = types; 58 #define ADD_TYPE(t) types->addWithoutOwnership(fContext.f ## t ## _Type->fName, \ 59 fContext.f ## t ## _Type.get()) 60 ADD_TYPE(Void); 61 ADD_TYPE(Float); 62 ADD_TYPE(Vec2); 63 ADD_TYPE(Vec3); 64 ADD_TYPE(Vec4); 65 ADD_TYPE(Double); 66 ADD_TYPE(DVec2); 67 ADD_TYPE(DVec3); 68 ADD_TYPE(DVec4); 69 ADD_TYPE(Int); 70 ADD_TYPE(IVec2); 71 ADD_TYPE(IVec3); 72 ADD_TYPE(IVec4); 73 ADD_TYPE(UInt); 74 ADD_TYPE(UVec2); 75 ADD_TYPE(UVec3); 76 ADD_TYPE(UVec4); 77 ADD_TYPE(Bool); 78 ADD_TYPE(BVec2); 79 ADD_TYPE(BVec3); 80 ADD_TYPE(BVec4); 81 ADD_TYPE(Mat2x2); 82 types->addWithoutOwnership(String("mat2x2"), fContext.fMat2x2_Type.get()); 83 ADD_TYPE(Mat2x3); 84 ADD_TYPE(Mat2x4); 85 ADD_TYPE(Mat3x2); 86 ADD_TYPE(Mat3x3); 87 types->addWithoutOwnership(String("mat3x3"), fContext.fMat3x3_Type.get()); 88 ADD_TYPE(Mat3x4); 89 ADD_TYPE(Mat4x2); 90 ADD_TYPE(Mat4x3); 91 ADD_TYPE(Mat4x4); 92 types->addWithoutOwnership(String("mat4x4"), fContext.fMat4x4_Type.get()); 93 ADD_TYPE(GenType); 94 ADD_TYPE(GenDType); 95 ADD_TYPE(GenIType); 96 ADD_TYPE(GenUType); 97 ADD_TYPE(GenBType); 98 ADD_TYPE(Mat); 99 ADD_TYPE(Vec); 100 ADD_TYPE(GVec); 101 ADD_TYPE(GVec2); 102 ADD_TYPE(GVec3); 103 ADD_TYPE(GVec4); 104 ADD_TYPE(DVec); 105 ADD_TYPE(IVec); 106 ADD_TYPE(UVec); 107 ADD_TYPE(BVec); 108 109 ADD_TYPE(Sampler1D); 110 ADD_TYPE(Sampler2D); 111 ADD_TYPE(Sampler3D); 112 ADD_TYPE(SamplerExternalOES); 113 ADD_TYPE(SamplerCube); 114 ADD_TYPE(Sampler2DRect); 115 ADD_TYPE(Sampler1DArray); 116 ADD_TYPE(Sampler2DArray); 117 ADD_TYPE(SamplerCubeArray); 118 ADD_TYPE(SamplerBuffer); 119 ADD_TYPE(Sampler2DMS); 120 ADD_TYPE(Sampler2DMSArray); 121 122 ADD_TYPE(ISampler2D); 123 124 ADD_TYPE(Image2D); 125 ADD_TYPE(IImage2D); 126 127 ADD_TYPE(SubpassInput); 128 ADD_TYPE(SubpassInputMS); 129 130 ADD_TYPE(GSampler1D); 131 ADD_TYPE(GSampler2D); 132 ADD_TYPE(GSampler3D); 133 ADD_TYPE(GSamplerCube); 134 ADD_TYPE(GSampler2DRect); 135 ADD_TYPE(GSampler1DArray); 136 ADD_TYPE(GSampler2DArray); 137 ADD_TYPE(GSamplerCubeArray); 138 ADD_TYPE(GSamplerBuffer); 139 ADD_TYPE(GSampler2DMS); 140 ADD_TYPE(GSampler2DMSArray); 141 142 ADD_TYPE(Sampler1DShadow); 143 ADD_TYPE(Sampler2DShadow); 144 ADD_TYPE(SamplerCubeShadow); 145 ADD_TYPE(Sampler2DRectShadow); 146 ADD_TYPE(Sampler1DArrayShadow); 147 ADD_TYPE(Sampler2DArrayShadow); 148 ADD_TYPE(SamplerCubeArrayShadow); 149 ADD_TYPE(GSampler2DArrayShadow); 150 ADD_TYPE(GSamplerCubeArrayShadow); 151 152 String skCapsName("sk_Caps"); 153 Variable* skCaps = new Variable(Position(), Modifiers(), skCapsName, 154 *fContext.fSkCaps_Type, Variable::kGlobal_Storage); 155 fIRGenerator->fSymbolTable->add(skCapsName, std::unique_ptr<Symbol>(skCaps)); 156 157 Modifiers::Flag ignored1; 158 std::vector<std::unique_ptr<ProgramElement>> ignored2; 159 this->internalConvertProgram(String(SKSL_INCLUDE), &ignored1, &ignored2); 160 fIRGenerator->fSymbolTable->markAllFunctionsBuiltin(); 161 ASSERT(!fErrorCount); 162} 163 164Compiler::~Compiler() { 165 delete fIRGenerator; 166} 167 168// add the definition created by assigning to the lvalue to the definition set 169void Compiler::addDefinition(const Expression* lvalue, std::unique_ptr<Expression>* expr, 170 DefinitionMap* definitions) { 171 switch (lvalue->fKind) { 172 case Expression::kVariableReference_Kind: { 173 const Variable& var = ((VariableReference*) lvalue)->fVariable; 174 if (var.fStorage == Variable::kLocal_Storage) { 175 (*definitions)[&var] = expr; 176 } 177 break; 178 } 179 case Expression::kSwizzle_Kind: 180 // We consider the variable written to as long as at least some of its components have 181 // been written to. This will lead to some false negatives (we won't catch it if you 182 // write to foo.x and then read foo.y), but being stricter could lead to false positives 183 // (we write to foo.x, and then pass foo to a function which happens to only read foo.x, 184 // but since we pass foo as a whole it is flagged as an error) unless we perform a much 185 // more complicated whole-program analysis. This is probably good enough. 186 this->addDefinition(((Swizzle*) lvalue)->fBase.get(), 187 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, 188 definitions); 189 break; 190 case Expression::kIndex_Kind: 191 // see comments in Swizzle 192 this->addDefinition(((IndexExpression*) lvalue)->fBase.get(), 193 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, 194 definitions); 195 break; 196 case Expression::kFieldAccess_Kind: 197 // see comments in Swizzle 198 this->addDefinition(((FieldAccess*) lvalue)->fBase.get(), 199 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, 200 definitions); 201 break; 202 default: 203 // not an lvalue, can't happen 204 ASSERT(false); 205 } 206} 207 208// add local variables defined by this node to the set 209void Compiler::addDefinitions(const BasicBlock::Node& node, 210 DefinitionMap* definitions) { 211 switch (node.fKind) { 212 case BasicBlock::Node::kExpression_Kind: { 213 ASSERT(node.expression()); 214 const Expression* expr = (Expression*) node.expression()->get(); 215 switch (expr->fKind) { 216 case Expression::kBinary_Kind: { 217 BinaryExpression* b = (BinaryExpression*) expr; 218 if (b->fOperator == Token::EQ) { 219 this->addDefinition(b->fLeft.get(), &b->fRight, definitions); 220 } else if (Token::IsAssignment(b->fOperator)) { 221 this->addDefinition( 222 b->fLeft.get(), 223 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, 224 definitions); 225 226 } 227 break; 228 } 229 case Expression::kPrefix_Kind: { 230 const PrefixExpression* p = (PrefixExpression*) expr; 231 if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) { 232 this->addDefinition( 233 p->fOperand.get(), 234 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, 235 definitions); 236 } 237 break; 238 } 239 case Expression::kPostfix_Kind: { 240 const PostfixExpression* p = (PostfixExpression*) expr; 241 if (p->fOperator == Token::MINUSMINUS || p->fOperator == Token::PLUSPLUS) { 242 this->addDefinition( 243 p->fOperand.get(), 244 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, 245 definitions); 246 } 247 break; 248 } 249 case Expression::kVariableReference_Kind: { 250 const VariableReference* v = (VariableReference*) expr; 251 if (v->fRefKind != VariableReference::kRead_RefKind) { 252 this->addDefinition( 253 v, 254 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression, 255 definitions); 256 } 257 } 258 default: 259 break; 260 } 261 break; 262 } 263 case BasicBlock::Node::kStatement_Kind: { 264 const Statement* stmt = (Statement*) node.statement()->get(); 265 if (stmt->fKind == Statement::kVarDeclaration_Kind) { 266 VarDeclaration& vd = (VarDeclaration&) *stmt; 267 if (vd.fValue) { 268 (*definitions)[vd.fVar] = &vd.fValue; 269 } 270 } 271 break; 272 } 273 } 274} 275 276void Compiler::scanCFG(CFG* cfg, BlockId blockId, std::set<BlockId>* workList) { 277 BasicBlock& block = cfg->fBlocks[blockId]; 278 279 // compute definitions after this block 280 DefinitionMap after = block.fBefore; 281 for (const BasicBlock::Node& n : block.fNodes) { 282 this->addDefinitions(n, &after); 283 } 284 285 // propagate definitions to exits 286 for (BlockId exitId : block.fExits) { 287 BasicBlock& exit = cfg->fBlocks[exitId]; 288 for (const auto& pair : after) { 289 std::unique_ptr<Expression>* e1 = pair.second; 290 auto found = exit.fBefore.find(pair.first); 291 if (found == exit.fBefore.end()) { 292 // exit has no definition for it, just copy it 293 workList->insert(exitId); 294 exit.fBefore[pair.first] = e1; 295 } else { 296 // exit has a (possibly different) value already defined 297 std::unique_ptr<Expression>* e2 = exit.fBefore[pair.first]; 298 if (e1 != e2) { 299 // definition has changed, merge and add exit block to worklist 300 workList->insert(exitId); 301 if (e1 && e2) { 302 exit.fBefore[pair.first] = 303 (std::unique_ptr<Expression>*) &fContext.fDefined_Expression; 304 } else { 305 exit.fBefore[pair.first] = nullptr; 306 } 307 } 308 } 309 } 310 } 311} 312 313// returns a map which maps all local variables in the function to null, indicating that their value 314// is initially unknown 315static DefinitionMap compute_start_state(const CFG& cfg) { 316 DefinitionMap result; 317 for (const auto& block : cfg.fBlocks) { 318 for (const auto& node : block.fNodes) { 319 if (node.fKind == BasicBlock::Node::kStatement_Kind) { 320 ASSERT(node.statement()); 321 const Statement* s = node.statement()->get(); 322 if (s->fKind == Statement::kVarDeclarations_Kind) { 323 const VarDeclarationsStatement* vd = (const VarDeclarationsStatement*) s; 324 for (const auto& decl : vd->fDeclaration->fVars) { 325 result[((VarDeclaration&) *decl).fVar] = nullptr; 326 } 327 } 328 } 329 } 330 } 331 return result; 332} 333 334/** 335 * Returns true if assigning to this lvalue has no effect. 336 */ 337static bool is_dead(const Expression& lvalue) { 338 switch (lvalue.fKind) { 339 case Expression::kVariableReference_Kind: 340 return ((VariableReference&) lvalue).fVariable.dead(); 341 case Expression::kSwizzle_Kind: 342 return is_dead(*((Swizzle&) lvalue).fBase); 343 case Expression::kFieldAccess_Kind: 344 return is_dead(*((FieldAccess&) lvalue).fBase); 345 case Expression::kIndex_Kind: { 346 const IndexExpression& idx = (IndexExpression&) lvalue; 347 return is_dead(*idx.fBase) && !idx.fIndex->hasSideEffects(); 348 } 349 default: 350 ABORT("invalid lvalue: %s\n", lvalue.description().c_str()); 351 } 352} 353 354/** 355 * Returns true if this is an assignment which can be collapsed down to just the right hand side due 356 * to a dead target and lack of side effects on the left hand side. 357 */ 358static bool dead_assignment(const BinaryExpression& b) { 359 if (!Token::IsAssignment(b.fOperator)) { 360 return false; 361 } 362 return is_dead(*b.fLeft); 363} 364 365void Compiler::computeDataFlow(CFG* cfg) { 366 cfg->fBlocks[cfg->fStart].fBefore = compute_start_state(*cfg); 367 std::set<BlockId> workList; 368 for (BlockId i = 0; i < cfg->fBlocks.size(); i++) { 369 workList.insert(i); 370 } 371 while (workList.size()) { 372 BlockId next = *workList.begin(); 373 workList.erase(workList.begin()); 374 this->scanCFG(cfg, next, &workList); 375 } 376} 377 378/** 379 * Attempts to replace the expression pointed to by iter with a new one (in both the CFG and the 380 * IR). If the expression can be cleanly removed, returns true and updates the iterator to point to 381 * the newly-inserted element. Otherwise updates only the IR and returns false (and the CFG will 382 * need to be regenerated). 383 */ 384bool try_replace_expression(BasicBlock* b, 385 std::vector<BasicBlock::Node>::iterator* iter, 386 std::unique_ptr<Expression>* newExpression) { 387 std::unique_ptr<Expression>* target = (*iter)->expression(); 388 if (!b->tryRemoveExpression(iter)) { 389 *target = std::move(*newExpression); 390 return false; 391 } 392 *target = std::move(*newExpression); 393 return b->tryInsertExpression(iter, target); 394} 395 396/** 397 * Returns true if the expression is a constant numeric literal with the specified value, or a 398 * constant vector with all elements equal to the specified value. 399 */ 400bool is_constant(const Expression& expr, double value) { 401 switch (expr.fKind) { 402 case Expression::kIntLiteral_Kind: 403 return ((IntLiteral&) expr).fValue == value; 404 case Expression::kFloatLiteral_Kind: 405 return ((FloatLiteral&) expr).fValue == value; 406 case Expression::kConstructor_Kind: { 407 Constructor& c = (Constructor&) expr; 408 if (c.fType.kind() == Type::kVector_Kind && c.isConstant()) { 409 for (int i = 0; i < c.fType.columns(); ++i) { 410 if (!is_constant(c.getVecComponent(i), value)) { 411 return false; 412 } 413 } 414 return true; 415 } 416 return false; 417 } 418 default: 419 return false; 420 } 421} 422 423/** 424 * Collapses the binary expression pointed to by iter down to just the right side (in both the IR 425 * and CFG structures). 426 */ 427void delete_left(BasicBlock* b, 428 std::vector<BasicBlock::Node>::iterator* iter, 429 bool* outUpdated, 430 bool* outNeedsRescan) { 431 *outUpdated = true; 432 std::unique_ptr<Expression>* target = (*iter)->expression(); 433 ASSERT((*target)->fKind == Expression::kBinary_Kind); 434 BinaryExpression& bin = (BinaryExpression&) **target; 435 bool result; 436 if (bin.fOperator == Token::EQ) { 437 result = b->tryRemoveLValueBefore(iter, bin.fLeft.get()); 438 } else { 439 result = b->tryRemoveExpressionBefore(iter, bin.fLeft.get()); 440 } 441 *target = std::move(bin.fRight); 442 if (!result) { 443 *outNeedsRescan = true; 444 return; 445 } 446 if (*iter == b->fNodes.begin()) { 447 *outNeedsRescan = true; 448 return; 449 } 450 --(*iter); 451 if ((*iter)->fKind != BasicBlock::Node::kExpression_Kind || 452 (*iter)->expression() != &bin.fRight) { 453 *outNeedsRescan = true; 454 return; 455 } 456 *iter = b->fNodes.erase(*iter); 457 ASSERT((*iter)->expression() == target); 458} 459 460/** 461 * Collapses the binary expression pointed to by iter down to just the left side (in both the IR and 462 * CFG structures). 463 */ 464void delete_right(BasicBlock* b, 465 std::vector<BasicBlock::Node>::iterator* iter, 466 bool* outUpdated, 467 bool* outNeedsRescan) { 468 *outUpdated = true; 469 std::unique_ptr<Expression>* target = (*iter)->expression(); 470 ASSERT((*target)->fKind == Expression::kBinary_Kind); 471 BinaryExpression& bin = (BinaryExpression&) **target; 472 if (!b->tryRemoveExpressionBefore(iter, bin.fRight.get())) { 473 *target = std::move(bin.fLeft); 474 *outNeedsRescan = true; 475 return; 476 } 477 *target = std::move(bin.fLeft); 478 if (*iter == b->fNodes.begin()) { 479 *outNeedsRescan = true; 480 return; 481 } 482 --(*iter); 483 if (((*iter)->fKind != BasicBlock::Node::kExpression_Kind || 484 (*iter)->expression() != &bin.fLeft)) { 485 *outNeedsRescan = true; 486 return; 487 } 488 *iter = b->fNodes.erase(*iter); 489 ASSERT((*iter)->expression() == target); 490} 491 492/** 493 * Constructs the specified type using a single argument. 494 */ 495static std::unique_ptr<Expression> construct(const Type& type, std::unique_ptr<Expression> v) { 496 std::vector<std::unique_ptr<Expression>> args; 497 args.push_back(std::move(v)); 498 auto result = std::unique_ptr<Expression>(new Constructor(Position(), type, std::move(args))); 499 return result; 500} 501 502/** 503 * Used in the implementations of vectorize_left and vectorize_right. Given a vector type and an 504 * expression x, deletes the expression pointed to by iter and replaces it with <type>(x). 505 */ 506static void vectorize(BasicBlock* b, 507 std::vector<BasicBlock::Node>::iterator* iter, 508 const Type& type, 509 std::unique_ptr<Expression>* otherExpression, 510 bool* outUpdated, 511 bool* outNeedsRescan) { 512 ASSERT((*(*iter)->expression())->fKind == Expression::kBinary_Kind); 513 ASSERT(type.kind() == Type::kVector_Kind); 514 ASSERT((*otherExpression)->fType.kind() == Type::kScalar_Kind); 515 *outUpdated = true; 516 std::unique_ptr<Expression>* target = (*iter)->expression(); 517 if (!b->tryRemoveExpression(iter)) { 518 *target = construct(type, std::move(*otherExpression)); 519 *outNeedsRescan = true; 520 } else { 521 *target = construct(type, std::move(*otherExpression)); 522 if (!b->tryInsertExpression(iter, target)) { 523 *outNeedsRescan = true; 524 } 525 } 526} 527 528/** 529 * Given a binary expression of the form x <op> vec<n>(y), deletes the right side and vectorizes the 530 * left to yield vec<n>(x). 531 */ 532static void vectorize_left(BasicBlock* b, 533 std::vector<BasicBlock::Node>::iterator* iter, 534 bool* outUpdated, 535 bool* outNeedsRescan) { 536 BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression(); 537 vectorize(b, iter, bin.fRight->fType, &bin.fLeft, outUpdated, outNeedsRescan); 538} 539 540/** 541 * Given a binary expression of the form vec<n>(x) <op> y, deletes the left side and vectorizes the 542 * right to yield vec<n>(y). 543 */ 544static void vectorize_right(BasicBlock* b, 545 std::vector<BasicBlock::Node>::iterator* iter, 546 bool* outUpdated, 547 bool* outNeedsRescan) { 548 BinaryExpression& bin = (BinaryExpression&) **(*iter)->expression(); 549 vectorize(b, iter, bin.fLeft->fType, &bin.fRight, outUpdated, outNeedsRescan); 550} 551 552// Mark that an expression which we were writing to is no longer being written to 553void clear_write(const Expression& expr) { 554 switch (expr.fKind) { 555 case Expression::kVariableReference_Kind: { 556 ((VariableReference&) expr).setRefKind(VariableReference::kRead_RefKind); 557 break; 558 } 559 case Expression::kFieldAccess_Kind: 560 clear_write(*((FieldAccess&) expr).fBase); 561 break; 562 case Expression::kSwizzle_Kind: 563 clear_write(*((Swizzle&) expr).fBase); 564 break; 565 case Expression::kIndex_Kind: 566 clear_write(*((IndexExpression&) expr).fBase); 567 break; 568 default: 569 ABORT("shouldn't be writing to this kind of expression\n"); 570 break; 571 } 572} 573 574void Compiler::simplifyExpression(DefinitionMap& definitions, 575 BasicBlock& b, 576 std::vector<BasicBlock::Node>::iterator* iter, 577 std::unordered_set<const Variable*>* undefinedVariables, 578 bool* outUpdated, 579 bool* outNeedsRescan) { 580 Expression* expr = (*iter)->expression()->get(); 581 ASSERT(expr); 582 if ((*iter)->fConstantPropagation) { 583 std::unique_ptr<Expression> optimized = expr->constantPropagate(*fIRGenerator, definitions); 584 if (optimized) { 585 *outUpdated = true; 586 if (!try_replace_expression(&b, iter, &optimized)) { 587 *outNeedsRescan = true; 588 return; 589 } 590 ASSERT((*iter)->fKind == BasicBlock::Node::kExpression_Kind); 591 expr = (*iter)->expression()->get(); 592 } 593 } 594 switch (expr->fKind) { 595 case Expression::kVariableReference_Kind: { 596 const Variable& var = ((VariableReference*) expr)->fVariable; 597 if (var.fStorage == Variable::kLocal_Storage && !definitions[&var] && 598 (*undefinedVariables).find(&var) == (*undefinedVariables).end()) { 599 (*undefinedVariables).insert(&var); 600 this->error(expr->fPosition, 601 "'" + var.fName + "' has not been assigned"); 602 } 603 break; 604 } 605 case Expression::kTernary_Kind: { 606 TernaryExpression* t = (TernaryExpression*) expr; 607 if (t->fTest->fKind == Expression::kBoolLiteral_Kind) { 608 // ternary has a constant test, replace it with either the true or 609 // false branch 610 if (((BoolLiteral&) *t->fTest).fValue) { 611 (*iter)->setExpression(std::move(t->fIfTrue)); 612 } else { 613 (*iter)->setExpression(std::move(t->fIfFalse)); 614 } 615 *outUpdated = true; 616 *outNeedsRescan = true; 617 } 618 break; 619 } 620 case Expression::kBinary_Kind: { 621 BinaryExpression* bin = (BinaryExpression*) expr; 622 if (dead_assignment(*bin)) { 623 delete_left(&b, iter, outUpdated, outNeedsRescan); 624 break; 625 } 626 // collapse useless expressions like x * 1 or x + 0 627 if (((bin->fLeft->fType.kind() != Type::kScalar_Kind) && 628 (bin->fLeft->fType.kind() != Type::kVector_Kind)) || 629 ((bin->fRight->fType.kind() != Type::kScalar_Kind) && 630 (bin->fRight->fType.kind() != Type::kVector_Kind))) { 631 break; 632 } 633 switch (bin->fOperator) { 634 case Token::STAR: 635 if (is_constant(*bin->fLeft, 1)) { 636 if (bin->fLeft->fType.kind() == Type::kVector_Kind && 637 bin->fRight->fType.kind() == Type::kScalar_Kind) { 638 // vec4(1) * x -> vec4(x) 639 vectorize_right(&b, iter, outUpdated, outNeedsRescan); 640 } else { 641 // 1 * x -> x 642 // 1 * vec4(x) -> vec4(x) 643 // vec4(1) * vec4(x) -> vec4(x) 644 delete_left(&b, iter, outUpdated, outNeedsRescan); 645 } 646 } 647 else if (is_constant(*bin->fLeft, 0)) { 648 if (bin->fLeft->fType.kind() == Type::kScalar_Kind && 649 bin->fRight->fType.kind() == Type::kVector_Kind) { 650 // 0 * vec4(x) -> vec4(0) 651 vectorize_left(&b, iter, outUpdated, outNeedsRescan); 652 } else { 653 // 0 * x -> 0 654 // vec4(0) * x -> vec4(0) 655 // vec4(0) * vec4(x) -> vec4(0) 656 delete_right(&b, iter, outUpdated, outNeedsRescan); 657 } 658 } 659 else if (is_constant(*bin->fRight, 1)) { 660 if (bin->fLeft->fType.kind() == Type::kScalar_Kind && 661 bin->fRight->fType.kind() == Type::kVector_Kind) { 662 // x * vec4(1) -> vec4(x) 663 vectorize_left(&b, iter, outUpdated, outNeedsRescan); 664 } else { 665 // x * 1 -> x 666 // vec4(x) * 1 -> vec4(x) 667 // vec4(x) * vec4(1) -> vec4(x) 668 delete_right(&b, iter, outUpdated, outNeedsRescan); 669 } 670 } 671 else if (is_constant(*bin->fRight, 0)) { 672 if (bin->fLeft->fType.kind() == Type::kVector_Kind && 673 bin->fRight->fType.kind() == Type::kScalar_Kind) { 674 // vec4(x) * 0 -> vec4(0) 675 vectorize_right(&b, iter, outUpdated, outNeedsRescan); 676 } else { 677 // x * 0 -> 0 678 // x * vec4(0) -> vec4(0) 679 // vec4(x) * vec4(0) -> vec4(0) 680 delete_left(&b, iter, outUpdated, outNeedsRescan); 681 } 682 } 683 break; 684 case Token::PLUS: 685 if (is_constant(*bin->fLeft, 0)) { 686 if (bin->fLeft->fType.kind() == Type::kVector_Kind && 687 bin->fRight->fType.kind() == Type::kScalar_Kind) { 688 // vec4(0) + x -> vec4(x) 689 vectorize_right(&b, iter, outUpdated, outNeedsRescan); 690 } else { 691 // 0 + x -> x 692 // 0 + vec4(x) -> vec4(x) 693 // vec4(0) + vec4(x) -> vec4(x) 694 delete_left(&b, iter, outUpdated, outNeedsRescan); 695 } 696 } else if (is_constant(*bin->fRight, 0)) { 697 if (bin->fLeft->fType.kind() == Type::kScalar_Kind && 698 bin->fRight->fType.kind() == Type::kVector_Kind) { 699 // x + vec4(0) -> vec4(x) 700 vectorize_left(&b, iter, outUpdated, outNeedsRescan); 701 } else { 702 // x + 0 -> x 703 // vec4(x) + 0 -> vec4(x) 704 // vec4(x) + vec4(0) -> vec4(x) 705 delete_right(&b, iter, outUpdated, outNeedsRescan); 706 } 707 } 708 break; 709 case Token::MINUS: 710 if (is_constant(*bin->fRight, 0)) { 711 if (bin->fLeft->fType.kind() == Type::kScalar_Kind && 712 bin->fRight->fType.kind() == Type::kVector_Kind) { 713 // x - vec4(0) -> vec4(x) 714 vectorize_left(&b, iter, outUpdated, outNeedsRescan); 715 } else { 716 // x - 0 -> x 717 // vec4(x) - 0 -> vec4(x) 718 // vec4(x) - vec4(0) -> vec4(x) 719 delete_right(&b, iter, outUpdated, outNeedsRescan); 720 } 721 } 722 break; 723 case Token::SLASH: 724 if (is_constant(*bin->fRight, 1)) { 725 if (bin->fLeft->fType.kind() == Type::kScalar_Kind && 726 bin->fRight->fType.kind() == Type::kVector_Kind) { 727 // x / vec4(1) -> vec4(x) 728 vectorize_left(&b, iter, outUpdated, outNeedsRescan); 729 } else { 730 // x / 1 -> x 731 // vec4(x) / 1 -> vec4(x) 732 // vec4(x) / vec4(1) -> vec4(x) 733 delete_right(&b, iter, outUpdated, outNeedsRescan); 734 } 735 } else if (is_constant(*bin->fLeft, 0)) { 736 if (bin->fLeft->fType.kind() == Type::kScalar_Kind && 737 bin->fRight->fType.kind() == Type::kVector_Kind) { 738 // 0 / vec4(x) -> vec4(0) 739 vectorize_left(&b, iter, outUpdated, outNeedsRescan); 740 } else { 741 // 0 / x -> 0 742 // vec4(0) / x -> vec4(0) 743 // vec4(0) / vec4(x) -> vec4(0) 744 delete_right(&b, iter, outUpdated, outNeedsRescan); 745 } 746 } 747 break; 748 case Token::PLUSEQ: 749 if (is_constant(*bin->fRight, 0)) { 750 clear_write(*bin->fLeft); 751 delete_right(&b, iter, outUpdated, outNeedsRescan); 752 } 753 break; 754 case Token::MINUSEQ: 755 if (is_constant(*bin->fRight, 0)) { 756 clear_write(*bin->fLeft); 757 delete_right(&b, iter, outUpdated, outNeedsRescan); 758 } 759 break; 760 case Token::STAREQ: 761 if (is_constant(*bin->fRight, 1)) { 762 clear_write(*bin->fLeft); 763 delete_right(&b, iter, outUpdated, outNeedsRescan); 764 } 765 break; 766 case Token::SLASHEQ: 767 if (is_constant(*bin->fRight, 1)) { 768 clear_write(*bin->fLeft); 769 delete_right(&b, iter, outUpdated, outNeedsRescan); 770 } 771 break; 772 default: 773 break; 774 } 775 } 776 default: 777 break; 778 } 779} 780 781 782// returns true if this statement could potentially execute a break at the current level (we ignore 783// nested loops and switches, since any breaks inside of them will merely break the loop / switch) 784static bool contains_break(Statement& s) { 785 switch (s.fKind) { 786 case Statement::kBlock_Kind: 787 for (const auto& sub : ((Block&) s).fStatements) { 788 if (contains_break(*sub)) { 789 return true; 790 } 791 } 792 return false; 793 case Statement::kBreak_Kind: 794 return true; 795 case Statement::kIf_Kind: { 796 const IfStatement& i = (IfStatement&) s; 797 return contains_break(*i.fIfTrue) || (i.fIfFalse && contains_break(*i.fIfFalse)); 798 } 799 default: 800 return false; 801 } 802} 803 804// Returns a block containing all of the statements that will be run if the given case matches 805// (which, owing to the statements being owned by unique_ptrs, means the switch itself will be 806// broken by this call and must then be discarded). 807// Returns null (and leaves the switch unmodified) if no such simple reduction is possible, such as 808// when break statements appear inside conditionals. 809static std::unique_ptr<Statement> block_for_case(SwitchStatement* s, SwitchCase* c) { 810 bool capturing = false; 811 std::vector<std::unique_ptr<Statement>*> statementPtrs; 812 for (const auto& current : s->fCases) { 813 if (current.get() == c) { 814 capturing = true; 815 } 816 if (capturing) { 817 for (auto& stmt : current->fStatements) { 818 if (stmt->fKind == Statement::kBreak_Kind) { 819 capturing = false; 820 break; 821 } 822 if (contains_break(*stmt)) { 823 return nullptr; 824 } 825 statementPtrs.push_back(&stmt); 826 } 827 if (!capturing) { 828 break; 829 } 830 } 831 } 832 std::vector<std::unique_ptr<Statement>> statements; 833 for (const auto& s : statementPtrs) { 834 statements.push_back(std::move(*s)); 835 } 836 return std::unique_ptr<Statement>(new Block(Position(), std::move(statements))); 837} 838 839void Compiler::simplifyStatement(DefinitionMap& definitions, 840 BasicBlock& b, 841 std::vector<BasicBlock::Node>::iterator* iter, 842 std::unordered_set<const Variable*>* undefinedVariables, 843 bool* outUpdated, 844 bool* outNeedsRescan) { 845 Statement* stmt = (*iter)->statement()->get(); 846 switch (stmt->fKind) { 847 case Statement::kVarDeclaration_Kind: { 848 const auto& varDecl = (VarDeclaration&) *stmt; 849 if (varDecl.fVar->dead() && 850 (!varDecl.fValue || 851 !varDecl.fValue->hasSideEffects())) { 852 if (varDecl.fValue) { 853 ASSERT((*iter)->statement()->get() == stmt); 854 if (!b.tryRemoveExpressionBefore(iter, varDecl.fValue.get())) { 855 *outNeedsRescan = true; 856 } 857 } 858 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop())); 859 *outUpdated = true; 860 } 861 break; 862 } 863 case Statement::kIf_Kind: { 864 IfStatement& i = (IfStatement&) *stmt; 865 if (i.fTest->fKind == Expression::kBoolLiteral_Kind) { 866 // constant if, collapse down to a single branch 867 if (((BoolLiteral&) *i.fTest).fValue) { 868 ASSERT(i.fIfTrue); 869 (*iter)->setStatement(std::move(i.fIfTrue)); 870 } else { 871 if (i.fIfFalse) { 872 (*iter)->setStatement(std::move(i.fIfFalse)); 873 } else { 874 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop())); 875 } 876 } 877 *outUpdated = true; 878 *outNeedsRescan = true; 879 break; 880 } 881 if (i.fIfFalse && i.fIfFalse->isEmpty()) { 882 // else block doesn't do anything, remove it 883 i.fIfFalse.reset(); 884 *outUpdated = true; 885 *outNeedsRescan = true; 886 } 887 if (!i.fIfFalse && i.fIfTrue->isEmpty()) { 888 // if block doesn't do anything, no else block 889 if (i.fTest->hasSideEffects()) { 890 // test has side effects, keep it 891 (*iter)->setStatement(std::unique_ptr<Statement>( 892 new ExpressionStatement(std::move(i.fTest)))); 893 } else { 894 // no if, no else, no test side effects, kill the whole if 895 // statement 896 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop())); 897 } 898 *outUpdated = true; 899 *outNeedsRescan = true; 900 } 901 break; 902 } 903 case Statement::kSwitch_Kind: { 904 SwitchStatement& s = (SwitchStatement&) *stmt; 905 if (s.fValue->isConstant()) { 906 // switch is constant, replace it with the case that matches 907 bool found = false; 908 SwitchCase* defaultCase = nullptr; 909 for (const auto& c : s.fCases) { 910 if (!c->fValue) { 911 defaultCase = c.get(); 912 continue; 913 } 914 ASSERT(c->fValue->fKind == s.fValue->fKind); 915 found = c->fValue->compareConstant(fContext, *s.fValue); 916 if (found) { 917 std::unique_ptr<Statement> newBlock = block_for_case(&s, c.get()); 918 if (newBlock) { 919 (*iter)->setStatement(std::move(newBlock)); 920 break; 921 } else { 922 if (s.fIsStatic) { 923 this->error(s.fPosition, 924 "static switch contains non-static conditional break"); 925 s.fIsStatic = false; 926 } 927 return; // can't simplify 928 } 929 } 930 } 931 if (!found) { 932 // no matching case. use default if it exists, or kill the whole thing 933 if (defaultCase) { 934 std::unique_ptr<Statement> newBlock = block_for_case(&s, defaultCase); 935 if (newBlock) { 936 (*iter)->setStatement(std::move(newBlock)); 937 } else { 938 if (s.fIsStatic) { 939 this->error(s.fPosition, 940 "static switch contains non-static conditional break"); 941 s.fIsStatic = false; 942 } 943 return; // can't simplify 944 } 945 } else { 946 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop())); 947 } 948 } 949 *outUpdated = true; 950 *outNeedsRescan = true; 951 } 952 break; 953 } 954 case Statement::kExpression_Kind: { 955 ExpressionStatement& e = (ExpressionStatement&) *stmt; 956 ASSERT((*iter)->statement()->get() == &e); 957 if (!e.fExpression->hasSideEffects()) { 958 // Expression statement with no side effects, kill it 959 if (!b.tryRemoveExpressionBefore(iter, e.fExpression.get())) { 960 *outNeedsRescan = true; 961 } 962 ASSERT((*iter)->statement()->get() == stmt); 963 (*iter)->setStatement(std::unique_ptr<Statement>(new Nop())); 964 *outUpdated = true; 965 } 966 break; 967 } 968 default: 969 break; 970 } 971} 972 973void Compiler::scanCFG(FunctionDefinition& f) { 974 CFG cfg = CFGGenerator().getCFG(f); 975 this->computeDataFlow(&cfg); 976 977 // check for unreachable code 978 for (size_t i = 0; i < cfg.fBlocks.size(); i++) { 979 if (i != cfg.fStart && !cfg.fBlocks[i].fEntrances.size() && 980 cfg.fBlocks[i].fNodes.size()) { 981 Position p; 982 switch (cfg.fBlocks[i].fNodes[0].fKind) { 983 case BasicBlock::Node::kStatement_Kind: 984 p = (*cfg.fBlocks[i].fNodes[0].statement())->fPosition; 985 break; 986 case BasicBlock::Node::kExpression_Kind: 987 p = (*cfg.fBlocks[i].fNodes[0].expression())->fPosition; 988 break; 989 } 990 this->error(p, String("unreachable")); 991 } 992 } 993 if (fErrorCount) { 994 return; 995 } 996 997 // check for dead code & undefined variables, perform constant propagation 998 std::unordered_set<const Variable*> undefinedVariables; 999 bool updated; 1000 bool needsRescan = false; 1001 do { 1002 if (needsRescan) { 1003 cfg = CFGGenerator().getCFG(f); 1004 this->computeDataFlow(&cfg); 1005 needsRescan = false; 1006 } 1007 1008 updated = false; 1009 for (BasicBlock& b : cfg.fBlocks) { 1010 DefinitionMap definitions = b.fBefore; 1011 1012 for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) { 1013 if (iter->fKind == BasicBlock::Node::kExpression_Kind) { 1014 this->simplifyExpression(definitions, b, &iter, &undefinedVariables, &updated, 1015 &needsRescan); 1016 } else { 1017 this->simplifyStatement(definitions, b, &iter, &undefinedVariables, &updated, 1018 &needsRescan); 1019 } 1020 if (needsRescan) { 1021 break; 1022 } 1023 this->addDefinitions(*iter, &definitions); 1024 } 1025 } 1026 } while (updated); 1027 ASSERT(!needsRescan); 1028 1029 // verify static ifs & switches 1030 for (BasicBlock& b : cfg.fBlocks) { 1031 DefinitionMap definitions = b.fBefore; 1032 1033 for (auto iter = b.fNodes.begin(); iter != b.fNodes.end() && !needsRescan; ++iter) { 1034 if (iter->fKind == BasicBlock::Node::kStatement_Kind) { 1035 const Statement& s = **iter->statement(); 1036 switch (s.fKind) { 1037 case Statement::kIf_Kind: 1038 if (((const IfStatement&) s).fIsStatic) { 1039 this->error(s.fPosition, "static if has non-static test"); 1040 } 1041 break; 1042 case Statement::kSwitch_Kind: 1043 if (((const SwitchStatement&) s).fIsStatic) { 1044 this->error(s.fPosition, "static switch has non-static test"); 1045 } 1046 break; 1047 default: 1048 break; 1049 } 1050 } 1051 } 1052 } 1053 1054 // check for missing return 1055 if (f.fDeclaration.fReturnType != *fContext.fVoid_Type) { 1056 if (cfg.fBlocks[cfg.fExit].fEntrances.size()) { 1057 this->error(f.fPosition, String("function can exit without returning a value")); 1058 } 1059 } 1060} 1061 1062void Compiler::internalConvertProgram(String text, 1063 Modifiers::Flag* defaultPrecision, 1064 std::vector<std::unique_ptr<ProgramElement>>* result) { 1065 Parser parser(text, *fTypes, *this); 1066 std::vector<std::unique_ptr<ASTDeclaration>> parsed = parser.file(); 1067 if (fErrorCount) { 1068 return; 1069 } 1070 *defaultPrecision = Modifiers::kHighp_Flag; 1071 for (size_t i = 0; i < parsed.size(); i++) { 1072 ASTDeclaration& decl = *parsed[i]; 1073 switch (decl.fKind) { 1074 case ASTDeclaration::kVar_Kind: { 1075 std::unique_ptr<VarDeclarations> s = fIRGenerator->convertVarDeclarations( 1076 (ASTVarDeclarations&) decl, 1077 Variable::kGlobal_Storage); 1078 if (s) { 1079 result->push_back(std::move(s)); 1080 } 1081 break; 1082 } 1083 case ASTDeclaration::kFunction_Kind: { 1084 std::unique_ptr<FunctionDefinition> f = fIRGenerator->convertFunction( 1085 (ASTFunction&) decl); 1086 if (!fErrorCount && f) { 1087 this->scanCFG(*f); 1088 result->push_back(std::move(f)); 1089 } 1090 break; 1091 } 1092 case ASTDeclaration::kModifiers_Kind: { 1093 std::unique_ptr<ModifiersDeclaration> f = fIRGenerator->convertModifiersDeclaration( 1094 (ASTModifiersDeclaration&) decl); 1095 if (f) { 1096 result->push_back(std::move(f)); 1097 } 1098 break; 1099 } 1100 case ASTDeclaration::kInterfaceBlock_Kind: { 1101 std::unique_ptr<InterfaceBlock> i = fIRGenerator->convertInterfaceBlock( 1102 (ASTInterfaceBlock&) decl); 1103 if (i) { 1104 result->push_back(std::move(i)); 1105 } 1106 break; 1107 } 1108 case ASTDeclaration::kExtension_Kind: { 1109 std::unique_ptr<Extension> e = fIRGenerator->convertExtension((ASTExtension&) decl); 1110 if (e) { 1111 result->push_back(std::move(e)); 1112 } 1113 break; 1114 } 1115 case ASTDeclaration::kPrecision_Kind: { 1116 *defaultPrecision = ((ASTPrecision&) decl).fPrecision; 1117 break; 1118 } 1119 default: 1120 ABORT("unsupported declaration: %s\n", decl.description().c_str()); 1121 } 1122 } 1123} 1124 1125std::unique_ptr<Program> Compiler::convertProgram(Program::Kind kind, String text, 1126 const Program::Settings& settings) { 1127 fErrorText = ""; 1128 fErrorCount = 0; 1129 fIRGenerator->start(&settings); 1130 std::vector<std::unique_ptr<ProgramElement>> elements; 1131 Modifiers::Flag ignored; 1132 switch (kind) { 1133 case Program::kVertex_Kind: 1134 this->internalConvertProgram(String(SKSL_VERT_INCLUDE), &ignored, &elements); 1135 break; 1136 case Program::kFragment_Kind: 1137 this->internalConvertProgram(String(SKSL_FRAG_INCLUDE), &ignored, &elements); 1138 break; 1139 case Program::kGeometry_Kind: 1140 this->internalConvertProgram(String(SKSL_GEOM_INCLUDE), &ignored, &elements); 1141 break; 1142 } 1143 fIRGenerator->fSymbolTable->markAllFunctionsBuiltin(); 1144 Modifiers::Flag defaultPrecision; 1145 this->internalConvertProgram(text, &defaultPrecision, &elements); 1146 auto result = std::unique_ptr<Program>(new Program(kind, settings, defaultPrecision, &fContext, 1147 std::move(elements), 1148 fIRGenerator->fSymbolTable, 1149 fIRGenerator->fInputs)); 1150 fIRGenerator->finish(); 1151 this->writeErrorCount(); 1152 if (fErrorCount) { 1153 return nullptr; 1154 } 1155 return result; 1156} 1157 1158bool Compiler::toSPIRV(const Program& program, OutputStream& out) { 1159#ifdef SK_ENABLE_SPIRV_VALIDATION 1160 StringStream buffer; 1161 SPIRVCodeGenerator cg(&fContext, &program, this, &buffer); 1162 bool result = cg.generateCode(); 1163 if (result) { 1164 spvtools::SpirvTools tools(SPV_ENV_VULKAN_1_0); 1165 ASSERT(0 == buffer.size() % 4); 1166 auto dumpmsg = [](spv_message_level_t, const char*, const spv_position_t&, const char* m) { 1167 SkDebugf("SPIR-V validation error: %s\n", m); 1168 }; 1169 tools.SetMessageConsumer(dumpmsg); 1170 // Verify that the SPIR-V we produced is valid. If this assert fails, check the logs prior 1171 // to the failure to see the validation errors. 1172 ASSERT_RESULT(tools.Validate((const uint32_t*) buffer.data(), buffer.size() / 4)); 1173 out.write(buffer.data(), buffer.size()); 1174 } 1175#else 1176 SPIRVCodeGenerator cg(&fContext, &program, this, &out); 1177 bool result = cg.generateCode(); 1178#endif 1179 this->writeErrorCount(); 1180 return result; 1181} 1182 1183bool Compiler::toSPIRV(const Program& program, String* out) { 1184 StringStream buffer; 1185 bool result = this->toSPIRV(program, buffer); 1186 if (result) { 1187 *out = String(buffer.data(), buffer.size()); 1188 } 1189 return result; 1190} 1191 1192bool Compiler::toGLSL(const Program& program, OutputStream& out) { 1193 GLSLCodeGenerator cg(&fContext, &program, this, &out); 1194 bool result = cg.generateCode(); 1195 this->writeErrorCount(); 1196 return result; 1197} 1198 1199bool Compiler::toGLSL(const Program& program, String* out) { 1200 StringStream buffer; 1201 bool result = this->toGLSL(program, buffer); 1202 if (result) { 1203 *out = String(buffer.data(), buffer.size()); 1204 } 1205 return result; 1206} 1207 1208 1209void Compiler::error(Position position, String msg) { 1210 fErrorCount++; 1211 fErrorText += "error: " + position.description() + ": " + msg.c_str() + "\n"; 1212} 1213 1214String Compiler::errorText() { 1215 String result = fErrorText; 1216 return result; 1217} 1218 1219void Compiler::writeErrorCount() { 1220 if (fErrorCount) { 1221 fErrorText += to_string(fErrorCount) + " error"; 1222 if (fErrorCount > 1) { 1223 fErrorText += "s"; 1224 } 1225 fErrorText += "\n"; 1226 } 1227} 1228 1229} // namespace 1230