Commit 9e4565a2 authored by Grzegorz Kluczek's avatar Grzegorz Kluczek

Refactor GET_MEMPOOL_PTR to use ThreadCount specific to HW platform

Change-Id: I3f5ba9f12e69afd1d92b8929a396d912c85414db
parent 341d76ce
......@@ -31,10 +31,11 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// Use this to grab a pointer to local memory whenever you
// are treating the local memory as automatic storage.
#define GET_MEMPOOL_PTR(_ptr, _type, _nelems) \
#define GET_MEMPOOL_PTR(_ptr, _type, _allocAllWorkgroups, _additionalElems) \
__local _type* _ptr = \
(__local _type*)__builtin_IB_AllocLocalMemPool( \
(_nelems) * sizeof(_type), \
_allocAllWorkgroups, \
_additionalElems, \
sizeof(_type));
// Macro for async work copy implementation.
......
......@@ -108,7 +108,7 @@ void __builtin_IB_write_2darr_f(int, int4, float4, int);
void __builtin_IB_write_2d_f(int, int2, float4, int);
// Workgroup functions
local uchar* __builtin_IB_AllocLocalMemPool(uint, uint);
local uchar* __builtin_IB_AllocLocalMemPool(bool allocAllWorkgroups, uint numAdditionalElements, uint elementSize);
// Memory fences
// See GenISAIntrinsics.td for documentation
......
......@@ -976,7 +976,7 @@ bool __builtin_spirv_OpGroupAll_i32_i1(uint Execution, bool Predicate)
{
if (Execution == Workgroup)
{
GET_MEMPOOL_PTR(tmp, int, 1)
GET_MEMPOOL_PTR(tmp, int, false, 1)
*tmp = 0;
__builtin_spirv_OpControlBarrier_i32_i32_i32(Execution, 0, AcquireRelease | WorkgroupMemory); // Wait for tmp to be initialized
__builtin_spirv_OpAtomicOr_p3i32_i32_i32_i32((volatile local uint*)tmp, Device, Relaxed, Predicate == 0); // Set to true if predicate is zero
......@@ -1000,7 +1000,7 @@ bool __builtin_spirv_OpGroupAny_i32_i1(uint Execution, bool Predicate)
{
if (Execution == Workgroup)
{
GET_MEMPOOL_PTR(tmp, int, 1)
GET_MEMPOOL_PTR(tmp, int, false, 1)
*tmp = 0;
__builtin_spirv_OpControlBarrier_i32_i32_i32(Execution, 0, AcquireRelease | WorkgroupMemory); // Wait for tmp to be initialized
__builtin_spirv_OpAtomicOr_p3i32_i32_i32_i32((volatile local uint*)tmp, Device, Relaxed, Predicate != 0); // Set to true if predicate is non-zero
......@@ -1027,7 +1027,7 @@ bool __builtin_spirv_OpGroupAny_i32_i1(uint Execution, bool Predicate)
#define BROADCAST_WORKGROUP(type) \
{ \
GET_MEMPOOL_PTR(tmp, type, 1) \
GET_MEMPOOL_PTR(tmp, type, false, 1) \
if( (__spirv_LocalID(0) == LocalId.s0) & \
(__spirv_LocalID(1) == LocalId.s1) & \
(__spirv_LocalID(2) == LocalId.s2) ) \
......@@ -1356,7 +1356,7 @@ static double OVERLOADABLE __intel_add(double lhs, double rhs) { return lhs +
#define DEFN_WORK_GROUP_REDUCE(type, op, identity, X) \
{ \
GET_MEMPOOL_PTR(data, type, 448) \
GET_MEMPOOL_PTR(data, type, true, 0) \
uint lid = __spirv_BuiltInLocalInvocationIndex(); \
uint lsize = __spirv_WorkgroupSize(); \
data[lid] = X; \
......@@ -1380,7 +1380,7 @@ static double OVERLOADABLE __intel_add(double lhs, double rhs) { return lhs +
#define DEFN_WORK_GROUP_SCAN_INCL(type, op, identity, X) \
{ \
GET_MEMPOOL_PTR(data, type, 448) \
GET_MEMPOOL_PTR(data, type, true, 0) \
uint lid = __spirv_BuiltInLocalInvocationIndex(); \
uint lsize = __spirv_WorkgroupSize(); \
data[lid] = X; \
......@@ -1403,7 +1403,7 @@ static double OVERLOADABLE __intel_add(double lhs, double rhs) { return lhs +
#define DEFN_WORK_GROUP_SCAN_EXCL(type, op, identity, X) \
{ \
GET_MEMPOOL_PTR(data, type, 448 + 1) \
GET_MEMPOOL_PTR(data, type, true, 1) \
uint lid = __spirv_BuiltInLocalInvocationIndex(); \
uint lsize = __spirv_WorkgroupSize(); \
data[0] = identity; \
......
......@@ -563,7 +563,7 @@ atomic_flag_clear_function()
INLINE int OVERLOADABLE work_group_any(int predicate)
{
GET_MEMPOOL_PTR(tmp, int, 1)
GET_MEMPOOL_PTR(tmp, int, false, 1)
*tmp = 0;
barrier(CLK_LOCAL_MEM_FENCE); // Wait for tmp to be initialized
atomic_or(tmp, predicate != 0); // Set to true if predicate is non-zero
......@@ -573,7 +573,7 @@ INLINE int OVERLOADABLE work_group_any(int predicate)
INLINE int OVERLOADABLE work_group_all(int predicate)
{
GET_MEMPOOL_PTR(tmp, int, 1)
GET_MEMPOOL_PTR(tmp, int, false, 1)
*tmp = 0;
barrier(CLK_LOCAL_MEM_FENCE); // Wait for tmp to be initialized
atomic_or(tmp, predicate == 0); // Set to true if predicate is zero
......
......@@ -31,10 +31,11 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
// Use this to grab a pointer to local memory whenever you
// are treating the local memory as automatic storage.
#define GET_MEMPOOL_PTR(_ptr, _type, _nelems) \
#define GET_MEMPOOL_PTR(_ptr, _type, _allocAllWorkgroups, _additionalElems) \
__local _type* _ptr = \
(__local _type*)__builtin_IB_AllocLocalMemPool( \
(_nelems) * sizeof(_type), \
_allocAllWorkgroups, \
_additionalElems, \
sizeof(_type));
// Macro for async work copy implementation.
......
......@@ -219,6 +219,9 @@ void InlineLocalsResolution::collectInfoOnSharedLocalMem(Module& M)
// first we collect SLM usage on GET_MEMPOOL_PTR
if (M.getFunction(BUILTIN_MEMPOOL) != nullptr)
{
const auto pCtx = getAnalysis<CodeGenContextWrapper>().getCodeGenContext();
const GT_SYSTEM_INFO platform = pCtx->platform.GetGTSystemInfo();
SmallVector<CallInst*, 8> callsToReplace;
unsigned maxBytesOnModule = 0;
unsigned maxAlignOnModule = 0;
......@@ -244,9 +247,27 @@ void InlineLocalsResolution::collectInfoOnSharedLocalMem(Module& M)
// should always be called with constant operands
assert(isa<ConstantInt>(CI->getArgOperand(0)));
assert(isa<ConstantInt>(CI->getArgOperand(1)));
assert(isa<ConstantInt>(CI->getArgOperand(2)));
const unsigned int allocAllWorkgroups = unsigned(cast<ConstantInt>(CI->getArgOperand(0))->getZExtValue());
const unsigned int numAdditionalElements = unsigned(cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue());
const unsigned int elementSize = unsigned(cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue());
unsigned int size = unsigned(cast<ConstantInt>(CI->getArgOperand(0))->getZExtValue());
unsigned int align = unsigned(cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue());
unsigned int numElements = numAdditionalElements;
if (allocAllWorkgroups)
{
if (platform.ThreadCount != 0)
{
numElements += platform.ThreadCount;
}
else
{
//workaround for cloc offline compiler, which currently does not pass any platform data
numElements += 448;
}
}
const unsigned int size = numElements * elementSize;
const unsigned int align = elementSize;
maxBytesOnFunc = std::max(maxBytesOnFunc, size);
maxBytesOnModule = std::max(maxBytesOnModule, size);
......
......@@ -70,6 +70,7 @@ namespace IGC
{
AU.setPreservesCFG();
AU.addRequired<MetaDataUtilsWrapper>();
AU.addRequired<CodeGenContextWrapper>();
AU.addRequired<llvm::CallGraphWrapperPass>();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment