SkSLSPIRVCodeGenerator.cpp revision 8a83ca4e9afc9e3c08b4e8c33a74392f9b3154d7
1/* 2 * Copyright 2016 Google Inc. 3 * 4 * Use of this source code is governed by a BSD-style license that can be 5 * found in the LICENSE file. 6 */ 7 8#include "SkSLSPIRVCodeGenerator.h" 9 10#include "GLSL.std.450.h" 11 12#include "ir/SkSLExpressionStatement.h" 13#include "ir/SkSLExtension.h" 14#include "ir/SkSLIndexExpression.h" 15#include "ir/SkSLVariableReference.h" 16#include "SkSLCompiler.h" 17 18namespace SkSL { 19 20static const int32_t SKSL_MAGIC = 0x0; // FIXME: we should probably register a magic number 21 22void SPIRVCodeGenerator::setupIntrinsics() { 23#define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \ 24 GLSLstd450 ## x, GLSLstd450 ## x) 25#define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \ 26 GLSLstd450 ## ifFloat, \ 27 GLSLstd450 ## ifInt, \ 28 GLSLstd450 ## ifUInt, \ 29 SpvOpUndef) 30#define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \ 31 SpvOp ## x) 32#define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \ 33 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \ 34 k ## x ## _SpecialIntrinsic) 35 fIntrinsicMap[String("round")] = ALL_GLSL(Round); 36 fIntrinsicMap[String("roundEven")] = ALL_GLSL(RoundEven); 37 fIntrinsicMap[String("trunc")] = ALL_GLSL(Trunc); 38 fIntrinsicMap[String("abs")] = BY_TYPE_GLSL(FAbs, SAbs, SAbs); 39 fIntrinsicMap[String("sign")] = BY_TYPE_GLSL(FSign, SSign, SSign); 40 fIntrinsicMap[String("floor")] = ALL_GLSL(Floor); 41 fIntrinsicMap[String("ceil")] = ALL_GLSL(Ceil); 42 fIntrinsicMap[String("fract")] = ALL_GLSL(Fract); 43 fIntrinsicMap[String("radians")] = ALL_GLSL(Radians); 44 fIntrinsicMap[String("degrees")] = ALL_GLSL(Degrees); 45 fIntrinsicMap[String("sin")] = ALL_GLSL(Sin); 46 fIntrinsicMap[String("cos")] = ALL_GLSL(Cos); 47 fIntrinsicMap[String("tan")] = ALL_GLSL(Tan); 48 fIntrinsicMap[String("asin")] = ALL_GLSL(Asin); 49 fIntrinsicMap[String("acos")] = ALL_GLSL(Acos); 50 fIntrinsicMap[String("atan")] = SPECIAL(Atan); 51 fIntrinsicMap[String("sinh")] = ALL_GLSL(Sinh); 52 fIntrinsicMap[String("cosh")] = ALL_GLSL(Cosh); 53 fIntrinsicMap[String("tanh")] = ALL_GLSL(Tanh); 54 fIntrinsicMap[String("asinh")] = ALL_GLSL(Asinh); 55 fIntrinsicMap[String("acosh")] = ALL_GLSL(Acosh); 56 fIntrinsicMap[String("atanh")] = ALL_GLSL(Atanh); 57 fIntrinsicMap[String("pow")] = ALL_GLSL(Pow); 58 fIntrinsicMap[String("exp")] = ALL_GLSL(Exp); 59 fIntrinsicMap[String("log")] = ALL_GLSL(Log); 60 fIntrinsicMap[String("exp2")] = ALL_GLSL(Exp2); 61 fIntrinsicMap[String("log2")] = ALL_GLSL(Log2); 62 fIntrinsicMap[String("sqrt")] = ALL_GLSL(Sqrt); 63 fIntrinsicMap[String("inverse")] = ALL_GLSL(MatrixInverse); 64 fIntrinsicMap[String("transpose")] = ALL_SPIRV(Transpose); 65 fIntrinsicMap[String("inversesqrt")] = ALL_GLSL(InverseSqrt); 66 fIntrinsicMap[String("determinant")] = ALL_GLSL(Determinant); 67 fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse); 68 fIntrinsicMap[String("mod")] = SPECIAL(Mod); 69 fIntrinsicMap[String("min")] = BY_TYPE_GLSL(FMin, SMin, UMin); 70 fIntrinsicMap[String("max")] = BY_TYPE_GLSL(FMax, SMax, UMax); 71 fIntrinsicMap[String("clamp")] = BY_TYPE_GLSL(FClamp, SClamp, UClamp); 72 fIntrinsicMap[String("dot")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot, 73 SpvOpUndef, SpvOpUndef, SpvOpUndef); 74 fIntrinsicMap[String("mix")] = ALL_GLSL(FMix); 75 fIntrinsicMap[String("step")] = ALL_GLSL(Step); 76 fIntrinsicMap[String("smoothstep")] = ALL_GLSL(SmoothStep); 77 fIntrinsicMap[String("fma")] = ALL_GLSL(Fma); 78 fIntrinsicMap[String("frexp")] = ALL_GLSL(Frexp); 79 fIntrinsicMap[String("ldexp")] = ALL_GLSL(Ldexp); 80 81#define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \ 82 fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type) 83 PACK(Snorm4x8); 84 PACK(Unorm4x8); 85 PACK(Snorm2x16); 86 PACK(Unorm2x16); 87 PACK(Half2x16); 88 PACK(Double2x32); 89 fIntrinsicMap[String("length")] = ALL_GLSL(Length); 90 fIntrinsicMap[String("distance")] = ALL_GLSL(Distance); 91 fIntrinsicMap[String("cross")] = ALL_GLSL(Cross); 92 fIntrinsicMap[String("normalize")] = ALL_GLSL(Normalize); 93 fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward); 94 fIntrinsicMap[String("reflect")] = ALL_GLSL(Reflect); 95 fIntrinsicMap[String("refract")] = ALL_GLSL(Refract); 96 fIntrinsicMap[String("findLSB")] = ALL_GLSL(FindILsb); 97 fIntrinsicMap[String("findMSB")] = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb); 98 fIntrinsicMap[String("dFdx")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx, 99 SpvOpUndef, SpvOpUndef, SpvOpUndef); 100 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy, 101 SpvOpUndef, SpvOpUndef, SpvOpUndef); 102 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy, 103 SpvOpUndef, SpvOpUndef, SpvOpUndef); 104 fIntrinsicMap[String("texture")] = SPECIAL(Texture); 105 fIntrinsicMap[String("texelFetch")] = SPECIAL(TexelFetch); 106 fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad); 107 108 fIntrinsicMap[String("any")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef, 109 SpvOpUndef, SpvOpUndef, SpvOpAny); 110 fIntrinsicMap[String("all")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef, 111 SpvOpUndef, SpvOpUndef, SpvOpAll); 112 fIntrinsicMap[String("equal")] = std::make_tuple(kSPIRV_IntrinsicKind, 113 SpvOpFOrdEqual, SpvOpIEqual, 114 SpvOpIEqual, SpvOpLogicalEqual); 115 fIntrinsicMap[String("notEqual")] = std::make_tuple(kSPIRV_IntrinsicKind, 116 SpvOpFOrdNotEqual, SpvOpINotEqual, 117 SpvOpINotEqual, 118 SpvOpLogicalNotEqual); 119 fIntrinsicMap[String("lessThan")] = std::make_tuple(kSPIRV_IntrinsicKind, 120 SpvOpFOrdLessThan, SpvOpSLessThan, 121 SpvOpULessThan, SpvOpUndef); 122 fIntrinsicMap[String("lessThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind, 123 SpvOpFOrdLessThanEqual, 124 SpvOpSLessThanEqual, 125 SpvOpULessThanEqual, 126 SpvOpUndef); 127 fIntrinsicMap[String("greaterThan")] = std::make_tuple(kSPIRV_IntrinsicKind, 128 SpvOpFOrdGreaterThan, 129 SpvOpSGreaterThan, 130 SpvOpUGreaterThan, 131 SpvOpUndef); 132 fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind, 133 SpvOpFOrdGreaterThanEqual, 134 SpvOpSGreaterThanEqual, 135 SpvOpUGreaterThanEqual, 136 SpvOpUndef); 137 fIntrinsicMap[String("EmitVertex")] = ALL_SPIRV(EmitVertex); 138 fIntrinsicMap[String("EndPrimitive")] = ALL_SPIRV(EndPrimitive); 139// interpolateAt* not yet supported... 140} 141 142void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) { 143 out.write((const char*) &word, sizeof(word)); 144} 145 146static bool is_float(const Context& context, const Type& type) { 147 if (type.kind() == Type::kVector_Kind) { 148 return is_float(context, type.componentType()); 149 } 150 return type == *context.fFloat_Type || type == *context.fHalf_Type || 151 type == *context.fDouble_Type; 152} 153 154static bool is_signed(const Context& context, const Type& type) { 155 if (type.kind() == Type::kVector_Kind) { 156 return is_signed(context, type.componentType()); 157 } 158 return type == *context.fInt_Type || type == *context.fShort_Type; 159} 160 161static bool is_unsigned(const Context& context, const Type& type) { 162 if (type.kind() == Type::kVector_Kind) { 163 return is_unsigned(context, type.componentType()); 164 } 165 return type == *context.fUInt_Type || type == *context.fUShort_Type; 166} 167 168static bool is_bool(const Context& context, const Type& type) { 169 if (type.kind() == Type::kVector_Kind) { 170 return is_bool(context, type.componentType()); 171 } 172 return type == *context.fBool_Type; 173} 174 175static bool is_out(const Variable& var) { 176 return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0; 177} 178 179void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) { 180 ASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer); 181 ASSERT(opCode != SpvOpUndef); 182 switch (opCode) { 183 case SpvOpReturn: // fall through 184 case SpvOpReturnValue: // fall through 185 case SpvOpKill: // fall through 186 case SpvOpBranch: // fall through 187 case SpvOpBranchConditional: 188 ASSERT(fCurrentBlock); 189 fCurrentBlock = 0; 190 break; 191 case SpvOpConstant: // fall through 192 case SpvOpConstantTrue: // fall through 193 case SpvOpConstantFalse: // fall through 194 case SpvOpConstantComposite: // fall through 195 case SpvOpTypeVoid: // fall through 196 case SpvOpTypeInt: // fall through 197 case SpvOpTypeFloat: // fall through 198 case SpvOpTypeBool: // fall through 199 case SpvOpTypeVector: // fall through 200 case SpvOpTypeMatrix: // fall through 201 case SpvOpTypeArray: // fall through 202 case SpvOpTypePointer: // fall through 203 case SpvOpTypeFunction: // fall through 204 case SpvOpTypeRuntimeArray: // fall through 205 case SpvOpTypeStruct: // fall through 206 case SpvOpTypeImage: // fall through 207 case SpvOpTypeSampledImage: // fall through 208 case SpvOpVariable: // fall through 209 case SpvOpFunction: // fall through 210 case SpvOpFunctionParameter: // fall through 211 case SpvOpFunctionEnd: // fall through 212 case SpvOpExecutionMode: // fall through 213 case SpvOpMemoryModel: // fall through 214 case SpvOpCapability: // fall through 215 case SpvOpExtInstImport: // fall through 216 case SpvOpEntryPoint: // fall through 217 case SpvOpSource: // fall through 218 case SpvOpSourceExtension: // fall through 219 case SpvOpName: // fall through 220 case SpvOpMemberName: // fall through 221 case SpvOpDecorate: // fall through 222 case SpvOpMemberDecorate: 223 break; 224 default: 225 ASSERT(fCurrentBlock); 226 } 227 this->writeWord((length << 16) | opCode, out); 228} 229 230void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) { 231 fCurrentBlock = label; 232 this->writeInstruction(SpvOpLabel, label, out); 233} 234 235void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) { 236 this->writeOpCode(opCode, 1, out); 237} 238 239void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) { 240 this->writeOpCode(opCode, 2, out); 241 this->writeWord(word1, out); 242} 243 244void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) { 245 out.write(string, length); 246 switch (length % 4) { 247 case 1: 248 out.write8(0); 249 // fall through 250 case 2: 251 out.write8(0); 252 // fall through 253 case 3: 254 out.write8(0); 255 break; 256 default: 257 this->writeWord(0, out); 258 } 259} 260 261void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) { 262 this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out); 263 this->writeString(string.fChars, string.fLength, out); 264} 265 266 267void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string, 268 OutputStream& out) { 269 this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out); 270 this->writeWord(word1, out); 271 this->writeString(string.fChars, string.fLength, out); 272} 273 274void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 275 StringFragment string, OutputStream& out) { 276 this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out); 277 this->writeWord(word1, out); 278 this->writeWord(word2, out); 279 this->writeString(string.fChars, string.fLength, out); 280} 281 282void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 283 OutputStream& out) { 284 this->writeOpCode(opCode, 3, out); 285 this->writeWord(word1, out); 286 this->writeWord(word2, out); 287} 288 289void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 290 int32_t word3, OutputStream& out) { 291 this->writeOpCode(opCode, 4, out); 292 this->writeWord(word1, out); 293 this->writeWord(word2, out); 294 this->writeWord(word3, out); 295} 296 297void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 298 int32_t word3, int32_t word4, OutputStream& out) { 299 this->writeOpCode(opCode, 5, out); 300 this->writeWord(word1, out); 301 this->writeWord(word2, out); 302 this->writeWord(word3, out); 303 this->writeWord(word4, out); 304} 305 306void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 307 int32_t word3, int32_t word4, int32_t word5, 308 OutputStream& out) { 309 this->writeOpCode(opCode, 6, out); 310 this->writeWord(word1, out); 311 this->writeWord(word2, out); 312 this->writeWord(word3, out); 313 this->writeWord(word4, out); 314 this->writeWord(word5, out); 315} 316 317void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 318 int32_t word3, int32_t word4, int32_t word5, 319 int32_t word6, OutputStream& out) { 320 this->writeOpCode(opCode, 7, out); 321 this->writeWord(word1, out); 322 this->writeWord(word2, out); 323 this->writeWord(word3, out); 324 this->writeWord(word4, out); 325 this->writeWord(word5, out); 326 this->writeWord(word6, out); 327} 328 329void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 330 int32_t word3, int32_t word4, int32_t word5, 331 int32_t word6, int32_t word7, OutputStream& out) { 332 this->writeOpCode(opCode, 8, out); 333 this->writeWord(word1, out); 334 this->writeWord(word2, out); 335 this->writeWord(word3, out); 336 this->writeWord(word4, out); 337 this->writeWord(word5, out); 338 this->writeWord(word6, out); 339 this->writeWord(word7, out); 340} 341 342void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 343 int32_t word3, int32_t word4, int32_t word5, 344 int32_t word6, int32_t word7, int32_t word8, 345 OutputStream& out) { 346 this->writeOpCode(opCode, 9, out); 347 this->writeWord(word1, out); 348 this->writeWord(word2, out); 349 this->writeWord(word3, out); 350 this->writeWord(word4, out); 351 this->writeWord(word5, out); 352 this->writeWord(word6, out); 353 this->writeWord(word7, out); 354 this->writeWord(word8, out); 355} 356 357void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) { 358 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) { 359 if (fCapabilities & bit) { 360 this->writeInstruction(SpvOpCapability, (SpvId) i, out); 361 } 362 } 363 if (fProgram.fKind == Program::kGeometry_Kind) { 364 this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out); 365 } 366} 367 368SpvId SPIRVCodeGenerator::nextId() { 369 return fIdCount++; 370} 371 372void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout, 373 SpvId resultId) { 374 this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer); 375 // go ahead and write all of the field types, so we don't inadvertently write them while we're 376 // in the middle of writing the struct instruction 377 std::vector<SpvId> types; 378 for (const auto& f : type.fields()) { 379 types.push_back(this->getType(*f.fType, memoryLayout)); 380 } 381 this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer); 382 this->writeWord(resultId, fConstantBuffer); 383 for (SpvId id : types) { 384 this->writeWord(id, fConstantBuffer); 385 } 386 size_t offset = 0; 387 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) { 388 size_t size = memoryLayout.size(*type.fields()[i].fType); 389 size_t alignment = memoryLayout.alignment(*type.fields()[i].fType); 390 const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout; 391 if (fieldLayout.fOffset >= 0) { 392 if (fieldLayout.fOffset < (int) offset) { 393 fErrors.error(type.fOffset, 394 "offset of field '" + type.fields()[i].fName + "' must be at " 395 "least " + to_string((int) offset)); 396 } 397 if (fieldLayout.fOffset % alignment) { 398 fErrors.error(type.fOffset, 399 "offset of field '" + type.fields()[i].fName + "' must be a multiple" 400 " of " + to_string((int) alignment)); 401 } 402 offset = fieldLayout.fOffset; 403 } else { 404 size_t mod = offset % alignment; 405 if (mod) { 406 offset += alignment - mod; 407 } 408 } 409 this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName, fNameBuffer); 410 this->writeLayout(fieldLayout, resultId, i); 411 if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) { 412 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset, 413 (SpvId) offset, fDecorationBuffer); 414 } 415 if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) { 416 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor, 417 fDecorationBuffer); 418 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride, 419 (SpvId) memoryLayout.stride(*type.fields()[i].fType), 420 fDecorationBuffer); 421 } 422 offset += size; 423 Type::Kind kind = type.fields()[i].fType->kind(); 424 if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) { 425 offset += alignment - offset % alignment; 426 } 427 } 428} 429 430Type SPIRVCodeGenerator::getActualType(const Type& type) { 431 if (type == *fContext.fHalf_Type) { 432 return *fContext.fFloat_Type; 433 } 434 if (type == *fContext.fShort_Type) { 435 return *fContext.fInt_Type; 436 } 437 if (type == *fContext.fUShort_Type) { 438 return *fContext.fUInt_Type; 439 } 440 if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) { 441 if (type.componentType() == *fContext.fHalf_Type) { 442 return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows()); 443 } 444 if (type.componentType() == *fContext.fShort_Type) { 445 return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows()); 446 } 447 if (type.componentType() == *fContext.fUShort_Type) { 448 return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows()); 449 } 450 } 451 return type; 452} 453 454SpvId SPIRVCodeGenerator::getType(const Type& type) { 455 return this->getType(type, fDefaultLayout); 456} 457 458SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) { 459 Type type = this->getActualType(rawType); 460 String key = type.name() + to_string((int) layout.fStd); 461 auto entry = fTypeMap.find(key); 462 if (entry == fTypeMap.end()) { 463 SpvId result = this->nextId(); 464 switch (type.kind()) { 465 case Type::kScalar_Kind: 466 if (type == *fContext.fBool_Type) { 467 this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer); 468 } else if (type == *fContext.fInt_Type) { 469 this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer); 470 } else if (type == *fContext.fUInt_Type) { 471 this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer); 472 } else if (type == *fContext.fFloat_Type) { 473 this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer); 474 } else if (type == *fContext.fDouble_Type) { 475 this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer); 476 } else { 477 ASSERT(false); 478 } 479 break; 480 case Type::kVector_Kind: 481 this->writeInstruction(SpvOpTypeVector, result, 482 this->getType(type.componentType(), layout), 483 type.columns(), fConstantBuffer); 484 break; 485 case Type::kMatrix_Kind: 486 this->writeInstruction(SpvOpTypeMatrix, result, 487 this->getType(index_type(fContext, type), layout), 488 type.columns(), fConstantBuffer); 489 break; 490 case Type::kStruct_Kind: 491 this->writeStruct(type, layout, result); 492 break; 493 case Type::kArray_Kind: { 494 if (type.columns() > 0) { 495 IntLiteral count(fContext, -1, type.columns()); 496 this->writeInstruction(SpvOpTypeArray, result, 497 this->getType(type.componentType(), layout), 498 this->writeIntLiteral(count), fConstantBuffer); 499 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, 500 (int32_t) layout.stride(type), 501 fDecorationBuffer); 502 } else { 503 this->writeInstruction(SpvOpTypeRuntimeArray, result, 504 this->getType(type.componentType(), layout), 505 fConstantBuffer); 506 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, 507 (int32_t) layout.stride(type), 508 fDecorationBuffer); 509 } 510 break; 511 } 512 case Type::kSampler_Kind: { 513 SpvId image = result; 514 if (SpvDimSubpassData != type.dimensions()) { 515 image = this->nextId(); 516 } 517 if (SpvDimBuffer == type.dimensions()) { 518 fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer); 519 } 520 this->writeInstruction(SpvOpTypeImage, image, 521 this->getType(*fContext.fFloat_Type, layout), 522 type.dimensions(), type.isDepth(), type.isArrayed(), 523 type.isMultisampled(), type.isSampled() ? 1 : 2, 524 SpvImageFormatUnknown, fConstantBuffer); 525 fImageTypeMap[key] = image; 526 if (SpvDimSubpassData != type.dimensions()) { 527 this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer); 528 } 529 break; 530 } 531 default: 532 if (type == *fContext.fVoid_Type) { 533 this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer); 534 } else { 535 ABORT("invalid type: %s", type.description().c_str()); 536 } 537 } 538 fTypeMap[key] = result; 539 return result; 540 } 541 return entry->second; 542} 543 544SpvId SPIRVCodeGenerator::getImageType(const Type& type) { 545 ASSERT(type.kind() == Type::kSampler_Kind); 546 this->getType(type); 547 String key = type.name() + to_string((int) fDefaultLayout.fStd); 548 ASSERT(fImageTypeMap.find(key) != fImageTypeMap.end()); 549 return fImageTypeMap[key]; 550} 551 552SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) { 553 String key = function.fReturnType.description() + "("; 554 String separator; 555 for (size_t i = 0; i < function.fParameters.size(); i++) { 556 key += separator; 557 separator = ", "; 558 key += function.fParameters[i]->fType.description(); 559 } 560 key += ")"; 561 auto entry = fTypeMap.find(key); 562 if (entry == fTypeMap.end()) { 563 SpvId result = this->nextId(); 564 int32_t length = 3 + (int32_t) function.fParameters.size(); 565 SpvId returnType = this->getType(function.fReturnType); 566 std::vector<SpvId> parameterTypes; 567 for (size_t i = 0; i < function.fParameters.size(); i++) { 568 // glslang seems to treat all function arguments as pointers whether they need to be or 569 // not. I was initially puzzled by this until I ran bizarre failures with certain 570 // patterns of function calls and control constructs, as exemplified by this minimal 571 // failure case: 572 // 573 // void sphere(float x) { 574 // } 575 // 576 // void map() { 577 // sphere(1.0); 578 // } 579 // 580 // void main() { 581 // for (int i = 0; i < 1; i++) { 582 // map(); 583 // } 584 // } 585 // 586 // As of this writing, compiling this in the "obvious" way (with sphere taking a float) 587 // crashes. Making it take a float* and storing the argument in a temporary variable, 588 // as glslang does, fixes it. It's entirely possible I simply missed whichever part of 589 // the spec makes this make sense. 590// if (is_out(function->fParameters[i])) { 591 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType, 592 SpvStorageClassFunction)); 593// } else { 594// parameterTypes.push_back(this->getType(function.fParameters[i]->fType)); 595// } 596 } 597 this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer); 598 this->writeWord(result, fConstantBuffer); 599 this->writeWord(returnType, fConstantBuffer); 600 for (SpvId id : parameterTypes) { 601 this->writeWord(id, fConstantBuffer); 602 } 603 fTypeMap[key] = result; 604 return result; 605 } 606 return entry->second; 607} 608 609SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) { 610 return this->getPointerType(type, fDefaultLayout, storageClass); 611} 612 613SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout, 614 SpvStorageClass_ storageClass) { 615 Type type = this->getActualType(rawType); 616 String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass); 617 auto entry = fTypeMap.find(key); 618 if (entry == fTypeMap.end()) { 619 SpvId result = this->nextId(); 620 this->writeInstruction(SpvOpTypePointer, result, storageClass, 621 this->getType(type), fConstantBuffer); 622 fTypeMap[key] = result; 623 return result; 624 } 625 return entry->second; 626} 627 628SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) { 629 switch (expr.fKind) { 630 case Expression::kBinary_Kind: 631 return this->writeBinaryExpression((BinaryExpression&) expr, out); 632 case Expression::kBoolLiteral_Kind: 633 return this->writeBoolLiteral((BoolLiteral&) expr); 634 case Expression::kConstructor_Kind: 635 return this->writeConstructor((Constructor&) expr, out); 636 case Expression::kIntLiteral_Kind: 637 return this->writeIntLiteral((IntLiteral&) expr); 638 case Expression::kFieldAccess_Kind: 639 return this->writeFieldAccess(((FieldAccess&) expr), out); 640 case Expression::kFloatLiteral_Kind: 641 return this->writeFloatLiteral(((FloatLiteral&) expr)); 642 case Expression::kFunctionCall_Kind: 643 return this->writeFunctionCall((FunctionCall&) expr, out); 644 case Expression::kPrefix_Kind: 645 return this->writePrefixExpression((PrefixExpression&) expr, out); 646 case Expression::kPostfix_Kind: 647 return this->writePostfixExpression((PostfixExpression&) expr, out); 648 case Expression::kSwizzle_Kind: 649 return this->writeSwizzle((Swizzle&) expr, out); 650 case Expression::kVariableReference_Kind: 651 return this->writeVariableReference((VariableReference&) expr, out); 652 case Expression::kTernary_Kind: 653 return this->writeTernaryExpression((TernaryExpression&) expr, out); 654 case Expression::kIndex_Kind: 655 return this->writeIndexExpression((IndexExpression&) expr, out); 656 default: 657 ABORT("unsupported expression: %s", expr.description().c_str()); 658 } 659 return -1; 660} 661 662SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) { 663 auto intrinsic = fIntrinsicMap.find(c.fFunction.fName); 664 ASSERT(intrinsic != fIntrinsicMap.end()); 665 int32_t intrinsicId; 666 if (c.fArguments.size() > 0) { 667 const Type& type = c.fArguments[0]->fType; 668 if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) { 669 intrinsicId = std::get<1>(intrinsic->second); 670 } else if (is_signed(fContext, type)) { 671 intrinsicId = std::get<2>(intrinsic->second); 672 } else if (is_unsigned(fContext, type)) { 673 intrinsicId = std::get<3>(intrinsic->second); 674 } else if (is_bool(fContext, type)) { 675 intrinsicId = std::get<4>(intrinsic->second); 676 } else { 677 intrinsicId = std::get<1>(intrinsic->second); 678 } 679 } else { 680 intrinsicId = std::get<1>(intrinsic->second); 681 } 682 switch (std::get<0>(intrinsic->second)) { 683 case kGLSL_STD_450_IntrinsicKind: { 684 SpvId result = this->nextId(); 685 std::vector<SpvId> arguments; 686 for (size_t i = 0; i < c.fArguments.size(); i++) { 687 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 688 } 689 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out); 690 this->writeWord(this->getType(c.fType), out); 691 this->writeWord(result, out); 692 this->writeWord(fGLSLExtendedInstructions, out); 693 this->writeWord(intrinsicId, out); 694 for (SpvId id : arguments) { 695 this->writeWord(id, out); 696 } 697 return result; 698 } 699 case kSPIRV_IntrinsicKind: { 700 SpvId result = this->nextId(); 701 std::vector<SpvId> arguments; 702 for (size_t i = 0; i < c.fArguments.size(); i++) { 703 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 704 } 705 if (c.fType != *fContext.fVoid_Type) { 706 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out); 707 this->writeWord(this->getType(c.fType), out); 708 this->writeWord(result, out); 709 } else { 710 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out); 711 } 712 for (SpvId id : arguments) { 713 this->writeWord(id, out); 714 } 715 return result; 716 } 717 case kSpecial_IntrinsicKind: 718 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out); 719 default: 720 ABORT("unsupported intrinsic kind"); 721 } 722} 723 724SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, 725 OutputStream& out) { 726 SpvId result = this->nextId(); 727 switch (kind) { 728 case kAtan_SpecialIntrinsic: { 729 std::vector<SpvId> arguments; 730 for (size_t i = 0; i < c.fArguments.size(); i++) { 731 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 732 } 733 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out); 734 this->writeWord(this->getType(c.fType), out); 735 this->writeWord(result, out); 736 this->writeWord(fGLSLExtendedInstructions, out); 737 this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out); 738 for (SpvId id : arguments) { 739 this->writeWord(id, out); 740 } 741 break; 742 } 743 case kSubpassLoad_SpecialIntrinsic: { 744 SpvId img = this->writeExpression(*c.fArguments[0], out); 745 std::vector<std::unique_ptr<Expression>> args; 746 args.emplace_back(new FloatLiteral(fContext, -1, 0.0)); 747 args.emplace_back(new FloatLiteral(fContext, -1, 0.0)); 748 Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args)); 749 SpvId coords = this->writeConstantVector(ctor); 750 if (1 == c.fArguments.size()) { 751 this->writeInstruction(SpvOpImageRead, 752 this->getType(c.fType), 753 result, 754 img, 755 coords, 756 out); 757 } else { 758 ASSERT(2 == c.fArguments.size()); 759 SpvId sample = this->writeExpression(*c.fArguments[1], out); 760 this->writeInstruction(SpvOpImageRead, 761 this->getType(c.fType), 762 result, 763 img, 764 coords, 765 SpvImageOperandsSampleMask, 766 sample, 767 out); 768 } 769 break; 770 } 771 case kTexelFetch_SpecialIntrinsic: { 772 ASSERT(c.fArguments.size() == 2); 773 SpvId image = this->nextId(); 774 this->writeInstruction(SpvOpImage, 775 this->getImageType(c.fArguments[0]->fType), 776 image, 777 this->writeExpression(*c.fArguments[0], out), 778 out); 779 this->writeInstruction(SpvOpImageFetch, 780 this->getType(c.fType), 781 result, 782 image, 783 this->writeExpression(*c.fArguments[1], out), 784 out); 785 break; 786 } 787 case kTexture_SpecialIntrinsic: { 788 SpvOp_ op = SpvOpImageSampleImplicitLod; 789 switch (c.fArguments[0]->fType.dimensions()) { 790 case SpvDim1D: 791 if (c.fArguments[1]->fType == *fContext.fFloat2_Type) { 792 op = SpvOpImageSampleProjImplicitLod; 793 } else { 794 ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type); 795 } 796 break; 797 case SpvDim2D: 798 if (c.fArguments[1]->fType == *fContext.fFloat3_Type) { 799 op = SpvOpImageSampleProjImplicitLod; 800 } else { 801 ASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type); 802 } 803 break; 804 case SpvDim3D: 805 if (c.fArguments[1]->fType == *fContext.fFloat4_Type) { 806 op = SpvOpImageSampleProjImplicitLod; 807 } else { 808 ASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type); 809 } 810 break; 811 case SpvDimCube: // fall through 812 case SpvDimRect: // fall through 813 case SpvDimBuffer: // fall through 814 case SpvDimSubpassData: 815 break; 816 } 817 SpvId type = this->getType(c.fType); 818 SpvId sampler = this->writeExpression(*c.fArguments[0], out); 819 SpvId uv = this->writeExpression(*c.fArguments[1], out); 820 if (c.fArguments.size() == 3) { 821 this->writeInstruction(op, type, result, sampler, uv, 822 SpvImageOperandsBiasMask, 823 this->writeExpression(*c.fArguments[2], out), 824 out); 825 } else { 826 ASSERT(c.fArguments.size() == 2); 827 if (fProgram.fSettings.fSharpenTextures) { 828 FloatLiteral lodBias(fContext, -1, -0.5); 829 this->writeInstruction(op, type, result, sampler, uv, 830 SpvImageOperandsBiasMask, 831 this->writeFloatLiteral(lodBias), 832 out); 833 } else { 834 this->writeInstruction(op, type, result, sampler, uv, 835 out); 836 } 837 } 838 break; 839 } 840 case kMod_SpecialIntrinsic: { 841 ASSERT(c.fArguments.size() == 2); 842 SpvId arg1 = this->writeExpression(*c.fArguments[0], out); 843 SpvId arg2 = this->writeExpression(*c.fArguments[1], out); 844 if (c.fArguments[0]->fType != c.fArguments[1]->fType) { 845 // we have mod(vector, scalar), but SPIR-V wants mod(vector, vector) 846 ASSERT(c.fArguments[0]->fType.componentType() == c.fArguments[1]->fType); 847 SpvId scalar = arg2; 848 const Type& type = c.fArguments[0]->fType; 849 arg2 = this->nextId(); 850 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(), out); 851 this->writeWord(this->getType(type), out); 852 this->writeWord(arg2, out); 853 for (int i = 0; i < type.columns(); i++) { 854 this->writeWord(scalar, out); 855 } 856 } 857 const Type& operandType = c.fArguments[0]->fType; 858 SpvOp_ op; 859 if (is_float(fContext, operandType)) { 860 op = SpvOpFMod; 861 } else if (is_signed(fContext, operandType)) { 862 op = SpvOpSMod; 863 } else if (is_unsigned(fContext, operandType)) { 864 op = SpvOpUMod; 865 } else { 866 ASSERT(false); 867 return 0; 868 } 869 this->writeOpCode(op, 5, out); 870 this->writeWord(this->getType(operandType), out); 871 this->writeWord(result, out); 872 this->writeWord(arg1, out); 873 this->writeWord(arg2, out); 874 } 875 } 876 return result; 877} 878 879SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) { 880 const auto& entry = fFunctionMap.find(&c.fFunction); 881 if (entry == fFunctionMap.end()) { 882 return this->writeIntrinsicCall(c, out); 883 } 884 // stores (variable, type, lvalue) pairs to extract and save after the function call is complete 885 std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues; 886 std::vector<SpvId> arguments; 887 for (size_t i = 0; i < c.fArguments.size(); i++) { 888 // id of temporary variable that we will use to hold this argument, or 0 if it is being 889 // passed directly 890 SpvId tmpVar; 891 // if we need a temporary var to store this argument, this is the value to store in the var 892 SpvId tmpValueId; 893 if (is_out(*c.fFunction.fParameters[i])) { 894 std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out); 895 SpvId ptr = lv->getPointer(); 896 if (ptr) { 897 arguments.push_back(ptr); 898 continue; 899 } else { 900 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to 901 // copy it into a temp, call the function, read the value out of the temp, and then 902 // update the lvalue. 903 tmpValueId = lv->load(out); 904 tmpVar = this->nextId(); 905 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType), 906 std::move(lv))); 907 } 908 } else { 909 // see getFunctionType for an explanation of why we're always using pointer parameters 910 tmpValueId = this->writeExpression(*c.fArguments[i], out); 911 tmpVar = this->nextId(); 912 } 913 this->writeInstruction(SpvOpVariable, 914 this->getPointerType(c.fArguments[i]->fType, 915 SpvStorageClassFunction), 916 tmpVar, 917 SpvStorageClassFunction, 918 fVariableBuffer); 919 this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out); 920 arguments.push_back(tmpVar); 921 } 922 SpvId result = this->nextId(); 923 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out); 924 this->writeWord(this->getType(c.fType), out); 925 this->writeWord(result, out); 926 this->writeWord(entry->second, out); 927 for (SpvId id : arguments) { 928 this->writeWord(id, out); 929 } 930 // now that the call is complete, we may need to update some lvalues with the new values of out 931 // arguments 932 for (const auto& tuple : lvalues) { 933 SpvId load = this->nextId(); 934 this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out); 935 std::get<2>(tuple)->store(load, out); 936 } 937 return result; 938} 939 940SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) { 941 ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant()); 942 SpvId result = this->nextId(); 943 std::vector<SpvId> arguments; 944 for (size_t i = 0; i < c.fArguments.size(); i++) { 945 arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer)); 946 } 947 SpvId type = this->getType(c.fType); 948 if (c.fArguments.size() == 1) { 949 // with a single argument, a vector will have all of its entries equal to the argument 950 this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer); 951 this->writeWord(type, fConstantBuffer); 952 this->writeWord(result, fConstantBuffer); 953 for (int i = 0; i < c.fType.columns(); i++) { 954 this->writeWord(arguments[0], fConstantBuffer); 955 } 956 } else { 957 this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(), 958 fConstantBuffer); 959 this->writeWord(type, fConstantBuffer); 960 this->writeWord(result, fConstantBuffer); 961 for (SpvId id : arguments) { 962 this->writeWord(id, fConstantBuffer); 963 } 964 } 965 return result; 966} 967 968SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) { 969 ASSERT(c.fType.isFloat()); 970 ASSERT(c.fArguments.size() == 1); 971 ASSERT(c.fArguments[0]->fType.isNumber()); 972 SpvId result = this->nextId(); 973 SpvId parameter = this->writeExpression(*c.fArguments[0], out); 974 if (c.fArguments[0]->fType.isSigned()) { 975 this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter, 976 out); 977 } else { 978 ASSERT(c.fArguments[0]->fType.isUnsigned()); 979 this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter, 980 out); 981 } 982 return result; 983} 984 985SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) { 986 ASSERT(c.fType.isSigned()); 987 ASSERT(c.fArguments.size() == 1); 988 ASSERT(c.fArguments[0]->fType.isNumber()); 989 SpvId result = this->nextId(); 990 SpvId parameter = this->writeExpression(*c.fArguments[0], out); 991 if (c.fArguments[0]->fType.isFloat()) { 992 this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter, 993 out); 994 } 995 else { 996 ASSERT(c.fArguments[0]->fType.isUnsigned()); 997 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter, 998 out); 999 } 1000 return result; 1001} 1002 1003SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) { 1004 ASSERT(c.fType.isUnsigned()); 1005 ASSERT(c.fArguments.size() == 1); 1006 ASSERT(c.fArguments[0]->fType.isNumber()); 1007 SpvId result = this->nextId(); 1008 SpvId parameter = this->writeExpression(*c.fArguments[0], out); 1009 if (c.fArguments[0]->fType.isFloat()) { 1010 this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter, 1011 out); 1012 } else { 1013 ASSERT(c.fArguments[0]->fType.isSigned()); 1014 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter, 1015 out); 1016 } 1017 return result; 1018} 1019 1020void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, 1021 OutputStream& out) { 1022 FloatLiteral zero(fContext, -1, 0); 1023 SpvId zeroId = this->writeFloatLiteral(zero); 1024 std::vector<SpvId> columnIds; 1025 for (int column = 0; column < type.columns(); column++) { 1026 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(), 1027 out); 1028 this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)), 1029 out); 1030 SpvId columnId = this->nextId(); 1031 this->writeWord(columnId, out); 1032 columnIds.push_back(columnId); 1033 for (int row = 0; row < type.columns(); row++) { 1034 this->writeWord(row == column ? diagonal : zeroId, out); 1035 } 1036 } 1037 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(), 1038 out); 1039 this->writeWord(this->getType(type), out); 1040 this->writeWord(id, out); 1041 for (SpvId id : columnIds) { 1042 this->writeWord(id, out); 1043 } 1044} 1045 1046void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType, 1047 const Type& dstType, OutputStream& out) { 1048 ASSERT(srcType.kind() == Type::kMatrix_Kind); 1049 ASSERT(dstType.kind() == Type::kMatrix_Kind); 1050 ASSERT(srcType.componentType() == dstType.componentType()); 1051 SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext, 1052 srcType.rows(), 1053 1)); 1054 SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext, 1055 dstType.rows(), 1056 1)); 1057 SpvId zeroId; 1058 if (dstType.componentType() == *fContext.fFloat_Type) { 1059 FloatLiteral zero(fContext, -1, 0.0); 1060 zeroId = this->writeFloatLiteral(zero); 1061 } else if (dstType.componentType() == *fContext.fInt_Type) { 1062 IntLiteral zero(fContext, -1, 0); 1063 zeroId = this->writeIntLiteral(zero); 1064 } else { 1065 ABORT("unsupported matrix component type"); 1066 } 1067 SpvId zeroColumn = 0; 1068 SpvId columns[4]; 1069 for (int i = 0; i < dstType.columns(); i++) { 1070 if (i < srcType.columns()) { 1071 // we're still inside the src matrix, copy the column 1072 SpvId srcColumn = this->nextId(); 1073 this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out); 1074 SpvId dstColumn; 1075 if (srcType.rows() == dstType.rows()) { 1076 // columns are equal size, don't need to do anything 1077 dstColumn = srcColumn; 1078 } 1079 else if (dstType.rows() > srcType.rows()) { 1080 // dst column is bigger, need to zero-pad it 1081 dstColumn = this->nextId(); 1082 int delta = dstType.rows() - srcType.rows(); 1083 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out); 1084 this->writeWord(dstColumnType, out); 1085 this->writeWord(dstColumn, out); 1086 this->writeWord(srcColumn, out); 1087 for (int i = 0; i < delta; ++i) { 1088 this->writeWord(zeroId, out); 1089 } 1090 } 1091 else { 1092 // dst column is smaller, need to swizzle the src column 1093 dstColumn = this->nextId(); 1094 int count = dstType.rows(); 1095 this->writeOpCode(SpvOpVectorShuffle, 5 + count, out); 1096 this->writeWord(dstColumnType, out); 1097 this->writeWord(dstColumn, out); 1098 this->writeWord(srcColumn, out); 1099 this->writeWord(srcColumn, out); 1100 for (int i = 0; i < count; i++) { 1101 this->writeWord(i, out); 1102 } 1103 } 1104 columns[i] = dstColumn; 1105 } else { 1106 // we're past the end of the src matrix, need a vector of zeroes 1107 if (!zeroColumn) { 1108 zeroColumn = this->nextId(); 1109 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out); 1110 this->writeWord(dstColumnType, out); 1111 this->writeWord(zeroColumn, out); 1112 for (int i = 0; i < dstType.rows(); ++i) { 1113 this->writeWord(zeroId, out); 1114 } 1115 } 1116 columns[i] = zeroColumn; 1117 } 1118 } 1119 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out); 1120 this->writeWord(this->getType(dstType), out); 1121 this->writeWord(id, out); 1122 for (int i = 0; i < dstType.columns(); i++) { 1123 this->writeWord(columns[i], out); 1124 } 1125} 1126 1127SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) { 1128 ASSERT(c.fType.kind() == Type::kMatrix_Kind); 1129 // go ahead and write the arguments so we don't try to write new instructions in the middle of 1130 // an instruction 1131 std::vector<SpvId> arguments; 1132 for (size_t i = 0; i < c.fArguments.size(); i++) { 1133 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1134 } 1135 SpvId result = this->nextId(); 1136 int rows = c.fType.rows(); 1137 int columns = c.fType.columns(); 1138 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) { 1139 this->writeUniformScaleMatrix(result, arguments[0], c.fType, out); 1140 } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) { 1141 this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out); 1142 } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) { 1143 ASSERT(c.fType.rows() == 2 && c.fType.columns() == 2); 1144 ASSERT(c.fArguments[0]->fType.columns() == 4); 1145 SpvId componentType = this->getType(c.fType.componentType()); 1146 SpvId v[4]; 1147 for (int i = 0; i < 4; ++i) { 1148 v[i] = this->nextId(); 1149 this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out); 1150 } 1151 SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1)); 1152 SpvId column1 = this->nextId(); 1153 this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out); 1154 SpvId column2 = this->nextId(); 1155 this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out); 1156 this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1, 1157 column2, out); 1158 } else { 1159 std::vector<SpvId> columnIds; 1160 // ids of vectors and scalars we have written to the current column so far 1161 std::vector<SpvId> currentColumn; 1162 // the total number of scalars represented by currentColumn's entries 1163 int currentCount = 0; 1164 for (size_t i = 0; i < arguments.size(); i++) { 1165 if (c.fArguments[i]->fType.kind() == Type::kVector_Kind && 1166 c.fArguments[i]->fType.columns() == c.fType.rows()) { 1167 // this is a complete column by itself 1168 ASSERT(currentCount == 0); 1169 columnIds.push_back(arguments[i]); 1170 } else { 1171 currentColumn.push_back(arguments[i]); 1172 currentCount += c.fArguments[i]->fType.columns(); 1173 if (currentCount == rows) { 1174 currentCount = 0; 1175 this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn.size(), out); 1176 this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows, 1177 1)), 1178 out); 1179 SpvId columnId = this->nextId(); 1180 this->writeWord(columnId, out); 1181 columnIds.push_back(columnId); 1182 for (SpvId id : currentColumn) { 1183 this->writeWord(id, out); 1184 } 1185 currentColumn.clear(); 1186 } 1187 ASSERT(currentCount < rows); 1188 } 1189 } 1190 ASSERT(columnIds.size() == (size_t) columns); 1191 this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out); 1192 this->writeWord(this->getType(c.fType), out); 1193 this->writeWord(result, out); 1194 for (SpvId id : columnIds) { 1195 this->writeWord(id, out); 1196 } 1197 } 1198 return result; 1199} 1200 1201SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) { 1202 ASSERT(c.fType.kind() == Type::kVector_Kind); 1203 if (c.isConstant()) { 1204 return this->writeConstantVector(c); 1205 } 1206 // go ahead and write the arguments so we don't try to write new instructions in the middle of 1207 // an instruction 1208 std::vector<SpvId> arguments; 1209 for (size_t i = 0; i < c.fArguments.size(); i++) { 1210 if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) { 1211 // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to 1212 // extract the components and convert them in that case manually. On top of that, 1213 // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite 1214 // doesn't handle vector arguments at all, so we always extract vector components and 1215 // pass them into OpCreateComposite individually. 1216 SpvId vec = this->writeExpression(*c.fArguments[i], out); 1217 SpvOp_ op = SpvOpUndef; 1218 const Type& src = c.fArguments[i]->fType.componentType(); 1219 const Type& dst = c.fType.componentType(); 1220 if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) { 1221 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) { 1222 if (c.fArguments.size() == 1) { 1223 return vec; 1224 } 1225 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) { 1226 op = SpvOpConvertSToF; 1227 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) { 1228 op = SpvOpConvertUToF; 1229 } else { 1230 ASSERT(false); 1231 } 1232 } else if (dst == *fContext.fInt_Type || dst == *fContext.fShort_Type) { 1233 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) { 1234 op = SpvOpConvertFToS; 1235 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) { 1236 if (c.fArguments.size() == 1) { 1237 return vec; 1238 } 1239 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) { 1240 op = SpvOpBitcast; 1241 } else { 1242 ASSERT(false); 1243 } 1244 } else if (dst == *fContext.fUInt_Type || dst == *fContext.fUShort_Type) { 1245 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) { 1246 op = SpvOpConvertFToS; 1247 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) { 1248 op = SpvOpBitcast; 1249 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) { 1250 if (c.fArguments.size() == 1) { 1251 return vec; 1252 } 1253 } else { 1254 ASSERT(false); 1255 } 1256 } 1257 for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) { 1258 SpvId swizzle = this->nextId(); 1259 this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j, 1260 out); 1261 if (op != SpvOpUndef) { 1262 SpvId cast = this->nextId(); 1263 this->writeInstruction(op, this->getType(dst), cast, swizzle, out); 1264 arguments.push_back(cast); 1265 } else { 1266 arguments.push_back(swizzle); 1267 } 1268 } 1269 } else { 1270 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1271 } 1272 } 1273 SpvId result = this->nextId(); 1274 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) { 1275 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out); 1276 this->writeWord(this->getType(c.fType), out); 1277 this->writeWord(result, out); 1278 for (int i = 0; i < c.fType.columns(); i++) { 1279 this->writeWord(arguments[0], out); 1280 } 1281 } else { 1282 ASSERT(arguments.size() > 1); 1283 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out); 1284 this->writeWord(this->getType(c.fType), out); 1285 this->writeWord(result, out); 1286 for (SpvId id : arguments) { 1287 this->writeWord(id, out); 1288 } 1289 } 1290 return result; 1291} 1292 1293SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) { 1294 ASSERT(c.fType.kind() == Type::kArray_Kind); 1295 // go ahead and write the arguments so we don't try to write new instructions in the middle of 1296 // an instruction 1297 std::vector<SpvId> arguments; 1298 for (size_t i = 0; i < c.fArguments.size(); i++) { 1299 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1300 } 1301 SpvId result = this->nextId(); 1302 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out); 1303 this->writeWord(this->getType(c.fType), out); 1304 this->writeWord(result, out); 1305 for (SpvId id : arguments) { 1306 this->writeWord(id, out); 1307 } 1308 return result; 1309} 1310 1311SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) { 1312 if (c.fArguments.size() == 1 && 1313 this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) { 1314 return this->writeExpression(*c.fArguments[0], out); 1315 } 1316 if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) { 1317 return this->writeFloatConstructor(c, out); 1318 } else if (c.fType == *fContext.fInt_Type || c.fType == *fContext.fShort_Type) { 1319 return this->writeIntConstructor(c, out); 1320 } else if (c.fType == *fContext.fUInt_Type || c.fType == *fContext.fUShort_Type) { 1321 return this->writeUIntConstructor(c, out); 1322 } 1323 switch (c.fType.kind()) { 1324 case Type::kVector_Kind: 1325 return this->writeVectorConstructor(c, out); 1326 case Type::kMatrix_Kind: 1327 return this->writeMatrixConstructor(c, out); 1328 case Type::kArray_Kind: 1329 return this->writeArrayConstructor(c, out); 1330 default: 1331 ABORT("unsupported constructor: %s", c.description().c_str()); 1332 } 1333} 1334 1335SpvStorageClass_ get_storage_class(const Modifiers& modifiers) { 1336 if (modifiers.fFlags & Modifiers::kIn_Flag) { 1337 ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag)); 1338 return SpvStorageClassInput; 1339 } else if (modifiers.fFlags & Modifiers::kOut_Flag) { 1340 ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag)); 1341 return SpvStorageClassOutput; 1342 } else if (modifiers.fFlags & Modifiers::kUniform_Flag) { 1343 if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) { 1344 return SpvStorageClassPushConstant; 1345 } 1346 return SpvStorageClassUniform; 1347 } else { 1348 return SpvStorageClassFunction; 1349 } 1350} 1351 1352SpvStorageClass_ get_storage_class(const Expression& expr) { 1353 switch (expr.fKind) { 1354 case Expression::kVariableReference_Kind: { 1355 const Variable& var = ((VariableReference&) expr).fVariable; 1356 if (var.fStorage != Variable::kGlobal_Storage) { 1357 return SpvStorageClassFunction; 1358 } 1359 SpvStorageClass_ result = get_storage_class(var.fModifiers); 1360 if (result == SpvStorageClassFunction) { 1361 result = SpvStorageClassPrivate; 1362 } 1363 return result; 1364 } 1365 case Expression::kFieldAccess_Kind: 1366 return get_storage_class(*((FieldAccess&) expr).fBase); 1367 case Expression::kIndex_Kind: 1368 return get_storage_class(*((IndexExpression&) expr).fBase); 1369 default: 1370 return SpvStorageClassFunction; 1371 } 1372} 1373 1374std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) { 1375 std::vector<SpvId> chain; 1376 switch (expr.fKind) { 1377 case Expression::kIndex_Kind: { 1378 IndexExpression& indexExpr = (IndexExpression&) expr; 1379 chain = this->getAccessChain(*indexExpr.fBase, out); 1380 chain.push_back(this->writeExpression(*indexExpr.fIndex, out)); 1381 break; 1382 } 1383 case Expression::kFieldAccess_Kind: { 1384 FieldAccess& fieldExpr = (FieldAccess&) expr; 1385 chain = this->getAccessChain(*fieldExpr.fBase, out); 1386 IntLiteral index(fContext, -1, fieldExpr.fFieldIndex); 1387 chain.push_back(this->writeIntLiteral(index)); 1388 break; 1389 } 1390 default: 1391 chain.push_back(this->getLValue(expr, out)->getPointer()); 1392 } 1393 return chain; 1394} 1395 1396class PointerLValue : public SPIRVCodeGenerator::LValue { 1397public: 1398 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type) 1399 : fGen(gen) 1400 , fPointer(pointer) 1401 , fType(type) {} 1402 1403 virtual SpvId getPointer() override { 1404 return fPointer; 1405 } 1406 1407 virtual SpvId load(OutputStream& out) override { 1408 SpvId result = fGen.nextId(); 1409 fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out); 1410 return result; 1411 } 1412 1413 virtual void store(SpvId value, OutputStream& out) override { 1414 fGen.writeInstruction(SpvOpStore, fPointer, value, out); 1415 } 1416 1417private: 1418 SPIRVCodeGenerator& fGen; 1419 const SpvId fPointer; 1420 const SpvId fType; 1421}; 1422 1423class SwizzleLValue : public SPIRVCodeGenerator::LValue { 1424public: 1425 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components, 1426 const Type& baseType, const Type& swizzleType) 1427 : fGen(gen) 1428 , fVecPointer(vecPointer) 1429 , fComponents(components) 1430 , fBaseType(baseType) 1431 , fSwizzleType(swizzleType) {} 1432 1433 virtual SpvId getPointer() override { 1434 return 0; 1435 } 1436 1437 virtual SpvId load(OutputStream& out) override { 1438 SpvId base = fGen.nextId(); 1439 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out); 1440 SpvId result = fGen.nextId(); 1441 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out); 1442 fGen.writeWord(fGen.getType(fSwizzleType), out); 1443 fGen.writeWord(result, out); 1444 fGen.writeWord(base, out); 1445 fGen.writeWord(base, out); 1446 for (int component : fComponents) { 1447 fGen.writeWord(component, out); 1448 } 1449 return result; 1450 } 1451 1452 virtual void store(SpvId value, OutputStream& out) override { 1453 // use OpVectorShuffle to mix and match the vector components. We effectively create 1454 // a virtual vector out of the concatenation of the left and right vectors, and then 1455 // select components from this virtual vector to make the result vector. For 1456 // instance, given: 1457 // float3L = ...; 1458 // float3R = ...; 1459 // L.xz = R.xy; 1460 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want 1461 // our result vector to look like (R.x, L.y, R.y), so we need to select indices 1462 // (3, 1, 4). 1463 SpvId base = fGen.nextId(); 1464 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out); 1465 SpvId shuffle = fGen.nextId(); 1466 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out); 1467 fGen.writeWord(fGen.getType(fBaseType), out); 1468 fGen.writeWord(shuffle, out); 1469 fGen.writeWord(base, out); 1470 fGen.writeWord(value, out); 1471 for (int i = 0; i < fBaseType.columns(); i++) { 1472 // current offset into the virtual vector, defaults to pulling the unmodified 1473 // value from the left side 1474 int offset = i; 1475 // check to see if we are writing this component 1476 for (size_t j = 0; j < fComponents.size(); j++) { 1477 if (fComponents[j] == i) { 1478 // we're writing to this component, so adjust the offset to pull from 1479 // the correct component of the right side instead of preserving the 1480 // value from the left 1481 offset = (int) (j + fBaseType.columns()); 1482 break; 1483 } 1484 } 1485 fGen.writeWord(offset, out); 1486 } 1487 fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out); 1488 } 1489 1490private: 1491 SPIRVCodeGenerator& fGen; 1492 const SpvId fVecPointer; 1493 const std::vector<int>& fComponents; 1494 const Type& fBaseType; 1495 const Type& fSwizzleType; 1496}; 1497 1498std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr, 1499 OutputStream& out) { 1500 switch (expr.fKind) { 1501 case Expression::kVariableReference_Kind: { 1502 const Variable& var = ((VariableReference&) expr).fVariable; 1503 auto entry = fVariableMap.find(&var); 1504 ASSERT(entry != fVariableMap.end()); 1505 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1506 *this, 1507 entry->second, 1508 this->getType(expr.fType))); 1509 } 1510 case Expression::kIndex_Kind: // fall through 1511 case Expression::kFieldAccess_Kind: { 1512 std::vector<SpvId> chain = this->getAccessChain(expr, out); 1513 SpvId member = this->nextId(); 1514 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out); 1515 this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out); 1516 this->writeWord(member, out); 1517 for (SpvId idx : chain) { 1518 this->writeWord(idx, out); 1519 } 1520 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1521 *this, 1522 member, 1523 this->getType(expr.fType))); 1524 } 1525 case Expression::kSwizzle_Kind: { 1526 Swizzle& swizzle = (Swizzle&) expr; 1527 size_t count = swizzle.fComponents.size(); 1528 SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer(); 1529 ASSERT(base); 1530 if (count == 1) { 1531 IntLiteral index(fContext, -1, swizzle.fComponents[0]); 1532 SpvId member = this->nextId(); 1533 this->writeInstruction(SpvOpAccessChain, 1534 this->getPointerType(swizzle.fType, 1535 get_storage_class(*swizzle.fBase)), 1536 member, 1537 base, 1538 this->writeIntLiteral(index), 1539 out); 1540 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1541 *this, 1542 member, 1543 this->getType(expr.fType))); 1544 } else { 1545 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue( 1546 *this, 1547 base, 1548 swizzle.fComponents, 1549 swizzle.fBase->fType, 1550 expr.fType)); 1551 } 1552 } 1553 case Expression::kTernary_Kind: { 1554 TernaryExpression& t = (TernaryExpression&) expr; 1555 SpvId test = this->writeExpression(*t.fTest, out); 1556 SpvId end = this->nextId(); 1557 SpvId ifTrueLabel = this->nextId(); 1558 SpvId ifFalseLabel = this->nextId(); 1559 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 1560 this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out); 1561 this->writeLabel(ifTrueLabel, out); 1562 SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer(); 1563 ASSERT(ifTrue); 1564 this->writeInstruction(SpvOpBranch, end, out); 1565 ifTrueLabel = fCurrentBlock; 1566 SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer(); 1567 ASSERT(ifFalse); 1568 ifFalseLabel = fCurrentBlock; 1569 this->writeInstruction(SpvOpBranch, end, out); 1570 SpvId result = this->nextId(); 1571 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue, 1572 ifTrueLabel, ifFalse, ifFalseLabel, out); 1573 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1574 *this, 1575 result, 1576 this->getType(expr.fType))); 1577 } 1578 default: 1579 // expr isn't actually an lvalue, create a dummy variable for it. This case happens due 1580 // to the need to store values in temporary variables during function calls (see 1581 // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been 1582 // caught by IRGenerator 1583 SpvId result = this->nextId(); 1584 SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction); 1585 this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction, 1586 fVariableBuffer); 1587 this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out); 1588 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1589 *this, 1590 result, 1591 this->getType(expr.fType))); 1592 } 1593} 1594 1595SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) { 1596 SpvId result = this->nextId(); 1597 auto entry = fVariableMap.find(&ref.fVariable); 1598 ASSERT(entry != fVariableMap.end()); 1599 SpvId var = entry->second; 1600 this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out); 1601 if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN && 1602 fProgram.fSettings.fFlipY) { 1603 // need to remap to a top-left coordinate system 1604 if (fRTHeightStructId == (SpvId) -1) { 1605 // height variable hasn't been written yet 1606 std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors)); 1607 ASSERT(fRTHeightFieldIndex == (SpvId) -1); 1608 std::vector<Type::Field> fields; 1609 fields.emplace_back(Modifiers(), SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get()); 1610 StringFragment name("sksl_synthetic_uniforms"); 1611 Type intfStruct(-1, name, fields); 1612 Layout layout(0, -1, -1, 1, -1, -1, -1, -1, Layout::Format::kUnspecified, 1613 Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key, 1614 StringFragment()); 1615 Variable* intfVar = new Variable(-1, 1616 Modifiers(layout, Modifiers::kUniform_Flag), 1617 name, 1618 intfStruct, 1619 Variable::kGlobal_Storage); 1620 fSynthetics.takeOwnership(intfVar); 1621 InterfaceBlock intf(-1, intfVar, name, String(""), 1622 std::vector<std::unique_ptr<Expression>>(), st); 1623 fRTHeightStructId = this->writeInterfaceBlock(intf); 1624 fRTHeightFieldIndex = 0; 1625 } 1626 ASSERT(fRTHeightFieldIndex != (SpvId) -1); 1627 // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0) 1628 SpvId xId = this->nextId(); 1629 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId, 1630 result, 0, out); 1631 IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex); 1632 SpvId fieldIndexId = this->writeIntLiteral(fieldIndex); 1633 SpvId heightPtr = this->nextId(); 1634 this->writeOpCode(SpvOpAccessChain, 5, out); 1635 this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out); 1636 this->writeWord(heightPtr, out); 1637 this->writeWord(fRTHeightStructId, out); 1638 this->writeWord(fieldIndexId, out); 1639 SpvId heightRead = this->nextId(); 1640 this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead, 1641 heightPtr, out); 1642 SpvId rawYId = this->nextId(); 1643 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId, 1644 result, 1, out); 1645 SpvId flippedYId = this->nextId(); 1646 this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId, 1647 heightRead, rawYId, out); 1648 FloatLiteral zero(fContext, -1, 0.0); 1649 SpvId zeroId = writeFloatLiteral(zero); 1650 FloatLiteral one(fContext, -1, 1.0); 1651 SpvId oneId = writeFloatLiteral(one); 1652 SpvId flipped = this->nextId(); 1653 this->writeOpCode(SpvOpCompositeConstruct, 7, out); 1654 this->writeWord(this->getType(*fContext.fFloat4_Type), out); 1655 this->writeWord(flipped, out); 1656 this->writeWord(xId, out); 1657 this->writeWord(flippedYId, out); 1658 this->writeWord(zeroId, out); 1659 this->writeWord(oneId, out); 1660 return flipped; 1661 } 1662 return result; 1663} 1664 1665SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) { 1666 return getLValue(expr, out)->load(out); 1667} 1668 1669SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) { 1670 return getLValue(f, out)->load(out); 1671} 1672 1673SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) { 1674 SpvId base = this->writeExpression(*swizzle.fBase, out); 1675 SpvId result = this->nextId(); 1676 size_t count = swizzle.fComponents.size(); 1677 if (count == 1) { 1678 this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base, 1679 swizzle.fComponents[0], out); 1680 } else { 1681 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out); 1682 this->writeWord(this->getType(swizzle.fType), out); 1683 this->writeWord(result, out); 1684 this->writeWord(base, out); 1685 this->writeWord(base, out); 1686 for (int component : swizzle.fComponents) { 1687 this->writeWord(component, out); 1688 } 1689 } 1690 return result; 1691} 1692 1693SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, 1694 const Type& operandType, SpvId lhs, 1695 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, 1696 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) { 1697 SpvId result = this->nextId(); 1698 if (is_float(fContext, operandType)) { 1699 this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out); 1700 } else if (is_signed(fContext, operandType)) { 1701 this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out); 1702 } else if (is_unsigned(fContext, operandType)) { 1703 this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out); 1704 } else if (operandType == *fContext.fBool_Type) { 1705 this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out); 1706 } else { 1707 ABORT("invalid operandType: %s", operandType.description().c_str()); 1708 } 1709 return result; 1710} 1711 1712bool is_assignment(Token::Kind op) { 1713 switch (op) { 1714 case Token::EQ: // fall through 1715 case Token::PLUSEQ: // fall through 1716 case Token::MINUSEQ: // fall through 1717 case Token::STAREQ: // fall through 1718 case Token::SLASHEQ: // fall through 1719 case Token::PERCENTEQ: // fall through 1720 case Token::SHLEQ: // fall through 1721 case Token::SHREQ: // fall through 1722 case Token::BITWISEOREQ: // fall through 1723 case Token::BITWISEXOREQ: // fall through 1724 case Token::BITWISEANDEQ: // fall through 1725 case Token::LOGICALOREQ: // fall through 1726 case Token::LOGICALXOREQ: // fall through 1727 case Token::LOGICALANDEQ: 1728 return true; 1729 default: 1730 return false; 1731 } 1732} 1733 1734SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, OutputStream& out) { 1735 if (operandType.kind() == Type::kVector_Kind) { 1736 SpvId result = this->nextId(); 1737 this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out); 1738 return result; 1739 } 1740 return id; 1741} 1742 1743SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, 1744 SpvOp_ floatOperator, SpvOp_ intOperator, 1745 OutputStream& out) { 1746 SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator; 1747 ASSERT(operandType.kind() == Type::kMatrix_Kind); 1748 SpvId rowType = this->getType(operandType.componentType().toCompound(fContext, 1749 operandType.columns(), 1750 1)); 1751 SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext, 1752 operandType.columns(), 1753 1)); 1754 SpvId boolType = this->getType(*fContext.fBool_Type); 1755 SpvId result = 0; 1756 for (int i = 0; i < operandType.rows(); i++) { 1757 SpvId rowL = this->nextId(); 1758 this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out); 1759 SpvId rowR = this->nextId(); 1760 this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out); 1761 SpvId compare = this->nextId(); 1762 this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out); 1763 SpvId all = this->nextId(); 1764 this->writeInstruction(SpvOpAll, boolType, all, compare, out); 1765 if (result != 0) { 1766 SpvId next = this->nextId(); 1767 this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out); 1768 result = next; 1769 } 1770 else { 1771 result = all; 1772 } 1773 } 1774 return result; 1775} 1776 1777SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) { 1778 // handle cases where we don't necessarily evaluate both LHS and RHS 1779 switch (b.fOperator) { 1780 case Token::EQ: { 1781 SpvId rhs = this->writeExpression(*b.fRight, out); 1782 this->getLValue(*b.fLeft, out)->store(rhs, out); 1783 return rhs; 1784 } 1785 case Token::LOGICALAND: 1786 return this->writeLogicalAnd(b, out); 1787 case Token::LOGICALOR: 1788 return this->writeLogicalOr(b, out); 1789 default: 1790 break; 1791 } 1792 1793 // "normal" operators 1794 const Type& resultType = b.fType; 1795 std::unique_ptr<LValue> lvalue; 1796 SpvId lhs; 1797 if (is_assignment(b.fOperator)) { 1798 lvalue = this->getLValue(*b.fLeft, out); 1799 lhs = lvalue->load(out); 1800 } else { 1801 lvalue = nullptr; 1802 lhs = this->writeExpression(*b.fLeft, out); 1803 } 1804 SpvId rhs = this->writeExpression(*b.fRight, out); 1805 if (b.fOperator == Token::COMMA) { 1806 return rhs; 1807 } 1808 Type tmp("<invalid>"); 1809 // component type we are operating on: float, int, uint 1810 const Type* operandType; 1811 // IR allows mismatched types in expressions (e.g. float2* float), but they need special handling 1812 // in SPIR-V 1813 if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) { 1814 if (b.fLeft->fType.kind() == Type::kVector_Kind && 1815 b.fRight->fType.isNumber()) { 1816 // promote number to vector 1817 SpvId vec = this->nextId(); 1818 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out); 1819 this->writeWord(this->getType(resultType), out); 1820 this->writeWord(vec, out); 1821 for (int i = 0; i < resultType.columns(); i++) { 1822 this->writeWord(rhs, out); 1823 } 1824 rhs = vec; 1825 operandType = &b.fRight->fType; 1826 } else if (b.fRight->fType.kind() == Type::kVector_Kind && 1827 b.fLeft->fType.isNumber()) { 1828 // promote number to vector 1829 SpvId vec = this->nextId(); 1830 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out); 1831 this->writeWord(this->getType(resultType), out); 1832 this->writeWord(vec, out); 1833 for (int i = 0; i < resultType.columns(); i++) { 1834 this->writeWord(lhs, out); 1835 } 1836 lhs = vec; 1837 ASSERT(!lvalue); 1838 operandType = &b.fLeft->fType; 1839 } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) { 1840 SpvOp_ op; 1841 if (b.fRight->fType.kind() == Type::kMatrix_Kind) { 1842 op = SpvOpMatrixTimesMatrix; 1843 } else if (b.fRight->fType.kind() == Type::kVector_Kind) { 1844 op = SpvOpMatrixTimesVector; 1845 } else { 1846 ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind); 1847 op = SpvOpMatrixTimesScalar; 1848 } 1849 SpvId result = this->nextId(); 1850 this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out); 1851 if (b.fOperator == Token::STAREQ) { 1852 lvalue->store(result, out); 1853 } else { 1854 ASSERT(b.fOperator == Token::STAR); 1855 } 1856 return result; 1857 } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) { 1858 SpvId result = this->nextId(); 1859 if (b.fLeft->fType.kind() == Type::kVector_Kind) { 1860 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result, 1861 lhs, rhs, out); 1862 } else { 1863 ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind); 1864 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs, 1865 lhs, out); 1866 } 1867 if (b.fOperator == Token::STAREQ) { 1868 lvalue->store(result, out); 1869 } else { 1870 ASSERT(b.fOperator == Token::STAR); 1871 } 1872 return result; 1873 } else { 1874 ABORT("unsupported binary expression: %s", b.description().c_str()); 1875 } 1876 } else { 1877 tmp = this->getActualType(b.fLeft->fType); 1878 operandType = &tmp; 1879 ASSERT(*operandType == this->getActualType(b.fRight->fType)); 1880 } 1881 switch (b.fOperator) { 1882 case Token::EQEQ: { 1883 if (operandType->kind() == Type::kMatrix_Kind) { 1884 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual, 1885 SpvOpIEqual, out); 1886 } 1887 ASSERT(resultType == *fContext.fBool_Type); 1888 return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 1889 SpvOpFOrdEqual, SpvOpIEqual, 1890 SpvOpIEqual, SpvOpLogicalEqual, out), 1891 *operandType, out); 1892 } 1893 case Token::NEQ: 1894 if (operandType->kind() == Type::kMatrix_Kind) { 1895 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual, 1896 SpvOpINotEqual, out); 1897 } 1898 ASSERT(resultType == *fContext.fBool_Type); 1899 return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 1900 SpvOpFOrdNotEqual, SpvOpINotEqual, 1901 SpvOpINotEqual, SpvOpLogicalNotEqual, 1902 out), 1903 *operandType, out); 1904 case Token::GT: 1905 ASSERT(resultType == *fContext.fBool_Type); 1906 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 1907 SpvOpFOrdGreaterThan, SpvOpSGreaterThan, 1908 SpvOpUGreaterThan, SpvOpUndef, out); 1909 case Token::LT: 1910 ASSERT(resultType == *fContext.fBool_Type); 1911 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan, 1912 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out); 1913 case Token::GTEQ: 1914 ASSERT(resultType == *fContext.fBool_Type); 1915 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 1916 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual, 1917 SpvOpUGreaterThanEqual, SpvOpUndef, out); 1918 case Token::LTEQ: 1919 ASSERT(resultType == *fContext.fBool_Type); 1920 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 1921 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual, 1922 SpvOpULessThanEqual, SpvOpUndef, out); 1923 case Token::PLUS: 1924 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd, 1925 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 1926 case Token::MINUS: 1927 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, 1928 SpvOpISub, SpvOpISub, SpvOpUndef, out); 1929 case Token::STAR: 1930 if (b.fLeft->fType.kind() == Type::kMatrix_Kind && 1931 b.fRight->fType.kind() == Type::kMatrix_Kind) { 1932 // matrix multiply 1933 SpvId result = this->nextId(); 1934 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, 1935 lhs, rhs, out); 1936 return result; 1937 } 1938 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul, 1939 SpvOpIMul, SpvOpIMul, SpvOpUndef, out); 1940 case Token::SLASH: 1941 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv, 1942 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out); 1943 case Token::PERCENT: 1944 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod, 1945 SpvOpSMod, SpvOpUMod, SpvOpUndef, out); 1946 case Token::SHL: 1947 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 1948 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical, 1949 SpvOpUndef, out); 1950 case Token::SHR: 1951 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 1952 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical, 1953 SpvOpUndef, out); 1954 case Token::BITWISEAND: 1955 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 1956 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out); 1957 case Token::BITWISEOR: 1958 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 1959 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out); 1960 case Token::BITWISEXOR: 1961 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 1962 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out); 1963 case Token::PLUSEQ: { 1964 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd, 1965 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 1966 ASSERT(lvalue); 1967 lvalue->store(result, out); 1968 return result; 1969 } 1970 case Token::MINUSEQ: { 1971 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, 1972 SpvOpISub, SpvOpISub, SpvOpUndef, out); 1973 ASSERT(lvalue); 1974 lvalue->store(result, out); 1975 return result; 1976 } 1977 case Token::STAREQ: { 1978 if (b.fLeft->fType.kind() == Type::kMatrix_Kind && 1979 b.fRight->fType.kind() == Type::kMatrix_Kind) { 1980 // matrix multiply 1981 SpvId result = this->nextId(); 1982 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, 1983 lhs, rhs, out); 1984 ASSERT(lvalue); 1985 lvalue->store(result, out); 1986 return result; 1987 } 1988 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul, 1989 SpvOpIMul, SpvOpIMul, SpvOpUndef, out); 1990 ASSERT(lvalue); 1991 lvalue->store(result, out); 1992 return result; 1993 } 1994 case Token::SLASHEQ: { 1995 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv, 1996 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out); 1997 ASSERT(lvalue); 1998 lvalue->store(result, out); 1999 return result; 2000 } 2001 case Token::PERCENTEQ: { 2002 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod, 2003 SpvOpSMod, SpvOpUMod, SpvOpUndef, out); 2004 ASSERT(lvalue); 2005 lvalue->store(result, out); 2006 return result; 2007 } 2008 case Token::SHLEQ: { 2009 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2010 SpvOpUndef, SpvOpShiftLeftLogical, 2011 SpvOpShiftLeftLogical, SpvOpUndef, out); 2012 ASSERT(lvalue); 2013 lvalue->store(result, out); 2014 return result; 2015 } 2016 case Token::SHREQ: { 2017 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2018 SpvOpUndef, SpvOpShiftRightArithmetic, 2019 SpvOpShiftRightLogical, SpvOpUndef, out); 2020 ASSERT(lvalue); 2021 lvalue->store(result, out); 2022 return result; 2023 } 2024 case Token::BITWISEANDEQ: { 2025 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2026 SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd, 2027 SpvOpUndef, out); 2028 ASSERT(lvalue); 2029 lvalue->store(result, out); 2030 return result; 2031 } 2032 case Token::BITWISEOREQ: { 2033 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2034 SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr, 2035 SpvOpUndef, out); 2036 ASSERT(lvalue); 2037 lvalue->store(result, out); 2038 return result; 2039 } 2040 case Token::BITWISEXOREQ: { 2041 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2042 SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor, 2043 SpvOpUndef, out); 2044 ASSERT(lvalue); 2045 lvalue->store(result, out); 2046 return result; 2047 } 2048 default: 2049 ABORT("unsupported binary expression: %s", b.description().c_str()); 2050 } 2051} 2052 2053SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) { 2054 ASSERT(a.fOperator == Token::LOGICALAND); 2055 BoolLiteral falseLiteral(fContext, -1, false); 2056 SpvId falseConstant = this->writeBoolLiteral(falseLiteral); 2057 SpvId lhs = this->writeExpression(*a.fLeft, out); 2058 SpvId rhsLabel = this->nextId(); 2059 SpvId end = this->nextId(); 2060 SpvId lhsBlock = fCurrentBlock; 2061 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2062 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out); 2063 this->writeLabel(rhsLabel, out); 2064 SpvId rhs = this->writeExpression(*a.fRight, out); 2065 SpvId rhsBlock = fCurrentBlock; 2066 this->writeInstruction(SpvOpBranch, end, out); 2067 this->writeLabel(end, out); 2068 SpvId result = this->nextId(); 2069 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant, 2070 lhsBlock, rhs, rhsBlock, out); 2071 return result; 2072} 2073 2074SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) { 2075 ASSERT(o.fOperator == Token::LOGICALOR); 2076 BoolLiteral trueLiteral(fContext, -1, true); 2077 SpvId trueConstant = this->writeBoolLiteral(trueLiteral); 2078 SpvId lhs = this->writeExpression(*o.fLeft, out); 2079 SpvId rhsLabel = this->nextId(); 2080 SpvId end = this->nextId(); 2081 SpvId lhsBlock = fCurrentBlock; 2082 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2083 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out); 2084 this->writeLabel(rhsLabel, out); 2085 SpvId rhs = this->writeExpression(*o.fRight, out); 2086 SpvId rhsBlock = fCurrentBlock; 2087 this->writeInstruction(SpvOpBranch, end, out); 2088 this->writeLabel(end, out); 2089 SpvId result = this->nextId(); 2090 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant, 2091 lhsBlock, rhs, rhsBlock, out); 2092 return result; 2093} 2094 2095SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) { 2096 SpvId test = this->writeExpression(*t.fTest, out); 2097 if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) { 2098 // both true and false are constants, can just use OpSelect 2099 SpvId result = this->nextId(); 2100 SpvId trueId = this->writeExpression(*t.fIfTrue, out); 2101 SpvId falseId = this->writeExpression(*t.fIfFalse, out); 2102 this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId, 2103 out); 2104 return result; 2105 } 2106 // was originally using OpPhi to choose the result, but for some reason that is crashing on 2107 // Adreno. Switched to storing the result in a temp variable as glslang does. 2108 SpvId var = this->nextId(); 2109 this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction), 2110 var, SpvStorageClassFunction, fVariableBuffer); 2111 SpvId trueLabel = this->nextId(); 2112 SpvId falseLabel = this->nextId(); 2113 SpvId end = this->nextId(); 2114 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2115 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out); 2116 this->writeLabel(trueLabel, out); 2117 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out); 2118 this->writeInstruction(SpvOpBranch, end, out); 2119 this->writeLabel(falseLabel, out); 2120 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out); 2121 this->writeInstruction(SpvOpBranch, end, out); 2122 this->writeLabel(end, out); 2123 SpvId result = this->nextId(); 2124 this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out); 2125 return result; 2126} 2127 2128std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) { 2129 if (type.isInteger()) { 2130 return std::unique_ptr<Expression>(new IntLiteral(context, -1, 1, &type)); 2131 } 2132 else if (type.isFloat()) { 2133 return std::unique_ptr<Expression>(new FloatLiteral(context, -1, 1.0, &type)); 2134 } else { 2135 ABORT("math is unsupported on type '%s'", type.name().c_str()); 2136 } 2137} 2138 2139SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) { 2140 if (p.fOperator == Token::MINUS) { 2141 SpvId result = this->nextId(); 2142 SpvId typeId = this->getType(p.fType); 2143 SpvId expr = this->writeExpression(*p.fOperand, out); 2144 if (is_float(fContext, p.fType)) { 2145 this->writeInstruction(SpvOpFNegate, typeId, result, expr, out); 2146 } else if (is_signed(fContext, p.fType)) { 2147 this->writeInstruction(SpvOpSNegate, typeId, result, expr, out); 2148 } else { 2149 ABORT("unsupported prefix expression %s", p.description().c_str()); 2150 }; 2151 return result; 2152 } 2153 switch (p.fOperator) { 2154 case Token::PLUS: 2155 return this->writeExpression(*p.fOperand, out); 2156 case Token::PLUSPLUS: { 2157 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); 2158 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); 2159 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, 2160 SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, 2161 out); 2162 lv->store(result, out); 2163 return result; 2164 } 2165 case Token::MINUSMINUS: { 2166 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); 2167 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); 2168 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, 2169 SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, 2170 out); 2171 lv->store(result, out); 2172 return result; 2173 } 2174 case Token::LOGICALNOT: { 2175 ASSERT(p.fOperand->fType == *fContext.fBool_Type); 2176 SpvId result = this->nextId(); 2177 this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result, 2178 this->writeExpression(*p.fOperand, out), out); 2179 return result; 2180 } 2181 case Token::BITWISENOT: { 2182 SpvId result = this->nextId(); 2183 this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result, 2184 this->writeExpression(*p.fOperand, out), out); 2185 return result; 2186 } 2187 default: 2188 ABORT("unsupported prefix expression: %s", p.description().c_str()); 2189 } 2190} 2191 2192SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) { 2193 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); 2194 SpvId result = lv->load(out); 2195 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); 2196 switch (p.fOperator) { 2197 case Token::PLUSPLUS: { 2198 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd, 2199 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 2200 lv->store(temp, out); 2201 return result; 2202 } 2203 case Token::MINUSMINUS: { 2204 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub, 2205 SpvOpISub, SpvOpISub, SpvOpUndef, out); 2206 lv->store(temp, out); 2207 return result; 2208 } 2209 default: 2210 ABORT("unsupported postfix expression %s", p.description().c_str()); 2211 } 2212} 2213 2214SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) { 2215 if (b.fValue) { 2216 if (fBoolTrue == 0) { 2217 fBoolTrue = this->nextId(); 2218 this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue, 2219 fConstantBuffer); 2220 } 2221 return fBoolTrue; 2222 } else { 2223 if (fBoolFalse == 0) { 2224 fBoolFalse = this->nextId(); 2225 this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse, 2226 fConstantBuffer); 2227 } 2228 return fBoolFalse; 2229 } 2230} 2231 2232SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) { 2233 if (i.fType == *fContext.fInt_Type) { 2234 auto entry = fIntConstants.find(i.fValue); 2235 if (entry == fIntConstants.end()) { 2236 SpvId result = this->nextId(); 2237 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, 2238 fConstantBuffer); 2239 fIntConstants[i.fValue] = result; 2240 return result; 2241 } 2242 return entry->second; 2243 } else { 2244 ASSERT(i.fType == *fContext.fUInt_Type); 2245 auto entry = fUIntConstants.find(i.fValue); 2246 if (entry == fUIntConstants.end()) { 2247 SpvId result = this->nextId(); 2248 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, 2249 fConstantBuffer); 2250 fUIntConstants[i.fValue] = result; 2251 return result; 2252 } 2253 return entry->second; 2254 } 2255} 2256 2257SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) { 2258 if (f.fType == *fContext.fFloat_Type || f.fType == *fContext.fHalf_Type) { 2259 float value = (float) f.fValue; 2260 auto entry = fFloatConstants.find(value); 2261 if (entry == fFloatConstants.end()) { 2262 SpvId result = this->nextId(); 2263 uint32_t bits; 2264 ASSERT(sizeof(bits) == sizeof(value)); 2265 memcpy(&bits, &value, sizeof(bits)); 2266 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits, 2267 fConstantBuffer); 2268 fFloatConstants[value] = result; 2269 return result; 2270 } 2271 return entry->second; 2272 } else { 2273 ASSERT(f.fType == *fContext.fDouble_Type); 2274 auto entry = fDoubleConstants.find(f.fValue); 2275 if (entry == fDoubleConstants.end()) { 2276 SpvId result = this->nextId(); 2277 uint64_t bits; 2278 ASSERT(sizeof(bits) == sizeof(f.fValue)); 2279 memcpy(&bits, &f.fValue, sizeof(bits)); 2280 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, 2281 bits & 0xffffffff, bits >> 32, fConstantBuffer); 2282 fDoubleConstants[f.fValue] = result; 2283 return result; 2284 } 2285 return entry->second; 2286 } 2287} 2288 2289SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) { 2290 SpvId result = fFunctionMap[&f]; 2291 this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result, 2292 SpvFunctionControlMaskNone, this->getFunctionType(f), out); 2293 this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer); 2294 for (size_t i = 0; i < f.fParameters.size(); i++) { 2295 SpvId id = this->nextId(); 2296 fVariableMap[f.fParameters[i]] = id; 2297 SpvId type; 2298 type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction); 2299 this->writeInstruction(SpvOpFunctionParameter, type, id, out); 2300 } 2301 return result; 2302} 2303 2304SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) { 2305 fVariableBuffer.reset(); 2306 SpvId result = this->writeFunctionStart(f.fDeclaration, out); 2307 this->writeLabel(this->nextId(), out); 2308 if (f.fDeclaration.fName == "main") { 2309 write_stringstream(fGlobalInitializersBuffer, out); 2310 } 2311 StringStream bodyBuffer; 2312 this->writeBlock((Block&) *f.fBody, bodyBuffer); 2313 write_stringstream(fVariableBuffer, out); 2314 write_stringstream(bodyBuffer, out); 2315 if (fCurrentBlock) { 2316 if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) { 2317 this->writeInstruction(SpvOpReturn, out); 2318 } else { 2319 this->writeInstruction(SpvOpUnreachable, out); 2320 } 2321 } 2322 this->writeInstruction(SpvOpFunctionEnd, out); 2323 return result; 2324} 2325 2326void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) { 2327 if (layout.fLocation >= 0) { 2328 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation, 2329 fDecorationBuffer); 2330 } 2331 if (layout.fBinding >= 0) { 2332 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding, 2333 fDecorationBuffer); 2334 } 2335 if (layout.fIndex >= 0) { 2336 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex, 2337 fDecorationBuffer); 2338 } 2339 if (layout.fSet >= 0) { 2340 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet, 2341 fDecorationBuffer); 2342 } 2343 if (layout.fInputAttachmentIndex >= 0) { 2344 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex, 2345 layout.fInputAttachmentIndex, fDecorationBuffer); 2346 fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment); 2347 } 2348 if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN && 2349 layout.fBuiltin != SK_IN_BUILTIN) { 2350 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin, 2351 fDecorationBuffer); 2352 } 2353} 2354 2355void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) { 2356 if (layout.fLocation >= 0) { 2357 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation, 2358 layout.fLocation, fDecorationBuffer); 2359 } 2360 if (layout.fBinding >= 0) { 2361 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding, 2362 layout.fBinding, fDecorationBuffer); 2363 } 2364 if (layout.fIndex >= 0) { 2365 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex, 2366 layout.fIndex, fDecorationBuffer); 2367 } 2368 if (layout.fSet >= 0) { 2369 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet, 2370 layout.fSet, fDecorationBuffer); 2371 } 2372 if (layout.fInputAttachmentIndex >= 0) { 2373 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex, 2374 layout.fInputAttachmentIndex, fDecorationBuffer); 2375 } 2376 if (layout.fBuiltin >= 0) { 2377 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn, 2378 layout.fBuiltin, fDecorationBuffer); 2379 } 2380} 2381 2382SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) { 2383 bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag)); 2384 bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags & 2385 Layout::kPushConstant_Flag)); 2386 MemoryLayout layout = (pushConstant || isBuffer) ? 2387 MemoryLayout(MemoryLayout::k430_Standard) : 2388 fDefaultLayout; 2389 SpvId result = this->nextId(); 2390 const Type* type = &intf.fVariable.fType; 2391 if (fProgram.fInputs.fRTHeight) { 2392 ASSERT(fRTHeightStructId == (SpvId) -1); 2393 ASSERT(fRTHeightFieldIndex == (SpvId) -1); 2394 std::vector<Type::Field> fields = type->fields(); 2395 fRTHeightStructId = result; 2396 fRTHeightFieldIndex = fields.size(); 2397 fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get()); 2398 type = new Type(type->fOffset, type->name(), fields); 2399 } 2400 SpvId typeId = this->getType(*type, layout); 2401 if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) { 2402 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer); 2403 } else { 2404 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer); 2405 } 2406 SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers); 2407 SpvId ptrType = this->nextId(); 2408 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer); 2409 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer); 2410 this->writeLayout(intf.fVariable.fModifiers.fLayout, result); 2411 fVariableMap[&intf.fVariable] = result; 2412 if (fProgram.fInputs.fRTHeight) { 2413 delete type; 2414 } 2415 return result; 2416} 2417 2418void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) { 2419 if ((modifiers.fFlags & Modifiers::kLowp_Flag) | 2420 (modifiers.fFlags & Modifiers::kMediump_Flag)) { 2421 this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer); 2422 } 2423} 2424 2425#define BUILTIN_IGNORE 9999 2426void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl, 2427 OutputStream& out) { 2428 for (size_t i = 0; i < decl.fVars.size(); i++) { 2429 if (decl.fVars[i]->fKind == Statement::kNop_Kind) { 2430 continue; 2431 } 2432 const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i]; 2433 const Variable* var = varDecl.fVar; 2434 // These haven't been implemented in our SPIR-V generator yet and we only currently use them 2435 // in the OpenGL backend. 2436 ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag | 2437 Modifiers::kWriteOnly_Flag | 2438 Modifiers::kCoherent_Flag | 2439 Modifiers::kVolatile_Flag | 2440 Modifiers::kRestrict_Flag))); 2441 if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) { 2442 continue; 2443 } 2444 if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN && 2445 kind != Program::kFragment_Kind) { 2446 continue; 2447 } 2448 if (!var->fReadCount && !var->fWriteCount && 2449 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag | 2450 Modifiers::kOut_Flag | 2451 Modifiers::kUniform_Flag | 2452 Modifiers::kBuffer_Flag))) { 2453 // variable is dead and not an input / output var (the Vulkan debug layers complain if 2454 // we elide an interface var, even if it's dead) 2455 continue; 2456 } 2457 SpvStorageClass_ storageClass; 2458 if (var->fModifiers.fFlags & Modifiers::kIn_Flag) { 2459 storageClass = SpvStorageClassInput; 2460 } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) { 2461 storageClass = SpvStorageClassOutput; 2462 } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) { 2463 if (var->fType.kind() == Type::kSampler_Kind) { 2464 storageClass = SpvStorageClassUniformConstant; 2465 } else { 2466 storageClass = SpvStorageClassUniform; 2467 } 2468 } else { 2469 storageClass = SpvStorageClassPrivate; 2470 } 2471 SpvId id = this->nextId(); 2472 fVariableMap[var] = id; 2473 SpvId type = this->getPointerType(var->fType, storageClass); 2474 this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer); 2475 this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer); 2476 this->writePrecisionModifier(var->fModifiers, id); 2477 if (varDecl.fValue) { 2478 ASSERT(!fCurrentBlock); 2479 fCurrentBlock = -1; 2480 SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer); 2481 this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer); 2482 fCurrentBlock = 0; 2483 } 2484 this->writeLayout(var->fModifiers.fLayout, id); 2485 if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) { 2486 this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer); 2487 } 2488 if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) { 2489 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective, 2490 fDecorationBuffer); 2491 } 2492 } 2493} 2494 2495void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) { 2496 for (const auto& stmt : decl.fVars) { 2497 ASSERT(stmt->fKind == Statement::kVarDeclaration_Kind); 2498 VarDeclaration& varDecl = (VarDeclaration&) *stmt; 2499 const Variable* var = varDecl.fVar; 2500 // These haven't been implemented in our SPIR-V generator yet and we only currently use them 2501 // in the OpenGL backend. 2502 ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag | 2503 Modifiers::kWriteOnly_Flag | 2504 Modifiers::kCoherent_Flag | 2505 Modifiers::kVolatile_Flag | 2506 Modifiers::kRestrict_Flag))); 2507 SpvId id = this->nextId(); 2508 fVariableMap[var] = id; 2509 SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction); 2510 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer); 2511 this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer); 2512 if (varDecl.fValue) { 2513 SpvId value = this->writeExpression(*varDecl.fValue, out); 2514 this->writeInstruction(SpvOpStore, id, value, out); 2515 } 2516 } 2517} 2518 2519void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) { 2520 switch (s.fKind) { 2521 case Statement::kNop_Kind: 2522 break; 2523 case Statement::kBlock_Kind: 2524 this->writeBlock((Block&) s, out); 2525 break; 2526 case Statement::kExpression_Kind: 2527 this->writeExpression(*((ExpressionStatement&) s).fExpression, out); 2528 break; 2529 case Statement::kReturn_Kind: 2530 this->writeReturnStatement((ReturnStatement&) s, out); 2531 break; 2532 case Statement::kVarDeclarations_Kind: 2533 this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out); 2534 break; 2535 case Statement::kIf_Kind: 2536 this->writeIfStatement((IfStatement&) s, out); 2537 break; 2538 case Statement::kFor_Kind: 2539 this->writeForStatement((ForStatement&) s, out); 2540 break; 2541 case Statement::kWhile_Kind: 2542 this->writeWhileStatement((WhileStatement&) s, out); 2543 break; 2544 case Statement::kDo_Kind: 2545 this->writeDoStatement((DoStatement&) s, out); 2546 break; 2547 case Statement::kSwitch_Kind: 2548 this->writeSwitchStatement((SwitchStatement&) s, out); 2549 break; 2550 case Statement::kBreak_Kind: 2551 this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out); 2552 break; 2553 case Statement::kContinue_Kind: 2554 this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out); 2555 break; 2556 case Statement::kDiscard_Kind: 2557 this->writeInstruction(SpvOpKill, out); 2558 break; 2559 default: 2560 ABORT("unsupported statement: %s", s.description().c_str()); 2561 } 2562} 2563 2564void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) { 2565 for (size_t i = 0; i < b.fStatements.size(); i++) { 2566 this->writeStatement(*b.fStatements[i], out); 2567 } 2568} 2569 2570void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) { 2571 SpvId test = this->writeExpression(*stmt.fTest, out); 2572 SpvId ifTrue = this->nextId(); 2573 SpvId ifFalse = this->nextId(); 2574 if (stmt.fIfFalse) { 2575 SpvId end = this->nextId(); 2576 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2577 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out); 2578 this->writeLabel(ifTrue, out); 2579 this->writeStatement(*stmt.fIfTrue, out); 2580 if (fCurrentBlock) { 2581 this->writeInstruction(SpvOpBranch, end, out); 2582 } 2583 this->writeLabel(ifFalse, out); 2584 this->writeStatement(*stmt.fIfFalse, out); 2585 if (fCurrentBlock) { 2586 this->writeInstruction(SpvOpBranch, end, out); 2587 } 2588 this->writeLabel(end, out); 2589 } else { 2590 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out); 2591 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out); 2592 this->writeLabel(ifTrue, out); 2593 this->writeStatement(*stmt.fIfTrue, out); 2594 if (fCurrentBlock) { 2595 this->writeInstruction(SpvOpBranch, ifFalse, out); 2596 } 2597 this->writeLabel(ifFalse, out); 2598 } 2599} 2600 2601void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) { 2602 if (f.fInitializer) { 2603 this->writeStatement(*f.fInitializer, out); 2604 } 2605 SpvId header = this->nextId(); 2606 SpvId start = this->nextId(); 2607 SpvId body = this->nextId(); 2608 SpvId next = this->nextId(); 2609 fContinueTarget.push(next); 2610 SpvId end = this->nextId(); 2611 fBreakTarget.push(end); 2612 this->writeInstruction(SpvOpBranch, header, out); 2613 this->writeLabel(header, out); 2614 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out); 2615 this->writeInstruction(SpvOpBranch, start, out); 2616 this->writeLabel(start, out); 2617 if (f.fTest) { 2618 SpvId test = this->writeExpression(*f.fTest, out); 2619 this->writeInstruction(SpvOpBranchConditional, test, body, end, out); 2620 } 2621 this->writeLabel(body, out); 2622 this->writeStatement(*f.fStatement, out); 2623 if (fCurrentBlock) { 2624 this->writeInstruction(SpvOpBranch, next, out); 2625 } 2626 this->writeLabel(next, out); 2627 if (f.fNext) { 2628 this->writeExpression(*f.fNext, out); 2629 } 2630 this->writeInstruction(SpvOpBranch, header, out); 2631 this->writeLabel(end, out); 2632 fBreakTarget.pop(); 2633 fContinueTarget.pop(); 2634} 2635 2636void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) { 2637 // We believe the while loop code below will work, but Skia doesn't actually use them and 2638 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For 2639 // the time being, we just fail with an error due to the lack of testing. If you encounter this 2640 // message, simply remove the error call below to see whether our while loop support actually 2641 // works. 2642 fErrors.error(w.fOffset, "internal error: while loop support has been disabled in SPIR-V, " 2643 "see SkSLSPIRVCodeGenerator.cpp for details"); 2644 2645 SpvId header = this->nextId(); 2646 SpvId start = this->nextId(); 2647 SpvId body = this->nextId(); 2648 fContinueTarget.push(start); 2649 SpvId end = this->nextId(); 2650 fBreakTarget.push(end); 2651 this->writeInstruction(SpvOpBranch, header, out); 2652 this->writeLabel(header, out); 2653 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out); 2654 this->writeInstruction(SpvOpBranch, start, out); 2655 this->writeLabel(start, out); 2656 SpvId test = this->writeExpression(*w.fTest, out); 2657 this->writeInstruction(SpvOpBranchConditional, test, body, end, out); 2658 this->writeLabel(body, out); 2659 this->writeStatement(*w.fStatement, out); 2660 if (fCurrentBlock) { 2661 this->writeInstruction(SpvOpBranch, start, out); 2662 } 2663 this->writeLabel(end, out); 2664 fBreakTarget.pop(); 2665 fContinueTarget.pop(); 2666} 2667 2668void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) { 2669 // We believe the do loop code below will work, but Skia doesn't actually use them and 2670 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For 2671 // the time being, we just fail with an error due to the lack of testing. If you encounter this 2672 // message, simply remove the error call below to see whether our do loop support actually 2673 // works. 2674 fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see " 2675 "SkSLSPIRVCodeGenerator.cpp for details"); 2676 2677 SpvId header = this->nextId(); 2678 SpvId start = this->nextId(); 2679 SpvId next = this->nextId(); 2680 fContinueTarget.push(next); 2681 SpvId end = this->nextId(); 2682 fBreakTarget.push(end); 2683 this->writeInstruction(SpvOpBranch, header, out); 2684 this->writeLabel(header, out); 2685 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out); 2686 this->writeInstruction(SpvOpBranch, start, out); 2687 this->writeLabel(start, out); 2688 this->writeStatement(*d.fStatement, out); 2689 if (fCurrentBlock) { 2690 this->writeInstruction(SpvOpBranch, next, out); 2691 } 2692 this->writeLabel(next, out); 2693 SpvId test = this->writeExpression(*d.fTest, out); 2694 this->writeInstruction(SpvOpBranchConditional, test, start, end, out); 2695 this->writeLabel(end, out); 2696 fBreakTarget.pop(); 2697 fContinueTarget.pop(); 2698} 2699 2700void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) { 2701 SpvId value = this->writeExpression(*s.fValue, out); 2702 std::vector<SpvId> labels; 2703 SpvId end = this->nextId(); 2704 SpvId defaultLabel = end; 2705 fBreakTarget.push(end); 2706 int size = 3; 2707 for (const auto& c : s.fCases) { 2708 SpvId label = this->nextId(); 2709 labels.push_back(label); 2710 if (c->fValue) { 2711 size += 2; 2712 } else { 2713 defaultLabel = label; 2714 } 2715 } 2716 labels.push_back(end); 2717 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2718 this->writeOpCode(SpvOpSwitch, size, out); 2719 this->writeWord(value, out); 2720 this->writeWord(defaultLabel, out); 2721 for (size_t i = 0; i < s.fCases.size(); ++i) { 2722 if (!s.fCases[i]->fValue) { 2723 continue; 2724 } 2725 ASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind); 2726 this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out); 2727 this->writeWord(labels[i], out); 2728 } 2729 for (size_t i = 0; i < s.fCases.size(); ++i) { 2730 this->writeLabel(labels[i], out); 2731 for (const auto& stmt : s.fCases[i]->fStatements) { 2732 this->writeStatement(*stmt, out); 2733 } 2734 if (fCurrentBlock) { 2735 this->writeInstruction(SpvOpBranch, labels[i + 1], out); 2736 } 2737 } 2738 this->writeLabel(end, out); 2739 fBreakTarget.pop(); 2740} 2741 2742void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) { 2743 if (r.fExpression) { 2744 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out), 2745 out); 2746 } else { 2747 this->writeInstruction(SpvOpReturn, out); 2748 } 2749} 2750 2751void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) { 2752 ASSERT(fProgram.fKind == Program::kGeometry_Kind); 2753 int invocations = 1; 2754 for (size_t i = 0; i < fProgram.fElements.size(); i++) { 2755 if (fProgram.fElements[i]->fKind == ProgramElement::kModifiers_Kind) { 2756 const Modifiers& m = ((ModifiersDeclaration&) *fProgram.fElements[i]).fModifiers; 2757 if (m.fFlags & Modifiers::kIn_Flag) { 2758 if (m.fLayout.fInvocations != -1) { 2759 invocations = m.fLayout.fInvocations; 2760 } 2761 SpvId input; 2762 switch (m.fLayout.fPrimitive) { 2763 case Layout::kPoints_Primitive: 2764 input = SpvExecutionModeInputPoints; 2765 break; 2766 case Layout::kLines_Primitive: 2767 input = SpvExecutionModeInputLines; 2768 break; 2769 case Layout::kLinesAdjacency_Primitive: 2770 input = SpvExecutionModeInputLinesAdjacency; 2771 break; 2772 case Layout::kTriangles_Primitive: 2773 input = SpvExecutionModeTriangles; 2774 break; 2775 case Layout::kTrianglesAdjacency_Primitive: 2776 input = SpvExecutionModeInputTrianglesAdjacency; 2777 break; 2778 default: 2779 input = 0; 2780 break; 2781 } 2782 if (input) { 2783 this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out); 2784 } 2785 } else if (m.fFlags & Modifiers::kOut_Flag) { 2786 SpvId output; 2787 switch (m.fLayout.fPrimitive) { 2788 case Layout::kPoints_Primitive: 2789 output = SpvExecutionModeOutputPoints; 2790 break; 2791 case Layout::kLineStrip_Primitive: 2792 output = SpvExecutionModeOutputLineStrip; 2793 break; 2794 case Layout::kTriangleStrip_Primitive: 2795 output = SpvExecutionModeOutputTriangleStrip; 2796 break; 2797 default: 2798 output = 0; 2799 break; 2800 } 2801 if (output) { 2802 this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out); 2803 } 2804 if (m.fLayout.fMaxVertices != -1) { 2805 this->writeInstruction(SpvOpExecutionMode, entryPoint, 2806 SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices, 2807 out); 2808 } 2809 } 2810 } 2811 } 2812 this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations, 2813 invocations, out); 2814} 2815 2816void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) { 2817 fGLSLExtendedInstructions = this->nextId(); 2818 StringStream body; 2819 std::set<SpvId> interfaceVars; 2820 // assign IDs to functions, determine sk_in size 2821 int skInSize = -1; 2822 for (size_t i = 0; i < program.fElements.size(); i++) { 2823 switch (program.fElements[i]->fKind) { 2824 case ProgramElement::kFunction_Kind: { 2825 FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i]; 2826 fFunctionMap[&f.fDeclaration] = this->nextId(); 2827 break; 2828 } 2829 case ProgramElement::kModifiers_Kind: { 2830 Modifiers& m = ((ModifiersDeclaration&) *program.fElements[i]).fModifiers; 2831 if (m.fFlags & Modifiers::kIn_Flag) { 2832 switch (m.fLayout.fPrimitive) { 2833 case Layout::kPoints_Primitive: // break 2834 case Layout::kLines_Primitive: 2835 skInSize = 1; 2836 break; 2837 case Layout::kLinesAdjacency_Primitive: // break 2838 skInSize = 2; 2839 break; 2840 case Layout::kTriangles_Primitive: // break 2841 case Layout::kTrianglesAdjacency_Primitive: 2842 skInSize = 3; 2843 break; 2844 default: 2845 break; 2846 } 2847 } 2848 break; 2849 } 2850 default: 2851 break; 2852 } 2853 } 2854 for (size_t i = 0; i < program.fElements.size(); i++) { 2855 if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) { 2856 InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i]; 2857 if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) { 2858 ASSERT(skInSize != -1); 2859 intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize)); 2860 } 2861 SpvId id = this->writeInterfaceBlock(intf); 2862 if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) || 2863 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) { 2864 interfaceVars.insert(id); 2865 } 2866 } 2867 } 2868 for (size_t i = 0; i < program.fElements.size(); i++) { 2869 if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) { 2870 this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]), 2871 body); 2872 } 2873 } 2874 for (size_t i = 0; i < program.fElements.size(); i++) { 2875 if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) { 2876 this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body); 2877 } 2878 } 2879 const FunctionDeclaration* main = nullptr; 2880 for (auto entry : fFunctionMap) { 2881 if (entry.first->fName == "main") { 2882 main = entry.first; 2883 } 2884 } 2885 ASSERT(main); 2886 for (auto entry : fVariableMap) { 2887 const Variable* var = entry.first; 2888 if (var->fStorage == Variable::kGlobal_Storage && 2889 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) || 2890 (var->fModifiers.fFlags & Modifiers::kOut_Flag))) { 2891 interfaceVars.insert(entry.second); 2892 } 2893 } 2894 this->writeCapabilities(out); 2895 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out); 2896 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out); 2897 this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) + 2898 (int32_t) interfaceVars.size(), out); 2899 switch (program.fKind) { 2900 case Program::kVertex_Kind: 2901 this->writeWord(SpvExecutionModelVertex, out); 2902 break; 2903 case Program::kFragment_Kind: 2904 this->writeWord(SpvExecutionModelFragment, out); 2905 break; 2906 case Program::kGeometry_Kind: 2907 this->writeWord(SpvExecutionModelGeometry, out); 2908 break; 2909 default: 2910 ABORT("cannot write this kind of program to SPIR-V\n"); 2911 } 2912 SpvId entryPoint = fFunctionMap[main]; 2913 this->writeWord(entryPoint, out); 2914 this->writeString(main->fName.fChars, main->fName.fLength, out); 2915 for (int var : interfaceVars) { 2916 this->writeWord(var, out); 2917 } 2918 if (program.fKind == Program::kGeometry_Kind) { 2919 this->writeGeometryShaderExecutionMode(entryPoint, out); 2920 } 2921 if (program.fKind == Program::kFragment_Kind) { 2922 this->writeInstruction(SpvOpExecutionMode, 2923 fFunctionMap[main], 2924 SpvExecutionModeOriginUpperLeft, 2925 out); 2926 } 2927 for (size_t i = 0; i < program.fElements.size(); i++) { 2928 if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) { 2929 this->writeInstruction(SpvOpSourceExtension, 2930 ((Extension&) *program.fElements[i]).fName.c_str(), 2931 out); 2932 } 2933 } 2934 2935 write_stringstream(fExtraGlobalsBuffer, out); 2936 write_stringstream(fNameBuffer, out); 2937 write_stringstream(fDecorationBuffer, out); 2938 write_stringstream(fConstantBuffer, out); 2939 write_stringstream(fExternalFunctionsBuffer, out); 2940 write_stringstream(body, out); 2941} 2942 2943bool SPIRVCodeGenerator::generateCode() { 2944 ASSERT(!fErrors.errorCount()); 2945 this->writeWord(SpvMagicNumber, *fOut); 2946 this->writeWord(SpvVersion, *fOut); 2947 this->writeWord(SKSL_MAGIC, *fOut); 2948 StringStream buffer; 2949 this->writeInstructions(fProgram, buffer); 2950 this->writeWord(fIdCount, *fOut); 2951 this->writeWord(0, *fOut); // reserved, always zero 2952 write_stringstream(buffer, *fOut); 2953 return 0 == fErrors.errorCount(); 2954} 2955 2956} 2957