Commit 6a6a81d9 authored by Scott Pillow's avatar Scott Pillow Committed by gbsbuild

Add mask argument to wavePrefix and waveInverseBallot. This is support for coming commits

Change-Id: Ic524f8083e729796041b3f277ab5a75c4f03dd5f
parent fc83b6f2
......@@ -7435,14 +7435,17 @@ void EmitPass::EmitGenIntrinsicMessage(llvm::GenIntrinsicInst* inst)
case GenISAIntrinsic::GenISA_WaveBallot:
emitWaveBallot(inst);
break;
case GenISAIntrinsic::GenISA_WaveInverseBallot:
emitWaveInverseBallot(inst);
break;
case GenISAIntrinsic::GenISA_WaveShuffleIndex:
emitSimdShuffle(inst);
break;
case GenISAIntrinsic::GenISA_WavePrefix:
emitWavePrefix(inst);
emitWavePrefix(cast<WavePrefixIntrinsic>(inst));
break;
case GenISAIntrinsic::GenISA_QuadPrefix:
emitWavePrefix(inst, true);
emitQuadPrefix(cast<QuadPrefixIntrinsic>(inst));
break;
case GenISAIntrinsic::GenISA_WaveAll:
emitWaveAll(inst);
......@@ -9981,16 +9984,21 @@ void EmitPass::emitReductionAll(
m_encoder->Push();
}
void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc, CVariable* pSrc, CVariable* pSrcsArr[2], bool isPrefix, bool isQuad)
void EmitPass::emitPreOrPostFixOp(
e_opcode op, uint64_t identityValue, VISA_Type type, bool negateSrc,
CVariable* pSrc, CVariable* pSrcsArr[2], CVariable *Flag,
bool isPrefix, bool isQuad)
{
// This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
if (m_currShader->m_Platform->doScalar64bScan() && CEncoder::GetCISADataTypeSize(type) == 8 && !isQuad)
{
emitPreOrPostFixOpScalar(op, identityValue, type, negateSrc, pSrc, pSrcsArr, isPrefix);
emitPreOrPostFixOpScalar(
op, identityValue, type, negateSrc,
pSrc, pSrcsArr, Flag,
isPrefix);
return;
}
// This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
bool isSimd32 = (m_currShader->m_dispatchSize == SIMDMode::SIMD32);
int counter = 1;
if (isSimd32)
......@@ -10004,18 +10012,20 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
IGC::EALIGN_GRF,
false);
// Set the GRF to 0 with no mask. This will set all the registers to 0
// Set the GRF to <identity> with no mask. This will set all the registers to <identity>
CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
m_encoder->SetNoMask();
m_encoder->Copy(pSrcCopy, pIdentityValue);
m_encoder->Push();
// Now copy the src with a mask so the disabled lanes still keep their 0
// Now copy the src with a mask so the disabled lanes still keep their <identity>
if (negateSrc)
{
m_encoder->SetSrcModifier(0, EMOD_NEG);
}
m_encoder->SetSecondHalf(i == 1);
if (Flag)
m_encoder->SetPredicate(Flag);
m_encoder->Copy(pSrcCopy, pSrc);
m_encoder->Push();
......@@ -10063,7 +10073,8 @@ void EmitPass::emitPreOrPostFixOp(e_opcode op, uint64_t identityValue, VISA_Type
{
/*
Copy the adjacent elements.
for example: r10 be the register
for example: let r10 be the register
Assume we are performing addition for this example
____ ____ ____ ____
__|____|____|____|____|____|____|____|_
| 7 | 6 | 5 | 4 | 9 | 5 | 3 | 2 |
......@@ -10239,6 +10250,7 @@ void EmitPass::emitPreOrPostFixOpScalar(
bool negateSrc,
CVariable* src,
CVariable* result[2],
CVariable* Flag,
bool isPrefix)
{
// This is to handle cases when not all lanes are enabled. In that case we fill the lanes with 0.
......@@ -10259,19 +10271,21 @@ void EmitPass::emitPreOrPostFixOpScalar(
IGC::EALIGN_GRF,
false);
// Set the GRF to 0 with no mask. This will set all the registers to 0
// Set the GRF to <identity> with no mask. This will set all the registers to <identity>
CVariable* pIdentityValue = m_currShader->ImmToVariable(identityValue, type);
m_encoder->SetSecondHalf(i == 1);
m_encoder->SetNoMask();
m_encoder->Copy(pSrcCopy[i], pIdentityValue);
m_encoder->Push();
// Now copy the src with a mask so the disabled lanes still keep their 0
// Now copy the src with a mask so the disabled lanes still keep their <identity>
if (negateSrc)
{
m_encoder->SetSrcModifier(0, EMOD_NEG);
}
m_encoder->SetSecondHalf(i == 1);
if (Flag)
m_encoder->SetPredicate(Flag);
m_encoder->Copy(pSrcCopy[i], src);
m_encoder->Push();
......@@ -14326,6 +14340,33 @@ void EmitPass::emitWaveBallot(llvm::GenIntrinsicInst* inst)
}
}
void EmitPass::emitWaveInverseBallot(llvm::GenIntrinsicInst* inst)
{
CVariable *Mask = GetSymbol(inst->getOperand(0));
if (Mask->IsUniform())
{
if (m_encoder->IsSecondHalf())
return;
m_encoder->SetP(m_destination, Mask);
return;
}
// The uniform case should by far be the most common. Otherwise,
// fall back and compute:
//
// (val & (1 << id)) != 0
CVariable *Temp = m_currShader->GetNewVariable(
numLanes(m_currShader->m_SIMDSize), ISA_TYPE_UD, EALIGN_GRF);
m_currShader->GetSimdOffsetBase(Temp);
m_encoder->Shl(Temp, m_currShader->ImmToVariable(1, ISA_TYPE_UD), Temp);
m_encoder->And(Temp, Mask, Temp);
m_encoder->Cmp(EPREDICATE_NE,
m_destination, Temp, m_currShader->ImmToVariable(0, ISA_TYPE_UD));
}
static void GetReductionOp(WaveOps op, Type* opndTy, uint64_t& identity, e_opcode& opcode, VISA_Type& type)
{
auto getISAType = [](Type* ty, bool isSigned = true)
......@@ -14468,17 +14509,49 @@ static void GetReductionOp(WaveOps op, Type* opndTy, uint64_t& identity, e_opcod
}
}
void EmitPass::emitWavePrefix(llvm::GenIntrinsicInst* inst, bool isQuad)
void EmitPass::emitWavePrefix(WavePrefixIntrinsic* I)
{
Value *Mask = I->getMask();
if (auto *CI = dyn_cast<ConstantInt>(Mask))
{
// If the mask is all set, then we just pass a null
// mask to emitScan() indicating we don't want to
// emit any predication.
if (CI->isAllOnesValue())
Mask = nullptr;
}
emitScan(
I->getSrc(), I->getOpKind(), I->isInclusiveScan(), Mask, false);
}
void EmitPass::emitQuadPrefix(QuadPrefixIntrinsic* I)
{
emitScan(
I->getSrc(), I->getOpKind(), I->isInclusiveScan(), nullptr, true);
}
void EmitPass::emitScan(
Value *Src, IGC::WaveOps Op,
bool isInclusiveScan, Value *Mask, bool isQuad)
{
WaveOps op = static_cast<WaveOps>(cast<llvm::ConstantInt>(inst->getOperand(1))->getZExtValue());
bool isInclusiveScan = cast<llvm::ConstantInt>(inst->getOperand(2))->getZExtValue() != 0;
VISA_Type type;
e_opcode opCode;
uint64_t identity = 0;
GetReductionOp(op, inst->getOperand(0)->getType(), identity, opCode, type);
CVariable* src = GetSymbol(inst->getOperand(0));
GetReductionOp(Op, Src->getType(), identity, opCode, type);
CVariable* src = GetSymbol(Src);
CVariable *dst[2] = { nullptr, nullptr };
emitPreOrPostFixOp(opCode, identity, type, false, src, dst, !isInclusiveScan, isQuad);
CVariable *Flag = Mask ? GetSymbol(Mask) : nullptr;
emitPreOrPostFixOp(
opCode, identity, type,
false, src, dst, Flag,
!isInclusiveScan, isQuad);
// Now that we've computed the result in temporary registers,
// make sure we only write the results to lanes participating in the
// scan as specified by 'mask'.
if (Flag)
m_encoder->SetPredicate(Flag);
m_encoder->Copy(m_destination, dst[0]);
if (m_currShader->m_dispatchSize == SIMDMode::SIMD32)
{
......
......@@ -259,6 +259,7 @@ public:
bool negateSrc,
CVariable* src,
CVariable* result[2],
CVariable* Flag = nullptr,
bool isPrefix = false,
bool isQuad = false);
......@@ -269,7 +270,8 @@ public:
bool negateSrc,
CVariable* src,
CVariable* result[2],
bool isPrefix = false);
CVariable* Flag,
bool isPrefix);
bool IsUniformAtomic(llvm::Instruction* pInst);
void emitAtomicRaw(llvm::GenIntrinsicInst* pInst);
......@@ -360,8 +362,10 @@ public:
// CrossLane Instructions
void emitWaveBallot(llvm::GenIntrinsicInst* inst);
void emitWaveInverseBallot(llvm::GenIntrinsicInst* inst);
void emitWaveShuffleIndex(llvm::GenIntrinsicInst* inst);
void emitWavePrefix(llvm::GenIntrinsicInst* inst, bool isQuad = false);
void emitWavePrefix(llvm::WavePrefixIntrinsic* I);
void emitQuadPrefix(llvm::QuadPrefixIntrinsic* I);
void emitWaveAll(llvm::GenIntrinsicInst* inst);
// Those three "vector" version shall be combined with
......@@ -501,6 +505,9 @@ private:
void emitSetMessagePhaseType(llvm::GenIntrinsicInst* inst, VISA_Type type);
void emitSetMessagePhaseType_legacy(llvm::GenIntrinsicInst* inst, VISA_Type type);
void emitScan(llvm::Value *Src, IGC::WaveOps Op,
bool isInclusiveScan, llvm::Value *Mask, bool isQuad);
// Cached per lane offset variables. This is a per basic block data
// structure. For each entry, the first item is the scalar type size in
// bytes, the second item is the corresponding symbol.
......
......@@ -463,7 +463,7 @@ void SubGroupFuncsResolution::subGroupScan(WaveOps op, CallInst &CI)
IRBuilder<> IRB(&CI);
Value* arg = CI.getArgOperand(0);
Value* opVal = IRB.getInt8((uint8_t)op);
Value* args[3] = { arg, opVal, IRB.getInt1(false) };
Value* args[] = { arg, opVal, IRB.getInt1(false), IRB.getInt1(true) };
Function* waveScan = GenISAIntrinsic::getDeclaration(CI.getCalledFunction()->getParent(),
GenISAIntrinsic::GenISA_WavePrefix,
arg->getType());
......
......@@ -96,7 +96,7 @@ public:
return isa<CallInst>(V) && classof(cast<CallInst>(V));
}
uint64_t getImm64Operand(unsigned idx) {
uint64_t getImm64Operand(unsigned idx) const {
assert(isa<ConstantInt>(getOperand(idx)));
return valueToImm64(getOperand(idx));
}
......@@ -531,6 +531,51 @@ public:
}
};
class WavePrefixIntrinsic : public GenIntrinsicInst
{
public:
Value *getSrc() const { return getOperand(0); }
IGC::WaveOps getOpKind() const
{
return static_cast<IGC::WaveOps>(getImm64Operand(1));
}
bool isInclusiveScan() const
{
return getImm64Operand(2) != 0;
}
Value *getMask() const { return getOperand(3); }
// Methods for support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const GenIntrinsicInst *I) {
return I->getIntrinsicID() == GenISAIntrinsic::GenISA_WavePrefix;
}
static inline bool classof(const Value *V) {
return isa<GenIntrinsicInst>(V) && classof(cast<GenIntrinsicInst>(V));
}
};
class QuadPrefixIntrinsic : public GenIntrinsicInst
{
public:
Value *getSrc() const { return getOperand(0); }
IGC::WaveOps getOpKind() const
{
return static_cast<IGC::WaveOps>(getImm64Operand(1));
}
bool isInclusiveScan() const
{
return getImm64Operand(2) != 0;
}
// Methods for support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const GenIntrinsicInst *I) {
return I->getIntrinsicID() == GenISAIntrinsic::GenISA_QuadPrefix;
}
static inline bool classof(const Value *V) {
return isa<GenIntrinsicInst>(V) && classof(cast<GenIntrinsicInst>(V));
}
};
template <class X, class Y>
inline bool isa(const Y &Val, GenISAIntrinsic::ID id)
{
......
......@@ -265,9 +265,18 @@ Imported_Intrinsics = \
"GenISA_pair_to_ptr": ["anyptr",["int","int"],"NoMem"],
"GenISA_ptr_to_pair": [["int","int"],["anyptr"],"NoMem"],
"GenISA_WaveBallot": ["int",["bool"],"Convergent,InaccessibleMemOnly"],
# Arg 0 - Mask value
# Return - assigns each lane the value of its corresponding bit.
"GenISA_WaveInverseBallot": ["bool",["int"],"Convergent,InaccessibleMemOnly"],
"GenISA_WaveShuffleIndex": ["anyint",[0,"int"],"Convergent,NoMem"],
"GenISA_WaveAll": ["anyint",[0,"char"],"Convergent,InaccessibleMemOnly"],
"GenISA_WavePrefix": ["anyint",[0,"char","bool"],"Convergent,InaccessibleMemOnly"],
# Arg 0 - Src value
# Arg 1 - Operation type
# Arg 2 - Is the operation inclusive (1) or exclusive (0)?
# Arg 3 - a mask that specifies a subset of lanes to participate
# in the computation.
# Return - The computed prefix/postfix result
"GenISA_WavePrefix": ["anyint",[0,"char","bool","bool"],"Convergent,InaccessibleMemOnly"],
"GenISA_QuadPrefix": ["anyint",[0,"char","bool"],"Convergent,InaccessibleMemOnly"],
"GenISA_InitDiscardMask": ["bool",[],"None"],
"GenISA_UpdateDiscardMask": ["bool",["bool","bool"],"None"],
......
......@@ -1238,6 +1238,7 @@ GenISA_WavePrefix,,"// Description: Accumulate and keep the intermediate results
,anyint,
,0,
,char,specify the type: sum / Prod / Min/ Max
,bool,a mask that specifies a subset of lanes to participate in the computation.
GenISA_setMessagePhaseV,,
,anyvector,new message phase result
,0,cur message phase
......
......@@ -766,9 +766,12 @@ public:
llvm::Value* create_runtime(llvm::Value* offset);
llvm::Value* create_countbits(llvm::Value* src);
llvm::Value* create_waveBallot(llvm::Value* src);
llvm::Value* create_waveInverseBallot(llvm::Value* src);
llvm::Value* create_waveshuffleIndex(llvm::Value* src, llvm::Value* index);
llvm::Value* create_waveAll(llvm::Value* src, llvm::Value* type);
llvm::Value* create_wavePrefix(llvm::Value* src, llvm::Value* type, bool inclusive = false);
llvm::Value* create_wavePrefix(
llvm::Value* src, llvm::Value* type, bool inclusive,
llvm::Value *Mask = nullptr);
llvm::Value* create_waveMatch(llvm::Instruction *inst, llvm::Value *src);
llvm::Value* create_quadPrefix(llvm::Value* src, llvm::Value* type, bool inclusive = false);
llvm::Value* get32BitLaneID();
......
......@@ -4357,6 +4357,18 @@ inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_countbits(
return this->CreateCall(pFunc, src);
}
template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value*
LLVM3DBuilder<preserveNames, T, Inserter>::create_waveInverseBallot(
llvm::Value* src)
{
llvm::Module* module = this->GetInsertBlock()->getParent()->getParent();
llvm::Function* pFunc = llvm::GenISAIntrinsic::getDeclaration(
module,
llvm::GenISAIntrinsic::GenISA_WaveInverseBallot);
return this->CreateCall(pFunc, src);
}
template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_waveBallot(llvm::Value* src)
{
......@@ -4390,14 +4402,19 @@ inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_waveAll(ll
}
template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_wavePrefix(llvm::Value* src, llvm::Value* type, bool inclusive)
inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_wavePrefix(
llvm::Value* src, llvm::Value* type, bool inclusive, llvm::Value *Mask)
{
// If a nullptr is passed in for 'Mask' (as is the default), just include
// all lanes.
Mask = Mask ? Mask : this->getInt1(true);
llvm::Module* module = this->GetInsertBlock()->getParent()->getParent();
llvm::Function* pFunc = llvm::GenISAIntrinsic::getDeclaration(
module,
llvm::GenISAIntrinsic::GenISA_WavePrefix,
src->getType());
return this->CreateCall3(pFunc, src, type, this->getInt1(inclusive));
return this->CreateCall4(pFunc, src, type, this->getInt1(inclusive), Mask);
}
template<bool preserveNames, typename T, typename Inserter>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment