Renderer/ShaderValidator: Use ShaderVarVisitor instead of switch

This commit is contained in:
Lynix 2020-07-29 14:39:34 +02:00
parent a02dd3bf05
commit 2271432748
2 changed files with 85 additions and 89 deletions

View File

@ -13,10 +13,11 @@
#include <Nazara/Renderer/Config.hpp>
#include <Nazara/Renderer/ShaderAst.hpp>
#include <Nazara/Renderer/ShaderRecursiveVisitor.hpp>
#include <Nazara/Renderer/ShaderVarVisitor.hpp>
namespace Nz
{
class NAZARA_RENDERER_API ShaderValidator : public ShaderRecursiveVisitor
class NAZARA_RENDERER_API ShaderValidator : public ShaderRecursiveVisitor, public ShaderVarVisitor
{
public:
inline ShaderValidator(const ShaderAst& shader);
@ -47,6 +48,14 @@ namespace Nz
void Visit(const ShaderNodes::StatementBlock& node) override;
void Visit(const ShaderNodes::SwizzleOp& node) override;
using ShaderVarVisitor::Visit;
void Visit(const ShaderNodes::BuiltinVariable& var) override;
void Visit(const ShaderNodes::InputVariable& var) override;
void Visit(const ShaderNodes::LocalVariable& var) override;
void Visit(const ShaderNodes::OutputVariable& var) override;
void Visit(const ShaderNodes::ParameterVariable& var) override;
void Visit(const ShaderNodes::UniformVariable& var) override;
struct Context;
const ShaderAst& m_shader;

View File

@ -251,94 +251,7 @@ namespace Nz
if (!node.var)
throw AstError{ "Invalid variable" };
//< FIXME: Use variable visitor
switch (node.var->GetType())
{
case ShaderNodes::VariableType::BuiltinVariable:
break;
case ShaderNodes::VariableType::InputVariable:
{
auto& namedVar = static_cast<ShaderNodes::InputVariable&>(*node.var);
for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i)
{
const auto& input = m_shader.GetInput(i);
if (input.name == namedVar.name)
{
TypeMustMatch(input.type, namedVar.type);
return;
}
}
throw AstError{ "Input not found" };
}
case ShaderNodes::VariableType::LocalVariable:
{
auto& localVar = static_cast<ShaderNodes::LocalVariable&>(*node.var);
const auto& vars = m_context->declaredLocals;
auto it = std::find_if(vars.begin(), vars.end(), [&](const auto& var) { return var.name == localVar.name; });
if (it == vars.end())
throw AstError{ "Local variable not found in this block" };
TypeMustMatch(it->type, localVar.type);
break;
}
case ShaderNodes::VariableType::OutputVariable:
{
auto& outputVar = static_cast<ShaderNodes::OutputVariable&>(*node.var);
for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i)
{
const auto& input = m_shader.GetOutput(i);
if (input.name == outputVar.name)
{
TypeMustMatch(input.type, outputVar.type);
return;
}
}
throw AstError{ "Output not found" };
}
case ShaderNodes::VariableType::ParameterVariable:
{
assert(m_context->currentFunction);
auto& parameter = static_cast<ShaderNodes::ParameterVariable&>(*node.var);
const auto& parameters = m_context->currentFunction->parameters;
auto it = std::find_if(parameters.begin(), parameters.end(), [&](const auto& parameter) { return parameter.name == parameter.name; });
if (it == parameters.end())
throw AstError{ "Parameter not found in function" };
TypeMustMatch(it->type, parameter.type);
break;
}
case ShaderNodes::VariableType::UniformVariable:
{
auto& uniformVar = static_cast<ShaderNodes::UniformVariable&>(*node.var);
for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i)
{
const auto& uniform = m_shader.GetUniform(i);
if (uniform.name == uniformVar.name)
{
TypeMustMatch(uniform.type, uniformVar.type);
return;
}
}
throw AstError{ "Uniform not found" };
}
default:
break;
}
Visit(node.var);
}
void ShaderValidator::Visit(const ShaderNodes::IntrinsicCall& node)
@ -433,6 +346,80 @@ namespace Nz
ShaderRecursiveVisitor::Visit(node);
}
void ShaderValidator::Visit(const ShaderNodes::BuiltinVariable& /*var*/)
{
/* Nothing to do */
}
void ShaderValidator::Visit(const ShaderNodes::InputVariable& var)
{
for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i)
{
const auto& input = m_shader.GetInput(i);
if (input.name == var.name)
{
TypeMustMatch(input.type, var.type);
return;
}
}
throw AstError{ "Input not found" };
}
void ShaderValidator::Visit(const ShaderNodes::LocalVariable& var)
{
const auto& vars = m_context->declaredLocals;
auto it = std::find_if(vars.begin(), vars.end(), [&](const auto& v) { return v.name == var.name; });
if (it == vars.end())
throw AstError{ "Local variable not found in this block" };
TypeMustMatch(it->type, var.type);
}
void ShaderValidator::Visit(const ShaderNodes::OutputVariable& var)
{
for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i)
{
const auto& input = m_shader.GetOutput(i);
if (input.name == var.name)
{
TypeMustMatch(input.type, var.type);
return;
}
}
throw AstError{ "Output not found" };
}
void ShaderValidator::Visit(const ShaderNodes::ParameterVariable& var)
{
assert(m_context->currentFunction);
const auto& parameters = m_context->currentFunction->parameters;
auto it = std::find_if(parameters.begin(), parameters.end(), [&](const auto& parameter) { return parameter.name == var.name; });
if (it == parameters.end())
throw AstError{ "Parameter not found in function" };
TypeMustMatch(it->type, var.type);
}
void ShaderValidator::Visit(const ShaderNodes::UniformVariable& var)
{
for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i)
{
const auto& uniform = m_shader.GetUniform(i);
if (uniform.name == var.name)
{
TypeMustMatch(uniform.type, var.type);
return;
}
}
throw AstError{ "Uniform not found" };
}
bool ValidateShader(const ShaderAst& shader, std::string* error)
{
ShaderValidator validator(shader);