//===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This pass identifies/eliminate Redundant TLS Loads if related option is set. // The example: Please refer to the comment at the head of TLSVariableHoist.h. // //===----------------------------------------------------------------------===// #include "llvm/ADT/SmallVector.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/TLSVariableHoist.h" #include #include #include #include #include using namespace llvm; using namespace tlshoist; #define DEBUG_TYPE "tlshoist" static cl::opt TLSLoadHoist( "tls-load-hoist", cl::init(false), cl::Hidden, cl::desc("hoist the TLS loads in PIC model to eliminate redundant " "TLS address calculation.")); namespace { /// The TLS Variable hoist pass. class TLSVariableHoistLegacyPass : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid TLSVariableHoistLegacyPass() : FunctionPass(ID) { initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &Fn) override; StringRef getPassName() const override { return "TLS Variable Hoist"; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); AU.addRequired(); } private: TLSVariableHoistPass Impl; }; } // end anonymous namespace char TLSVariableHoistLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist", "TLS Variable Hoist", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist", "TLS Variable Hoist", false, false) FunctionPass *llvm::createTLSVariableHoistPass() { return new TLSVariableHoistLegacyPass(); } /// Perform the TLS Variable Hoist optimization for the given function. bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) { if (skipFunction(Fn)) return false; LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n"); LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); bool MadeChange = Impl.runImpl(Fn, getAnalysis().getDomTree(), getAnalysis().getLoopInfo()); if (MadeChange) { LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: " << Fn.getName() << '\n'); LLVM_DEBUG(dbgs() << Fn); } LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n"); return MadeChange; } void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) { // Skip all cast instructions. They are visited indirectly later on. if (Inst->isCast()) return; // Scan all operands. for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) { auto *GV = dyn_cast(Inst->getOperand(Idx)); if (!GV || !GV->isThreadLocal()) continue; // Add Candidate to TLSCandMap (GV --> Candidate). TLSCandMap[GV].addUser(Inst, Idx); } } void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) { // First, quickly check if there is TLS Variable. Module *M = Fn.getParent(); bool HasTLS = llvm::any_of( M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); }); // If non, directly return. if (!HasTLS) return; TLSCandMap.clear(); // Then, collect TLS Variable info. for (BasicBlock &BB : Fn) { // Ignore unreachable basic blocks. if (!DT->isReachableFromEntry(&BB)) continue; for (Instruction &Inst : BB) collectTLSCandidate(&Inst); } } static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) { if (Cand.Users.size() != 1) return false; BasicBlock *BB = Cand.Users[0].Inst->getParent(); if (LI->getLoopFor(BB)) return false; return true; } Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB, Loop *L) { assert(L && "Unexcepted Loop status!"); // Get the outermost loop. while (Loop *Parent = L->getParentLoop()) L = Parent; BasicBlock *PreHeader = L->getLoopPreheader(); // There is unique predecessor outside the loop. if (PreHeader) return PreHeader->getTerminator(); BasicBlock *Header = L->getHeader(); BasicBlock *Dom = Header; for (BasicBlock *PredBB : predecessors(Header)) Dom = DT->findNearestCommonDominator(Dom, PredBB); assert(Dom && "Not find dominator BB!"); Instruction *Term = Dom->getTerminator(); return Term; } Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1, Instruction *I2) { if (!I1) return I2; return DT->findNearestCommonDominator(I1, I2); } BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn, GlobalVariable *GV, BasicBlock *&PosBB) { tlshoist::TLSCandidate &Cand = TLSCandMap[GV]; // We should hoist the TLS use out of loop, so choose its nearest instruction // which dominate the loop and the outside loops (if exist). Instruction *LastPos = nullptr; for (auto &User : Cand.Users) { BasicBlock *BB = User.Inst->getParent(); Instruction *Pos = User.Inst; if (Loop *L = LI->getLoopFor(BB)) { Pos = getNearestLoopDomInst(BB, L); assert(Pos && "Not find insert position out of loop!"); } Pos = getDomInst(LastPos, Pos); LastPos = Pos; } assert(LastPos && "Unexpected insert position!"); BasicBlock *Parent = LastPos->getParent(); PosBB = Parent; return LastPos->getIterator(); } // Generate a bitcast (no type change) to replace the uses of TLS Candidate. Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn, GlobalVariable *GV) { BasicBlock *PosBB = &Fn.getEntryBlock(); BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB); Type *Ty = GV->getType(); auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast"); CastInst->insertInto(PosBB, Iter); return CastInst; } bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn, GlobalVariable *GV) { tlshoist::TLSCandidate &Cand = TLSCandMap[GV]; // If only used 1 time and not in loops, we no need to replace it. if (oneUseOutsideLoop(Cand, LI)) return false; // Generate a bitcast (no type change) auto *CastInst = genBitCastInst(Fn, GV); // to replace the uses of TLS Candidate for (auto &User : Cand.Users) User.Inst->setOperand(User.OpndIdx, CastInst); return true; } bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) { if (TLSCandMap.empty()) return false; bool Replaced = false; for (auto &GV2Cand : TLSCandMap) { GlobalVariable *GV = GV2Cand.first; Replaced |= tryReplaceTLSCandidate(Fn, GV); } return Replaced; } /// Optimize expensive TLS variables in the given function. bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT, LoopInfo &LI) { if (Fn.hasOptNone()) return false; if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist")) return false; this->LI = &LI; this->DT = &DT; assert(this->LI && this->DT && "Unexcepted requirement!"); // Collect all TLS variable candidates. collectTLSCandidates(Fn); bool MadeChange = tryReplaceTLSCandidates(Fn); return MadeChange; } PreservedAnalyses TLSVariableHoistPass::run(Function &F, FunctionAnalysisManager &AM) { auto &LI = AM.getResult(F); auto &DT = AM.getResult(F); if (!runImpl(F, DT, LI)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserveSet(); return PA; }