1#include "llvm/ADT/APFloat.h" 2#include "llvm/ADT/STLExtras.h" 3#include "llvm/IR/BasicBlock.h" 4#include "llvm/IR/Constants.h" 5#include "llvm/IR/DerivedTypes.h" 6#include "llvm/IR/Function.h" 7#include "llvm/IR/IRBuilder.h" 8#include "llvm/IR/LLVMContext.h" 9#include "llvm/IR/Module.h" 10#include "llvm/IR/Type.h" 11#include "llvm/IR/Verifier.h" 12#include <cctype> 13#include <cstdio> 14#include <cstdlib> 15#include <map> 16#include <memory> 17#include <string> 18#include <vector> 19 20using namespace llvm; 21 22//===----------------------------------------------------------------------===// 23// Lexer 24//===----------------------------------------------------------------------===// 25 26// The lexer returns tokens [0-255] if it is an unknown character, otherwise one 27// of these for known things. 28enum Token { 29 tok_eof = -1, 30 31 // commands 32 tok_def = -2, 33 tok_extern = -3, 34 35 // primary 36 tok_identifier = -4, 37 tok_number = -5 38}; 39 40static std::string IdentifierStr; // Filled in if tok_identifier 41static double NumVal; // Filled in if tok_number 42 43/// gettok - Return the next token from standard input. 44static int gettok() { 45 static int LastChar = ' '; 46 47 // Skip any whitespace. 48 while (isspace(LastChar)) 49 LastChar = getchar(); 50 51 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]* 52 IdentifierStr = LastChar; 53 while (isalnum((LastChar = getchar()))) 54 IdentifierStr += LastChar; 55 56 if (IdentifierStr == "def") 57 return tok_def; 58 if (IdentifierStr == "extern") 59 return tok_extern; 60 return tok_identifier; 61 } 62 63 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ 64 std::string NumStr; 65 do { 66 NumStr += LastChar; 67 LastChar = getchar(); 68 } while (isdigit(LastChar) || LastChar == '.'); 69 70 NumVal = strtod(NumStr.c_str(), nullptr); 71 return tok_number; 72 } 73 74 if (LastChar == '#') { 75 // Comment until end of line. 76 do 77 LastChar = getchar(); 78 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); 79 80 if (LastChar != EOF) 81 return gettok(); 82 } 83 84 // Check for end of file. Don't eat the EOF. 85 if (LastChar == EOF) 86 return tok_eof; 87 88 // Otherwise, just return the character as its ascii value. 89 int ThisChar = LastChar; 90 LastChar = getchar(); 91 return ThisChar; 92} 93 94//===----------------------------------------------------------------------===// 95// Abstract Syntax Tree (aka Parse Tree) 96//===----------------------------------------------------------------------===// 97namespace { 98/// ExprAST - Base class for all expression nodes. 99class ExprAST { 100public: 101 virtual ~ExprAST() {} 102 virtual Value *codegen() = 0; 103}; 104 105/// NumberExprAST - Expression class for numeric literals like "1.0". 106class NumberExprAST : public ExprAST { 107 double Val; 108 109public: 110 NumberExprAST(double Val) : Val(Val) {} 111 Value *codegen() override; 112}; 113 114/// VariableExprAST - Expression class for referencing a variable, like "a". 115class VariableExprAST : public ExprAST { 116 std::string Name; 117 118public: 119 VariableExprAST(const std::string &Name) : Name(Name) {} 120 Value *codegen() override; 121}; 122 123/// BinaryExprAST - Expression class for a binary operator. 124class BinaryExprAST : public ExprAST { 125 char Op; 126 std::unique_ptr<ExprAST> LHS, RHS; 127 128public: 129 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS, 130 std::unique_ptr<ExprAST> RHS) 131 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {} 132 Value *codegen() override; 133}; 134 135/// CallExprAST - Expression class for function calls. 136class CallExprAST : public ExprAST { 137 std::string Callee; 138 std::vector<std::unique_ptr<ExprAST>> Args; 139 140public: 141 CallExprAST(const std::string &Callee, 142 std::vector<std::unique_ptr<ExprAST>> Args) 143 : Callee(Callee), Args(std::move(Args)) {} 144 Value *codegen() override; 145}; 146 147/// PrototypeAST - This class represents the "prototype" for a function, 148/// which captures its name, and its argument names (thus implicitly the number 149/// of arguments the function takes). 150class PrototypeAST { 151 std::string Name; 152 std::vector<std::string> Args; 153 154public: 155 PrototypeAST(const std::string &Name, std::vector<std::string> Args) 156 : Name(Name), Args(std::move(Args)) {} 157 Function *codegen(); 158 const std::string &getName() const { return Name; } 159}; 160 161/// FunctionAST - This class represents a function definition itself. 162class FunctionAST { 163 std::unique_ptr<PrototypeAST> Proto; 164 std::unique_ptr<ExprAST> Body; 165 166public: 167 FunctionAST(std::unique_ptr<PrototypeAST> Proto, 168 std::unique_ptr<ExprAST> Body) 169 : Proto(std::move(Proto)), Body(std::move(Body)) {} 170 Function *codegen(); 171}; 172} // end anonymous namespace 173 174//===----------------------------------------------------------------------===// 175// Parser 176//===----------------------------------------------------------------------===// 177 178/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current 179/// token the parser is looking at. getNextToken reads another token from the 180/// lexer and updates CurTok with its results. 181static int CurTok; 182static int getNextToken() { return CurTok = gettok(); } 183 184/// BinopPrecedence - This holds the precedence for each binary operator that is 185/// defined. 186static std::map<char, int> BinopPrecedence; 187 188/// GetTokPrecedence - Get the precedence of the pending binary operator token. 189static int GetTokPrecedence() { 190 if (!isascii(CurTok)) 191 return -1; 192 193 // Make sure it's a declared binop. 194 int TokPrec = BinopPrecedence[CurTok]; 195 if (TokPrec <= 0) 196 return -1; 197 return TokPrec; 198} 199 200/// LogError* - These are little helper functions for error handling. 201std::unique_ptr<ExprAST> LogError(const char *Str) { 202 fprintf(stderr, "Error: %s\n", Str); 203 return nullptr; 204} 205 206std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) { 207 LogError(Str); 208 return nullptr; 209} 210 211static std::unique_ptr<ExprAST> ParseExpression(); 212 213/// numberexpr ::= number 214static std::unique_ptr<ExprAST> ParseNumberExpr() { 215 auto Result = llvm::make_unique<NumberExprAST>(NumVal); 216 getNextToken(); // consume the number 217 return std::move(Result); 218} 219 220/// parenexpr ::= '(' expression ')' 221static std::unique_ptr<ExprAST> ParseParenExpr() { 222 getNextToken(); // eat (. 223 auto V = ParseExpression(); 224 if (!V) 225 return nullptr; 226 227 if (CurTok != ')') 228 return LogError("expected ')'"); 229 getNextToken(); // eat ). 230 return V; 231} 232 233/// identifierexpr 234/// ::= identifier 235/// ::= identifier '(' expression* ')' 236static std::unique_ptr<ExprAST> ParseIdentifierExpr() { 237 std::string IdName = IdentifierStr; 238 239 getNextToken(); // eat identifier. 240 241 if (CurTok != '(') // Simple variable ref. 242 return llvm::make_unique<VariableExprAST>(IdName); 243 244 // Call. 245 getNextToken(); // eat ( 246 std::vector<std::unique_ptr<ExprAST>> Args; 247 if (CurTok != ')') { 248 while (true) { 249 if (auto Arg = ParseExpression()) 250 Args.push_back(std::move(Arg)); 251 else 252 return nullptr; 253 254 if (CurTok == ')') 255 break; 256 257 if (CurTok != ',') 258 return LogError("Expected ')' or ',' in argument list"); 259 getNextToken(); 260 } 261 } 262 263 // Eat the ')'. 264 getNextToken(); 265 266 return llvm::make_unique<CallExprAST>(IdName, std::move(Args)); 267} 268 269/// primary 270/// ::= identifierexpr 271/// ::= numberexpr 272/// ::= parenexpr 273static std::unique_ptr<ExprAST> ParsePrimary() { 274 switch (CurTok) { 275 default: 276 return LogError("unknown token when expecting an expression"); 277 case tok_identifier: 278 return ParseIdentifierExpr(); 279 case tok_number: 280 return ParseNumberExpr(); 281 case '(': 282 return ParseParenExpr(); 283 } 284} 285 286/// binoprhs 287/// ::= ('+' primary)* 288static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec, 289 std::unique_ptr<ExprAST> LHS) { 290 // If this is a binop, find its precedence. 291 while (true) { 292 int TokPrec = GetTokPrecedence(); 293 294 // If this is a binop that binds at least as tightly as the current binop, 295 // consume it, otherwise we are done. 296 if (TokPrec < ExprPrec) 297 return LHS; 298 299 // Okay, we know this is a binop. 300 int BinOp = CurTok; 301 getNextToken(); // eat binop 302 303 // Parse the primary expression after the binary operator. 304 auto RHS = ParsePrimary(); 305 if (!RHS) 306 return nullptr; 307 308 // If BinOp binds less tightly with RHS than the operator after RHS, let 309 // the pending operator take RHS as its LHS. 310 int NextPrec = GetTokPrecedence(); 311 if (TokPrec < NextPrec) { 312 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS)); 313 if (!RHS) 314 return nullptr; 315 } 316 317 // Merge LHS/RHS. 318 LHS = 319 llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS)); 320 } 321} 322 323/// expression 324/// ::= primary binoprhs 325/// 326static std::unique_ptr<ExprAST> ParseExpression() { 327 auto LHS = ParsePrimary(); 328 if (!LHS) 329 return nullptr; 330 331 return ParseBinOpRHS(0, std::move(LHS)); 332} 333 334/// prototype 335/// ::= id '(' id* ')' 336static std::unique_ptr<PrototypeAST> ParsePrototype() { 337 if (CurTok != tok_identifier) 338 return LogErrorP("Expected function name in prototype"); 339 340 std::string FnName = IdentifierStr; 341 getNextToken(); 342 343 if (CurTok != '(') 344 return LogErrorP("Expected '(' in prototype"); 345 346 std::vector<std::string> ArgNames; 347 while (getNextToken() == tok_identifier) 348 ArgNames.push_back(IdentifierStr); 349 if (CurTok != ')') 350 return LogErrorP("Expected ')' in prototype"); 351 352 // success. 353 getNextToken(); // eat ')'. 354 355 return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames)); 356} 357 358/// definition ::= 'def' prototype expression 359static std::unique_ptr<FunctionAST> ParseDefinition() { 360 getNextToken(); // eat def. 361 auto Proto = ParsePrototype(); 362 if (!Proto) 363 return nullptr; 364 365 if (auto E = ParseExpression()) 366 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E)); 367 return nullptr; 368} 369 370/// toplevelexpr ::= expression 371static std::unique_ptr<FunctionAST> ParseTopLevelExpr() { 372 if (auto E = ParseExpression()) { 373 // Make an anonymous proto. 374 auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr", 375 std::vector<std::string>()); 376 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E)); 377 } 378 return nullptr; 379} 380 381/// external ::= 'extern' prototype 382static std::unique_ptr<PrototypeAST> ParseExtern() { 383 getNextToken(); // eat extern. 384 return ParsePrototype(); 385} 386 387//===----------------------------------------------------------------------===// 388// Code Generation 389//===----------------------------------------------------------------------===// 390 391static LLVMContext TheContext; 392static IRBuilder<> Builder(TheContext); 393static std::unique_ptr<Module> TheModule; 394static std::map<std::string, Value *> NamedValues; 395 396Value *LogErrorV(const char *Str) { 397 LogError(Str); 398 return nullptr; 399} 400 401Value *NumberExprAST::codegen() { 402 return ConstantFP::get(TheContext, APFloat(Val)); 403} 404 405Value *VariableExprAST::codegen() { 406 // Look this variable up in the function. 407 Value *V = NamedValues[Name]; 408 if (!V) 409 return LogErrorV("Unknown variable name"); 410 return V; 411} 412 413Value *BinaryExprAST::codegen() { 414 Value *L = LHS->codegen(); 415 Value *R = RHS->codegen(); 416 if (!L || !R) 417 return nullptr; 418 419 switch (Op) { 420 case '+': 421 return Builder.CreateFAdd(L, R, "addtmp"); 422 case '-': 423 return Builder.CreateFSub(L, R, "subtmp"); 424 case '*': 425 return Builder.CreateFMul(L, R, "multmp"); 426 case '<': 427 L = Builder.CreateFCmpULT(L, R, "cmptmp"); 428 // Convert bool 0/1 to double 0.0 or 1.0 429 return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp"); 430 default: 431 return LogErrorV("invalid binary operator"); 432 } 433} 434 435Value *CallExprAST::codegen() { 436 // Look up the name in the global module table. 437 Function *CalleeF = TheModule->getFunction(Callee); 438 if (!CalleeF) 439 return LogErrorV("Unknown function referenced"); 440 441 // If argument mismatch error. 442 if (CalleeF->arg_size() != Args.size()) 443 return LogErrorV("Incorrect # arguments passed"); 444 445 std::vector<Value *> ArgsV; 446 for (unsigned i = 0, e = Args.size(); i != e; ++i) { 447 ArgsV.push_back(Args[i]->codegen()); 448 if (!ArgsV.back()) 449 return nullptr; 450 } 451 452 return Builder.CreateCall(CalleeF, ArgsV, "calltmp"); 453} 454 455Function *PrototypeAST::codegen() { 456 // Make the function type: double(double,double) etc. 457 std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext)); 458 FunctionType *FT = 459 FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false); 460 461 Function *F = 462 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get()); 463 464 // Set names for all arguments. 465 unsigned Idx = 0; 466 for (auto &Arg : F->args()) 467 Arg.setName(Args[Idx++]); 468 469 return F; 470} 471 472Function *FunctionAST::codegen() { 473 // First, check for an existing function from a previous 'extern' declaration. 474 Function *TheFunction = TheModule->getFunction(Proto->getName()); 475 476 if (!TheFunction) 477 TheFunction = Proto->codegen(); 478 479 if (!TheFunction) 480 return nullptr; 481 482 // Create a new basic block to start insertion into. 483 BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction); 484 Builder.SetInsertPoint(BB); 485 486 // Record the function arguments in the NamedValues map. 487 NamedValues.clear(); 488 for (auto &Arg : TheFunction->args()) 489 NamedValues[Arg.getName()] = &Arg; 490 491 if (Value *RetVal = Body->codegen()) { 492 // Finish off the function. 493 Builder.CreateRet(RetVal); 494 495 // Validate the generated code, checking for consistency. 496 verifyFunction(*TheFunction); 497 498 return TheFunction; 499 } 500 501 // Error reading body, remove function. 502 TheFunction->eraseFromParent(); 503 return nullptr; 504} 505 506//===----------------------------------------------------------------------===// 507// Top-Level parsing and JIT Driver 508//===----------------------------------------------------------------------===// 509 510static void HandleDefinition() { 511 if (auto FnAST = ParseDefinition()) { 512 if (auto *FnIR = FnAST->codegen()) { 513 fprintf(stderr, "Read function definition:"); 514 FnIR->dump(); 515 } 516 } else { 517 // Skip token for error recovery. 518 getNextToken(); 519 } 520} 521 522static void HandleExtern() { 523 if (auto ProtoAST = ParseExtern()) { 524 if (auto *FnIR = ProtoAST->codegen()) { 525 fprintf(stderr, "Read extern: "); 526 FnIR->dump(); 527 } 528 } else { 529 // Skip token for error recovery. 530 getNextToken(); 531 } 532} 533 534static void HandleTopLevelExpression() { 535 // Evaluate a top-level expression into an anonymous function. 536 if (auto FnAST = ParseTopLevelExpr()) { 537 if (auto *FnIR = FnAST->codegen()) { 538 fprintf(stderr, "Read top-level expression:"); 539 FnIR->dump(); 540 } 541 } else { 542 // Skip token for error recovery. 543 getNextToken(); 544 } 545} 546 547/// top ::= definition | external | expression | ';' 548static void MainLoop() { 549 while (true) { 550 fprintf(stderr, "ready> "); 551 switch (CurTok) { 552 case tok_eof: 553 return; 554 case ';': // ignore top-level semicolons. 555 getNextToken(); 556 break; 557 case tok_def: 558 HandleDefinition(); 559 break; 560 case tok_extern: 561 HandleExtern(); 562 break; 563 default: 564 HandleTopLevelExpression(); 565 break; 566 } 567 } 568} 569 570//===----------------------------------------------------------------------===// 571// Main driver code. 572//===----------------------------------------------------------------------===// 573 574int main() { 575 // Install standard binary operators. 576 // 1 is lowest precedence. 577 BinopPrecedence['<'] = 10; 578 BinopPrecedence['+'] = 20; 579 BinopPrecedence['-'] = 20; 580 BinopPrecedence['*'] = 40; // highest. 581 582 // Prime the first token. 583 fprintf(stderr, "ready> "); 584 getNextToken(); 585 586 // Make the module, which holds all the code. 587 TheModule = llvm::make_unique<Module>("my cool jit", TheContext); 588 589 // Run the main "interpreter loop" now. 590 MainLoop(); 591 592 // Print out all of the generated code. 593 TheModule->dump(); 594 595 return 0; 596} 597