Commit c129a9a6 authored by Andrzej Ratajewski's avatar Andrzej Ratajewski Committed by gbsbuild

Support for SPIRV specialization constants

Change-Id: If22772de9d38cccd4a841f462d6934e7843aa373
parent d61693f6
......@@ -146,6 +146,8 @@ enum E_SH_TYPE
SH_TYPE_OPENCL_DEV_DEBUG = 0xff000008, // Device debug
SH_TYPE_SPIRV = 0xff000009, // SPIRV
SH_TYPE_NON_COHERENT_DEV_BINARY = 0xff00000a, // Non-coherent Device binary
SH_TYPE_SPIRV_SC_IDS = 0xff00000b, // Specialization Constants IDs
SH_TYPE_SPIRV_SC_VALUES = 0xff00000c // Specialization Constants values
};
// E_SH_FLAG - List of section header flags.
......
......@@ -76,6 +76,7 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#include "libSPIRV/SPIRVFunction.h"
#include "libSPIRV/SPIRVInstruction.h"
#include "libSPIRV/SPIRVModule.h"
#include "SPIRVInternal.h"
#include "common/MDFrameWork.h"
#include "../../AdaptorCommon/TypesLegalizationPass.hpp"
......@@ -2102,15 +2103,22 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
// Translation of non-instruction values
switch(OC) {
case OpSpecConstant:
case OpConstant: {
SPIRVConstant *BConst = static_cast<SPIRVConstant *>(BV);
SPIRVType *BT = BV->getType();
Type *LT = transType(BT);
uint64_t V = BConst->getZExtIntValue();
if(BV->hasDecorate(DecorationSpecId)) {
spirv_assert(OC == OpSpecConstant && "Only SpecConstants can be specialized!");
SPIRVWord specid = *BV->getDecorate(DecorationSpecId).begin();
V = BM->getSpecConstant(specid);
}
switch(BT->getOpCode()) {
case OpTypeBool:
case OpTypeInt:
return mapValue(BV, ConstantInt::get(LT, BConst->getZExtIntValue(),
static_cast<SPIRVTypeInt*>(BT)->isSigned()));
return mapValue(BV, ConstantInt::get(LT, V,
static_cast<SPIRVTypeInt*>(BT)->isSigned()));
case OpTypeFloat: {
const llvm::fltSemantics *FS = nullptr;
switch (BT->getFloatBitWidth()) {
......@@ -2127,7 +2135,7 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
spirv_assert (0 && "invalid float type");
}
return mapValue(BV, ConstantFP::get(*Context, APFloat(*FS,
APInt(BT->getFloatBitWidth(), BConst->getZExtIntValue()))));
APInt(BT->getFloatBitWidth(), V))));
}
default:
llvm_unreachable("Not implemented");
......@@ -2136,9 +2144,29 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
}
break;
case OpSpecConstantTrue:
if (BV->hasDecorate(DecorationSpecId)) {
SPIRVWord specid = *BV->getDecorate(DecorationSpecId).begin();
if(BM->getSpecConstant(specid))
return mapValue(BV, ConstantInt::getTrue(*Context));
else
return mapValue(BV, ConstantInt::getFalse(*Context));
}
// intentional fall-through: if decoration was not specified, treat this
// as a OpConstantTrue (default spec constant value)
case OpConstantTrue:
return mapValue(BV, ConstantInt::getTrue(*Context));
case OpSpecConstantFalse:
if (BV->hasDecorate(DecorationSpecId)) {
SPIRVWord specid = *BV->getDecorate(DecorationSpecId).begin();
if (BM->getSpecConstant(specid))
return mapValue(BV, ConstantInt::getTrue(*Context));
else
return mapValue(BV, ConstantInt::getFalse(*Context));
}
// intentional fall-through: if decoration was not specified, treat this
// as a OpConstantFalse (default spec constant value)
case OpConstantFalse:
return mapValue(BV, ConstantInt::getFalse(*Context));
......@@ -2149,6 +2177,7 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
return mapValue(BV, ConstantAggregateZero::get(LT));
}
case OpSpecConstantComposite:
case OpConstantComposite: {
auto BCC = static_cast<SPIRVConstantComposite*>(BV);
std::vector<Constant *> CV;
......@@ -2349,6 +2378,13 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
break;
}
// During translation of OpSpecConstantOp we create an instruction
// corresponding to the Opcode operand and then translate this instruction.
// For such instruction BB and F should be nullptr, because it is a constant
// expression declared out of scope of any basic block or function.
// All other values require valid BB pointer.
assert(((isSpecConstantOpAllowedOp(OC) && !F && !BB) || BB) && "Invalid BB");
// Creation of place holder
if (CreatePlaceHolder) {
auto GV = new GlobalVariable(*M,
......@@ -3846,10 +3882,11 @@ static void dumpSPIRVBC(const char* fname, const char* data, unsigned int size)
bool ReadSPIRV(LLVMContext &C, std::istream &IS, Module *&M,
StringRef options,
std::string &ErrMsg) {
std::string &ErrMsg,
std::unordered_map<uint32_t, uint64_t> *specConstants) {
std::unique_ptr<SPIRVModule> BM( SPIRVModule::createSPIRVModule() );
BM->setCompileFlag( options );
BM->setSpecConstantMap(specConstants);
IS >> *BM;
BM->resolveUnknownStructFields();
M = new Module( "",C );
......
......@@ -40,12 +40,15 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#include "llvm/IR/Module.h"
#include <unordered_map>
namespace spv{
// Loads SPIRV from istream and translate to LLVM module.
// Returns true if succeeds.
bool ReadSPIRV(llvm::LLVMContext &C, std::istream &IS, llvm::Module *&M,
llvm::StringRef options,
std::string &ErrMsg);
std::string &ErrMsg,
std::unordered_map<uint32_t, uint64_t> *specConstants);
}
#endif
......@@ -663,10 +663,6 @@ _SPIRV_OP(SourceContinued)
_SPIRV_OP(TypeMatrix)
_SPIRV_OP(TypeRuntimeArray)
_SPIRV_OP(TypeForwardPointer)
_SPIRV_OP(SpecConstantTrue)
_SPIRV_OP(SpecConstantFalse)
_SPIRV_OP(SpecConstant)
_SPIRV_OP(SpecConstantComposite)
_SPIRV_OP(ImageTexelPointer)
_SPIRV_OP(ImageSampleDrefImplicitLod)
_SPIRV_OP(ImageSampleDrefExplicitLod)
......
......@@ -87,7 +87,8 @@ public:
InstSchema(SPIRVISCH_Default),
SrcLang(SpvSourceLanguageOpenCL_C),
SrcLangVer(12),
MemoryModel(SPIRVMemoryModelKind::MemoryModelOpenCL){
MemoryModel(SPIRVMemoryModelKind::MemoryModelOpenCL),
SCMap(nullptr) {
AddrModel = sizeof(size_t) == 32 ? AddressingModelPhysical32 : AddressingModelPhysical64;
};
virtual ~SPIRVModuleImpl();
......@@ -116,6 +117,17 @@ public:
const std::string &getCompileFlag() const { return CompileFlag;}
std::string &getCompileFlag() { return CompileFlag;}
void setCompileFlag(const std::string &options) { CompileFlag = options; }
bool isSpecConstant(SPIRVWord spec_id) const {
if(SCMap)
return SCMap->find(spec_id) != SCMap->end();
else
return false;
}
uint64_t getSpecConstant(SPIRVWord spec_id) {
spirv_assert(isSpecConstant(spec_id) && "Specialization constant was not specialized!");
return SCMap->at(spec_id);
}
void setSpecConstantMap(SPIRVSpecConstantMap *specConstants) { SCMap = specConstants; }
std::set<std::string> &getExtension() { return SPIRVExt;}
SPIRVFunction *getFunction(unsigned I) const { return FuncVec[I];}
SPIRVVariable *getVariable(unsigned I) const { return VariableVec[I];}
......@@ -245,6 +257,25 @@ public:
return globalVars;
}
virtual std::vector<SPIRVValue*> parseSpecConstants()
{
std::vector<SPIRVValue*> specConstants;
for (auto& item : IdEntryMap)
{
Op opcode = item.second->getOpCode();
if (opcode == spv::Op::OpSpecConstant ||
opcode == spv::Op::OpSpecConstantTrue ||
opcode == spv::Op::OpSpecConstantFalse)
{
auto specConstant = static_cast<SPIRVValue*>(item.second);
specConstants.push_back(specConstant);
}
}
return specConstants;
}
// I/O functions
friend std::istream & operator>>(std::istream &I, SPIRVModule& M);
......@@ -296,6 +327,7 @@ private:
SPIRVExecModelIdVecMap EntryPointVec;
SPIRVStringMap StrMap;
SPIRVCapSet CapSet;
SPIRVSpecConstantMap *SCMap;
std::map<unsigned, SPIRVTypeInt*> IntTypeMap;
std::map<unsigned, SPIRVConstant*> LiteralMap;
......
......@@ -78,7 +78,6 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
namespace spv{
class SPIRVBasicBlock;
class SPIRVConstant;
class SPIRVEntry;
class SPIRVFunction;
class SPIRVInstruction;
......@@ -110,6 +109,11 @@ class SPIRVInstTemplateBase;
typedef SPIRVBasicBlock SPIRVLabel;
struct SPIRVTypeImageDescriptor;
template <Op OP> class SPIRVConstantBase;
typedef SPIRVConstantBase<OpConstant> SPIRVConstant;
typedef std::unordered_map<uint32_t, uint64_t> SPIRVSpecConstantMap;
class SPIRVModule
{
public:
......@@ -141,6 +145,9 @@ public:
virtual SPIRVExtInstSetKind getBuiltinSet(SPIRVId) const = 0;
virtual std::string &getCompileFlag() = 0;
virtual void setCompileFlag(const std::string &options) = 0;
virtual bool isSpecConstant(SPIRVWord) const = 0;
virtual uint64_t getSpecConstant(SPIRVWord) = 0;
virtual void setSpecConstantMap(SPIRVSpecConstantMap *) = 0;
virtual const std::string &getCompileFlag() const = 0;
virtual SPIRVFunction *getEntryPoint(SPIRVExecutionModelKind, unsigned) const = 0;
virtual std::set<std::string> &getExtension() = 0;
......@@ -214,6 +221,7 @@ public:
virtual SPIRVExtInst* getCompilationUnit() const = 0;
virtual std::vector<SPIRVExtInst*> getGlobalVars() = 0;
virtual std::vector<SPIRVValue*> parseSpecConstants() = 0;
// I/O functions
friend std::istream & operator>>(std::istream &I, SPIRVModule& M);
......
......@@ -274,7 +274,6 @@ private:
SPIRVWord CompCount; // Component Count
};
class SPIRVConstant;
class SPIRVTypeArray:public SPIRVType {
public:
// Complete constructor
......
......@@ -162,32 +162,33 @@ protected:
SPIRVType *Type; // Value Type
};
class SPIRVConstant: public SPIRVValue {
template<Op OC>
class SPIRVConstantBase: public SPIRVValue {
public:
// Complete constructor for integer constant
SPIRVConstant(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
SPIRVConstantBase(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
uint64_t TheValue)
:SPIRVValue(M, 0, OpConstant, TheType, TheId){
:SPIRVValue(M, 0, OC, TheType, TheId){
Union.UInt64Val = TheValue;
recalculateWordCount();
validate();
}
// Complete constructor for float constant
SPIRVConstant(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId, float TheValue)
:SPIRVValue(M, 0, OpConstant, TheType, TheId){
SPIRVConstantBase(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId, float TheValue)
:SPIRVValue(M, 0, OC, TheType, TheId){
Union.FloatVal = TheValue;
recalculateWordCount();
validate();
}
// Complete constructor for double constant
SPIRVConstant(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId, double TheValue)
:SPIRVValue(M, 0, OpConstant, TheType, TheId){
SPIRVConstantBase(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId, double TheValue)
:SPIRVValue(M, 0, OC, TheType, TheId){
Union.DoubleVal = TheValue;
recalculateWordCount();
validate();
}
// Incomplete constructor
SPIRVConstant():SPIRVValue(OpConstant), NumWords(0){}
SPIRVConstantBase():SPIRVValue(OC), NumWords(0){}
uint64_t getZExtIntValue() const { return Union.UInt64Val;}
float getFloatValue() const { return Union.FloatVal;}
double getDoubleValue() const { return Union.DoubleVal;}
......@@ -225,6 +226,9 @@ protected:
} Union;
};
typedef SPIRVConstantBase<OpConstant> SPIRVConstant;
typedef SPIRVConstantBase<OpSpecConstant> SPIRVSpecConstant;
template<Op OC>
class SPIRVConstantEmpty: public SPIRVValue {
public:
......@@ -260,6 +264,9 @@ protected:
typedef SPIRVConstantBool<OpConstantTrue> SPIRVConstantTrue;
typedef SPIRVConstantBool<OpConstantFalse> SPIRVConstantFalse;
typedef SPIRVConstantBool<OpSpecConstantTrue> SPIRVSpecConstantTrue;
typedef SPIRVConstantBool<OpSpecConstantFalse> SPIRVSpecConstantFalse;
class SPIRVConstantNull : public SPIRVConstantEmpty<OpConstantNull>
{
public:
......@@ -293,18 +300,18 @@ protected:
}
};
class SPIRVConstantComposite: public SPIRVValue {
template<Op OC>
class SPIRVConstantCompositeBase: public SPIRVValue {
public:
// Complete constructor for composite constant
SPIRVConstantComposite(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
SPIRVConstantCompositeBase(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
const std::vector<SPIRVValue *> TheElements)
:SPIRVValue(M, TheElements.size()+3, OpConstantComposite, TheType,
TheId){
:SPIRVValue(M, TheElements.size()+3, OC, TheType, TheId){
Elements = getIds(TheElements);
validate();
}
// Incomplete constructor
SPIRVConstantComposite():SPIRVValue(OpConstantComposite){}
SPIRVConstantCompositeBase():SPIRVValue(OC){}
std::vector<SPIRVValue*> getElements()const {
return getValues(Elements);
}
......@@ -322,6 +329,9 @@ protected:
std::vector<SPIRVId> Elements;
};
typedef SPIRVConstantCompositeBase<OpConstantComposite> SPIRVConstantComposite;
typedef SPIRVConstantCompositeBase<OpSpecConstantComposite> SPIRVSpecConstantComposite;
class SPIRVConstantSampler: public SPIRVValue {
public:
const static Op OC = OpConstantSampler;
......
......@@ -176,29 +176,35 @@ Description:
\******************************************************************************/
struct STB_TranslateInputArgs
{
char* pInput; // data to be translated
uint32_t InputSize; // size of data to be translated
const char* pOptions; // list of build/compile options
uint32_t OptionsSize; // size of options list
const char* pInternalOptions; // list of build/compile options
uint32_t InternalOptionsSize; // size of options list
void* pTracingOptions; // instrumentation options
uint32_t TracingOptionsCount; // number of instrumentation options
void* GTPinInput; // input structure for GTPin requests
bool CompileTimeStatisticsEnable;
char* pInput; // data to be translated
uint32_t InputSize; // size of data to be translated
const char* pOptions; // list of build/compile options
uint32_t OptionsSize; // size of options list
const char* pInternalOptions; // list of build/compile options
uint32_t InternalOptionsSize; // size of options list
void* pTracingOptions; // instrumentation options
uint32_t TracingOptionsCount; // number of instrumentation options
void* GTPinInput; // input structure for GTPin requests
bool CompileTimeStatisticsEnable;
const uint32_t* pSpecConstantsIds; // user-defined spec constants ids
const uint64_t* pSpecConstantsValues; // spec constants values to be translated
uint32_t SpecConstantsSize; // number of specialization constants
STB_TranslateInputArgs()
{
pInput = NULL;
InputSize = 0;
pOptions = NULL;
OptionsSize = 0;
pInternalOptions = NULL;
InternalOptionsSize = 0;
pTracingOptions = NULL;
TracingOptionsCount = 0;
GTPinInput = NULL;
pInput = NULL;
InputSize = 0;
pOptions = NULL;
OptionsSize = 0;
pInternalOptions = NULL;
InternalOptionsSize = 0;
pTracingOptions = NULL;
TracingOptionsCount = 0;
GTPinInput = NULL;
CompileTimeStatisticsEnable = false;
pSpecConstantsIds = NULL;
pSpecConstantsValues = NULL;
SpecConstantsSize = 0;
}
};
......
......@@ -65,6 +65,8 @@ SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#include "AdaptorOCL/SPIRV/SPIRVconsum.h"
#include "common/LLVMWarningsPop.hpp"
#include "AdaptorOCL/SPIRV/SPIRV-Tools/include/spirv-tools/libspirv.h"
#include "AdaptorOCL/SPIRV/libSPIRV/SPIRVModule.h"
#include "AdaptorOCL/SPIRV/libSPIRV/SPIRVValue.h"
#endif
#include "common/LLVMWarningsPush.hpp"
......@@ -294,6 +296,19 @@ bool CIGCTranslationBlock::Translate(
return false;
}
std::unordered_map<uint32_t, uint64_t> UnpackSpecConstants(
const uint32_t* pSpecConstantsIds,
const uint64_t* pSpecConstantsValues,
uint32_t size)
{
std::unordered_map<uint32_t, uint64_t> outSpecConstantsMap;
for (uint32_t i = 0; i < size; i++)
{
outSpecConstantsMap[pSpecConstantsIds[i]] = pSpecConstantsValues[i];
}
return outSpecConstantsMap;
}
bool ProcessElfInput(
STB_TranslateInputArgs &InputArgs,
STB_TranslateOutputArgs &OutputArgs,
......@@ -320,12 +335,25 @@ bool ProcessElfInput(
const CLElfLib::SElf64SectionHeader* pSectionHeader = pElfReader->GetSectionHeader(i);
assert(pSectionHeader != NULL);
char* pData = NULL;
size_t dataSize = 0;
if (pSectionHeader->Type == CLElfLib::SH_TYPE_SPIRV_SC_IDS)
{
pElfReader->GetSectionData(i, pData, dataSize);
InputArgs.pSpecConstantsIds = reinterpret_cast<const uint32_t*>(pData);
}
if (pSectionHeader->Type == CLElfLib::SH_TYPE_SPIRV_SC_VALUES)
{
pElfReader->GetSectionData(i, pData, dataSize);
InputArgs.pSpecConstantsValues = reinterpret_cast<const uint64_t*>(pData);
}
if ((pSectionHeader->Type == CLElfLib::SH_TYPE_OPENCL_LLVM_BINARY) ||
(pSectionHeader->Type == CLElfLib::SH_TYPE_OPENCL_LLVM_ARCHIVE) ||
(pSectionHeader->Type == CLElfLib::SH_TYPE_SPIRV))
{
char* pData = NULL;
size_t dataSize = 0;
pElfReader->GetSectionData(i, pData, dataSize);
// Create input module from the buffer
......@@ -344,7 +372,11 @@ bool ProcessElfInput(
if(InputArgs.OptionsSize > 0){
options = llvm::StringRef(InputArgs.pOptions, InputArgs.OptionsSize - 1);
}
bool success = spv::ReadSPIRV(*Context.getLLVMContext(), IS, pKernelModule, options, stringErrMsg);
std::unordered_map<uint32_t, uint64_t> specIDToSpecValueMap = UnpackSpecConstants(
InputArgs.pSpecConstantsIds,
InputArgs.pSpecConstantsValues,
InputArgs.SpecConstantsSize);
bool success = spv::ReadSPIRV(*Context.getLLVMContext(), IS, pKernelModule, options, stringErrMsg, &specIDToSpecValueMap);
#else
std::string stringErrMsg{ "SPIRV consumption not enabled for the TARGET." };
bool success = false;
......@@ -546,7 +578,11 @@ bool ParseInput(
if(pInputArgs->OptionsSize > 0){
options = llvm::StringRef(pInputArgs->pOptions, pInputArgs->OptionsSize);
}
bool success = spv::ReadSPIRV(oclContext, IS, pKernelModule, options, stringErrMsg);
std::unordered_map<uint32_t, uint64_t> specIDToSpecValueMap = UnpackSpecConstants(
pInputArgs->pSpecConstantsIds,
pInputArgs->pSpecConstantsValues,
pInputArgs->SpecConstantsSize);
bool success = spv::ReadSPIRV(oclContext, IS, pKernelModule, options, stringErrMsg, &specIDToSpecValueMap);
#else
std::string stringErrMsg{"SPIRV consumption not enabled for the TARGET."};
bool success = false;
......@@ -576,6 +612,42 @@ bool ParseInput(
return true;
}
#if defined(IGC_SPIRV_ENABLED)
bool ReadSpecConstantsFromSPIRV(std::istream &IS, std::vector<std::pair<uint32_t, uint32_t>> &OutSCInfo)
{
using namespace spv;
std::unique_ptr<SPIRVModule> BM(SPIRVModule::createSPIRVModule());
IS >> *BM;
auto SPV = BM->parseSpecConstants();
for (auto& SC : SPV)
{
SPIRVType *type = SC->getType();
uint32_t spec_size = type->getBitWidth() / 8;
if (SC->hasDecorate(DecorationSpecId))
{
SPIRVWord spec_id = *SC->getDecorate(DecorationSpecId).begin();
Op OP = SC->getOpCode();
if (OP == OpSpecConstant ||
OP == OpSpecConstantFalse ||
OP == OpSpecConstantTrue)
{
OutSCInfo.push_back(std::make_pair(spec_id, spec_size));
}
else
{
assert("Wrong instruction opcode, shouldn't be here!");
return false;
}
}
}
return true;
}
#endif
// Dump shader (binary or text), to default directory.
// Create directory if it doesn't exist.
// Works for all OSes.
......
......@@ -93,9 +93,44 @@ protected:
void *gtPinInput);
};
CIF_DEFINE_INTERFACE_VER_WITH_COMPATIBILITY(IgcOclTranslationCtx, 3, 2) {
using IgcOclTranslationCtx<2>::TranslateImpl;
using IgcOclTranslationCtx<2>::Translate;
CIF_INHERIT_CONSTRUCTOR();
template <typename OclTranslationOutputInterface = OclTranslationOutputTagOCL>
CIF::RAII::UPtr_t<OclTranslationOutputInterface> Translate(CIF::Builtins::BufferSimple *src,
CIF::Builtins::BufferSimple *specConstantsIds,
CIF::Builtins::BufferSimple *specConstantsValues,
CIF::Builtins::BufferSimple *options,
CIF::Builtins::BufferSimple *internalOptions,
CIF::Builtins::BufferSimple *tracingOptions,
uint32_t tracingOptionsCount,
void *gtPinInput) {
auto p = TranslateImpl(OclTranslationOutputInterface::GetVersion(), src, options, internalOptions, tracingOptions, tracingOptionsCount, gtPinInput, specConstantsIds, specConstantsValues);
return CIF::RAII::Pack<OclTranslationOutputInterface>(p);
}
bool GetSpecConstantsInfoImpl(CIF::Builtins::BufferSimple *src,
CIF::Builtins::BufferSimple *outSpecConstantsIds,
CIF::Builtins::BufferSimple *outSpecConstantsSizes);
protected:
virtual OclTranslationOutputBase *TranslateImpl(CIF::Version_t outVersion,
CIF::Builtins::BufferSimple *src,
CIF::Builtins::BufferSimple *specConstantsIds,
CIF::Builtins::BufferSimple *specConstantsValues,
CIF::Builtins::BufferSimple *options,
CIF::Builtins::BufferSimple *internalOptions,
CIF::Builtins::BufferSimple *tracingOptions,
uint32_t tracingOptionsCount,
void *gtPinInput);
};
CIF_GENERATE_VERSIONS_LIST_AND_DECLARE_INTERFACE_DEPENDENCIES(IgcOclTranslationCtx, IGC::OclTranslationOutput, CIF::Builtins::Buffer);
CIF_MARK_LATEST_VERSION(IgcOclTranslationCtxLatest, IgcOclTranslationCtx);
using IgcOclTranslationCtxTagOCL = IgcOclTranslationCtxLatest; // Note : can tag with different version for
using IgcOclTranslationCtxTagOCL = IgcOclTranslationCtx<2>; // Note : can tag with different version for
// transition periods
}
......
......@@ -38,7 +38,7 @@ OclTranslationOutputBase *CIF_GET_INTERFACE_CLASS(IgcOclTranslationCtx, 1)::Tran
CIF::Builtins::BufferSimple *internalOptions,
CIF::Builtins::BufferSimple *tracingOptions,
uint32_t tracingOptionsCount) {
return CIF_GET_PIMPL()->Translate(outVersion, src, options, internalOptions, tracingOptions, tracingOptionsCount, nullptr);
return CIF_GET_PIMPL()->Translate(outVersion, src, nullptr, nullptr, options, internalOptions, tracingOptions, tracingOptionsCount, nullptr);
}
OclTranslationOutputBase *CIF_GET_INTERFACE_CLASS(IgcOclTranslationCtx, 2)::TranslateImpl(
......@@ -49,7 +49,27 @@ OclTranslationOutputBase *CIF_GET_INTERFACE_CLASS(IgcOclTranslationCtx, 2)::Tran
CIF::Builtins::BufferSimple *tracingOptions,
uint32_t tracingOptionsCount,
void *gtPinInput) {
return CIF_GET_PIMPL()->Translate(outVersion, src, options, internalOptions, tracingOptions, tracingOptionsCount, gtPinInput);
return CIF_GET_PIMPL()->Translate(outVersion, src, nullptr, nullptr, options, internalOptions, tracingOptions, tracingOptionsCount, gtPinInput);
}
bool CIF_GET_INTERFACE_CLASS(IgcOclTranslationCtx, 3)::GetSpecConstantsInfoImpl(
CIF::Builtins::BufferSimple *src,
CIF::Builtins::BufferSimple *outSpecConstantsIds,
CIF::Builtins::BufferSimple *outSpecConstantsSizes) {
return CIF_GET_PIMPL()->GetSpecConstantsInfo(src, outSpecConstantsIds, outSpecConstantsSizes);
}
OclTranslationOutputBase *CIF_GET_INTERFACE_CLASS(IgcOclTranslationCtx, 3)::TranslateImpl(
CIF::Version_t outVersion,
CIF::Builtins::BufferSimple *src,
CIF::Builtins::BufferSimple *specConstantsIds,
CIF::Builtins::BufferSimple *specConstantsValues,
CIF::Builtins::BufferSimple *options,
CIF::Builtins::BufferSimple *internalOptions,
CIF::Builtins::BufferSimple *tracingOptions,
uint32_t tracingOptionsCount,
void *gtPinInput) {
return CIF_GET_PIMPL()->Translate(outVersion, src, specConstantsIds, specConstantsValues, options, internalOptions, tracingOptions, tracingOptionsCount, gtPinInput);
}
}
......
......@@ -65,6 +65,10 @@ bool TranslateBuild(
const IGC::CPlatform &platform,
float profilingTimerResolution);
bool ReadSpecConstantsFromSPIRV(
std::istream &IS,
std::vector<std::pair<uint32_t, uint32_t>> &OutSCInfo);
}
bool enableSrcLine(void*);
......@@ -109,14 +113,49 @@ CIF_DECLARE_INTERFACE_PIMPL(IgcOclTranslationCtx) : CIF::PimplBase
return false;
}
bool GetSpecConstantsInfo(CIF::Builtins::BufferSimple *src,
CIF::Builtins::BufferSimple *outSpecConstantsIds,
CIF::Builtins::BufferSimple *outSpecConstantsSizes)
{
bool success = false;
const char* pInput = src->GetMemory<char>();
uint32_t inputSize = static_cast<uint32_t>(src->GetSizeRaw());
if(this->inType == CodeType::spirV){
llvm::StringRef strInput = llvm::StringRef(pInput, inputSize);
std::istringstream IS(strInput);
// vector of pairs [spec_id, spec_size]
std::vector<std::pair<uint32_t, uint32_t>> SCInfo;
success = TC::ReadSpecConstantsFromSPIRV(IS, SCInfo);
outSpecConstantsIds->Resize(sizeof(uint32_t) * SCInfo.size());
outSpecConstantsSizes->Resize(sizeof(uint32_t) * SCInfo.size());
uint32_t* specConstantsIds = outSpecConstantsIds->GetMemoryWriteable<uint32_t>();
uint32_t* specConstantsSizes = outSpecConstantsSizes->GetMemoryWriteable<uint32_t>();
for(uint32_t i = 0; i < SCInfo.size(); ++i){
specConstantsIds[i] = SCInfo.at(i).first;
specConstantsSizes[i] = SCInfo.at(i).second;
}
}
else{
success = false;
}
return success;
}
OclTranslationOutputBase *Translate(CIF::Version_t outVersion,
CIF::Builtins::BufferSimple *src,
CIF::Builtins::BufferSimple *src,
CIF::Builtins::BufferSimple *specConstantsIds,
CIF::Builtins::BufferSimple *specConstantsValues,
CIF::Builtins::BufferSimple *options,
CIF::Builtins::BufferSimple *internalOptions,
CIF::Builtins::BufferSimple *tracingOptions,
uint32_t tracingOptionsCount,
void *gtPinInput
) const{
void *gtPinInput) const{
// Create interface for return data
auto outputInterface = CIF::RAII::UPtr(CIF::InterfaceCreator<OclTranslationOutput>::CreateInterfaceVer(outVersion, this->outType));
if(outputInterface == nullptr){
......@@ -149,6 +188,11 @@ CIF_DECLARE_INTERFACE_PIMPL(IgcOclTranslationCtx) : CIF::PimplBase
inputArgs.pTracingOptions = tracingOptions->GetMemoryRawWriteable();
}
inputArgs.TracingOptionsCount = tracingOptionsCount;
if(specConstantsIds != nullptr && specConstantsValues != nullptr){
inputArgs.pSpecConstantsIds = specConstantsIds->GetMemory<uint32_t>();
inputArgs.SpecConstantsSize = static_cast<uint32_t>(specConstantsIds->GetSizeRaw() / sizeof(uint32_t));
inputArgs.pSpecConstantsValues = specConstantsValues->GetMemory<uint64_t>();
}
inputArgs.GTPinInput = gtPinInput;
IGC::CPlatform igcPlatform = this->globalState.GetIgcCPlatform();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment