Last active
August 29, 2015 14:17
-
-
Save riyadparvez/a2c157b24579c6552466 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include "llvm/ADT/SetVector.h" | |
#include "llvm/ADT/APSInt.h" | |
#include "clang/Driver/Options.h" | |
#include "clang/AST/AST.h" | |
#include "clang/AST/ASTContext.h" | |
#include "clang/AST/ASTConsumer.h" | |
#include "clang/AST/Expr.h" | |
#include "clang/AST/OperationKinds.h" | |
#include "clang/AST/RecursiveASTVisitor.h" | |
#include "clang/Frontend/ASTConsumers.h" | |
#include "clang/Frontend/FrontendActions.h" | |
#include "clang/Frontend/CompilerInstance.h" | |
#include "clang/Lex/Lexer.h" | |
#include "clang/Rewrite/Core/Rewriter.h" | |
#include "clang/Rewrite/Frontend/FrontendActions.h" | |
#include "clang/Tooling/CommonOptionsParser.h" | |
#include "clang/Tooling/Refactoring.h" | |
#include "clang/Tooling/Tooling.h" | |
#include <iostream> | |
#include <set> | |
using namespace std; | |
using namespace clang; | |
using namespace clang::driver; | |
using namespace clang::tooling; | |
using namespace llvm; | |
typedef struct { | |
int line_no; | |
} s; | |
// | |
Rewriter rewriter; | |
LangOptions languageOptions; | |
class S2EInstrumentationVisitor : public RecursiveASTVisitor<S2EInstrumentationVisitor> { | |
private: | |
std::set<Expr *> expressions; | |
std::set<VarDecl *> declarations; | |
FunctionDecl *currentFunctionDecl; | |
ASTContext *astContext; // used for getting additional AST info | |
public: | |
explicit S2EInstrumentationVisitor(CompilerInstance *CI) | |
: astContext(&(CI->getASTContext())) // initialize private members | |
{ | |
rewriter.setSourceMgr(astContext->getSourceManager(), astContext->getLangOpts()); | |
} | |
bool VisitDecl(Decl *Declaration) | |
{ | |
if(FunctionDecl *funcDecl = dyn_cast<FunctionDecl>(Declaration)) { | |
// Function is defined in external translation unit | |
if (funcDecl->isExternC()) { | |
currentFunctionDecl = funcDecl; | |
} else { | |
currentFunctionDecl = NULL; | |
} | |
expressions.clear(); | |
declarations.clear(); | |
#if 0 | |
llvm::outs() << "function name: " << funcDecl->getNameAsString() << " (return type = " << funcDecl->getResultType().getAsString() << ")\n"; | |
unsigned paramCount = funcDecl->getNumParams(); | |
llvm::outs() << "function param count: " << paramCount << "\n"; | |
for(unsigned i = 0; i < paramCount; ++i) { | |
llvm::outs() << "-param #" << i << "\n"; | |
const ParmVarDecl *currentParam = funcDecl->getParamDecl(i); | |
QualType userType = currentParam->getType(); | |
while(userType->isPointerType()) { | |
llvm::outs() << "\tpointer to" << "\n"; | |
userType = userType->getPointeeType(); | |
} | |
if(userType.isConstQualified()) { | |
llvm::outs() << "\tconst" << "\n"; | |
} | |
if(userType->isReferenceType()) { | |
llvm::outs() << "\treference to" << "\n"; | |
} | |
userType = userType.getNonReferenceType().getUnqualifiedType(); | |
llvm::outs() << "\t(type = " << userType.getAsString() << ", name = " << currentParam->getNameAsString() << ")\n"; | |
} | |
llvm::outs() << "\n"; | |
#endif | |
} | |
if(VarDecl *varDecl = dyn_cast<VarDecl>(Declaration)) { | |
if(!dyn_cast<ParmVarDecl>(Declaration)) { | |
//llvm::outs() << "variable type: " << varDecl->getType().getAsString() << ", variable name: " << varDecl->getNameAsString(); | |
std::string name = varDecl->getNameAsString(); | |
if(varDecl->hasInit()) { | |
Expr* varInit = varDecl->getInit(); | |
if(varInit->isRValue()) { | |
// Works | |
#if 0 | |
SourceRange varSourceRange = varInit->getSourceRange(); | |
if(!varSourceRange.isValid()) | |
return true; | |
CharSourceRange charSourceRange(varSourceRange, true); | |
StringRef sourceText = Lexer::getSourceText(charSourceRange, astContext->getSourceManager(), astContext->getLangOpts(), 0); | |
//llvm::outs() << ", initialization value: " << sourceText.str(); | |
#endif | |
if (isa<CallExpr>(varInit)) { | |
// Works | |
CallExpr *Call = dyn_cast<CallExpr>(varInit); | |
Decl *D = Call->getCalleeDecl(); | |
FunctionDecl *FD = Call->getDirectCallee(); | |
std::string fname = FD->getNameInfo().getAsString(); | |
if (FD->isExternC() && (fname == "malloc" || fname == "calloc")) { | |
declarations.insert(varDecl); | |
std::string str = "\ns2e_concretize_fork(" + name + ", " + "sizeof(" + name + "), " + "0" + ");\n"; | |
//llvm::outs() << str; | |
InstrumentStmtAfter(varDecl, str); | |
} | |
} | |
} | |
} | |
} | |
} | |
return true; | |
} | |
// Get assigned variable | |
bool GetAssignedVar(Stmt *s, std::string& name) { | |
BinaryOperator *BinOp = dyn_cast<BinaryOperator>(s); | |
if (BinOp && BinOp->isAssignmentOp()) { | |
Expr *Lhs = BinOp->getLHS(); | |
if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Lhs)) { | |
if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) { | |
name = VD->getQualifiedNameAsString(); | |
return true; | |
} | |
} | |
} | |
return false; | |
} | |
// Override Statements which includes expressions and more | |
bool VisitStmt(Stmt *s) { | |
#if 0 | |
Stmt *TH = If->getThen(); | |
// Add braces if needed to then clause | |
InstrumentStmt(TH); | |
Stmt *EL = If->getElse(); | |
if (EL) { | |
// Add braces if needed to else clause | |
InstrumentStmt(EL); | |
} | |
} else if (isa<ForStmt>(s)) { | |
ForStmt *For = cast<ForStmt>(s); | |
Stmt *BODY = For->getBody(); | |
//InstrumentStmt(BODY); | |
} | |
#endif | |
return true; // returning false aborts the traversal | |
} | |
virtual bool VisitCallExpr(CallExpr *CallE) { | |
Decl *D = CallE->getCalleeDecl(); | |
FunctionDecl *FD = CallE->getDirectCallee(); | |
std::string fname = FD->getNameInfo().getAsString(); | |
if(fname == "func") { | |
//SourceLocation START = s->getLocStart(); | |
/** Replace function **/ | |
SourceRange range = CallE->getSourceRange(); | |
SourceLocation source = range.getBegin(); | |
rewriter.ReplaceText(source, "s2e"); | |
llvm::outs() << "Begin: " << range.getBegin().printToString(rewriter.getSourceMgr()) | |
<< " End: " << range.getEnd().printToString(rewriter.getSourceMgr()) << "\n"; | |
/** Replace function argument **/ | |
//#if 0 | |
//for (CallExpr::const_arg_iterator it = CallE->arg_begin(), ite = CallE->arg_end(); it != ite; ++it) { | |
for (CallExpr::arg_iterator it = CallE->arg_begin(), ite = CallE->arg_end(); it != ite; ++it) { | |
Expr *arg = *it; | |
//SourceLocation source = arg->getExprLoc(); | |
SourceRange r = arg->getSourceRange(); | |
//SourceLocation begin = r.getBegin(); | |
//SourceLocation end = r.getEnd(); | |
SourceLocation begin(arg->getLocStart()), _e(arg->getLocEnd()); | |
SourceLocation end(clang::Lexer::getLocForEndOfToken(_e, 0, rewriter.getSourceMgr(), languageOptions)); | |
llvm::outs() << std::string(rewriter.getSourceMgr().getCharacterData(begin), | |
rewriter.getSourceMgr().getCharacterData(end) - rewriter.getSourceMgr().getCharacterData(begin)) | |
<< "\n"; | |
//rewriter.ReplaceText(source, "val"); | |
//llvm::outs() << source.printToString(rewriter.getSourceMgr()) << "\n"; | |
return true; | |
} | |
//#endif | |
} | |
return false; | |
} | |
virtual bool VisitBinaryOperator(BinaryOperator* BinaryOp) { | |
if (BinaryOp->isAssignmentOp() && isa<CallExpr>(BinaryOp->getRHS())) { | |
Expr *Lhs = BinaryOp->getLHS(); | |
std::string name; | |
if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Lhs)) { | |
if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) { | |
declarations.insert(VD); | |
name = VD->getQualifiedNameAsString(); | |
} | |
} | |
CallExpr *CallE = cast<CallExpr>(BinaryOp->getRHS()); | |
Decl *D = CallE->getCalleeDecl(); | |
FunctionDecl *FD = CallE->getDirectCallee(); | |
std::string fname = FD->getNameInfo().getAsString(); | |
if (fname == "malloc" || fname == "calloc") { | |
//expressions.insert(BinaryOp); | |
//CallE->dumpPretty(*astContext); | |
CallE->dumpColor(); | |
llvm::outs() << "\n"; | |
std::string str = "\ns2e_concretize_fork(" + name + ", " + "sizeof(" + name + "), " + "0" + ");\n"; | |
//llvm::outs() << str; | |
InstrumentStmtAfter(BinaryOp, str); | |
return true; | |
} | |
} | |
return false; | |
} | |
// Returns true if the condition was simple boolean | |
virtual bool VisitBooleanCondition(Expr *Cond) { | |
Expr *Var = Cond; | |
bool negation = false; | |
UnaryOperator *UnaryOp = dyn_cast<UnaryOperator>(Cond); | |
// Handles if (p) or if (!p) cases | |
if (UnaryOp && UnaryOp->getOpcode() == UO_Not) { | |
Var = UnaryOp->getSubExpr(); | |
negation = true; | |
} | |
if (ImplicitCastExpr *Cast = dyn_cast<ImplicitCastExpr>(Var)) { | |
VarDecl *VarD; | |
std::string name; | |
Expr *OriginalCast = Cast->getSubExpr(); | |
if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(OriginalCast)) { | |
if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) { | |
VarD = VD; | |
name = VD->getQualifiedNameAsString(); | |
} | |
} | |
llvm::outs() << name << "\n"; | |
} else { | |
return true; | |
} | |
return true; | |
} | |
virtual bool VisitIfStmt(IfStmt* If) { | |
Expr *Cond = If->getCond(); | |
if (VisitBooleanCondition(Cond)) { | |
return true; | |
} | |
#if 0 | |
BinaryOperator *BinaryOp = dyn_cast<BinaryOperator>(Cond); | |
if (!BinaryOp->isEqualityOp() && | |
!BinaryOp->isRelationalOp() && | |
!BinaryOp->isComparisonOp()) { | |
return true; | |
} | |
Expr *Lhs = BinaryOp->getLHS(); | |
std::string name; | |
VarDecl *VarD; | |
// a == 5, p == NULL, s.x == 0 | |
if (ImplicitCastExpr *Cast = dyn_cast<ImplicitCastExpr>(Lhs)) { | |
Expr *OriginalCast = Cast->getSubExpr(); | |
if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(OriginalCast)) { | |
if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) { | |
VarD = VD; | |
name = VD->getQualifiedNameAsString(); | |
} | |
} else if (MemberExpr *Member = dyn_cast<MemberExpr>(OriginalCast)) { | |
// Doesn't work | |
ValueDecl *ValueD = Member->getMemberDecl(); | |
DeclarationNameInfo Name = Member->getMemberNameInfo(); | |
name = Name.getAsString(); | |
//llvm::outs() << "LHS is " << name << "\n"; | |
if (VarDecl *VD = dyn_cast<VarDecl>(Member->getMemberDecl())) { | |
VarD = VD; | |
name = VD->getQualifiedNameAsString(); | |
} | |
} | |
} | |
// | |
else if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Lhs)) { | |
if (VarDecl *VD = dyn_cast<VarDecl>(DRE->getDecl())) { | |
VarD = VD; | |
name = VD->getQualifiedNameAsString(); | |
} | |
} | |
Expr *Rhs = BinaryOp->getRHS(); | |
if (isa<IntegerLiteral>(Rhs)) { | |
IntegerLiteral *IntLit = cast<IntegerLiteral>(Rhs); | |
llvm::APSInt Result; | |
Rhs->EvaluateAsInt(Result, *astContext); | |
//int64_t result = Result.getExtValue(); | |
llvm::APSInt concreteValue; | |
BinaryOperatorKind opc = BinaryOp->getOpcode(); | |
switch (opc) { | |
case (BO_LE) : | |
case (BO_GE) : | |
case (BO_EQ) : { | |
concreteValue = Result; | |
break; | |
} | |
case (BO_NE) : { | |
concreteValue = Result++; | |
break; | |
} | |
case (BO_LT) : { | |
concreteValue = Result--; | |
break; | |
} | |
case (BO_GT) : { | |
concreteValue = Result++; | |
break; | |
} | |
default : { | |
break; | |
} | |
} | |
llvm::outs() << "Set value " << concreteValue.toString(10) << "\n"; | |
#if 0 | |
if (BinaryOp->isEqualityOp()) { | |
} else if (BinaryOp->isComparisonOp()) { | |
} else if (BinaryOp->isRelationalOp()) { | |
} | |
#endif | |
//llvm::outs() << "Integer Literal: " << result.toString(10) << "\n"; | |
//InstrumentStmtBefore(If); | |
} else if (isa<CharacterLiteral>(Rhs)) { | |
CharacterLiteral *CharLit = cast<CharacterLiteral>(Rhs); | |
} else if (Rhs->isNullPointerConstant(*astContext, Expr::NPC_ValueDependentIsNotNull)) { | |
// Works | |
//InstrumentStmtBefore(If); | |
} | |
#endif | |
return true; | |
} | |
void InstrumentStmtBefore(Stmt *s, const std::string& str) { | |
if (!isa<CompoundStmt>(s)) { | |
SourceLocation START = s->getLocStart(); | |
rewriter.InsertText(START, str, true, true); | |
} else { | |
SourceLocation START = s->getSourceRange().getBegin(); | |
rewriter.InsertText(START, str, true, true); | |
} | |
} | |
// InstrumentStmt - Add braces to line of code | |
void InstrumentStmtAfter(Stmt *s, const std::string& str) { | |
// Only perform if statement is not compound | |
if (!isa<CompoundStmt>(s)) { | |
#if 0 | |
SourceLocation ST = s->getLocStart(); | |
// Insert opening brace. Note the second true parameter to InsertText() | |
// says to indent. Sadly, it will indent to the line after the if, giving: | |
// if (expr) | |
// { | |
// stmt; | |
// } | |
rewriter.InsertText(ST, "{\n", true, true); | |
// Note Stmt::getLocEnd() returns the source location prior to the | |
// token at the end of the line. For instance, for: | |
// var = 123; | |
// ^---- getLocEnd() points here. | |
#endif | |
SourceLocation END = s->getLocEnd(); | |
// MeasureTokenLength gets us past the last token, and adding 1 gets | |
// us past the ';'. | |
int offset = Lexer::MeasureTokenLength(END, rewriter.getSourceMgr(), rewriter.getLangOpts()) + 1; | |
SourceLocation END1 = END.getLocWithOffset(offset); | |
rewriter.InsertText(END1, str, true, true); | |
} | |
} | |
void InstrumentStmtAfter(Decl *d, const std::string& str) { | |
#if 0 | |
SourceLocation ST = s->getLocStart(); | |
// Insert opening brace. Note the second true parameter to InsertText() | |
// says to indent. Sadly, it will indent to the line after the if, giving: | |
// if (expr) | |
// { | |
// stmt; | |
// } | |
rewriter.InsertText(ST, "{\n", true, true); | |
// Note Stmt::getLocEnd() returns the source location prior to the | |
// token at the end of the line. For instance, for: | |
// var = 123; | |
// ^---- getLocEnd() points here. | |
#endif | |
SourceLocation END = d->getLocEnd(); | |
// MeasureTokenLength gets us past the last token, and adding 1 gets | |
// us past the ';'. | |
int offset = Lexer::MeasureTokenLength(END, rewriter.getSourceMgr(), rewriter.getLangOpts()) + 1; | |
SourceLocation END1 = END.getLocWithOffset(offset); | |
rewriter.InsertText(END1, str, true, true); | |
} | |
#if 0 | |
virtual bool VisitFunctionDecl(FunctionDecl *func) { | |
numFunctions++; | |
string funcName = func->getNameInfo().getName().getAsString(); | |
if (funcName == "do_math") { | |
rewriter.ReplaceText(func->getLocation(), funcName.length(), "add5"); | |
errs() << "** Rewrote function def: " << funcName << "\n"; | |
} | |
return true; | |
} | |
virtual bool VisitStmt(Stmt *st) { | |
if (ReturnStmt *ret = dyn_cast<ReturnStmt>(st)) { | |
rewriter.ReplaceText(ret->getRetValue()->getLocStart(), 6, "val"); | |
errs() << "** Rewrote ReturnStmt\n"; | |
} | |
if (CallExpr *call = dyn_cast<CallExpr>(st)) { | |
rewriter.ReplaceText(call->getLocStart(), 7, "add5"); | |
errs() << "** Rewrote function call\n"; | |
} | |
return true; | |
} | |
// Override Binary Operator expressions | |
virtual Expr *VisitBinaryOperator(BinaryOperator *E) { | |
// Determine type of binary operator | |
if (E->isLogicalOp()) { | |
// Insert function call at start of first expression. | |
// Note getLocStart() should work as well as getExprLoc() | |
rewriter.InsertText(E->getLHS()->getExprLoc(), | |
E->getOpcode() == BO_LAnd ? "L_AND(" : "L_OR(", true); | |
// Replace operator ("||" or "&&") with "," | |
rewriter.ReplaceText(E->getOperatorLoc(), E->getOpcodeStr().size(), ","); | |
// Insert closing paren at end of right-hand expression | |
rewriter.InsertTextAfterToken(E->getRHS()->getLocEnd(), ")"); | |
} else | |
// Note isComparisonOp() is like isRelationalOp() but includes == and != | |
if (E->isRelationalOp()) { | |
llvm::errs() << "Relational Op " << E->getOpcodeStr() << "\n"; | |
} else | |
// Handles == and != comparisons | |
if (E->isEqualityOp()) { | |
llvm::errs() << "Equality Op " << E->getOpcodeStr() << "\n"; | |
} | |
return E; | |
} | |
/* | |
virtual bool VisitReturnStmt(ReturnStmt *ret) { | |
rewriter.ReplaceText(ret->getRetValue()->getLocStart(), 6, "val"); | |
errs() << "** Rewrote ReturnStmt\n"; | |
return true; | |
} | |
virtual bool VisitCallExpr(CallExpr *call) { | |
rewriter.ReplaceText(call->getLocStart(), 7, "add5"); | |
errs() << "** Rewrote function call\n"; | |
return true; | |
} | |
*/ | |
#endif | |
}; | |
class S2EInstrumentationASTConsumer : public ASTConsumer { | |
private: | |
S2EInstrumentationVisitor *visitor; // doesn't have to be private | |
public: | |
// override the constructor in order to pass CI | |
explicit S2EInstrumentationASTConsumer(CompilerInstance *CI) | |
: visitor(new S2EInstrumentationVisitor(CI)) // initialize the visitor | |
{ } | |
#if 0 | |
// override this to call our ExampleVisitor on the entire source file | |
virtual void HandleTranslationUnit(ASTContext &Context) { | |
/* we can use ASTContext to get the TranslationUnitDecl, which is | |
a single Decl that collectively represents the entire source file */ | |
visitor->TraverseDecl(Context.getTranslationUnitDecl()); | |
//visitor->TraverseStmt(); | |
} | |
#endif | |
// override this to call our ExampleVisitor on each top-level Decl | |
virtual bool HandleTopLevelDecl(DeclGroupRef DG) { | |
// a DeclGroupRef may have multiple Decls, so we iterate through each one | |
for (DeclGroupRef::iterator i = DG.begin(), e = DG.end(); i != e; i++) { | |
Decl *D = *i; | |
visitor->TraverseDecl(D); // recursively visit each AST node in Decl "D" | |
//D->dump(); | |
} | |
return true; | |
} | |
}; | |
class S2EInstrumentationFrontendAction : public ASTFrontendAction { | |
public: | |
virtual ASTConsumer *CreateASTConsumer(CompilerInstance &CI, StringRef file) { | |
return new S2EInstrumentationASTConsumer(&CI); // pass CI pointer to ASTConsumer | |
} | |
}; | |
int main(int argc, const char **argv) { | |
// Parse the command-line args passed to your code | |
CommonOptionsParser op(argc, argv); | |
// Create a new Clang Tool instance (a LibTooling environment) | |
ClangTool Tool(op.getCompilations(), op.getSourcePathList()); | |
languageOptions.GNUMode = 1; | |
languageOptions.CXXExceptions = 1; | |
languageOptions.RTTI = 1; | |
languageOptions.Bool = 1; | |
languageOptions.CPlusPlus = 1; | |
// Run the Clang Tool, creating a new FrontendAction (explained below) | |
//int result = Tool.run(newFrontendActionFactory<RewriteMacrosAction>()); | |
int result = Tool.run(newFrontendActionFactory<S2EInstrumentationFrontendAction>()); | |
//result = Tool.run(newFrontendActionFactory<S2EInstrumentationFrontendAction>()); | |
// Print out the rewritten source code ("rewriter" is a global var.) | |
rewriter.getEditBuffer(rewriter.getSourceMgr().getMainFileID()).write(errs()); | |
//rewriter.getEditBuffer(rewriter.getSourceMgr().getMainFileID()).write(outs()); | |
return result; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment