1/*
2 * Copyright 2016 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "SkSLMetalCodeGenerator.h"
9
10#include "SkSLCompiler.h"
11#include "ir/SkSLExpressionStatement.h"
12#include "ir/SkSLExtension.h"
13#include "ir/SkSLIndexExpression.h"
14#include "ir/SkSLModifiersDeclaration.h"
15#include "ir/SkSLNop.h"
16#include "ir/SkSLVariableReference.h"
17
18namespace SkSL {
19
20void MetalCodeGenerator::write(const char* s) {
21    if (!s[0]) {
22        return;
23    }
24    if (fAtLineStart) {
25        for (int i = 0; i < fIndentation; i++) {
26            fOut->writeText("    ");
27        }
28    }
29    fOut->writeText(s);
30    fAtLineStart = false;
31}
32
33void MetalCodeGenerator::writeLine(const char* s) {
34    this->write(s);
35    fOut->writeText(fLineEnding);
36    fAtLineStart = true;
37}
38
39void MetalCodeGenerator::write(const String& s) {
40    this->write(s.c_str());
41}
42
43void MetalCodeGenerator::writeLine(const String& s) {
44    this->writeLine(s.c_str());
45}
46
47void MetalCodeGenerator::writeLine() {
48    this->writeLine("");
49}
50
51void MetalCodeGenerator::writeExtension(const Extension& ext) {
52    this->writeLine("#extension " + ext.fName + " : enable");
53}
54
55void MetalCodeGenerator::writeType(const Type& type) {
56    switch (type.kind()) {
57        case Type::kStruct_Kind:
58            for (const Type* search : fWrittenStructs) {
59                if (*search == type) {
60                    // already written
61                    this->write(type.name());
62                    return;
63                }
64            }
65            fWrittenStructs.push_back(&type);
66            this->writeLine("struct " + type.name() + " {");
67            fIndentation++;
68            for (const auto& f : type.fields()) {
69                this->writeModifiers(f.fModifiers, false);
70                // sizes (which must be static in structs) are part of the type name here
71                this->writeType(*f.fType);
72                this->writeLine(" " + f.fName + ";");
73            }
74            fIndentation--;
75            this->write("}");
76            break;
77        case Type::kVector_Kind:
78            this->writeType(type.componentType());
79            this->write(to_string(type.columns()));
80            break;
81        default:
82            this->write(type.name());
83    }
84}
85
86void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
87    switch (expr.fKind) {
88        case Expression::kBinary_Kind:
89            this->writeBinaryExpression((BinaryExpression&) expr, parentPrecedence);
90            break;
91        case Expression::kBoolLiteral_Kind:
92            this->writeBoolLiteral((BoolLiteral&) expr);
93            break;
94        case Expression::kConstructor_Kind:
95            this->writeConstructor((Constructor&) expr);
96            break;
97        case Expression::kIntLiteral_Kind:
98            this->writeIntLiteral((IntLiteral&) expr);
99            break;
100        case Expression::kFieldAccess_Kind:
101            this->writeFieldAccess(((FieldAccess&) expr));
102            break;
103        case Expression::kFloatLiteral_Kind:
104            this->writeFloatLiteral(((FloatLiteral&) expr));
105            break;
106        case Expression::kFunctionCall_Kind:
107            this->writeFunctionCall((FunctionCall&) expr);
108            break;
109        case Expression::kPrefix_Kind:
110            this->writePrefixExpression((PrefixExpression&) expr, parentPrecedence);
111            break;
112        case Expression::kPostfix_Kind:
113            this->writePostfixExpression((PostfixExpression&) expr, parentPrecedence);
114            break;
115        case Expression::kSetting_Kind:
116            this->writeSetting((Setting&) expr);
117            break;
118        case Expression::kSwizzle_Kind:
119            this->writeSwizzle((Swizzle&) expr);
120            break;
121        case Expression::kVariableReference_Kind:
122            this->writeVariableReference((VariableReference&) expr);
123            break;
124        case Expression::kTernary_Kind:
125            this->writeTernaryExpression((TernaryExpression&) expr, parentPrecedence);
126            break;
127        case Expression::kIndex_Kind:
128            this->writeIndexExpression((IndexExpression&) expr);
129            break;
130        default:
131            ABORT("unsupported expression: %s", expr.description().c_str());
132    }
133}
134
135void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
136    if (c.fFunction.fBuiltin && "atan" == c.fFunction.fName && 2 == c.fArguments.size()) {
137        this->write("atan2");
138    } else {
139        this->write(c.fFunction.fName);
140    }
141    this->write("(");
142    const char* separator = "";
143    if (this->requirements(c.fFunction) & kInputs_Requirement) {
144        this->write("_in");
145        separator = ", ";
146    }
147    if (this->requirements(c.fFunction) & kOutputs_Requirement) {
148        this->write(separator);
149        this->write("_out");
150        separator = ", ";
151    }
152    if (this->requirements(c.fFunction) & kUniforms_Requirement) {
153        this->write(separator);
154        this->write("_uniforms");
155        separator = ", ";
156    }
157    for (size_t i = 0; i < c.fArguments.size(); ++i) {
158        const Expression& arg = *c.fArguments[i];
159        this->write(separator);
160        separator = ", ";
161        if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
162            this->write("&");
163        }
164        this->writeExpression(arg, kSequence_Precedence);
165    }
166    this->write(")");
167}
168
169void MetalCodeGenerator::writeConstructor(const Constructor& c) {
170    this->writeType(c.fType);
171    this->write("(");
172    const char* separator = "";
173    int scalarCount = 0;
174    for (const auto& arg : c.fArguments) {
175        this->write(separator);
176        separator = ", ";
177        if (Type::kMatrix_Kind == c.fType.kind() && Type::kScalar_Kind == arg->fType.kind()) {
178            // float2x2(float, float, float, float) doesn't work in Metal 1, so we need to merge to
179            // float2x2(float2, float2).
180            if (!scalarCount) {
181                this->writeType(c.fType.componentType());
182                this->write(to_string(c.fType.rows()));
183                this->write("(");
184            }
185            ++scalarCount;
186        }
187        this->writeExpression(*arg, kSequence_Precedence);
188        if (scalarCount && scalarCount == c.fType.rows()) {
189            this->write(")");
190            scalarCount = 0;
191        }
192    }
193    this->write(")");
194}
195
196void MetalCodeGenerator::writeFragCoord() {
197    this->write("_in.position");
198}
199
200void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
201    switch (ref.fVariable.fModifiers.fLayout.fBuiltin) {
202        case SK_FRAGCOLOR_BUILTIN:
203            this->write("sk_FragColor");
204            break;
205        default:
206            if (Variable::kGlobal_Storage == ref.fVariable.fStorage) {
207                if (ref.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
208                    this->write("_in.");
209                } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
210                    this->write("_out.");
211                } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag) {
212                    this->write("_uniforms.");
213                } else {
214                    fErrors.error(ref.fVariable.fOffset, "Metal backend does not support global "
215                                  "variables");
216                }
217            }
218            this->write(ref.fVariable.fName);
219    }
220}
221
222void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
223    this->writeExpression(*expr.fBase, kPostfix_Precedence);
224    this->write("[");
225    this->writeExpression(*expr.fIndex, kTopLevel_Precedence);
226    this->write("]");
227}
228
229void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
230    if (FieldAccess::kDefault_OwnerKind == f.fOwnerKind) {
231        this->writeExpression(*f.fBase, kPostfix_Precedence);
232        this->write(".");
233    }
234    switch (f.fBase->fType.fields()[f.fFieldIndex].fModifiers.fLayout.fBuiltin) {
235        case SK_CLIPDISTANCE_BUILTIN:
236            this->write("gl_ClipDistance");
237            break;
238        case SK_POSITION_BUILTIN:
239            this->write("_out.position");
240            break;
241        default:
242            this->write(f.fBase->fType.fields()[f.fFieldIndex].fName);
243    }
244}
245
246void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
247    this->writeExpression(*swizzle.fBase, kPostfix_Precedence);
248    this->write(".");
249    for (int c : swizzle.fComponents) {
250        this->write(&("x\0y\0z\0w\0"[c * 2]));
251    }
252}
253
254MetalCodeGenerator::Precedence MetalCodeGenerator::GetBinaryPrecedence(Token::Kind op) {
255    switch (op) {
256        case Token::STAR:         // fall through
257        case Token::SLASH:        // fall through
258        case Token::PERCENT:      return MetalCodeGenerator::kMultiplicative_Precedence;
259        case Token::PLUS:         // fall through
260        case Token::MINUS:        return MetalCodeGenerator::kAdditive_Precedence;
261        case Token::SHL:          // fall through
262        case Token::SHR:          return MetalCodeGenerator::kShift_Precedence;
263        case Token::LT:           // fall through
264        case Token::GT:           // fall through
265        case Token::LTEQ:         // fall through
266        case Token::GTEQ:         return MetalCodeGenerator::kRelational_Precedence;
267        case Token::EQEQ:         // fall through
268        case Token::NEQ:          return MetalCodeGenerator::kEquality_Precedence;
269        case Token::BITWISEAND:   return MetalCodeGenerator::kBitwiseAnd_Precedence;
270        case Token::BITWISEXOR:   return MetalCodeGenerator::kBitwiseXor_Precedence;
271        case Token::BITWISEOR:    return MetalCodeGenerator::kBitwiseOr_Precedence;
272        case Token::LOGICALAND:   return MetalCodeGenerator::kLogicalAnd_Precedence;
273        case Token::LOGICALXOR:   return MetalCodeGenerator::kLogicalXor_Precedence;
274        case Token::LOGICALOR:    return MetalCodeGenerator::kLogicalOr_Precedence;
275        case Token::EQ:           // fall through
276        case Token::PLUSEQ:       // fall through
277        case Token::MINUSEQ:      // fall through
278        case Token::STAREQ:       // fall through
279        case Token::SLASHEQ:      // fall through
280        case Token::PERCENTEQ:    // fall through
281        case Token::SHLEQ:        // fall through
282        case Token::SHREQ:        // fall through
283        case Token::LOGICALANDEQ: // fall through
284        case Token::LOGICALXOREQ: // fall through
285        case Token::LOGICALOREQ:  // fall through
286        case Token::BITWISEANDEQ: // fall through
287        case Token::BITWISEXOREQ: // fall through
288        case Token::BITWISEOREQ:  return MetalCodeGenerator::kAssignment_Precedence;
289        case Token::COMMA:        return MetalCodeGenerator::kSequence_Precedence;
290        default: ABORT("unsupported binary operator");
291    }
292}
293
294void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
295                                               Precedence parentPrecedence) {
296    Precedence precedence = GetBinaryPrecedence(b.fOperator);
297    if (precedence >= parentPrecedence) {
298        this->write("(");
299    }
300    if (Compiler::IsAssignment(b.fOperator) &&
301        Expression::kVariableReference_Kind == b.fLeft->fKind &&
302        Variable::kParameter_Storage == ((VariableReference&) *b.fLeft).fVariable.fStorage &&
303        (((VariableReference&) *b.fLeft).fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
304        // writing to an out parameter. Since we have to turn those into pointers, we have to
305        // dereference it here.
306        this->write("*");
307    }
308    this->writeExpression(*b.fLeft, precedence);
309    if (b.fOperator != Token::EQ && Compiler::IsAssignment(b.fOperator) &&
310        Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) {
311        // This doesn't compile in Metal:
312        // float4 x = float4(1);
313        // x.xy *= float2x2(...);
314        // with the error message "non-const reference cannot bind to vector element",
315        // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
316        // as long as the LHS has no side effects, and hope for the best otherwise.
317        this->write(" = ");
318        this->writeExpression(*b.fLeft, kAssignment_Precedence);
319        this->write(" ");
320        String op = Compiler::OperatorName(b.fOperator);
321        ASSERT(op.endsWith("="));
322        this->write(op.substr(0, op.size() - 1).c_str());
323        this->write(" ");
324    } else {
325        this->write(String(" ") + Compiler::OperatorName(b.fOperator) + " ");
326    }
327    this->writeExpression(*b.fRight, precedence);
328    if (precedence >= parentPrecedence) {
329        this->write(")");
330    }
331}
332
333void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
334                                               Precedence parentPrecedence) {
335    if (kTernary_Precedence >= parentPrecedence) {
336        this->write("(");
337    }
338    this->writeExpression(*t.fTest, kTernary_Precedence);
339    this->write(" ? ");
340    this->writeExpression(*t.fIfTrue, kTernary_Precedence);
341    this->write(" : ");
342    this->writeExpression(*t.fIfFalse, kTernary_Precedence);
343    if (kTernary_Precedence >= parentPrecedence) {
344        this->write(")");
345    }
346}
347
348void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
349                                              Precedence parentPrecedence) {
350    if (kPrefix_Precedence >= parentPrecedence) {
351        this->write("(");
352    }
353    this->write(Compiler::OperatorName(p.fOperator));
354    this->writeExpression(*p.fOperand, kPrefix_Precedence);
355    if (kPrefix_Precedence >= parentPrecedence) {
356        this->write(")");
357    }
358}
359
360void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
361                                               Precedence parentPrecedence) {
362    if (kPostfix_Precedence >= parentPrecedence) {
363        this->write("(");
364    }
365    this->writeExpression(*p.fOperand, kPostfix_Precedence);
366    this->write(Compiler::OperatorName(p.fOperator));
367    if (kPostfix_Precedence >= parentPrecedence) {
368        this->write(")");
369    }
370}
371
372void MetalCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
373    this->write(b.fValue ? "true" : "false");
374}
375
376void MetalCodeGenerator::writeIntLiteral(const IntLiteral& i) {
377    if (i.fType == *fContext.fUInt_Type) {
378        this->write(to_string(i.fValue & 0xffffffff) + "u");
379    } else {
380        this->write(to_string((int32_t) i.fValue));
381    }
382}
383
384void MetalCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
385    this->write(to_string(f.fValue));
386}
387
388void MetalCodeGenerator::writeSetting(const Setting& s) {
389    ABORT("internal error; setting was not folded to a constant during compilation\n");
390}
391
392void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
393    const char* separator = "";
394    if ("main" == f.fDeclaration.fName) {
395        switch (fProgram.fKind) {
396            case Program::kFragment_Kind:
397                this->write("fragment half4 _frag");
398                break;
399            case Program::kVertex_Kind:
400                this->write("vertex Outputs _vert");
401                break;
402            default:
403                ASSERT(false);
404        }
405        this->write("(Inputs _in [[stage_in]]");
406        if (-1 != fUniformBuffer) {
407            this->write(", constant Uniforms& _uniforms [[buffer(" +
408                        to_string(fUniformBuffer) + ")]]");
409        }
410        separator = ", ";
411    } else {
412        this->writeType(f.fDeclaration.fReturnType);
413        this->write(" " + f.fDeclaration.fName + "(");
414        if (this->requirements(f.fDeclaration) & kInputs_Requirement) {
415            this->write("Inputs _in");
416            separator = ", ";
417        }
418        if (this->requirements(f.fDeclaration) & kOutputs_Requirement) {
419            this->write(separator);
420            this->write("thread Outputs& _out");
421            separator = ", ";
422        }
423        if (this->requirements(f.fDeclaration) & kUniforms_Requirement) {
424            this->write(separator);
425            this->write("Uniforms _uniforms");
426            separator = ", ";
427        }
428    }
429    for (const auto& param : f.fDeclaration.fParameters) {
430        this->write(separator);
431        separator = ", ";
432        this->writeModifiers(param->fModifiers, false);
433        std::vector<int> sizes;
434        const Type* type = &param->fType;
435        while (Type::kArray_Kind == type->kind()) {
436            sizes.push_back(type->columns());
437            type = &type->componentType();
438        }
439        this->writeType(*type);
440        if (param->fModifiers.fFlags & Modifiers::kOut_Flag) {
441            this->write("*");
442        }
443        this->write(" " + param->fName);
444        for (int s : sizes) {
445            if (s <= 0) {
446                this->write("[]");
447            } else {
448                this->write("[" + to_string(s) + "]");
449            }
450        }
451    }
452    this->writeLine(") {");
453
454    ASSERT(!fProgram.fSettings.fFragColorIsInOut);
455
456    if ("main" == f.fDeclaration.fName) {
457        switch (fProgram.fKind) {
458            case Program::kFragment_Kind:
459                this->writeLine("    half4 sk_FragColor;");
460                break;
461            case Program::kVertex_Kind:
462                this->writeLine("    Outputs _out;");
463                break;
464            default:
465                ASSERT(false);
466        }
467    }
468    fFunctionHeader = "";
469    OutputStream* oldOut = fOut;
470    StringStream buffer;
471    fOut = &buffer;
472    fIndentation++;
473    this->writeStatements(((Block&) *f.fBody).fStatements);
474    if ("main" == f.fDeclaration.fName) {
475        switch (fProgram.fKind) {
476            case Program::kFragment_Kind:
477                this->writeLine("return sk_FragColor;");
478                break;
479            case Program::kVertex_Kind:
480                this->writeLine("return _out;");
481                break;
482            default:
483                ASSERT(false);
484        }
485    }
486    fIndentation--;
487    this->writeLine("}");
488
489    fOut = oldOut;
490    this->write(fFunctionHeader);
491    this->write(buffer.str());
492}
493
494void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers,
495                                       bool globalContext) {
496    if (modifiers.fFlags & Modifiers::kOut_Flag) {
497        this->write("thread ");
498    }
499    if (modifiers.fFlags & Modifiers::kConst_Flag) {
500        this->write("const ");
501    }
502}
503
504void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
505    if ("sk_PerVertex" == intf.fTypeName) {
506        return;
507    }
508    this->writeModifiers(intf.fVariable.fModifiers, true);
509    this->writeLine(intf.fTypeName + " {");
510    fIndentation++;
511    const Type* structType = &intf.fVariable.fType;
512    while (Type::kArray_Kind == structType->kind()) {
513        structType = &structType->componentType();
514    }
515    for (const auto& f : structType->fields()) {
516        this->writeModifiers(f.fModifiers, false);
517        this->writeType(*f.fType);
518        this->writeLine(" " + f.fName + ";");
519    }
520    fIndentation--;
521    this->write("}");
522    if (intf.fInstanceName.size()) {
523        this->write(" ");
524        this->write(intf.fInstanceName);
525        for (const auto& size : intf.fSizes) {
526            this->write("[");
527            if (size) {
528                this->writeExpression(*size, kTopLevel_Precedence);
529            }
530            this->write("]");
531        }
532    }
533    this->writeLine(";");
534}
535
536void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
537    this->writeExpression(value, kTopLevel_Precedence);
538}
539
540void MetalCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool global) {
541    ASSERT(decl.fVars.size() > 0);
542    bool wroteType = false;
543    for (const auto& stmt : decl.fVars) {
544        VarDeclaration& var = (VarDeclaration&) *stmt;
545        if (var.fVar->fModifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag |
546                                           Modifiers::kUniform_Flag)) {
547            ASSERT(global);
548            continue;
549        }
550        if (wroteType) {
551            this->write(", ");
552        } else {
553            this->writeModifiers(var.fVar->fModifiers, global);
554            this->writeType(decl.fBaseType);
555            this->write(" ");
556            wroteType = true;
557        }
558        this->write(var.fVar->fName);
559        for (const auto& size : var.fSizes) {
560            this->write("[");
561            if (size) {
562                this->writeExpression(*size, kTopLevel_Precedence);
563            }
564            this->write("]");
565        }
566        if (var.fValue) {
567            this->write(" = ");
568            this->writeVarInitializer(*var.fVar, *var.fValue);
569        }
570        if (!fFoundImageDecl && var.fVar->fType == *fContext.fImage2D_Type) {
571            if (fProgram.fSettings.fCaps->imageLoadStoreExtensionString()) {
572                fHeader.writeText("#extension ");
573                fHeader.writeText(fProgram.fSettings.fCaps->imageLoadStoreExtensionString());
574                fHeader.writeText(" : require\n");
575            }
576            fFoundImageDecl = true;
577        }
578    }
579    if (wroteType) {
580        this->write(";");
581    }
582}
583
584void MetalCodeGenerator::writeStatement(const Statement& s) {
585    switch (s.fKind) {
586        case Statement::kBlock_Kind:
587            this->writeBlock((Block&) s);
588            break;
589        case Statement::kExpression_Kind:
590            this->writeExpression(*((ExpressionStatement&) s).fExpression, kTopLevel_Precedence);
591            this->write(";");
592            break;
593        case Statement::kReturn_Kind:
594            this->writeReturnStatement((ReturnStatement&) s);
595            break;
596        case Statement::kVarDeclarations_Kind:
597            this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, false);
598            break;
599        case Statement::kIf_Kind:
600            this->writeIfStatement((IfStatement&) s);
601            break;
602        case Statement::kFor_Kind:
603            this->writeForStatement((ForStatement&) s);
604            break;
605        case Statement::kWhile_Kind:
606            this->writeWhileStatement((WhileStatement&) s);
607            break;
608        case Statement::kDo_Kind:
609            this->writeDoStatement((DoStatement&) s);
610            break;
611        case Statement::kSwitch_Kind:
612            this->writeSwitchStatement((SwitchStatement&) s);
613            break;
614        case Statement::kBreak_Kind:
615            this->write("break;");
616            break;
617        case Statement::kContinue_Kind:
618            this->write("continue;");
619            break;
620        case Statement::kDiscard_Kind:
621            this->write("discard;");
622            break;
623        case Statement::kNop_Kind:
624            this->write(";");
625            break;
626        default:
627            ABORT("unsupported statement: %s", s.description().c_str());
628    }
629}
630
631void MetalCodeGenerator::writeStatements(const std::vector<std::unique_ptr<Statement>>& statements) {
632    for (const auto& s : statements) {
633        if (!s->isEmpty()) {
634            this->writeStatement(*s);
635            this->writeLine();
636        }
637    }
638}
639
640void MetalCodeGenerator::writeBlock(const Block& b) {
641    this->writeLine("{");
642    fIndentation++;
643    this->writeStatements(b.fStatements);
644    fIndentation--;
645    this->write("}");
646}
647
648void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
649    this->write("if (");
650    this->writeExpression(*stmt.fTest, kTopLevel_Precedence);
651    this->write(") ");
652    this->writeStatement(*stmt.fIfTrue);
653    if (stmt.fIfFalse) {
654        this->write(" else ");
655        this->writeStatement(*stmt.fIfFalse);
656    }
657}
658
659void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
660    this->write("for (");
661    if (f.fInitializer && !f.fInitializer->isEmpty()) {
662        this->writeStatement(*f.fInitializer);
663    } else {
664        this->write("; ");
665    }
666    if (f.fTest) {
667        this->writeExpression(*f.fTest, kTopLevel_Precedence);
668    }
669    this->write("; ");
670    if (f.fNext) {
671        this->writeExpression(*f.fNext, kTopLevel_Precedence);
672    }
673    this->write(") ");
674    this->writeStatement(*f.fStatement);
675}
676
677void MetalCodeGenerator::writeWhileStatement(const WhileStatement& w) {
678    this->write("while (");
679    this->writeExpression(*w.fTest, kTopLevel_Precedence);
680    this->write(") ");
681    this->writeStatement(*w.fStatement);
682}
683
684void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
685    this->write("do ");
686    this->writeStatement(*d.fStatement);
687    this->write(" while (");
688    this->writeExpression(*d.fTest, kTopLevel_Precedence);
689    this->write(");");
690}
691
692void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
693    this->write("switch (");
694    this->writeExpression(*s.fValue, kTopLevel_Precedence);
695    this->writeLine(") {");
696    fIndentation++;
697    for (const auto& c : s.fCases) {
698        if (c->fValue) {
699            this->write("case ");
700            this->writeExpression(*c->fValue, kTopLevel_Precedence);
701            this->writeLine(":");
702        } else {
703            this->writeLine("default:");
704        }
705        fIndentation++;
706        for (const auto& stmt : c->fStatements) {
707            this->writeStatement(*stmt);
708            this->writeLine();
709        }
710        fIndentation--;
711    }
712    fIndentation--;
713    this->write("}");
714}
715
716void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
717    this->write("return");
718    if (r.fExpression) {
719        this->write(" ");
720        this->writeExpression(*r.fExpression, kTopLevel_Precedence);
721    }
722    this->write(";");
723}
724
725void MetalCodeGenerator::writeHeader() {
726    this->write("#include <metal_stdlib>\n");
727    this->write("#include <simd/simd.h>\n");
728    this->write("using namespace metal;\n");
729}
730
731void MetalCodeGenerator::writeUniformStruct() {
732    for (const auto& e : fProgram.fElements) {
733        if (ProgramElement::kVar_Kind == e->fKind) {
734            VarDeclarations& decls = (VarDeclarations&) *e;
735            if (!decls.fVars.size()) {
736                continue;
737            }
738            const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
739            if (first.fModifiers.fFlags & Modifiers::kUniform_Flag) {
740                if (-1 == fUniformBuffer) {
741                    this->write("struct Uniforms {\n");
742                    fUniformBuffer = first.fModifiers.fLayout.fSet;
743                    if (-1 == fUniformBuffer) {
744                        fErrors.error(decls.fOffset, "Metal uniforms must have 'layout(set=...)'");
745                    }
746                } else if (first.fModifiers.fLayout.fSet != fUniformBuffer) {
747                    if (-1 == fUniformBuffer) {
748                        fErrors.error(decls.fOffset, "Metal backend requires all uniforms to have "
749                                    "the same 'layout(set=...)'");
750                    }
751                }
752                this->write("    ");
753                this->writeType(first.fType);
754                this->write(" ");
755                for (const auto& stmt : decls.fVars) {
756                    VarDeclaration& var = (VarDeclaration&) *stmt;
757                    this->write(var.fVar->fName);
758                }
759                this->write(";\n");
760            }
761        }
762    }
763    if (-1 != fUniformBuffer) {
764        this->write("};\n");
765    }
766}
767
768void MetalCodeGenerator::writeInputStruct() {
769    this->write("struct Inputs {\n");
770    if (Program::kFragment_Kind == fProgram.fKind) {
771        this->write("    float4 position [[position]];\n");
772    }
773    for (const auto& e : fProgram.fElements) {
774        if (ProgramElement::kVar_Kind == e->fKind) {
775            VarDeclarations& decls = (VarDeclarations&) *e;
776            if (!decls.fVars.size()) {
777                continue;
778            }
779            const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
780            if (first.fModifiers.fFlags & Modifiers::kIn_Flag &&
781                -1 == first.fModifiers.fLayout.fBuiltin) {
782                this->write("    ");
783                this->writeType(first.fType);
784                this->write(" ");
785                for (const auto& stmt : decls.fVars) {
786                    VarDeclaration& var = (VarDeclaration&) *stmt;
787                    this->write(var.fVar->fName);
788                    if (-1 != var.fVar->fModifiers.fLayout.fLocation) {
789                        this->write("  [[attribute(" +
790                                    to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]");
791                    }
792                }
793                this->write(";\n");
794            }
795        }
796    }
797    this->write("};\n");
798}
799
800void MetalCodeGenerator::writeOutputStruct() {
801    this->write("struct Outputs {\n");
802    this->write("    float4 position [[position]];\n");
803    for (const auto& e : fProgram.fElements) {
804        if (ProgramElement::kVar_Kind == e->fKind) {
805            VarDeclarations& decls = (VarDeclarations&) *e;
806            if (!decls.fVars.size()) {
807                continue;
808            }
809            const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
810            if (first.fModifiers.fFlags & Modifiers::kOut_Flag &&
811                -1 == first.fModifiers.fLayout.fBuiltin) {
812                this->write("    ");
813                this->writeType(first.fType);
814                this->write(" ");
815                for (const auto& stmt : decls.fVars) {
816                    VarDeclaration& var = (VarDeclaration&) *stmt;
817                    this->write(var.fVar->fName);
818                }
819                this->write(";\n");
820            }
821        }
822    }    this->write("};\n");
823}
824
825void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
826    switch (e.fKind) {
827        case ProgramElement::kExtension_Kind:
828            break;
829        case ProgramElement::kVar_Kind: {
830            VarDeclarations& decl = (VarDeclarations&) e;
831            if (decl.fVars.size() > 0) {
832                int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin;
833                if (-1 == builtin) {
834                    // normal var
835                    this->writeVarDeclarations(decl, true);
836                    this->writeLine();
837                } else if (SK_FRAGCOLOR_BUILTIN == builtin) {
838                    // ignore
839                }
840            }
841            break;
842        }
843        case ProgramElement::kInterfaceBlock_Kind:
844            this->writeInterfaceBlock((InterfaceBlock&) e);
845            break;
846        case ProgramElement::kFunction_Kind:
847            this->writeFunction((FunctionDefinition&) e);
848            break;
849        case ProgramElement::kModifiers_Kind:
850            this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true);
851            this->writeLine(";");
852            break;
853        default:
854            printf("%s\n", e.description().c_str());
855            ABORT("unsupported program element");
856    }
857}
858
859MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression& e) {
860    switch (e.fKind) {
861        case Expression::kFunctionCall_Kind: {
862            const FunctionCall& f = (const FunctionCall&) e;
863            Requirements result = this->requirements(f.fFunction);
864            for (const auto& e : f.fArguments) {
865                result |= this->requirements(*e);
866            }
867            return result;
868        }
869        case Expression::kConstructor_Kind: {
870            const Constructor& c = (const Constructor&) e;
871            Requirements result = kNo_Requirements;
872            for (const auto& e : c.fArguments) {
873                result |= this->requirements(*e);
874            }
875            return result;
876        }
877        case Expression::kFieldAccess_Kind:
878            return this->requirements(*((const FieldAccess&) e).fBase);
879        case Expression::kSwizzle_Kind:
880            return this->requirements(*((const Swizzle&) e).fBase);
881        case Expression::kBinary_Kind: {
882            const BinaryExpression& b = (const BinaryExpression&) e;
883            return this->requirements(*b.fLeft) | this->requirements(*b.fRight);
884        }
885        case Expression::kIndex_Kind: {
886            const IndexExpression& idx = (const IndexExpression&) e;
887            return this->requirements(*idx.fBase) | this->requirements(*idx.fIndex);
888        }
889        case Expression::kPrefix_Kind:
890            return this->requirements(*((const PrefixExpression&) e).fOperand);
891        case Expression::kPostfix_Kind:
892            return this->requirements(*((const PostfixExpression&) e).fOperand);
893        case Expression::kTernary_Kind: {
894            const TernaryExpression& t = (const TernaryExpression&) e;
895            return this->requirements(*t.fTest) | this->requirements(*t.fIfTrue) |
896                   this->requirements(*t.fIfFalse);
897        }
898        case Expression::kVariableReference_Kind: {
899            const VariableReference& v = (const VariableReference&) e;
900            Requirements result = kNo_Requirements;
901            if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
902                result = kInputs_Requirement;
903            } else if (Variable::kGlobal_Storage == v.fVariable.fStorage) {
904                if (v.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
905                    result = kInputs_Requirement;
906                } else if (v.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
907                    result = kOutputs_Requirement;
908                } else if (v.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag) {
909                    result = kUniforms_Requirement;
910                }
911            }
912            return result;
913        }
914        default:
915            return kNo_Requirements;
916    }
917}
918
919MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement& s) {
920    switch (s.fKind) {
921        case Statement::kBlock_Kind: {
922            Requirements result = kNo_Requirements;
923            for (const auto& child : ((const Block&) s).fStatements) {
924                result |= this->requirements(*child);
925            }
926            return result;
927        }
928        case Statement::kExpression_Kind:
929            return this->requirements(*((const ExpressionStatement&) s).fExpression);
930        case Statement::kReturn_Kind: {
931            const ReturnStatement& r = (const ReturnStatement&) s;
932            if (r.fExpression) {
933                return this->requirements(*r.fExpression);
934            }
935            return kNo_Requirements;
936        }
937        case Statement::kIf_Kind: {
938            const IfStatement& i = (const IfStatement&) s;
939            return this->requirements(*i.fTest) |
940                   this->requirements(*i.fIfTrue) |
941                   (i.fIfFalse && this->requirements(*i.fIfFalse));
942        }
943        case Statement::kFor_Kind: {
944            const ForStatement& f = (const ForStatement&) s;
945            return this->requirements(*f.fInitializer) |
946                   this->requirements(*f.fTest) |
947                   this->requirements(*f.fNext) |
948                   this->requirements(*f.fStatement);
949        }
950        case Statement::kWhile_Kind: {
951            const WhileStatement& w = (const WhileStatement&) s;
952            return this->requirements(*w.fTest) |
953                   this->requirements(*w.fStatement);
954        }
955        case Statement::kDo_Kind: {
956            const DoStatement& d = (const DoStatement&) s;
957            return this->requirements(*d.fTest) |
958                   this->requirements(*d.fStatement);
959        }
960        case Statement::kSwitch_Kind: {
961            const SwitchStatement& sw = (const SwitchStatement&) s;
962            Requirements result = this->requirements(*sw.fValue);
963            for (const auto& c : sw.fCases) {
964                for (const auto& st : c->fStatements) {
965                    result |= this->requirements(*st);
966                }
967            }
968            return result;
969        }
970        default:
971            return kNo_Requirements;
972    }
973}
974
975MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
976    if (f.fBuiltin) {
977        return kNo_Requirements;
978    }
979    auto found = fRequirements.find(&f);
980    if (found == fRequirements.end()) {
981        for (const auto& e : fProgram.fElements) {
982            if (ProgramElement::kFunction_Kind == e->fKind) {
983                const FunctionDefinition& def = (const FunctionDefinition&) *e;
984                if (&def.fDeclaration == &f) {
985                    Requirements reqs = this->requirements(*def.fBody);
986                    fRequirements[&f] = reqs;
987                    return reqs;
988                }
989            }
990        }
991    }
992    return found->second;
993}
994
995bool MetalCodeGenerator::generateCode() {
996    OutputStream* rawOut = fOut;
997    fOut = &fHeader;
998    fProgramKind = fProgram.fKind;
999    this->writeHeader();
1000    this->writeUniformStruct();
1001    this->writeInputStruct();
1002    if (Program::kVertex_Kind == fProgram.fKind) {
1003        this->writeOutputStruct();
1004    }
1005    StringStream body;
1006    fOut = &body;
1007    for (const auto& e : fProgram.fElements) {
1008        this->writeProgramElement(*e);
1009    }
1010    fOut = rawOut;
1011
1012    write_stringstream(fHeader, *rawOut);
1013    write_stringstream(body, *rawOut);
1014    return true;
1015}
1016
1017}
1018