1#include "llvm/ADT/STLExtras.h"
2#include "llvm/Analysis/BasicAliasAnalysis.h"
3#include "llvm/Analysis/Passes.h"
4#include "llvm/IR/DIBuilder.h"
5#include "llvm/IR/IRBuilder.h"
6#include "llvm/IR/LLVMContext.h"
7#include "llvm/IR/LegacyPassManager.h"
8#include "llvm/IR/Module.h"
9#include "llvm/IR/Verifier.h"
10#include "llvm/Support/TargetSelect.h"
11#include "llvm/Transforms/Scalar.h"
12#include <cctype>
13#include <cstdio>
14#include <map>
15#include <string>
16#include <vector>
17#include "../include/KaleidoscopeJIT.h"
18
19using namespace llvm;
20using namespace llvm::orc;
21
22//===----------------------------------------------------------------------===//
23// Lexer
24//===----------------------------------------------------------------------===//
25
26// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
27// of these for known things.
28enum Token {
29  tok_eof = -1,
30
31  // commands
32  tok_def = -2,
33  tok_extern = -3,
34
35  // primary
36  tok_identifier = -4,
37  tok_number = -5,
38
39  // control
40  tok_if = -6,
41  tok_then = -7,
42  tok_else = -8,
43  tok_for = -9,
44  tok_in = -10,
45
46  // operators
47  tok_binary = -11,
48  tok_unary = -12,
49
50  // var definition
51  tok_var = -13
52};
53
54std::string getTokName(int Tok) {
55  switch (Tok) {
56  case tok_eof:
57    return "eof";
58  case tok_def:
59    return "def";
60  case tok_extern:
61    return "extern";
62  case tok_identifier:
63    return "identifier";
64  case tok_number:
65    return "number";
66  case tok_if:
67    return "if";
68  case tok_then:
69    return "then";
70  case tok_else:
71    return "else";
72  case tok_for:
73    return "for";
74  case tok_in:
75    return "in";
76  case tok_binary:
77    return "binary";
78  case tok_unary:
79    return "unary";
80  case tok_var:
81    return "var";
82  }
83  return std::string(1, (char)Tok);
84}
85
86namespace {
87class PrototypeAST;
88class ExprAST;
89}
90static LLVMContext TheContext;
91static IRBuilder<> Builder(TheContext);
92struct DebugInfo {
93  DICompileUnit *TheCU;
94  DIType *DblTy;
95  std::vector<DIScope *> LexicalBlocks;
96
97  void emitLocation(ExprAST *AST);
98  DIType *getDoubleTy();
99} KSDbgInfo;
100
101struct SourceLocation {
102  int Line;
103  int Col;
104};
105static SourceLocation CurLoc;
106static SourceLocation LexLoc = {1, 0};
107
108static int advance() {
109  int LastChar = getchar();
110
111  if (LastChar == '\n' || LastChar == '\r') {
112    LexLoc.Line++;
113    LexLoc.Col = 0;
114  } else
115    LexLoc.Col++;
116  return LastChar;
117}
118
119static std::string IdentifierStr; // Filled in if tok_identifier
120static double NumVal;             // Filled in if tok_number
121
122/// gettok - Return the next token from standard input.
123static int gettok() {
124  static int LastChar = ' ';
125
126  // Skip any whitespace.
127  while (isspace(LastChar))
128    LastChar = advance();
129
130  CurLoc = LexLoc;
131
132  if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
133    IdentifierStr = LastChar;
134    while (isalnum((LastChar = advance())))
135      IdentifierStr += LastChar;
136
137    if (IdentifierStr == "def")
138      return tok_def;
139    if (IdentifierStr == "extern")
140      return tok_extern;
141    if (IdentifierStr == "if")
142      return tok_if;
143    if (IdentifierStr == "then")
144      return tok_then;
145    if (IdentifierStr == "else")
146      return tok_else;
147    if (IdentifierStr == "for")
148      return tok_for;
149    if (IdentifierStr == "in")
150      return tok_in;
151    if (IdentifierStr == "binary")
152      return tok_binary;
153    if (IdentifierStr == "unary")
154      return tok_unary;
155    if (IdentifierStr == "var")
156      return tok_var;
157    return tok_identifier;
158  }
159
160  if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
161    std::string NumStr;
162    do {
163      NumStr += LastChar;
164      LastChar = advance();
165    } while (isdigit(LastChar) || LastChar == '.');
166
167    NumVal = strtod(NumStr.c_str(), nullptr);
168    return tok_number;
169  }
170
171  if (LastChar == '#') {
172    // Comment until end of line.
173    do
174      LastChar = advance();
175    while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
176
177    if (LastChar != EOF)
178      return gettok();
179  }
180
181  // Check for end of file.  Don't eat the EOF.
182  if (LastChar == EOF)
183    return tok_eof;
184
185  // Otherwise, just return the character as its ascii value.
186  int ThisChar = LastChar;
187  LastChar = advance();
188  return ThisChar;
189}
190
191//===----------------------------------------------------------------------===//
192// Abstract Syntax Tree (aka Parse Tree)
193//===----------------------------------------------------------------------===//
194namespace {
195
196raw_ostream &indent(raw_ostream &O, int size) {
197  return O << std::string(size, ' ');
198}
199
200/// ExprAST - Base class for all expression nodes.
201class ExprAST {
202  SourceLocation Loc;
203
204public:
205  ExprAST(SourceLocation Loc = CurLoc) : Loc(Loc) {}
206  virtual ~ExprAST() {}
207  virtual Value *codegen() = 0;
208  int getLine() const { return Loc.Line; }
209  int getCol() const { return Loc.Col; }
210  virtual raw_ostream &dump(raw_ostream &out, int ind) {
211    return out << ':' << getLine() << ':' << getCol() << '\n';
212  }
213};
214
215/// NumberExprAST - Expression class for numeric literals like "1.0".
216class NumberExprAST : public ExprAST {
217  double Val;
218
219public:
220  NumberExprAST(double Val) : Val(Val) {}
221  raw_ostream &dump(raw_ostream &out, int ind) override {
222    return ExprAST::dump(out << Val, ind);
223  }
224  Value *codegen() override;
225};
226
227/// VariableExprAST - Expression class for referencing a variable, like "a".
228class VariableExprAST : public ExprAST {
229  std::string Name;
230
231public:
232  VariableExprAST(SourceLocation Loc, const std::string &Name)
233      : ExprAST(Loc), Name(Name) {}
234  const std::string &getName() const { return Name; }
235  Value *codegen() override;
236  raw_ostream &dump(raw_ostream &out, int ind) override {
237    return ExprAST::dump(out << Name, ind);
238  }
239};
240
241/// UnaryExprAST - Expression class for a unary operator.
242class UnaryExprAST : public ExprAST {
243  char Opcode;
244  std::unique_ptr<ExprAST> Operand;
245
246public:
247  UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
248      : Opcode(Opcode), Operand(std::move(Operand)) {}
249  Value *codegen() override;
250  raw_ostream &dump(raw_ostream &out, int ind) override {
251    ExprAST::dump(out << "unary" << Opcode, ind);
252    Operand->dump(out, ind + 1);
253    return out;
254  }
255};
256
257/// BinaryExprAST - Expression class for a binary operator.
258class BinaryExprAST : public ExprAST {
259  char Op;
260  std::unique_ptr<ExprAST> LHS, RHS;
261
262public:
263  BinaryExprAST(SourceLocation Loc, char Op, std::unique_ptr<ExprAST> LHS,
264                std::unique_ptr<ExprAST> RHS)
265      : ExprAST(Loc), Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
266  Value *codegen() override;
267  raw_ostream &dump(raw_ostream &out, int ind) override {
268    ExprAST::dump(out << "binary" << Op, ind);
269    LHS->dump(indent(out, ind) << "LHS:", ind + 1);
270    RHS->dump(indent(out, ind) << "RHS:", ind + 1);
271    return out;
272  }
273};
274
275/// CallExprAST - Expression class for function calls.
276class CallExprAST : public ExprAST {
277  std::string Callee;
278  std::vector<std::unique_ptr<ExprAST>> Args;
279
280public:
281  CallExprAST(SourceLocation Loc, const std::string &Callee,
282              std::vector<std::unique_ptr<ExprAST>> Args)
283      : ExprAST(Loc), Callee(Callee), Args(std::move(Args)) {}
284  Value *codegen() override;
285  raw_ostream &dump(raw_ostream &out, int ind) override {
286    ExprAST::dump(out << "call " << Callee, ind);
287    for (const auto &Arg : Args)
288      Arg->dump(indent(out, ind + 1), ind + 1);
289    return out;
290  }
291};
292
293/// IfExprAST - Expression class for if/then/else.
294class IfExprAST : public ExprAST {
295  std::unique_ptr<ExprAST> Cond, Then, Else;
296
297public:
298  IfExprAST(SourceLocation Loc, std::unique_ptr<ExprAST> Cond,
299            std::unique_ptr<ExprAST> Then, std::unique_ptr<ExprAST> Else)
300      : ExprAST(Loc), Cond(std::move(Cond)), Then(std::move(Then)),
301        Else(std::move(Else)) {}
302  Value *codegen() override;
303  raw_ostream &dump(raw_ostream &out, int ind) override {
304    ExprAST::dump(out << "if", ind);
305    Cond->dump(indent(out, ind) << "Cond:", ind + 1);
306    Then->dump(indent(out, ind) << "Then:", ind + 1);
307    Else->dump(indent(out, ind) << "Else:", ind + 1);
308    return out;
309  }
310};
311
312/// ForExprAST - Expression class for for/in.
313class ForExprAST : public ExprAST {
314  std::string VarName;
315  std::unique_ptr<ExprAST> Start, End, Step, Body;
316
317public:
318  ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
319             std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
320             std::unique_ptr<ExprAST> Body)
321      : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
322        Step(std::move(Step)), Body(std::move(Body)) {}
323  Value *codegen() override;
324  raw_ostream &dump(raw_ostream &out, int ind) override {
325    ExprAST::dump(out << "for", ind);
326    Start->dump(indent(out, ind) << "Cond:", ind + 1);
327    End->dump(indent(out, ind) << "End:", ind + 1);
328    Step->dump(indent(out, ind) << "Step:", ind + 1);
329    Body->dump(indent(out, ind) << "Body:", ind + 1);
330    return out;
331  }
332};
333
334/// VarExprAST - Expression class for var/in
335class VarExprAST : public ExprAST {
336  std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
337  std::unique_ptr<ExprAST> Body;
338
339public:
340  VarExprAST(
341      std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames,
342      std::unique_ptr<ExprAST> Body)
343      : VarNames(std::move(VarNames)), Body(std::move(Body)) {}
344  Value *codegen() override;
345  raw_ostream &dump(raw_ostream &out, int ind) override {
346    ExprAST::dump(out << "var", ind);
347    for (const auto &NamedVar : VarNames)
348      NamedVar.second->dump(indent(out, ind) << NamedVar.first << ':', ind + 1);
349    Body->dump(indent(out, ind) << "Body:", ind + 1);
350    return out;
351  }
352};
353
354/// PrototypeAST - This class represents the "prototype" for a function,
355/// which captures its name, and its argument names (thus implicitly the number
356/// of arguments the function takes), as well as if it is an operator.
357class PrototypeAST {
358  std::string Name;
359  std::vector<std::string> Args;
360  bool IsOperator;
361  unsigned Precedence; // Precedence if a binary op.
362  int Line;
363
364public:
365  PrototypeAST(SourceLocation Loc, const std::string &Name,
366               std::vector<std::string> Args, bool IsOperator = false,
367               unsigned Prec = 0)
368      : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
369        Precedence(Prec), Line(Loc.Line) {}
370  Function *codegen();
371  const std::string &getName() const { return Name; }
372
373  bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
374  bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
375
376  char getOperatorName() const {
377    assert(isUnaryOp() || isBinaryOp());
378    return Name[Name.size() - 1];
379  }
380
381  unsigned getBinaryPrecedence() const { return Precedence; }
382  int getLine() const { return Line; }
383};
384
385/// FunctionAST - This class represents a function definition itself.
386class FunctionAST {
387  std::unique_ptr<PrototypeAST> Proto;
388  std::unique_ptr<ExprAST> Body;
389
390public:
391  FunctionAST(std::unique_ptr<PrototypeAST> Proto,
392              std::unique_ptr<ExprAST> Body)
393      : Proto(std::move(Proto)), Body(std::move(Body)) {}
394  Function *codegen();
395  raw_ostream &dump(raw_ostream &out, int ind) {
396    indent(out, ind) << "FunctionAST\n";
397    ++ind;
398    indent(out, ind) << "Body:";
399    return Body ? Body->dump(out, ind) : out << "null\n";
400  }
401};
402} // end anonymous namespace
403
404//===----------------------------------------------------------------------===//
405// Parser
406//===----------------------------------------------------------------------===//
407
408/// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
409/// token the parser is looking at.  getNextToken reads another token from the
410/// lexer and updates CurTok with its results.
411static int CurTok;
412static int getNextToken() { return CurTok = gettok(); }
413
414/// BinopPrecedence - This holds the precedence for each binary operator that is
415/// defined.
416static std::map<char, int> BinopPrecedence;
417
418/// GetTokPrecedence - Get the precedence of the pending binary operator token.
419static int GetTokPrecedence() {
420  if (!isascii(CurTok))
421    return -1;
422
423  // Make sure it's a declared binop.
424  int TokPrec = BinopPrecedence[CurTok];
425  if (TokPrec <= 0)
426    return -1;
427  return TokPrec;
428}
429
430/// LogError* - These are little helper functions for error handling.
431std::unique_ptr<ExprAST> LogError(const char *Str) {
432  fprintf(stderr, "Error: %s\n", Str);
433  return nullptr;
434}
435
436std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
437  LogError(Str);
438  return nullptr;
439}
440
441static std::unique_ptr<ExprAST> ParseExpression();
442
443/// numberexpr ::= number
444static std::unique_ptr<ExprAST> ParseNumberExpr() {
445  auto Result = llvm::make_unique<NumberExprAST>(NumVal);
446  getNextToken(); // consume the number
447  return std::move(Result);
448}
449
450/// parenexpr ::= '(' expression ')'
451static std::unique_ptr<ExprAST> ParseParenExpr() {
452  getNextToken(); // eat (.
453  auto V = ParseExpression();
454  if (!V)
455    return nullptr;
456
457  if (CurTok != ')')
458    return LogError("expected ')'");
459  getNextToken(); // eat ).
460  return V;
461}
462
463/// identifierexpr
464///   ::= identifier
465///   ::= identifier '(' expression* ')'
466static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
467  std::string IdName = IdentifierStr;
468
469  SourceLocation LitLoc = CurLoc;
470
471  getNextToken(); // eat identifier.
472
473  if (CurTok != '(') // Simple variable ref.
474    return llvm::make_unique<VariableExprAST>(LitLoc, IdName);
475
476  // Call.
477  getNextToken(); // eat (
478  std::vector<std::unique_ptr<ExprAST>> Args;
479  if (CurTok != ')') {
480    while (1) {
481      if (auto Arg = ParseExpression())
482        Args.push_back(std::move(Arg));
483      else
484        return nullptr;
485
486      if (CurTok == ')')
487        break;
488
489      if (CurTok != ',')
490        return LogError("Expected ')' or ',' in argument list");
491      getNextToken();
492    }
493  }
494
495  // Eat the ')'.
496  getNextToken();
497
498  return llvm::make_unique<CallExprAST>(LitLoc, IdName, std::move(Args));
499}
500
501/// ifexpr ::= 'if' expression 'then' expression 'else' expression
502static std::unique_ptr<ExprAST> ParseIfExpr() {
503  SourceLocation IfLoc = CurLoc;
504
505  getNextToken(); // eat the if.
506
507  // condition.
508  auto Cond = ParseExpression();
509  if (!Cond)
510    return nullptr;
511
512  if (CurTok != tok_then)
513    return LogError("expected then");
514  getNextToken(); // eat the then
515
516  auto Then = ParseExpression();
517  if (!Then)
518    return nullptr;
519
520  if (CurTok != tok_else)
521    return LogError("expected else");
522
523  getNextToken();
524
525  auto Else = ParseExpression();
526  if (!Else)
527    return nullptr;
528
529  return llvm::make_unique<IfExprAST>(IfLoc, std::move(Cond), std::move(Then),
530                                      std::move(Else));
531}
532
533/// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
534static std::unique_ptr<ExprAST> ParseForExpr() {
535  getNextToken(); // eat the for.
536
537  if (CurTok != tok_identifier)
538    return LogError("expected identifier after for");
539
540  std::string IdName = IdentifierStr;
541  getNextToken(); // eat identifier.
542
543  if (CurTok != '=')
544    return LogError("expected '=' after for");
545  getNextToken(); // eat '='.
546
547  auto Start = ParseExpression();
548  if (!Start)
549    return nullptr;
550  if (CurTok != ',')
551    return LogError("expected ',' after for start value");
552  getNextToken();
553
554  auto End = ParseExpression();
555  if (!End)
556    return nullptr;
557
558  // The step value is optional.
559  std::unique_ptr<ExprAST> Step;
560  if (CurTok == ',') {
561    getNextToken();
562    Step = ParseExpression();
563    if (!Step)
564      return nullptr;
565  }
566
567  if (CurTok != tok_in)
568    return LogError("expected 'in' after for");
569  getNextToken(); // eat 'in'.
570
571  auto Body = ParseExpression();
572  if (!Body)
573    return nullptr;
574
575  return llvm::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
576                                       std::move(Step), std::move(Body));
577}
578
579/// varexpr ::= 'var' identifier ('=' expression)?
580//                    (',' identifier ('=' expression)?)* 'in' expression
581static std::unique_ptr<ExprAST> ParseVarExpr() {
582  getNextToken(); // eat the var.
583
584  std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
585
586  // At least one variable name is required.
587  if (CurTok != tok_identifier)
588    return LogError("expected identifier after var");
589
590  while (1) {
591    std::string Name = IdentifierStr;
592    getNextToken(); // eat identifier.
593
594    // Read the optional initializer.
595    std::unique_ptr<ExprAST> Init = nullptr;
596    if (CurTok == '=') {
597      getNextToken(); // eat the '='.
598
599      Init = ParseExpression();
600      if (!Init)
601        return nullptr;
602    }
603
604    VarNames.push_back(std::make_pair(Name, std::move(Init)));
605
606    // End of var list, exit loop.
607    if (CurTok != ',')
608      break;
609    getNextToken(); // eat the ','.
610
611    if (CurTok != tok_identifier)
612      return LogError("expected identifier list after var");
613  }
614
615  // At this point, we have to have 'in'.
616  if (CurTok != tok_in)
617    return LogError("expected 'in' keyword after 'var'");
618  getNextToken(); // eat 'in'.
619
620  auto Body = ParseExpression();
621  if (!Body)
622    return nullptr;
623
624  return llvm::make_unique<VarExprAST>(std::move(VarNames), std::move(Body));
625}
626
627/// primary
628///   ::= identifierexpr
629///   ::= numberexpr
630///   ::= parenexpr
631///   ::= ifexpr
632///   ::= forexpr
633///   ::= varexpr
634static std::unique_ptr<ExprAST> ParsePrimary() {
635  switch (CurTok) {
636  default:
637    return LogError("unknown token when expecting an expression");
638  case tok_identifier:
639    return ParseIdentifierExpr();
640  case tok_number:
641    return ParseNumberExpr();
642  case '(':
643    return ParseParenExpr();
644  case tok_if:
645    return ParseIfExpr();
646  case tok_for:
647    return ParseForExpr();
648  case tok_var:
649    return ParseVarExpr();
650  }
651}
652
653/// unary
654///   ::= primary
655///   ::= '!' unary
656static std::unique_ptr<ExprAST> ParseUnary() {
657  // If the current token is not an operator, it must be a primary expr.
658  if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
659    return ParsePrimary();
660
661  // If this is a unary operator, read it.
662  int Opc = CurTok;
663  getNextToken();
664  if (auto Operand = ParseUnary())
665    return llvm::make_unique<UnaryExprAST>(Opc, std::move(Operand));
666  return nullptr;
667}
668
669/// binoprhs
670///   ::= ('+' unary)*
671static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
672                                              std::unique_ptr<ExprAST> LHS) {
673  // If this is a binop, find its precedence.
674  while (1) {
675    int TokPrec = GetTokPrecedence();
676
677    // If this is a binop that binds at least as tightly as the current binop,
678    // consume it, otherwise we are done.
679    if (TokPrec < ExprPrec)
680      return LHS;
681
682    // Okay, we know this is a binop.
683    int BinOp = CurTok;
684    SourceLocation BinLoc = CurLoc;
685    getNextToken(); // eat binop
686
687    // Parse the unary expression after the binary operator.
688    auto RHS = ParseUnary();
689    if (!RHS)
690      return nullptr;
691
692    // If BinOp binds less tightly with RHS than the operator after RHS, let
693    // the pending operator take RHS as its LHS.
694    int NextPrec = GetTokPrecedence();
695    if (TokPrec < NextPrec) {
696      RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
697      if (!RHS)
698        return nullptr;
699    }
700
701    // Merge LHS/RHS.
702    LHS = llvm::make_unique<BinaryExprAST>(BinLoc, BinOp, std::move(LHS),
703                                           std::move(RHS));
704  }
705}
706
707/// expression
708///   ::= unary binoprhs
709///
710static std::unique_ptr<ExprAST> ParseExpression() {
711  auto LHS = ParseUnary();
712  if (!LHS)
713    return nullptr;
714
715  return ParseBinOpRHS(0, std::move(LHS));
716}
717
718/// prototype
719///   ::= id '(' id* ')'
720///   ::= binary LETTER number? (id, id)
721///   ::= unary LETTER (id)
722static std::unique_ptr<PrototypeAST> ParsePrototype() {
723  std::string FnName;
724
725  SourceLocation FnLoc = CurLoc;
726
727  unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
728  unsigned BinaryPrecedence = 30;
729
730  switch (CurTok) {
731  default:
732    return LogErrorP("Expected function name in prototype");
733  case tok_identifier:
734    FnName = IdentifierStr;
735    Kind = 0;
736    getNextToken();
737    break;
738  case tok_unary:
739    getNextToken();
740    if (!isascii(CurTok))
741      return LogErrorP("Expected unary operator");
742    FnName = "unary";
743    FnName += (char)CurTok;
744    Kind = 1;
745    getNextToken();
746    break;
747  case tok_binary:
748    getNextToken();
749    if (!isascii(CurTok))
750      return LogErrorP("Expected binary operator");
751    FnName = "binary";
752    FnName += (char)CurTok;
753    Kind = 2;
754    getNextToken();
755
756    // Read the precedence if present.
757    if (CurTok == tok_number) {
758      if (NumVal < 1 || NumVal > 100)
759        return LogErrorP("Invalid precedecnce: must be 1..100");
760      BinaryPrecedence = (unsigned)NumVal;
761      getNextToken();
762    }
763    break;
764  }
765
766  if (CurTok != '(')
767    return LogErrorP("Expected '(' in prototype");
768
769  std::vector<std::string> ArgNames;
770  while (getNextToken() == tok_identifier)
771    ArgNames.push_back(IdentifierStr);
772  if (CurTok != ')')
773    return LogErrorP("Expected ')' in prototype");
774
775  // success.
776  getNextToken(); // eat ')'.
777
778  // Verify right number of names for operator.
779  if (Kind && ArgNames.size() != Kind)
780    return LogErrorP("Invalid number of operands for operator");
781
782  return llvm::make_unique<PrototypeAST>(FnLoc, FnName, ArgNames, Kind != 0,
783                                         BinaryPrecedence);
784}
785
786/// definition ::= 'def' prototype expression
787static std::unique_ptr<FunctionAST> ParseDefinition() {
788  getNextToken(); // eat def.
789  auto Proto = ParsePrototype();
790  if (!Proto)
791    return nullptr;
792
793  if (auto E = ParseExpression())
794    return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
795  return nullptr;
796}
797
798/// toplevelexpr ::= expression
799static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
800  SourceLocation FnLoc = CurLoc;
801  if (auto E = ParseExpression()) {
802    // Make an anonymous proto.
803    auto Proto = llvm::make_unique<PrototypeAST>(FnLoc, "__anon_expr",
804                                                 std::vector<std::string>());
805    return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
806  }
807  return nullptr;
808}
809
810/// external ::= 'extern' prototype
811static std::unique_ptr<PrototypeAST> ParseExtern() {
812  getNextToken(); // eat extern.
813  return ParsePrototype();
814}
815
816//===----------------------------------------------------------------------===//
817// Debug Info Support
818//===----------------------------------------------------------------------===//
819
820static std::unique_ptr<DIBuilder> DBuilder;
821
822DIType *DebugInfo::getDoubleTy() {
823  if (DblTy)
824    return DblTy;
825
826  DblTy = DBuilder->createBasicType("double", 64, 64, dwarf::DW_ATE_float);
827  return DblTy;
828}
829
830void DebugInfo::emitLocation(ExprAST *AST) {
831  if (!AST)
832    return Builder.SetCurrentDebugLocation(DebugLoc());
833  DIScope *Scope;
834  if (LexicalBlocks.empty())
835    Scope = TheCU;
836  else
837    Scope = LexicalBlocks.back();
838  Builder.SetCurrentDebugLocation(
839      DebugLoc::get(AST->getLine(), AST->getCol(), Scope));
840}
841
842static DISubroutineType *CreateFunctionType(unsigned NumArgs, DIFile *Unit) {
843  SmallVector<Metadata *, 8> EltTys;
844  DIType *DblTy = KSDbgInfo.getDoubleTy();
845
846  // Add the result type.
847  EltTys.push_back(DblTy);
848
849  for (unsigned i = 0, e = NumArgs; i != e; ++i)
850    EltTys.push_back(DblTy);
851
852  return DBuilder->createSubroutineType(DBuilder->getOrCreateTypeArray(EltTys));
853}
854
855//===----------------------------------------------------------------------===//
856// Code Generation
857//===----------------------------------------------------------------------===//
858
859static std::unique_ptr<Module> TheModule;
860static std::map<std::string, AllocaInst *> NamedValues;
861static std::unique_ptr<KaleidoscopeJIT> TheJIT;
862static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
863
864Value *LogErrorV(const char *Str) {
865  LogError(Str);
866  return nullptr;
867}
868
869Function *getFunction(std::string Name) {
870  // First, see if the function has already been added to the current module.
871  if (auto *F = TheModule->getFunction(Name))
872    return F;
873
874  // If not, check whether we can codegen the declaration from some existing
875  // prototype.
876  auto FI = FunctionProtos.find(Name);
877  if (FI != FunctionProtos.end())
878    return FI->second->codegen();
879
880  // If no existing prototype exists, return null.
881  return nullptr;
882}
883
884/// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of
885/// the function.  This is used for mutable variables etc.
886static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
887                                          const std::string &VarName) {
888  IRBuilder<> TmpB(&TheFunction->getEntryBlock(),
889                   TheFunction->getEntryBlock().begin());
890  return TmpB.CreateAlloca(Type::getDoubleTy(TheContext), nullptr,
891                           VarName.c_str());
892}
893
894Value *NumberExprAST::codegen() {
895  KSDbgInfo.emitLocation(this);
896  return ConstantFP::get(TheContext, APFloat(Val));
897}
898
899Value *VariableExprAST::codegen() {
900  // Look this variable up in the function.
901  Value *V = NamedValues[Name];
902  if (!V)
903    return LogErrorV("Unknown variable name");
904
905  KSDbgInfo.emitLocation(this);
906  // Load the value.
907  return Builder.CreateLoad(V, Name.c_str());
908}
909
910Value *UnaryExprAST::codegen() {
911  Value *OperandV = Operand->codegen();
912  if (!OperandV)
913    return nullptr;
914
915  Function *F = getFunction(std::string("unary") + Opcode);
916  if (!F)
917    return LogErrorV("Unknown unary operator");
918
919  KSDbgInfo.emitLocation(this);
920  return Builder.CreateCall(F, OperandV, "unop");
921}
922
923Value *BinaryExprAST::codegen() {
924  KSDbgInfo.emitLocation(this);
925
926  // Special case '=' because we don't want to emit the LHS as an expression.
927  if (Op == '=') {
928    // Assignment requires the LHS to be an identifier.
929    // This assume we're building without RTTI because LLVM builds that way by
930    // default.  If you build LLVM with RTTI this can be changed to a
931    // dynamic_cast for automatic error checking.
932    VariableExprAST *LHSE = static_cast<VariableExprAST *>(LHS.get());
933    if (!LHSE)
934      return LogErrorV("destination of '=' must be a variable");
935    // Codegen the RHS.
936    Value *Val = RHS->codegen();
937    if (!Val)
938      return nullptr;
939
940    // Look up the name.
941    Value *Variable = NamedValues[LHSE->getName()];
942    if (!Variable)
943      return LogErrorV("Unknown variable name");
944
945    Builder.CreateStore(Val, Variable);
946    return Val;
947  }
948
949  Value *L = LHS->codegen();
950  Value *R = RHS->codegen();
951  if (!L || !R)
952    return nullptr;
953
954  switch (Op) {
955  case '+':
956    return Builder.CreateFAdd(L, R, "addtmp");
957  case '-':
958    return Builder.CreateFSub(L, R, "subtmp");
959  case '*':
960    return Builder.CreateFMul(L, R, "multmp");
961  case '<':
962    L = Builder.CreateFCmpULT(L, R, "cmptmp");
963    // Convert bool 0/1 to double 0.0 or 1.0
964    return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
965  default:
966    break;
967  }
968
969  // If it wasn't a builtin binary operator, it must be a user defined one. Emit
970  // a call to it.
971  Function *F = getFunction(std::string("binary") + Op);
972  assert(F && "binary operator not found!");
973
974  Value *Ops[] = {L, R};
975  return Builder.CreateCall(F, Ops, "binop");
976}
977
978Value *CallExprAST::codegen() {
979  KSDbgInfo.emitLocation(this);
980
981  // Look up the name in the global module table.
982  Function *CalleeF = getFunction(Callee);
983  if (!CalleeF)
984    return LogErrorV("Unknown function referenced");
985
986  // If argument mismatch error.
987  if (CalleeF->arg_size() != Args.size())
988    return LogErrorV("Incorrect # arguments passed");
989
990  std::vector<Value *> ArgsV;
991  for (unsigned i = 0, e = Args.size(); i != e; ++i) {
992    ArgsV.push_back(Args[i]->codegen());
993    if (!ArgsV.back())
994      return nullptr;
995  }
996
997  return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
998}
999
1000Value *IfExprAST::codegen() {
1001  KSDbgInfo.emitLocation(this);
1002
1003  Value *CondV = Cond->codegen();
1004  if (!CondV)
1005    return nullptr;
1006
1007  // Convert condition to a bool by comparing equal to 0.0.
1008  CondV = Builder.CreateFCmpONE(
1009      CondV, ConstantFP::get(TheContext, APFloat(0.0)), "ifcond");
1010
1011  Function *TheFunction = Builder.GetInsertBlock()->getParent();
1012
1013  // Create blocks for the then and else cases.  Insert the 'then' block at the
1014  // end of the function.
1015  BasicBlock *ThenBB = BasicBlock::Create(TheContext, "then", TheFunction);
1016  BasicBlock *ElseBB = BasicBlock::Create(TheContext, "else");
1017  BasicBlock *MergeBB = BasicBlock::Create(TheContext, "ifcont");
1018
1019  Builder.CreateCondBr(CondV, ThenBB, ElseBB);
1020
1021  // Emit then value.
1022  Builder.SetInsertPoint(ThenBB);
1023
1024  Value *ThenV = Then->codegen();
1025  if (!ThenV)
1026    return nullptr;
1027
1028  Builder.CreateBr(MergeBB);
1029  // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
1030  ThenBB = Builder.GetInsertBlock();
1031
1032  // Emit else block.
1033  TheFunction->getBasicBlockList().push_back(ElseBB);
1034  Builder.SetInsertPoint(ElseBB);
1035
1036  Value *ElseV = Else->codegen();
1037  if (!ElseV)
1038    return nullptr;
1039
1040  Builder.CreateBr(MergeBB);
1041  // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
1042  ElseBB = Builder.GetInsertBlock();
1043
1044  // Emit merge block.
1045  TheFunction->getBasicBlockList().push_back(MergeBB);
1046  Builder.SetInsertPoint(MergeBB);
1047  PHINode *PN = Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, "iftmp");
1048
1049  PN->addIncoming(ThenV, ThenBB);
1050  PN->addIncoming(ElseV, ElseBB);
1051  return PN;
1052}
1053
1054// Output for-loop as:
1055//   var = alloca double
1056//   ...
1057//   start = startexpr
1058//   store start -> var
1059//   goto loop
1060// loop:
1061//   ...
1062//   bodyexpr
1063//   ...
1064// loopend:
1065//   step = stepexpr
1066//   endcond = endexpr
1067//
1068//   curvar = load var
1069//   nextvar = curvar + step
1070//   store nextvar -> var
1071//   br endcond, loop, endloop
1072// outloop:
1073Value *ForExprAST::codegen() {
1074  Function *TheFunction = Builder.GetInsertBlock()->getParent();
1075
1076  // Create an alloca for the variable in the entry block.
1077  AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
1078
1079  KSDbgInfo.emitLocation(this);
1080
1081  // Emit the start code first, without 'variable' in scope.
1082  Value *StartVal = Start->codegen();
1083  if (!StartVal)
1084    return nullptr;
1085
1086  // Store the value into the alloca.
1087  Builder.CreateStore(StartVal, Alloca);
1088
1089  // Make the new basic block for the loop header, inserting after current
1090  // block.
1091  BasicBlock *LoopBB = BasicBlock::Create(TheContext, "loop", TheFunction);
1092
1093  // Insert an explicit fall through from the current block to the LoopBB.
1094  Builder.CreateBr(LoopBB);
1095
1096  // Start insertion in LoopBB.
1097  Builder.SetInsertPoint(LoopBB);
1098
1099  // Within the loop, the variable is defined equal to the PHI node.  If it
1100  // shadows an existing variable, we have to restore it, so save it now.
1101  AllocaInst *OldVal = NamedValues[VarName];
1102  NamedValues[VarName] = Alloca;
1103
1104  // Emit the body of the loop.  This, like any other expr, can change the
1105  // current BB.  Note that we ignore the value computed by the body, but don't
1106  // allow an error.
1107  if (!Body->codegen())
1108    return nullptr;
1109
1110  // Emit the step value.
1111  Value *StepVal = nullptr;
1112  if (Step) {
1113    StepVal = Step->codegen();
1114    if (!StepVal)
1115      return nullptr;
1116  } else {
1117    // If not specified, use 1.0.
1118    StepVal = ConstantFP::get(TheContext, APFloat(1.0));
1119  }
1120
1121  // Compute the end condition.
1122  Value *EndCond = End->codegen();
1123  if (!EndCond)
1124    return nullptr;
1125
1126  // Reload, increment, and restore the alloca.  This handles the case where
1127  // the body of the loop mutates the variable.
1128  Value *CurVar = Builder.CreateLoad(Alloca, VarName.c_str());
1129  Value *NextVar = Builder.CreateFAdd(CurVar, StepVal, "nextvar");
1130  Builder.CreateStore(NextVar, Alloca);
1131
1132  // Convert condition to a bool by comparing equal to 0.0.
1133  EndCond = Builder.CreateFCmpONE(
1134      EndCond, ConstantFP::get(TheContext, APFloat(0.0)), "loopcond");
1135
1136  // Create the "after loop" block and insert it.
1137  BasicBlock *AfterBB =
1138      BasicBlock::Create(TheContext, "afterloop", TheFunction);
1139
1140  // Insert the conditional branch into the end of LoopEndBB.
1141  Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
1142
1143  // Any new code will be inserted in AfterBB.
1144  Builder.SetInsertPoint(AfterBB);
1145
1146  // Restore the unshadowed variable.
1147  if (OldVal)
1148    NamedValues[VarName] = OldVal;
1149  else
1150    NamedValues.erase(VarName);
1151
1152  // for expr always returns 0.0.
1153  return Constant::getNullValue(Type::getDoubleTy(TheContext));
1154}
1155
1156Value *VarExprAST::codegen() {
1157  std::vector<AllocaInst *> OldBindings;
1158
1159  Function *TheFunction = Builder.GetInsertBlock()->getParent();
1160
1161  // Register all variables and emit their initializer.
1162  for (unsigned i = 0, e = VarNames.size(); i != e; ++i) {
1163    const std::string &VarName = VarNames[i].first;
1164    ExprAST *Init = VarNames[i].second.get();
1165
1166    // Emit the initializer before adding the variable to scope, this prevents
1167    // the initializer from referencing the variable itself, and permits stuff
1168    // like this:
1169    //  var a = 1 in
1170    //    var a = a in ...   # refers to outer 'a'.
1171    Value *InitVal;
1172    if (Init) {
1173      InitVal = Init->codegen();
1174      if (!InitVal)
1175        return nullptr;
1176    } else { // If not specified, use 0.0.
1177      InitVal = ConstantFP::get(TheContext, APFloat(0.0));
1178    }
1179
1180    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
1181    Builder.CreateStore(InitVal, Alloca);
1182
1183    // Remember the old variable binding so that we can restore the binding when
1184    // we unrecurse.
1185    OldBindings.push_back(NamedValues[VarName]);
1186
1187    // Remember this binding.
1188    NamedValues[VarName] = Alloca;
1189  }
1190
1191  KSDbgInfo.emitLocation(this);
1192
1193  // Codegen the body, now that all vars are in scope.
1194  Value *BodyVal = Body->codegen();
1195  if (!BodyVal)
1196    return nullptr;
1197
1198  // Pop all our variables from scope.
1199  for (unsigned i = 0, e = VarNames.size(); i != e; ++i)
1200    NamedValues[VarNames[i].first] = OldBindings[i];
1201
1202  // Return the body computation.
1203  return BodyVal;
1204}
1205
1206Function *PrototypeAST::codegen() {
1207  // Make the function type:  double(double,double) etc.
1208  std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
1209  FunctionType *FT =
1210      FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
1211
1212  Function *F =
1213      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
1214
1215  // Set names for all arguments.
1216  unsigned Idx = 0;
1217  for (auto &Arg : F->args())
1218    Arg.setName(Args[Idx++]);
1219
1220  return F;
1221}
1222
1223Function *FunctionAST::codegen() {
1224  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
1225  // reference to it for use below.
1226  auto &P = *Proto;
1227  FunctionProtos[Proto->getName()] = std::move(Proto);
1228  Function *TheFunction = getFunction(P.getName());
1229  if (!TheFunction)
1230    return nullptr;
1231
1232  // If this is an operator, install it.
1233  if (P.isBinaryOp())
1234    BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
1235
1236  // Create a new basic block to start insertion into.
1237  BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
1238  Builder.SetInsertPoint(BB);
1239
1240  // Create a subprogram DIE for this function.
1241  DIFile *Unit = DBuilder->createFile(KSDbgInfo.TheCU->getFilename(),
1242                                      KSDbgInfo.TheCU->getDirectory());
1243  DIScope *FContext = Unit;
1244  unsigned LineNo = P.getLine();
1245  unsigned ScopeLine = LineNo;
1246  DISubprogram *SP = DBuilder->createFunction(
1247      FContext, P.getName(), StringRef(), Unit, LineNo,
1248      CreateFunctionType(TheFunction->arg_size(), Unit),
1249      false /* internal linkage */, true /* definition */, ScopeLine,
1250      DINode::FlagPrototyped, false);
1251  TheFunction->setSubprogram(SP);
1252
1253  // Push the current scope.
1254  KSDbgInfo.LexicalBlocks.push_back(SP);
1255
1256  // Unset the location for the prologue emission (leading instructions with no
1257  // location in a function are considered part of the prologue and the debugger
1258  // will run past them when breaking on a function)
1259  KSDbgInfo.emitLocation(nullptr);
1260
1261  // Record the function arguments in the NamedValues map.
1262  NamedValues.clear();
1263  unsigned ArgIdx = 0;
1264  for (auto &Arg : TheFunction->args()) {
1265    // Create an alloca for this variable.
1266    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName());
1267
1268    // Create a debug descriptor for the variable.
1269    DILocalVariable *D = DBuilder->createParameterVariable(
1270        SP, Arg.getName(), ++ArgIdx, Unit, LineNo, KSDbgInfo.getDoubleTy(),
1271        true);
1272
1273    DBuilder->insertDeclare(Alloca, D, DBuilder->createExpression(),
1274                            DebugLoc::get(LineNo, 0, SP),
1275                            Builder.GetInsertBlock());
1276
1277    // Store the initial value into the alloca.
1278    Builder.CreateStore(&Arg, Alloca);
1279
1280    // Add arguments to variable symbol table.
1281    NamedValues[Arg.getName()] = Alloca;
1282  }
1283
1284  KSDbgInfo.emitLocation(Body.get());
1285
1286  if (Value *RetVal = Body->codegen()) {
1287    // Finish off the function.
1288    Builder.CreateRet(RetVal);
1289
1290    // Pop off the lexical block for the function.
1291    KSDbgInfo.LexicalBlocks.pop_back();
1292
1293    // Validate the generated code, checking for consistency.
1294    verifyFunction(*TheFunction);
1295
1296    return TheFunction;
1297  }
1298
1299  // Error reading body, remove function.
1300  TheFunction->eraseFromParent();
1301
1302  if (P.isBinaryOp())
1303    BinopPrecedence.erase(Proto->getOperatorName());
1304
1305  // Pop off the lexical block for the function since we added it
1306  // unconditionally.
1307  KSDbgInfo.LexicalBlocks.pop_back();
1308
1309  return nullptr;
1310}
1311
1312//===----------------------------------------------------------------------===//
1313// Top-Level parsing and JIT Driver
1314//===----------------------------------------------------------------------===//
1315
1316static void InitializeModule() {
1317  // Open a new module.
1318  TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
1319  TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
1320}
1321
1322static void HandleDefinition() {
1323  if (auto FnAST = ParseDefinition()) {
1324    if (!FnAST->codegen())
1325      fprintf(stderr, "Error reading function definition:");
1326  } else {
1327    // Skip token for error recovery.
1328    getNextToken();
1329  }
1330}
1331
1332static void HandleExtern() {
1333  if (auto ProtoAST = ParseExtern()) {
1334    if (!ProtoAST->codegen())
1335      fprintf(stderr, "Error reading extern");
1336    else
1337      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
1338  } else {
1339    // Skip token for error recovery.
1340    getNextToken();
1341  }
1342}
1343
1344static void HandleTopLevelExpression() {
1345  // Evaluate a top-level expression into an anonymous function.
1346  if (auto FnAST = ParseTopLevelExpr()) {
1347    if (!FnAST->codegen()) {
1348      fprintf(stderr, "Error generating code for top level expr");
1349    }
1350  } else {
1351    // Skip token for error recovery.
1352    getNextToken();
1353  }
1354}
1355
1356/// top ::= definition | external | expression | ';'
1357static void MainLoop() {
1358  while (1) {
1359    switch (CurTok) {
1360    case tok_eof:
1361      return;
1362    case ';': // ignore top-level semicolons.
1363      getNextToken();
1364      break;
1365    case tok_def:
1366      HandleDefinition();
1367      break;
1368    case tok_extern:
1369      HandleExtern();
1370      break;
1371    default:
1372      HandleTopLevelExpression();
1373      break;
1374    }
1375  }
1376}
1377
1378//===----------------------------------------------------------------------===//
1379// "Library" functions that can be "extern'd" from user code.
1380//===----------------------------------------------------------------------===//
1381
1382/// putchard - putchar that takes a double and returns 0.
1383extern "C" double putchard(double X) {
1384  fputc((char)X, stderr);
1385  return 0;
1386}
1387
1388/// printd - printf that takes a double prints it as "%f\n", returning 0.
1389extern "C" double printd(double X) {
1390  fprintf(stderr, "%f\n", X);
1391  return 0;
1392}
1393
1394//===----------------------------------------------------------------------===//
1395// Main driver code.
1396//===----------------------------------------------------------------------===//
1397
1398int main() {
1399  InitializeNativeTarget();
1400  InitializeNativeTargetAsmPrinter();
1401  InitializeNativeTargetAsmParser();
1402
1403  // Install standard binary operators.
1404  // 1 is lowest precedence.
1405  BinopPrecedence['='] = 2;
1406  BinopPrecedence['<'] = 10;
1407  BinopPrecedence['+'] = 20;
1408  BinopPrecedence['-'] = 20;
1409  BinopPrecedence['*'] = 40; // highest.
1410
1411  // Prime the first token.
1412  getNextToken();
1413
1414  TheJIT = llvm::make_unique<KaleidoscopeJIT>();
1415
1416  InitializeModule();
1417
1418  // Add the current debug info version into the module.
1419  TheModule->addModuleFlag(Module::Warning, "Debug Info Version",
1420                           DEBUG_METADATA_VERSION);
1421
1422  // Darwin only supports dwarf2.
1423  if (Triple(sys::getProcessTriple()).isOSDarwin())
1424    TheModule->addModuleFlag(llvm::Module::Warning, "Dwarf Version", 2);
1425
1426  // Construct the DIBuilder, we do this here because we need the module.
1427  DBuilder = llvm::make_unique<DIBuilder>(*TheModule);
1428
1429  // Create the compile unit for the module.
1430  // Currently down as "fib.ks" as a filename since we're redirecting stdin
1431  // but we'd like actual source locations.
1432  KSDbgInfo.TheCU = DBuilder->createCompileUnit(
1433      dwarf::DW_LANG_C, "fib.ks", ".", "Kaleidoscope Compiler", 0, "", 0);
1434
1435  // Run the main "interpreter loop" now.
1436  MainLoop();
1437
1438  // Finalize the debug info.
1439  DBuilder->finalize();
1440
1441  // Print out all of the generated code.
1442  TheModule->dump();
1443
1444  return 0;
1445}
1446