Shader: Rework scope handling
This commit is contained in:
@@ -25,18 +25,26 @@ namespace Nz
|
||||
{
|
||||
namespace
|
||||
{
|
||||
class PreVisitor : public ShaderAst::AstRecursiveVisitor
|
||||
class PreVisitor : public ShaderAst::AstScopedVisitor
|
||||
{
|
||||
public:
|
||||
using ExtInstList = std::unordered_set<std::string>;
|
||||
using LocalContainer = std::unordered_set<ShaderAst::ExpressionType>;
|
||||
using FunctionContainer = std::vector<std::reference_wrapper<ShaderAst::DeclareFunctionStatement>>;
|
||||
|
||||
PreVisitor(ShaderAst::AstCache* cache, const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
m_cache(cache),
|
||||
PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache) :
|
||||
m_conditions(conditions),
|
||||
m_constantCache(constantCache)
|
||||
{
|
||||
m_constantCache.SetIdentifierCallback([&](const std::string& identifierName)
|
||||
{
|
||||
const Identifier* identifier = FindIdentifier(identifierName);
|
||||
if (!identifier)
|
||||
throw std::runtime_error("invalid identifier " + identifierName);
|
||||
|
||||
assert(std::holds_alternative<ShaderAst::StructDescription>(identifier->value));
|
||||
return SpirvConstantCache::BuildType(std::get<ShaderAst::StructDescription>(identifier->value));
|
||||
});
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::AccessMemberExpression& node) override
|
||||
@@ -74,7 +82,7 @@ namespace Nz
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildConstant(arg));
|
||||
}, node.value);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
@@ -87,11 +95,13 @@ namespace Nz
|
||||
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildFunctionType(node.returnType, parameterTypes));
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::DeclareStructStatement& node) override
|
||||
{
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
SpirvConstantCache::Structure sType;
|
||||
sType.name = node.description.name;
|
||||
|
||||
@@ -107,21 +117,21 @@ namespace Nz
|
||||
|
||||
void Visit(ShaderAst::DeclareVariableStatement& node) override
|
||||
{
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.varType));
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::IdentifierExpression& node) override
|
||||
{
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(GetExpressionType(node, m_cache)));
|
||||
m_constantCache.Register(*SpirvConstantCache::BuildType(node.cachedExpressionType.value()));
|
||||
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void Visit(ShaderAst::IntrinsicExpression& node) override
|
||||
{
|
||||
AstRecursiveVisitor::Visit(node);
|
||||
AstScopedVisitor::Visit(node);
|
||||
|
||||
switch (node.intrinsic)
|
||||
{
|
||||
@@ -140,7 +150,6 @@ namespace Nz
|
||||
FunctionContainer funcs;
|
||||
|
||||
private:
|
||||
ShaderAst::AstCache* m_cache;
|
||||
const SpirvWriter::States& m_conditions;
|
||||
SpirvConstantCache& m_constantCache;
|
||||
};
|
||||
@@ -214,7 +223,7 @@ namespace Nz
|
||||
std::vector<UInt32> SpirvWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
|
||||
{
|
||||
std::string error;
|
||||
if (!ShaderAst::ValidateAst(shader, &error, &m_context.cache))
|
||||
if (!ShaderAst::ValidateAst(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
m_context.states = &conditions;
|
||||
@@ -229,7 +238,7 @@ namespace Nz
|
||||
ShaderAst::AstCloner cloner;
|
||||
|
||||
// Register all extended instruction sets
|
||||
PreVisitor preVisitor(&m_context.cache, conditions, state.constantTypeCache);
|
||||
PreVisitor preVisitor(conditions, state.constantTypeCache);
|
||||
shader->Visit(preVisitor);
|
||||
|
||||
for (const std::string& extInst : preVisitor.extInsts)
|
||||
@@ -397,7 +406,7 @@ namespace Nz
|
||||
state.parameterIds.emplace(param.name, std::move(parameterData));
|
||||
}
|
||||
|
||||
SpirvAstVisitor visitor(*this, state.functionBlocks, &m_context.cache);
|
||||
SpirvAstVisitor visitor(*this, state.functionBlocks);
|
||||
for (const auto& statement : func.statements)
|
||||
statement->Visit(visitor);
|
||||
|
||||
@@ -419,7 +428,7 @@ namespace Nz
|
||||
|
||||
for (std::size_t i = 0; i < ShaderStageTypeCount; ++i)
|
||||
{
|
||||
const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
|
||||
/*const ShaderAst::DeclareFunctionStatement* statement = m_context.cache.entryFunctions[i];
|
||||
if (!statement)
|
||||
continue;
|
||||
|
||||
@@ -462,7 +471,7 @@ namespace Nz
|
||||
});
|
||||
|
||||
if (stage == ShaderStageType::Fragment)
|
||||
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);
|
||||
state.header.Append(SpirvOp::OpExecutionMode, entryFunc.id, SpirvExecutionMode::OriginUpperLeft);*/
|
||||
}
|
||||
|
||||
std::vector<UInt32> ret;
|
||||
|
||||
Reference in New Issue
Block a user