Shader: Rework scope handling

This commit is contained in:
Jérôme Leclercq
2021-04-04 20:31:09 +02:00
parent feffcfa6e5
commit f93a5bbdc1
23 changed files with 661 additions and 755 deletions

View File

@@ -25,18 +25,26 @@ namespace Nz
{
namespace
{
class PreVisitor : public ShaderAst::AstRecursiveVisitor
class PreVisitor : public ShaderAst::AstScopedVisitor
{
public:
using ExtInstList = std::unordered_set<std::string>;
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
m_cache(cache),
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
m_conditions(conditions),
m_constantCache(constantCache)
{
m_constantCache.SetIdentifierCallback([&](const std::string& identifierName)
{
const Identifier* identifier = FindIdentifier(identifierName);
if (!identifier)
throw std::runtime_error("invalid identifier " + identifierName);
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
return SpirvConstantCache::BuildType(std::get<ShaderAst::StructDescription>(identifier->value));
});
}
void Visit(ShaderAst::AccessMemberExpression& node) override
@@ -74,7 +82,7 @@ namespace Nz
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
}, node.value);
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
}
void Visit(ShaderAst::DeclareFunctionStatement& node) override
@@ -87,11 +95,13 @@ namespace Nz
m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes));
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
}
void Visit(ShaderAst::DeclareStructStatement& node) override
{
AstScopedVisitor::Visit(node);
SpirvConstantCache::Structure sType;
sType.name = node.description.name;
@@ -107,21 +117,21 @@ namespace Nz
void Visit(ShaderAst::DeclareVariableStatement& node) override
{
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
AstScopedVisitor::Visit(node);
AstRecursiveVisitor::Visit(node);
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
}
void Visit(ShaderAst::IdentifierExpression& node) override
{
m_constantCache.Register(*SpirvConstantCache::BuildType(GetExpressionType(node, m_cache)));
m_constantCache.Register(*SpirvConstantCache::BuildType(node.cachedExpressionType.value()));
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
}
void Visit(ShaderAst::IntrinsicExpression& node) override
{
AstRecursiveVisitor::Visit(node);
AstScopedVisitor::Visit(node);
switch (node.intrinsic)
{
@@ -140,7 +150,6 @@ namespace Nz
FunctionContainer funcs;
private:
ShaderAst::AstCache* m_cache;
const SpirvWriter::States& m_conditions;
SpirvConstantCache& m_constantCache;
};
@@ -214,7 +223,7 @@ namespace Nz
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
{
std::string error;
if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache))
if (!ShaderAst::ValidateAst(shader, &error))
throw std::runtime_error("Invalid shader AST: " + error);
m_context.states = &conditions;
@@ -229,7 +238,7 @@ namespace Nz
ShaderAst::AstCloner cloner;
// Register all extended instruction sets
PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache);
PreVisitor preVisitor(conditions, state.constantTypeCache);
shader->Visit(preVisitor);
for (const std::string& extInst : preVisitor.extInsts)
@@ -397,7 +406,7 @@ namespace Nz
state.parameterIds.emplace(param.name, std::move(parameterData));
}
SpirvAstVisitor visitor(*this, state.functionBlocks, &m_context.cache);
SpirvAstVisitor visitor(*this, state.functionBlocks);
for (const auto& statement : func.statements)
statement->Visit(visitor);
@@ -419,7 +428,7 @@ namespace Nz
for (std::size_t i = 0; i < ShaderStageTypeCount; ++i)
{
const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
/*const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
if (!statement)
continue;
@@ -462,7 +471,7 @@ namespace Nz
});
if (stage == ShaderStageType::Fragment)
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);*/
}
std::vector<UInt32> ret;