1//
2// Copyright (c) 2002-2010 The ANGLE Project Authors. All rights reserved.
3// Use of this source code is governed by a BSD-style license that can be
4// found in the LICENSE file.
5//
6
7#include "compiler/OutputHLSL.h"
8
9#include "compiler/debug.h"
10#include "compiler/InfoSink.h"
11#include "compiler/UnfoldSelect.h"
12#include "compiler/SearchSymbol.h"
13
14#include <stdio.h>
15#include <algorithm>
16
17namespace sh
18{
19// Integer to TString conversion
20TString str(int i)
21{
22    char buffer[20];
23    sprintf(buffer, "%d", i);
24    return buffer;
25}
26
27OutputHLSL::OutputHLSL(TParseContext &context) : TIntermTraverser(true, true, true), mContext(context)
28{
29    mUnfoldSelect = new UnfoldSelect(context, this);
30    mInsideFunction = false;
31
32    mUsesTexture2D = false;
33    mUsesTexture2D_bias = false;
34    mUsesTexture2DProj = false;
35    mUsesTexture2DProj_bias = false;
36    mUsesTextureCube = false;
37    mUsesTextureCube_bias = false;
38    mUsesDepthRange = false;
39    mUsesFragCoord = false;
40    mUsesPointCoord = false;
41    mUsesFrontFacing = false;
42    mUsesPointSize = false;
43    mUsesXor = false;
44    mUsesMod1 = false;
45    mUsesMod2 = false;
46    mUsesMod3 = false;
47    mUsesMod4 = false;
48    mUsesFaceforward1 = false;
49    mUsesFaceforward2 = false;
50    mUsesFaceforward3 = false;
51    mUsesFaceforward4 = false;
52    mUsesEqualMat2 = false;
53    mUsesEqualMat3 = false;
54    mUsesEqualMat4 = false;
55    mUsesEqualVec2 = false;
56    mUsesEqualVec3 = false;
57    mUsesEqualVec4 = false;
58    mUsesEqualIVec2 = false;
59    mUsesEqualIVec3 = false;
60    mUsesEqualIVec4 = false;
61    mUsesEqualBVec2 = false;
62    mUsesEqualBVec3 = false;
63    mUsesEqualBVec4 = false;
64    mUsesAtan2 = false;
65
66    mScopeDepth = 0;
67
68    mUniqueIndex = 0;
69}
70
71OutputHLSL::~OutputHLSL()
72{
73    delete mUnfoldSelect;
74}
75
76void OutputHLSL::output()
77{
78    mContext.treeRoot->traverse(this);   // Output the body first to determine what has to go in the header
79    header();
80
81    mContext.infoSink.obj << mHeader.c_str();
82    mContext.infoSink.obj << mBody.c_str();
83}
84
85TInfoSinkBase &OutputHLSL::getBodyStream()
86{
87    return mBody;
88}
89
90int OutputHLSL::vectorSize(const TType &type) const
91{
92    int elementSize = type.isMatrix() ? type.getNominalSize() : 1;
93    int arraySize = type.isArray() ? type.getArraySize() : 1;
94
95    return elementSize * arraySize;
96}
97
98void OutputHLSL::header()
99{
100    ShShaderType shaderType = mContext.shaderType;
101    TInfoSinkBase &out = mHeader;
102
103    for (StructDeclarations::iterator structDeclaration = mStructDeclarations.begin(); structDeclaration != mStructDeclarations.end(); structDeclaration++)
104    {
105        out << *structDeclaration;
106    }
107
108    for (Constructors::iterator constructor = mConstructors.begin(); constructor != mConstructors.end(); constructor++)
109    {
110        out << *constructor;
111    }
112
113    if (shaderType == SH_FRAGMENT_SHADER)
114    {
115        TString uniforms;
116        TString varyings;
117
118        TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
119        int semanticIndex = 0;
120
121        for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
122        {
123            const TSymbol *symbol = (*namedSymbol).second;
124            const TString &name = symbol->getName();
125
126            if (symbol->isVariable())
127            {
128                const TVariable *variable = static_cast<const TVariable*>(symbol);
129                const TType &type = variable->getType();
130                TQualifier qualifier = type.getQualifier();
131
132                if (qualifier == EvqUniform)
133                {
134                    if (mReferencedUniforms.find(name.c_str()) != mReferencedUniforms.end())
135                    {
136                        uniforms += "uniform " + typeString(type) + " " + decorate(name) + arrayString(type) + ";\n";
137                    }
138                }
139                else if (qualifier == EvqVaryingIn || qualifier == EvqInvariantVaryingIn)
140                {
141                    if (mReferencedVaryings.find(name.c_str()) != mReferencedVaryings.end())
142                    {
143                        // Program linking depends on this exact format
144                        varyings += "static " + typeString(type) + " " + decorate(name) + arrayString(type) + " = " + initializer(type) + ";\n";
145
146                        semanticIndex += type.isArray() ? type.getArraySize() : 1;
147                    }
148                }
149                else if (qualifier == EvqGlobal || qualifier == EvqTemporary)
150                {
151                    // Globals are declared and intialized as an aggregate node
152                }
153                else if (qualifier == EvqConst)
154                {
155                    // Constants are repeated as literals where used
156                }
157                else UNREACHABLE();
158            }
159        }
160
161        out << "// Varyings\n";
162        out <<  varyings;
163        out << "\n"
164               "static float4 gl_Color[1] = {float4(0, 0, 0, 0)};\n";
165
166        if (mUsesFragCoord)
167        {
168            out << "static float4 gl_FragCoord = float4(0, 0, 0, 0);\n";
169        }
170
171        if (mUsesPointCoord)
172        {
173            out << "static float2 gl_PointCoord = float2(0.5, 0.5);\n";
174        }
175
176        if (mUsesFrontFacing)
177        {
178            out << "static bool gl_FrontFacing = false;\n";
179        }
180
181        out << "\n";
182
183        if (mUsesFragCoord)
184        {
185            out << "uniform float4 dx_Viewport;\n"
186                   "uniform float2 dx_Depth;\n";
187        }
188
189        if (mUsesFrontFacing)
190        {
191            out << "uniform bool dx_PointsOrLines;\n"
192                   "uniform bool dx_FrontCCW;\n";
193        }
194
195        out << "\n";
196        out <<  uniforms;
197        out << "\n";
198
199        if (mUsesTexture2D)
200        {
201            out << "float4 gl_texture2D(sampler2D s, float2 t)\n"
202                   "{\n"
203                   "    return tex2D(s, t);\n"
204                   "}\n"
205                   "\n";
206        }
207
208        if (mUsesTexture2D_bias)
209        {
210            out << "float4 gl_texture2D(sampler2D s, float2 t, float bias)\n"
211                   "{\n"
212                   "    return tex2Dbias(s, float4(t.x, t.y, 0, bias));\n"
213                   "}\n"
214                   "\n";
215        }
216
217        if (mUsesTexture2DProj)
218        {
219            out << "float4 gl_texture2DProj(sampler2D s, float3 t)\n"
220                   "{\n"
221                   "    return tex2Dproj(s, float4(t.x, t.y, 0, t.z));\n"
222                   "}\n"
223                   "\n"
224                   "float4 gl_texture2DProj(sampler2D s, float4 t)\n"
225                   "{\n"
226                   "    return tex2Dproj(s, t);\n"
227                   "}\n"
228                   "\n";
229        }
230
231        if (mUsesTexture2DProj_bias)
232        {
233            out << "float4 gl_texture2DProj(sampler2D s, float3 t, float bias)\n"
234                   "{\n"
235                   "    return tex2Dbias(s, float4(t.x / t.z, t.y / t.z, 0, bias));\n"
236                   "}\n"
237                   "\n"
238                   "float4 gl_texture2DProj(sampler2D s, float4 t, float bias)\n"
239                   "{\n"
240                   "    return tex2Dbias(s, float4(t.x / t.w, t.y / t.w, 0, bias));\n"
241                   "}\n"
242                   "\n";
243        }
244
245        if (mUsesTextureCube)
246        {
247            out << "float4 gl_textureCube(samplerCUBE s, float3 t)\n"
248                   "{\n"
249                   "    return texCUBE(s, t);\n"
250                   "}\n"
251                   "\n";
252        }
253
254        if (mUsesTextureCube_bias)
255        {
256            out << "float4 gl_textureCube(samplerCUBE s, float3 t, float bias)\n"
257                   "{\n"
258                   "    return texCUBEbias(s, float4(t.x, t.y, t.z, bias));\n"
259                   "}\n"
260                   "\n";
261        }
262    }
263    else   // Vertex shader
264    {
265        TString uniforms;
266        TString attributes;
267        TString varyings;
268
269        TSymbolTableLevel *symbols = mContext.symbolTable.getGlobalLevel();
270
271        for (TSymbolTableLevel::const_iterator namedSymbol = symbols->begin(); namedSymbol != symbols->end(); namedSymbol++)
272        {
273            const TSymbol *symbol = (*namedSymbol).second;
274            const TString &name = symbol->getName();
275
276            if (symbol->isVariable())
277            {
278                const TVariable *variable = static_cast<const TVariable*>(symbol);
279                const TType &type = variable->getType();
280                TQualifier qualifier = type.getQualifier();
281
282                if (qualifier == EvqUniform)
283                {
284                    if (mReferencedUniforms.find(name.c_str()) != mReferencedUniforms.end())
285                    {
286                        uniforms += "uniform " + typeString(type) + " " + decorate(name) + arrayString(type) + ";\n";
287                    }
288                }
289                else if (qualifier == EvqAttribute)
290                {
291                    if (mReferencedAttributes.find(name.c_str()) != mReferencedAttributes.end())
292                    {
293                        attributes += "static " + typeString(type) + " " + decorate(name) + arrayString(type) + " = " + initializer(type) + ";\n";
294                    }
295                }
296                else if (qualifier == EvqVaryingOut || qualifier == EvqInvariantVaryingOut)
297                {
298                    if (mReferencedVaryings.find(name.c_str()) != mReferencedVaryings.end())
299                    {
300                        // Program linking depends on this exact format
301                        varyings += "static " + typeString(type) + " " + decorate(name) + arrayString(type) + " = " + initializer(type) + ";\n";
302                    }
303                }
304                else if (qualifier == EvqGlobal || qualifier == EvqTemporary)
305                {
306                    // Globals are declared and intialized as an aggregate node
307                }
308                else if (qualifier == EvqConst)
309                {
310                    // Constants are repeated as literals where used
311                }
312                else UNREACHABLE();
313            }
314        }
315
316        out << "// Attributes\n";
317        out <<  attributes;
318        out << "\n"
319               "static float4 gl_Position = float4(0, 0, 0, 0);\n";
320
321        if (mUsesPointSize)
322        {
323            out << "static float gl_PointSize = float(1);\n";
324        }
325
326        out << "\n"
327               "// Varyings\n";
328        out <<  varyings;
329        out << "\n"
330               "uniform float2 dx_HalfPixelSize;\n"
331               "\n";
332        out <<  uniforms;
333        out << "\n";
334    }
335
336    if (mUsesFragCoord)
337    {
338        out << "#define GL_USES_FRAG_COORD\n";
339    }
340
341    if (mUsesPointCoord)
342    {
343        out << "#define GL_USES_POINT_COORD\n";
344    }
345
346    if (mUsesFrontFacing)
347    {
348        out << "#define GL_USES_FRONT_FACING\n";
349    }
350
351    if (mUsesPointSize)
352    {
353        out << "#define GL_USES_POINT_SIZE\n";
354    }
355
356    if (mUsesDepthRange)
357    {
358        out << "struct gl_DepthRangeParameters\n"
359               "{\n"
360               "    float near;\n"
361               "    float far;\n"
362               "    float diff;\n"
363               "};\n"
364               "\n"
365               "uniform float3 dx_DepthRange;"
366               "static gl_DepthRangeParameters gl_DepthRange = {dx_DepthRange.x, dx_DepthRange.y, dx_DepthRange.z};\n"
367               "\n";
368    }
369
370    if (mUsesXor)
371    {
372        out << "bool xor(bool p, bool q)\n"
373               "{\n"
374               "    return (p || q) && !(p && q);\n"
375               "}\n"
376               "\n";
377    }
378
379    if (mUsesMod1)
380    {
381        out << "float mod(float x, float y)\n"
382               "{\n"
383               "    return x - y * floor(x / y);\n"
384               "}\n"
385               "\n";
386    }
387
388    if (mUsesMod2)
389    {
390        out << "float2 mod(float2 x, float y)\n"
391               "{\n"
392               "    return x - y * floor(x / y);\n"
393               "}\n"
394               "\n";
395    }
396
397    if (mUsesMod3)
398    {
399        out << "float3 mod(float3 x, float y)\n"
400               "{\n"
401               "    return x - y * floor(x / y);\n"
402               "}\n"
403               "\n";
404    }
405
406    if (mUsesMod4)
407    {
408        out << "float4 mod(float4 x, float y)\n"
409               "{\n"
410               "    return x - y * floor(x / y);\n"
411               "}\n"
412               "\n";
413    }
414
415    if (mUsesFaceforward1)
416    {
417        out << "float faceforward(float N, float I, float Nref)\n"
418               "{\n"
419               "    if(dot(Nref, I) >= 0)\n"
420               "    {\n"
421               "        return -N;\n"
422               "    }\n"
423               "    else\n"
424               "    {\n"
425               "        return N;\n"
426               "    }\n"
427               "}\n"
428               "\n";
429    }
430
431    if (mUsesFaceforward2)
432    {
433        out << "float2 faceforward(float2 N, float2 I, float2 Nref)\n"
434               "{\n"
435               "    if(dot(Nref, I) >= 0)\n"
436               "    {\n"
437               "        return -N;\n"
438               "    }\n"
439               "    else\n"
440               "    {\n"
441               "        return N;\n"
442               "    }\n"
443               "}\n"
444               "\n";
445    }
446
447    if (mUsesFaceforward3)
448    {
449        out << "float3 faceforward(float3 N, float3 I, float3 Nref)\n"
450               "{\n"
451               "    if(dot(Nref, I) >= 0)\n"
452               "    {\n"
453               "        return -N;\n"
454               "    }\n"
455               "    else\n"
456               "    {\n"
457               "        return N;\n"
458               "    }\n"
459               "}\n"
460               "\n";
461    }
462
463    if (mUsesFaceforward4)
464    {
465        out << "float4 faceforward(float4 N, float4 I, float4 Nref)\n"
466               "{\n"
467               "    if(dot(Nref, I) >= 0)\n"
468               "    {\n"
469               "        return -N;\n"
470               "    }\n"
471               "    else\n"
472               "    {\n"
473               "        return N;\n"
474               "    }\n"
475               "}\n"
476               "\n";
477    }
478
479    if (mUsesEqualMat2)
480    {
481        out << "bool equal(float2x2 m, float2x2 n)\n"
482               "{\n"
483               "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] &&\n"
484               "           m[1][0] == n[1][0] && m[1][1] == n[1][1];\n"
485               "}\n";
486    }
487
488    if (mUsesEqualMat3)
489    {
490        out << "bool equal(float3x3 m, float3x3 n)\n"
491               "{\n"
492               "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] &&\n"
493               "           m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] &&\n"
494               "           m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2];\n"
495               "}\n";
496    }
497
498    if (mUsesEqualMat4)
499    {
500        out << "bool equal(float4x4 m, float4x4 n)\n"
501               "{\n"
502               "    return m[0][0] == n[0][0] && m[0][1] == n[0][1] && m[0][2] == n[0][2] && m[0][3] == n[0][3] &&\n"
503               "           m[1][0] == n[1][0] && m[1][1] == n[1][1] && m[1][2] == n[1][2] && m[1][3] == n[1][3] &&\n"
504               "           m[2][0] == n[2][0] && m[2][1] == n[2][1] && m[2][2] == n[2][2] && m[2][3] == n[2][3] &&\n"
505               "           m[3][0] == n[3][0] && m[3][1] == n[3][1] && m[3][2] == n[3][2] && m[3][3] == n[3][3];\n"
506               "}\n";
507    }
508
509    if (mUsesEqualVec2)
510    {
511        out << "bool equal(float2 v, float2 u)\n"
512               "{\n"
513               "    return v.x == u.x && v.y == u.y;\n"
514               "}\n";
515    }
516
517    if (mUsesEqualVec3)
518    {
519        out << "bool equal(float3 v, float3 u)\n"
520               "{\n"
521               "    return v.x == u.x && v.y == u.y && v.z == u.z;\n"
522               "}\n";
523    }
524
525    if (mUsesEqualVec4)
526    {
527        out << "bool equal(float4 v, float4 u)\n"
528               "{\n"
529               "    return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
530               "}\n";
531    }
532
533    if (mUsesEqualIVec2)
534    {
535        out << "bool equal(int2 v, int2 u)\n"
536               "{\n"
537               "    return v.x == u.x && v.y == u.y;\n"
538               "}\n";
539    }
540
541    if (mUsesEqualIVec3)
542    {
543        out << "bool equal(int3 v, int3 u)\n"
544               "{\n"
545               "    return v.x == u.x && v.y == u.y && v.z == u.z;\n"
546               "}\n";
547    }
548
549    if (mUsesEqualIVec4)
550    {
551        out << "bool equal(int4 v, int4 u)\n"
552               "{\n"
553               "    return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
554               "}\n";
555    }
556
557    if (mUsesEqualBVec2)
558    {
559        out << "bool equal(bool2 v, bool2 u)\n"
560               "{\n"
561               "    return v.x == u.x && v.y == u.y;\n"
562               "}\n";
563    }
564
565    if (mUsesEqualBVec3)
566    {
567        out << "bool equal(bool3 v, bool3 u)\n"
568               "{\n"
569               "    return v.x == u.x && v.y == u.y && v.z == u.z;\n"
570               "}\n";
571    }
572
573    if (mUsesEqualBVec4)
574    {
575        out << "bool equal(bool4 v, bool4 u)\n"
576               "{\n"
577               "    return v.x == u.x && v.y == u.y && v.z == u.z && v.w == u.w;\n"
578               "}\n";
579    }
580
581    if (mUsesAtan2)
582    {
583        out << "float atanyx(float y, float x)\n"
584               "{\n"
585               "    if(x == 0 && y == 0) x = 1;\n"   // Avoid producing a NaN
586               "    return atan2(y, x);\n"
587               "}\n";
588    }
589}
590
591void OutputHLSL::visitSymbol(TIntermSymbol *node)
592{
593    TInfoSinkBase &out = mBody;
594
595    TString name = node->getSymbol();
596
597    if (name == "gl_FragColor")
598    {
599        out << "gl_Color[0]";
600    }
601    else if (name == "gl_FragData")
602    {
603        out << "gl_Color";
604    }
605    else if (name == "gl_DepthRange")
606    {
607        mUsesDepthRange = true;
608        out << name;
609    }
610    else if (name == "gl_FragCoord")
611    {
612        mUsesFragCoord = true;
613        out << name;
614    }
615    else if (name == "gl_PointCoord")
616    {
617        mUsesPointCoord = true;
618        out << name;
619    }
620    else if (name == "gl_FrontFacing")
621    {
622        mUsesFrontFacing = true;
623        out << name;
624    }
625    else if (name == "gl_PointSize")
626    {
627        mUsesPointSize = true;
628        out << name;
629    }
630    else
631    {
632        TQualifier qualifier = node->getQualifier();
633
634        if (qualifier == EvqUniform)
635        {
636            mReferencedUniforms.insert(name.c_str());
637        }
638        else if (qualifier == EvqAttribute)
639        {
640            mReferencedAttributes.insert(name.c_str());
641        }
642        else if (qualifier == EvqVaryingOut || qualifier == EvqInvariantVaryingOut || qualifier == EvqVaryingIn || qualifier == EvqInvariantVaryingIn)
643        {
644            mReferencedVaryings.insert(name.c_str());
645        }
646
647        out << decorate(name);
648    }
649}
650
651bool OutputHLSL::visitBinary(Visit visit, TIntermBinary *node)
652{
653    TInfoSinkBase &out = mBody;
654
655    switch (node->getOp())
656    {
657      case EOpAssign:                  outputTriplet(visit, "(", " = ", ")");           break;
658      case EOpInitialize:
659        if (visit == PreVisit)
660        {
661            // GLSL allows to write things like "float x = x;" where a new variable x is defined
662            // and the value of an existing variable x is assigned. HLSL uses C semantics (the
663            // new variable is created before the assignment is evaluated), so we need to convert
664            // this to "float t = x, x = t;".
665
666            TIntermSymbol *symbolNode = node->getLeft()->getAsSymbolNode();
667            TIntermTyped *expression = node->getRight();
668
669            sh::SearchSymbol searchSymbol(symbolNode->getSymbol());
670            expression->traverse(&searchSymbol);
671            bool sameSymbol = searchSymbol.foundMatch();
672
673            if (sameSymbol)
674            {
675                // Type already printed
676                out << "t" + str(mUniqueIndex) + " = ";
677                expression->traverse(this);
678                out << ", ";
679                symbolNode->traverse(this);
680                out << " = t" + str(mUniqueIndex);
681
682                mUniqueIndex++;
683                return false;
684            }
685        }
686        else if (visit == InVisit)
687        {
688            out << " = ";
689        }
690        break;
691      case EOpAddAssign:               outputTriplet(visit, "(", " += ", ")");          break;
692      case EOpSubAssign:               outputTriplet(visit, "(", " -= ", ")");          break;
693      case EOpMulAssign:               outputTriplet(visit, "(", " *= ", ")");          break;
694      case EOpVectorTimesScalarAssign: outputTriplet(visit, "(", " *= ", ")");          break;
695      case EOpMatrixTimesScalarAssign: outputTriplet(visit, "(", " *= ", ")");          break;
696      case EOpVectorTimesMatrixAssign:
697        if (visit == PreVisit)
698        {
699            out << "(";
700        }
701        else if (visit == InVisit)
702        {
703            out << " = mul(";
704            node->getLeft()->traverse(this);
705            out << ", transpose(";
706        }
707        else
708        {
709            out << ")))";
710        }
711        break;
712      case EOpMatrixTimesMatrixAssign:
713        if (visit == PreVisit)
714        {
715            out << "(";
716        }
717        else if (visit == InVisit)
718        {
719            out << " = mul(";
720            node->getLeft()->traverse(this);
721            out << ", ";
722        }
723        else
724        {
725            out << "))";
726        }
727        break;
728      case EOpDivAssign:               outputTriplet(visit, "(", " /= ", ")");          break;
729      case EOpIndexDirect:             outputTriplet(visit, "", "[", "]");              break;
730      case EOpIndexIndirect:           outputTriplet(visit, "", "[", "]");              break;
731      case EOpIndexDirectStruct:
732        if (visit == InVisit)
733        {
734            out << "." + node->getType().getFieldName();
735
736            return false;
737        }
738        break;
739      case EOpVectorSwizzle:
740        if (visit == InVisit)
741        {
742            out << ".";
743
744            TIntermAggregate *swizzle = node->getRight()->getAsAggregate();
745
746            if (swizzle)
747            {
748                TIntermSequence &sequence = swizzle->getSequence();
749
750                for (TIntermSequence::iterator sit = sequence.begin(); sit != sequence.end(); sit++)
751                {
752                    TIntermConstantUnion *element = (*sit)->getAsConstantUnion();
753
754                    if (element)
755                    {
756                        int i = element->getUnionArrayPointer()[0].getIConst();
757
758                        switch (i)
759                        {
760                        case 0: out << "x"; break;
761                        case 1: out << "y"; break;
762                        case 2: out << "z"; break;
763                        case 3: out << "w"; break;
764                        default: UNREACHABLE();
765                        }
766                    }
767                    else UNREACHABLE();
768                }
769            }
770            else UNREACHABLE();
771
772            return false;   // Fully processed
773        }
774        break;
775      case EOpAdd:               outputTriplet(visit, "(", " + ", ")"); break;
776      case EOpSub:               outputTriplet(visit, "(", " - ", ")"); break;
777      case EOpMul:               outputTriplet(visit, "(", " * ", ")"); break;
778      case EOpDiv:               outputTriplet(visit, "(", " / ", ")"); break;
779      case EOpEqual:
780      case EOpNotEqual:
781        if (node->getLeft()->isScalar())
782        {
783            if (node->getOp() == EOpEqual)
784            {
785                outputTriplet(visit, "(", " == ", ")");
786            }
787            else
788            {
789                outputTriplet(visit, "(", " != ", ")");
790            }
791        }
792        else if (node->getLeft()->getBasicType() == EbtStruct)
793        {
794            if (node->getOp() == EOpEqual)
795            {
796                out << "(";
797            }
798            else
799            {
800                out << "!(";
801            }
802
803            const TTypeList *fields = node->getLeft()->getType().getStruct();
804
805            for (size_t i = 0; i < fields->size(); i++)
806            {
807                const TType *fieldType = (*fields)[i].type;
808
809                node->getLeft()->traverse(this);
810                out << "." + fieldType->getFieldName() + " == ";
811                node->getRight()->traverse(this);
812                out << "." + fieldType->getFieldName();
813
814                if (i < fields->size() - 1)
815                {
816                    out << " && ";
817                }
818            }
819
820            out << ")";
821
822            return false;
823        }
824        else
825        {
826            if (node->getLeft()->isMatrix())
827            {
828                switch (node->getLeft()->getNominalSize())
829                {
830                  case 2: mUsesEqualMat2 = true; break;
831                  case 3: mUsesEqualMat3 = true; break;
832                  case 4: mUsesEqualMat4 = true; break;
833                  default: UNREACHABLE();
834                }
835            }
836            else if (node->getLeft()->isVector())
837            {
838                switch (node->getLeft()->getBasicType())
839                {
840                  case EbtFloat:
841                    switch (node->getLeft()->getNominalSize())
842                    {
843                      case 2: mUsesEqualVec2 = true; break;
844                      case 3: mUsesEqualVec3 = true; break;
845                      case 4: mUsesEqualVec4 = true; break;
846                      default: UNREACHABLE();
847                    }
848                    break;
849                  case EbtInt:
850                    switch (node->getLeft()->getNominalSize())
851                    {
852                      case 2: mUsesEqualIVec2 = true; break;
853                      case 3: mUsesEqualIVec3 = true; break;
854                      case 4: mUsesEqualIVec4 = true; break;
855                      default: UNREACHABLE();
856                    }
857                    break;
858                  case EbtBool:
859                    switch (node->getLeft()->getNominalSize())
860                    {
861                      case 2: mUsesEqualBVec2 = true; break;
862                      case 3: mUsesEqualBVec3 = true; break;
863                      case 4: mUsesEqualBVec4 = true; break;
864                      default: UNREACHABLE();
865                    }
866                    break;
867                  default: UNREACHABLE();
868                }
869            }
870            else UNREACHABLE();
871
872            if (node->getOp() == EOpEqual)
873            {
874                outputTriplet(visit, "equal(", ", ", ")");
875            }
876            else
877            {
878                outputTriplet(visit, "!equal(", ", ", ")");
879            }
880        }
881        break;
882      case EOpLessThan:          outputTriplet(visit, "(", " < ", ")");   break;
883      case EOpGreaterThan:       outputTriplet(visit, "(", " > ", ")");   break;
884      case EOpLessThanEqual:     outputTriplet(visit, "(", " <= ", ")");  break;
885      case EOpGreaterThanEqual:  outputTriplet(visit, "(", " >= ", ")");  break;
886      case EOpVectorTimesScalar: outputTriplet(visit, "(", " * ", ")");   break;
887      case EOpMatrixTimesScalar: outputTriplet(visit, "(", " * ", ")");   break;
888      case EOpVectorTimesMatrix: outputTriplet(visit, "mul(", ", transpose(", "))"); break;
889      case EOpMatrixTimesVector: outputTriplet(visit, "mul(transpose(", "), ", ")"); break;
890      case EOpMatrixTimesMatrix: outputTriplet(visit, "transpose(mul(transpose(", "), transpose(", ")))"); break;
891      case EOpLogicalOr:         outputTriplet(visit, "(", " || ", ")");  break;
892      case EOpLogicalXor:
893        mUsesXor = true;
894        outputTriplet(visit, "xor(", ", ", ")");
895        break;
896      case EOpLogicalAnd:        outputTriplet(visit, "(", " && ", ")");  break;
897      default: UNREACHABLE();
898    }
899
900    return true;
901}
902
903bool OutputHLSL::visitUnary(Visit visit, TIntermUnary *node)
904{
905    TInfoSinkBase &out = mBody;
906
907    switch (node->getOp())
908    {
909      case EOpNegative:         outputTriplet(visit, "(-", "", ")");  break;
910      case EOpVectorLogicalNot: outputTriplet(visit, "(!", "", ")");  break;
911      case EOpLogicalNot:       outputTriplet(visit, "(!", "", ")");  break;
912      case EOpPostIncrement:    outputTriplet(visit, "(", "", "++)"); break;
913      case EOpPostDecrement:    outputTriplet(visit, "(", "", "--)"); break;
914      case EOpPreIncrement:     outputTriplet(visit, "(++", "", ")"); break;
915      case EOpPreDecrement:     outputTriplet(visit, "(--", "", ")"); break;
916      case EOpConvIntToBool:
917      case EOpConvFloatToBool:
918        switch (node->getOperand()->getType().getNominalSize())
919        {
920          case 1:    outputTriplet(visit, "bool(", "", ")");  break;
921          case 2:    outputTriplet(visit, "bool2(", "", ")"); break;
922          case 3:    outputTriplet(visit, "bool3(", "", ")"); break;
923          case 4:    outputTriplet(visit, "bool4(", "", ")"); break;
924          default: UNREACHABLE();
925        }
926        break;
927      case EOpConvBoolToFloat:
928      case EOpConvIntToFloat:
929        switch (node->getOperand()->getType().getNominalSize())
930        {
931          case 1:    outputTriplet(visit, "float(", "", ")");  break;
932          case 2:    outputTriplet(visit, "float2(", "", ")"); break;
933          case 3:    outputTriplet(visit, "float3(", "", ")"); break;
934          case 4:    outputTriplet(visit, "float4(", "", ")"); break;
935          default: UNREACHABLE();
936        }
937        break;
938      case EOpConvFloatToInt:
939      case EOpConvBoolToInt:
940        switch (node->getOperand()->getType().getNominalSize())
941        {
942          case 1:    outputTriplet(visit, "int(", "", ")");  break;
943          case 2:    outputTriplet(visit, "int2(", "", ")"); break;
944          case 3:    outputTriplet(visit, "int3(", "", ")"); break;
945          case 4:    outputTriplet(visit, "int4(", "", ")"); break;
946          default: UNREACHABLE();
947        }
948        break;
949      case EOpRadians:          outputTriplet(visit, "radians(", "", ")");   break;
950      case EOpDegrees:          outputTriplet(visit, "degrees(", "", ")");   break;
951      case EOpSin:              outputTriplet(visit, "sin(", "", ")");       break;
952      case EOpCos:              outputTriplet(visit, "cos(", "", ")");       break;
953      case EOpTan:              outputTriplet(visit, "tan(", "", ")");       break;
954      case EOpAsin:             outputTriplet(visit, "asin(", "", ")");      break;
955      case EOpAcos:             outputTriplet(visit, "acos(", "", ")");      break;
956      case EOpAtan:             outputTriplet(visit, "atan(", "", ")");      break;
957      case EOpExp:              outputTriplet(visit, "exp(", "", ")");       break;
958      case EOpLog:              outputTriplet(visit, "log(", "", ")");       break;
959      case EOpExp2:             outputTriplet(visit, "exp2(", "", ")");      break;
960      case EOpLog2:             outputTriplet(visit, "log2(", "", ")");      break;
961      case EOpSqrt:             outputTriplet(visit, "sqrt(", "", ")");      break;
962      case EOpInverseSqrt:      outputTriplet(visit, "rsqrt(", "", ")");     break;
963      case EOpAbs:              outputTriplet(visit, "abs(", "", ")");       break;
964      case EOpSign:             outputTriplet(visit, "sign(", "", ")");      break;
965      case EOpFloor:            outputTriplet(visit, "floor(", "", ")");     break;
966      case EOpCeil:             outputTriplet(visit, "ceil(", "", ")");      break;
967      case EOpFract:            outputTriplet(visit, "frac(", "", ")");      break;
968      case EOpLength:           outputTriplet(visit, "length(", "", ")");    break;
969      case EOpNormalize:        outputTriplet(visit, "normalize(", "", ")"); break;
970      case EOpDFdx:             outputTriplet(visit, "ddx(", "", ")");       break;
971      case EOpDFdy:             outputTriplet(visit, "ddy(", "", ")");       break;
972      case EOpFwidth:           outputTriplet(visit, "fwidth(", "", ")");    break;
973      case EOpAny:              outputTriplet(visit, "any(", "", ")");       break;
974      case EOpAll:              outputTriplet(visit, "all(", "", ")");       break;
975      default: UNREACHABLE();
976    }
977
978    return true;
979}
980
981bool OutputHLSL::visitAggregate(Visit visit, TIntermAggregate *node)
982{
983    ShShaderType shaderType = mContext.shaderType;
984    TInfoSinkBase &out = mBody;
985
986    switch (node->getOp())
987    {
988      case EOpSequence:
989        {
990            if (mInsideFunction)
991            {
992                out << "{\n";
993
994                mScopeDepth++;
995
996                if (mScopeBracket.size() < mScopeDepth)
997                {
998                    mScopeBracket.push_back(0);   // New scope level
999                }
1000                else
1001                {
1002                    mScopeBracket[mScopeDepth - 1]++;   // New scope at existing level
1003                }
1004            }
1005
1006            for (TIntermSequence::iterator sit = node->getSequence().begin(); sit != node->getSequence().end(); sit++)
1007            {
1008                if (isSingleStatement(*sit))
1009                {
1010                    mUnfoldSelect->traverse(*sit);
1011                }
1012
1013                (*sit)->traverse(this);
1014
1015                out << ";\n";
1016            }
1017
1018            if (mInsideFunction)
1019            {
1020                out << "}\n";
1021
1022                mScopeDepth--;
1023            }
1024
1025            return false;
1026        }
1027      case EOpDeclaration:
1028        if (visit == PreVisit)
1029        {
1030            TIntermSequence &sequence = node->getSequence();
1031            TIntermTyped *variable = sequence[0]->getAsTyped();
1032            bool visit = true;
1033
1034            if (variable && (variable->getQualifier() == EvqTemporary || variable->getQualifier() == EvqGlobal))
1035            {
1036                if (variable->getType().getStruct())
1037                {
1038                    addConstructor(variable->getType(), scopedStruct(variable->getType().getTypeName()), NULL);
1039                }
1040
1041                if (!variable->getAsSymbolNode() || variable->getAsSymbolNode()->getSymbol() != "")   // Variable declaration
1042                {
1043                    if (!mInsideFunction)
1044                    {
1045                        out << "static ";
1046                    }
1047
1048                    out << typeString(variable->getType()) + " ";
1049
1050                    for (TIntermSequence::iterator sit = sequence.begin(); sit != sequence.end(); sit++)
1051                    {
1052                        TIntermSymbol *symbol = (*sit)->getAsSymbolNode();
1053
1054                        if (symbol)
1055                        {
1056                            symbol->traverse(this);
1057                            out << arrayString(symbol->getType());
1058                            out << " = " + initializer(variable->getType());
1059                        }
1060                        else
1061                        {
1062                            (*sit)->traverse(this);
1063                        }
1064
1065                        if (visit && this->inVisit)
1066                        {
1067                            if (*sit != sequence.back())
1068                            {
1069                                visit = this->visitAggregate(InVisit, node);
1070                            }
1071                        }
1072                    }
1073
1074                    if (visit && this->postVisit)
1075                    {
1076                        this->visitAggregate(PostVisit, node);
1077                    }
1078                }
1079                else if (variable->getAsSymbolNode() && variable->getAsSymbolNode()->getSymbol() == "")   // Type (struct) declaration
1080                {
1081                    // Already added to constructor map
1082                }
1083                else UNREACHABLE();
1084            }
1085
1086            return false;
1087        }
1088        else if (visit == InVisit)
1089        {
1090            out << ", ";
1091        }
1092        break;
1093      case EOpPrototype:
1094        if (visit == PreVisit)
1095        {
1096            out << typeString(node->getType()) << " " << decorate(node->getName()) << "(";
1097
1098            TIntermSequence &arguments = node->getSequence();
1099
1100            for (unsigned int i = 0; i < arguments.size(); i++)
1101            {
1102                TIntermSymbol *symbol = arguments[i]->getAsSymbolNode();
1103
1104                if (symbol)
1105                {
1106                    out << argumentString(symbol);
1107
1108                    if (i < arguments.size() - 1)
1109                    {
1110                        out << ", ";
1111                    }
1112                }
1113                else UNREACHABLE();
1114            }
1115
1116            out << ");\n";
1117
1118            return false;
1119        }
1120        break;
1121      case EOpComma:            outputTriplet(visit, "", ", ", "");                break;
1122      case EOpFunction:
1123        {
1124            TString name = TFunction::unmangleName(node->getName());
1125
1126            if (visit == PreVisit)
1127            {
1128                out << typeString(node->getType()) << " ";
1129
1130                if (name == "main")
1131                {
1132                    out << "gl_main(";
1133                }
1134                else
1135                {
1136                    out << decorate(name) << "(";
1137                }
1138
1139                TIntermSequence &sequence = node->getSequence();
1140                TIntermSequence &arguments = sequence[0]->getAsAggregate()->getSequence();
1141
1142                for (unsigned int i = 0; i < arguments.size(); i++)
1143                {
1144                    TIntermSymbol *symbol = arguments[i]->getAsSymbolNode();
1145
1146                    if (symbol)
1147                    {
1148                        out << argumentString(symbol);
1149
1150                        if (i < arguments.size() - 1)
1151                        {
1152                            out << ", ";
1153                        }
1154                    }
1155                    else UNREACHABLE();
1156                }
1157
1158                sequence.erase(sequence.begin());
1159
1160                out << ")\n"
1161                       "{\n";
1162
1163                mInsideFunction = true;
1164            }
1165            else if (visit == PostVisit)
1166            {
1167                out << "}\n";
1168
1169                mInsideFunction = false;
1170            }
1171        }
1172        break;
1173      case EOpFunctionCall:
1174        {
1175            if (visit == PreVisit)
1176            {
1177                TString name = TFunction::unmangleName(node->getName());
1178
1179                if (node->isUserDefined())
1180                {
1181                    out << decorate(name) << "(";
1182                }
1183                else
1184                {
1185                    if (name == "texture2D")
1186                    {
1187                        if (node->getSequence().size() == 2)
1188                        {
1189                            mUsesTexture2D = true;
1190                        }
1191                        else if (node->getSequence().size() == 3)
1192                        {
1193                            mUsesTexture2D_bias = true;
1194                        }
1195                        else UNREACHABLE();
1196
1197                        out << "gl_texture2D(";
1198                    }
1199                    else if (name == "texture2DProj")
1200                    {
1201                        if (node->getSequence().size() == 2)
1202                        {
1203                            mUsesTexture2DProj = true;
1204                        }
1205                        else if (node->getSequence().size() == 3)
1206                        {
1207                            mUsesTexture2DProj_bias = true;
1208                        }
1209                        else UNREACHABLE();
1210
1211                        out << "gl_texture2DProj(";
1212                    }
1213                    else if (name == "textureCube")
1214                    {
1215                        if (node->getSequence().size() == 2)
1216                        {
1217                            mUsesTextureCube = true;
1218                        }
1219                        else if (node->getSequence().size() == 3)
1220                        {
1221                            mUsesTextureCube_bias = true;
1222                        }
1223                        else UNREACHABLE();
1224
1225                        out << "gl_textureCube(";
1226                    }
1227                    else if (name == "texture2DLod")
1228                    {
1229                        UNIMPLEMENTED();   // Requires the vertex shader texture sampling extension
1230                    }
1231                    else if (name == "texture2DProjLod")
1232                    {
1233                        UNIMPLEMENTED();   // Requires the vertex shader texture sampling extension
1234                    }
1235                    else if (name == "textureCubeLod")
1236                    {
1237                        UNIMPLEMENTED();   // Requires the vertex shader texture sampling extension
1238                    }
1239                    else UNREACHABLE();
1240                }
1241            }
1242            else if (visit == InVisit)
1243            {
1244                out << ", ";
1245            }
1246            else
1247            {
1248                out << ")";
1249            }
1250        }
1251        break;
1252      case EOpParameters:       outputTriplet(visit, "(", ", ", ")\n{\n");             break;
1253      case EOpConstructFloat:
1254        addConstructor(node->getType(), "vec1", &node->getSequence());
1255        outputTriplet(visit, "vec1(", "", ")");
1256        break;
1257      case EOpConstructVec2:
1258        addConstructor(node->getType(), "vec2", &node->getSequence());
1259        outputTriplet(visit, "vec2(", ", ", ")");
1260        break;
1261      case EOpConstructVec3:
1262        addConstructor(node->getType(), "vec3", &node->getSequence());
1263        outputTriplet(visit, "vec3(", ", ", ")");
1264        break;
1265      case EOpConstructVec4:
1266        addConstructor(node->getType(), "vec4", &node->getSequence());
1267        outputTriplet(visit, "vec4(", ", ", ")");
1268        break;
1269      case EOpConstructBool:
1270        addConstructor(node->getType(), "bvec1", &node->getSequence());
1271        outputTriplet(visit, "bvec1(", "", ")");
1272        break;
1273      case EOpConstructBVec2:
1274        addConstructor(node->getType(), "bvec2", &node->getSequence());
1275        outputTriplet(visit, "bvec2(", ", ", ")");
1276        break;
1277      case EOpConstructBVec3:
1278        addConstructor(node->getType(), "bvec3", &node->getSequence());
1279        outputTriplet(visit, "bvec3(", ", ", ")");
1280        break;
1281      case EOpConstructBVec4:
1282        addConstructor(node->getType(), "bvec4", &node->getSequence());
1283        outputTriplet(visit, "bvec4(", ", ", ")");
1284        break;
1285      case EOpConstructInt:
1286        addConstructor(node->getType(), "ivec1", &node->getSequence());
1287        outputTriplet(visit, "ivec1(", "", ")");
1288        break;
1289      case EOpConstructIVec2:
1290        addConstructor(node->getType(), "ivec2", &node->getSequence());
1291        outputTriplet(visit, "ivec2(", ", ", ")");
1292        break;
1293      case EOpConstructIVec3:
1294        addConstructor(node->getType(), "ivec3", &node->getSequence());
1295        outputTriplet(visit, "ivec3(", ", ", ")");
1296        break;
1297      case EOpConstructIVec4:
1298        addConstructor(node->getType(), "ivec4", &node->getSequence());
1299        outputTriplet(visit, "ivec4(", ", ", ")");
1300        break;
1301      case EOpConstructMat2:
1302        addConstructor(node->getType(), "mat2", &node->getSequence());
1303        outputTriplet(visit, "mat2(", ", ", ")");
1304        break;
1305      case EOpConstructMat3:
1306        addConstructor(node->getType(), "mat3", &node->getSequence());
1307        outputTriplet(visit, "mat3(", ", ", ")");
1308        break;
1309      case EOpConstructMat4:
1310        addConstructor(node->getType(), "mat4", &node->getSequence());
1311        outputTriplet(visit, "mat4(", ", ", ")");
1312        break;
1313      case EOpConstructStruct:
1314        addConstructor(node->getType(), scopedStruct(node->getType().getTypeName()), &node->getSequence());
1315        outputTriplet(visit, structLookup(node->getType().getTypeName()) + "_ctor(", ", ", ")");
1316        break;
1317      case EOpLessThan:         outputTriplet(visit, "(", " < ", ")");                 break;
1318      case EOpGreaterThan:      outputTriplet(visit, "(", " > ", ")");                 break;
1319      case EOpLessThanEqual:    outputTriplet(visit, "(", " <= ", ")");                break;
1320      case EOpGreaterThanEqual: outputTriplet(visit, "(", " >= ", ")");                break;
1321      case EOpVectorEqual:      outputTriplet(visit, "(", " == ", ")");                break;
1322      case EOpVectorNotEqual:   outputTriplet(visit, "(", " != ", ")");                break;
1323      case EOpMod:
1324        {
1325            switch (node->getSequence()[0]->getAsTyped()->getNominalSize())   // Number of components in the first argument
1326            {
1327              case 1: mUsesMod1 = true; break;
1328              case 2: mUsesMod2 = true; break;
1329              case 3: mUsesMod3 = true; break;
1330              case 4: mUsesMod4 = true; break;
1331              default: UNREACHABLE();
1332            }
1333
1334            outputTriplet(visit, "mod(", ", ", ")");
1335        }
1336        break;
1337      case EOpPow:              outputTriplet(visit, "pow(", ", ", ")");               break;
1338      case EOpAtan:
1339        ASSERT(node->getSequence().size() == 2);   // atan(x) is a unary operator
1340        mUsesAtan2 = true;
1341        outputTriplet(visit, "atanyx(", ", ", ")");
1342        break;
1343      case EOpMin:           outputTriplet(visit, "min(", ", ", ")");           break;
1344      case EOpMax:           outputTriplet(visit, "max(", ", ", ")");           break;
1345      case EOpClamp:         outputTriplet(visit, "clamp(", ", ", ")");         break;
1346      case EOpMix:           outputTriplet(visit, "lerp(", ", ", ")");          break;
1347      case EOpStep:          outputTriplet(visit, "step(", ", ", ")");          break;
1348      case EOpSmoothStep:    outputTriplet(visit, "smoothstep(", ", ", ")");    break;
1349      case EOpDistance:      outputTriplet(visit, "distance(", ", ", ")");      break;
1350      case EOpDot:           outputTriplet(visit, "dot(", ", ", ")");           break;
1351      case EOpCross:         outputTriplet(visit, "cross(", ", ", ")");         break;
1352      case EOpFaceForward:
1353        {
1354            switch (node->getSequence()[0]->getAsTyped()->getNominalSize())   // Number of components in the first argument
1355            {
1356            case 1: mUsesFaceforward1 = true; break;
1357            case 2: mUsesFaceforward2 = true; break;
1358            case 3: mUsesFaceforward3 = true; break;
1359            case 4: mUsesFaceforward4 = true; break;
1360            default: UNREACHABLE();
1361            }
1362
1363            outputTriplet(visit, "faceforward(", ", ", ")");
1364        }
1365        break;
1366      case EOpReflect:       outputTriplet(visit, "reflect(", ", ", ")");       break;
1367      case EOpRefract:       outputTriplet(visit, "refract(", ", ", ")");       break;
1368      case EOpMul:           outputTriplet(visit, "(", " * ", ")");             break;
1369      default: UNREACHABLE();
1370    }
1371
1372    return true;
1373}
1374
1375bool OutputHLSL::visitSelection(Visit visit, TIntermSelection *node)
1376{
1377    TInfoSinkBase &out = mBody;
1378
1379    if (node->usesTernaryOperator())
1380    {
1381        out << "t" << mUnfoldSelect->getTemporaryIndex();
1382    }
1383    else  // if/else statement
1384    {
1385        mUnfoldSelect->traverse(node->getCondition());
1386
1387        out << "if(";
1388
1389        node->getCondition()->traverse(this);
1390
1391        out << ")\n"
1392               "{\n";
1393
1394        if (node->getTrueBlock())
1395        {
1396            node->getTrueBlock()->traverse(this);
1397        }
1398
1399        out << ";}\n";
1400
1401        if (node->getFalseBlock())
1402        {
1403            out << "else\n"
1404                   "{\n";
1405
1406            node->getFalseBlock()->traverse(this);
1407
1408            out << ";}\n";
1409        }
1410    }
1411
1412    return false;
1413}
1414
1415void OutputHLSL::visitConstantUnion(TIntermConstantUnion *node)
1416{
1417    writeConstantUnion(node->getType(), node->getUnionArrayPointer());
1418}
1419
1420bool OutputHLSL::visitLoop(Visit visit, TIntermLoop *node)
1421{
1422    if (handleExcessiveLoop(node))
1423    {
1424        return false;
1425    }
1426
1427    TInfoSinkBase &out = mBody;
1428
1429    if (node->getType() == ELoopDoWhile)
1430    {
1431        out << "do\n"
1432               "{\n";
1433    }
1434    else
1435    {
1436        if (node->getInit())
1437        {
1438            mUnfoldSelect->traverse(node->getInit());
1439        }
1440
1441        if (node->getCondition())
1442        {
1443            mUnfoldSelect->traverse(node->getCondition());
1444        }
1445
1446        if (node->getExpression())
1447        {
1448            mUnfoldSelect->traverse(node->getExpression());
1449        }
1450
1451        out << "for(";
1452
1453        if (node->getInit())
1454        {
1455            node->getInit()->traverse(this);
1456        }
1457
1458        out << "; ";
1459
1460        if (node->getCondition())
1461        {
1462            node->getCondition()->traverse(this);
1463        }
1464
1465        out << "; ";
1466
1467        if (node->getExpression())
1468        {
1469            node->getExpression()->traverse(this);
1470        }
1471
1472        out << ")\n"
1473               "{\n";
1474    }
1475
1476    if (node->getBody())
1477    {
1478        node->getBody()->traverse(this);
1479    }
1480
1481    out << "}\n";
1482
1483    if (node->getType() == ELoopDoWhile)
1484    {
1485        out << "while(\n";
1486
1487        node->getCondition()->traverse(this);
1488
1489        out << ")";
1490    }
1491
1492    out << ";\n";
1493
1494    return false;
1495}
1496
1497bool OutputHLSL::visitBranch(Visit visit, TIntermBranch *node)
1498{
1499    TInfoSinkBase &out = mBody;
1500
1501    switch (node->getFlowOp())
1502    {
1503      case EOpKill:     outputTriplet(visit, "discard", "", "");  break;
1504      case EOpBreak:    outputTriplet(visit, "break", "", "");    break;
1505      case EOpContinue: outputTriplet(visit, "continue", "", ""); break;
1506      case EOpReturn:
1507        if (visit == PreVisit)
1508        {
1509            if (node->getExpression())
1510            {
1511                out << "return ";
1512            }
1513            else
1514            {
1515                out << "return;\n";
1516            }
1517        }
1518        else if (visit == PostVisit)
1519        {
1520            out << ";\n";
1521        }
1522        break;
1523      default: UNREACHABLE();
1524    }
1525
1526    return true;
1527}
1528
1529bool OutputHLSL::isSingleStatement(TIntermNode *node)
1530{
1531    TIntermAggregate *aggregate = node->getAsAggregate();
1532
1533    if (aggregate)
1534    {
1535        if (aggregate->getOp() == EOpSequence)
1536        {
1537            return false;
1538        }
1539        else
1540        {
1541            for (TIntermSequence::iterator sit = aggregate->getSequence().begin(); sit != aggregate->getSequence().end(); sit++)
1542            {
1543                if (!isSingleStatement(*sit))
1544                {
1545                    return false;
1546                }
1547            }
1548
1549            return true;
1550        }
1551    }
1552
1553    return true;
1554}
1555
1556// Handle loops with more than 255 iterations (unsupported by D3D9) by splitting them
1557bool OutputHLSL::handleExcessiveLoop(TIntermLoop *node)
1558{
1559    TInfoSinkBase &out = mBody;
1560
1561    // Parse loops of the form:
1562    // for(int index = initial; index [comparator] limit; index += increment)
1563    TIntermSymbol *index = NULL;
1564    TOperator comparator = EOpNull;
1565    int initial = 0;
1566    int limit = 0;
1567    int increment = 0;
1568
1569    // Parse index name and intial value
1570    if (node->getInit())
1571    {
1572        TIntermAggregate *init = node->getInit()->getAsAggregate();
1573
1574        if (init)
1575        {
1576            TIntermSequence &sequence = init->getSequence();
1577            TIntermTyped *variable = sequence[0]->getAsTyped();
1578
1579            if (variable && variable->getQualifier() == EvqTemporary)
1580            {
1581                TIntermBinary *assign = variable->getAsBinaryNode();
1582
1583                if (assign->getOp() == EOpInitialize)
1584                {
1585                    TIntermSymbol *symbol = assign->getLeft()->getAsSymbolNode();
1586                    TIntermConstantUnion *constant = assign->getRight()->getAsConstantUnion();
1587
1588                    if (symbol && constant)
1589                    {
1590                        if (constant->getBasicType() == EbtInt && constant->getNominalSize() == 1)
1591                        {
1592                            index = symbol;
1593                            initial = constant->getUnionArrayPointer()[0].getIConst();
1594                        }
1595                    }
1596                }
1597            }
1598        }
1599    }
1600
1601    // Parse comparator and limit value
1602    if (index != NULL && node->getCondition())
1603    {
1604        TIntermBinary *test = node->getCondition()->getAsBinaryNode();
1605
1606        if (test && test->getLeft()->getAsSymbolNode()->getId() == index->getId())
1607        {
1608            TIntermConstantUnion *constant = test->getRight()->getAsConstantUnion();
1609
1610            if (constant)
1611            {
1612                if (constant->getBasicType() == EbtInt && constant->getNominalSize() == 1)
1613                {
1614                    comparator = test->getOp();
1615                    limit = constant->getUnionArrayPointer()[0].getIConst();
1616                }
1617            }
1618        }
1619    }
1620
1621    // Parse increment
1622    if (index != NULL && comparator != EOpNull && node->getExpression())
1623    {
1624        TIntermBinary *binaryTerminal = node->getExpression()->getAsBinaryNode();
1625        TIntermUnary *unaryTerminal = node->getExpression()->getAsUnaryNode();
1626
1627        if (binaryTerminal)
1628        {
1629            TOperator op = binaryTerminal->getOp();
1630            TIntermConstantUnion *constant = binaryTerminal->getRight()->getAsConstantUnion();
1631
1632            if (constant)
1633            {
1634                if (constant->getBasicType() == EbtInt && constant->getNominalSize() == 1)
1635                {
1636                    int value = constant->getUnionArrayPointer()[0].getIConst();
1637
1638                    switch (op)
1639                    {
1640                      case EOpAddAssign: increment = value;  break;
1641                      case EOpSubAssign: increment = -value; break;
1642                      default: UNIMPLEMENTED();
1643                    }
1644                }
1645            }
1646        }
1647        else if (unaryTerminal)
1648        {
1649            TOperator op = unaryTerminal->getOp();
1650
1651            switch (op)
1652            {
1653              case EOpPostIncrement: increment = 1;  break;
1654              case EOpPostDecrement: increment = -1; break;
1655              case EOpPreIncrement:  increment = 1;  break;
1656              case EOpPreDecrement:  increment = -1; break;
1657              default: UNIMPLEMENTED();
1658            }
1659        }
1660    }
1661
1662    if (index != NULL && comparator != EOpNull && increment != 0)
1663    {
1664        if (comparator == EOpLessThanEqual)
1665        {
1666            comparator = EOpLessThan;
1667            limit += 1;
1668        }
1669
1670        if (comparator == EOpLessThan)
1671        {
1672            int iterations = (limit - initial + 1) / increment;
1673
1674            if (iterations <= 255)
1675            {
1676                return false;   // Not an excessive loop
1677            }
1678
1679            while (iterations > 0)
1680            {
1681                int remainder = (limit - initial + 1) % increment;
1682                int clampedLimit = initial + increment * std::min(255, iterations) - 1 - remainder;
1683
1684                // for(int index = initial; index < clampedLimit; index += increment)
1685
1686                out << "for(int ";
1687                index->traverse(this);
1688                out << " = ";
1689                out << initial;
1690
1691                out << "; ";
1692                index->traverse(this);
1693                out << " < ";
1694                out << clampedLimit;
1695
1696                out << "; ";
1697                index->traverse(this);
1698                out << " += ";
1699                out << increment;
1700                out << ")\n"
1701                       "{\n";
1702
1703                if (node->getBody())
1704                {
1705                    node->getBody()->traverse(this);
1706                }
1707
1708                out << "}\n";
1709
1710                initial += 255 * increment;
1711                iterations -= 255;
1712            }
1713
1714            return true;
1715        }
1716        else UNIMPLEMENTED();
1717    }
1718
1719    return false;   // Not handled as an excessive loop
1720}
1721
1722void OutputHLSL::outputTriplet(Visit visit, const TString &preString, const TString &inString, const TString &postString)
1723{
1724    TInfoSinkBase &out = mBody;
1725
1726    if (visit == PreVisit)
1727    {
1728        out << preString;
1729    }
1730    else if (visit == InVisit)
1731    {
1732        out << inString;
1733    }
1734    else if (visit == PostVisit)
1735    {
1736        out << postString;
1737    }
1738}
1739
1740TString OutputHLSL::argumentString(const TIntermSymbol *symbol)
1741{
1742    TQualifier qualifier = symbol->getQualifier();
1743    const TType &type = symbol->getType();
1744    TString name = symbol->getSymbol();
1745
1746    if (name.empty())   // HLSL demands named arguments, also for prototypes
1747    {
1748        name = "x" + str(mUniqueIndex++);
1749    }
1750    else
1751    {
1752        name = decorate(name);
1753    }
1754
1755    return qualifierString(qualifier) + " " + typeString(type) + " " + name + arrayString(type);
1756}
1757
1758TString OutputHLSL::qualifierString(TQualifier qualifier)
1759{
1760    switch(qualifier)
1761    {
1762      case EvqIn:            return "in";
1763      case EvqOut:           return "out";
1764      case EvqInOut:         return "inout";
1765      case EvqConstReadOnly: return "const";
1766      default: UNREACHABLE();
1767    }
1768
1769    return "";
1770}
1771
1772TString OutputHLSL::typeString(const TType &type)
1773{
1774    if (type.getBasicType() == EbtStruct)
1775    {
1776        if (type.getTypeName() != "")
1777        {
1778            return structLookup(type.getTypeName());
1779        }
1780        else   // Nameless structure, define in place
1781        {
1782            const TTypeList &fields = *type.getStruct();
1783
1784            TString string = "struct\n"
1785                             "{\n";
1786
1787            for (unsigned int i = 0; i < fields.size(); i++)
1788            {
1789                const TType &field = *fields[i].type;
1790
1791                string += "    " + typeString(field) + " " + field.getFieldName() + arrayString(field) + ";\n";
1792            }
1793
1794            string += "} ";
1795
1796            return string;
1797        }
1798    }
1799    else if (type.isMatrix())
1800    {
1801        switch (type.getNominalSize())
1802        {
1803          case 2: return "float2x2";
1804          case 3: return "float3x3";
1805          case 4: return "float4x4";
1806        }
1807    }
1808    else
1809    {
1810        switch (type.getBasicType())
1811        {
1812          case EbtFloat:
1813            switch (type.getNominalSize())
1814            {
1815              case 1: return "float";
1816              case 2: return "float2";
1817              case 3: return "float3";
1818              case 4: return "float4";
1819            }
1820          case EbtInt:
1821            switch (type.getNominalSize())
1822            {
1823              case 1: return "int";
1824              case 2: return "int2";
1825              case 3: return "int3";
1826              case 4: return "int4";
1827            }
1828          case EbtBool:
1829            switch (type.getNominalSize())
1830            {
1831              case 1: return "bool";
1832              case 2: return "bool2";
1833              case 3: return "bool3";
1834              case 4: return "bool4";
1835            }
1836          case EbtVoid:
1837            return "void";
1838          case EbtSampler2D:
1839            return "sampler2D";
1840          case EbtSamplerCube:
1841            return "samplerCUBE";
1842        }
1843    }
1844
1845    UNIMPLEMENTED();   // FIXME
1846    return "<unknown type>";
1847}
1848
1849TString OutputHLSL::arrayString(const TType &type)
1850{
1851    if (!type.isArray())
1852    {
1853        return "";
1854    }
1855
1856    return "[" + str(type.getArraySize()) + "]";
1857}
1858
1859TString OutputHLSL::initializer(const TType &type)
1860{
1861    TString string;
1862
1863    for (int component = 0; component < type.getObjectSize(); component++)
1864    {
1865        string += "0";
1866
1867        if (component < type.getObjectSize() - 1)
1868        {
1869            string += ", ";
1870        }
1871    }
1872
1873    return "{" + string + "}";
1874}
1875
1876void OutputHLSL::addConstructor(const TType &type, const TString &name, const TIntermSequence *parameters)
1877{
1878    if (name == "")
1879    {
1880        return;   // Nameless structures don't have constructors
1881    }
1882
1883    TType ctorType = type;
1884    ctorType.clearArrayness();
1885    ctorType.setPrecision(EbpHigh);
1886    ctorType.setQualifier(EvqTemporary);
1887
1888    TString ctorName = type.getStruct() ? decorate(name) : name;
1889
1890    typedef std::vector<TType> ParameterArray;
1891    ParameterArray ctorParameters;
1892
1893    if (parameters)
1894    {
1895        for (TIntermSequence::const_iterator parameter = parameters->begin(); parameter != parameters->end(); parameter++)
1896        {
1897            ctorParameters.push_back((*parameter)->getAsTyped()->getType());
1898        }
1899    }
1900    else if (type.getStruct())
1901    {
1902        mStructNames.insert(decorate(name));
1903
1904        TString structure;
1905        structure += "struct " + decorate(name) + "\n"
1906                     "{\n";
1907
1908        const TTypeList &fields = *type.getStruct();
1909
1910        for (unsigned int i = 0; i < fields.size(); i++)
1911        {
1912            const TType &field = *fields[i].type;
1913
1914            structure += "    " + typeString(field) + " " + field.getFieldName() + arrayString(field) + ";\n";
1915        }
1916
1917        structure += "};\n";
1918
1919        if (std::find(mStructDeclarations.begin(), mStructDeclarations.end(), structure) == mStructDeclarations.end())
1920        {
1921            mStructDeclarations.push_back(structure);
1922        }
1923
1924        for (unsigned int i = 0; i < fields.size(); i++)
1925        {
1926            ctorParameters.push_back(*fields[i].type);
1927        }
1928    }
1929    else UNREACHABLE();
1930
1931    TString constructor;
1932
1933    if (ctorType.getStruct())
1934    {
1935        constructor += ctorName + " " + ctorName + "_ctor(";
1936    }
1937    else   // Built-in type
1938    {
1939        constructor += typeString(ctorType) + " " + ctorName + "(";
1940    }
1941
1942    for (unsigned int parameter = 0; parameter < ctorParameters.size(); parameter++)
1943    {
1944        const TType &type = ctorParameters[parameter];
1945
1946        constructor += typeString(type) + " x" + str(parameter) + arrayString(type);
1947
1948        if (parameter < ctorParameters.size() - 1)
1949        {
1950            constructor += ", ";
1951        }
1952    }
1953
1954    constructor += ")\n"
1955                   "{\n";
1956
1957    if (ctorType.getStruct())
1958    {
1959        constructor += "    " + ctorName + " structure = {";
1960    }
1961    else
1962    {
1963        constructor += "    return " + typeString(ctorType) + "(";
1964    }
1965
1966    if (ctorType.isMatrix() && ctorParameters.size() == 1)
1967    {
1968        int dim = ctorType.getNominalSize();
1969        const TType &parameter = ctorParameters[0];
1970
1971        if (parameter.isScalar())
1972        {
1973            for (int row = 0; row < dim; row++)
1974            {
1975                for (int col = 0; col < dim; col++)
1976                {
1977                    constructor += TString((row == col) ? "x0" : "0.0");
1978
1979                    if (row < dim - 1 || col < dim - 1)
1980                    {
1981                        constructor += ", ";
1982                    }
1983                }
1984            }
1985        }
1986        else if (parameter.isMatrix())
1987        {
1988            for (int row = 0; row < dim; row++)
1989            {
1990                for (int col = 0; col < dim; col++)
1991                {
1992                    if (row < parameter.getNominalSize() && col < parameter.getNominalSize())
1993                    {
1994                        constructor += TString("x0") + "[" + str(row) + "]" + "[" + str(col) + "]";
1995                    }
1996                    else
1997                    {
1998                        constructor += TString((row == col) ? "1.0" : "0.0");
1999                    }
2000
2001                    if (row < dim - 1 || col < dim - 1)
2002                    {
2003                        constructor += ", ";
2004                    }
2005                }
2006            }
2007        }
2008        else UNREACHABLE();
2009    }
2010    else
2011    {
2012        int remainingComponents = ctorType.getObjectSize();
2013        int parameterIndex = 0;
2014
2015        while (remainingComponents > 0)
2016        {
2017            const TType &parameter = ctorParameters[parameterIndex];
2018            bool moreParameters = parameterIndex < (int)ctorParameters.size() - 1;
2019
2020            constructor += "x" + str(parameterIndex);
2021
2022            if (parameter.isScalar())
2023            {
2024                remainingComponents -= parameter.getObjectSize();
2025            }
2026            else if (parameter.isVector())
2027            {
2028                if (remainingComponents == parameter.getObjectSize() || moreParameters)
2029                {
2030                    remainingComponents -= parameter.getObjectSize();
2031                }
2032                else if (remainingComponents < parameter.getNominalSize())
2033                {
2034                    switch (remainingComponents)
2035                    {
2036                      case 1: constructor += ".x";    break;
2037                      case 2: constructor += ".xy";   break;
2038                      case 3: constructor += ".xyz";  break;
2039                      case 4: constructor += ".xyzw"; break;
2040                      default: UNREACHABLE();
2041                    }
2042
2043                    remainingComponents = 0;
2044                }
2045                else UNREACHABLE();
2046            }
2047            else if (parameter.isMatrix() || parameter.getStruct())
2048            {
2049                ASSERT(remainingComponents == parameter.getObjectSize() || moreParameters);
2050
2051                remainingComponents -= parameter.getObjectSize();
2052            }
2053            else UNREACHABLE();
2054
2055            if (moreParameters)
2056            {
2057                parameterIndex++;
2058            }
2059
2060            if (remainingComponents)
2061            {
2062                constructor += ", ";
2063            }
2064        }
2065    }
2066
2067    if (ctorType.getStruct())
2068    {
2069        constructor += "};\n"
2070                       "    return structure;\n"
2071                       "}\n";
2072    }
2073    else
2074    {
2075        constructor += ");\n"
2076                       "}\n";
2077    }
2078
2079    mConstructors.insert(constructor);
2080}
2081
2082const ConstantUnion *OutputHLSL::writeConstantUnion(const TType &type, const ConstantUnion *constUnion)
2083{
2084    TInfoSinkBase &out = mBody;
2085
2086    if (type.getBasicType() == EbtStruct)
2087    {
2088        out << structLookup(type.getTypeName()) + "_ctor(";
2089
2090        const TTypeList *structure = type.getStruct();
2091
2092        for (size_t i = 0; i < structure->size(); i++)
2093        {
2094            const TType *fieldType = (*structure)[i].type;
2095
2096            constUnion = writeConstantUnion(*fieldType, constUnion);
2097
2098            if (i != structure->size() - 1)
2099            {
2100                out << ", ";
2101            }
2102        }
2103
2104        out << ")";
2105    }
2106    else
2107    {
2108        int size = type.getObjectSize();
2109        bool writeType = size > 1;
2110
2111        if (writeType)
2112        {
2113            out << typeString(type) << "(";
2114        }
2115
2116        for (int i = 0; i < size; i++, constUnion++)
2117        {
2118            switch (constUnion->getType())
2119            {
2120              case EbtFloat: out << constUnion->getFConst(); break;
2121              case EbtInt:   out << constUnion->getIConst(); break;
2122              case EbtBool:  out << constUnion->getBConst(); break;
2123              default: UNREACHABLE();
2124            }
2125
2126            if (i != size - 1)
2127            {
2128                out << ", ";
2129            }
2130        }
2131
2132        if (writeType)
2133        {
2134            out << ")";
2135        }
2136    }
2137
2138    return constUnion;
2139}
2140
2141TString OutputHLSL::scopeString(unsigned int depthLimit)
2142{
2143    TString string;
2144
2145    for (unsigned int i = 0; i < mScopeBracket.size() && i < depthLimit; i++)
2146    {
2147        string += "_" + str(i);
2148    }
2149
2150    return string;
2151}
2152
2153TString OutputHLSL::scopedStruct(const TString &typeName)
2154{
2155    if (typeName == "")
2156    {
2157        return typeName;
2158    }
2159
2160    return typeName + scopeString(mScopeDepth);
2161}
2162
2163TString OutputHLSL::structLookup(const TString &typeName)
2164{
2165    for (int depth = mScopeDepth; depth >= 0; depth--)
2166    {
2167        TString scopedName = decorate(typeName + scopeString(depth));
2168
2169        for (StructNames::iterator structName = mStructNames.begin(); structName != mStructNames.end(); structName++)
2170        {
2171            if (*structName == scopedName)
2172            {
2173                return scopedName;
2174            }
2175        }
2176    }
2177
2178    UNREACHABLE();   // Should have found a matching constructor
2179
2180    return typeName;
2181}
2182
2183TString OutputHLSL::decorate(const TString &string)
2184{
2185    if (string.substr(0, 3) != "gl_" && string.substr(0, 3) != "dx_")
2186    {
2187        return "_" + string;
2188    }
2189    else
2190    {
2191        return string;
2192    }
2193}
2194}
2195