toy.cpp revision 2632bbf0de4637e02f333952991263d526976faf
1#include "llvm/DerivedTypes.h" 2#include "llvm/ExecutionEngine/ExecutionEngine.h" 3#include "llvm/ExecutionEngine/JIT.h" 4#include "llvm/LLVMContext.h" 5#include "llvm/Module.h" 6#include "llvm/PassManager.h" 7#include "llvm/Analysis/Verifier.h" 8#include "llvm/Target/TargetData.h" 9#include "llvm/Target/TargetSelect.h" 10#include "llvm/Transforms/Scalar.h" 11#include "llvm/Support/IRBuilder.h" 12#include <cstdio> 13#include <string> 14#include <map> 15#include <vector> 16using namespace llvm; 17 18//===----------------------------------------------------------------------===// 19// Lexer 20//===----------------------------------------------------------------------===// 21 22// The lexer returns tokens [0-255] if it is an unknown character, otherwise one 23// of these for known things. 24enum Token { 25 tok_eof = -1, 26 27 // commands 28 tok_def = -2, tok_extern = -3, 29 30 // primary 31 tok_identifier = -4, tok_number = -5 32}; 33 34static std::string IdentifierStr; // Filled in if tok_identifier 35static double NumVal; // Filled in if tok_number 36 37/// gettok - Return the next token from standard input. 38static int gettok() { 39 static int LastChar = ' '; 40 41 // Skip any whitespace. 42 while (isspace(LastChar)) 43 LastChar = getchar(); 44 45 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]* 46 IdentifierStr = LastChar; 47 while (isalnum((LastChar = getchar()))) 48 IdentifierStr += LastChar; 49 50 if (IdentifierStr == "def") return tok_def; 51 if (IdentifierStr == "extern") return tok_extern; 52 return tok_identifier; 53 } 54 55 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+ 56 std::string NumStr; 57 do { 58 NumStr += LastChar; 59 LastChar = getchar(); 60 } while (isdigit(LastChar) || LastChar == '.'); 61 62 NumVal = strtod(NumStr.c_str(), 0); 63 return tok_number; 64 } 65 66 if (LastChar == '#') { 67 // Comment until end of line. 68 do LastChar = getchar(); 69 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r'); 70 71 if (LastChar != EOF) 72 return gettok(); 73 } 74 75 // Check for end of file. Don't eat the EOF. 76 if (LastChar == EOF) 77 return tok_eof; 78 79 // Otherwise, just return the character as its ascii value. 80 int ThisChar = LastChar; 81 LastChar = getchar(); 82 return ThisChar; 83} 84 85//===----------------------------------------------------------------------===// 86// Abstract Syntax Tree (aka Parse Tree) 87//===----------------------------------------------------------------------===// 88 89/// ExprAST - Base class for all expression nodes. 90class ExprAST { 91public: 92 virtual ~ExprAST() {} 93 virtual Value *Codegen() = 0; 94}; 95 96/// NumberExprAST - Expression class for numeric literals like "1.0". 97class NumberExprAST : public ExprAST { 98 double Val; 99public: 100 NumberExprAST(double val) : Val(val) {} 101 virtual Value *Codegen(); 102}; 103 104/// VariableExprAST - Expression class for referencing a variable, like "a". 105class VariableExprAST : public ExprAST { 106 std::string Name; 107public: 108 VariableExprAST(const std::string &name) : Name(name) {} 109 virtual Value *Codegen(); 110}; 111 112/// BinaryExprAST - Expression class for a binary operator. 113class BinaryExprAST : public ExprAST { 114 char Op; 115 ExprAST *LHS, *RHS; 116public: 117 BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs) 118 : Op(op), LHS(lhs), RHS(rhs) {} 119 virtual Value *Codegen(); 120}; 121 122/// CallExprAST - Expression class for function calls. 123class CallExprAST : public ExprAST { 124 std::string Callee; 125 std::vector<ExprAST*> Args; 126public: 127 CallExprAST(const std::string &callee, std::vector<ExprAST*> &args) 128 : Callee(callee), Args(args) {} 129 virtual Value *Codegen(); 130}; 131 132/// PrototypeAST - This class represents the "prototype" for a function, 133/// which captures its name, and its argument names (thus implicitly the number 134/// of arguments the function takes). 135class PrototypeAST { 136 std::string Name; 137 std::vector<std::string> Args; 138public: 139 PrototypeAST(const std::string &name, const std::vector<std::string> &args) 140 : Name(name), Args(args) {} 141 142 Function *Codegen(); 143}; 144 145/// FunctionAST - This class represents a function definition itself. 146class FunctionAST { 147 PrototypeAST *Proto; 148 ExprAST *Body; 149public: 150 FunctionAST(PrototypeAST *proto, ExprAST *body) 151 : Proto(proto), Body(body) {} 152 153 Function *Codegen(); 154}; 155 156//===----------------------------------------------------------------------===// 157// Parser 158//===----------------------------------------------------------------------===// 159 160/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current 161/// token the parser is looking at. getNextToken reads another token from the 162/// lexer and updates CurTok with its results. 163static int CurTok; 164static int getNextToken() { 165 return CurTok = gettok(); 166} 167 168/// BinopPrecedence - This holds the precedence for each binary operator that is 169/// defined. 170static std::map<char, int> BinopPrecedence; 171 172/// GetTokPrecedence - Get the precedence of the pending binary operator token. 173static int GetTokPrecedence() { 174 if (!isascii(CurTok)) 175 return -1; 176 177 // Make sure it's a declared binop. 178 int TokPrec = BinopPrecedence[CurTok]; 179 if (TokPrec <= 0) return -1; 180 return TokPrec; 181} 182 183/// Error* - These are little helper functions for error handling. 184ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;} 185PrototypeAST *ErrorP(const char *Str) { Error(Str); return 0; } 186FunctionAST *ErrorF(const char *Str) { Error(Str); return 0; } 187 188static ExprAST *ParseExpression(); 189 190/// identifierexpr 191/// ::= identifier 192/// ::= identifier '(' expression* ')' 193static ExprAST *ParseIdentifierExpr() { 194 std::string IdName = IdentifierStr; 195 196 getNextToken(); // eat identifier. 197 198 if (CurTok != '(') // Simple variable ref. 199 return new VariableExprAST(IdName); 200 201 // Call. 202 getNextToken(); // eat ( 203 std::vector<ExprAST*> Args; 204 if (CurTok != ')') { 205 while (1) { 206 ExprAST *Arg = ParseExpression(); 207 if (!Arg) return 0; 208 Args.push_back(Arg); 209 210 if (CurTok == ')') break; 211 212 if (CurTok != ',') 213 return Error("Expected ')' or ',' in argument list"); 214 getNextToken(); 215 } 216 } 217 218 // Eat the ')'. 219 getNextToken(); 220 221 return new CallExprAST(IdName, Args); 222} 223 224/// numberexpr ::= number 225static ExprAST *ParseNumberExpr() { 226 ExprAST *Result = new NumberExprAST(NumVal); 227 getNextToken(); // consume the number 228 return Result; 229} 230 231/// parenexpr ::= '(' expression ')' 232static ExprAST *ParseParenExpr() { 233 getNextToken(); // eat (. 234 ExprAST *V = ParseExpression(); 235 if (!V) return 0; 236 237 if (CurTok != ')') 238 return Error("expected ')'"); 239 getNextToken(); // eat ). 240 return V; 241} 242 243/// primary 244/// ::= identifierexpr 245/// ::= numberexpr 246/// ::= parenexpr 247static ExprAST *ParsePrimary() { 248 switch (CurTok) { 249 default: return Error("unknown token when expecting an expression"); 250 case tok_identifier: return ParseIdentifierExpr(); 251 case tok_number: return ParseNumberExpr(); 252 case '(': return ParseParenExpr(); 253 } 254} 255 256/// binoprhs 257/// ::= ('+' primary)* 258static ExprAST *ParseBinOpRHS(int ExprPrec, ExprAST *LHS) { 259 // If this is a binop, find its precedence. 260 while (1) { 261 int TokPrec = GetTokPrecedence(); 262 263 // If this is a binop that binds at least as tightly as the current binop, 264 // consume it, otherwise we are done. 265 if (TokPrec < ExprPrec) 266 return LHS; 267 268 // Okay, we know this is a binop. 269 int BinOp = CurTok; 270 getNextToken(); // eat binop 271 272 // Parse the primary expression after the binary operator. 273 ExprAST *RHS = ParsePrimary(); 274 if (!RHS) return 0; 275 276 // If BinOp binds less tightly with RHS than the operator after RHS, let 277 // the pending operator take RHS as its LHS. 278 int NextPrec = GetTokPrecedence(); 279 if (TokPrec < NextPrec) { 280 RHS = ParseBinOpRHS(TokPrec+1, RHS); 281 if (RHS == 0) return 0; 282 } 283 284 // Merge LHS/RHS. 285 LHS = new BinaryExprAST(BinOp, LHS, RHS); 286 } 287} 288 289/// expression 290/// ::= primary binoprhs 291/// 292static ExprAST *ParseExpression() { 293 ExprAST *LHS = ParsePrimary(); 294 if (!LHS) return 0; 295 296 return ParseBinOpRHS(0, LHS); 297} 298 299/// prototype 300/// ::= id '(' id* ')' 301static PrototypeAST *ParsePrototype() { 302 if (CurTok != tok_identifier) 303 return ErrorP("Expected function name in prototype"); 304 305 std::string FnName = IdentifierStr; 306 getNextToken(); 307 308 if (CurTok != '(') 309 return ErrorP("Expected '(' in prototype"); 310 311 std::vector<std::string> ArgNames; 312 while (getNextToken() == tok_identifier) 313 ArgNames.push_back(IdentifierStr); 314 if (CurTok != ')') 315 return ErrorP("Expected ')' in prototype"); 316 317 // success. 318 getNextToken(); // eat ')'. 319 320 return new PrototypeAST(FnName, ArgNames); 321} 322 323/// definition ::= 'def' prototype expression 324static FunctionAST *ParseDefinition() { 325 getNextToken(); // eat def. 326 PrototypeAST *Proto = ParsePrototype(); 327 if (Proto == 0) return 0; 328 329 if (ExprAST *E = ParseExpression()) 330 return new FunctionAST(Proto, E); 331 return 0; 332} 333 334/// toplevelexpr ::= expression 335static FunctionAST *ParseTopLevelExpr() { 336 if (ExprAST *E = ParseExpression()) { 337 // Make an anonymous proto. 338 PrototypeAST *Proto = new PrototypeAST("", std::vector<std::string>()); 339 return new FunctionAST(Proto, E); 340 } 341 return 0; 342} 343 344/// external ::= 'extern' prototype 345static PrototypeAST *ParseExtern() { 346 getNextToken(); // eat extern. 347 return ParsePrototype(); 348} 349 350//===----------------------------------------------------------------------===// 351// Code Generation 352//===----------------------------------------------------------------------===// 353 354static Module *TheModule; 355static IRBuilder<> Builder(getGlobalContext()); 356static std::map<std::string, Value*> NamedValues; 357static FunctionPassManager *TheFPM; 358 359Value *ErrorV(const char *Str) { Error(Str); return 0; } 360 361Value *NumberExprAST::Codegen() { 362 return ConstantFP::get(getGlobalContext(), APFloat(Val)); 363} 364 365Value *VariableExprAST::Codegen() { 366 // Look this variable up in the function. 367 Value *V = NamedValues[Name]; 368 return V ? V : ErrorV("Unknown variable name"); 369} 370 371Value *BinaryExprAST::Codegen() { 372 Value *L = LHS->Codegen(); 373 Value *R = RHS->Codegen(); 374 if (L == 0 || R == 0) return 0; 375 376 switch (Op) { 377 case '+': return Builder.CreateFAdd(L, R, "addtmp"); 378 case '-': return Builder.CreateFSub(L, R, "subtmp"); 379 case '*': return Builder.CreateFMul(L, R, "multmp"); 380 case '<': 381 L = Builder.CreateFCmpULT(L, R, "cmptmp"); 382 // Convert bool 0/1 to double 0.0 or 1.0 383 return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()), 384 "booltmp"); 385 default: return ErrorV("invalid binary operator"); 386 } 387} 388 389Value *CallExprAST::Codegen() { 390 // Look up the name in the global module table. 391 Function *CalleeF = TheModule->getFunction(Callee); 392 if (CalleeF == 0) 393 return ErrorV("Unknown function referenced"); 394 395 // If argument mismatch error. 396 if (CalleeF->arg_size() != Args.size()) 397 return ErrorV("Incorrect # arguments passed"); 398 399 std::vector<Value*> ArgsV; 400 for (unsigned i = 0, e = Args.size(); i != e; ++i) { 401 ArgsV.push_back(Args[i]->Codegen()); 402 if (ArgsV.back() == 0) return 0; 403 } 404 405 return Builder.CreateCall(CalleeF, ArgsV.begin(), ArgsV.end(), "calltmp"); 406} 407 408Function *PrototypeAST::Codegen() { 409 // Make the function type: double(double,double) etc. 410 std::vector<const Type*> Doubles(Args.size(), 411 Type::getDoubleTy(getGlobalContext())); 412 FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()), 413 Doubles, false); 414 415 Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule); 416 417 // If F conflicted, there was already something named 'Name'. If it has a 418 // body, don't allow redefinition or reextern. 419 if (F->getName() != Name) { 420 // Delete the one we just made and get the existing one. 421 F->eraseFromParent(); 422 F = TheModule->getFunction(Name); 423 424 // If F already has a body, reject this. 425 if (!F->empty()) { 426 ErrorF("redefinition of function"); 427 return 0; 428 } 429 430 // If F took a different number of args, reject. 431 if (F->arg_size() != Args.size()) { 432 ErrorF("redefinition of function with different # args"); 433 return 0; 434 } 435 } 436 437 // Set names for all arguments. 438 unsigned Idx = 0; 439 for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size(); 440 ++AI, ++Idx) { 441 AI->setName(Args[Idx]); 442 443 // Add arguments to variable symbol table. 444 NamedValues[Args[Idx]] = AI; 445 } 446 447 return F; 448} 449 450Function *FunctionAST::Codegen() { 451 NamedValues.clear(); 452 453 Function *TheFunction = Proto->Codegen(); 454 if (TheFunction == 0) 455 return 0; 456 457 // Create a new basic block to start insertion into. 458 BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction); 459 Builder.SetInsertPoint(BB); 460 461 if (Value *RetVal = Body->Codegen()) { 462 // Finish off the function. 463 Builder.CreateRet(RetVal); 464 465 // Validate the generated code, checking for consistency. 466 verifyFunction(*TheFunction); 467 468 // Optimize the function. 469 TheFPM->run(*TheFunction); 470 471 return TheFunction; 472 } 473 474 // Error reading body, remove function. 475 TheFunction->eraseFromParent(); 476 return 0; 477} 478 479//===----------------------------------------------------------------------===// 480// Top-Level parsing and JIT Driver 481//===----------------------------------------------------------------------===// 482 483static ExecutionEngine *TheExecutionEngine; 484 485static void HandleDefinition() { 486 if (FunctionAST *F = ParseDefinition()) { 487 if (Function *LF = F->Codegen()) { 488 fprintf(stderr, "Read function definition:"); 489 LF->dump(); 490 } 491 } else { 492 // Skip token for error recovery. 493 getNextToken(); 494 } 495} 496 497static void HandleExtern() { 498 if (PrototypeAST *P = ParseExtern()) { 499 if (Function *F = P->Codegen()) { 500 fprintf(stderr, "Read extern: "); 501 F->dump(); 502 } 503 } else { 504 // Skip token for error recovery. 505 getNextToken(); 506 } 507} 508 509static void HandleTopLevelExpression() { 510 // Evaluate a top-level expression into an anonymous function. 511 if (FunctionAST *F = ParseTopLevelExpr()) { 512 if (Function *LF = F->Codegen()) { 513 // JIT the function, returning a function pointer. 514 void *FPtr = TheExecutionEngine->getPointerToFunction(LF); 515 516 // Cast it to the right type (takes no arguments, returns a double) so we 517 // can call it as a native function. 518 double (*FP)() = (double (*)())(intptr_t)FPtr; 519 fprintf(stderr, "Evaluated to %f\n", FP()); 520 } 521 } else { 522 // Skip token for error recovery. 523 getNextToken(); 524 } 525} 526 527/// top ::= definition | external | expression | ';' 528static void MainLoop() { 529 while (1) { 530 fprintf(stderr, "ready> "); 531 switch (CurTok) { 532 case tok_eof: return; 533 case ';': getNextToken(); break; // ignore top-level semicolons. 534 case tok_def: HandleDefinition(); break; 535 case tok_extern: HandleExtern(); break; 536 default: HandleTopLevelExpression(); break; 537 } 538 } 539} 540 541//===----------------------------------------------------------------------===// 542// "Library" functions that can be "extern'd" from user code. 543//===----------------------------------------------------------------------===// 544 545/// putchard - putchar that takes a double and returns 0. 546extern "C" 547double putchard(double X) { 548 putchar((char)X); 549 return 0; 550} 551 552//===----------------------------------------------------------------------===// 553// Main driver code. 554//===----------------------------------------------------------------------===// 555 556int main() { 557 InitializeNativeTarget(); 558 LLVMContext &Context = getGlobalContext(); 559 560 // Install standard binary operators. 561 // 1 is lowest precedence. 562 BinopPrecedence['<'] = 10; 563 BinopPrecedence['+'] = 20; 564 BinopPrecedence['-'] = 20; 565 BinopPrecedence['*'] = 40; // highest. 566 567 // Prime the first token. 568 fprintf(stderr, "ready> "); 569 getNextToken(); 570 571 // Make the module, which holds all the code. 572 TheModule = new Module("my cool jit", Context); 573 574 // Create the JIT. This takes ownership of the module. 575 std::string ErrStr; 576 TheExecutionEngine = EngineBuilder(TheModule).setErrorStr(&ErrStr).create(); 577 if (!TheExecutionEngine) { 578 fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str()); 579 exit(1); 580 } 581 582 FunctionPassManager OurFPM(TheModule); 583 584 // Set up the optimizer pipeline. Start with registering info about how the 585 // target lays out data structures. 586 OurFPM.add(new TargetData(*TheExecutionEngine->getTargetData())); 587 // Do simple "peephole" optimizations and bit-twiddling optzns. 588 OurFPM.add(createInstructionCombiningPass()); 589 // Reassociate expressions. 590 OurFPM.add(createReassociatePass()); 591 // Eliminate Common SubExpressions. 592 OurFPM.add(createGVNPass()); 593 // Simplify the control flow graph (deleting unreachable blocks, etc). 594 OurFPM.add(createCFGSimplificationPass()); 595 596 OurFPM.doInitialization(); 597 598 // Set the global so the code gen can use this. 599 TheFPM = &OurFPM; 600 601 // Run the main "interpreter loop" now. 602 MainLoop(); 603 604 TheFPM = 0; 605 606 // Print out all of the generated code. 607 TheModule->dump(); 608 609 return 0; 610} 611