1#include "llvm/ADT/STLExtras.h" 2#include "llvm/Analysis/Passes.h" 3#include "llvm/IR/IRBuilder.h" 4#include "llvm/IR/LLVMContext.h" 5#include "llvm/IR/LegacyPassManager.h" 6#include "llvm/IR/Module.h" 7#include "llvm/IR/Verifier.h" 8#include "llvm/Support/TargetSelect.h" 9#include "llvm/Transforms/Scalar.h" 10#include <cctype> 11#include <cstdio> 12#include <map> 13#include <string> 14#include <vector> 15#include "../include/KaleidoscopeJIT.h" 16 17using namespace llvm; 18using namespace llvm::orc; 19 20//===----------------------------------------------------------------------===// 21// Lexer 22//===----------------------------------------------------------------------===// 23 24// The lexer returns tokens [0-255] if it is an unknown character, otherwise one 25// of these for known things. 26enum Token { 27 tok_eof = -1, 28 29 // commands 30 tok_def = -2, 31 tok_extern = -3, 32 33 // primary 34 tok_identifier = -4, 35 tok_number = -5 36}; 37 38static std::string IdentifierStr; // Filled in if tok_identifier 39static double NumVal; // Filled in if tok_number 40 41/// gettok - Return the next token from standard input. 42static int gettok() { 43 static int LastChar = ' '; 44 45 // Skip any whitespace. 46 while (isspace(LastChar)) 47 LastChar = getchar(); 48 49 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]* 50 IdentifierStr = LastChar; 51 while (isalnum((LastChar = getchar()))) 52 IdentifierStr += LastChar; 53 54 if (IdentifierStr == "def") 55 return tok_def; 56 if (IdentifierStr == "extern") 57 return tok_extern; 58 return tok_identifier; 59 } 60 61 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ 62 std::string NumStr; 63 do { 64 NumStr += LastChar; 65 LastChar = getchar(); 66 } while (isdigit(LastChar) || LastChar == '.'); 67 68 NumVal = strtod(NumStr.c_str(), nullptr); 69 return tok_number; 70 } 71 72 if (LastChar == '#') { 73 // Comment until end of line. 74 do 75 LastChar = getchar(); 76 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); 77 78 if (LastChar != EOF) 79 return gettok(); 80 } 81 82 // Check for end of file. Don't eat the EOF. 83 if (LastChar == EOF) 84 return tok_eof; 85 86 // Otherwise, just return the character as its ascii value. 87 int ThisChar = LastChar; 88 LastChar = getchar(); 89 return ThisChar; 90} 91 92//===----------------------------------------------------------------------===// 93// Abstract Syntax Tree (aka Parse Tree) 94//===----------------------------------------------------------------------===// 95namespace { 96/// ExprAST - Base class for all expression nodes. 97class ExprAST { 98public: 99 virtual ~ExprAST() {} 100 virtual Value *codegen() = 0; 101}; 102 103/// NumberExprAST - Expression class for numeric literals like "1.0". 104class NumberExprAST : public ExprAST { 105 double Val; 106 107public: 108 NumberExprAST(double Val) : Val(Val) {} 109 Value *codegen() override; 110}; 111 112/// VariableExprAST - Expression class for referencing a variable, like "a". 113class VariableExprAST : public ExprAST { 114 std::string Name; 115 116public: 117 VariableExprAST(const std::string &Name) : Name(Name) {} 118 Value *codegen() override; 119}; 120 121/// BinaryExprAST - Expression class for a binary operator. 122class BinaryExprAST : public ExprAST { 123 char Op; 124 std::unique_ptr<ExprAST> LHS, RHS; 125 126public: 127 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS, 128 std::unique_ptr<ExprAST> RHS) 129 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {} 130 Value *codegen() override; 131}; 132 133/// CallExprAST - Expression class for function calls. 134class CallExprAST : public ExprAST { 135 std::string Callee; 136 std::vector<std::unique_ptr<ExprAST>> Args; 137 138public: 139 CallExprAST(const std::string &Callee, 140 std::vector<std::unique_ptr<ExprAST>> Args) 141 : Callee(Callee), Args(std::move(Args)) {} 142 Value *codegen() override; 143}; 144 145/// PrototypeAST - This class represents the "prototype" for a function, 146/// which captures its name, and its argument names (thus implicitly the number 147/// of arguments the function takes). 148class PrototypeAST { 149 std::string Name; 150 std::vector<std::string> Args; 151 152public: 153 PrototypeAST(const std::string &Name, std::vector<std::string> Args) 154 : Name(Name), Args(std::move(Args)) {} 155 Function *codegen(); 156 const std::string &getName() const { return Name; } 157}; 158 159/// FunctionAST - This class represents a function definition itself. 160class FunctionAST { 161 std::unique_ptr<PrototypeAST> Proto; 162 std::unique_ptr<ExprAST> Body; 163 164public: 165 FunctionAST(std::unique_ptr<PrototypeAST> Proto, 166 std::unique_ptr<ExprAST> Body) 167 : Proto(std::move(Proto)), Body(std::move(Body)) {} 168 Function *codegen(); 169}; 170} // end anonymous namespace 171 172//===----------------------------------------------------------------------===// 173// Parser 174//===----------------------------------------------------------------------===// 175 176/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current 177/// token the parser is looking at. getNextToken reads another token from the 178/// lexer and updates CurTok with its results. 179static int CurTok; 180static int getNextToken() { return CurTok = gettok(); } 181 182/// BinopPrecedence - This holds the precedence for each binary operator that is 183/// defined. 184static std::map<char, int> BinopPrecedence; 185 186/// GetTokPrecedence - Get the precedence of the pending binary operator token. 187static int GetTokPrecedence() { 188 if (!isascii(CurTok)) 189 return -1; 190 191 // Make sure it's a declared binop. 192 int TokPrec = BinopPrecedence[CurTok]; 193 if (TokPrec <= 0) 194 return -1; 195 return TokPrec; 196} 197 198/// Error* - These are little helper functions for error handling. 199std::unique_ptr<ExprAST> Error(const char *Str) { 200 fprintf(stderr, "Error: %s\n", Str); 201 return nullptr; 202} 203 204std::unique_ptr<PrototypeAST> ErrorP(const char *Str) { 205 Error(Str); 206 return nullptr; 207} 208 209static std::unique_ptr<ExprAST> ParseExpression(); 210 211/// numberexpr ::= number 212static std::unique_ptr<ExprAST> ParseNumberExpr() { 213 auto Result = llvm::make_unique<NumberExprAST>(NumVal); 214 getNextToken(); // consume the number 215 return std::move(Result); 216} 217 218/// parenexpr ::= '(' expression ')' 219static std::unique_ptr<ExprAST> ParseParenExpr() { 220 getNextToken(); // eat (. 221 auto V = ParseExpression(); 222 if (!V) 223 return nullptr; 224 225 if (CurTok != ')') 226 return Error("expected ')'"); 227 getNextToken(); // eat ). 228 return V; 229} 230 231/// identifierexpr 232/// ::= identifier 233/// ::= identifier '(' expression* ')' 234static std::unique_ptr<ExprAST> ParseIdentifierExpr() { 235 std::string IdName = IdentifierStr; 236 237 getNextToken(); // eat identifier. 238 239 if (CurTok != '(') // Simple variable ref. 240 return llvm::make_unique<VariableExprAST>(IdName); 241 242 // Call. 243 getNextToken(); // eat ( 244 std::vector<std::unique_ptr<ExprAST>> Args; 245 if (CurTok != ')') { 246 while (1) { 247 if (auto Arg = ParseExpression()) 248 Args.push_back(std::move(Arg)); 249 else 250 return nullptr; 251 252 if (CurTok == ')') 253 break; 254 255 if (CurTok != ',') 256 return Error("Expected ')' or ',' in argument list"); 257 getNextToken(); 258 } 259 } 260 261 // Eat the ')'. 262 getNextToken(); 263 264 return llvm::make_unique<CallExprAST>(IdName, std::move(Args)); 265} 266 267/// primary 268/// ::= identifierexpr 269/// ::= numberexpr 270/// ::= parenexpr 271static std::unique_ptr<ExprAST> ParsePrimary() { 272 switch (CurTok) { 273 default: 274 return Error("unknown token when expecting an expression"); 275 case tok_identifier: 276 return ParseIdentifierExpr(); 277 case tok_number: 278 return ParseNumberExpr(); 279 case '(': 280 return ParseParenExpr(); 281 } 282} 283 284/// binoprhs 285/// ::= ('+' primary)* 286static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, 287 std::unique_ptr<ExprAST> LHS) { 288 // If this is a binop, find its precedence. 289 while (1) { 290 int TokPrec = GetTokPrecedence(); 291 292 // If this is a binop that binds at least as tightly as the current binop, 293 // consume it, otherwise we are done. 294 if (TokPrec < ExprPrec) 295 return LHS; 296 297 // Okay, we know this is a binop. 298 int BinOp = CurTok; 299 getNextToken(); // eat binop 300 301 // Parse the primary expression after the binary operator. 302 auto RHS = ParsePrimary(); 303 if (!RHS) 304 return nullptr; 305 306 // If BinOp binds less tightly with RHS than the operator after RHS, let 307 // the pending operator take RHS as its LHS. 308 int NextPrec = GetTokPrecedence(); 309 if (TokPrec < NextPrec) { 310 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); 311 if (!RHS) 312 return nullptr; 313 } 314 315 // Merge LHS/RHS. 316 LHS = 317 llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS)); 318 } 319} 320 321/// expression 322/// ::= primary binoprhs 323/// 324static std::unique_ptr<ExprAST> ParseExpression() { 325 auto LHS = ParsePrimary(); 326 if (!LHS) 327 return nullptr; 328 329 return ParseBinOpRHS(0, std::move(LHS)); 330} 331 332/// prototype 333/// ::= id '(' id* ')' 334static std::unique_ptr<PrototypeAST> ParsePrototype() { 335 if (CurTok != tok_identifier) 336 return ErrorP("Expected function name in prototype"); 337 338 std::string FnName = IdentifierStr; 339 getNextToken(); 340 341 if (CurTok != '(') 342 return ErrorP("Expected '(' in prototype"); 343 344 std::vector<std::string> ArgNames; 345 while (getNextToken() == tok_identifier) 346 ArgNames.push_back(IdentifierStr); 347 if (CurTok != ')') 348 return ErrorP("Expected ')' in prototype"); 349 350 // success. 351 getNextToken(); // eat ')'. 352 353 return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames)); 354} 355 356/// definition ::= 'def' prototype expression 357static std::unique_ptr<FunctionAST> ParseDefinition() { 358 getNextToken(); // eat def. 359 auto Proto = ParsePrototype(); 360 if (!Proto) 361 return nullptr; 362 363 if (auto E = ParseExpression()) 364 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E)); 365 return nullptr; 366} 367 368/// toplevelexpr ::= expression 369static std::unique_ptr<FunctionAST> ParseTopLevelExpr() { 370 if (auto E = ParseExpression()) { 371 // Make an anonymous proto. 372 auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr", 373 std::vector<std::string>()); 374 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E)); 375 } 376 return nullptr; 377} 378 379/// external ::= 'extern' prototype 380static std::unique_ptr<PrototypeAST> ParseExtern() { 381 getNextToken(); // eat extern. 382 return ParsePrototype(); 383} 384 385//===----------------------------------------------------------------------===// 386// Code Generation 387//===----------------------------------------------------------------------===// 388 389static std::unique_ptr<Module> TheModule; 390static IRBuilder<> Builder(getGlobalContext()); 391static std::map<std::string, Value *> NamedValues; 392static std::unique_ptr<legacy::FunctionPassManager> TheFPM; 393static std::unique_ptr<KaleidoscopeJIT> TheJIT; 394static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos; 395 396Value *ErrorV(const char *Str) { 397 Error(Str); 398 return nullptr; 399} 400 401Function *getFunction(std::string Name) { 402 // First, see if the function has already been added to the current module. 403 if (auto *F = TheModule->getFunction(Name)) 404 return F; 405 406 // If not, check whether we can codegen the declaration from some existing 407 // prototype. 408 auto FI = FunctionProtos.find(Name); 409 if (FI != FunctionProtos.end()) 410 return FI->second->codegen(); 411 412 // If no existing prototype exists, return null. 413 return nullptr; 414} 415 416Value *NumberExprAST::codegen() { 417 return ConstantFP::get(getGlobalContext(), APFloat(Val)); 418} 419 420Value *VariableExprAST::codegen() { 421 // Look this variable up in the function. 422 Value *V = NamedValues[Name]; 423 if (!V) 424 return ErrorV("Unknown variable name"); 425 return V; 426} 427 428Value *BinaryExprAST::codegen() { 429 Value *L = LHS->codegen(); 430 Value *R = RHS->codegen(); 431 if (!L || !R) 432 return nullptr; 433 434 switch (Op) { 435 case '+': 436 return Builder.CreateFAdd(L, R, "addtmp"); 437 case '-': 438 return Builder.CreateFSub(L, R, "subtmp"); 439 case '*': 440 return Builder.CreateFMul(L, R, "multmp"); 441 case '<': 442 L = Builder.CreateFCmpULT(L, R, "cmptmp"); 443 // Convert bool 0/1 to double 0.0 or 1.0 444 return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()), 445 "booltmp"); 446 default: 447 return ErrorV("invalid binary operator"); 448 } 449} 450 451Value *CallExprAST::codegen() { 452 // Look up the name in the global module table. 453 Function *CalleeF = getFunction(Callee); 454 if (!CalleeF) 455 return ErrorV("Unknown function referenced"); 456 457 // If argument mismatch error. 458 if (CalleeF->arg_size() != Args.size()) 459 return ErrorV("Incorrect # arguments passed"); 460 461 std::vector<Value *> ArgsV; 462 for (unsigned i = 0, e = Args.size(); i != e; ++i) { 463 ArgsV.push_back(Args[i]->codegen()); 464 if (!ArgsV.back()) 465 return nullptr; 466 } 467 468 return Builder.CreateCall(CalleeF, ArgsV, "calltmp"); 469} 470 471Function *PrototypeAST::codegen() { 472 // Make the function type: double(double,double) etc. 473 std::vector<Type *> Doubles(Args.size(), 474 Type::getDoubleTy(getGlobalContext())); 475 FunctionType *FT = 476 FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false); 477 478 Function *F = 479 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get()); 480 481 // Set names for all arguments. 482 unsigned Idx = 0; 483 for (auto &Arg : F->args()) 484 Arg.setName(Args[Idx++]); 485 486 return F; 487} 488 489Function *FunctionAST::codegen() { 490 // Transfer ownership of the prototype to the FunctionProtos map, but keep a 491 // reference to it for use below. 492 auto &P = *Proto; 493 FunctionProtos[Proto->getName()] = std::move(Proto); 494 Function *TheFunction = getFunction(P.getName()); 495 if (!TheFunction) 496 return nullptr; 497 498 // Create a new basic block to start insertion into. 499 BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction); 500 Builder.SetInsertPoint(BB); 501 502 // Record the function arguments in the NamedValues map. 503 NamedValues.clear(); 504 for (auto &Arg : TheFunction->args()) 505 NamedValues[Arg.getName()] = &Arg; 506 507 if (Value *RetVal = Body->codegen()) { 508 // Finish off the function. 509 Builder.CreateRet(RetVal); 510 511 // Validate the generated code, checking for consistency. 512 verifyFunction(*TheFunction); 513 514 // Run the optimizer on the function. 515 TheFPM->run(*TheFunction); 516 517 return TheFunction; 518 } 519 520 // Error reading body, remove function. 521 TheFunction->eraseFromParent(); 522 return nullptr; 523} 524 525//===----------------------------------------------------------------------===// 526// Top-Level parsing and JIT Driver 527//===----------------------------------------------------------------------===// 528 529static void InitializeModuleAndPassManager() { 530 // Open a new module. 531 TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext()); 532 TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout()); 533 534 // Create a new pass manager attached to it. 535 TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get()); 536 537 // Do simple "peephole" optimizations and bit-twiddling optzns. 538 TheFPM->add(createInstructionCombiningPass()); 539 // Reassociate expressions. 540 TheFPM->add(createReassociatePass()); 541 // Eliminate Common SubExpressions. 542 TheFPM->add(createGVNPass()); 543 // Simplify the control flow graph (deleting unreachable blocks, etc). 544 TheFPM->add(createCFGSimplificationPass()); 545 546 TheFPM->doInitialization(); 547} 548 549static void HandleDefinition() { 550 if (auto FnAST = ParseDefinition()) { 551 if (auto *FnIR = FnAST->codegen()) { 552 fprintf(stderr, "Read function definition:"); 553 FnIR->dump(); 554 TheJIT->addModule(std::move(TheModule)); 555 InitializeModuleAndPassManager(); 556 } 557 } else { 558 // Skip token for error recovery. 559 getNextToken(); 560 } 561} 562 563static void HandleExtern() { 564 if (auto ProtoAST = ParseExtern()) { 565 if (auto *FnIR = ProtoAST->codegen()) { 566 fprintf(stderr, "Read extern: "); 567 FnIR->dump(); 568 FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST); 569 } 570 } else { 571 // Skip token for error recovery. 572 getNextToken(); 573 } 574} 575 576static void HandleTopLevelExpression() { 577 // Evaluate a top-level expression into an anonymous function. 578 if (auto FnAST = ParseTopLevelExpr()) { 579 if (FnAST->codegen()) { 580 581 // JIT the module containing the anonymous expression, keeping a handle so 582 // we can free it later. 583 auto H = TheJIT->addModule(std::move(TheModule)); 584 InitializeModuleAndPassManager(); 585 586 // Search the JIT for the __anon_expr symbol. 587 auto ExprSymbol = TheJIT->findSymbol("__anon_expr"); 588 assert(ExprSymbol && "Function not found"); 589 590 // Get the symbol's address and cast it to the right type (takes no 591 // arguments, returns a double) so we can call it as a native function. 592 double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress(); 593 fprintf(stderr, "Evaluated to %f\n", FP()); 594 595 // Delete the anonymous expression module from the JIT. 596 TheJIT->removeModule(H); 597 } 598 } else { 599 // Skip token for error recovery. 600 getNextToken(); 601 } 602} 603 604/// top ::= definition | external | expression | ';' 605static void MainLoop() { 606 while (1) { 607 fprintf(stderr, "ready> "); 608 switch (CurTok) { 609 case tok_eof: 610 return; 611 case ';': // ignore top-level semicolons. 612 getNextToken(); 613 break; 614 case tok_def: 615 HandleDefinition(); 616 break; 617 case tok_extern: 618 HandleExtern(); 619 break; 620 default: 621 HandleTopLevelExpression(); 622 break; 623 } 624 } 625} 626 627//===----------------------------------------------------------------------===// 628// "Library" functions that can be "extern'd" from user code. 629//===----------------------------------------------------------------------===// 630 631/// putchard - putchar that takes a double and returns 0. 632extern "C" double putchard(double X) { 633 fputc((char)X, stderr); 634 return 0; 635} 636 637/// printd - printf that takes a double prints it as "%f\n", returning 0. 638extern "C" double printd(double X) { 639 fprintf(stderr, "%f\n", X); 640 return 0; 641} 642 643//===----------------------------------------------------------------------===// 644// Main driver code. 645//===----------------------------------------------------------------------===// 646 647int main() { 648 InitializeNativeTarget(); 649 InitializeNativeTargetAsmPrinter(); 650 InitializeNativeTargetAsmParser(); 651 652 // Install standard binary operators. 653 // 1 is lowest precedence. 654 BinopPrecedence['<'] = 10; 655 BinopPrecedence['+'] = 20; 656 BinopPrecedence['-'] = 20; 657 BinopPrecedence['*'] = 40; // highest. 658 659 // Prime the first token. 660 fprintf(stderr, "ready> "); 661 getNextToken(); 662 663 TheJIT = llvm::make_unique<KaleidoscopeJIT>(); 664 665 InitializeModuleAndPassManager(); 666 667 // Run the main "interpreter loop" now. 668 MainLoop(); 669 670 return 0; 671} 672