1/*-------------------------------------------------------------------------
2 * Vulkan Conformance Tests
3 * ------------------------
4 *
5 * Copyright (c) 2016 Google Inc.
6 *
7 * Licensed under the Apache License, Version 2.0 (the "License");
8 * you may not use this file except in compliance with the License.
9 * You may obtain a copy of the License at
10 *
11 *      http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing, software
14 * distributed under the License is distributed on an "AS IS" BASIS,
15 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 * See the License for the specific language governing permissions and
17 * limitations under the License.
18 *
19 *//*!
20 * \file
21 * \brief Utility for pre-compiling source programs to SPIR-V
22 *//*--------------------------------------------------------------------*/
23
24#include "tcuDefs.hpp"
25#include "tcuCommandLine.hpp"
26#include "tcuPlatform.hpp"
27#include "tcuResource.hpp"
28#include "tcuTestLog.hpp"
29#include "tcuTestHierarchyIterator.hpp"
30#include "deUniquePtr.hpp"
31#include "vkPrograms.hpp"
32#include "vkBinaryRegistry.hpp"
33#include "vktTestCase.hpp"
34#include "vktTestPackage.hpp"
35#include "deUniquePtr.hpp"
36#include "deCommandLine.hpp"
37#include "deSharedPtr.hpp"
38#include "deThread.hpp"
39#include "deThreadSafeRingBuffer.hpp"
40#include "dePoolArray.hpp"
41
42#include <iostream>
43
44using std::vector;
45using std::string;
46using de::UniquePtr;
47using de::MovePtr;
48using de::SharedPtr;
49
50namespace vkt
51{
52
53namespace // anonymous
54{
55
56typedef de::SharedPtr<glu::ProgramSources>	ProgramSourcesSp;
57typedef de::SharedPtr<vk::SpirVAsmSource>	SpirVAsmSourceSp;
58typedef de::SharedPtr<vk::ProgramBinary>	ProgramBinarySp;
59
60class Task
61{
62public:
63	virtual void	execute		(void) = 0;
64};
65
66typedef de::ThreadSafeRingBuffer<Task*>	TaskQueue;
67
68class TaskExecutorThread : public de::Thread
69{
70public:
71	TaskExecutorThread (TaskQueue& tasks)
72		: m_tasks(tasks)
73	{
74		start();
75	}
76
77	void run (void)
78	{
79		for (;;)
80		{
81			Task* const	task	= m_tasks.popBack();
82
83			if (task)
84				task->execute();
85			else
86				break; // End of tasks - time to terminate
87		}
88	}
89
90private:
91	TaskQueue&	m_tasks;
92};
93
94class TaskExecutor
95{
96public:
97								TaskExecutor		(deUint32 numThreads);
98								~TaskExecutor		(void);
99
100	void						submit				(Task* task);
101	void						waitForComplete		(void);
102
103private:
104	typedef de::SharedPtr<TaskExecutorThread>	ExecThreadSp;
105
106	std::vector<ExecThreadSp>	m_threads;
107	TaskQueue					m_tasks;
108};
109
110TaskExecutor::TaskExecutor (deUint32 numThreads)
111	: m_threads	(numThreads)
112	, m_tasks	(m_threads.size() * 1024u)
113{
114	for (size_t ndx = 0; ndx < m_threads.size(); ++ndx)
115		m_threads[ndx] = ExecThreadSp(new TaskExecutorThread(m_tasks));
116}
117
118TaskExecutor::~TaskExecutor (void)
119{
120	for (size_t ndx = 0; ndx < m_threads.size(); ++ndx)
121		m_tasks.pushFront(DE_NULL);
122
123	for (size_t ndx = 0; ndx < m_threads.size(); ++ndx)
124		m_threads[ndx]->join();
125}
126
127void TaskExecutor::submit (Task* task)
128{
129	DE_ASSERT(task);
130	m_tasks.pushFront(task);
131}
132
133class SyncTask : public Task
134{
135public:
136	SyncTask (de::Semaphore* enterBarrier, de::Semaphore* inBarrier, de::Semaphore* leaveBarrier)
137		: m_enterBarrier	(enterBarrier)
138		, m_inBarrier		(inBarrier)
139		, m_leaveBarrier	(leaveBarrier)
140	{}
141
142	SyncTask (void)
143		: m_enterBarrier	(DE_NULL)
144		, m_inBarrier		(DE_NULL)
145		, m_leaveBarrier	(DE_NULL)
146	{}
147
148	void execute (void)
149	{
150		m_enterBarrier->increment();
151		m_inBarrier->decrement();
152		m_leaveBarrier->increment();
153	}
154
155private:
156	de::Semaphore*	m_enterBarrier;
157	de::Semaphore*	m_inBarrier;
158	de::Semaphore*	m_leaveBarrier;
159};
160
161void TaskExecutor::waitForComplete (void)
162{
163	de::Semaphore			enterBarrier	(0);
164	de::Semaphore			inBarrier		(0);
165	de::Semaphore			leaveBarrier	(0);
166	std::vector<SyncTask>	syncTasks		(m_threads.size());
167
168	for (size_t ndx = 0; ndx < m_threads.size(); ++ndx)
169	{
170		syncTasks[ndx] = SyncTask(&enterBarrier, &inBarrier, &leaveBarrier);
171		submit(&syncTasks[ndx]);
172	}
173
174	for (size_t ndx = 0; ndx < m_threads.size(); ++ndx)
175		enterBarrier.decrement();
176
177	for (size_t ndx = 0; ndx < m_threads.size(); ++ndx)
178		inBarrier.increment();
179
180	for (size_t ndx = 0; ndx < m_threads.size(); ++ndx)
181		leaveBarrier.decrement();
182}
183
184struct Program
185{
186	enum Status
187	{
188		STATUS_NOT_COMPLETED = 0,
189		STATUS_FAILED,
190		STATUS_PASSED,
191
192		STATUS_LAST
193	};
194
195	vk::ProgramIdentifier	id;
196
197	Status					buildStatus;
198	std::string				buildLog;
199	ProgramBinarySp			binary;
200
201	Status					validationStatus;
202	std::string				validationLog;
203
204	explicit				Program		(const vk::ProgramIdentifier& id_)
205								: id				(id_)
206								, buildStatus		(STATUS_NOT_COMPLETED)
207								, validationStatus	(STATUS_NOT_COMPLETED)
208							{}
209							Program		(void)
210								: id				("", "")
211								, buildStatus		(STATUS_NOT_COMPLETED)
212								, validationStatus	(STATUS_NOT_COMPLETED)
213							{}
214};
215
216void writeBuildLogs (const glu::ShaderProgramInfo& buildInfo, std::ostream& dst)
217{
218	for (size_t shaderNdx = 0; shaderNdx < buildInfo.shaders.size(); shaderNdx++)
219	{
220		const glu::ShaderInfo&	shaderInfo	= buildInfo.shaders[shaderNdx];
221		const char* const		shaderName	= getShaderTypeName(shaderInfo.type);
222
223		dst << shaderName << " source:\n"
224			<< "---\n"
225			<< shaderInfo.source << "\n"
226			<< "---\n"
227			<< shaderName << " compile log:\n"
228			<< "---\n"
229			<< shaderInfo.infoLog << "\n"
230			<< "---\n";
231	}
232
233	dst << "link log:\n"
234		<< "---\n"
235		<< buildInfo.program.infoLog << "\n"
236		<< "---\n";
237}
238
239class BuildGlslTask : public Task
240{
241public:
242
243	BuildGlslTask (const glu::ProgramSources& source, Program* program)
244		: m_source	(source)
245		, m_program	(program)
246	{}
247
248	BuildGlslTask (void) : m_program(DE_NULL) {}
249
250	void execute (void)
251	{
252		glu::ShaderProgramInfo buildInfo;
253
254		try
255		{
256			m_program->binary		= ProgramBinarySp(vk::buildProgram(m_source, vk::PROGRAM_FORMAT_SPIRV, &buildInfo));
257			m_program->buildStatus	= Program::STATUS_PASSED;
258		}
259		catch (const tcu::Exception&)
260		{
261			std::ostringstream log;
262
263			writeBuildLogs(buildInfo, log);
264
265			m_program->buildStatus	= Program::STATUS_FAILED;
266			m_program->buildLog		= log.str();
267
268		}
269	}
270
271private:
272	glu::ProgramSources	m_source;
273	Program*			m_program;
274};
275
276void writeBuildLogs (const vk::SpirVProgramInfo& buildInfo, std::ostream& dst)
277{
278	dst << "source:\n"
279		<< "---\n"
280		<< buildInfo.source << "\n"
281		<< "---\n";
282}
283
284class BuildSpirVAsmTask : public Task
285{
286public:
287	BuildSpirVAsmTask (const vk::SpirVAsmSource& source, Program* program)
288		: m_source	(source)
289		, m_program	(program)
290	{}
291
292	BuildSpirVAsmTask (void) : m_program(DE_NULL) {}
293
294	void execute (void)
295	{
296		vk::SpirVProgramInfo buildInfo;
297
298		try
299		{
300			m_program->binary		= ProgramBinarySp(vk::assembleProgram(m_source, &buildInfo));
301			m_program->buildStatus	= Program::STATUS_PASSED;
302		}
303		catch (const tcu::Exception&)
304		{
305			std::ostringstream log;
306
307			writeBuildLogs(buildInfo, log);
308
309			m_program->buildStatus	= Program::STATUS_FAILED;
310			m_program->buildLog		= log.str();
311		}
312	}
313
314private:
315	vk::SpirVAsmSource	m_source;
316	Program*			m_program;
317};
318
319class ValidateBinaryTask : public Task
320{
321public:
322	ValidateBinaryTask (Program* program)
323		: m_program(program)
324	{}
325
326	void execute (void)
327	{
328		DE_ASSERT(m_program->buildStatus == Program::STATUS_PASSED);
329
330		std::ostringstream validationLog;
331
332		if (vk::validateProgram(*m_program->binary, &validationLog))
333			m_program->validationStatus = Program::STATUS_PASSED;
334		else
335			m_program->validationStatus = Program::STATUS_FAILED;
336	}
337
338private:
339	Program*	m_program;
340};
341
342tcu::TestPackageRoot* createRoot (tcu::TestContext& testCtx)
343{
344	vector<tcu::TestNode*>	children;
345	children.push_back(new TestPackage(testCtx));
346	return new tcu::TestPackageRoot(testCtx, children);
347}
348
349} // anonymous
350
351struct BuildStats
352{
353	int		numSucceeded;
354	int		numFailed;
355
356	BuildStats (void)
357		: numSucceeded	(0)
358		, numFailed		(0)
359	{
360	}
361};
362
363BuildStats buildPrograms (tcu::TestContext& testCtx, const std::string& dstPath, bool validateBinaries)
364{
365	const deUint32						numThreads			= deGetNumAvailableLogicalCores();
366
367	TaskExecutor						executor			(numThreads);
368
369	// de::PoolArray<> is faster to build than std::vector
370	de::MemPool							programPool;
371	de::PoolArray<Program>				programs			(&programPool);
372
373	{
374		de::MemPool							tmpPool;
375		de::PoolArray<BuildGlslTask>		buildGlslTasks		(&tmpPool);
376		de::PoolArray<BuildSpirVAsmTask>	buildSpirvAsmTasks	(&tmpPool);
377
378		// Collect build tasks
379		{
380			const UniquePtr<tcu::TestPackageRoot>	root		(createRoot(testCtx));
381			tcu::DefaultHierarchyInflater			inflater	(testCtx);
382			tcu::TestHierarchyIterator				iterator	(*root, inflater, testCtx.getCommandLine());
383
384			while (iterator.getState() != tcu::TestHierarchyIterator::STATE_FINISHED)
385			{
386				if (iterator.getState() == tcu::TestHierarchyIterator::STATE_ENTER_NODE &&
387					tcu::isTestNodeTypeExecutable(iterator.getNode()->getNodeType()))
388				{
389					const TestCase* const		testCase	= dynamic_cast<TestCase*>(iterator.getNode());
390					const string				casePath	= iterator.getNodePath();
391					vk::SourceCollections		sourcePrograms;
392
393					testCase->initPrograms(sourcePrograms);
394
395					for (vk::GlslSourceCollection::Iterator progIter = sourcePrograms.glslSources.begin();
396						 progIter != sourcePrograms.glslSources.end();
397						 ++progIter)
398					{
399						programs.pushBack(Program(vk::ProgramIdentifier(casePath, progIter.getName())));
400						buildGlslTasks.pushBack(BuildGlslTask(progIter.getProgram(), &programs.back()));
401						executor.submit(&buildGlslTasks.back());
402					}
403
404					for (vk::SpirVAsmCollection::Iterator progIter = sourcePrograms.spirvAsmSources.begin();
405						 progIter != sourcePrograms.spirvAsmSources.end();
406						 ++progIter)
407					{
408						programs.pushBack(Program(vk::ProgramIdentifier(casePath, progIter.getName())));
409						buildSpirvAsmTasks.pushBack(BuildSpirVAsmTask(progIter.getProgram(), &programs.back()));
410						executor.submit(&buildSpirvAsmTasks.back());
411					}
412				}
413
414				iterator.next();
415			}
416		}
417
418		// Need to wait until tasks completed before freeing task memory
419		executor.waitForComplete();
420	}
421
422	if (validateBinaries)
423	{
424		std::vector<ValidateBinaryTask>	validationTasks;
425
426		validationTasks.reserve(programs.size());
427
428		for (de::PoolArray<Program>::iterator progIter = programs.begin(); progIter != programs.end(); ++progIter)
429		{
430			if (progIter->buildStatus == Program::STATUS_PASSED)
431			{
432				validationTasks.push_back(ValidateBinaryTask(&*progIter));
433				executor.submit(&validationTasks.back());
434			}
435		}
436
437		executor.waitForComplete();
438	}
439
440	{
441		vk::BinaryRegistryWriter	registryWriter		(dstPath);
442
443		for (de::PoolArray<Program>::iterator progIter = programs.begin(); progIter != programs.end(); ++progIter)
444		{
445			if (progIter->buildStatus == Program::STATUS_PASSED)
446				registryWriter.addProgram(progIter->id, *progIter->binary);
447		}
448
449		registryWriter.write();
450	}
451
452	{
453		BuildStats	stats;
454
455		for (de::PoolArray<Program>::iterator progIter = programs.begin(); progIter != programs.end(); ++progIter)
456		{
457			const bool	buildOk			= progIter->buildStatus == Program::STATUS_PASSED;
458			const bool	validationOk	= progIter->validationStatus != Program::STATUS_FAILED;
459
460			if (buildOk && validationOk)
461				stats.numSucceeded += 1;
462			else
463			{
464				stats.numFailed += 1;
465				tcu::print("ERROR: %s / %s: %s failed\n",
466						   progIter->id.testCasePath.c_str(),
467						   progIter->id.programName.c_str(),
468						   (buildOk ? "validation" : "build"));
469				tcu::print("%s\n", (buildOk ? progIter->validationLog.c_str() : progIter->buildLog.c_str()));
470			}
471		}
472
473		return stats;
474	}
475}
476
477} // vkt
478
479namespace opt
480{
481
482DE_DECLARE_COMMAND_LINE_OPT(DstPath,	std::string);
483DE_DECLARE_COMMAND_LINE_OPT(Cases,		std::string);
484DE_DECLARE_COMMAND_LINE_OPT(Validate,	bool);
485
486} // opt
487
488void registerOptions (de::cmdline::Parser& parser)
489{
490	using de::cmdline::Option;
491
492	parser << Option<opt::DstPath>	("d", "dst-path",		"Destination path",	"out")
493		   << Option<opt::Cases>	("n", "deqp-case",		"Case path filter (works as in test binaries)")
494		   << Option<opt::Validate>	("v", "validate-spv",	"Validate generated SPIR-V binaries");
495}
496
497int main (int argc, const char* argv[])
498{
499	de::cmdline::CommandLine	cmdLine;
500	tcu::CommandLine			deqpCmdLine;
501
502	{
503		de::cmdline::Parser		parser;
504		registerOptions(parser);
505		if (!parser.parse(argc, argv, &cmdLine, std::cerr))
506		{
507			parser.help(std::cout);
508			return -1;
509		}
510	}
511
512	{
513		vector<const char*> deqpArgv;
514
515		deqpArgv.push_back("unused");
516
517		if (cmdLine.hasOption<opt::Cases>())
518		{
519			deqpArgv.push_back("--deqp-case");
520			deqpArgv.push_back(cmdLine.getOption<opt::Cases>().c_str());
521		}
522
523		if (!deqpCmdLine.parse((int)deqpArgv.size(), &deqpArgv[0]))
524			return -1;
525	}
526
527	try
528	{
529		tcu::DirArchive			archive			(".");
530		tcu::TestLog			log				(deqpCmdLine.getLogFileName(), deqpCmdLine.getLogFlags());
531		tcu::Platform			platform;
532		tcu::TestContext		testCtx			(platform, archive, log, deqpCmdLine, DE_NULL);
533
534		const vkt::BuildStats	stats			= vkt::buildPrograms(testCtx,
535																	 cmdLine.getOption<opt::DstPath>(),
536																	 cmdLine.getOption<opt::Validate>());
537
538		tcu::print("DONE: %d passed, %d failed\n", stats.numSucceeded, stats.numFailed);
539
540		return stats.numFailed == 0 ? 0 : -1;
541	}
542	catch (const std::exception& e)
543	{
544		tcu::die("%s", e.what());
545	}
546}
547