Shader: Fix support of conditional functions
This commit is contained in:
parent
c8e7fa5063
commit
a5b71f33b9
|
|
@ -27,7 +27,7 @@ namespace Nz
|
|||
struct FuncData;
|
||||
struct Variable;
|
||||
|
||||
inline SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::vector<FuncData>& funcData);
|
||||
inline SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::unordered_map<std::size_t, FuncData>& funcData);
|
||||
SpirvAstVisitor(const SpirvAstVisitor&) = delete;
|
||||
SpirvAstVisitor(SpirvAstVisitor&&) = delete;
|
||||
~SpirvAstVisitor() = default;
|
||||
|
|
@ -147,10 +147,10 @@ namespace Nz
|
|||
std::size_t m_extVarIndex;
|
||||
std::size_t m_funcCallIndex;
|
||||
std::size_t m_funcIndex;
|
||||
std::unordered_map<std::size_t, FuncData>& m_funcData;
|
||||
std::unordered_map<std::size_t, ShaderAst::StructDescription*> m_structs;
|
||||
std::unordered_map<std::size_t, Variable> m_variables;
|
||||
std::vector<std::size_t> m_scopeSizes;
|
||||
std::vector<FuncData>& m_funcData;
|
||||
std::vector<ShaderAst::StructDescription*> m_structs;
|
||||
std::vector<std::optional<Variable>> m_variables;
|
||||
std::vector<SpirvBlock> m_functionBlocks;
|
||||
std::vector<UInt32> m_resultIds;
|
||||
SpirvBlock* m_currentBlock;
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@
|
|||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/SpirvAstVisitor.hpp>
|
||||
#include <cassert>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::vector<FuncData>& funcData) :
|
||||
inline SpirvAstVisitor::SpirvAstVisitor(SpirvWriter& writer, SpirvSection& instructions, std::unordered_map<std::size_t, FuncData>& funcData) :
|
||||
m_extVarIndex(0),
|
||||
m_funcIndex(0),
|
||||
m_funcData(funcData),
|
||||
|
|
@ -27,17 +28,13 @@ namespace Nz
|
|||
|
||||
inline void SpirvAstVisitor::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* structDesc)
|
||||
{
|
||||
if (structIndex >= m_structs.size())
|
||||
m_structs.resize(structIndex + 1);
|
||||
|
||||
assert(m_structs.find(structIndex) == m_structs.end());
|
||||
m_structs[structIndex] = structDesc;
|
||||
}
|
||||
|
||||
inline void SpirvAstVisitor::RegisterVariable(std::size_t varIndex, UInt32 typeId, UInt32 pointerId, SpirvStorageClass storageClass)
|
||||
{
|
||||
if (varIndex >= m_variables.size())
|
||||
m_variables.resize(varIndex + 1);
|
||||
|
||||
assert(m_variables.find(varIndex) == m_variables.end());
|
||||
m_variables[varIndex] = Variable{
|
||||
storageClass,
|
||||
pointerId,
|
||||
|
|
|
|||
|
|
@ -81,10 +81,22 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
if (statementPtr->GetType() == NodeType::DeclareFunctionStatement)
|
||||
DeclareFunction(static_cast<DeclareFunctionStatement&>(*statementPtr));
|
||||
else if (statementPtr->GetType() == NodeType::ConditionalStatement)
|
||||
{
|
||||
const ConditionalStatement& condStatement = static_cast<const ConditionalStatement&>(*statementPtr);
|
||||
if (condStatement.statement->GetType() == NodeType::DeclareFunctionStatement)
|
||||
DeclareFunction(static_cast<DeclareFunctionStatement&>(*condStatement.statement));
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (statement.GetType() == NodeType::DeclareFunctionStatement)
|
||||
DeclareFunction(static_cast<DeclareFunctionStatement&>(statement));
|
||||
else if (statement.GetType() == NodeType::ConditionalStatement)
|
||||
{
|
||||
const ConditionalStatement& condStatement = static_cast<const ConditionalStatement&>(statement);
|
||||
if (condStatement.statement->GetType() == NodeType::DeclareFunctionStatement)
|
||||
DeclareFunction(static_cast<DeclareFunctionStatement&>(*condStatement.statement));
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
|
|
@ -1193,7 +1205,8 @@ namespace Nz::ShaderAst
|
|||
{
|
||||
assert(funcIndex < m_context->functions.size());
|
||||
auto& funcData = m_context->functions[funcIndex];
|
||||
assert(funcData.defined);
|
||||
if (!funcData.defined)
|
||||
return;
|
||||
|
||||
funcData.flags |= flags;
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,16 @@
|
|||
|
||||
namespace Nz
|
||||
{
|
||||
namespace
|
||||
{
|
||||
template<typename T> const T& Retrieve(const std::unordered_map<std::size_t, T>& map, std::size_t id)
|
||||
{
|
||||
auto it = map.find(id);
|
||||
assert(it != map.end());
|
||||
return it->second;
|
||||
}
|
||||
}
|
||||
|
||||
UInt32 SpirvAstVisitor::AllocateResultId()
|
||||
{
|
||||
return m_writer.AllocateResultId();
|
||||
|
|
@ -30,9 +40,7 @@ namespace Nz
|
|||
|
||||
auto SpirvAstVisitor::GetVariable(std::size_t varIndex) const -> const Variable&
|
||||
{
|
||||
assert(varIndex < m_variables.size());
|
||||
assert(m_variables[varIndex]);
|
||||
return *m_variables[varIndex];
|
||||
return Retrieve(m_variables, varIndex);
|
||||
}
|
||||
|
||||
void SpirvAstVisitor::Visit(ShaderAst::AccessIndexExpression& node)
|
||||
|
|
@ -415,9 +423,9 @@ namespace Nz
|
|||
std::size_t functionIndex = std::get<std::size_t>(node.targetFunction);
|
||||
|
||||
UInt32 funcId = 0;
|
||||
for (const auto& func : m_funcData)
|
||||
for (const auto& [funcIndex, func] : m_funcData)
|
||||
{
|
||||
if (func.funcIndex == functionIndex)
|
||||
if (funcIndex == functionIndex)
|
||||
{
|
||||
funcId = func.funcId;
|
||||
break;
|
||||
|
|
@ -425,7 +433,7 @@ namespace Nz
|
|||
}
|
||||
assert(funcId != 0);
|
||||
|
||||
const FuncData& funcData = m_funcData[m_funcIndex];
|
||||
const FuncData& funcData = Retrieve(m_funcData, m_funcIndex);
|
||||
const auto& funcCall = funcData.funcCalls[m_funcCallIndex++];
|
||||
|
||||
StackArray<UInt32> parameterIds = NazaraStackArrayNoInit(UInt32, node.parameters.size());
|
||||
|
|
|
|||
|
|
@ -35,6 +35,20 @@ namespace Nz
|
|||
SpirvBuiltIn decoration;
|
||||
};
|
||||
|
||||
template<typename T> T& Retrieve(std::unordered_map<std::size_t, T>& map, std::size_t id)
|
||||
{
|
||||
auto it = map.find(id);
|
||||
assert(it != map.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
template<typename T> const T& Retrieve(const std::unordered_map<std::size_t, T>& map, std::size_t id)
|
||||
{
|
||||
auto it = map.find(id);
|
||||
assert(it != map.end());
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = {
|
||||
{ ShaderAst::BuiltinEntry::FragCoord, { "FragmentCoordinates", ShaderStageType::Fragment, SpirvBuiltIn::FragCoord } },
|
||||
{ ShaderAst::BuiltinEntry::FragDepth, { "FragmentDepth", ShaderStageType::Fragment, SpirvBuiltIn::FragDepth } },
|
||||
|
|
@ -59,7 +73,7 @@ namespace Nz
|
|||
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
|
||||
using StructContainer = std::vector<ShaderAst::StructDescription*>;
|
||||
|
||||
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector<SpirvAstVisitor::FuncData>& funcs) :
|
||||
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::unordered_map<std::size_t, SpirvAstVisitor::FuncData>& funcs) :
|
||||
m_states(conditions),
|
||||
m_constantCache(constantCache),
|
||||
m_externalBlockIndex(0),
|
||||
|
|
@ -91,7 +105,7 @@ namespace Nz
|
|||
AstRecursiveVisitor::Visit(node);
|
||||
|
||||
assert(m_funcIndex);
|
||||
auto& func = m_funcs[*m_funcIndex];
|
||||
auto& func = Retrieve(m_funcs, *m_funcIndex);
|
||||
|
||||
auto& funcCall = func.funcCalls.emplace_back();
|
||||
funcCall.firstVarIndex = func.variables.size();
|
||||
|
|
@ -152,9 +166,6 @@ namespace Nz
|
|||
assert(node.funcIndex);
|
||||
std::size_t funcIndex = *node.funcIndex;
|
||||
|
||||
if (funcIndex >= m_funcs.size())
|
||||
m_funcs.resize(funcIndex + 1);
|
||||
|
||||
auto& funcData = m_funcs[funcIndex];
|
||||
funcData.name = node.name;
|
||||
funcData.funcIndex = funcIndex;
|
||||
|
|
@ -409,7 +420,7 @@ namespace Nz
|
|||
SpirvConstantCache& m_constantCache;
|
||||
std::optional<std::size_t> m_funcIndex;
|
||||
std::size_t m_externalBlockIndex;
|
||||
std::vector<SpirvAstVisitor::FuncData>& m_funcs;
|
||||
std::unordered_map<std::size_t, SpirvAstVisitor::FuncData>& m_funcs;
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -429,7 +440,7 @@ namespace Nz
|
|||
|
||||
std::unordered_map<std::string, UInt32> extensionInstructionSet;
|
||||
std::unordered_map<std::string, UInt32> varToResult;
|
||||
std::vector<SpirvAstVisitor::FuncData> funcs;
|
||||
std::unordered_map<std::size_t, SpirvAstVisitor::FuncData> funcs;
|
||||
std::vector<UInt32> resultIds;
|
||||
UInt32 nextVarIndex = 1;
|
||||
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
|
||||
|
|
@ -488,7 +499,7 @@ namespace Nz
|
|||
state.extensionInstructionSet[extInst] = AllocateResultId();
|
||||
|
||||
// Assign function ID (required for forward declaration)
|
||||
for (auto& func : state.funcs)
|
||||
for (auto&& [funcIndex, func] : state.funcs)
|
||||
func.funcId = AllocateResultId();
|
||||
|
||||
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
|
||||
|
|
@ -548,7 +559,7 @@ namespace Nz
|
|||
|
||||
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450);
|
||||
|
||||
for (auto& func : m_currentState->funcs)
|
||||
for (auto&& [funcIndex, func] : m_currentState->funcs)
|
||||
{
|
||||
m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name);
|
||||
|
||||
|
|
@ -588,7 +599,7 @@ namespace Nz
|
|||
}
|
||||
|
||||
// Write execution modes
|
||||
for (auto& func : m_currentState->funcs)
|
||||
for (auto&& [funcIndex, func] : m_currentState->funcs)
|
||||
{
|
||||
if (func.entryPointData)
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in New Issue