From 22714327485f83d3042fc1434bf658f3a7f56c47 Mon Sep 17 00:00:00 2001 From: Lynix Date: Wed, 29 Jul 2020 14:39:34 +0200 Subject: [PATCH] Renderer/ShaderValidator: Use ShaderVarVisitor instead of switch --- include/Nazara/Renderer/ShaderValidator.hpp | 11 +- src/Nazara/Renderer/ShaderValidator.cpp | 163 +++++++++----------- 2 files changed, 85 insertions(+), 89 deletions(-) diff --git a/include/Nazara/Renderer/ShaderValidator.hpp b/include/Nazara/Renderer/ShaderValidator.hpp index 237ecaa73..edf54370f 100644 --- a/include/Nazara/Renderer/ShaderValidator.hpp +++ b/include/Nazara/Renderer/ShaderValidator.hpp @@ -13,10 +13,11 @@ #include #include #include +#include namespace Nz { - class NAZARA_RENDERER_API ShaderValidator : public ShaderRecursiveVisitor + class NAZARA_RENDERER_API ShaderValidator : public ShaderRecursiveVisitor, public ShaderVarVisitor { public: inline ShaderValidator(const ShaderAst& shader); @@ -47,6 +48,14 @@ namespace Nz void Visit(const ShaderNodes::StatementBlock& node) override; void Visit(const ShaderNodes::SwizzleOp& node) override; + using ShaderVarVisitor::Visit; + void Visit(const ShaderNodes::BuiltinVariable& var) override; + void Visit(const ShaderNodes::InputVariable& var) override; + void Visit(const ShaderNodes::LocalVariable& var) override; + void Visit(const ShaderNodes::OutputVariable& var) override; + void Visit(const ShaderNodes::ParameterVariable& var) override; + void Visit(const ShaderNodes::UniformVariable& var) override; + struct Context; const ShaderAst& m_shader; diff --git a/src/Nazara/Renderer/ShaderValidator.cpp b/src/Nazara/Renderer/ShaderValidator.cpp index 01f038217..351a46839 100644 --- a/src/Nazara/Renderer/ShaderValidator.cpp +++ b/src/Nazara/Renderer/ShaderValidator.cpp @@ -251,94 +251,7 @@ namespace Nz if (!node.var) throw AstError{ "Invalid variable" }; - //< FIXME: Use variable visitor - switch (node.var->GetType()) - { - case ShaderNodes::VariableType::BuiltinVariable: - break; - - case ShaderNodes::VariableType::InputVariable: - { - auto& namedVar = static_cast(*node.var); - - for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i) - { - const auto& input = m_shader.GetInput(i); - if (input.name == namedVar.name) - { - TypeMustMatch(input.type, namedVar.type); - return; - } - } - - throw AstError{ "Input not found" }; - } - - case ShaderNodes::VariableType::LocalVariable: - { - auto& localVar = static_cast(*node.var); - const auto& vars = m_context->declaredLocals; - - auto it = std::find_if(vars.begin(), vars.end(), [&](const auto& var) { return var.name == localVar.name; }); - if (it == vars.end()) - throw AstError{ "Local variable not found in this block" }; - - TypeMustMatch(it->type, localVar.type); - break; - } - - case ShaderNodes::VariableType::OutputVariable: - { - auto& outputVar = static_cast(*node.var); - - for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i) - { - const auto& input = m_shader.GetOutput(i); - if (input.name == outputVar.name) - { - TypeMustMatch(input.type, outputVar.type); - return; - } - } - - throw AstError{ "Output not found" }; - } - - case ShaderNodes::VariableType::ParameterVariable: - { - assert(m_context->currentFunction); - - auto& parameter = static_cast(*node.var); - const auto& parameters = m_context->currentFunction->parameters; - - auto it = std::find_if(parameters.begin(), parameters.end(), [&](const auto& parameter) { return parameter.name == parameter.name; }); - if (it == parameters.end()) - throw AstError{ "Parameter not found in function" }; - - TypeMustMatch(it->type, parameter.type); - break; - } - - case ShaderNodes::VariableType::UniformVariable: - { - auto& uniformVar = static_cast(*node.var); - - for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i) - { - const auto& uniform = m_shader.GetUniform(i); - if (uniform.name == uniformVar.name) - { - TypeMustMatch(uniform.type, uniformVar.type); - return; - } - } - - throw AstError{ "Uniform not found" }; - } - - default: - break; - } + Visit(node.var); } void ShaderValidator::Visit(const ShaderNodes::IntrinsicCall& node) @@ -433,6 +346,80 @@ namespace Nz ShaderRecursiveVisitor::Visit(node); } + void ShaderValidator::Visit(const ShaderNodes::BuiltinVariable& /*var*/) + { + /* Nothing to do */ + } + + void ShaderValidator::Visit(const ShaderNodes::InputVariable& var) + { + for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i) + { + const auto& input = m_shader.GetInput(i); + if (input.name == var.name) + { + TypeMustMatch(input.type, var.type); + return; + } + } + + throw AstError{ "Input not found" }; + } + + void ShaderValidator::Visit(const ShaderNodes::LocalVariable& var) + { + const auto& vars = m_context->declaredLocals; + + auto it = std::find_if(vars.begin(), vars.end(), [&](const auto& v) { return v.name == var.name; }); + if (it == vars.end()) + throw AstError{ "Local variable not found in this block" }; + + TypeMustMatch(it->type, var.type); + } + + void ShaderValidator::Visit(const ShaderNodes::OutputVariable& var) + { + for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i) + { + const auto& input = m_shader.GetOutput(i); + if (input.name == var.name) + { + TypeMustMatch(input.type, var.type); + return; + } + } + + throw AstError{ "Output not found" }; + } + + void ShaderValidator::Visit(const ShaderNodes::ParameterVariable& var) + { + assert(m_context->currentFunction); + + const auto& parameters = m_context->currentFunction->parameters; + + auto it = std::find_if(parameters.begin(), parameters.end(), [&](const auto& parameter) { return parameter.name == var.name; }); + if (it == parameters.end()) + throw AstError{ "Parameter not found in function" }; + + TypeMustMatch(it->type, var.type); + } + + void ShaderValidator::Visit(const ShaderNodes::UniformVariable& var) + { + for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i) + { + const auto& uniform = m_shader.GetUniform(i); + if (uniform.name == var.name) + { + TypeMustMatch(uniform.type, var.type); + return; + } + } + + throw AstError{ "Uniform not found" }; + } + bool ValidateShader(const ShaderAst& shader, std::string* error) { ShaderValidator validator(shader);