1/*------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2017 The Khronos Group Inc.
6 * Copyright (c) 2017 Codeplay Software Ltd.
7 *
8 * Licensed under the Apache License, Version 2.0 (the "License");
9 * you may not use this file except in compliance with the License.
10 * You may obtain a copy of the License at
11 *
12 *      http://www.apache.org/licenses/LICENSE-2.0
13 *
14 * Unless required by applicable law or agreed to in writing, software
15 * distributed under the License is distributed on an "AS IS" BASIS,
16 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 * See the License for the specific language governing permissions and
18 * limitations under the License.
19 *
20 */ /*!
21 * \file
22 * \brief Subgroups Tests
23 */ /*--------------------------------------------------------------------*/
24
25#include "vktSubgroupsArithmeticTests.hpp"
26#include "vktSubgroupsTestsUtils.hpp"
27
28#include <string>
29#include <vector>
30
31using namespace tcu;
32using namespace std;
33using namespace vk;
34using namespace vkt;
35
36namespace
37{
38enum OpType
39{
40	OPTYPE_ADD = 0,
41	OPTYPE_MUL,
42	OPTYPE_MIN,
43	OPTYPE_MAX,
44	OPTYPE_AND,
45	OPTYPE_OR,
46	OPTYPE_XOR,
47	OPTYPE_INCLUSIVE_ADD,
48	OPTYPE_INCLUSIVE_MUL,
49	OPTYPE_INCLUSIVE_MIN,
50	OPTYPE_INCLUSIVE_MAX,
51	OPTYPE_INCLUSIVE_AND,
52	OPTYPE_INCLUSIVE_OR,
53	OPTYPE_INCLUSIVE_XOR,
54	OPTYPE_EXCLUSIVE_ADD,
55	OPTYPE_EXCLUSIVE_MUL,
56	OPTYPE_EXCLUSIVE_MIN,
57	OPTYPE_EXCLUSIVE_MAX,
58	OPTYPE_EXCLUSIVE_AND,
59	OPTYPE_EXCLUSIVE_OR,
60	OPTYPE_EXCLUSIVE_XOR,
61	OPTYPE_LAST
62};
63
64static bool checkVertexPipelineStages(std::vector<const void*> datas,
65									  deUint32 width, deUint32)
66{
67	const deUint32* data =
68		reinterpret_cast<const deUint32*>(datas[0]);
69	for (deUint32 x = 0; x < width; ++x)
70	{
71		deUint32 val = data[x];
72
73		if (0x3 != val)
74		{
75			return false;
76		}
77	}
78
79	return true;
80}
81
82static bool checkFragment(std::vector<const void*> datas,
83						  deUint32 width, deUint32 height, deUint32)
84{
85	const deUint32* data =
86		reinterpret_cast<const deUint32*>(datas[0]);
87	for (deUint32 x = 0; x < width; ++x)
88	{
89		for (deUint32 y = 0; y < height; ++y)
90		{
91			deUint32 val = data[x * height + y];
92
93			if (0x3 != val)
94			{
95				return false;
96			}
97		}
98	}
99
100	return true;
101}
102
103static bool checkCompute(std::vector<const void*> datas,
104						 const deUint32 numWorkgroups[3], const deUint32 localSize[3],
105						 deUint32)
106{
107	const deUint32* data =
108		reinterpret_cast<const deUint32*>(datas[0]);
109
110	for (deUint32 nX = 0; nX < numWorkgroups[0]; ++nX)
111	{
112		for (deUint32 nY = 0; nY < numWorkgroups[1]; ++nY)
113		{
114			for (deUint32 nZ = 0; nZ < numWorkgroups[2]; ++nZ)
115			{
116				for (deUint32 lX = 0; lX < localSize[0]; ++lX)
117				{
118					for (deUint32 lY = 0; lY < localSize[1]; ++lY)
119					{
120						for (deUint32 lZ = 0; lZ < localSize[2];
121								++lZ)
122						{
123							const deUint32 globalInvocationX =
124								nX * localSize[0] + lX;
125							const deUint32 globalInvocationY =
126								nY * localSize[1] + lY;
127							const deUint32 globalInvocationZ =
128								nZ * localSize[2] + lZ;
129
130							const deUint32 globalSizeX =
131								numWorkgroups[0] * localSize[0];
132							const deUint32 globalSizeY =
133								numWorkgroups[1] * localSize[1];
134
135							const deUint32 offset =
136								globalSizeX *
137								((globalSizeY *
138								  globalInvocationZ) +
139								 globalInvocationY) +
140								globalInvocationX;
141
142							if (0x3 != data[offset])
143							{
144								return false;
145							}
146						}
147					}
148				}
149			}
150		}
151	}
152
153	return true;
154}
155
156std::string getOpTypeName(int opType)
157{
158	switch (opType)
159	{
160		default:
161			DE_FATAL("Unsupported op type");
162		case OPTYPE_ADD:
163			return "subgroupAdd";
164		case OPTYPE_MUL:
165			return "subgroupMul";
166		case OPTYPE_MIN:
167			return "subgroupMin";
168		case OPTYPE_MAX:
169			return "subgroupMax";
170		case OPTYPE_AND:
171			return "subgroupAnd";
172		case OPTYPE_OR:
173			return "subgroupOr";
174		case OPTYPE_XOR:
175			return "subgroupXor";
176		case OPTYPE_INCLUSIVE_ADD:
177			return "subgroupInclusiveAdd";
178		case OPTYPE_INCLUSIVE_MUL:
179			return "subgroupInclusiveMul";
180		case OPTYPE_INCLUSIVE_MIN:
181			return "subgroupInclusiveMin";
182		case OPTYPE_INCLUSIVE_MAX:
183			return "subgroupInclusiveMax";
184		case OPTYPE_INCLUSIVE_AND:
185			return "subgroupInclusiveAnd";
186		case OPTYPE_INCLUSIVE_OR:
187			return "subgroupInclusiveOr";
188		case OPTYPE_INCLUSIVE_XOR:
189			return "subgroupInclusiveXor";
190		case OPTYPE_EXCLUSIVE_ADD:
191			return "subgroupExclusiveAdd";
192		case OPTYPE_EXCLUSIVE_MUL:
193			return "subgroupExclusiveMul";
194		case OPTYPE_EXCLUSIVE_MIN:
195			return "subgroupExclusiveMin";
196		case OPTYPE_EXCLUSIVE_MAX:
197			return "subgroupExclusiveMax";
198		case OPTYPE_EXCLUSIVE_AND:
199			return "subgroupExclusiveAnd";
200		case OPTYPE_EXCLUSIVE_OR:
201			return "subgroupExclusiveOr";
202		case OPTYPE_EXCLUSIVE_XOR:
203			return "subgroupExclusiveXor";
204	}
205}
206
207std::string getOpTypeOperation(int opType, vk::VkFormat format, std::string lhs, std::string rhs)
208{
209	switch (opType)
210	{
211		default:
212			DE_FATAL("Unsupported op type");
213		case OPTYPE_ADD:
214		case OPTYPE_INCLUSIVE_ADD:
215		case OPTYPE_EXCLUSIVE_ADD:
216			return lhs + " + " + rhs;
217		case OPTYPE_MUL:
218		case OPTYPE_INCLUSIVE_MUL:
219		case OPTYPE_EXCLUSIVE_MUL:
220			return lhs + " * " + rhs;
221		case OPTYPE_MIN:
222		case OPTYPE_INCLUSIVE_MIN:
223		case OPTYPE_EXCLUSIVE_MIN:
224			switch (format)
225			{
226				default:
227					return "min(" + lhs + ", " + rhs + ")";
228				case VK_FORMAT_R32_SFLOAT:
229				case VK_FORMAT_R64_SFLOAT:
230					return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : min(" + lhs + ", " + rhs + ")))";
231				case VK_FORMAT_R32G32_SFLOAT:
232				case VK_FORMAT_R32G32B32_SFLOAT:
233				case VK_FORMAT_R32G32B32A32_SFLOAT:
234				case VK_FORMAT_R64G64_SFLOAT:
235				case VK_FORMAT_R64G64B64_SFLOAT:
236				case VK_FORMAT_R64G64B64A64_SFLOAT:
237					return "mix(mix(min(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
238			}
239		case OPTYPE_MAX:
240		case OPTYPE_INCLUSIVE_MAX:
241		case OPTYPE_EXCLUSIVE_MAX:
242			switch (format)
243			{
244				default:
245					return "max(" + lhs + ", " + rhs + ")";
246				case VK_FORMAT_R32_SFLOAT:
247				case VK_FORMAT_R64_SFLOAT:
248					return "(isnan(" + lhs + ") ? " + rhs + " : (isnan(" + rhs + ") ? " + lhs + " : max(" + lhs + ", " + rhs + ")))";
249				case VK_FORMAT_R32G32_SFLOAT:
250				case VK_FORMAT_R32G32B32_SFLOAT:
251				case VK_FORMAT_R32G32B32A32_SFLOAT:
252				case VK_FORMAT_R64G64_SFLOAT:
253				case VK_FORMAT_R64G64B64_SFLOAT:
254				case VK_FORMAT_R64G64B64A64_SFLOAT:
255					return "mix(mix(max(" + lhs + ", " + rhs + "), " + lhs + ", isnan(" + rhs + ")), " + rhs + ", isnan(" + lhs + "))";
256			}
257		case OPTYPE_AND:
258		case OPTYPE_INCLUSIVE_AND:
259		case OPTYPE_EXCLUSIVE_AND:
260			switch (format)
261			{
262				default:
263					return lhs + " & " + rhs;
264				case VK_FORMAT_R8_USCALED:
265					return lhs + " && " + rhs;
266				case VK_FORMAT_R8G8_USCALED:
267					return "bvec2(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y)";
268				case VK_FORMAT_R8G8B8_USCALED:
269					return "bvec3(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z)";
270				case VK_FORMAT_R8G8B8A8_USCALED:
271					return "bvec4(" + lhs + ".x && " + rhs + ".x, " + lhs + ".y && " + rhs + ".y, " + lhs + ".z && " + rhs + ".z, " + lhs + ".w && " + rhs + ".w)";
272			}
273		case OPTYPE_OR:
274		case OPTYPE_INCLUSIVE_OR:
275		case OPTYPE_EXCLUSIVE_OR:
276			switch (format)
277			{
278				default:
279					return lhs + " | " + rhs;
280				case VK_FORMAT_R8_USCALED:
281					return lhs + " || " + rhs;
282				case VK_FORMAT_R8G8_USCALED:
283					return "bvec2(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y)";
284				case VK_FORMAT_R8G8B8_USCALED:
285					return "bvec3(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z)";
286				case VK_FORMAT_R8G8B8A8_USCALED:
287					return "bvec4(" + lhs + ".x || " + rhs + ".x, " + lhs + ".y || " + rhs + ".y, " + lhs + ".z || " + rhs + ".z, " + lhs + ".w || " + rhs + ".w)";
288			}
289		case OPTYPE_XOR:
290		case OPTYPE_INCLUSIVE_XOR:
291		case OPTYPE_EXCLUSIVE_XOR:
292			switch (format)
293			{
294				default:
295					return lhs + " ^ " + rhs;
296				case VK_FORMAT_R8_USCALED:
297					return lhs + " ^^ " + rhs;
298				case VK_FORMAT_R8G8_USCALED:
299					return "bvec2(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y)";
300				case VK_FORMAT_R8G8B8_USCALED:
301					return "bvec3(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z)";
302				case VK_FORMAT_R8G8B8A8_USCALED:
303					return "bvec4(" + lhs + ".x ^^ " + rhs + ".x, " + lhs + ".y ^^ " + rhs + ".y, " + lhs + ".z ^^ " + rhs + ".z, " + lhs + ".w ^^ " + rhs + ".w)";
304			}
305	}
306}
307
308std::string getIdentity(int opType, vk::VkFormat format)
309{
310	bool isFloat = false;
311	bool isInt = false;
312	bool isUnsigned = false;
313
314	switch (format)
315	{
316		default:
317			DE_FATAL("Unhandled format!");
318		case VK_FORMAT_R32_SINT:
319		case VK_FORMAT_R32G32_SINT:
320		case VK_FORMAT_R32G32B32_SINT:
321		case VK_FORMAT_R32G32B32A32_SINT:
322			isInt = true;
323			break;
324		case VK_FORMAT_R32_UINT:
325		case VK_FORMAT_R32G32_UINT:
326		case VK_FORMAT_R32G32B32_UINT:
327		case VK_FORMAT_R32G32B32A32_UINT:
328			isUnsigned = true;
329			break;
330		case VK_FORMAT_R32_SFLOAT:
331		case VK_FORMAT_R32G32_SFLOAT:
332		case VK_FORMAT_R32G32B32_SFLOAT:
333		case VK_FORMAT_R32G32B32A32_SFLOAT:
334		case VK_FORMAT_R64_SFLOAT:
335		case VK_FORMAT_R64G64_SFLOAT:
336		case VK_FORMAT_R64G64B64_SFLOAT:
337		case VK_FORMAT_R64G64B64A64_SFLOAT:
338			isFloat = true;
339			break;
340		case VK_FORMAT_R8_USCALED:
341		case VK_FORMAT_R8G8_USCALED:
342		case VK_FORMAT_R8G8B8_USCALED:
343		case VK_FORMAT_R8G8B8A8_USCALED:
344			break; // bool types are not anything
345	}
346
347	switch (opType)
348	{
349		default:
350			DE_FATAL("Unsupported op type");
351		case OPTYPE_ADD:
352		case OPTYPE_INCLUSIVE_ADD:
353		case OPTYPE_EXCLUSIVE_ADD:
354			return subgroups::getFormatNameForGLSL(format) + "(0)";
355		case OPTYPE_MUL:
356		case OPTYPE_INCLUSIVE_MUL:
357		case OPTYPE_EXCLUSIVE_MUL:
358			return subgroups::getFormatNameForGLSL(format) + "(1)";
359		case OPTYPE_MIN:
360		case OPTYPE_INCLUSIVE_MIN:
361		case OPTYPE_EXCLUSIVE_MIN:
362			if (isFloat)
363			{
364				return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0x7f800000))";
365			}
366			else if (isInt)
367			{
368				return subgroups::getFormatNameForGLSL(format) + "(0x7fffffff)";
369			}
370			else if (isUnsigned)
371			{
372				return subgroups::getFormatNameForGLSL(format) + "(0xffffffffu)";
373			}
374			else
375			{
376				DE_FATAL("Unhandled case");
377			}
378		case OPTYPE_MAX:
379		case OPTYPE_INCLUSIVE_MAX:
380		case OPTYPE_EXCLUSIVE_MAX:
381			if (isFloat)
382			{
383				return subgroups::getFormatNameForGLSL(format) + "(intBitsToFloat(0xff800000))";
384			}
385			else if (isInt)
386			{
387				return subgroups::getFormatNameForGLSL(format) + "(0x80000000)";
388			}
389			else if (isUnsigned)
390			{
391				return subgroups::getFormatNameForGLSL(format) + "(0)";
392			}
393			else
394			{
395				DE_FATAL("Unhandled case");
396			}
397		case OPTYPE_AND:
398		case OPTYPE_INCLUSIVE_AND:
399		case OPTYPE_EXCLUSIVE_AND:
400			return subgroups::getFormatNameForGLSL(format) + "(~0)";
401		case OPTYPE_OR:
402		case OPTYPE_INCLUSIVE_OR:
403		case OPTYPE_EXCLUSIVE_OR:
404			return subgroups::getFormatNameForGLSL(format) + "(0)";
405		case OPTYPE_XOR:
406		case OPTYPE_INCLUSIVE_XOR:
407		case OPTYPE_EXCLUSIVE_XOR:
408			return subgroups::getFormatNameForGLSL(format) + "(0)";
409	}
410}
411
412std::string getCompare(int opType, vk::VkFormat format, std::string lhs, std::string rhs)
413{
414	std::string formatName = subgroups::getFormatNameForGLSL(format);
415	switch (format)
416	{
417		default:
418			return "all(equal(" + lhs + ", " + rhs + "))";
419		case VK_FORMAT_R8_USCALED:
420		case VK_FORMAT_R32_UINT:
421		case VK_FORMAT_R32_SINT:
422			return "(" + lhs + " == " + rhs + ")";
423		case VK_FORMAT_R32_SFLOAT:
424		case VK_FORMAT_R64_SFLOAT:
425			switch (opType)
426			{
427				default:
428					return "(abs(" + lhs + " - " + rhs + ") < 0.00001)";
429				case OPTYPE_MIN:
430				case OPTYPE_INCLUSIVE_MIN:
431				case OPTYPE_EXCLUSIVE_MIN:
432				case OPTYPE_MAX:
433				case OPTYPE_INCLUSIVE_MAX:
434				case OPTYPE_EXCLUSIVE_MAX:
435					return "(" + lhs + " == " + rhs + ")";
436			}
437		case VK_FORMAT_R32G32_SFLOAT:
438		case VK_FORMAT_R32G32B32_SFLOAT:
439		case VK_FORMAT_R32G32B32A32_SFLOAT:
440		case VK_FORMAT_R64G64_SFLOAT:
441		case VK_FORMAT_R64G64B64_SFLOAT:
442		case VK_FORMAT_R64G64B64A64_SFLOAT:
443			switch (opType)
444			{
445				default:
446					return "all(lessThan(abs(" + lhs + " - " + rhs + "), " + formatName + "(0.00001)))";
447				case OPTYPE_MIN:
448				case OPTYPE_INCLUSIVE_MIN:
449				case OPTYPE_EXCLUSIVE_MIN:
450				case OPTYPE_MAX:
451				case OPTYPE_INCLUSIVE_MAX:
452				case OPTYPE_EXCLUSIVE_MAX:
453					return "all(equal(" + lhs + ", " + rhs + "))";
454			}
455	}
456}
457
458struct CaseDefinition
459{
460	int					opType;
461	VkShaderStageFlags	shaderStage;
462	VkFormat			format;
463	bool				noSSBO;
464};
465
466void initFrameBufferPrograms (SourceCollections& programCollection, CaseDefinition caseDef)
467{
468	std::string indexVars;
469	switch (caseDef.opType)
470	{
471		default:
472			indexVars = "  uint start = 0, end = gl_SubgroupSize;\n";
473			break;
474		case OPTYPE_INCLUSIVE_ADD:
475		case OPTYPE_INCLUSIVE_MUL:
476		case OPTYPE_INCLUSIVE_MIN:
477		case OPTYPE_INCLUSIVE_MAX:
478		case OPTYPE_INCLUSIVE_AND:
479		case OPTYPE_INCLUSIVE_OR:
480		case OPTYPE_INCLUSIVE_XOR:
481			indexVars = "  uint start = 0, end = gl_SubgroupInvocationID + 1;\n";
482			break;
483		case OPTYPE_EXCLUSIVE_ADD:
484		case OPTYPE_EXCLUSIVE_MUL:
485		case OPTYPE_EXCLUSIVE_MIN:
486		case OPTYPE_EXCLUSIVE_MAX:
487		case OPTYPE_EXCLUSIVE_AND:
488		case OPTYPE_EXCLUSIVE_OR:
489		case OPTYPE_EXCLUSIVE_XOR:
490			indexVars = "  uint start = 0, end = gl_SubgroupInvocationID;\n";
491			break;
492	}
493
494	std::ostringstream bdy;
495
496	bdy << indexVars
497		<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = "
498		<< getIdentity(caseDef.opType, caseDef.format) << ";\n"
499		<< "  uint tempResult = 0;\n"
500		<< "  for (uint index = start; index < end; index++)\n"
501		<< "  {\n"
502		<< "    if (subgroupBallotBitExtract(mask, index))\n"
503		<< "    {\n"
504		<< "      ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
505		<< "    }\n"
506		<< "  }\n"
507		<< "  tempResult = " << getCompare(caseDef.opType, caseDef.format, "ref",
508											getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x1 : 0;\n"
509		<< "  if (1 == (gl_SubgroupInvocationID % 2))\n"
510		<< "  {\n"
511		<< "    mask = subgroupBallot(true);\n"
512		<< "    ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n"
513		<< "    for (uint index = start; index < end; index++)\n"
514		<< "    {\n"
515		<< "      if (subgroupBallotBitExtract(mask, index))\n"
516		<< "      {\n"
517		<< "        ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
518		<< "      }\n"
519		<< "    }\n"
520		<< "    tempResult |= " << getCompare(caseDef.opType, caseDef.format, "ref",
521				getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x2 : 0;\n"
522		<< "  }\n"
523		<< "  else\n"
524		<< "  {\n"
525		<< "    tempResult |= 0x2;\n"
526		<< "  }\n";
527	if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
528	{
529		std::ostringstream	src;
530		std::ostringstream	fragmentSrc;
531
532		src << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
533			<< "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
534			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
535			<< "layout(location = 0) in highp vec4 in_position;\n"
536			<< "layout(location = 0) out float out_color;\n"
537			<< "layout(set = 0, binding = 0) uniform Buffer1\n"
538			<< "{\n"
539			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[" << subgroups::maxSupportedSubgroupSize() << "];\n"
540			<< "};\n"
541			<< "\n"
542			<< "void main (void)\n"
543			<< "{\n"
544			<< "  uvec4 mask = subgroupBallot(true);\n"
545			<< bdy.str()
546			<< "  out_color = float(tempResult);\n"
547			<< "  gl_Position = in_position;\n"
548			<< "  gl_PointSize = 1.0f;\n"
549			<< "}\n";
550
551		programCollection.glslSources.add("vert") << glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
552
553		fragmentSrc << glu::getGLSLVersionDeclaration(glu::GLSL_VERSION_450)<<"\n"
554			<< "layout(location = 0) in float in_color;\n"
555			<< "layout(location = 0) out uint out_color;\n"
556			<< "void main()\n"
557			<<"{\n"
558			<< "	out_color = uint(in_color);\n"
559			<< "}\n";
560		programCollection.glslSources.add("fragment") << glu::FragmentSource(fragmentSrc.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
561	}
562	else
563	{
564		DE_FATAL("Unsupported shader stage");
565	}
566}
567
568void initPrograms(SourceCollections& programCollection, CaseDefinition caseDef)
569{
570	std::string indexVars;
571	switch (caseDef.opType)
572	{
573		default:
574			indexVars = "  uint start = 0, end = gl_SubgroupSize;\n";
575			break;
576		case OPTYPE_INCLUSIVE_ADD:
577		case OPTYPE_INCLUSIVE_MUL:
578		case OPTYPE_INCLUSIVE_MIN:
579		case OPTYPE_INCLUSIVE_MAX:
580		case OPTYPE_INCLUSIVE_AND:
581		case OPTYPE_INCLUSIVE_OR:
582		case OPTYPE_INCLUSIVE_XOR:
583			indexVars = "  uint start = 0, end = gl_SubgroupInvocationID + 1;\n";
584			break;
585		case OPTYPE_EXCLUSIVE_ADD:
586		case OPTYPE_EXCLUSIVE_MUL:
587		case OPTYPE_EXCLUSIVE_MIN:
588		case OPTYPE_EXCLUSIVE_MAX:
589		case OPTYPE_EXCLUSIVE_AND:
590		case OPTYPE_EXCLUSIVE_OR:
591		case OPTYPE_EXCLUSIVE_XOR:
592			indexVars = "  uint start = 0, end = gl_SubgroupInvocationID;\n";
593			break;
594	}
595
596	std::ostringstream bdy;
597
598	bdy << indexVars
599		<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " ref = "
600		<< getIdentity(caseDef.opType, caseDef.format) << ";\n"
601		<< "  uint tempResult = 0;\n"
602		<< "  for (uint index = start; index < end; index++)\n"
603		<< "  {\n"
604		<< "    if (subgroupBallotBitExtract(mask, index))\n"
605		<< "    {\n"
606		<< "      ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
607		<< "    }\n"
608		<< "  }\n"
609		<< "  tempResult = " << getCompare(caseDef.opType, caseDef.format, "ref",
610										   getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x1 : 0;\n"
611		<< "  if (1 == (gl_SubgroupInvocationID % 2))\n"
612		<< "  {\n"
613		<< "    mask = subgroupBallot(true);\n"
614		<< "    ref = " << getIdentity(caseDef.opType, caseDef.format) << ";\n"
615		<< "    for (uint index = start; index < end; index++)\n"
616		<< "    {\n"
617		<< "      if (subgroupBallotBitExtract(mask, index))\n"
618		<< "      {\n"
619		<< "        ref = " << getOpTypeOperation(caseDef.opType, caseDef.format, "ref", "data[index]") << ";\n"
620		<< "      }\n"
621		<< "    }\n"
622		<< "    tempResult |= " << getCompare(caseDef.opType, caseDef.format, "ref",
623				getOpTypeName(caseDef.opType) + "(data[gl_SubgroupInvocationID])") << " ? 0x2 : 0;\n"
624		<< "  }\n"
625		<< "  else\n"
626		<< "  {\n"
627		<< "    tempResult |= 0x2;\n"
628		<< "  }\n";
629
630	if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
631	{
632		std::ostringstream src;
633
634		src << "#version 450\n"
635			<< "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
636			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
637			<< "layout (local_size_x_id = 0, local_size_y_id = 1, "
638			"local_size_z_id = 2) in;\n"
639			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
640			<< "{\n"
641			<< "  uint result[];\n"
642			<< "};\n"
643			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
644			<< "{\n"
645			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
646			<< "};\n"
647			<< "\n"
648			<< "void main (void)\n"
649			<< "{\n"
650			<< "  uvec3 globalSize = gl_NumWorkGroups * gl_WorkGroupSize;\n"
651			<< "  highp uint offset = globalSize.x * ((globalSize.y * "
652			"gl_GlobalInvocationID.z) + gl_GlobalInvocationID.y) + "
653			"gl_GlobalInvocationID.x;\n"
654			<< "  uvec4 mask = subgroupBallot(true);\n"
655			<< bdy.str()
656			<< "  result[offset] = tempResult;\n"
657			<< "}\n";
658
659		programCollection.glslSources.add("comp")
660				<< glu::ComputeSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
661	}
662	else if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
663	{
664		programCollection.glslSources.add("vert")
665				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
666
667		std::ostringstream frag;
668
669		frag << "#version 450\n"
670			 << "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
671			 << "#extension GL_KHR_shader_subgroup_ballot: enable\n"
672			 << "layout(location = 0) out uint result;\n"
673			 << "layout(set = 0, binding = 0, std430) readonly buffer Buffer2\n"
674			 << "{\n"
675			 << "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
676			 << "};\n"
677			 << "void main (void)\n"
678			 << "{\n"
679			 << "  uvec4 mask = subgroupBallot(true);\n"
680			 << bdy.str()
681			 << "  result = tempResult;\n"
682			 << "}\n";
683
684		programCollection.glslSources.add("frag")
685				<< glu::FragmentSource(frag.str())<< vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
686	}
687	else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
688	{
689		std::ostringstream src;
690
691		src << "#version 450\n"
692			<< "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
693			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
694			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
695			<< "{\n"
696			<< "  uint result[];\n"
697			<< "};\n"
698			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
699			<< "{\n"
700			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
701			<< "};\n"
702			<< "\n"
703			<< "void main (void)\n"
704			<< "{\n"
705			<< "  uvec4 mask = subgroupBallot(true);\n"
706			<< bdy.str()
707			<< "  result[gl_VertexIndex] = tempResult;\n"
708			<< "  gl_PointSize = 1.0f;\n"
709			<< "}\n";
710
711		programCollection.glslSources.add("vert")
712				<< glu::VertexSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
713	}
714	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
715	{
716		programCollection.glslSources.add("vert")
717				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
718
719		std::ostringstream src;
720
721		src << "#version 450\n"
722			<< "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
723			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
724			<< "layout(points) in;\n"
725			<< "layout(points, max_vertices = 1) out;\n"
726			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
727			<< "{\n"
728			<< "  uint result[];\n"
729			<< "};\n"
730			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
731			<< "{\n"
732			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
733			<< "};\n"
734			<< "\n"
735			<< "void main (void)\n"
736			<< "{\n"
737			<< "  uvec4 mask = subgroupBallot(true);\n"
738			<< bdy.str()
739			<< "  result[gl_PrimitiveIDIn] = tempResult;\n"
740			<< "}\n";
741
742		programCollection.glslSources.add("geom")
743				<< glu::GeometrySource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
744	}
745	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
746	{
747		programCollection.glslSources.add("vert")
748				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
749
750		programCollection.glslSources.add("tese")
751				<< glu::TessellationEvaluationSource("#version 450\nlayout(isolines) in;\nvoid main (void) {}\n");
752
753		std::ostringstream src;
754
755		src << "#version 450\n"
756			<< "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
757			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
758			<< "layout(vertices=1) out;\n"
759			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
760			<< "{\n"
761			<< "  uint result[];\n"
762			<< "};\n"
763			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
764			<< "{\n"
765			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
766			<< "};\n"
767			<< "\n"
768			<< "void main (void)\n"
769			<< "{\n"
770			<< "  uvec4 mask = subgroupBallot(true);\n"
771			<< bdy.str()
772			<< "  result[gl_PrimitiveID] = tempResult;\n"
773			<< "}\n";
774
775		programCollection.glslSources.add("tesc")
776				<< glu::TessellationControlSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
777	}
778	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
779	{
780		programCollection.glslSources.add("vert")
781				<< glu::VertexSource(subgroups::getVertShaderForStage(caseDef.shaderStage)) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
782
783		programCollection.glslSources.add("tesc")
784				<< glu::TessellationControlSource("#version 450\nlayout(vertices=1) out;\nvoid main (void) { for(uint i = 0; i < 4; i++) { gl_TessLevelOuter[i] = 1.0f; } }\n");
785
786		std::ostringstream src;
787
788		src << "#version 450\n"
789			<< "#extension GL_KHR_shader_subgroup_arithmetic: enable\n"
790			<< "#extension GL_KHR_shader_subgroup_ballot: enable\n"
791			<< "layout(isolines) in;\n"
792			<< "layout(set = 0, binding = 0, std430) buffer Buffer1\n"
793			<< "{\n"
794			<< "  uint result[];\n"
795			<< "};\n"
796			<< "layout(set = 0, binding = 1, std430) buffer Buffer2\n"
797			<< "{\n"
798			<< "  " << subgroups::getFormatNameForGLSL(caseDef.format) << " data[];\n"
799			<< "};\n"
800			<< "\n"
801			<< "void main (void)\n"
802			<< "{\n"
803			<< "  uvec4 mask = subgroupBallot(true);\n"
804			<< bdy.str()
805			<< "  result[gl_PrimitiveID * 2 + uint(gl_TessCoord.x + 0.5)] = tempResult;\n"
806			<< "}\n";
807
808		programCollection.glslSources.add("tese")
809				<< glu::TessellationEvaluationSource(src.str()) << vk::ShaderBuildOptions(vk::SPIRV_VERSION_1_3, 0u);
810	}
811	else
812	{
813		DE_FATAL("Unsupported shader stage");
814	}
815}
816
817tcu::TestStatus test(Context& context, const CaseDefinition caseDef)
818{
819	if (!subgroups::isSubgroupSupported(context))
820		TCU_THROW(NotSupportedError, "Subgroup operations are not supported");
821
822	if (!subgroups::areSubgroupOperationsSupportedForStage(
823				context, caseDef.shaderStage))
824	{
825		if (subgroups::areSubgroupOperationsRequiredForStage(
826					caseDef.shaderStage))
827		{
828			return tcu::TestStatus::fail(
829					   "Shader stage " +
830					   subgroups::getShaderStageName(caseDef.shaderStage) +
831					   " is required to support subgroup operations!");
832		}
833		else
834		{
835			TCU_THROW(NotSupportedError, "Device does not support subgroup operations for this stage");
836		}
837	}
838
839	if (!subgroups::isSubgroupFeatureSupportedForDevice(context, VK_SUBGROUP_FEATURE_ARITHMETIC_BIT))
840	{
841		TCU_THROW(NotSupportedError, "Device does not support subgroup arithmetic operations");
842	}
843
844	if (subgroups::isDoubleFormat(caseDef.format) &&
845			!subgroups::isDoubleSupportedForDevice(context))
846	{
847		TCU_THROW(NotSupportedError, "Device does not support subgroup double operations");
848	}
849
850	//Tests which don't use the SSBO
851	if (caseDef.noSSBO && VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
852	{
853		subgroups::SSBOData inputData;
854		inputData.format = caseDef.format;
855		inputData.numElements = subgroups::maxSupportedSubgroupSize();
856		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
857
858		return subgroups::makeVertexFrameBufferTest(context, VK_FORMAT_R32_UINT, &inputData, 1, checkVertexPipelineStages);
859	}
860
861	if ((VK_SHADER_STAGE_FRAGMENT_BIT != caseDef.shaderStage) &&
862			(VK_SHADER_STAGE_COMPUTE_BIT != caseDef.shaderStage))
863	{
864		if (!subgroups::isVertexSSBOSupportedForDevice(context))
865		{
866			TCU_THROW(NotSupportedError, "Device does not support vertex stage SSBO writes");
867		}
868	}
869
870	if (VK_SHADER_STAGE_FRAGMENT_BIT == caseDef.shaderStage)
871	{
872		subgroups::SSBOData inputData;
873		inputData.format = caseDef.format;
874		inputData.numElements = subgroups::maxSupportedSubgroupSize();
875		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
876
877		return subgroups::makeFragmentTest(context, VK_FORMAT_R32_UINT,
878										   &inputData, 1, checkFragment);
879	}
880	else if (VK_SHADER_STAGE_COMPUTE_BIT == caseDef.shaderStage)
881	{
882		subgroups::SSBOData inputData;
883		inputData.format = caseDef.format;
884		inputData.numElements = subgroups::maxSupportedSubgroupSize();
885		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
886
887		return subgroups::makeComputeTest(context, VK_FORMAT_R32_UINT, &inputData,
888										  1, checkCompute);
889	}
890	else if (VK_SHADER_STAGE_VERTEX_BIT == caseDef.shaderStage)
891	{
892		subgroups::SSBOData inputData;
893		inputData.format = caseDef.format;
894		inputData.numElements = subgroups::maxSupportedSubgroupSize();
895		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
896
897		return subgroups::makeVertexTest(context, VK_FORMAT_R32_UINT, &inputData,
898										 1, checkVertexPipelineStages);
899	}
900	else if (VK_SHADER_STAGE_GEOMETRY_BIT == caseDef.shaderStage)
901	{
902		subgroups::SSBOData inputData;
903		inputData.format = caseDef.format;
904		inputData.numElements = subgroups::maxSupportedSubgroupSize();
905		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
906
907		return subgroups::makeGeometryTest(context, VK_FORMAT_R32_UINT, &inputData,
908										   1, checkVertexPipelineStages);
909	}
910	else if (VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT == caseDef.shaderStage)
911	{
912		subgroups::SSBOData inputData;
913		inputData.format = caseDef.format;
914		inputData.numElements = subgroups::maxSupportedSubgroupSize();
915		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
916
917		return subgroups::makeTessellationControlTest(context, VK_FORMAT_R32_UINT, &inputData,
918				1, checkVertexPipelineStages);
919	}
920	else if (VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT == caseDef.shaderStage)
921	{
922		subgroups::SSBOData inputData;
923		inputData.format = caseDef.format;
924		inputData.numElements = subgroups::maxSupportedSubgroupSize();
925		inputData.initializeType = subgroups::SSBOData::InitializeNonZero;
926
927		return subgroups::makeTessellationEvaluationTest(context, VK_FORMAT_R32_UINT, &inputData,
928				1, checkVertexPipelineStages);
929	}
930	else
931	{
932		TCU_THROW(InternalError, "Unhandled shader stage");
933	}
934}
935}
936
937namespace vkt
938{
939namespace subgroups
940{
941tcu::TestCaseGroup* createSubgroupsArithmeticTests(tcu::TestContext& testCtx)
942{
943	de::MovePtr<tcu::TestCaseGroup> group(new tcu::TestCaseGroup(
944			testCtx, "arithmetic", "Subgroup arithmetic category tests"));
945
946	const VkShaderStageFlags stages[] =
947	{
948		VK_SHADER_STAGE_TESSELLATION_EVALUATION_BIT,
949		VK_SHADER_STAGE_TESSELLATION_CONTROL_BIT,
950		VK_SHADER_STAGE_GEOMETRY_BIT,
951		VK_SHADER_STAGE_VERTEX_BIT,
952		VK_SHADER_STAGE_FRAGMENT_BIT,
953		VK_SHADER_STAGE_COMPUTE_BIT
954	};
955
956	const VkFormat formats[] =
957	{
958		VK_FORMAT_R32_SINT, VK_FORMAT_R32G32_SINT, VK_FORMAT_R32G32B32_SINT,
959		VK_FORMAT_R32G32B32A32_SINT, VK_FORMAT_R32_UINT, VK_FORMAT_R32G32_UINT,
960		VK_FORMAT_R32G32B32_UINT, VK_FORMAT_R32G32B32A32_UINT,
961		VK_FORMAT_R32_SFLOAT, VK_FORMAT_R32G32_SFLOAT,
962		VK_FORMAT_R32G32B32_SFLOAT, VK_FORMAT_R32G32B32A32_SFLOAT,
963		VK_FORMAT_R64_SFLOAT, VK_FORMAT_R64G64_SFLOAT,
964		VK_FORMAT_R64G64B64_SFLOAT, VK_FORMAT_R64G64B64A64_SFLOAT,
965		VK_FORMAT_R8_USCALED, VK_FORMAT_R8G8_USCALED,
966		VK_FORMAT_R8G8B8_USCALED, VK_FORMAT_R8G8B8A8_USCALED,
967	};
968
969	for (int stageIndex = 0; stageIndex < DE_LENGTH_OF_ARRAY(stages); ++stageIndex)
970	{
971		const VkShaderStageFlags stage = stages[stageIndex];
972
973		for (int formatIndex = 0; formatIndex < DE_LENGTH_OF_ARRAY(formats); ++formatIndex)
974		{
975			const VkFormat format = formats[formatIndex];
976
977			for (int opTypeIndex = 0; opTypeIndex < OPTYPE_LAST; ++opTypeIndex)
978			{
979				bool isBool = false;
980				bool isFloat = false;
981
982				switch (format)
983				{
984					default:
985						break;
986					case VK_FORMAT_R32_SFLOAT:
987					case VK_FORMAT_R32G32_SFLOAT:
988					case VK_FORMAT_R32G32B32_SFLOAT:
989					case VK_FORMAT_R32G32B32A32_SFLOAT:
990					case VK_FORMAT_R64_SFLOAT:
991					case VK_FORMAT_R64G64_SFLOAT:
992					case VK_FORMAT_R64G64B64_SFLOAT:
993					case VK_FORMAT_R64G64B64A64_SFLOAT:
994						isFloat = true;
995						break;
996					case VK_FORMAT_R8_USCALED:
997					case VK_FORMAT_R8G8_USCALED:
998					case VK_FORMAT_R8G8B8_USCALED:
999					case VK_FORMAT_R8G8B8A8_USCALED:
1000						isBool = true;
1001						break;
1002				}
1003
1004				bool isBitwiseOp = false;
1005
1006				switch (opTypeIndex)
1007				{
1008					default:
1009						break;
1010					case OPTYPE_AND:
1011					case OPTYPE_INCLUSIVE_AND:
1012					case OPTYPE_EXCLUSIVE_AND:
1013					case OPTYPE_OR:
1014					case OPTYPE_INCLUSIVE_OR:
1015					case OPTYPE_EXCLUSIVE_OR:
1016					case OPTYPE_XOR:
1017					case OPTYPE_INCLUSIVE_XOR:
1018					case OPTYPE_EXCLUSIVE_XOR:
1019						isBitwiseOp = true;
1020						break;
1021				}
1022
1023				if (isFloat && isBitwiseOp)
1024				{
1025					// Skip float with bitwise category.
1026					continue;
1027				}
1028
1029				if (isBool && !isBitwiseOp)
1030				{
1031					// Skip bool when its not the bitwise category.
1032					continue;
1033				}
1034
1035				CaseDefinition caseDef = {opTypeIndex, stage, format, false};
1036
1037				std::string op = getOpTypeName(opTypeIndex);
1038
1039				addFunctionCaseWithPrograms(group.get(),
1040											de::toLower(op) + "_" +
1041											subgroups::getFormatNameForGLSL(format) +
1042											"_" + getShaderStageName(stage),
1043											"", initPrograms, test, caseDef);
1044
1045				if (VK_SHADER_STAGE_VERTEX_BIT == stage)
1046				{
1047					caseDef.noSSBO = true;
1048					addFunctionCaseWithPrograms(group.get(), de::toLower(op) + "_" + subgroups::getFormatNameForGLSL(format) +
1049												"_" + getShaderStageName(stage) + "_framebuffer", "",
1050												initFrameBufferPrograms, test, caseDef);
1051				}
1052			}
1053		}
1054	}
1055
1056	return group.release();
1057}
1058
1059} // subgroups
1060} // vkt
1061