Commit 7167475d authored by Thomas's avatar Thomas Committed by gbsbuild

combine shift + bitcast to vector + extract into extract

Change-Id: I893f227970850fdfedc41eebccf55d5f93d4558d
parent 35e635c9
......@@ -729,6 +729,66 @@ void IGC::CustomSafeOptPass::visitSampleBptr(llvm::SampleIntrinsic* sampleInst)
}
}
void CustomSafeOptPass::visitExtractElementInst(ExtractElementInst &I)
{
// convert:
// %1 = lshr i32 %0, 16,
// %2 = bitcast i32 %1 to <2 x half>
// %3 = extractelement <2 x half> %2, i32 0
// to ->
// %2 = bitcast i32 %0 to <2 x half>
// %3 = extractelement <2 x half> %2, i32 1
BitCastInst* bitCast = dyn_cast<BitCastInst>(I.getVectorOperand());
ConstantInt* cstIndex = dyn_cast<ConstantInt>(I.getIndexOperand());
if(bitCast && cstIndex)
{
// skip intermediate bitcast
while(isa<BitCastInst>(bitCast->getOperand(0)))
{
bitCast = cast<BitCastInst>(bitCast->getOperand(0));
}
if(BinaryOperator* binOp = dyn_cast<BinaryOperator>(bitCast->getOperand(0)))
{
unsigned int bitShift = 0;
bool rightShift = false;
if(binOp->getOpcode() == Instruction::LShr)
{
if(ConstantInt* cstShift = dyn_cast<ConstantInt>(binOp->getOperand(1)))
{
bitShift = (unsigned int)cstShift->getZExtValue();
rightShift = true;
}
}
else if(binOp->getOpcode() == Instruction::Shl)
{
if(ConstantInt* cstShift = dyn_cast<ConstantInt>(binOp->getOperand(1)))
{
bitShift = (unsigned int)cstShift->getZExtValue();
}
}
if(bitShift != 0)
{
Type* vecType = I.getVectorOperand()->getType();
unsigned int eltSize = vecType->getVectorElementType()->getPrimitiveSizeInBits();
if(bitShift % eltSize == 0)
{
int elOffset = (int)(bitShift / eltSize);
elOffset = rightShift ? elOffset : -elOffset;
unsigned int newIndex = (unsigned int)((int)cstIndex->getZExtValue() + elOffset);
if(newIndex < vecType->getVectorNumElements())
{
IRBuilder<> builder(&I);
Value* newBitCast = builder.CreateBitCast(binOp->getOperand(0), vecType);
Value* newScalar = builder.CreateExtractElement(newBitCast, newIndex);
I.replaceAllUsesWith(newScalar);
I.eraseFromParent();
}
}
}
}
}
}
// Register pass to igc-opt
#define PASS_FLAG2 "igc-gen-specific-pattern"
#define PASS_DESCRIPTION2 "LastPatternMatch Pass"
......
......@@ -74,6 +74,7 @@ namespace IGC
void visitMulH(llvm::CallInst* inst, bool isSigned);
void visitFPToUIInst(llvm::FPToUIInst& FPUII);
void visitFPTruncInst(llvm::FPTruncInst &I);
void visitExtractElementInst(llvm::ExtractElementInst& I);
//
// IEEE Floating point arithmetic is not associative. Any pattern
// match that changes the order or paramters is unsafe.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment