1#include "llvm/ADT/STLExtras.h"
2#include "llvm/Analysis/BasicAliasAnalysis.h"
3#include "llvm/Analysis/Passes.h"
4#include "llvm/IR/DIBuilder.h"
5#include "llvm/IR/IRBuilder.h"
6#include "llvm/IR/LLVMContext.h"
7#include "llvm/IR/LegacyPassManager.h"
8#include "llvm/IR/Module.h"
9#include "llvm/IR/Verifier.h"
10#include "llvm/Support/TargetSelect.h"
11#include "llvm/Transforms/Scalar.h"
12#include <cctype>
13#include <cstdio>
14#include <map>
15#include <string>
16#include <vector>
17#include "../include/KaleidoscopeJIT.h"
18
19using namespace llvm;
20using namespace llvm::orc;
21
22//===----------------------------------------------------------------------===//
23// Lexer
24//===----------------------------------------------------------------------===//
25
26// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
27// of these for known things.
28enum Token {
29  tok_eof = -1,
30
31  // commands
32  tok_def = -2,
33  tok_extern = -3,
34
35  // primary
36  tok_identifier = -4,
37  tok_number = -5,
38
39  // control
40  tok_if = -6,
41  tok_then = -7,
42  tok_else = -8,
43  tok_for = -9,
44  tok_in = -10,
45
46  // operators
47  tok_binary = -11,
48  tok_unary = -12,
49
50  // var definition
51  tok_var = -13
52};
53
54std::string getTokName(int Tok) {
55  switch (Tok) {
56  case tok_eof:
57    return "eof";
58  case tok_def:
59    return "def";
60  case tok_extern:
61    return "extern";
62  case tok_identifier:
63    return "identifier";
64  case tok_number:
65    return "number";
66  case tok_if:
67    return "if";
68  case tok_then:
69    return "then";
70  case tok_else:
71    return "else";
72  case tok_for:
73    return "for";
74  case tok_in:
75    return "in";
76  case tok_binary:
77    return "binary";
78  case tok_unary:
79    return "unary";
80  case tok_var:
81    return "var";
82  }
83  return std::string(1, (char)Tok);
84}
85
86namespace {
87class PrototypeAST;
88class ExprAST;
89}
90static IRBuilder<> Builder(getGlobalContext());
91struct DebugInfo {
92  DICompileUnit *TheCU;
93  DIType *DblTy;
94  std::vector<DIScope *> LexicalBlocks;
95
96  void emitLocation(ExprAST *AST);
97  DIType *getDoubleTy();
98} KSDbgInfo;
99
100struct SourceLocation {
101  int Line;
102  int Col;
103};
104static SourceLocation CurLoc;
105static SourceLocation LexLoc = {1, 0};
106
107static int advance() {
108  int LastChar = getchar();
109
110  if (LastChar == '\n' || LastChar == '\r') {
111    LexLoc.Line++;
112    LexLoc.Col = 0;
113  } else
114    LexLoc.Col++;
115  return LastChar;
116}
117
118static std::string IdentifierStr; // Filled in if tok_identifier
119static double NumVal;             // Filled in if tok_number
120
121/// gettok - Return the next token from standard input.
122static int gettok() {
123  static int LastChar = ' ';
124
125  // Skip any whitespace.
126  while (isspace(LastChar))
127    LastChar = advance();
128
129  CurLoc = LexLoc;
130
131  if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
132    IdentifierStr = LastChar;
133    while (isalnum((LastChar = advance())))
134      IdentifierStr += LastChar;
135
136    if (IdentifierStr == "def")
137      return tok_def;
138    if (IdentifierStr == "extern")
139      return tok_extern;
140    if (IdentifierStr == "if")
141      return tok_if;
142    if (IdentifierStr == "then")
143      return tok_then;
144    if (IdentifierStr == "else")
145      return tok_else;
146    if (IdentifierStr == "for")
147      return tok_for;
148    if (IdentifierStr == "in")
149      return tok_in;
150    if (IdentifierStr == "binary")
151      return tok_binary;
152    if (IdentifierStr == "unary")
153      return tok_unary;
154    if (IdentifierStr == "var")
155      return tok_var;
156    return tok_identifier;
157  }
158
159  if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
160    std::string NumStr;
161    do {
162      NumStr += LastChar;
163      LastChar = advance();
164    } while (isdigit(LastChar) || LastChar == '.');
165
166    NumVal = strtod(NumStr.c_str(), nullptr);
167    return tok_number;
168  }
169
170  if (LastChar == '#') {
171    // Comment until end of line.
172    do
173      LastChar = advance();
174    while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
175
176    if (LastChar != EOF)
177      return gettok();
178  }
179
180  // Check for end of file.  Don't eat the EOF.
181  if (LastChar == EOF)
182    return tok_eof;
183
184  // Otherwise, just return the character as its ascii value.
185  int ThisChar = LastChar;
186  LastChar = advance();
187  return ThisChar;
188}
189
190//===----------------------------------------------------------------------===//
191// Abstract Syntax Tree (aka Parse Tree)
192//===----------------------------------------------------------------------===//
193namespace {
194
195raw_ostream &indent(raw_ostream &O, int size) {
196  return O << std::string(size, ' ');
197}
198
199/// ExprAST - Base class for all expression nodes.
200class ExprAST {
201  SourceLocation Loc;
202
203public:
204  ExprAST(SourceLocation Loc = CurLoc) : Loc(Loc) {}
205  virtual ~ExprAST() {}
206  virtual Value *codegen() = 0;
207  int getLine() const { return Loc.Line; }
208  int getCol() const { return Loc.Col; }
209  virtual raw_ostream &dump(raw_ostream &out, int ind) {
210    return out << ':' << getLine() << ':' << getCol() << '\n';
211  }
212};
213
214/// NumberExprAST - Expression class for numeric literals like "1.0".
215class NumberExprAST : public ExprAST {
216  double Val;
217
218public:
219  NumberExprAST(double Val) : Val(Val) {}
220  raw_ostream &dump(raw_ostream &out, int ind) override {
221    return ExprAST::dump(out << Val, ind);
222  }
223  Value *codegen() override;
224};
225
226/// VariableExprAST - Expression class for referencing a variable, like "a".
227class VariableExprAST : public ExprAST {
228  std::string Name;
229
230public:
231  VariableExprAST(SourceLocation Loc, const std::string &Name)
232      : ExprAST(Loc), Name(Name) {}
233  const std::string &getName() const { return Name; }
234  Value *codegen() override;
235  raw_ostream &dump(raw_ostream &out, int ind) override {
236    return ExprAST::dump(out << Name, ind);
237  }
238};
239
240/// UnaryExprAST - Expression class for a unary operator.
241class UnaryExprAST : public ExprAST {
242  char Opcode;
243  std::unique_ptr<ExprAST> Operand;
244
245public:
246  UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
247      : Opcode(Opcode), Operand(std::move(Operand)) {}
248  Value *codegen() override;
249  raw_ostream &dump(raw_ostream &out, int ind) override {
250    ExprAST::dump(out << "unary" << Opcode, ind);
251    Operand->dump(out, ind + 1);
252    return out;
253  }
254};
255
256/// BinaryExprAST - Expression class for a binary operator.
257class BinaryExprAST : public ExprAST {
258  char Op;
259  std::unique_ptr<ExprAST> LHS, RHS;
260
261public:
262  BinaryExprAST(SourceLocation Loc, char Op, std::unique_ptr<ExprAST> LHS,
263                std::unique_ptr<ExprAST> RHS)
264      : ExprAST(Loc), Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
265  Value *codegen() override;
266  raw_ostream &dump(raw_ostream &out, int ind) override {
267    ExprAST::dump(out << "binary" << Op, ind);
268    LHS->dump(indent(out, ind) << "LHS:", ind + 1);
269    RHS->dump(indent(out, ind) << "RHS:", ind + 1);
270    return out;
271  }
272};
273
274/// CallExprAST - Expression class for function calls.
275class CallExprAST : public ExprAST {
276  std::string Callee;
277  std::vector<std::unique_ptr<ExprAST>> Args;
278
279public:
280  CallExprAST(SourceLocation Loc, const std::string &Callee,
281              std::vector<std::unique_ptr<ExprAST>> Args)
282      : ExprAST(Loc), Callee(Callee), Args(std::move(Args)) {}
283  Value *codegen() override;
284  raw_ostream &dump(raw_ostream &out, int ind) override {
285    ExprAST::dump(out << "call " << Callee, ind);
286    for (const auto &Arg : Args)
287      Arg->dump(indent(out, ind + 1), ind + 1);
288    return out;
289  }
290};
291
292/// IfExprAST - Expression class for if/then/else.
293class IfExprAST : public ExprAST {
294  std::unique_ptr<ExprAST> Cond, Then, Else;
295
296public:
297  IfExprAST(SourceLocation Loc, std::unique_ptr<ExprAST> Cond,
298            std::unique_ptr<ExprAST> Then, std::unique_ptr<ExprAST> Else)
299      : ExprAST(Loc), Cond(std::move(Cond)), Then(std::move(Then)),
300        Else(std::move(Else)) {}
301  Value *codegen() override;
302  raw_ostream &dump(raw_ostream &out, int ind) override {
303    ExprAST::dump(out << "if", ind);
304    Cond->dump(indent(out, ind) << "Cond:", ind + 1);
305    Then->dump(indent(out, ind) << "Then:", ind + 1);
306    Else->dump(indent(out, ind) << "Else:", ind + 1);
307    return out;
308  }
309};
310
311/// ForExprAST - Expression class for for/in.
312class ForExprAST : public ExprAST {
313  std::string VarName;
314  std::unique_ptr<ExprAST> Start, End, Step, Body;
315
316public:
317  ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
318             std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
319             std::unique_ptr<ExprAST> Body)
320      : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
321        Step(std::move(Step)), Body(std::move(Body)) {}
322  Value *codegen() override;
323  raw_ostream &dump(raw_ostream &out, int ind) override {
324    ExprAST::dump(out << "for", ind);
325    Start->dump(indent(out, ind) << "Cond:", ind + 1);
326    End->dump(indent(out, ind) << "End:", ind + 1);
327    Step->dump(indent(out, ind) << "Step:", ind + 1);
328    Body->dump(indent(out, ind) << "Body:", ind + 1);
329    return out;
330  }
331};
332
333/// VarExprAST - Expression class for var/in
334class VarExprAST : public ExprAST {
335  std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
336  std::unique_ptr<ExprAST> Body;
337
338public:
339  VarExprAST(
340      std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames,
341      std::unique_ptr<ExprAST> Body)
342      : VarNames(std::move(VarNames)), Body(std::move(Body)) {}
343  Value *codegen() override;
344  raw_ostream &dump(raw_ostream &out, int ind) override {
345    ExprAST::dump(out << "var", ind);
346    for (const auto &NamedVar : VarNames)
347      NamedVar.second->dump(indent(out, ind) << NamedVar.first << ':', ind + 1);
348    Body->dump(indent(out, ind) << "Body:", ind + 1);
349    return out;
350  }
351};
352
353/// PrototypeAST - This class represents the "prototype" for a function,
354/// which captures its name, and its argument names (thus implicitly the number
355/// of arguments the function takes), as well as if it is an operator.
356class PrototypeAST {
357  std::string Name;
358  std::vector<std::string> Args;
359  bool IsOperator;
360  unsigned Precedence; // Precedence if a binary op.
361  int Line;
362
363public:
364  PrototypeAST(SourceLocation Loc, const std::string &Name,
365               std::vector<std::string> Args, bool IsOperator = false,
366               unsigned Prec = 0)
367      : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
368        Precedence(Prec), Line(Loc.Line) {}
369  Function *codegen();
370  const std::string &getName() const { return Name; }
371
372  bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
373  bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
374
375  char getOperatorName() const {
376    assert(isUnaryOp() || isBinaryOp());
377    return Name[Name.size() - 1];
378  }
379
380  unsigned getBinaryPrecedence() const { return Precedence; }
381  int getLine() const { return Line; }
382};
383
384/// FunctionAST - This class represents a function definition itself.
385class FunctionAST {
386  std::unique_ptr<PrototypeAST> Proto;
387  std::unique_ptr<ExprAST> Body;
388
389public:
390  FunctionAST(std::unique_ptr<PrototypeAST> Proto,
391              std::unique_ptr<ExprAST> Body)
392      : Proto(std::move(Proto)), Body(std::move(Body)) {}
393  Function *codegen();
394  raw_ostream &dump(raw_ostream &out, int ind) {
395    indent(out, ind) << "FunctionAST\n";
396    ++ind;
397    indent(out, ind) << "Body:";
398    return Body ? Body->dump(out, ind) : out << "null\n";
399  }
400};
401} // end anonymous namespace
402
403//===----------------------------------------------------------------------===//
404// Parser
405//===----------------------------------------------------------------------===//
406
407/// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
408/// token the parser is looking at.  getNextToken reads another token from the
409/// lexer and updates CurTok with its results.
410static int CurTok;
411static int getNextToken() { return CurTok = gettok(); }
412
413/// BinopPrecedence - This holds the precedence for each binary operator that is
414/// defined.
415static std::map<char, int> BinopPrecedence;
416
417/// GetTokPrecedence - Get the precedence of the pending binary operator token.
418static int GetTokPrecedence() {
419  if (!isascii(CurTok))
420    return -1;
421
422  // Make sure it's a declared binop.
423  int TokPrec = BinopPrecedence[CurTok];
424  if (TokPrec <= 0)
425    return -1;
426  return TokPrec;
427}
428
429/// Error* - These are little helper functions for error handling.
430std::unique_ptr<ExprAST> Error(const char *Str) {
431  fprintf(stderr, "Error: %s\n", Str);
432  return nullptr;
433}
434
435std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
436  Error(Str);
437  return nullptr;
438}
439
440static std::unique_ptr<ExprAST> ParseExpression();
441
442/// numberexpr ::= number
443static std::unique_ptr<ExprAST> ParseNumberExpr() {
444  auto Result = llvm::make_unique<NumberExprAST>(NumVal);
445  getNextToken(); // consume the number
446  return std::move(Result);
447}
448
449/// parenexpr ::= '(' expression ')'
450static std::unique_ptr<ExprAST> ParseParenExpr() {
451  getNextToken(); // eat (.
452  auto V = ParseExpression();
453  if (!V)
454    return nullptr;
455
456  if (CurTok != ')')
457    return Error("expected ')'");
458  getNextToken(); // eat ).
459  return V;
460}
461
462/// identifierexpr
463///   ::= identifier
464///   ::= identifier '(' expression* ')'
465static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
466  std::string IdName = IdentifierStr;
467
468  SourceLocation LitLoc = CurLoc;
469
470  getNextToken(); // eat identifier.
471
472  if (CurTok != '(') // Simple variable ref.
473    return llvm::make_unique<VariableExprAST>(LitLoc, IdName);
474
475  // Call.
476  getNextToken(); // eat (
477  std::vector<std::unique_ptr<ExprAST>> Args;
478  if (CurTok != ')') {
479    while (1) {
480      if (auto Arg = ParseExpression())
481        Args.push_back(std::move(Arg));
482      else
483        return nullptr;
484
485      if (CurTok == ')')
486        break;
487
488      if (CurTok != ',')
489        return Error("Expected ')' or ',' in argument list");
490      getNextToken();
491    }
492  }
493
494  // Eat the ')'.
495  getNextToken();
496
497  return llvm::make_unique<CallExprAST>(LitLoc, IdName, std::move(Args));
498}
499
500/// ifexpr ::= 'if' expression 'then' expression 'else' expression
501static std::unique_ptr<ExprAST> ParseIfExpr() {
502  SourceLocation IfLoc = CurLoc;
503
504  getNextToken(); // eat the if.
505
506  // condition.
507  auto Cond = ParseExpression();
508  if (!Cond)
509    return nullptr;
510
511  if (CurTok != tok_then)
512    return Error("expected then");
513  getNextToken(); // eat the then
514
515  auto Then = ParseExpression();
516  if (!Then)
517    return nullptr;
518
519  if (CurTok != tok_else)
520    return Error("expected else");
521
522  getNextToken();
523
524  auto Else = ParseExpression();
525  if (!Else)
526    return nullptr;
527
528  return llvm::make_unique<IfExprAST>(IfLoc, std::move(Cond), std::move(Then),
529                                      std::move(Else));
530}
531
532/// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
533static std::unique_ptr<ExprAST> ParseForExpr() {
534  getNextToken(); // eat the for.
535
536  if (CurTok != tok_identifier)
537    return Error("expected identifier after for");
538
539  std::string IdName = IdentifierStr;
540  getNextToken(); // eat identifier.
541
542  if (CurTok != '=')
543    return Error("expected '=' after for");
544  getNextToken(); // eat '='.
545
546  auto Start = ParseExpression();
547  if (!Start)
548    return nullptr;
549  if (CurTok != ',')
550    return Error("expected ',' after for start value");
551  getNextToken();
552
553  auto End = ParseExpression();
554  if (!End)
555    return nullptr;
556
557  // The step value is optional.
558  std::unique_ptr<ExprAST> Step;
559  if (CurTok == ',') {
560    getNextToken();
561    Step = ParseExpression();
562    if (!Step)
563      return nullptr;
564  }
565
566  if (CurTok != tok_in)
567    return Error("expected 'in' after for");
568  getNextToken(); // eat 'in'.
569
570  auto Body = ParseExpression();
571  if (!Body)
572    return nullptr;
573
574  return llvm::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
575                                       std::move(Step), std::move(Body));
576}
577
578/// varexpr ::= 'var' identifier ('=' expression)?
579//                    (',' identifier ('=' expression)?)* 'in' expression
580static std::unique_ptr<ExprAST> ParseVarExpr() {
581  getNextToken(); // eat the var.
582
583  std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
584
585  // At least one variable name is required.
586  if (CurTok != tok_identifier)
587    return Error("expected identifier after var");
588
589  while (1) {
590    std::string Name = IdentifierStr;
591    getNextToken(); // eat identifier.
592
593    // Read the optional initializer.
594    std::unique_ptr<ExprAST> Init = nullptr;
595    if (CurTok == '=') {
596      getNextToken(); // eat the '='.
597
598      Init = ParseExpression();
599      if (!Init)
600        return nullptr;
601    }
602
603    VarNames.push_back(std::make_pair(Name, std::move(Init)));
604
605    // End of var list, exit loop.
606    if (CurTok != ',')
607      break;
608    getNextToken(); // eat the ','.
609
610    if (CurTok != tok_identifier)
611      return Error("expected identifier list after var");
612  }
613
614  // At this point, we have to have 'in'.
615  if (CurTok != tok_in)
616    return Error("expected 'in' keyword after 'var'");
617  getNextToken(); // eat 'in'.
618
619  auto Body = ParseExpression();
620  if (!Body)
621    return nullptr;
622
623  return llvm::make_unique<VarExprAST>(std::move(VarNames), std::move(Body));
624}
625
626/// primary
627///   ::= identifierexpr
628///   ::= numberexpr
629///   ::= parenexpr
630///   ::= ifexpr
631///   ::= forexpr
632///   ::= varexpr
633static std::unique_ptr<ExprAST> ParsePrimary() {
634  switch (CurTok) {
635  default:
636    return Error("unknown token when expecting an expression");
637  case tok_identifier:
638    return ParseIdentifierExpr();
639  case tok_number:
640    return ParseNumberExpr();
641  case '(':
642    return ParseParenExpr();
643  case tok_if:
644    return ParseIfExpr();
645  case tok_for:
646    return ParseForExpr();
647  case tok_var:
648    return ParseVarExpr();
649  }
650}
651
652/// unary
653///   ::= primary
654///   ::= '!' unary
655static std::unique_ptr<ExprAST> ParseUnary() {
656  // If the current token is not an operator, it must be a primary expr.
657  if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
658    return ParsePrimary();
659
660  // If this is a unary operator, read it.
661  int Opc = CurTok;
662  getNextToken();
663  if (auto Operand = ParseUnary())
664    return llvm::make_unique<UnaryExprAST>(Opc, std::move(Operand));
665  return nullptr;
666}
667
668/// binoprhs
669///   ::= ('+' unary)*
670static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
671                                              std::unique_ptr<ExprAST> LHS) {
672  // If this is a binop, find its precedence.
673  while (1) {
674    int TokPrec = GetTokPrecedence();
675
676    // If this is a binop that binds at least as tightly as the current binop,
677    // consume it, otherwise we are done.
678    if (TokPrec < ExprPrec)
679      return LHS;
680
681    // Okay, we know this is a binop.
682    int BinOp = CurTok;
683    SourceLocation BinLoc = CurLoc;
684    getNextToken(); // eat binop
685
686    // Parse the unary expression after the binary operator.
687    auto RHS = ParseUnary();
688    if (!RHS)
689      return nullptr;
690
691    // If BinOp binds less tightly with RHS than the operator after RHS, let
692    // the pending operator take RHS as its LHS.
693    int NextPrec = GetTokPrecedence();
694    if (TokPrec < NextPrec) {
695      RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
696      if (!RHS)
697        return nullptr;
698    }
699
700    // Merge LHS/RHS.
701    LHS = llvm::make_unique<BinaryExprAST>(BinLoc, BinOp, std::move(LHS),
702                                           std::move(RHS));
703  }
704}
705
706/// expression
707///   ::= unary binoprhs
708///
709static std::unique_ptr<ExprAST> ParseExpression() {
710  auto LHS = ParseUnary();
711  if (!LHS)
712    return nullptr;
713
714  return ParseBinOpRHS(0, std::move(LHS));
715}
716
717/// prototype
718///   ::= id '(' id* ')'
719///   ::= binary LETTER number? (id, id)
720///   ::= unary LETTER (id)
721static std::unique_ptr<PrototypeAST> ParsePrototype() {
722  std::string FnName;
723
724  SourceLocation FnLoc = CurLoc;
725
726  unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
727  unsigned BinaryPrecedence = 30;
728
729  switch (CurTok) {
730  default:
731    return ErrorP("Expected function name in prototype");
732  case tok_identifier:
733    FnName = IdentifierStr;
734    Kind = 0;
735    getNextToken();
736    break;
737  case tok_unary:
738    getNextToken();
739    if (!isascii(CurTok))
740      return ErrorP("Expected unary operator");
741    FnName = "unary";
742    FnName += (char)CurTok;
743    Kind = 1;
744    getNextToken();
745    break;
746  case tok_binary:
747    getNextToken();
748    if (!isascii(CurTok))
749      return ErrorP("Expected binary operator");
750    FnName = "binary";
751    FnName += (char)CurTok;
752    Kind = 2;
753    getNextToken();
754
755    // Read the precedence if present.
756    if (CurTok == tok_number) {
757      if (NumVal < 1 || NumVal > 100)
758        return ErrorP("Invalid precedecnce: must be 1..100");
759      BinaryPrecedence = (unsigned)NumVal;
760      getNextToken();
761    }
762    break;
763  }
764
765  if (CurTok != '(')
766    return ErrorP("Expected '(' in prototype");
767
768  std::vector<std::string> ArgNames;
769  while (getNextToken() == tok_identifier)
770    ArgNames.push_back(IdentifierStr);
771  if (CurTok != ')')
772    return ErrorP("Expected ')' in prototype");
773
774  // success.
775  getNextToken(); // eat ')'.
776
777  // Verify right number of names for operator.
778  if (Kind && ArgNames.size() != Kind)
779    return ErrorP("Invalid number of operands for operator");
780
781  return llvm::make_unique<PrototypeAST>(FnLoc, FnName, ArgNames, Kind != 0,
782                                         BinaryPrecedence);
783}
784
785/// definition ::= 'def' prototype expression
786static std::unique_ptr<FunctionAST> ParseDefinition() {
787  getNextToken(); // eat def.
788  auto Proto = ParsePrototype();
789  if (!Proto)
790    return nullptr;
791
792  if (auto E = ParseExpression())
793    return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
794  return nullptr;
795}
796
797/// toplevelexpr ::= expression
798static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
799  SourceLocation FnLoc = CurLoc;
800  if (auto E = ParseExpression()) {
801    // Make an anonymous proto.
802    auto Proto = llvm::make_unique<PrototypeAST>(FnLoc, "__anon_expr",
803                                                 std::vector<std::string>());
804    return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
805  }
806  return nullptr;
807}
808
809/// external ::= 'extern' prototype
810static std::unique_ptr<PrototypeAST> ParseExtern() {
811  getNextToken(); // eat extern.
812  return ParsePrototype();
813}
814
815//===----------------------------------------------------------------------===//
816// Debug Info Support
817//===----------------------------------------------------------------------===//
818
819static std::unique_ptr<DIBuilder> DBuilder;
820
821DIType *DebugInfo::getDoubleTy() {
822  if (DblTy)
823    return DblTy;
824
825  DblTy = DBuilder->createBasicType("double", 64, 64, dwarf::DW_ATE_float);
826  return DblTy;
827}
828
829void DebugInfo::emitLocation(ExprAST *AST) {
830  if (!AST)
831    return Builder.SetCurrentDebugLocation(DebugLoc());
832  DIScope *Scope;
833  if (LexicalBlocks.empty())
834    Scope = TheCU;
835  else
836    Scope = LexicalBlocks.back();
837  Builder.SetCurrentDebugLocation(
838      DebugLoc::get(AST->getLine(), AST->getCol(), Scope));
839}
840
841static DISubroutineType *CreateFunctionType(unsigned NumArgs, DIFile *Unit) {
842  SmallVector<Metadata *, 8> EltTys;
843  DIType *DblTy = KSDbgInfo.getDoubleTy();
844
845  // Add the result type.
846  EltTys.push_back(DblTy);
847
848  for (unsigned i = 0, e = NumArgs; i != e; ++i)
849    EltTys.push_back(DblTy);
850
851  return DBuilder->createSubroutineType(DBuilder->getOrCreateTypeArray(EltTys));
852}
853
854//===----------------------------------------------------------------------===//
855// Code Generation
856//===----------------------------------------------------------------------===//
857
858static std::unique_ptr<Module> TheModule;
859static std::map<std::string, AllocaInst *> NamedValues;
860static std::unique_ptr<KaleidoscopeJIT> TheJIT;
861static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
862
863Value *ErrorV(const char *Str) {
864  Error(Str);
865  return nullptr;
866}
867
868Function *getFunction(std::string Name) {
869  // First, see if the function has already been added to the current module.
870  if (auto *F = TheModule->getFunction(Name))
871    return F;
872
873  // If not, check whether we can codegen the declaration from some existing
874  // prototype.
875  auto FI = FunctionProtos.find(Name);
876  if (FI != FunctionProtos.end())
877    return FI->second->codegen();
878
879  // If no existing prototype exists, return null.
880  return nullptr;
881}
882
883/// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of
884/// the function.  This is used for mutable variables etc.
885static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
886                                          const std::string &VarName) {
887  IRBuilder<> TmpB(&TheFunction->getEntryBlock(),
888                   TheFunction->getEntryBlock().begin());
889  return TmpB.CreateAlloca(Type::getDoubleTy(getGlobalContext()), nullptr,
890                           VarName.c_str());
891}
892
893Value *NumberExprAST::codegen() {
894  KSDbgInfo.emitLocation(this);
895  return ConstantFP::get(getGlobalContext(), APFloat(Val));
896}
897
898Value *VariableExprAST::codegen() {
899  // Look this variable up in the function.
900  Value *V = NamedValues[Name];
901  if (!V)
902    return ErrorV("Unknown variable name");
903
904  KSDbgInfo.emitLocation(this);
905  // Load the value.
906  return Builder.CreateLoad(V, Name.c_str());
907}
908
909Value *UnaryExprAST::codegen() {
910  Value *OperandV = Operand->codegen();
911  if (!OperandV)
912    return nullptr;
913
914  Function *F = getFunction(std::string("unary") + Opcode);
915  if (!F)
916    return ErrorV("Unknown unary operator");
917
918  KSDbgInfo.emitLocation(this);
919  return Builder.CreateCall(F, OperandV, "unop");
920}
921
922Value *BinaryExprAST::codegen() {
923  KSDbgInfo.emitLocation(this);
924
925  // Special case '=' because we don't want to emit the LHS as an expression.
926  if (Op == '=') {
927    // Assignment requires the LHS to be an identifier.
928    // This assume we're building without RTTI because LLVM builds that way by
929    // default.  If you build LLVM with RTTI this can be changed to a
930    // dynamic_cast for automatic error checking.
931    VariableExprAST *LHSE = static_cast<VariableExprAST *>(LHS.get());
932    if (!LHSE)
933      return ErrorV("destination of '=' must be a variable");
934    // Codegen the RHS.
935    Value *Val = RHS->codegen();
936    if (!Val)
937      return nullptr;
938
939    // Look up the name.
940    Value *Variable = NamedValues[LHSE->getName()];
941    if (!Variable)
942      return ErrorV("Unknown variable name");
943
944    Builder.CreateStore(Val, Variable);
945    return Val;
946  }
947
948  Value *L = LHS->codegen();
949  Value *R = RHS->codegen();
950  if (!L || !R)
951    return nullptr;
952
953  switch (Op) {
954  case '+':
955    return Builder.CreateFAdd(L, R, "addtmp");
956  case '-':
957    return Builder.CreateFSub(L, R, "subtmp");
958  case '*':
959    return Builder.CreateFMul(L, R, "multmp");
960  case '<':
961    L = Builder.CreateFCmpULT(L, R, "cmptmp");
962    // Convert bool 0/1 to double 0.0 or 1.0
963    return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
964                                "booltmp");
965  default:
966    break;
967  }
968
969  // If it wasn't a builtin binary operator, it must be a user defined one. Emit
970  // a call to it.
971  Function *F = getFunction(std::string("binary") + Op);
972  assert(F && "binary operator not found!");
973
974  Value *Ops[] = {L, R};
975  return Builder.CreateCall(F, Ops, "binop");
976}
977
978Value *CallExprAST::codegen() {
979  KSDbgInfo.emitLocation(this);
980
981  // Look up the name in the global module table.
982  Function *CalleeF = getFunction(Callee);
983  if (!CalleeF)
984    return ErrorV("Unknown function referenced");
985
986  // If argument mismatch error.
987  if (CalleeF->arg_size() != Args.size())
988    return ErrorV("Incorrect # arguments passed");
989
990  std::vector<Value *> ArgsV;
991  for (unsigned i = 0, e = Args.size(); i != e; ++i) {
992    ArgsV.push_back(Args[i]->codegen());
993    if (!ArgsV.back())
994      return nullptr;
995  }
996
997  return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
998}
999
1000Value *IfExprAST::codegen() {
1001  KSDbgInfo.emitLocation(this);
1002
1003  Value *CondV = Cond->codegen();
1004  if (!CondV)
1005    return nullptr;
1006
1007  // Convert condition to a bool by comparing equal to 0.0.
1008  CondV = Builder.CreateFCmpONE(
1009      CondV, ConstantFP::get(getGlobalContext(), APFloat(0.0)), "ifcond");
1010
1011  Function *TheFunction = Builder.GetInsertBlock()->getParent();
1012
1013  // Create blocks for the then and else cases.  Insert the 'then' block at the
1014  // end of the function.
1015  BasicBlock *ThenBB =
1016      BasicBlock::Create(getGlobalContext(), "then", TheFunction);
1017  BasicBlock *ElseBB = BasicBlock::Create(getGlobalContext(), "else");
1018  BasicBlock *MergeBB = BasicBlock::Create(getGlobalContext(), "ifcont");
1019
1020  Builder.CreateCondBr(CondV, ThenBB, ElseBB);
1021
1022  // Emit then value.
1023  Builder.SetInsertPoint(ThenBB);
1024
1025  Value *ThenV = Then->codegen();
1026  if (!ThenV)
1027    return nullptr;
1028
1029  Builder.CreateBr(MergeBB);
1030  // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
1031  ThenBB = Builder.GetInsertBlock();
1032
1033  // Emit else block.
1034  TheFunction->getBasicBlockList().push_back(ElseBB);
1035  Builder.SetInsertPoint(ElseBB);
1036
1037  Value *ElseV = Else->codegen();
1038  if (!ElseV)
1039    return nullptr;
1040
1041  Builder.CreateBr(MergeBB);
1042  // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
1043  ElseBB = Builder.GetInsertBlock();
1044
1045  // Emit merge block.
1046  TheFunction->getBasicBlockList().push_back(MergeBB);
1047  Builder.SetInsertPoint(MergeBB);
1048  PHINode *PN =
1049      Builder.CreatePHI(Type::getDoubleTy(getGlobalContext()), 2, "iftmp");
1050
1051  PN->addIncoming(ThenV, ThenBB);
1052  PN->addIncoming(ElseV, ElseBB);
1053  return PN;
1054}
1055
1056// Output for-loop as:
1057//   var = alloca double
1058//   ...
1059//   start = startexpr
1060//   store start -> var
1061//   goto loop
1062// loop:
1063//   ...
1064//   bodyexpr
1065//   ...
1066// loopend:
1067//   step = stepexpr
1068//   endcond = endexpr
1069//
1070//   curvar = load var
1071//   nextvar = curvar + step
1072//   store nextvar -> var
1073//   br endcond, loop, endloop
1074// outloop:
1075Value *ForExprAST::codegen() {
1076  Function *TheFunction = Builder.GetInsertBlock()->getParent();
1077
1078  // Create an alloca for the variable in the entry block.
1079  AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
1080
1081  KSDbgInfo.emitLocation(this);
1082
1083  // Emit the start code first, without 'variable' in scope.
1084  Value *StartVal = Start->codegen();
1085  if (!StartVal)
1086    return nullptr;
1087
1088  // Store the value into the alloca.
1089  Builder.CreateStore(StartVal, Alloca);
1090
1091  // Make the new basic block for the loop header, inserting after current
1092  // block.
1093  BasicBlock *LoopBB =
1094      BasicBlock::Create(getGlobalContext(), "loop", TheFunction);
1095
1096  // Insert an explicit fall through from the current block to the LoopBB.
1097  Builder.CreateBr(LoopBB);
1098
1099  // Start insertion in LoopBB.
1100  Builder.SetInsertPoint(LoopBB);
1101
1102  // Within the loop, the variable is defined equal to the PHI node.  If it
1103  // shadows an existing variable, we have to restore it, so save it now.
1104  AllocaInst *OldVal = NamedValues[VarName];
1105  NamedValues[VarName] = Alloca;
1106
1107  // Emit the body of the loop.  This, like any other expr, can change the
1108  // current BB.  Note that we ignore the value computed by the body, but don't
1109  // allow an error.
1110  if (!Body->codegen())
1111    return nullptr;
1112
1113  // Emit the step value.
1114  Value *StepVal = nullptr;
1115  if (Step) {
1116    StepVal = Step->codegen();
1117    if (!StepVal)
1118      return nullptr;
1119  } else {
1120    // If not specified, use 1.0.
1121    StepVal = ConstantFP::get(getGlobalContext(), APFloat(1.0));
1122  }
1123
1124  // Compute the end condition.
1125  Value *EndCond = End->codegen();
1126  if (!EndCond)
1127    return nullptr;
1128
1129  // Reload, increment, and restore the alloca.  This handles the case where
1130  // the body of the loop mutates the variable.
1131  Value *CurVar = Builder.CreateLoad(Alloca, VarName.c_str());
1132  Value *NextVar = Builder.CreateFAdd(CurVar, StepVal, "nextvar");
1133  Builder.CreateStore(NextVar, Alloca);
1134
1135  // Convert condition to a bool by comparing equal to 0.0.
1136  EndCond = Builder.CreateFCmpONE(
1137      EndCond, ConstantFP::get(getGlobalContext(), APFloat(0.0)), "loopcond");
1138
1139  // Create the "after loop" block and insert it.
1140  BasicBlock *AfterBB =
1141      BasicBlock::Create(getGlobalContext(), "afterloop", TheFunction);
1142
1143  // Insert the conditional branch into the end of LoopEndBB.
1144  Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
1145
1146  // Any new code will be inserted in AfterBB.
1147  Builder.SetInsertPoint(AfterBB);
1148
1149  // Restore the unshadowed variable.
1150  if (OldVal)
1151    NamedValues[VarName] = OldVal;
1152  else
1153    NamedValues.erase(VarName);
1154
1155  // for expr always returns 0.0.
1156  return Constant::getNullValue(Type::getDoubleTy(getGlobalContext()));
1157}
1158
1159Value *VarExprAST::codegen() {
1160  std::vector<AllocaInst *> OldBindings;
1161
1162  Function *TheFunction = Builder.GetInsertBlock()->getParent();
1163
1164  // Register all variables and emit their initializer.
1165  for (unsigned i = 0, e = VarNames.size(); i != e; ++i) {
1166    const std::string &VarName = VarNames[i].first;
1167    ExprAST *Init = VarNames[i].second.get();
1168
1169    // Emit the initializer before adding the variable to scope, this prevents
1170    // the initializer from referencing the variable itself, and permits stuff
1171    // like this:
1172    //  var a = 1 in
1173    //    var a = a in ...   # refers to outer 'a'.
1174    Value *InitVal;
1175    if (Init) {
1176      InitVal = Init->codegen();
1177      if (!InitVal)
1178        return nullptr;
1179    } else { // If not specified, use 0.0.
1180      InitVal = ConstantFP::get(getGlobalContext(), APFloat(0.0));
1181    }
1182
1183    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
1184    Builder.CreateStore(InitVal, Alloca);
1185
1186    // Remember the old variable binding so that we can restore the binding when
1187    // we unrecurse.
1188    OldBindings.push_back(NamedValues[VarName]);
1189
1190    // Remember this binding.
1191    NamedValues[VarName] = Alloca;
1192  }
1193
1194  KSDbgInfo.emitLocation(this);
1195
1196  // Codegen the body, now that all vars are in scope.
1197  Value *BodyVal = Body->codegen();
1198  if (!BodyVal)
1199    return nullptr;
1200
1201  // Pop all our variables from scope.
1202  for (unsigned i = 0, e = VarNames.size(); i != e; ++i)
1203    NamedValues[VarNames[i].first] = OldBindings[i];
1204
1205  // Return the body computation.
1206  return BodyVal;
1207}
1208
1209Function *PrototypeAST::codegen() {
1210  // Make the function type:  double(double,double) etc.
1211  std::vector<Type *> Doubles(Args.size(),
1212                              Type::getDoubleTy(getGlobalContext()));
1213  FunctionType *FT =
1214      FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
1215
1216  Function *F =
1217      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
1218
1219  // Set names for all arguments.
1220  unsigned Idx = 0;
1221  for (auto &Arg : F->args())
1222    Arg.setName(Args[Idx++]);
1223
1224  return F;
1225}
1226
1227Function *FunctionAST::codegen() {
1228  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
1229  // reference to it for use below.
1230  auto &P = *Proto;
1231  FunctionProtos[Proto->getName()] = std::move(Proto);
1232  Function *TheFunction = getFunction(P.getName());
1233  if (!TheFunction)
1234    return nullptr;
1235
1236  // If this is an operator, install it.
1237  if (P.isBinaryOp())
1238    BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
1239
1240  // Create a new basic block to start insertion into.
1241  BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
1242  Builder.SetInsertPoint(BB);
1243
1244  // Create a subprogram DIE for this function.
1245  DIFile *Unit = DBuilder->createFile(KSDbgInfo.TheCU->getFilename(),
1246                                      KSDbgInfo.TheCU->getDirectory());
1247  DIScope *FContext = Unit;
1248  unsigned LineNo = P.getLine();
1249  unsigned ScopeLine = LineNo;
1250  DISubprogram *SP = DBuilder->createFunction(
1251      FContext, P.getName(), StringRef(), Unit, LineNo,
1252      CreateFunctionType(TheFunction->arg_size(), Unit),
1253      false /* internal linkage */, true /* definition */, ScopeLine,
1254      DINode::FlagPrototyped, false);
1255  TheFunction->setSubprogram(SP);
1256
1257  // Push the current scope.
1258  KSDbgInfo.LexicalBlocks.push_back(SP);
1259
1260  // Unset the location for the prologue emission (leading instructions with no
1261  // location in a function are considered part of the prologue and the debugger
1262  // will run past them when breaking on a function)
1263  KSDbgInfo.emitLocation(nullptr);
1264
1265  // Record the function arguments in the NamedValues map.
1266  NamedValues.clear();
1267  unsigned ArgIdx = 0;
1268  for (auto &Arg : TheFunction->args()) {
1269    // Create an alloca for this variable.
1270    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName());
1271
1272    // Create a debug descriptor for the variable.
1273    DILocalVariable *D = DBuilder->createParameterVariable(
1274        SP, Arg.getName(), ++ArgIdx, Unit, LineNo, KSDbgInfo.getDoubleTy(),
1275        true);
1276
1277    DBuilder->insertDeclare(Alloca, D, DBuilder->createExpression(),
1278                            DebugLoc::get(LineNo, 0, SP),
1279                            Builder.GetInsertBlock());
1280
1281    // Store the initial value into the alloca.
1282    Builder.CreateStore(&Arg, Alloca);
1283
1284    // Add arguments to variable symbol table.
1285    NamedValues[Arg.getName()] = Alloca;
1286  }
1287
1288  KSDbgInfo.emitLocation(Body.get());
1289
1290  if (Value *RetVal = Body->codegen()) {
1291    // Finish off the function.
1292    Builder.CreateRet(RetVal);
1293
1294    // Pop off the lexical block for the function.
1295    KSDbgInfo.LexicalBlocks.pop_back();
1296
1297    // Validate the generated code, checking for consistency.
1298    verifyFunction(*TheFunction);
1299
1300    return TheFunction;
1301  }
1302
1303  // Error reading body, remove function.
1304  TheFunction->eraseFromParent();
1305
1306  if (P.isBinaryOp())
1307    BinopPrecedence.erase(Proto->getOperatorName());
1308
1309  // Pop off the lexical block for the function since we added it
1310  // unconditionally.
1311  KSDbgInfo.LexicalBlocks.pop_back();
1312
1313  return nullptr;
1314}
1315
1316//===----------------------------------------------------------------------===//
1317// Top-Level parsing and JIT Driver
1318//===----------------------------------------------------------------------===//
1319
1320static void InitializeModule() {
1321  // Open a new module.
1322  TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
1323  TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
1324}
1325
1326static void HandleDefinition() {
1327  if (auto FnAST = ParseDefinition()) {
1328    if (!FnAST->codegen())
1329      fprintf(stderr, "Error reading function definition:");
1330  } else {
1331    // Skip token for error recovery.
1332    getNextToken();
1333  }
1334}
1335
1336static void HandleExtern() {
1337  if (auto ProtoAST = ParseExtern()) {
1338    if (!ProtoAST->codegen())
1339      fprintf(stderr, "Error reading extern");
1340    else
1341      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
1342  } else {
1343    // Skip token for error recovery.
1344    getNextToken();
1345  }
1346}
1347
1348static void HandleTopLevelExpression() {
1349  // Evaluate a top-level expression into an anonymous function.
1350  if (auto FnAST = ParseTopLevelExpr()) {
1351    if (!FnAST->codegen()) {
1352      fprintf(stderr, "Error generating code for top level expr");
1353    }
1354  } else {
1355    // Skip token for error recovery.
1356    getNextToken();
1357  }
1358}
1359
1360/// top ::= definition | external | expression | ';'
1361static void MainLoop() {
1362  while (1) {
1363    switch (CurTok) {
1364    case tok_eof:
1365      return;
1366    case ';': // ignore top-level semicolons.
1367      getNextToken();
1368      break;
1369    case tok_def:
1370      HandleDefinition();
1371      break;
1372    case tok_extern:
1373      HandleExtern();
1374      break;
1375    default:
1376      HandleTopLevelExpression();
1377      break;
1378    }
1379  }
1380}
1381
1382//===----------------------------------------------------------------------===//
1383// "Library" functions that can be "extern'd" from user code.
1384//===----------------------------------------------------------------------===//
1385
1386/// putchard - putchar that takes a double and returns 0.
1387extern "C" double putchard(double X) {
1388  fputc((char)X, stderr);
1389  return 0;
1390}
1391
1392/// printd - printf that takes a double prints it as "%f\n", returning 0.
1393extern "C" double printd(double X) {
1394  fprintf(stderr, "%f\n", X);
1395  return 0;
1396}
1397
1398//===----------------------------------------------------------------------===//
1399// Main driver code.
1400//===----------------------------------------------------------------------===//
1401
1402int main() {
1403  InitializeNativeTarget();
1404  InitializeNativeTargetAsmPrinter();
1405  InitializeNativeTargetAsmParser();
1406
1407  // Install standard binary operators.
1408  // 1 is lowest precedence.
1409  BinopPrecedence['='] = 2;
1410  BinopPrecedence['<'] = 10;
1411  BinopPrecedence['+'] = 20;
1412  BinopPrecedence['-'] = 20;
1413  BinopPrecedence['*'] = 40; // highest.
1414
1415  // Prime the first token.
1416  getNextToken();
1417
1418  TheJIT = llvm::make_unique<KaleidoscopeJIT>();
1419
1420  InitializeModule();
1421
1422  // Add the current debug info version into the module.
1423  TheModule->addModuleFlag(Module::Warning, "Debug Info Version",
1424                           DEBUG_METADATA_VERSION);
1425
1426  // Darwin only supports dwarf2.
1427  if (Triple(sys::getProcessTriple()).isOSDarwin())
1428    TheModule->addModuleFlag(llvm::Module::Warning, "Dwarf Version", 2);
1429
1430  // Construct the DIBuilder, we do this here because we need the module.
1431  DBuilder = llvm::make_unique<DIBuilder>(*TheModule);
1432
1433  // Create the compile unit for the module.
1434  // Currently down as "fib.ks" as a filename since we're redirecting stdin
1435  // but we'd like actual source locations.
1436  KSDbgInfo.TheCU = DBuilder->createCompileUnit(
1437      dwarf::DW_LANG_C, "fib.ks", ".", "Kaleidoscope Compiler", 0, "", 0);
1438
1439  // Run the main "interpreter loop" now.
1440  MainLoop();
1441
1442  // Finalize the debug info.
1443  DBuilder->finalize();
1444
1445  // Print out all of the generated code.
1446  TheModule->dump();
1447
1448  return 0;
1449}
1450