Commit c5b7415d authored by Scott Pillow's avatar Scott Pillow Committed by gbsbuild

Added initial waveMultiPrefixBitcount implementation

Change-Id: I516533fa1c20c73406fef000f14112c81b3413a9
parent 84bc3a4f
......@@ -386,12 +386,10 @@ void GenIntrinsicsTTIImpl::getUnrollingPreferences(Loop *L,
bool GenIntrinsicsTTIImpl::isProfitableToHoist(Instruction *I)
{
if (llvm::GenIntrinsicInst* pIntrinsic = llvm::dyn_cast<llvm::GenIntrinsicInst>(I))
if (auto *CI = dyn_cast<CallInst>(I))
{
if (unsafeToHoist(pIntrinsic->getIntrinsicID()))
{
if (unsafeToHoist(CI))
return false;
}
}
return BaseT::isProfitableToHoist(I);
}
......
......@@ -30,19 +30,54 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#include <llvm/IR/Instruction.h>
#include <llvm/IR/BasicBlock.h>
#include "llvm/IR/InlineAsm.h"
#include "llvm/IR/InstIterator.h"
#include "common/LLVMWarningsPop.hpp"
#include "GenISAIntrinsics/GenIntrinsics.h"
#include "GenISAIntrinsics/GenIntrinsicInst.h"
using namespace llvm;
namespace IGC
{
bool unsafeToHoist(llvm::GenISAIntrinsic::ID id)
bool unsafeToHoist(const CallInst *CI)
{
return CI->isConvergent() &&
#if LLVM_VERSION_MAJOR >= 7
CI->onlyAccessesInaccessibleMemory();
#else
CI->hasFnAttr(Attribute::InaccessibleMemOnly);
#endif
}
// We currently use the combination of 'convergent' and
// 'inaccessiblememonly' to prevent code motion of
// wave intrinsics. Removing 'readnone' from a callsite
// is not sufficient to stop LICM from looking back up to the
// function definition for the attribute. We can short circuit that
// by creating an operand bundle. The name "nohoist" is not
// significant; anything will do.
CallInst* setUnsafeToHoistAttr(CallInst *CI)
{
return id == llvm::GenISAIntrinsic::GenISA_WaveBallot ||
id == llvm::GenISAIntrinsic::GenISA_WaveAll ||
id == llvm::GenISAIntrinsic::GenISA_WavePrefix ||
id == llvm::GenISAIntrinsic::GenISA_QuadPrefix;
CI->setConvergent();
#if LLVM_VERSION_MAJOR >= 7
CI->setOnlyAccessesInaccessibleMemory();
CI->removeAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
#else
CI->addAttribute(
AttributeSet::FunctionIndex, Attribute::InaccessibleMemOnly);
CI->removeAttribute(AttributeSet::FunctionIndex, Attribute::ReadNone);
#endif
OperandBundleDef OpDef("nohoist", None);
// An operand bundle cannot be appended onto a call after creation.
// clone the instruction but add our operandbundle on as well.
SmallVector<OperandBundleDef, 1> OpBundles;
CI->getOperandBundlesAsDefs(OpBundles);
OpBundles.push_back(OpDef);
CallInst *NewCall = CallInst::Create(CI, OpBundles, CI);
CI->replaceAllUsesWith(NewCall);
return NewCall;
}
class WaveIntrinsicWAPass : public llvm::ModulePass
......@@ -71,22 +106,19 @@ namespace IGC
uint32_t counter = 0;
for (llvm::Function& F : M)
{
for (llvm::BasicBlock& BB : F)
for (auto &I : instructions(F))
{
for (llvm::Instruction& inst : BB)
if (auto *CI = dyn_cast<CallInst>(&I))
{
if (llvm::GenIntrinsicInst* genIntrinsic = llvm::dyn_cast<llvm::GenIntrinsicInst>(&inst))
if (unsafeToHoist(CI))
{
if (unsafeToHoist(genIntrinsic->getIntrinsicID()))
{
changed = true;
llvm::FunctionType* voidFuncType = llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), false);
std::string asmText = "; " + std::to_string(counter++);
llvm::CallInst::Create(llvm::InlineAsm::get(voidFuncType, asmText, "", true), "", &inst);
asmText = "; " + std::to_string(counter++);
auto asmAfterIntrinsic = llvm::CallInst::Create(llvm::InlineAsm::get(voidFuncType, asmText, "", true));
asmAfterIntrinsic->insertAfter(&inst);
}
changed = true;
llvm::FunctionType* voidFuncType = llvm::FunctionType::get(llvm::Type::getVoidTy(ctx), false);
std::string asmText = "; " + std::to_string(counter++);
llvm::CallInst::Create(llvm::InlineAsm::get(voidFuncType, asmText, "", true), "", &I);
asmText = "; " + std::to_string(counter++);
auto asmAfterIntrinsic = llvm::CallInst::Create(llvm::InlineAsm::get(voidFuncType, asmText, "", true));
asmAfterIntrinsic->insertAfter(&I);
}
}
}
......
......@@ -33,5 +33,6 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
namespace IGC
{
llvm::ModulePass* createWaveIntrinsicWAPass();
bool unsafeToHoist(llvm::GenISAIntrinsic::ID id);
bool unsafeToHoist(const llvm::CallInst *CI);
llvm::CallInst* setUnsafeToHoistAttr(llvm::CallInst *CI);
}
\ No newline at end of file
......@@ -764,7 +764,7 @@ public:
llvm::Value* create_uavSerializeAll();
llvm::Value* create_discard(llvm::Value* condition);
llvm::Value* create_runtime(llvm::Value* offset);
llvm::Value* create_countbits(llvm::Value* src);
llvm::CallInst* create_countbits(llvm::Value* src);
llvm::Value* create_waveBallot(llvm::Value* src);
llvm::Value* create_waveInverseBallot(llvm::Value* src);
llvm::Value* create_waveshuffleIndex(llvm::Value* src, llvm::Value* index);
......@@ -772,12 +772,18 @@ public:
llvm::Value* create_wavePrefix(
llvm::Value* src, llvm::Value* type, bool inclusive,
llvm::Value *Mask = nullptr);
llvm::Value* create_wavePrefixBitCount(
llvm::Value* src, llvm::Value *Mask = nullptr);
llvm::Value* create_waveMatch(llvm::Instruction *inst, llvm::Value *src);
llvm::Value* create_waveMultiPrefix(
llvm::Instruction *I,
llvm::Value *Val,
llvm::Value *Mask,
IGC::WaveOps OpKind);
llvm::Value* create_waveMultiPrefixBitCount(
llvm::Instruction *I,
llvm::Value *Val,
llvm::Value *Mask);
llvm::Value* create_quadPrefix(llvm::Value* src, llvm::Value* type, bool inclusive = false);
llvm::Value* get32BitLaneID();
llvm::Value* getSimdSize();
......
......@@ -25,11 +25,10 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
======================= end_copyright_notice ==================================*/
#include "common/debug/DebugMacros.hpp" // VALUE_NAME() definition.
#include "Compiler/WaveIntrinsicWAPass.h"
#include "common/LLVMWarningsPush.hpp"
#include "llvmWrapper/AsmParser/Parser.h"
#include "common/LLVMWarningsPop.hpp"
......@@ -4347,7 +4346,7 @@ inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_uavSeriali
}
template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_countbits(llvm::Value* src)
inline llvm::CallInst* LLVM3DBuilder<preserveNames, T, Inserter>::create_countbits(llvm::Value* src)
{
llvm::Module* module = this->GetInsertBlock()->getParent()->getParent();
llvm::Function* pFunc = llvm::Intrinsic::getDeclaration(
......@@ -4417,6 +4416,29 @@ inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_wavePrefix
return this->CreateCall4(pFunc, src, type, this->getInt1(inclusive), Mask);
}
template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value*
LLVM3DBuilder<preserveNames, T, Inserter>::create_wavePrefixBitCount(
llvm::Value* src, llvm::Value *Mask)
{
//bits = ballot(bBit);
//laneMaskLT = (1 << WaveGetLaneIndex()) - 1;
//prefixBitCount = countbits(bits & laneMaskLT);
llvm::Value* ballot = this->create_waveBallot(src);
if (Mask)
ballot = this->CreateAnd(ballot, Mask);
llvm::Value* shlLaneId = this->CreateShl(
this->getInt32(1), this->get32BitLaneID());
llvm::Value* laneMask = this->CreateSub(shlLaneId, this->getInt32(1));
llvm::Value *mask = this->CreateAnd(ballot, laneMask);
// update llvm.ctpop so it won't be hoisted/sunk out of the loop.
auto *PopCnt = this->create_countbits(mask);
auto *NoHoistPopCnt = setUnsafeToHoistAttr(PopCnt);
PopCnt->eraseFromParent();
return NoHoistPopCnt;
}
template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_waveMatch(
llvm::Instruction *inst,
......@@ -4501,6 +4523,41 @@ LLVM3DBuilder<preserveNames, T, Inserter>::create_waveMultiPrefix(
return WavePrefix;
}
template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value*
LLVM3DBuilder<preserveNames, T, Inserter>::create_waveMultiPrefixBitCount(
llvm::Instruction *I,
llvm::Value *Val,
llvm::Value *Mask)
{
// Similar structure to waveMatch and waveMultiPrefix
auto *PreHeader = I->getParent();
auto *BodyBlock = PreHeader->splitBasicBlock(I, "multiprefixbitcount-body");
auto *EndBlock = BodyBlock->splitBasicBlock(
I->getNextNode(), "multiprefixbitcount-end");
// Make sure that we set the insert point again as we've just invalidated
// it with the splitBasicBlock() calls above.
this->SetInsertPoint(I);
// Now generate the code for a single iteration of the code
auto *FirstValue = this->readFirstLane(Mask);
auto *Count = this->create_wavePrefixBitCount(Val, FirstValue);
// Replace the current terminator to either exit the loop
// or branch back for another iteration.
auto *Br = BodyBlock->getTerminator();
this->SetInsertPoint(Br);
auto *ParticipatingLanes = this->create_waveInverseBallot(FirstValue);
this->CreateCondBr(ParticipatingLanes, EndBlock, BodyBlock);
Br->eraseFromParent();
this->SetInsertPoint(&*EndBlock->getFirstInsertionPt());
return Count;
}
template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_quadPrefix(llvm::Value* src, llvm::Value* type, bool inclusive)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment