Shader: Merge AstScopedVisitor, AstValidator and TransformVisitor to SanitizeVisitor
This commit is contained in:
@@ -8,9 +8,9 @@
|
||||
#include <Nazara/Math/Algorithm.hpp>
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <Nazara/Shader/ShaderAstCloner.hpp>
|
||||
#include <Nazara/Shader/ShaderAstRecursiveVisitor.hpp>
|
||||
#include <Nazara/Shader/ShaderAstUtils.hpp>
|
||||
#include <Nazara/Shader/ShaderAstValidator.hpp>
|
||||
#include <Nazara/Shader/Ast/TransformVisitor.hpp>
|
||||
#include <Nazara/Shader/Ast/SanitizeVisitor.hpp>
|
||||
#include <optional>
|
||||
#include <stdexcept>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
@@ -31,36 +31,29 @@ namespace Nz
|
||||
return it->second;
|
||||
}
|
||||
|
||||
struct PreVisitor : ShaderAst::AstCloner
|
||||
struct PreVisitor : ShaderAst::AstRecursiveVisitor
|
||||
{
|
||||
using AstCloner::Clone;
|
||||
using AstRecursiveVisitor::Visit;
|
||||
|
||||
ShaderAst::StatementPtr Clone(ShaderAst::DeclareFunctionStatement& node) override
|
||||
void Visit(ShaderAst::DeclareFunctionStatement& node) override
|
||||
{
|
||||
auto clone = AstCloner::Clone(node);
|
||||
assert(clone->GetType() == ShaderAst::NodeType::DeclareFunctionStatement);
|
||||
|
||||
ShaderAst::DeclareFunctionStatement* func = static_cast<ShaderAst::DeclareFunctionStatement*>(clone.get());
|
||||
|
||||
// Remove function if it's an entry point of another type than the one selected
|
||||
// Dismiss function if it's an entry point of another type than the one selected
|
||||
if (selectedStage)
|
||||
{
|
||||
if (node.entryStage)
|
||||
{
|
||||
ShaderStageType stage = *node.entryStage;
|
||||
if (stage != *selectedStage)
|
||||
return ShaderBuilder::NoOp();
|
||||
return;
|
||||
|
||||
entryPoint = func;
|
||||
entryPoint = &node;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(!entryPoint);
|
||||
entryPoint = func;
|
||||
entryPoint = &node;
|
||||
}
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
std::optional<ShaderStageType> selectedStage;
|
||||
@@ -99,17 +92,6 @@ namespace Nz
|
||||
unsigned int indentLevel = 0;
|
||||
};
|
||||
|
||||
|
||||
GlslWriter::GlslWriter() :
|
||||
m_currentState(nullptr)
|
||||
{
|
||||
}
|
||||
|
||||
std::string GlslWriter::Generate(ShaderAst::StatementPtr& shader, const States& conditions)
|
||||
{
|
||||
return Generate(std::nullopt, shader, conditions);
|
||||
}
|
||||
|
||||
std::string GlslWriter::Generate(std::optional<ShaderStageType> shaderStage, ShaderAst::StatementPtr& shader, const States& conditions)
|
||||
{
|
||||
State state;
|
||||
@@ -121,17 +103,11 @@ namespace Nz
|
||||
m_currentState = nullptr;
|
||||
});
|
||||
|
||||
std::string error;
|
||||
if (!ShaderAst::ValidateAst(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
ShaderAst::TransformVisitor transformVisitor;
|
||||
ShaderAst::StatementPtr transformedShader = transformVisitor.Transform(shader);
|
||||
ShaderAst::StatementPtr sanitizedAst = ShaderAst::Sanitize(shader);
|
||||
|
||||
PreVisitor previsitor;
|
||||
previsitor.selectedStage = shaderStage;
|
||||
|
||||
ShaderAst::StatementPtr adaptedShader = previsitor.Clone(transformedShader);
|
||||
sanitizedAst->Visit(previsitor);
|
||||
|
||||
if (!previsitor.entryPoint)
|
||||
throw std::runtime_error("missing entry point");
|
||||
@@ -140,7 +116,7 @@ namespace Nz
|
||||
|
||||
AppendHeader();
|
||||
|
||||
adaptedShader->Visit(*this);
|
||||
sanitizedAst->Visit(*this);
|
||||
|
||||
return state.stream.str();
|
||||
}
|
||||
@@ -361,6 +337,9 @@ namespace Nz
|
||||
|
||||
void GlslWriter::HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node)
|
||||
{
|
||||
if (m_currentState->entryFunc != &node)
|
||||
return; //< Ignore other entry points
|
||||
|
||||
HandleInOut();
|
||||
AppendLine("void main()");
|
||||
EnterScope();
|
||||
@@ -712,11 +691,10 @@ namespace Nz
|
||||
{
|
||||
NazaraAssert(m_currentState, "This function should only be called while processing an AST");
|
||||
|
||||
if (m_currentState->entryFunc == &node)
|
||||
if (node.entryStage)
|
||||
return HandleEntryPoint(node);
|
||||
|
||||
assert(node.varIndex);
|
||||
std::size_t varIndex = *node.varIndex;
|
||||
std::optional<std::size_t> varIndexOpt = node.varIndex;
|
||||
|
||||
Append(node.returnType);
|
||||
Append(" ");
|
||||
@@ -731,6 +709,8 @@ namespace Nz
|
||||
Append(" ");
|
||||
Append(node.parameters[i].name);
|
||||
|
||||
assert(varIndexOpt);
|
||||
std::size_t& varIndex = *varIndexOpt;
|
||||
RegisterVariable(varIndex++, node.parameters[i].name);
|
||||
}
|
||||
Append(")\n");
|
||||
|
||||
Reference in New Issue
Block a user