toy.cpp revision de2d8694e25a814696358e95141f4b1aa4d8847e
1#include "llvm/ADT/APFloat.h"
2#include "llvm/ADT/STLExtras.h"
3#include "llvm/IR/BasicBlock.h"
4#include "llvm/IR/Constants.h"
5#include "llvm/IR/DerivedTypes.h"
6#include "llvm/IR/Function.h"
7#include "llvm/IR/Instructions.h"
8#include "llvm/IR/IRBuilder.h"
9#include "llvm/IR/LLVMContext.h"
10#include "llvm/IR/LegacyPassManager.h"
11#include "llvm/IR/Module.h"
12#include "llvm/IR/Type.h"
13#include "llvm/IR/Verifier.h"
14#include "llvm/Support/TargetSelect.h"
15#include "llvm/Target/TargetMachine.h"
16#include "llvm/Transforms/Scalar.h"
17#include "llvm/Transforms/Scalar/GVN.h"
18#include "KaleidoscopeJIT.h"
19#include <cassert>
20#include <cctype>
21#include <cstdint>
22#include <cstdio>
23#include <cstdlib>
24#include <map>
25#include <memory>
26#include <string>
27#include <utility>
28#include <vector>
29
30using namespace llvm;
31using namespace llvm::orc;
32
33//===----------------------------------------------------------------------===//
34// Lexer
35//===----------------------------------------------------------------------===//
36
37// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
38// of these for known things.
39enum Token {
40  tok_eof = -1,
41
42  // commands
43  tok_def = -2,
44  tok_extern = -3,
45
46  // primary
47  tok_identifier = -4,
48  tok_number = -5,
49
50  // control
51  tok_if = -6,
52  tok_then = -7,
53  tok_else = -8,
54  tok_for = -9,
55  tok_in = -10,
56
57  // operators
58  tok_binary = -11,
59  tok_unary = -12,
60
61  // var definition
62  tok_var = -13
63};
64
65static std::string IdentifierStr; // Filled in if tok_identifier
66static double NumVal;             // Filled in if tok_number
67
68/// gettok - Return the next token from standard input.
69static int gettok() {
70  static int LastChar = ' ';
71
72  // Skip any whitespace.
73  while (isspace(LastChar))
74    LastChar = getchar();
75
76  if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
77    IdentifierStr = LastChar;
78    while (isalnum((LastChar = getchar())))
79      IdentifierStr += LastChar;
80
81    if (IdentifierStr == "def")
82      return tok_def;
83    if (IdentifierStr == "extern")
84      return tok_extern;
85    if (IdentifierStr == "if")
86      return tok_if;
87    if (IdentifierStr == "then")
88      return tok_then;
89    if (IdentifierStr == "else")
90      return tok_else;
91    if (IdentifierStr == "for")
92      return tok_for;
93    if (IdentifierStr == "in")
94      return tok_in;
95    if (IdentifierStr == "binary")
96      return tok_binary;
97    if (IdentifierStr == "unary")
98      return tok_unary;
99    if (IdentifierStr == "var")
100      return tok_var;
101    return tok_identifier;
102  }
103
104  if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
105    std::string NumStr;
106    do {
107      NumStr += LastChar;
108      LastChar = getchar();
109    } while (isdigit(LastChar) || LastChar == '.');
110
111    NumVal = strtod(NumStr.c_str(), nullptr);
112    return tok_number;
113  }
114
115  if (LastChar == '#') {
116    // Comment until end of line.
117    do
118      LastChar = getchar();
119    while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
120
121    if (LastChar != EOF)
122      return gettok();
123  }
124
125  // Check for end of file.  Don't eat the EOF.
126  if (LastChar == EOF)
127    return tok_eof;
128
129  // Otherwise, just return the character as its ascii value.
130  int ThisChar = LastChar;
131  LastChar = getchar();
132  return ThisChar;
133}
134
135//===----------------------------------------------------------------------===//
136// Abstract Syntax Tree (aka Parse Tree)
137//===----------------------------------------------------------------------===//
138namespace {
139/// ExprAST - Base class for all expression nodes.
140class ExprAST {
141public:
142  virtual ~ExprAST() {}
143  virtual Value *codegen() = 0;
144};
145
146/// NumberExprAST - Expression class for numeric literals like "1.0".
147class NumberExprAST : public ExprAST {
148  double Val;
149
150public:
151  NumberExprAST(double Val) : Val(Val) {}
152  Value *codegen() override;
153};
154
155/// VariableExprAST - Expression class for referencing a variable, like "a".
156class VariableExprAST : public ExprAST {
157  std::string Name;
158
159public:
160  VariableExprAST(const std::string &Name) : Name(Name) {}
161  const std::string &getName() const { return Name; }
162  Value *codegen() override;
163};
164
165/// UnaryExprAST - Expression class for a unary operator.
166class UnaryExprAST : public ExprAST {
167  char Opcode;
168  std::unique_ptr<ExprAST> Operand;
169
170public:
171  UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
172      : Opcode(Opcode), Operand(std::move(Operand)) {}
173  Value *codegen() override;
174};
175
176/// BinaryExprAST - Expression class for a binary operator.
177class BinaryExprAST : public ExprAST {
178  char Op;
179  std::unique_ptr<ExprAST> LHS, RHS;
180
181public:
182  BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
183                std::unique_ptr<ExprAST> RHS)
184      : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
185  Value *codegen() override;
186};
187
188/// CallExprAST - Expression class for function calls.
189class CallExprAST : public ExprAST {
190  std::string Callee;
191  std::vector<std::unique_ptr<ExprAST>> Args;
192
193public:
194  CallExprAST(const std::string &Callee,
195              std::vector<std::unique_ptr<ExprAST>> Args)
196      : Callee(Callee), Args(std::move(Args)) {}
197  Value *codegen() override;
198};
199
200/// IfExprAST - Expression class for if/then/else.
201class IfExprAST : public ExprAST {
202  std::unique_ptr<ExprAST> Cond, Then, Else;
203
204public:
205  IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
206            std::unique_ptr<ExprAST> Else)
207      : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
208  Value *codegen() override;
209};
210
211/// ForExprAST - Expression class for for/in.
212class ForExprAST : public ExprAST {
213  std::string VarName;
214  std::unique_ptr<ExprAST> Start, End, Step, Body;
215
216public:
217  ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
218             std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
219             std::unique_ptr<ExprAST> Body)
220      : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
221        Step(std::move(Step)), Body(std::move(Body)) {}
222  Value *codegen() override;
223};
224
225/// VarExprAST - Expression class for var/in
226class VarExprAST : public ExprAST {
227  std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
228  std::unique_ptr<ExprAST> Body;
229
230public:
231  VarExprAST(
232      std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames,
233      std::unique_ptr<ExprAST> Body)
234      : VarNames(std::move(VarNames)), Body(std::move(Body)) {}
235  Value *codegen() override;
236};
237
238/// PrototypeAST - This class represents the "prototype" for a function,
239/// which captures its name, and its argument names (thus implicitly the number
240/// of arguments the function takes), as well as if it is an operator.
241class PrototypeAST {
242  std::string Name;
243  std::vector<std::string> Args;
244  bool IsOperator;
245  unsigned Precedence; // Precedence if a binary op.
246
247public:
248  PrototypeAST(const std::string &Name, std::vector<std::string> Args,
249               bool IsOperator = false, unsigned Prec = 0)
250      : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
251        Precedence(Prec) {}
252  Function *codegen();
253  const std::string &getName() const { return Name; }
254
255  bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
256  bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
257
258  char getOperatorName() const {
259    assert(isUnaryOp() || isBinaryOp());
260    return Name[Name.size() - 1];
261  }
262
263  unsigned getBinaryPrecedence() const { return Precedence; }
264};
265
266/// FunctionAST - This class represents a function definition itself.
267class FunctionAST {
268  std::unique_ptr<PrototypeAST> Proto;
269  std::unique_ptr<ExprAST> Body;
270
271public:
272  FunctionAST(std::unique_ptr<PrototypeAST> Proto,
273              std::unique_ptr<ExprAST> Body)
274      : Proto(std::move(Proto)), Body(std::move(Body)) {}
275  Function *codegen();
276};
277} // end anonymous namespace
278
279//===----------------------------------------------------------------------===//
280// Parser
281//===----------------------------------------------------------------------===//
282
283/// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
284/// token the parser is looking at.  getNextToken reads another token from the
285/// lexer and updates CurTok with its results.
286static int CurTok;
287static int getNextToken() { return CurTok = gettok(); }
288
289/// BinopPrecedence - This holds the precedence for each binary operator that is
290/// defined.
291static std::map<char, int> BinopPrecedence;
292
293/// GetTokPrecedence - Get the precedence of the pending binary operator token.
294static int GetTokPrecedence() {
295  if (!isascii(CurTok))
296    return -1;
297
298  // Make sure it's a declared binop.
299  int TokPrec = BinopPrecedence[CurTok];
300  if (TokPrec <= 0)
301    return -1;
302  return TokPrec;
303}
304
305/// LogError* - These are little helper functions for error handling.
306std::unique_ptr<ExprAST> LogError(const char *Str) {
307  fprintf(stderr, "Error: %s\n", Str);
308  return nullptr;
309}
310
311std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
312  LogError(Str);
313  return nullptr;
314}
315
316static std::unique_ptr<ExprAST> ParseExpression();
317
318/// numberexpr ::= number
319static std::unique_ptr<ExprAST> ParseNumberExpr() {
320  auto Result = llvm::make_unique<NumberExprAST>(NumVal);
321  getNextToken(); // consume the number
322  return std::move(Result);
323}
324
325/// parenexpr ::= '(' expression ')'
326static std::unique_ptr<ExprAST> ParseParenExpr() {
327  getNextToken(); // eat (.
328  auto V = ParseExpression();
329  if (!V)
330    return nullptr;
331
332  if (CurTok != ')')
333    return LogError("expected ')'");
334  getNextToken(); // eat ).
335  return V;
336}
337
338/// identifierexpr
339///   ::= identifier
340///   ::= identifier '(' expression* ')'
341static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
342  std::string IdName = IdentifierStr;
343
344  getNextToken(); // eat identifier.
345
346  if (CurTok != '(') // Simple variable ref.
347    return llvm::make_unique<VariableExprAST>(IdName);
348
349  // Call.
350  getNextToken(); // eat (
351  std::vector<std::unique_ptr<ExprAST>> Args;
352  if (CurTok != ')') {
353    while (true) {
354      if (auto Arg = ParseExpression())
355        Args.push_back(std::move(Arg));
356      else
357        return nullptr;
358
359      if (CurTok == ')')
360        break;
361
362      if (CurTok != ',')
363        return LogError("Expected ')' or ',' in argument list");
364      getNextToken();
365    }
366  }
367
368  // Eat the ')'.
369  getNextToken();
370
371  return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
372}
373
374/// ifexpr ::= 'if' expression 'then' expression 'else' expression
375static std::unique_ptr<ExprAST> ParseIfExpr() {
376  getNextToken(); // eat the if.
377
378  // condition.
379  auto Cond = ParseExpression();
380  if (!Cond)
381    return nullptr;
382
383  if (CurTok != tok_then)
384    return LogError("expected then");
385  getNextToken(); // eat the then
386
387  auto Then = ParseExpression();
388  if (!Then)
389    return nullptr;
390
391  if (CurTok != tok_else)
392    return LogError("expected else");
393
394  getNextToken();
395
396  auto Else = ParseExpression();
397  if (!Else)
398    return nullptr;
399
400  return llvm::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
401                                      std::move(Else));
402}
403
404/// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
405static std::unique_ptr<ExprAST> ParseForExpr() {
406  getNextToken(); // eat the for.
407
408  if (CurTok != tok_identifier)
409    return LogError("expected identifier after for");
410
411  std::string IdName = IdentifierStr;
412  getNextToken(); // eat identifier.
413
414  if (CurTok != '=')
415    return LogError("expected '=' after for");
416  getNextToken(); // eat '='.
417
418  auto Start = ParseExpression();
419  if (!Start)
420    return nullptr;
421  if (CurTok != ',')
422    return LogError("expected ',' after for start value");
423  getNextToken();
424
425  auto End = ParseExpression();
426  if (!End)
427    return nullptr;
428
429  // The step value is optional.
430  std::unique_ptr<ExprAST> Step;
431  if (CurTok == ',') {
432    getNextToken();
433    Step = ParseExpression();
434    if (!Step)
435      return nullptr;
436  }
437
438  if (CurTok != tok_in)
439    return LogError("expected 'in' after for");
440  getNextToken(); // eat 'in'.
441
442  auto Body = ParseExpression();
443  if (!Body)
444    return nullptr;
445
446  return llvm::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
447                                       std::move(Step), std::move(Body));
448}
449
450/// varexpr ::= 'var' identifier ('=' expression)?
451//                    (',' identifier ('=' expression)?)* 'in' expression
452static std::unique_ptr<ExprAST> ParseVarExpr() {
453  getNextToken(); // eat the var.
454
455  std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
456
457  // At least one variable name is required.
458  if (CurTok != tok_identifier)
459    return LogError("expected identifier after var");
460
461  while (true) {
462    std::string Name = IdentifierStr;
463    getNextToken(); // eat identifier.
464
465    // Read the optional initializer.
466    std::unique_ptr<ExprAST> Init = nullptr;
467    if (CurTok == '=') {
468      getNextToken(); // eat the '='.
469
470      Init = ParseExpression();
471      if (!Init)
472        return nullptr;
473    }
474
475    VarNames.push_back(std::make_pair(Name, std::move(Init)));
476
477    // End of var list, exit loop.
478    if (CurTok != ',')
479      break;
480    getNextToken(); // eat the ','.
481
482    if (CurTok != tok_identifier)
483      return LogError("expected identifier list after var");
484  }
485
486  // At this point, we have to have 'in'.
487  if (CurTok != tok_in)
488    return LogError("expected 'in' keyword after 'var'");
489  getNextToken(); // eat 'in'.
490
491  auto Body = ParseExpression();
492  if (!Body)
493    return nullptr;
494
495  return llvm::make_unique<VarExprAST>(std::move(VarNames), std::move(Body));
496}
497
498/// primary
499///   ::= identifierexpr
500///   ::= numberexpr
501///   ::= parenexpr
502///   ::= ifexpr
503///   ::= forexpr
504///   ::= varexpr
505static std::unique_ptr<ExprAST> ParsePrimary() {
506  switch (CurTok) {
507  default:
508    return LogError("unknown token when expecting an expression");
509  case tok_identifier:
510    return ParseIdentifierExpr();
511  case tok_number:
512    return ParseNumberExpr();
513  case '(':
514    return ParseParenExpr();
515  case tok_if:
516    return ParseIfExpr();
517  case tok_for:
518    return ParseForExpr();
519  case tok_var:
520    return ParseVarExpr();
521  }
522}
523
524/// unary
525///   ::= primary
526///   ::= '!' unary
527static std::unique_ptr<ExprAST> ParseUnary() {
528  // If the current token is not an operator, it must be a primary expr.
529  if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
530    return ParsePrimary();
531
532  // If this is a unary operator, read it.
533  int Opc = CurTok;
534  getNextToken();
535  if (auto Operand = ParseUnary())
536    return llvm::make_unique<UnaryExprAST>(Opc, std::move(Operand));
537  return nullptr;
538}
539
540/// binoprhs
541///   ::= ('+' unary)*
542static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
543                                              std::unique_ptr<ExprAST> LHS) {
544  // If this is a binop, find its precedence.
545  while (true) {
546    int TokPrec = GetTokPrecedence();
547
548    // If this is a binop that binds at least as tightly as the current binop,
549    // consume it, otherwise we are done.
550    if (TokPrec < ExprPrec)
551      return LHS;
552
553    // Okay, we know this is a binop.
554    int BinOp = CurTok;
555    getNextToken(); // eat binop
556
557    // Parse the unary expression after the binary operator.
558    auto RHS = ParseUnary();
559    if (!RHS)
560      return nullptr;
561
562    // If BinOp binds less tightly with RHS than the operator after RHS, let
563    // the pending operator take RHS as its LHS.
564    int NextPrec = GetTokPrecedence();
565    if (TokPrec < NextPrec) {
566      RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
567      if (!RHS)
568        return nullptr;
569    }
570
571    // Merge LHS/RHS.
572    LHS =
573        llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
574  }
575}
576
577/// expression
578///   ::= unary binoprhs
579///
580static std::unique_ptr<ExprAST> ParseExpression() {
581  auto LHS = ParseUnary();
582  if (!LHS)
583    return nullptr;
584
585  return ParseBinOpRHS(0, std::move(LHS));
586}
587
588/// prototype
589///   ::= id '(' id* ')'
590///   ::= binary LETTER number? (id, id)
591///   ::= unary LETTER (id)
592static std::unique_ptr<PrototypeAST> ParsePrototype() {
593  std::string FnName;
594
595  unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
596  unsigned BinaryPrecedence = 30;
597
598  switch (CurTok) {
599  default:
600    return LogErrorP("Expected function name in prototype");
601  case tok_identifier:
602    FnName = IdentifierStr;
603    Kind = 0;
604    getNextToken();
605    break;
606  case tok_unary:
607    getNextToken();
608    if (!isascii(CurTok))
609      return LogErrorP("Expected unary operator");
610    FnName = "unary";
611    FnName += (char)CurTok;
612    Kind = 1;
613    getNextToken();
614    break;
615  case tok_binary:
616    getNextToken();
617    if (!isascii(CurTok))
618      return LogErrorP("Expected binary operator");
619    FnName = "binary";
620    FnName += (char)CurTok;
621    Kind = 2;
622    getNextToken();
623
624    // Read the precedence if present.
625    if (CurTok == tok_number) {
626      if (NumVal < 1 || NumVal > 100)
627        return LogErrorP("Invalid precedecnce: must be 1..100");
628      BinaryPrecedence = (unsigned)NumVal;
629      getNextToken();
630    }
631    break;
632  }
633
634  if (CurTok != '(')
635    return LogErrorP("Expected '(' in prototype");
636
637  std::vector<std::string> ArgNames;
638  while (getNextToken() == tok_identifier)
639    ArgNames.push_back(IdentifierStr);
640  if (CurTok != ')')
641    return LogErrorP("Expected ')' in prototype");
642
643  // success.
644  getNextToken(); // eat ')'.
645
646  // Verify right number of names for operator.
647  if (Kind && ArgNames.size() != Kind)
648    return LogErrorP("Invalid number of operands for operator");
649
650  return llvm::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
651                                         BinaryPrecedence);
652}
653
654/// definition ::= 'def' prototype expression
655static std::unique_ptr<FunctionAST> ParseDefinition() {
656  getNextToken(); // eat def.
657  auto Proto = ParsePrototype();
658  if (!Proto)
659    return nullptr;
660
661  if (auto E = ParseExpression())
662    return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
663  return nullptr;
664}
665
666/// toplevelexpr ::= expression
667static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
668  if (auto E = ParseExpression()) {
669    // Make an anonymous proto.
670    auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
671                                                 std::vector<std::string>());
672    return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
673  }
674  return nullptr;
675}
676
677/// external ::= 'extern' prototype
678static std::unique_ptr<PrototypeAST> ParseExtern() {
679  getNextToken(); // eat extern.
680  return ParsePrototype();
681}
682
683//===----------------------------------------------------------------------===//
684// Code Generation
685//===----------------------------------------------------------------------===//
686
687static LLVMContext TheContext;
688static IRBuilder<> Builder(TheContext);
689static std::unique_ptr<Module> TheModule;
690static std::map<std::string, AllocaInst *> NamedValues;
691static std::unique_ptr<KaleidoscopeJIT> TheJIT;
692static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
693
694Value *LogErrorV(const char *Str) {
695  LogError(Str);
696  return nullptr;
697}
698
699Function *getFunction(std::string Name) {
700  // First, see if the function has already been added to the current module.
701  if (auto *F = TheModule->getFunction(Name))
702    return F;
703
704  // If not, check whether we can codegen the declaration from some existing
705  // prototype.
706  auto FI = FunctionProtos.find(Name);
707  if (FI != FunctionProtos.end())
708    return FI->second->codegen();
709
710  // If no existing prototype exists, return null.
711  return nullptr;
712}
713
714/// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of
715/// the function.  This is used for mutable variables etc.
716static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
717                                          const std::string &VarName) {
718  IRBuilder<> TmpB(&TheFunction->getEntryBlock(),
719                   TheFunction->getEntryBlock().begin());
720  return TmpB.CreateAlloca(Type::getDoubleTy(TheContext), nullptr, VarName);
721}
722
723Value *NumberExprAST::codegen() {
724  return ConstantFP::get(TheContext, APFloat(Val));
725}
726
727Value *VariableExprAST::codegen() {
728  // Look this variable up in the function.
729  Value *V = NamedValues[Name];
730  if (!V)
731    return LogErrorV("Unknown variable name");
732
733  // Load the value.
734  return Builder.CreateLoad(V, Name.c_str());
735}
736
737Value *UnaryExprAST::codegen() {
738  Value *OperandV = Operand->codegen();
739  if (!OperandV)
740    return nullptr;
741
742  Function *F = getFunction(std::string("unary") + Opcode);
743  if (!F)
744    return LogErrorV("Unknown unary operator");
745
746  return Builder.CreateCall(F, OperandV, "unop");
747}
748
749Value *BinaryExprAST::codegen() {
750  // Special case '=' because we don't want to emit the LHS as an expression.
751  if (Op == '=') {
752    // Assignment requires the LHS to be an identifier.
753    // This assume we're building without RTTI because LLVM builds that way by
754    // default.  If you build LLVM with RTTI this can be changed to a
755    // dynamic_cast for automatic error checking.
756    VariableExprAST *LHSE = static_cast<VariableExprAST *>(LHS.get());
757    if (!LHSE)
758      return LogErrorV("destination of '=' must be a variable");
759    // Codegen the RHS.
760    Value *Val = RHS->codegen();
761    if (!Val)
762      return nullptr;
763
764    // Look up the name.
765    Value *Variable = NamedValues[LHSE->getName()];
766    if (!Variable)
767      return LogErrorV("Unknown variable name");
768
769    Builder.CreateStore(Val, Variable);
770    return Val;
771  }
772
773  Value *L = LHS->codegen();
774  Value *R = RHS->codegen();
775  if (!L || !R)
776    return nullptr;
777
778  switch (Op) {
779  case '+':
780    return Builder.CreateFAdd(L, R, "addtmp");
781  case '-':
782    return Builder.CreateFSub(L, R, "subtmp");
783  case '*':
784    return Builder.CreateFMul(L, R, "multmp");
785  case '<':
786    L = Builder.CreateFCmpULT(L, R, "cmptmp");
787    // Convert bool 0/1 to double 0.0 or 1.0
788    return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
789  default:
790    break;
791  }
792
793  // If it wasn't a builtin binary operator, it must be a user defined one. Emit
794  // a call to it.
795  Function *F = getFunction(std::string("binary") + Op);
796  assert(F && "binary operator not found!");
797
798  Value *Ops[] = {L, R};
799  return Builder.CreateCall(F, Ops, "binop");
800}
801
802Value *CallExprAST::codegen() {
803  // Look up the name in the global module table.
804  Function *CalleeF = getFunction(Callee);
805  if (!CalleeF)
806    return LogErrorV("Unknown function referenced");
807
808  // If argument mismatch error.
809  if (CalleeF->arg_size() != Args.size())
810    return LogErrorV("Incorrect # arguments passed");
811
812  std::vector<Value *> ArgsV;
813  for (unsigned i = 0, e = Args.size(); i != e; ++i) {
814    ArgsV.push_back(Args[i]->codegen());
815    if (!ArgsV.back())
816      return nullptr;
817  }
818
819  return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
820}
821
822Value *IfExprAST::codegen() {
823  Value *CondV = Cond->codegen();
824  if (!CondV)
825    return nullptr;
826
827  // Convert condition to a bool by comparing equal to 0.0.
828  CondV = Builder.CreateFCmpONE(
829      CondV, ConstantFP::get(TheContext, APFloat(0.0)), "ifcond");
830
831  Function *TheFunction = Builder.GetInsertBlock()->getParent();
832
833  // Create blocks for the then and else cases.  Insert the 'then' block at the
834  // end of the function.
835  BasicBlock *ThenBB = BasicBlock::Create(TheContext, "then", TheFunction);
836  BasicBlock *ElseBB = BasicBlock::Create(TheContext, "else");
837  BasicBlock *MergeBB = BasicBlock::Create(TheContext, "ifcont");
838
839  Builder.CreateCondBr(CondV, ThenBB, ElseBB);
840
841  // Emit then value.
842  Builder.SetInsertPoint(ThenBB);
843
844  Value *ThenV = Then->codegen();
845  if (!ThenV)
846    return nullptr;
847
848  Builder.CreateBr(MergeBB);
849  // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
850  ThenBB = Builder.GetInsertBlock();
851
852  // Emit else block.
853  TheFunction->getBasicBlockList().push_back(ElseBB);
854  Builder.SetInsertPoint(ElseBB);
855
856  Value *ElseV = Else->codegen();
857  if (!ElseV)
858    return nullptr;
859
860  Builder.CreateBr(MergeBB);
861  // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
862  ElseBB = Builder.GetInsertBlock();
863
864  // Emit merge block.
865  TheFunction->getBasicBlockList().push_back(MergeBB);
866  Builder.SetInsertPoint(MergeBB);
867  PHINode *PN = Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, "iftmp");
868
869  PN->addIncoming(ThenV, ThenBB);
870  PN->addIncoming(ElseV, ElseBB);
871  return PN;
872}
873
874// Output for-loop as:
875//   var = alloca double
876//   ...
877//   start = startexpr
878//   store start -> var
879//   goto loop
880// loop:
881//   ...
882//   bodyexpr
883//   ...
884// loopend:
885//   step = stepexpr
886//   endcond = endexpr
887//
888//   curvar = load var
889//   nextvar = curvar + step
890//   store nextvar -> var
891//   br endcond, loop, endloop
892// outloop:
893Value *ForExprAST::codegen() {
894  Function *TheFunction = Builder.GetInsertBlock()->getParent();
895
896  // Create an alloca for the variable in the entry block.
897  AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
898
899  // Emit the start code first, without 'variable' in scope.
900  Value *StartVal = Start->codegen();
901  if (!StartVal)
902    return nullptr;
903
904  // Store the value into the alloca.
905  Builder.CreateStore(StartVal, Alloca);
906
907  // Make the new basic block for the loop header, inserting after current
908  // block.
909  BasicBlock *LoopBB = BasicBlock::Create(TheContext, "loop", TheFunction);
910
911  // Insert an explicit fall through from the current block to the LoopBB.
912  Builder.CreateBr(LoopBB);
913
914  // Start insertion in LoopBB.
915  Builder.SetInsertPoint(LoopBB);
916
917  // Within the loop, the variable is defined equal to the PHI node.  If it
918  // shadows an existing variable, we have to restore it, so save it now.
919  AllocaInst *OldVal = NamedValues[VarName];
920  NamedValues[VarName] = Alloca;
921
922  // Emit the body of the loop.  This, like any other expr, can change the
923  // current BB.  Note that we ignore the value computed by the body, but don't
924  // allow an error.
925  if (!Body->codegen())
926    return nullptr;
927
928  // Emit the step value.
929  Value *StepVal = nullptr;
930  if (Step) {
931    StepVal = Step->codegen();
932    if (!StepVal)
933      return nullptr;
934  } else {
935    // If not specified, use 1.0.
936    StepVal = ConstantFP::get(TheContext, APFloat(1.0));
937  }
938
939  // Compute the end condition.
940  Value *EndCond = End->codegen();
941  if (!EndCond)
942    return nullptr;
943
944  // Reload, increment, and restore the alloca.  This handles the case where
945  // the body of the loop mutates the variable.
946  Value *CurVar = Builder.CreateLoad(Alloca, VarName.c_str());
947  Value *NextVar = Builder.CreateFAdd(CurVar, StepVal, "nextvar");
948  Builder.CreateStore(NextVar, Alloca);
949
950  // Convert condition to a bool by comparing equal to 0.0.
951  EndCond = Builder.CreateFCmpONE(
952      EndCond, ConstantFP::get(TheContext, APFloat(0.0)), "loopcond");
953
954  // Create the "after loop" block and insert it.
955  BasicBlock *AfterBB =
956      BasicBlock::Create(TheContext, "afterloop", TheFunction);
957
958  // Insert the conditional branch into the end of LoopEndBB.
959  Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
960
961  // Any new code will be inserted in AfterBB.
962  Builder.SetInsertPoint(AfterBB);
963
964  // Restore the unshadowed variable.
965  if (OldVal)
966    NamedValues[VarName] = OldVal;
967  else
968    NamedValues.erase(VarName);
969
970  // for expr always returns 0.0.
971  return Constant::getNullValue(Type::getDoubleTy(TheContext));
972}
973
974Value *VarExprAST::codegen() {
975  std::vector<AllocaInst *> OldBindings;
976
977  Function *TheFunction = Builder.GetInsertBlock()->getParent();
978
979  // Register all variables and emit their initializer.
980  for (unsigned i = 0, e = VarNames.size(); i != e; ++i) {
981    const std::string &VarName = VarNames[i].first;
982    ExprAST *Init = VarNames[i].second.get();
983
984    // Emit the initializer before adding the variable to scope, this prevents
985    // the initializer from referencing the variable itself, and permits stuff
986    // like this:
987    //  var a = 1 in
988    //    var a = a in ...   # refers to outer 'a'.
989    Value *InitVal;
990    if (Init) {
991      InitVal = Init->codegen();
992      if (!InitVal)
993        return nullptr;
994    } else { // If not specified, use 0.0.
995      InitVal = ConstantFP::get(TheContext, APFloat(0.0));
996    }
997
998    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
999    Builder.CreateStore(InitVal, Alloca);
1000
1001    // Remember the old variable binding so that we can restore the binding when
1002    // we unrecurse.
1003    OldBindings.push_back(NamedValues[VarName]);
1004
1005    // Remember this binding.
1006    NamedValues[VarName] = Alloca;
1007  }
1008
1009  // Codegen the body, now that all vars are in scope.
1010  Value *BodyVal = Body->codegen();
1011  if (!BodyVal)
1012    return nullptr;
1013
1014  // Pop all our variables from scope.
1015  for (unsigned i = 0, e = VarNames.size(); i != e; ++i)
1016    NamedValues[VarNames[i].first] = OldBindings[i];
1017
1018  // Return the body computation.
1019  return BodyVal;
1020}
1021
1022Function *PrototypeAST::codegen() {
1023  // Make the function type:  double(double,double) etc.
1024  std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
1025  FunctionType *FT =
1026      FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
1027
1028  Function *F =
1029      Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
1030
1031  // Set names for all arguments.
1032  unsigned Idx = 0;
1033  for (auto &Arg : F->args())
1034    Arg.setName(Args[Idx++]);
1035
1036  return F;
1037}
1038
1039Function *FunctionAST::codegen() {
1040  // Transfer ownership of the prototype to the FunctionProtos map, but keep a
1041  // reference to it for use below.
1042  auto &P = *Proto;
1043  FunctionProtos[Proto->getName()] = std::move(Proto);
1044  Function *TheFunction = getFunction(P.getName());
1045  if (!TheFunction)
1046    return nullptr;
1047
1048  // If this is an operator, install it.
1049  if (P.isBinaryOp())
1050    BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
1051
1052  // Create a new basic block to start insertion into.
1053  BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
1054  Builder.SetInsertPoint(BB);
1055
1056  // Record the function arguments in the NamedValues map.
1057  NamedValues.clear();
1058  for (auto &Arg : TheFunction->args()) {
1059    // Create an alloca for this variable.
1060    AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName());
1061
1062    // Store the initial value into the alloca.
1063    Builder.CreateStore(&Arg, Alloca);
1064
1065    // Add arguments to variable symbol table.
1066    NamedValues[Arg.getName()] = Alloca;
1067  }
1068
1069  if (Value *RetVal = Body->codegen()) {
1070    // Finish off the function.
1071    Builder.CreateRet(RetVal);
1072
1073    // Validate the generated code, checking for consistency.
1074    verifyFunction(*TheFunction);
1075
1076    return TheFunction;
1077  }
1078
1079  // Error reading body, remove function.
1080  TheFunction->eraseFromParent();
1081
1082  if (P.isBinaryOp())
1083    BinopPrecedence.erase(Proto->getOperatorName());
1084  return nullptr;
1085}
1086
1087//===----------------------------------------------------------------------===//
1088// Top-Level parsing and JIT Driver
1089//===----------------------------------------------------------------------===//
1090
1091static void InitializeModule() {
1092  // Open a new module.
1093  TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
1094  TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
1095}
1096
1097static void HandleDefinition() {
1098  if (auto FnAST = ParseDefinition()) {
1099    if (auto *FnIR = FnAST->codegen()) {
1100      fprintf(stderr, "Read function definition:");
1101      FnIR->dump();
1102      TheJIT->addModule(std::move(TheModule));
1103      InitializeModule();
1104    }
1105  } else {
1106    // Skip token for error recovery.
1107    getNextToken();
1108  }
1109}
1110
1111static void HandleExtern() {
1112  if (auto ProtoAST = ParseExtern()) {
1113    if (auto *FnIR = ProtoAST->codegen()) {
1114      fprintf(stderr, "Read extern: ");
1115      FnIR->dump();
1116      FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
1117    }
1118  } else {
1119    // Skip token for error recovery.
1120    getNextToken();
1121  }
1122}
1123
1124static void HandleTopLevelExpression() {
1125  // Evaluate a top-level expression into an anonymous function.
1126  if (auto FnAST = ParseTopLevelExpr()) {
1127    if (FnAST->codegen()) {
1128      // JIT the module containing the anonymous expression, keeping a handle so
1129      // we can free it later.
1130      auto H = TheJIT->addModule(std::move(TheModule));
1131      InitializeModule();
1132
1133      // Search the JIT for the __anon_expr symbol.
1134      auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
1135      assert(ExprSymbol && "Function not found");
1136
1137      // Get the symbol's address and cast it to the right type (takes no
1138      // arguments, returns a double) so we can call it as a native function.
1139      double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
1140      fprintf(stderr, "Evaluated to %f\n", FP());
1141
1142      // Delete the anonymous expression module from the JIT.
1143      TheJIT->removeModule(H);
1144    }
1145  } else {
1146    // Skip token for error recovery.
1147    getNextToken();
1148  }
1149}
1150
1151/// top ::= definition | external | expression | ';'
1152static void MainLoop() {
1153  while (true) {
1154    fprintf(stderr, "ready> ");
1155    switch (CurTok) {
1156    case tok_eof:
1157      return;
1158    case ';': // ignore top-level semicolons.
1159      getNextToken();
1160      break;
1161    case tok_def:
1162      HandleDefinition();
1163      break;
1164    case tok_extern:
1165      HandleExtern();
1166      break;
1167    default:
1168      HandleTopLevelExpression();
1169      break;
1170    }
1171  }
1172}
1173
1174//===----------------------------------------------------------------------===//
1175// "Library" functions that can be "extern'd" from user code.
1176//===----------------------------------------------------------------------===//
1177
1178/// putchard - putchar that takes a double and returns 0.
1179extern "C" double putchard(double X) {
1180  fputc((char)X, stderr);
1181  return 0;
1182}
1183
1184/// printd - printf that takes a double prints it as "%f\n", returning 0.
1185extern "C" double printd(double X) {
1186  fprintf(stderr, "%f\n", X);
1187  return 0;
1188}
1189
1190//===----------------------------------------------------------------------===//
1191// Main driver code.
1192//===----------------------------------------------------------------------===//
1193
1194int main() {
1195  InitializeNativeTarget();
1196  InitializeNativeTargetAsmPrinter();
1197  InitializeNativeTargetAsmParser();
1198
1199  // Install standard binary operators.
1200  // 1 is lowest precedence.
1201  BinopPrecedence['='] = 2;
1202  BinopPrecedence['<'] = 10;
1203  BinopPrecedence['+'] = 20;
1204  BinopPrecedence['-'] = 20;
1205  BinopPrecedence['*'] = 40; // highest.
1206
1207  // Prime the first token.
1208  fprintf(stderr, "ready> ");
1209  getNextToken();
1210
1211  TheJIT = llvm::make_unique<KaleidoscopeJIT>();
1212
1213  InitializeModule();
1214
1215  // Run the main "interpreter loop" now.
1216  MainLoop();
1217
1218  return 0;
1219}
1220