toy.cpp revision 31c6c5d58a6d2254063e8a18fd32b851a06e2ddf
1#include "llvm/DerivedTypes.h" 2#include "llvm/ExecutionEngine/ExecutionEngine.h" 3#include "llvm/ExecutionEngine/Interpreter.h" 4#include "llvm/ExecutionEngine/JIT.h" 5#include "llvm/LLVMContext.h" 6#include "llvm/Module.h" 7#include "llvm/ModuleProvider.h" 8#include "llvm/PassManager.h" 9#include "llvm/Analysis/Verifier.h" 10#include "llvm/Target/TargetData.h" 11#include "llvm/Target/TargetSelect.h" 12#include "llvm/Transforms/Scalar.h" 13#include "llvm/Support/IRBuilder.h" 14#include <cstdio> 15#include <string> 16#include <map> 17#include <vector> 18using namespace llvm; 19 20//===----------------------------------------------------------------------===// 21// Lexer 22//===----------------------------------------------------------------------===// 23 24// The lexer returns tokens [0-255] if it is an unknown character, otherwise one 25// of these for known things. 26enum Token { 27 tok_eof = -1, 28 29 // commands 30 tok_def = -2, tok_extern = -3, 31 32 // primary 33 tok_identifier = -4, tok_number = -5 34}; 35 36static std::string IdentifierStr; // Filled in if tok_identifier 37static double NumVal; // Filled in if tok_number 38 39/// gettok - Return the next token from standard input. 40static int gettok() { 41 static int LastChar = ' '; 42 43 // Skip any whitespace. 44 while (isspace(LastChar)) 45 LastChar = getchar(); 46 47 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]* 48 IdentifierStr = LastChar; 49 while (isalnum((LastChar = getchar()))) 50 IdentifierStr += LastChar; 51 52 if (IdentifierStr == "def") return tok_def; 53 if (IdentifierStr == "extern") return tok_extern; 54 return tok_identifier; 55 } 56 57 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ 58 std::string NumStr; 59 do { 60 NumStr += LastChar; 61 LastChar = getchar(); 62 } while (isdigit(LastChar) || LastChar == '.'); 63 64 NumVal = strtod(NumStr.c_str(), 0); 65 return tok_number; 66 } 67 68 if (LastChar == '#') { 69 // Comment until end of line. 70 do LastChar = getchar(); 71 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); 72 73 if (LastChar != EOF) 74 return gettok(); 75 } 76 77 // Check for end of file. Don't eat the EOF. 78 if (LastChar == EOF) 79 return tok_eof; 80 81 // Otherwise, just return the character as its ascii value. 82 int ThisChar = LastChar; 83 LastChar = getchar(); 84 return ThisChar; 85} 86 87//===----------------------------------------------------------------------===// 88// Abstract Syntax Tree (aka Parse Tree) 89//===----------------------------------------------------------------------===// 90 91/// ExprAST - Base class for all expression nodes. 92class ExprAST { 93public: 94 virtual ~ExprAST() {} 95 virtual Value *Codegen() = 0; 96}; 97 98/// NumberExprAST - Expression class for numeric literals like "1.0". 99class NumberExprAST : public ExprAST { 100 double Val; 101public: 102 NumberExprAST(double val) : Val(val) {} 103 virtual Value *Codegen(); 104}; 105 106/// VariableExprAST - Expression class for referencing a variable, like "a". 107class VariableExprAST : public ExprAST { 108 std::string Name; 109public: 110 VariableExprAST(const std::string &name) : Name(name) {} 111 virtual Value *Codegen(); 112}; 113 114/// BinaryExprAST - Expression class for a binary operator. 115class BinaryExprAST : public ExprAST { 116 char Op; 117 ExprAST *LHS, *RHS; 118public: 119 BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs) 120 : Op(op), LHS(lhs), RHS(rhs) {} 121 virtual Value *Codegen(); 122}; 123 124/// CallExprAST - Expression class for function calls. 125class CallExprAST : public ExprAST { 126 std::string Callee; 127 std::vector<ExprAST*> Args; 128public: 129 CallExprAST(const std::string &callee, std::vector<ExprAST*> &args) 130 : Callee(callee), Args(args) {} 131 virtual Value *Codegen(); 132}; 133 134/// PrototypeAST - This class represents the "prototype" for a function, 135/// which captures its name, and its argument names (thus implicitly the number 136/// of arguments the function takes). 137class PrototypeAST { 138 std::string Name; 139 std::vector<std::string> Args; 140public: 141 PrototypeAST(const std::string &name, const std::vector<std::string> &args) 142 : Name(name), Args(args) {} 143 144 Function *Codegen(); 145}; 146 147/// FunctionAST - This class represents a function definition itself. 148class FunctionAST { 149 PrototypeAST *Proto; 150 ExprAST *Body; 151public: 152 FunctionAST(PrototypeAST *proto, ExprAST *body) 153 : Proto(proto), Body(body) {} 154 155 Function *Codegen(); 156}; 157 158//===----------------------------------------------------------------------===// 159// Parser 160//===----------------------------------------------------------------------===// 161 162/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current 163/// token the parser is looking at. getNextToken reads another token from the 164/// lexer and updates CurTok with its results. 165static int CurTok; 166static int getNextToken() { 167 return CurTok = gettok(); 168} 169 170/// BinopPrecedence - This holds the precedence for each binary operator that is 171/// defined. 172static std::map<char, int> BinopPrecedence; 173 174/// GetTokPrecedence - Get the precedence of the pending binary operator token. 175static int GetTokPrecedence() { 176 if (!isascii(CurTok)) 177 return -1; 178 179 // Make sure it's a declared binop. 180 int TokPrec = BinopPrecedence[CurTok]; 181 if (TokPrec <= 0) return -1; 182 return TokPrec; 183} 184 185/// Error* - These are little helper functions for error handling. 186ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;} 187PrototypeAST *ErrorP(const char *Str) { Error(Str); return 0; } 188FunctionAST *ErrorF(const char *Str) { Error(Str); return 0; } 189 190static ExprAST *ParseExpression(); 191 192/// identifierexpr 193/// ::= identifier 194/// ::= identifier '(' expression* ')' 195static ExprAST *ParseIdentifierExpr() { 196 std::string IdName = IdentifierStr; 197 198 getNextToken(); // eat identifier. 199 200 if (CurTok != '(') // Simple variable ref. 201 return new VariableExprAST(IdName); 202 203 // Call. 204 getNextToken(); // eat ( 205 std::vector<ExprAST*> Args; 206 if (CurTok != ')') { 207 while (1) { 208 ExprAST *Arg = ParseExpression(); 209 if (!Arg) return 0; 210 Args.push_back(Arg); 211 212 if (CurTok == ')') break; 213 214 if (CurTok != ',') 215 return Error("Expected ')' or ',' in argument list"); 216 getNextToken(); 217 } 218 } 219 220 // Eat the ')'. 221 getNextToken(); 222 223 return new CallExprAST(IdName, Args); 224} 225 226/// numberexpr ::= number 227static ExprAST *ParseNumberExpr() { 228 ExprAST *Result = new NumberExprAST(NumVal); 229 getNextToken(); // consume the number 230 return Result; 231} 232 233/// parenexpr ::= '(' expression ')' 234static ExprAST *ParseParenExpr() { 235 getNextToken(); // eat (. 236 ExprAST *V = ParseExpression(); 237 if (!V) return 0; 238 239 if (CurTok != ')') 240 return Error("expected ')'"); 241 getNextToken(); // eat ). 242 return V; 243} 244 245/// primary 246/// ::= identifierexpr 247/// ::= numberexpr 248/// ::= parenexpr 249static ExprAST *ParsePrimary() { 250 switch (CurTok) { 251 default: return Error("unknown token when expecting an expression"); 252 case tok_identifier: return ParseIdentifierExpr(); 253 case tok_number: return ParseNumberExpr(); 254 case '(': return ParseParenExpr(); 255 } 256} 257 258/// binoprhs 259/// ::= ('+' primary)* 260static ExprAST *ParseBinOpRHS(int ExprPrec, ExprAST *LHS) { 261 // If this is a binop, find its precedence. 262 while (1) { 263 int TokPrec = GetTokPrecedence(); 264 265 // If this is a binop that binds at least as tightly as the current binop, 266 // consume it, otherwise we are done. 267 if (TokPrec < ExprPrec) 268 return LHS; 269 270 // Okay, we know this is a binop. 271 int BinOp = CurTok; 272 getNextToken(); // eat binop 273 274 // Parse the primary expression after the binary operator. 275 ExprAST *RHS = ParsePrimary(); 276 if (!RHS) return 0; 277 278 // If BinOp binds less tightly with RHS than the operator after RHS, let 279 // the pending operator take RHS as its LHS. 280 int NextPrec = GetTokPrecedence(); 281 if (TokPrec < NextPrec) { 282 RHS = ParseBinOpRHS(TokPrec+1, RHS); 283 if (RHS == 0) return 0; 284 } 285 286 // Merge LHS/RHS. 287 LHS = new BinaryExprAST(BinOp, LHS, RHS); 288 } 289} 290 291/// expression 292/// ::= primary binoprhs 293/// 294static ExprAST *ParseExpression() { 295 ExprAST *LHS = ParsePrimary(); 296 if (!LHS) return 0; 297 298 return ParseBinOpRHS(0, LHS); 299} 300 301/// prototype 302/// ::= id '(' id* ')' 303static PrototypeAST *ParsePrototype() { 304 if (CurTok != tok_identifier) 305 return ErrorP("Expected function name in prototype"); 306 307 std::string FnName = IdentifierStr; 308 getNextToken(); 309 310 if (CurTok != '(') 311 return ErrorP("Expected '(' in prototype"); 312 313 std::vector<std::string> ArgNames; 314 while (getNextToken() == tok_identifier) 315 ArgNames.push_back(IdentifierStr); 316 if (CurTok != ')') 317 return ErrorP("Expected ')' in prototype"); 318 319 // success. 320 getNextToken(); // eat ')'. 321 322 return new PrototypeAST(FnName, ArgNames); 323} 324 325/// definition ::= 'def' prototype expression 326static FunctionAST *ParseDefinition() { 327 getNextToken(); // eat def. 328 PrototypeAST *Proto = ParsePrototype(); 329 if (Proto == 0) return 0; 330 331 if (ExprAST *E = ParseExpression()) 332 return new FunctionAST(Proto, E); 333 return 0; 334} 335 336/// toplevelexpr ::= expression 337static FunctionAST *ParseTopLevelExpr() { 338 if (ExprAST *E = ParseExpression()) { 339 // Make an anonymous proto. 340 PrototypeAST *Proto = new PrototypeAST("", std::vector<std::string>()); 341 return new FunctionAST(Proto, E); 342 } 343 return 0; 344} 345 346/// external ::= 'extern' prototype 347static PrototypeAST *ParseExtern() { 348 getNextToken(); // eat extern. 349 return ParsePrototype(); 350} 351 352//===----------------------------------------------------------------------===// 353// Code Generation 354//===----------------------------------------------------------------------===// 355 356static Module *TheModule; 357static IRBuilder<> Builder(getGlobalContext()); 358static std::map<std::string, Value*> NamedValues; 359static FunctionPassManager *TheFPM; 360 361Value *ErrorV(const char *Str) { Error(Str); return 0; } 362 363Value *NumberExprAST::Codegen() { 364 return ConstantFP::get(getGlobalContext(), APFloat(Val)); 365} 366 367Value *VariableExprAST::Codegen() { 368 // Look this variable up in the function. 369 Value *V = NamedValues[Name]; 370 return V ? V : ErrorV("Unknown variable name"); 371} 372 373Value *BinaryExprAST::Codegen() { 374 Value *L = LHS->Codegen(); 375 Value *R = RHS->Codegen(); 376 if (L == 0 || R == 0) return 0; 377 378 switch (Op) { 379 case '+': return Builder.CreateAdd(L, R, "addtmp"); 380 case '-': return Builder.CreateSub(L, R, "subtmp"); 381 case '*': return Builder.CreateMul(L, R, "multmp"); 382 case '<': 383 L = Builder.CreateFCmpULT(L, R, "cmptmp"); 384 // Convert bool 0/1 to double 0.0 or 1.0 385 return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()), 386 "booltmp"); 387 default: return ErrorV("invalid binary operator"); 388 } 389} 390 391Value *CallExprAST::Codegen() { 392 // Look up the name in the global module table. 393 Function *CalleeF = TheModule->getFunction(Callee); 394 if (CalleeF == 0) 395 return ErrorV("Unknown function referenced"); 396 397 // If argument mismatch error. 398 if (CalleeF->arg_size() != Args.size()) 399 return ErrorV("Incorrect # arguments passed"); 400 401 std::vector<Value*> ArgsV; 402 for (unsigned i = 0, e = Args.size(); i != e; ++i) { 403 ArgsV.push_back(Args[i]->Codegen()); 404 if (ArgsV.back() == 0) return 0; 405 } 406 407 return Builder.CreateCall(CalleeF, ArgsV.begin(), ArgsV.end(), "calltmp"); 408} 409 410Function *PrototypeAST::Codegen() { 411 // Make the function type: double(double,double) etc. 412 std::vector<const Type*> Doubles(Args.size(), 413 Type::getDoubleTy(getGlobalContext())); 414 FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()), 415 Doubles, false); 416 417 Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule); 418 419 // If F conflicted, there was already something named 'Name'. If it has a 420 // body, don't allow redefinition or reextern. 421 if (F->getName() != Name) { 422 // Delete the one we just made and get the existing one. 423 F->eraseFromParent(); 424 F = TheModule->getFunction(Name); 425 426 // If F already has a body, reject this. 427 if (!F->empty()) { 428 ErrorF("redefinition of function"); 429 return 0; 430 } 431 432 // If F took a different number of args, reject. 433 if (F->arg_size() != Args.size()) { 434 ErrorF("redefinition of function with different # args"); 435 return 0; 436 } 437 } 438 439 // Set names for all arguments. 440 unsigned Idx = 0; 441 for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size(); 442 ++AI, ++Idx) { 443 AI->setName(Args[Idx]); 444 445 // Add arguments to variable symbol table. 446 NamedValues[Args[Idx]] = AI; 447 } 448 449 return F; 450} 451 452Function *FunctionAST::Codegen() { 453 NamedValues.clear(); 454 455 Function *TheFunction = Proto->Codegen(); 456 if (TheFunction == 0) 457 return 0; 458 459 // Create a new basic block to start insertion into. 460 BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction); 461 Builder.SetInsertPoint(BB); 462 463 if (Value *RetVal = Body->Codegen()) { 464 // Finish off the function. 465 Builder.CreateRet(RetVal); 466 467 // Validate the generated code, checking for consistency. 468 verifyFunction(*TheFunction); 469 470 // Optimize the function. 471 TheFPM->run(*TheFunction); 472 473 return TheFunction; 474 } 475 476 // Error reading body, remove function. 477 TheFunction->eraseFromParent(); 478 return 0; 479} 480 481//===----------------------------------------------------------------------===// 482// Top-Level parsing and JIT Driver 483//===----------------------------------------------------------------------===// 484 485static ExecutionEngine *TheExecutionEngine; 486 487static void HandleDefinition() { 488 if (FunctionAST *F = ParseDefinition()) { 489 if (Function *LF = F->Codegen()) { 490 fprintf(stderr, "Read function definition:"); 491 LF->dump(); 492 } 493 } else { 494 // Skip token for error recovery. 495 getNextToken(); 496 } 497} 498 499static void HandleExtern() { 500 if (PrototypeAST *P = ParseExtern()) { 501 if (Function *F = P->Codegen()) { 502 fprintf(stderr, "Read extern: "); 503 F->dump(); 504 } 505 } else { 506 // Skip token for error recovery. 507 getNextToken(); 508 } 509} 510 511static void HandleTopLevelExpression() { 512 // Evaluate a top-level expression into an anonymous function. 513 if (FunctionAST *F = ParseTopLevelExpr()) { 514 if (Function *LF = F->Codegen()) { 515 // JIT the function, returning a function pointer. 516 void *FPtr = TheExecutionEngine->getPointerToFunction(LF); 517 518 // Cast it to the right type (takes no arguments, returns a double) so we 519 // can call it as a native function. 520 double (*FP)() = (double (*)())(intptr_t)FPtr; 521 fprintf(stderr, "Evaluated to %f\n", FP()); 522 } 523 } else { 524 // Skip token for error recovery. 525 getNextToken(); 526 } 527} 528 529/// top ::= definition | external | expression | ';' 530static void MainLoop() { 531 while (1) { 532 fprintf(stderr, "ready> "); 533 switch (CurTok) { 534 case tok_eof: return; 535 case ';': getNextToken(); break; // ignore top-level semicolons. 536 case tok_def: HandleDefinition(); break; 537 case tok_extern: HandleExtern(); break; 538 default: HandleTopLevelExpression(); break; 539 } 540 } 541} 542 543//===----------------------------------------------------------------------===// 544// "Library" functions that can be "extern'd" from user code. 545//===----------------------------------------------------------------------===// 546 547/// putchard - putchar that takes a double and returns 0. 548extern "C" 549double putchard(double X) { 550 putchar((char)X); 551 return 0; 552} 553 554//===----------------------------------------------------------------------===// 555// Main driver code. 556//===----------------------------------------------------------------------===// 557 558int main() { 559 InitializeNativeTarget(); 560 LLVMContext &Context = getGlobalContext(); 561 562 // Install standard binary operators. 563 // 1 is lowest precedence. 564 BinopPrecedence['<'] = 10; 565 BinopPrecedence['+'] = 20; 566 BinopPrecedence['-'] = 20; 567 BinopPrecedence['*'] = 40; // highest. 568 569 // Prime the first token. 570 fprintf(stderr, "ready> "); 571 getNextToken(); 572 573 // Make the module, which holds all the code. 574 TheModule = new Module("my cool jit", Context); 575 576 ExistingModuleProvider *OurModuleProvider = 577 new ExistingModuleProvider(TheModule); 578 579 // Create the JIT. This takes ownership of the module and module provider. 580 TheExecutionEngine = EngineBuilder(OurModuleProvider).create(); 581 582 FunctionPassManager OurFPM(OurModuleProvider); 583 584 // Set up the optimizer pipeline. Start with registering info about how the 585 // target lays out data structures. 586 OurFPM.add(new TargetData(*TheExecutionEngine->getTargetData())); 587 // Do simple "peephole" optimizations and bit-twiddling optzns. 588 OurFPM.add(createInstructionCombiningPass()); 589 // Reassociate expressions. 590 OurFPM.add(createReassociatePass()); 591 // Eliminate Common SubExpressions. 592 OurFPM.add(createGVNPass()); 593 // Simplify the control flow graph (deleting unreachable blocks, etc). 594 OurFPM.add(createCFGSimplificationPass()); 595 596 OurFPM.doInitialization(); 597 598 // Set the global so the code gen can use this. 599 TheFPM = &OurFPM; 600 601 // Run the main "interpreter loop" now. 602 MainLoop(); 603 604 TheFPM = 0; 605 606 // Print out all of the generated code. 607 TheModule->dump(); 608 609 return 0; 610} 611