1/*------------------------------------------------------------------------ 2 * Vulkan Conformance Tests 3 * ------------------------ 4 * 5 * Copyright (c) 2017 The Khronos Group Inc. 6 * Copyright (c) 2017 Codeplay Software Ltd. 7 * 8 * Licensed under the Apache License, Version 2.0 (the "License"); 9 * you may not use this file except in compliance with the License. 10 * You may obtain a copy of the License at 11 * 12 * http://www.apache.org/licenses/LICENSE-2.0 13 * 14 * Unless required by applicable law or agreed to in writing, software 15 * distributed under the License is distributed on an "AS IS" BASIS, 16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 * See the License for the specific language governing permissions and 18 * limitations under the License. 19 * 20 */ /*! 21 * \file 22 * \brief Subgroups Tests 23 */ /*--------------------------------------------------------------------*/ 24 25#include "vktSubgroupsArithmeticTests.hpp" 26#include "vktSubgroupsTestsUtils.hpp" 27 28#include <string> 29#include <vector> 30 31using namespace tcu; 32using namespace std; 33using namespace vk; 34using namespace vkt; 35 36namespace 37{ 38enum OpType 39{ 40 OPTYPE_ADD = 0, 41 OPTYPE_MUL, 42 OPTYPE_MIN, 43 OPTYPE_MAX, 44 OPTYPE_AND, 45 OPTYPE_OR, 46 OPTYPE_XOR, 47 OPTYPE_INCLUSIVE_ADD, 48 OPTYPE_INCLUSIVE_MUL, 49 OPTYPE_INCLUSIVE_MIN, 50 OPTYPE_INCLUSIVE_MAX, 51 OPTYPE_INCLUSIVE_AND, 52 OPTYPE_INCLUSIVE_OR, 53 OPTYPE_INCLUSIVE_XOR, 54 OPTYPE_EXCLUSIVE_ADD, 55 OPTYPE_EXCLUSIVE_MUL, 56 OPTYPE_EXCLUSIVE_MIN, 57 OPTYPE_EXCLUSIVE_MAX, 58 OPTYPE_EXCLUSIVE_AND, 59 OPTYPE_EXCLUSIVE_OR, 60 OPTYPE_EXCLUSIVE_XOR, 61 OPTYPE_LAST 62}; 63 64static bool checkVertexPipelineStages(std::vector<const void*> datas, 65 deUint32 width, deUint32) 66{ 67 const deUint32* data = 68 reinterpret_cast<const deUint32*>(datas[0]); 69 for (deUint32 x = 0; x < width; ++x) 70 { 71 deUint32 val = data[x]; 72 73 if (0x3 != val) 74 { 75 return false; 76 } 77 } 78 79 return true; 80} 81 82static bool checkFragment(std::vector<const void*> datas, 83 deUint32 width, deUint32 height, deUint32) 84{ 85 const deUint32* data = 86 reinterpret_cast<const deUint32*>(datas[0]); 87 for (deUint32 x = 0; x < width; ++x) 88 { 89 for (deUint32 y = 0; y < height; ++y) 90 { 91 deUint32 val = data[x * height + y]; 92 93 if (0x3 != val) 94 { 95 return false; 96 } 97 } 98 } 99 100 return true; 101} 102 103static bool checkCompute(std::vector<const void*> datas, 104 const deUint32 numWorkgroups[3], const deUint32 localSize[3], 105 deUint32) 106{ 107 const deUint32* data = 108 reinterpret_cast<const deUint32*>(datas[0]); 109 110 for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX) 111 { 112 for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY) 113 { 114 for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ) 115 { 116 for (deUint32 lX = 0; lX < localSize[0]; ++lX) 117 { 118 for (deUint32 lY = 0; lY < localSize[1]; ++lY) 119 { 120 for (deUint32 lZ = 0; lZ < localSize[2]; 121 ++lZ) 122 { 123 const deUint32 globalInvocationX = 124 nX * localSize[0] + lX; 125 const deUint32 globalInvocationY = 126 nY * localSize[1] + lY; 127 const deUint32 globalInvocationZ = 128 nZ * localSize[2] + lZ; 129 130 const deUint32 globalSizeX = 131 numWorkgroups[0] * localSize[0]; 132 const deUint32 globalSizeY = 133 numWorkgroups[1] * localSize[1]; 134 135 const deUint32 offset = 136 globalSizeX * 137 ((globalSizeY * 138 globalInvocationZ) + 139 globalInvocationY) + 140 globalInvocationX; 141 142 if (0x3 != data[offset]) 143 { 144 return false; 145 } 146 } 147 } 148 } 149 } 150 } 151 } 152 153 return true; 154} 155 156std::string getOpTypeName(int opType) 157{ 158 switch (opType) 159 { 160 default: 161 DE_FATAL("Unsupported op type"); 162 case OPTYPE_ADD: 163 return "subgroupAdd"; 164 case OPTYPE_MUL: 165 return "subgroupMul"; 166 case OPTYPE_MIN: 167 return "subgroupMin"; 168 case OPTYPE_MAX: 169 return "subgroupMax"; 170 case OPTYPE_AND: 171 return "subgroupAnd"; 172 case OPTYPE_OR: 173 return "subgroupOr"; 174 case OPTYPE_XOR: 175 return "subgroupXor"; 176 case OPTYPE_INCLUSIVE_ADD: 177 return "subgroupInclusiveAdd"; 178 case OPTYPE_INCLUSIVE_MUL: 179 return "subgroupInclusiveMul"; 180 case OPTYPE_INCLUSIVE_MIN: 181 return "subgroupInclusiveMin"; 182 case OPTYPE_INCLUSIVE_MAX: 183 return "subgroupInclusiveMax"; 184 case OPTYPE_INCLUSIVE_AND: 185 return "subgroupInclusiveAnd"; 186 case OPTYPE_INCLUSIVE_OR: 187 return "subgroupInclusiveOr"; 188 case OPTYPE_INCLUSIVE_XOR: 189 return "subgroupInclusiveXor"; 190 case OPTYPE_EXCLUSIVE_ADD: 191 return "subgroupExclusiveAdd"; 192 case OPTYPE_EXCLUSIVE_MUL: 193 return "subgroupExclusiveMul"; 194 case OPTYPE_EXCLUSIVE_MIN: 195 return "subgroupExclusiveMin"; 196 case OPTYPE_EXCLUSIVE_MAX: 197 return "subgroupExclusiveMax"; 198 case OPTYPE_EXCLUSIVE_AND: 199 return "subgroupExclusiveAnd"; 200 case OPTYPE_EXCLUSIVE_OR: 201 return "subgroupExclusiveOr"; 202 case OPTYPE_EXCLUSIVE_XOR: 203 return "subgroupExclusiveXor"; 204 } 205} 206 207std::string getOpTypeOperation(int opType, vk::VkFormat format, std::string lhs, std::string rhs) 208{ 209 switch (opType) 210 { 211 default: 212 DE_FATAL("Unsupported op type"); 213 case OPTYPE_ADD: 214 case OPTYPE_INCLUSIVE_ADD: 215 case OPTYPE_EXCLUSIVE_ADD: 216 return lhs + " + " + rhs; 217 case OPTYPE_MUL: 218 case OPTYPE_INCLUSIVE_MUL: 219 case OPTYPE_EXCLUSIVE_MUL: 220 return lhs + " * " + rhs; 221 case OPTYPE_MIN: 222 case OPTYPE_INCLUSIVE_MIN: 223 case OPTYPE_EXCLUSIVE_MIN: 224 switch (format) 225 { 226 default: 227 return "min(" + lhs + ", " + rhs + ")"; 228 case VK_FORMAT_R32_SFLOAT: 229 case VK_FORMAT_R64_SFLOAT: 230 return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs + ")))"; 231 case VK_FORMAT_R32G32_SFLOAT: 232 case VK_FORMAT_R32G32B32_SFLOAT: 233 case VK_FORMAT_R32G32B32A32_SFLOAT: 234 case VK_FORMAT_R64G64_SFLOAT: 235 case VK_FORMAT_R64G64B64_SFLOAT: 236 case VK_FORMAT_R64G64B64A64_SFLOAT: 237 return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))"; 238 } 239 case OPTYPE_MAX: 240 case OPTYPE_INCLUSIVE_MAX: 241 case OPTYPE_EXCLUSIVE_MAX: 242 switch (format) 243 { 244 default: 245 return "max(" + lhs + ", " + rhs + ")"; 246 case VK_FORMAT_R32_SFLOAT: 247 case VK_FORMAT_R64_SFLOAT: 248 return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs + ")))"; 249 case VK_FORMAT_R32G32_SFLOAT: 250 case VK_FORMAT_R32G32B32_SFLOAT: 251 case VK_FORMAT_R32G32B32A32_SFLOAT: 252 case VK_FORMAT_R64G64_SFLOAT: 253 case VK_FORMAT_R64G64B64_SFLOAT: 254 case VK_FORMAT_R64G64B64A64_SFLOAT: 255 return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))"; 256 } 257 case OPTYPE_AND: 258 case OPTYPE_INCLUSIVE_AND: 259 case OPTYPE_EXCLUSIVE_AND: 260 switch (format) 261 { 262 default: 263 return lhs + " & " + rhs; 264 case VK_FORMAT_R8_USCALED: 265 return lhs + " && " + rhs; 266 case VK_FORMAT_R8G8_USCALED: 267 return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)"; 268 case VK_FORMAT_R8G8B8_USCALED: 269 return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z)"; 270 case VK_FORMAT_R8G8B8A8_USCALED: 271 return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z, " + lhs + ".w && " + rhs + ".w)"; 272 } 273 case OPTYPE_OR: 274 case OPTYPE_INCLUSIVE_OR: 275 case OPTYPE_EXCLUSIVE_OR: 276 switch (format) 277 { 278 default: 279 return lhs + " | " + rhs; 280 case VK_FORMAT_R8_USCALED: 281 return lhs + " || " + rhs; 282 case VK_FORMAT_R8G8_USCALED: 283 return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)"; 284 case VK_FORMAT_R8G8B8_USCALED: 285 return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z)"; 286 case VK_FORMAT_R8G8B8A8_USCALED: 287 return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z, " + lhs + ".w || " + rhs + ".w)"; 288 } 289 case OPTYPE_XOR: 290 case OPTYPE_INCLUSIVE_XOR: 291 case OPTYPE_EXCLUSIVE_XOR: 292 switch (format) 293 { 294 default: 295 return lhs + " ^ " + rhs; 296 case VK_FORMAT_R8_USCALED: 297 return lhs + " ^^ " + rhs; 298 case VK_FORMAT_R8G8_USCALED: 299 return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)"; 300 case VK_FORMAT_R8G8B8_USCALED: 301 return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z)"; 302 case VK_FORMAT_R8G8B8A8_USCALED: 303 return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z, " + lhs + ".w ^^ " + rhs + ".w)"; 304 } 305 } 306} 307 308std::string getIdentity(int opType, vk::VkFormat format) 309{ 310 bool isFloat = false; 311 bool isInt = false; 312 bool isUnsigned = false; 313 314 switch (format) 315 { 316 default: 317 DE_FATAL("Unhandled format!"); 318 case VK_FORMAT_R32_SINT: 319 case VK_FORMAT_R32G32_SINT: 320 case VK_FORMAT_R32G32B32_SINT: 321 case VK_FORMAT_R32G32B32A32_SINT: 322 isInt = true; 323 break; 324 case VK_FORMAT_R32_UINT: 325 case VK_FORMAT_R32G32_UINT: 326 case VK_FORMAT_R32G32B32_UINT: 327 case VK_FORMAT_R32G32B32A32_UINT: 328 isUnsigned = true; 329 break; 330 case VK_FORMAT_R32_SFLOAT: 331 case VK_FORMAT_R32G32_SFLOAT: 332 case VK_FORMAT_R32G32B32_SFLOAT: 333 case VK_FORMAT_R32G32B32A32_SFLOAT: 334 case VK_FORMAT_R64_SFLOAT: 335 case VK_FORMAT_R64G64_SFLOAT: 336 case VK_FORMAT_R64G64B64_SFLOAT: 337 case VK_FORMAT_R64G64B64A64_SFLOAT: 338 isFloat = true; 339 break; 340 case VK_FORMAT_R8_USCALED: 341 case VK_FORMAT_R8G8_USCALED: 342 case VK_FORMAT_R8G8B8_USCALED: 343 case VK_FORMAT_R8G8B8A8_USCALED: 344 break; // bool types are not anything 345 } 346 347 switch (opType) 348 { 349 default: 350 DE_FATAL("Unsupported op type"); 351 case OPTYPE_ADD: 352 case OPTYPE_INCLUSIVE_ADD: 353 case OPTYPE_EXCLUSIVE_ADD: 354 return subgroups::getFormatNameForGLSL(format) + "(0)"; 355 case OPTYPE_MUL: 356 case OPTYPE_INCLUSIVE_MUL: 357 case OPTYPE_EXCLUSIVE_MUL: 358 return subgroups::getFormatNameForGLSL(format) + "(1)"; 359 case OPTYPE_MIN: 360 case OPTYPE_INCLUSIVE_MIN: 361 case OPTYPE_EXCLUSIVE_MIN: 362 if (isFloat) 363 { 364 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))"; 365 } 366 else if (isInt) 367 { 368 return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)"; 369 } 370 else if (isUnsigned) 371 { 372 return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)"; 373 } 374 else 375 { 376 DE_FATAL("Unhandled case"); 377 } 378 case OPTYPE_MAX: 379 case OPTYPE_INCLUSIVE_MAX: 380 case OPTYPE_EXCLUSIVE_MAX: 381 if (isFloat) 382 { 383 return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))"; 384 } 385 else if (isInt) 386 { 387 return subgroups::getFormatNameForGLSL(format) + "(0x80000000)"; 388 } 389 else if (isUnsigned) 390 { 391 return subgroups::getFormatNameForGLSL(format) + "(0)"; 392 } 393 else 394 { 395 DE_FATAL("Unhandled case"); 396 } 397 case OPTYPE_AND: 398 case OPTYPE_INCLUSIVE_AND: 399 case OPTYPE_EXCLUSIVE_AND: 400 return subgroups::getFormatNameForGLSL(format) + "(~0)"; 401 case OPTYPE_OR: 402 case OPTYPE_INCLUSIVE_OR: 403 case OPTYPE_EXCLUSIVE_OR: 404 return subgroups::getFormatNameForGLSL(format) + "(0)"; 405 case OPTYPE_XOR: 406 case OPTYPE_INCLUSIVE_XOR: 407 case OPTYPE_EXCLUSIVE_XOR: 408 return subgroups::getFormatNameForGLSL(format) + "(0)"; 409 } 410} 411 412std::string getCompare(int opType, vk::VkFormat format, std::string lhs, std::string rhs) 413{ 414 std::string formatName = subgroups::getFormatNameForGLSL(format); 415 switch (format) 416 { 417 default: 418 return "all(equal(" + lhs + ", " + rhs + "))"; 419 case VK_FORMAT_R8_USCALED: 420 case VK_FORMAT_R32_UINT: 421 case VK_FORMAT_R32_SINT: 422 return "(" + lhs + " == " + rhs + ")"; 423 case VK_FORMAT_R32_SFLOAT: 424 case VK_FORMAT_R64_SFLOAT: 425 switch (opType) 426 { 427 default: 428 return "(abs(" + lhs + " - " + rhs + ") < 0.00001)"; 429 case OPTYPE_MIN: 430 case OPTYPE_INCLUSIVE_MIN: 431 case OPTYPE_EXCLUSIVE_MIN: 432 case OPTYPE_MAX: 433 case OPTYPE_INCLUSIVE_MAX: 434 case OPTYPE_EXCLUSIVE_MAX: 435 return "(" + lhs + " == " + rhs + ")"; 436 } 437 case VK_FORMAT_R32G32_SFLOAT: 438 case VK_FORMAT_R32G32B32_SFLOAT: 439 case VK_FORMAT_R32G32B32A32_SFLOAT: 440 case VK_FORMAT_R64G64_SFLOAT: 441 case VK_FORMAT_R64G64B64_SFLOAT: 442 case VK_FORMAT_R64G64B64A64_SFLOAT: 443 switch (opType) 444 { 445 default: 446 return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))"; 447 case OPTYPE_MIN: 448 case OPTYPE_INCLUSIVE_MIN: 449 case OPTYPE_EXCLUSIVE_MIN: 450 case OPTYPE_MAX: 451 case OPTYPE_INCLUSIVE_MAX: 452 case OPTYPE_EXCLUSIVE_MAX: 453 return "all(equal(" + lhs + ", " + rhs + "))"; 454 } 455 } 456} 457 458struct CaseDefinition 459{ 460 int opType; 461 VkShaderStageFlags shaderStage; 462 VkFormat format; 463 bool noSSBO; 464}; 465 466void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef) 467{ 468 std::string indexVars; 469 switch (caseDef.opType) 470 { 471 default: 472 indexVars = " uint start = 0, end = gl_SubgroupSize;\n"; 473 break; 474 case OPTYPE_INCLUSIVE_ADD: 475 case OPTYPE_INCLUSIVE_MUL: 476 case OPTYPE_INCLUSIVE_MIN: 477 case OPTYPE_INCLUSIVE_MAX: 478 case OPTYPE_INCLUSIVE_AND: 479 case OPTYPE_INCLUSIVE_OR: 480 case OPTYPE_INCLUSIVE_XOR: 481 indexVars = " uint start = 0, end = gl_SubgroupInvocationID + 1;\n"; 482 break; 483 case OPTYPE_EXCLUSIVE_ADD: 484 case OPTYPE_EXCLUSIVE_MUL: 485 case OPTYPE_EXCLUSIVE_MIN: 486 case OPTYPE_EXCLUSIVE_MAX: 487 case OPTYPE_EXCLUSIVE_AND: 488 case OPTYPE_EXCLUSIVE_OR: 489 case OPTYPE_EXCLUSIVE_XOR: 490 indexVars = " uint start = 0, end = gl_SubgroupInvocationID;\n"; 491 break; 492 } 493 494 std::ostringstream bdy; 495 496 bdy << indexVars 497 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = " 498 << getIdentity(caseDef.opType, caseDef.format) << ";\n" 499 << " uint tempResult = 0;\n" 500 << " for (uint index = start; index < end; index++)\n" 501 << " {\n" 502 << " if (subgroupBallotBitExtract(mask, index))\n" 503 << " {\n" 504 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n" 505 << " }\n" 506 << " }\n" 507 << " tempResult = " << getCompare(caseDef.opType, caseDef.format, "ref", 508 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x1 : 0;\n" 509 << " if (1 == (gl_SubgroupInvocationID % 2))\n" 510 << " {\n" 511 << " mask = subgroupBallot(true);\n" 512 << " ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n" 513 << " for (uint index = start; index < end; index++)\n" 514 << " {\n" 515 << " if (subgroupBallotBitExtract(mask, index))\n" 516 << " {\n" 517 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n" 518 << " }\n" 519 << " }\n" 520 << " tempResult |= " << getCompare(caseDef.opType, caseDef.format, "ref", 521 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x2 : 0;\n" 522 << " }\n" 523 << " else\n" 524 << " {\n" 525 << " tempResult |= 0x2;\n" 526 << " }\n"; 527 if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage) 528 { 529 std::ostringstream src; 530 std::ostringstream fragmentSrc; 531 532 src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n" 533 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n" 534 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 535 << "layout(location = 0) in highp vec4 in_position;\n" 536 << "layout(location = 0) out float out_color;\n" 537 << "layout(set = 0, binding = 0) uniform Buffer1\n" 538 << "{\n" 539 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n" 540 << "};\n" 541 << "\n" 542 << "void main (void)\n" 543 << "{\n" 544 << " uvec4 mask = subgroupBallot(true);\n" 545 << bdy.str() 546 << " out_color = float(tempResult);\n" 547 << " gl_Position = in_position;\n" 548 << " gl_PointSize = 1.0f;\n" 549 << "}\n"; 550 551 programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 552 553 fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n" 554 << "layout(location = 0) in float in_color;\n" 555 << "layout(location = 0) out uint out_color;\n" 556 << "void main()\n" 557 <<"{\n" 558 << " out_color = uint(in_color);\n" 559 << "}\n"; 560 programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 561 } 562 else 563 { 564 DE_FATAL("Unsupported shader stage"); 565 } 566} 567 568void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef) 569{ 570 std::string indexVars; 571 switch (caseDef.opType) 572 { 573 default: 574 indexVars = " uint start = 0, end = gl_SubgroupSize;\n"; 575 break; 576 case OPTYPE_INCLUSIVE_ADD: 577 case OPTYPE_INCLUSIVE_MUL: 578 case OPTYPE_INCLUSIVE_MIN: 579 case OPTYPE_INCLUSIVE_MAX: 580 case OPTYPE_INCLUSIVE_AND: 581 case OPTYPE_INCLUSIVE_OR: 582 case OPTYPE_INCLUSIVE_XOR: 583 indexVars = " uint start = 0, end = gl_SubgroupInvocationID + 1;\n"; 584 break; 585 case OPTYPE_EXCLUSIVE_ADD: 586 case OPTYPE_EXCLUSIVE_MUL: 587 case OPTYPE_EXCLUSIVE_MIN: 588 case OPTYPE_EXCLUSIVE_MAX: 589 case OPTYPE_EXCLUSIVE_AND: 590 case OPTYPE_EXCLUSIVE_OR: 591 case OPTYPE_EXCLUSIVE_XOR: 592 indexVars = " uint start = 0, end = gl_SubgroupInvocationID;\n"; 593 break; 594 } 595 596 std::ostringstream bdy; 597 598 bdy << indexVars 599 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = " 600 << getIdentity(caseDef.opType, caseDef.format) << ";\n" 601 << " uint tempResult = 0;\n" 602 << " for (uint index = start; index < end; index++)\n" 603 << " {\n" 604 << " if (subgroupBallotBitExtract(mask, index))\n" 605 << " {\n" 606 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n" 607 << " }\n" 608 << " }\n" 609 << " tempResult = " << getCompare(caseDef.opType, caseDef.format, "ref", 610 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x1 : 0;\n" 611 << " if (1 == (gl_SubgroupInvocationID % 2))\n" 612 << " {\n" 613 << " mask = subgroupBallot(true);\n" 614 << " ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n" 615 << " for (uint index = start; index < end; index++)\n" 616 << " {\n" 617 << " if (subgroupBallotBitExtract(mask, index))\n" 618 << " {\n" 619 << " ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n" 620 << " }\n" 621 << " }\n" 622 << " tempResult |= " << getCompare(caseDef.opType, caseDef.format, "ref", 623 getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x2 : 0;\n" 624 << " }\n" 625 << " else\n" 626 << " {\n" 627 << " tempResult |= 0x2;\n" 628 << " }\n"; 629 630 if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage) 631 { 632 std::ostringstream src; 633 634 src << "#version 450\n" 635 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n" 636 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 637 << "layout (local_size_x_id = 0, local_size_y_id = 1, " 638 "local_size_z_id = 2) in;\n" 639 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n" 640 << "{\n" 641 << " uint result[];\n" 642 << "};\n" 643 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n" 644 << "{\n" 645 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 646 << "};\n" 647 << "\n" 648 << "void main (void)\n" 649 << "{\n" 650 << " uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n" 651 << " highp uint offset = globalSize.x * ((globalSize.y * " 652 "gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + " 653 "gl_GlobalInvocationID.x;\n" 654 << " uvec4 mask = subgroupBallot(true);\n" 655 << bdy.str() 656 << " result[offset] = tempResult;\n" 657 << "}\n"; 658 659 programCollection.glslSources.add("comp") 660 << glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 661 } 662 else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage) 663 { 664 programCollection.glslSources.add("vert") 665 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 666 667 std::ostringstream frag; 668 669 frag << "#version 450\n" 670 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n" 671 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 672 << "layout(location = 0) out uint result;\n" 673 << "layout(set = 0, binding = 0, std430) readonly buffer Buffer2\n" 674 << "{\n" 675 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 676 << "};\n" 677 << "void main (void)\n" 678 << "{\n" 679 << " uvec4 mask = subgroupBallot(true);\n" 680 << bdy.str() 681 << " result = tempResult;\n" 682 << "}\n"; 683 684 programCollection.glslSources.add("frag") 685 << glu::FragmentSource(frag.str())<< vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 686 } 687 else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage) 688 { 689 std::ostringstream src; 690 691 src << "#version 450\n" 692 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n" 693 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 694 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n" 695 << "{\n" 696 << " uint result[];\n" 697 << "};\n" 698 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n" 699 << "{\n" 700 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 701 << "};\n" 702 << "\n" 703 << "void main (void)\n" 704 << "{\n" 705 << " uvec4 mask = subgroupBallot(true);\n" 706 << bdy.str() 707 << " result[gl_VertexIndex] = tempResult;\n" 708 << " gl_PointSize = 1.0f;\n" 709 << "}\n"; 710 711 programCollection.glslSources.add("vert") 712 << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 713 } 714 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage) 715 { 716 programCollection.glslSources.add("vert") 717 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 718 719 std::ostringstream src; 720 721 src << "#version 450\n" 722 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n" 723 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 724 << "layout(points) in;\n" 725 << "layout(points, max_vertices = 1) out;\n" 726 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n" 727 << "{\n" 728 << " uint result[];\n" 729 << "};\n" 730 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n" 731 << "{\n" 732 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 733 << "};\n" 734 << "\n" 735 << "void main (void)\n" 736 << "{\n" 737 << " uvec4 mask = subgroupBallot(true);\n" 738 << bdy.str() 739 << " result[gl_PrimitiveIDIn] = tempResult;\n" 740 << "}\n"; 741 742 programCollection.glslSources.add("geom") 743 << glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 744 } 745 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage) 746 { 747 programCollection.glslSources.add("vert") 748 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 749 750 programCollection.glslSources.add("tese") 751 << glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n"); 752 753 std::ostringstream src; 754 755 src << "#version 450\n" 756 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n" 757 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 758 << "layout(vertices=1) out;\n" 759 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n" 760 << "{\n" 761 << " uint result[];\n" 762 << "};\n" 763 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n" 764 << "{\n" 765 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 766 << "};\n" 767 << "\n" 768 << "void main (void)\n" 769 << "{\n" 770 << " uvec4 mask = subgroupBallot(true);\n" 771 << bdy.str() 772 << " result[gl_PrimitiveID] = tempResult;\n" 773 << "}\n"; 774 775 programCollection.glslSources.add("tesc") 776 << glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 777 } 778 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage) 779 { 780 programCollection.glslSources.add("vert") 781 << glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 782 783 programCollection.glslSources.add("tesc") 784 << glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n"); 785 786 std::ostringstream src; 787 788 src << "#version 450\n" 789 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n" 790 << "#extension GL_KHR_shader_subgroup_ballot: enable\n" 791 << "layout(isolines) in;\n" 792 << "layout(set = 0, binding = 0, std430) buffer Buffer1\n" 793 << "{\n" 794 << " uint result[];\n" 795 << "};\n" 796 << "layout(set = 0, binding = 1, std430) buffer Buffer2\n" 797 << "{\n" 798 << " " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n" 799 << "};\n" 800 << "\n" 801 << "void main (void)\n" 802 << "{\n" 803 << " uvec4 mask = subgroupBallot(true);\n" 804 << bdy.str() 805 << " result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n" 806 << "}\n"; 807 808 programCollection.glslSources.add("tese") 809 << glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u); 810 } 811 else 812 { 813 DE_FATAL("Unsupported shader stage"); 814 } 815} 816 817tcu::TestStatus test(Context& context, const CaseDefinition caseDef) 818{ 819 if (!subgroups::isSubgroupSupported(context)) 820 TCU_THROW(NotSupportedError, "Subgroup operations are not supported"); 821 822 if (!subgroups::areSubgroupOperationsSupportedForStage( 823 context, caseDef.shaderStage)) 824 { 825 if (subgroups::areSubgroupOperationsRequiredForStage( 826 caseDef.shaderStage)) 827 { 828 return tcu::TestStatus::fail( 829 "Shader stage " + 830 subgroups::getShaderStageName(caseDef.shaderStage) + 831 " is required to support subgroup operations!"); 832 } 833 else 834 { 835 TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage"); 836 } 837 } 838 839 if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_ARITHMETIC_BIT)) 840 { 841 TCU_THROW(NotSupportedError, "Device does not support subgroup arithmetic operations"); 842 } 843 844 if (subgroups::isDoubleFormat(caseDef.format) && 845 !subgroups::isDoubleSupportedForDevice(context)) 846 { 847 TCU_THROW(NotSupportedError, "Device does not support subgroup double operations"); 848 } 849 850 //Tests which don't use the SSBO 851 if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage) 852 { 853 subgroups::SSBOData inputData; 854 inputData.format = caseDef.format; 855 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 856 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 857 858 return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages); 859 } 860 861 if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) && 862 (VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage)) 863 { 864 if (!subgroups::isVertexSSBOSupportedForDevice(context)) 865 { 866 TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes"); 867 } 868 } 869 870 if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage) 871 { 872 subgroups::SSBOData inputData; 873 inputData.format = caseDef.format; 874 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 875 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 876 877 return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT, 878 &inputData, 1, checkFragment); 879 } 880 else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage) 881 { 882 subgroups::SSBOData inputData; 883 inputData.format = caseDef.format; 884 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 885 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 886 887 return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData, 888 1, checkCompute); 889 } 890 else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage) 891 { 892 subgroups::SSBOData inputData; 893 inputData.format = caseDef.format; 894 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 895 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 896 897 return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT, &inputData, 898 1, checkVertexPipelineStages); 899 } 900 else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage) 901 { 902 subgroups::SSBOData inputData; 903 inputData.format = caseDef.format; 904 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 905 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 906 907 return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT, &inputData, 908 1, checkVertexPipelineStages); 909 } 910 else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage) 911 { 912 subgroups::SSBOData inputData; 913 inputData.format = caseDef.format; 914 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 915 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 916 917 return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT, &inputData, 918 1, checkVertexPipelineStages); 919 } 920 else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage) 921 { 922 subgroups::SSBOData inputData; 923 inputData.format = caseDef.format; 924 inputData.numElements = subgroups::maxSupportedSubgroupSize(); 925 inputData.initializeType = subgroups::SSBOData::InitializeNonZero; 926 927 return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT, &inputData, 928 1, checkVertexPipelineStages); 929 } 930 else 931 { 932 TCU_THROW(InternalError, "Unhandled shader stage"); 933 } 934} 935} 936 937namespace vkt 938{ 939namespace subgroups 940{ 941tcu::TestCaseGroup* createSubgroupsArithmeticTests(tcu::TestContext& testCtx) 942{ 943 de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup( 944 testCtx, "arithmetic", "Subgroup arithmetic category tests")); 945 946 const VkShaderStageFlags stages[] = 947 { 948 VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT, 949 VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT, 950 VK_SHADER_STAGE_GEOMETRY_BIT, 951 VK_SHADER_STAGE_VERTEX_BIT, 952 VK_SHADER_STAGE_FRAGMENT_BIT, 953 VK_SHADER_STAGE_COMPUTE_BIT 954 }; 955 956 const VkFormat formats[] = 957 { 958 VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT, 959 VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT, 960 VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT, 961 VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT, 962 VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT, 963 VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT, 964 VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT, 965 VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED, 966 VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED, 967 }; 968 969 for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex) 970 { 971 const VkShaderStageFlags stage = stages[stageIndex]; 972 973 for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex) 974 { 975 const VkFormat format = formats[formatIndex]; 976 977 for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex) 978 { 979 bool isBool = false; 980 bool isFloat = false; 981 982 switch (format) 983 { 984 default: 985 break; 986 case VK_FORMAT_R32_SFLOAT: 987 case VK_FORMAT_R32G32_SFLOAT: 988 case VK_FORMAT_R32G32B32_SFLOAT: 989 case VK_FORMAT_R32G32B32A32_SFLOAT: 990 case VK_FORMAT_R64_SFLOAT: 991 case VK_FORMAT_R64G64_SFLOAT: 992 case VK_FORMAT_R64G64B64_SFLOAT: 993 case VK_FORMAT_R64G64B64A64_SFLOAT: 994 isFloat = true; 995 break; 996 case VK_FORMAT_R8_USCALED: 997 case VK_FORMAT_R8G8_USCALED: 998 case VK_FORMAT_R8G8B8_USCALED: 999 case VK_FORMAT_R8G8B8A8_USCALED: 1000 isBool = true; 1001 break; 1002 } 1003 1004 bool isBitwiseOp = false; 1005 1006 switch (opTypeIndex) 1007 { 1008 default: 1009 break; 1010 case OPTYPE_AND: 1011 case OPTYPE_INCLUSIVE_AND: 1012 case OPTYPE_EXCLUSIVE_AND: 1013 case OPTYPE_OR: 1014 case OPTYPE_INCLUSIVE_OR: 1015 case OPTYPE_EXCLUSIVE_OR: 1016 case OPTYPE_XOR: 1017 case OPTYPE_INCLUSIVE_XOR: 1018 case OPTYPE_EXCLUSIVE_XOR: 1019 isBitwiseOp = true; 1020 break; 1021 } 1022 1023 if (isFloat && isBitwiseOp) 1024 { 1025 // Skip float with bitwise category. 1026 continue; 1027 } 1028 1029 if (isBool && !isBitwiseOp) 1030 { 1031 // Skip bool when its not the bitwise category. 1032 continue; 1033 } 1034 1035 CaseDefinition caseDef = {opTypeIndex, stage, format, false}; 1036 1037 std::string op = getOpTypeName(opTypeIndex); 1038 1039 addFunctionCaseWithPrograms(group.get(), 1040 de::toLower(op) + "_" + 1041 subgroups::getFormatNameForGLSL(format) + 1042 "_" + getShaderStageName(stage), 1043 "", initPrograms, test, caseDef); 1044 1045 if (VK_SHADER_STAGE_VERTEX_BIT == stage) 1046 { 1047 caseDef.noSSBO = true; 1048 addFunctionCaseWithPrograms(group.get(), de::toLower(op) + "_" + subgroups::getFormatNameForGLSL(format) + 1049 "_" + getShaderStageName(stage) + "_framebuffer", "", 1050 initFrameBufferPrograms, test, caseDef); 1051 } 1052 } 1053 } 1054 } 1055 1056 return group.release(); 1057} 1058 1059} // subgroups 1060} // vkt 1061