1/****************************************************************************
2* Copyright (C) 2014-2015 Intel Corporation.   All Rights Reserved.
3*
4* Permission is hereby granted, free of charge, to any person obtaining a
5* copy of this software and associated documentation files (the "Software"),
6* to deal in the Software without restriction, including without limitation
7* the rights to use, copy, modify, merge, publish, distribute, sublicense,
8* and/or sell copies of the Software, and to permit persons to whom the
9* Software is furnished to do so, subject to the following conditions:
10*
11* The above copyright notice and this permission notice (including the next
12* paragraph) shall be included in all copies or substantial portions of the
13* Software.
14*
15* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21* IN THE SOFTWARE.
22*
23* @file streamout_jit.cpp
24*
25* @brief Implementation of the streamout jitter
26*
27* Notes:
28*
29******************************************************************************/
30#include "jit_api.h"
31#include "streamout_jit.h"
32#include "builder.h"
33#include "state_llvm.h"
34#include "llvm/IR/DataLayout.h"
35
36#include <sstream>
37#include <unordered_set>
38
39using namespace llvm;
40using namespace SwrJit;
41
42//////////////////////////////////////////////////////////////////////////
43/// Interface to Jitting a fetch shader
44//////////////////////////////////////////////////////////////////////////
45struct StreamOutJit : public Builder
46{
47    StreamOutJit(JitManager* pJitMgr) : Builder(pJitMgr){};
48
49    // returns pointer to SWR_STREAMOUT_BUFFER
50    Value* getSOBuffer(Value* pSoCtx, uint32_t buffer)
51    {
52        return LOAD(pSoCtx, { 0, SWR_STREAMOUT_CONTEXT_pBuffer, buffer });
53    }
54
55
56    //////////////////////////////////////////////////////////////////////////
57    // @brief checks if streamout buffer is oob
58    // @return <i1> true/false
59    Value* oob(const STREAMOUT_COMPILE_STATE& state, Value* pSoCtx, uint32_t buffer)
60    {
61        Value* returnMask = C(false);
62
63        Value* pBuf = getSOBuffer(pSoCtx, buffer);
64
65        // load enable
66        // @todo bool data types should generate <i1> llvm type
67        Value* enabled = TRUNC(LOAD(pBuf, { 0, SWR_STREAMOUT_BUFFER_enable }), IRB()->getInt1Ty());
68
69        // load buffer size
70        Value* bufferSize = LOAD(pBuf, { 0, SWR_STREAMOUT_BUFFER_bufferSize });
71
72        // load current streamOffset
73        Value* streamOffset = LOAD(pBuf, { 0, SWR_STREAMOUT_BUFFER_streamOffset });
74
75        // load buffer pitch
76        Value* pitch = LOAD(pBuf, { 0, SWR_STREAMOUT_BUFFER_pitch });
77
78        // buffer is considered oob if in use in a decl but not enabled
79        returnMask = OR(returnMask, NOT(enabled));
80
81        // buffer is oob if cannot fit a prims worth of verts
82        Value* newOffset = ADD(streamOffset, MUL(pitch, C(state.numVertsPerPrim)));
83        returnMask = OR(returnMask, ICMP_SGT(newOffset, bufferSize));
84
85        return returnMask;
86    }
87
88
89    //////////////////////////////////////////////////////////////////////////
90    // @brief converts scalar bitmask to <4 x i32> suitable for shuffle vector,
91    //        packing the active mask bits
92    //        ex. bitmask 0011 -> (0, 1, 0, 0)
93    //            bitmask 1000 -> (3, 0, 0, 0)
94    //            bitmask 1100 -> (2, 3, 0, 0)
95    Value* PackMask(uint32_t bitmask)
96    {
97        std::vector<Constant*> indices(4, C(0));
98        DWORD index;
99        uint32_t elem = 0;
100        while (_BitScanForward(&index, bitmask))
101        {
102            indices[elem++] = C((int)index);
103            bitmask &= ~(1 << index);
104        }
105
106        return ConstantVector::get(indices);
107    }
108
109    //////////////////////////////////////////////////////////////////////////
110    // @brief convert scalar bitmask to <4xfloat> bitmask
111    Value* ToMask(uint32_t bitmask)
112    {
113        std::vector<Constant*> indices;
114        for (uint32_t i = 0; i < 4; ++i)
115        {
116            if (bitmask & (1 << i))
117            {
118                indices.push_back(C(-1.0f));
119            }
120            else
121            {
122                indices.push_back(C(0.0f));
123            }
124        }
125        return ConstantVector::get(indices);
126    }
127
128    //////////////////////////////////////////////////////////////////////////
129    // @brief processes a single decl from the streamout stream. Reads 4 components from the input
130    //        stream and writes N components to the output buffer given the componentMask or if
131    //        a hole, just increments the buffer pointer
132    // @param pStream - pointer to current attribute
133    // @param pOutBuffers - pointers to the current location of each output buffer
134    // @param decl - input decl
135    void buildDecl(Value* pStream, Value* pOutBuffers[4], const STREAMOUT_DECL& decl)
136    {
137        // @todo add this to x86 macros
138        Function* maskStore = Intrinsic::getDeclaration(JM()->mpCurrentModule, Intrinsic::x86_avx_maskstore_ps);
139
140        uint32_t numComponents = _mm_popcnt_u32(decl.componentMask);
141        uint32_t packedMask = (1 << numComponents) - 1;
142        if (!decl.hole)
143        {
144            // increment stream pointer to correct slot
145            Value* pAttrib = GEP(pStream, C(4 * decl.attribSlot));
146
147            // load 4 components from stream
148            Type* simd4Ty = VectorType::get(IRB()->getFloatTy(), 4);
149            Type* simd4PtrTy = PointerType::get(simd4Ty, 0);
150            pAttrib = BITCAST(pAttrib, simd4PtrTy);
151            Value *vattrib = LOAD(pAttrib);
152
153            // shuffle/pack enabled components
154            Value* vpackedAttrib = VSHUFFLE(vattrib, vattrib, PackMask(decl.componentMask));
155
156            // store to output buffer
157            // cast SO buffer to i8*, needed by maskstore
158            Value* pOut = BITCAST(pOutBuffers[decl.bufferIndex], PointerType::get(mInt8Ty, 0));
159
160            // cast input to <4xfloat>
161            Value* src = BITCAST(vpackedAttrib, simd4Ty);
162            CALL(maskStore, {pOut, ToMask(packedMask), src});
163        }
164
165        // increment SO buffer
166        pOutBuffers[decl.bufferIndex] = GEP(pOutBuffers[decl.bufferIndex], C(numComponents));
167    }
168
169    //////////////////////////////////////////////////////////////////////////
170    // @brief builds a single vertex worth of data for the given stream
171    // @param streamState - state for this stream
172    // @param pCurVertex - pointer to src stream vertex data
173    // @param pOutBuffer - pointers to up to 4 SO buffers
174    void buildVertex(const STREAMOUT_STREAM& streamState, Value* pCurVertex, Value* pOutBuffer[4])
175    {
176        for (uint32_t d = 0; d < streamState.numDecls; ++d)
177        {
178            const STREAMOUT_DECL& decl = streamState.decl[d];
179            buildDecl(pCurVertex, pOutBuffer, decl);
180        }
181    }
182
183    void buildStream(const STREAMOUT_COMPILE_STATE& state, const STREAMOUT_STREAM& streamState, Value* pSoCtx, BasicBlock* returnBB, Function* soFunc)
184    {
185        // get list of active SO buffers
186        std::unordered_set<uint32_t> activeSOBuffers;
187        for (uint32_t d = 0; d < streamState.numDecls; ++d)
188        {
189            const STREAMOUT_DECL& decl = streamState.decl[d];
190            activeSOBuffers.insert(decl.bufferIndex);
191        }
192
193        // always increment numPrimStorageNeeded
194        Value *numPrimStorageNeeded = LOAD(pSoCtx, { 0, SWR_STREAMOUT_CONTEXT_numPrimStorageNeeded });
195        numPrimStorageNeeded = ADD(numPrimStorageNeeded, C(1));
196        STORE(numPrimStorageNeeded, pSoCtx, { 0, SWR_STREAMOUT_CONTEXT_numPrimStorageNeeded });
197
198        // check OOB on active SO buffers.  If any buffer is out of bound, don't write
199        // the primitive to any buffer
200        Value* oobMask = C(false);
201        for (uint32_t buffer : activeSOBuffers)
202        {
203            oobMask = OR(oobMask, oob(state, pSoCtx, buffer));
204        }
205
206        BasicBlock* validBB = BasicBlock::Create(JM()->mContext, "valid", soFunc);
207
208        // early out if OOB
209        COND_BR(oobMask, returnBB, validBB);
210
211        IRB()->SetInsertPoint(validBB);
212
213        Value* numPrimsWritten = LOAD(pSoCtx, { 0, SWR_STREAMOUT_CONTEXT_numPrimsWritten });
214        numPrimsWritten = ADD(numPrimsWritten, C(1));
215        STORE(numPrimsWritten, pSoCtx, { 0, SWR_STREAMOUT_CONTEXT_numPrimsWritten });
216
217        // compute start pointer for each output buffer
218        Value* pOutBuffer[4];
219        Value* pOutBufferStartVertex[4];
220        Value* outBufferPitch[4];
221        for (uint32_t b: activeSOBuffers)
222        {
223            Value* pBuf = getSOBuffer(pSoCtx, b);
224            Value* pData = LOAD(pBuf, { 0, SWR_STREAMOUT_BUFFER_pBuffer });
225            Value* streamOffset = LOAD(pBuf, { 0, SWR_STREAMOUT_BUFFER_streamOffset });
226            pOutBuffer[b] = GEP(pData, streamOffset);
227            pOutBufferStartVertex[b] = pOutBuffer[b];
228
229            outBufferPitch[b] = LOAD(pBuf, { 0, SWR_STREAMOUT_BUFFER_pitch });
230        }
231
232        // loop over the vertices of the prim
233        Value* pStreamData = LOAD(pSoCtx, { 0, SWR_STREAMOUT_CONTEXT_pPrimData });
234        for (uint32_t v = 0; v < state.numVertsPerPrim; ++v)
235        {
236            buildVertex(streamState, pStreamData, pOutBuffer);
237
238            // increment stream and output buffer pointers
239            // stream verts are always 32*4 dwords apart
240            pStreamData = GEP(pStreamData, C(KNOB_NUM_ATTRIBUTES * 4));
241
242            // output buffers offset using pitch in buffer state
243            for (uint32_t b : activeSOBuffers)
244            {
245                pOutBufferStartVertex[b] = GEP(pOutBufferStartVertex[b], outBufferPitch[b]);
246                pOutBuffer[b] = pOutBufferStartVertex[b];
247            }
248        }
249
250        // update each active buffer's streamOffset
251        for (uint32_t b : activeSOBuffers)
252        {
253            Value* pBuf = getSOBuffer(pSoCtx, b);
254            Value* streamOffset = LOAD(pBuf, { 0, SWR_STREAMOUT_BUFFER_streamOffset });
255            streamOffset = ADD(streamOffset, MUL(C(state.numVertsPerPrim), outBufferPitch[b]));
256            STORE(streamOffset, pBuf, { 0, SWR_STREAMOUT_BUFFER_streamOffset });
257        }
258    }
259
260    Function* Create(const STREAMOUT_COMPILE_STATE& state)
261    {
262        static std::size_t soNum = 0;
263
264        std::stringstream fnName("SOShader", std::ios_base::in | std::ios_base::out | std::ios_base::ate);
265        fnName << soNum++;
266
267        // SO function signature
268        // typedef void(__cdecl *PFN_SO_FUNC)(SWR_STREAMOUT_CONTEXT*)
269
270        std::vector<Type*> args{
271            PointerType::get(Gen_SWR_STREAMOUT_CONTEXT(JM()), 0), // SWR_STREAMOUT_CONTEXT*
272        };
273
274        FunctionType* fTy = FunctionType::get(IRB()->getVoidTy(), args, false);
275        Function* soFunc = Function::Create(fTy, GlobalValue::ExternalLinkage, fnName.str(), JM()->mpCurrentModule);
276
277        // create return basic block
278        BasicBlock* entry = BasicBlock::Create(JM()->mContext, "entry", soFunc);
279        BasicBlock* returnBB = BasicBlock::Create(JM()->mContext, "return", soFunc);
280
281        IRB()->SetInsertPoint(entry);
282
283        // arguments
284        auto argitr = soFunc->arg_begin();
285        Value* pSoCtx = &*argitr++;
286        pSoCtx->setName("pSoCtx");
287
288        const STREAMOUT_STREAM& streamState = state.stream;
289        buildStream(state, streamState, pSoCtx, returnBB, soFunc);
290
291        BR(returnBB);
292
293        IRB()->SetInsertPoint(returnBB);
294        RET_VOID();
295
296        JitManager::DumpToFile(soFunc, "SoFunc");
297
298        ::FunctionPassManager passes(JM()->mpCurrentModule);
299
300        passes.add(createBreakCriticalEdgesPass());
301        passes.add(createCFGSimplificationPass());
302        passes.add(createEarlyCSEPass());
303        passes.add(createPromoteMemoryToRegisterPass());
304        passes.add(createCFGSimplificationPass());
305        passes.add(createEarlyCSEPass());
306        passes.add(createInstructionCombiningPass());
307        passes.add(createInstructionSimplifierPass());
308        passes.add(createConstantPropagationPass());
309        passes.add(createSCCPPass());
310        passes.add(createAggressiveDCEPass());
311
312        passes.run(*soFunc);
313
314        JitManager::DumpToFile(soFunc, "SoFunc_optimized");
315
316        return soFunc;
317    }
318};
319
320//////////////////////////////////////////////////////////////////////////
321/// @brief JITs from streamout shader IR
322/// @param hJitMgr - JitManager handle
323/// @param func   - LLVM function IR
324/// @return PFN_SO_FUNC - pointer to SOS function
325PFN_SO_FUNC JitStreamoutFunc(HANDLE hJitMgr, const HANDLE hFunc)
326{
327    const llvm::Function *func = (const llvm::Function*)hFunc;
328    JitManager* pJitMgr = reinterpret_cast<JitManager*>(hJitMgr);
329    PFN_SO_FUNC pfnStreamOut;
330    pfnStreamOut = (PFN_SO_FUNC)(pJitMgr->mpExec->getFunctionAddress(func->getName().str()));
331    // MCJIT finalizes modules the first time you JIT code from them. After finalized, you cannot add new IR to the module
332    pJitMgr->mIsModuleFinalized = true;
333
334    return pfnStreamOut;
335}
336
337//////////////////////////////////////////////////////////////////////////
338/// @brief JIT compiles streamout shader
339/// @param hJitMgr - JitManager handle
340/// @param state   - SO state to build function from
341extern "C" PFN_SO_FUNC JITCALL JitCompileStreamout(HANDLE hJitMgr, const STREAMOUT_COMPILE_STATE& state)
342{
343    JitManager* pJitMgr = reinterpret_cast<JitManager*>(hJitMgr);
344
345    STREAMOUT_COMPILE_STATE soState = state;
346    if (soState.offsetAttribs)
347    {
348        for (uint32_t i = 0; i < soState.stream.numDecls; ++i)
349        {
350            soState.stream.decl[i].attribSlot -= soState.offsetAttribs;
351        }
352    }
353
354    pJitMgr->SetupNewModule();
355
356    StreamOutJit theJit(pJitMgr);
357    HANDLE hFunc = theJit.Create(soState);
358
359    return JitStreamoutFunc(hJitMgr, hFunc);
360}
361