Shader: Fix support of conditional functions

This commit is contained in:
Jérôme Leclercq
2021-08-02 11:12:34 +02:00
parent c8e7fa5063
commit a5b71f33b9
5 changed files with 57 additions and 28 deletions

View File

@@ -81,10 +81,22 @@ namespace Nz::ShaderAst
{
if (statementPtr->GetType() == NodeType::DeclareFunctionStatement)
DeclareFunction(static_cast<DeclareFunctionStatement&>(*statementPtr));
else if (statementPtr->GetType() == NodeType::ConditionalStatement)
{
const ConditionalStatement& condStatement = static_cast<const ConditionalStatement&>(*statementPtr);
if (condStatement.statement->GetType() == NodeType::DeclareFunctionStatement)
DeclareFunction(static_cast<DeclareFunctionStatement&>(*condStatement.statement));
}
}
}
else if (statement.GetType() == NodeType::DeclareFunctionStatement)
DeclareFunction(static_cast<DeclareFunctionStatement&>(statement));
else if (statement.GetType() == NodeType::ConditionalStatement)
{
const ConditionalStatement& condStatement = static_cast<const ConditionalStatement&>(statement);
if (condStatement.statement->GetType() == NodeType::DeclareFunctionStatement)
DeclareFunction(static_cast<DeclareFunctionStatement&>(*condStatement.statement));
}
try
{
@@ -1193,7 +1205,8 @@ namespace Nz::ShaderAst
{
assert(funcIndex < m_context->functions.size());
auto& funcData = m_context->functions[funcIndex];
assert(funcData.defined);
if (!funcData.defined)
return;
funcData.flags |= flags;

View File

@@ -15,6 +15,16 @@
namespace Nz
{
namespace
{
template<typename T> const T& Retrieve(const std::unordered_map<std::size_t, T>& map, std::size_t id)
{
auto it = map.find(id);
assert(it != map.end());
return it->second;
}
}
UInt32 SpirvAstVisitor::AllocateResultId()
{
return m_writer.AllocateResultId();
@@ -30,9 +40,7 @@ namespace Nz
auto SpirvAstVisitor::GetVariable(std::size_t varIndex) const -> const Variable&
{
assert(varIndex < m_variables.size());
assert(m_variables[varIndex]);
return *m_variables[varIndex];
return Retrieve(m_variables, varIndex);
}
void SpirvAstVisitor::Visit(ShaderAst::AccessIndexExpression& node)
@@ -415,9 +423,9 @@ namespace Nz
std::size_t functionIndex = std::get<std::size_t>(node.targetFunction);
UInt32 funcId = 0;
for (const auto& func : m_funcData)
for (const auto& [funcIndex, func] : m_funcData)
{
if (func.funcIndex == functionIndex)
if (funcIndex == functionIndex)
{
funcId = func.funcId;
break;
@@ -425,7 +433,7 @@ namespace Nz
}
assert(funcId != 0);
const FuncData& funcData = m_funcData[m_funcIndex];
const FuncData& funcData = Retrieve(m_funcData, m_funcIndex);
const auto& funcCall = funcData.funcCalls[m_funcCallIndex++];
StackArray<UInt32> parameterIds = NazaraStackArrayNoInit(UInt32, node.parameters.size());

View File

@@ -35,6 +35,20 @@ namespace Nz
SpirvBuiltIn decoration;
};
template<typename T> T& Retrieve(std::unordered_map<std::size_t, T>& map, std::size_t id)
{
auto it = map.find(id);
assert(it != map.end());
return it->second;
}
template<typename T> const T& Retrieve(const std::unordered_map<std::size_t, T>& map, std::size_t id)
{
auto it = map.find(id);
assert(it != map.end());
return it->second;
}
std::unordered_map<ShaderAst::BuiltinEntry, Builtin> s_builtinMapping = {
{ ShaderAst::BuiltinEntry::FragCoord, { "FragmentCoordinates", ShaderStageType::Fragment, SpirvBuiltIn::FragCoord } },
{ ShaderAst::BuiltinEntry::FragDepth, { "FragmentDepth", ShaderStageType::Fragment, SpirvBuiltIn::FragDepth } },
@@ -59,7 +73,7 @@ namespace Nz
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
using StructContainer = std::vector<ShaderAst::StructDescription*>;
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector<SpirvAstVisitor::FuncData>& funcs) :
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::unordered_map<std::size_t, SpirvAstVisitor::FuncData>& funcs) :
m_states(conditions),
m_constantCache(constantCache),
m_externalBlockIndex(0),
@@ -91,7 +105,7 @@ namespace Nz
AstRecursiveVisitor::Visit(node);
assert(m_funcIndex);
auto& func = m_funcs[*m_funcIndex];
auto& func = Retrieve(m_funcs, *m_funcIndex);
auto& funcCall = func.funcCalls.emplace_back();
funcCall.firstVarIndex = func.variables.size();
@@ -152,9 +166,6 @@ namespace Nz
assert(node.funcIndex);
std::size_t funcIndex = *node.funcIndex;
if (funcIndex >= m_funcs.size())
m_funcs.resize(funcIndex + 1);
auto& funcData = m_funcs[funcIndex];
funcData.name = node.name;
funcData.funcIndex = funcIndex;
@@ -409,7 +420,7 @@ namespace Nz
SpirvConstantCache& m_constantCache;
std::optional<std::size_t> m_funcIndex;
std::size_t m_externalBlockIndex;
std::vector<SpirvAstVisitor::FuncData>& m_funcs;
std::unordered_map<std::size_t, SpirvAstVisitor::FuncData>& m_funcs;
};
}
@@ -429,7 +440,7 @@ namespace Nz
std::unordered_map<std::string, UInt32> extensionInstructionSet;
std::unordered_map<std::string, UInt32> varToResult;
std::vector<SpirvAstVisitor::FuncData> funcs;
std::unordered_map<std::size_t, SpirvAstVisitor::FuncData> funcs;
std::vector<UInt32> resultIds;
UInt32 nextVarIndex = 1;
SpirvConstantCache constantTypeCache; //< init after nextVarIndex
@@ -488,7 +499,7 @@ namespace Nz
state.extensionInstructionSet[extInst] = AllocateResultId();
// Assign function ID (required for forward declaration)
for (auto& func : state.funcs)
for (auto&& [funcIndex, func] : state.funcs)
func.funcId = AllocateResultId();
SpirvAstVisitor visitor(*this, state.instructions, state.funcs);
@@ -548,7 +559,7 @@ namespace Nz
m_currentState->header.Append(SpirvOp::OpMemoryModel, SpirvAddressingModel::Logical, SpirvMemoryModel::GLSL450);
for (auto& func : m_currentState->funcs)
for (auto&& [funcIndex, func] : m_currentState->funcs)
{
m_currentState->debugInfo.Append(SpirvOp::OpName, func.funcId, func.name);
@@ -588,7 +599,7 @@ namespace Nz
}
// Write execution modes
for (auto& func : m_currentState->funcs)
for (auto&& [funcIndex, func] : m_currentState->funcs)
{
if (func.entryPointData)
{