Commit 6a6a81d9 authored by Scott Pillow's avatar Scott Pillow Committed by gbsbuild

Add mask argument to wavePrefix and waveInverseBallot. This is support for coming commits

Change-Id: Ic524f8083e729796041b3f277ab5a75c4f03dd5f
parent fc83b6f2
This diff is collapsed.
...@@ -259,6 +259,7 @@ public: ...@@ -259,6 +259,7 @@ public:
bool negateSrc, bool negateSrc,
CVariable* src, CVariable* src,
CVariable* result[2], CVariable* result[2],
CVariable* Flag = nullptr,
bool isPrefix = false, bool isPrefix = false,
bool isQuad = false); bool isQuad = false);
...@@ -269,7 +270,8 @@ public: ...@@ -269,7 +270,8 @@ public:
bool negateSrc, bool negateSrc,
CVariable* src, CVariable* src,
CVariable* result[2], CVariable* result[2],
bool isPrefix = false); CVariable* Flag,
bool isPrefix);
bool IsUniformAtomic(llvm::Instruction* pInst); bool IsUniformAtomic(llvm::Instruction* pInst);
void emitAtomicRaw(llvm::GenIntrinsicInst* pInst); void emitAtomicRaw(llvm::GenIntrinsicInst* pInst);
...@@ -360,8 +362,10 @@ public: ...@@ -360,8 +362,10 @@ public:
// CrossLane Instructions // CrossLane Instructions
void emitWaveBallot(llvm::GenIntrinsicInst* inst); void emitWaveBallot(llvm::GenIntrinsicInst* inst);
void emitWaveInverseBallot(llvm::GenIntrinsicInst* inst);
void emitWaveShuffleIndex(llvm::GenIntrinsicInst* inst); void emitWaveShuffleIndex(llvm::GenIntrinsicInst* inst);
void emitWavePrefix(llvm::GenIntrinsicInst* inst, bool isQuad = false); void emitWavePrefix(llvm::WavePrefixIntrinsic* I);
void emitQuadPrefix(llvm::QuadPrefixIntrinsic* I);
void emitWaveAll(llvm::GenIntrinsicInst* inst); void emitWaveAll(llvm::GenIntrinsicInst* inst);
// Those three "vector" version shall be combined with // Those three "vector" version shall be combined with
...@@ -501,6 +505,9 @@ private: ...@@ -501,6 +505,9 @@ private:
void emitSetMessagePhaseType(llvm::GenIntrinsicInst* inst, VISA_Type type); void emitSetMessagePhaseType(llvm::GenIntrinsicInst* inst, VISA_Type type);
void emitSetMessagePhaseType_legacy(llvm::GenIntrinsicInst* inst, VISA_Type type); void emitSetMessagePhaseType_legacy(llvm::GenIntrinsicInst* inst, VISA_Type type);
void emitScan(llvm::Value *Src, IGC::WaveOps Op,
bool isInclusiveScan, llvm::Value *Mask, bool isQuad);
// Cached per lane offset variables. This is a per basic block data // Cached per lane offset variables. This is a per basic block data
// structure. For each entry, the first item is the scalar type size in // structure. For each entry, the first item is the scalar type size in
// bytes, the second item is the corresponding symbol. // bytes, the second item is the corresponding symbol.
......
...@@ -463,7 +463,7 @@ void SubGroupFuncsResolution::subGroupScan(WaveOps op, CallInst &CI) ...@@ -463,7 +463,7 @@ void SubGroupFuncsResolution::subGroupScan(WaveOps op, CallInst &CI)
IRBuilder<> IRB(&CI); IRBuilder<> IRB(&CI);
Value* arg = CI.getArgOperand(0); Value* arg = CI.getArgOperand(0);
Value* opVal = IRB.getInt8((uint8_t)op); Value* opVal = IRB.getInt8((uint8_t)op);
Value* args[3] = { arg, opVal, IRB.getInt1(false) }; Value* args[] = { arg, opVal, IRB.getInt1(false), IRB.getInt1(true) };
Function* waveScan = GenISAIntrinsic::getDeclaration(CI.getCalledFunction()->getParent(), Function* waveScan = GenISAIntrinsic::getDeclaration(CI.getCalledFunction()->getParent(),
GenISAIntrinsic::GenISA_WavePrefix, GenISAIntrinsic::GenISA_WavePrefix,
arg->getType()); arg->getType());
......
...@@ -96,7 +96,7 @@ public: ...@@ -96,7 +96,7 @@ public:
return isa<CallInst>(V) && classof(cast<CallInst>(V)); return isa<CallInst>(V) && classof(cast<CallInst>(V));
} }
uint64_t getImm64Operand(unsigned idx) { uint64_t getImm64Operand(unsigned idx) const {
assert(isa<ConstantInt>(getOperand(idx))); assert(isa<ConstantInt>(getOperand(idx)));
return valueToImm64(getOperand(idx)); return valueToImm64(getOperand(idx));
} }
...@@ -531,6 +531,51 @@ public: ...@@ -531,6 +531,51 @@ public:
} }
}; };
class WavePrefixIntrinsic : public GenIntrinsicInst
{
public:
Value *getSrc() const { return getOperand(0); }
IGC::WaveOps getOpKind() const
{
return static_cast<IGC::WaveOps>(getImm64Operand(1));
}
bool isInclusiveScan() const
{
return getImm64Operand(2) != 0;
}
Value *getMask() const { return getOperand(3); }
// Methods for support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const GenIntrinsicInst *I) {
return I->getIntrinsicID() == GenISAIntrinsic::GenISA_WavePrefix;
}
static inline bool classof(const Value *V) {
return isa<GenIntrinsicInst>(V) && classof(cast<GenIntrinsicInst>(V));
}
};
class QuadPrefixIntrinsic : public GenIntrinsicInst
{
public:
Value *getSrc() const { return getOperand(0); }
IGC::WaveOps getOpKind() const
{
return static_cast<IGC::WaveOps>(getImm64Operand(1));
}
bool isInclusiveScan() const
{
return getImm64Operand(2) != 0;
}
// Methods for support type inquiry through isa, cast, and dyn_cast:
static inline bool classof(const GenIntrinsicInst *I) {
return I->getIntrinsicID() == GenISAIntrinsic::GenISA_QuadPrefix;
}
static inline bool classof(const Value *V) {
return isa<GenIntrinsicInst>(V) && classof(cast<GenIntrinsicInst>(V));
}
};
template <class X, class Y> template <class X, class Y>
inline bool isa(const Y &Val, GenISAIntrinsic::ID id) inline bool isa(const Y &Val, GenISAIntrinsic::ID id)
{ {
......
...@@ -265,9 +265,18 @@ Imported_Intrinsics = \ ...@@ -265,9 +265,18 @@ Imported_Intrinsics = \
"GenISA_pair_to_ptr": ["anyptr",["int","int"],"NoMem"], "GenISA_pair_to_ptr": ["anyptr",["int","int"],"NoMem"],
"GenISA_ptr_to_pair": [["int","int"],["anyptr"],"NoMem"], "GenISA_ptr_to_pair": [["int","int"],["anyptr"],"NoMem"],
"GenISA_WaveBallot": ["int",["bool"],"Convergent,InaccessibleMemOnly"], "GenISA_WaveBallot": ["int",["bool"],"Convergent,InaccessibleMemOnly"],
# Arg 0 - Mask value
# Return - assigns each lane the value of its corresponding bit.
"GenISA_WaveInverseBallot": ["bool",["int"],"Convergent,InaccessibleMemOnly"],
"GenISA_WaveShuffleIndex": ["anyint",[0,"int"],"Convergent,NoMem"], "GenISA_WaveShuffleIndex": ["anyint",[0,"int"],"Convergent,NoMem"],
"GenISA_WaveAll": ["anyint",[0,"char"],"Convergent,InaccessibleMemOnly"], "GenISA_WaveAll": ["anyint",[0,"char"],"Convergent,InaccessibleMemOnly"],
"GenISA_WavePrefix": ["anyint",[0,"char","bool"],"Convergent,InaccessibleMemOnly"], # Arg 0 - Src value
# Arg 1 - Operation type
# Arg 2 - Is the operation inclusive (1) or exclusive (0)?
# Arg 3 - a mask that specifies a subset of lanes to participate
# in the computation.
# Return - The computed prefix/postfix result
"GenISA_WavePrefix": ["anyint",[0,"char","bool","bool"],"Convergent,InaccessibleMemOnly"],
"GenISA_QuadPrefix": ["anyint",[0,"char","bool"],"Convergent,InaccessibleMemOnly"], "GenISA_QuadPrefix": ["anyint",[0,"char","bool"],"Convergent,InaccessibleMemOnly"],
"GenISA_InitDiscardMask": ["bool",[],"None"], "GenISA_InitDiscardMask": ["bool",[],"None"],
"GenISA_UpdateDiscardMask": ["bool",["bool","bool"],"None"], "GenISA_UpdateDiscardMask": ["bool",["bool","bool"],"None"],
......
...@@ -1238,6 +1238,7 @@ GenISA_WavePrefix,,"// Description: Accumulate and keep the intermediate results ...@@ -1238,6 +1238,7 @@ GenISA_WavePrefix,,"// Description: Accumulate and keep the intermediate results
,anyint, ,anyint,
,0, ,0,
,char,specify the type: sum / Prod / Min/ Max ,char,specify the type: sum / Prod / Min/ Max
,bool,a mask that specifies a subset of lanes to participate in the computation.
GenISA_setMessagePhaseV,, GenISA_setMessagePhaseV,,
,anyvector,new message phase result ,anyvector,new message phase result
,0,cur message phase ,0,cur message phase
......
...@@ -766,9 +766,12 @@ public: ...@@ -766,9 +766,12 @@ public:
llvm::Value* create_runtime(llvm::Value* offset); llvm::Value* create_runtime(llvm::Value* offset);
llvm::Value* create_countbits(llvm::Value* src); llvm::Value* create_countbits(llvm::Value* src);
llvm::Value* create_waveBallot(llvm::Value* src); llvm::Value* create_waveBallot(llvm::Value* src);
llvm::Value* create_waveInverseBallot(llvm::Value* src);
llvm::Value* create_waveshuffleIndex(llvm::Value* src, llvm::Value* index); llvm::Value* create_waveshuffleIndex(llvm::Value* src, llvm::Value* index);
llvm::Value* create_waveAll(llvm::Value* src, llvm::Value* type); llvm::Value* create_waveAll(llvm::Value* src, llvm::Value* type);
llvm::Value* create_wavePrefix(llvm::Value* src, llvm::Value* type, bool inclusive = false); llvm::Value* create_wavePrefix(
llvm::Value* src, llvm::Value* type, bool inclusive,
llvm::Value *Mask = nullptr);
llvm::Value* create_waveMatch(llvm::Instruction *inst, llvm::Value *src); llvm::Value* create_waveMatch(llvm::Instruction *inst, llvm::Value *src);
llvm::Value* create_quadPrefix(llvm::Value* src, llvm::Value* type, bool inclusive = false); llvm::Value* create_quadPrefix(llvm::Value* src, llvm::Value* type, bool inclusive = false);
llvm::Value* get32BitLaneID(); llvm::Value* get32BitLaneID();
......
...@@ -4357,6 +4357,18 @@ inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_countbits( ...@@ -4357,6 +4357,18 @@ inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_countbits(
return this->CreateCall(pFunc, src); return this->CreateCall(pFunc, src);
} }
template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value*
LLVM3DBuilder<preserveNames, T, Inserter>::create_waveInverseBallot(
llvm::Value* src)
{
llvm::Module* module = this->GetInsertBlock()->getParent()->getParent();
llvm::Function* pFunc = llvm::GenISAIntrinsic::getDeclaration(
module,
llvm::GenISAIntrinsic::GenISA_WaveInverseBallot);
return this->CreateCall(pFunc, src);
}
template<bool preserveNames, typename T, typename Inserter> template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_waveBallot(llvm::Value* src) inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_waveBallot(llvm::Value* src)
{ {
...@@ -4390,14 +4402,19 @@ inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_waveAll(ll ...@@ -4390,14 +4402,19 @@ inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_waveAll(ll
} }
template<bool preserveNames, typename T, typename Inserter> template<bool preserveNames, typename T, typename Inserter>
inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_wavePrefix(llvm::Value* src, llvm::Value* type, bool inclusive) inline llvm::Value* LLVM3DBuilder<preserveNames, T, Inserter>::create_wavePrefix(
llvm::Value* src, llvm::Value* type, bool inclusive, llvm::Value *Mask)
{ {
// If a nullptr is passed in for 'Mask' (as is the default), just include
// all lanes.
Mask = Mask ? Mask : this->getInt1(true);
llvm::Module* module = this->GetInsertBlock()->getParent()->getParent(); llvm::Module* module = this->GetInsertBlock()->getParent()->getParent();
llvm::Function* pFunc = llvm::GenISAIntrinsic::getDeclaration( llvm::Function* pFunc = llvm::GenISAIntrinsic::getDeclaration(
module, module,
llvm::GenISAIntrinsic::GenISA_WavePrefix, llvm::GenISAIntrinsic::GenISA_WavePrefix,
src->getType()); src->getType());
return this->CreateCall3(pFunc, src, type, this->getInt1(inclusive)); return this->CreateCall4(pFunc, src, type, this->getInt1(inclusive), Mask);
} }
template<bool preserveNames, typename T, typename Inserter> template<bool preserveNames, typename T, typename Inserter>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment