Renderer/ShaderValidator: Use ShaderVarVisitor instead of switch
This commit is contained in:
parent
a02dd3bf05
commit
2271432748
|
|
@ -13,10 +13,11 @@
|
|||
#include <Nazara/Renderer/Config.hpp>
|
||||
#include <Nazara/Renderer/ShaderAst.hpp>
|
||||
#include <Nazara/Renderer/ShaderRecursiveVisitor.hpp>
|
||||
#include <Nazara/Renderer/ShaderVarVisitor.hpp>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
class NAZARA_RENDERER_API ShaderValidator : public ShaderRecursiveVisitor
|
||||
class NAZARA_RENDERER_API ShaderValidator : public ShaderRecursiveVisitor, public ShaderVarVisitor
|
||||
{
|
||||
public:
|
||||
inline ShaderValidator(const ShaderAst& shader);
|
||||
|
|
@ -47,6 +48,14 @@ namespace Nz
|
|||
void Visit(const ShaderNodes::StatementBlock& node) override;
|
||||
void Visit(const ShaderNodes::SwizzleOp& node) override;
|
||||
|
||||
using ShaderVarVisitor::Visit;
|
||||
void Visit(const ShaderNodes::BuiltinVariable& var) override;
|
||||
void Visit(const ShaderNodes::InputVariable& var) override;
|
||||
void Visit(const ShaderNodes::LocalVariable& var) override;
|
||||
void Visit(const ShaderNodes::OutputVariable& var) override;
|
||||
void Visit(const ShaderNodes::ParameterVariable& var) override;
|
||||
void Visit(const ShaderNodes::UniformVariable& var) override;
|
||||
|
||||
struct Context;
|
||||
|
||||
const ShaderAst& m_shader;
|
||||
|
|
|
|||
|
|
@ -251,94 +251,7 @@ namespace Nz
|
|||
if (!node.var)
|
||||
throw AstError{ "Invalid variable" };
|
||||
|
||||
//< FIXME: Use variable visitor
|
||||
switch (node.var->GetType())
|
||||
{
|
||||
case ShaderNodes::VariableType::BuiltinVariable:
|
||||
break;
|
||||
|
||||
case ShaderNodes::VariableType::InputVariable:
|
||||
{
|
||||
auto& namedVar = static_cast<ShaderNodes::InputVariable&>(*node.var);
|
||||
|
||||
for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i)
|
||||
{
|
||||
const auto& input = m_shader.GetInput(i);
|
||||
if (input.name == namedVar.name)
|
||||
{
|
||||
TypeMustMatch(input.type, namedVar.type);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
throw AstError{ "Input not found" };
|
||||
}
|
||||
|
||||
case ShaderNodes::VariableType::LocalVariable:
|
||||
{
|
||||
auto& localVar = static_cast<ShaderNodes::LocalVariable&>(*node.var);
|
||||
const auto& vars = m_context->declaredLocals;
|
||||
|
||||
auto it = std::find_if(vars.begin(), vars.end(), [&](const auto& var) { return var.name == localVar.name; });
|
||||
if (it == vars.end())
|
||||
throw AstError{ "Local variable not found in this block" };
|
||||
|
||||
TypeMustMatch(it->type, localVar.type);
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::VariableType::OutputVariable:
|
||||
{
|
||||
auto& outputVar = static_cast<ShaderNodes::OutputVariable&>(*node.var);
|
||||
|
||||
for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i)
|
||||
{
|
||||
const auto& input = m_shader.GetOutput(i);
|
||||
if (input.name == outputVar.name)
|
||||
{
|
||||
TypeMustMatch(input.type, outputVar.type);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
throw AstError{ "Output not found" };
|
||||
}
|
||||
|
||||
case ShaderNodes::VariableType::ParameterVariable:
|
||||
{
|
||||
assert(m_context->currentFunction);
|
||||
|
||||
auto& parameter = static_cast<ShaderNodes::ParameterVariable&>(*node.var);
|
||||
const auto& parameters = m_context->currentFunction->parameters;
|
||||
|
||||
auto it = std::find_if(parameters.begin(), parameters.end(), [&](const auto& parameter) { return parameter.name == parameter.name; });
|
||||
if (it == parameters.end())
|
||||
throw AstError{ "Parameter not found in function" };
|
||||
|
||||
TypeMustMatch(it->type, parameter.type);
|
||||
break;
|
||||
}
|
||||
|
||||
case ShaderNodes::VariableType::UniformVariable:
|
||||
{
|
||||
auto& uniformVar = static_cast<ShaderNodes::UniformVariable&>(*node.var);
|
||||
|
||||
for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i)
|
||||
{
|
||||
const auto& uniform = m_shader.GetUniform(i);
|
||||
if (uniform.name == uniformVar.name)
|
||||
{
|
||||
TypeMustMatch(uniform.type, uniformVar.type);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
throw AstError{ "Uniform not found" };
|
||||
}
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
Visit(node.var);
|
||||
}
|
||||
|
||||
void ShaderValidator::Visit(const ShaderNodes::IntrinsicCall& node)
|
||||
|
|
@ -433,6 +346,80 @@ namespace Nz
|
|||
ShaderRecursiveVisitor::Visit(node);
|
||||
}
|
||||
|
||||
void ShaderValidator::Visit(const ShaderNodes::BuiltinVariable& /*var*/)
|
||||
{
|
||||
/* Nothing to do */
|
||||
}
|
||||
|
||||
void ShaderValidator::Visit(const ShaderNodes::InputVariable& var)
|
||||
{
|
||||
for (std::size_t i = 0; i < m_shader.GetInputCount(); ++i)
|
||||
{
|
||||
const auto& input = m_shader.GetInput(i);
|
||||
if (input.name == var.name)
|
||||
{
|
||||
TypeMustMatch(input.type, var.type);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
throw AstError{ "Input not found" };
|
||||
}
|
||||
|
||||
void ShaderValidator::Visit(const ShaderNodes::LocalVariable& var)
|
||||
{
|
||||
const auto& vars = m_context->declaredLocals;
|
||||
|
||||
auto it = std::find_if(vars.begin(), vars.end(), [&](const auto& v) { return v.name == var.name; });
|
||||
if (it == vars.end())
|
||||
throw AstError{ "Local variable not found in this block" };
|
||||
|
||||
TypeMustMatch(it->type, var.type);
|
||||
}
|
||||
|
||||
void ShaderValidator::Visit(const ShaderNodes::OutputVariable& var)
|
||||
{
|
||||
for (std::size_t i = 0; i < m_shader.GetOutputCount(); ++i)
|
||||
{
|
||||
const auto& input = m_shader.GetOutput(i);
|
||||
if (input.name == var.name)
|
||||
{
|
||||
TypeMustMatch(input.type, var.type);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
throw AstError{ "Output not found" };
|
||||
}
|
||||
|
||||
void ShaderValidator::Visit(const ShaderNodes::ParameterVariable& var)
|
||||
{
|
||||
assert(m_context->currentFunction);
|
||||
|
||||
const auto& parameters = m_context->currentFunction->parameters;
|
||||
|
||||
auto it = std::find_if(parameters.begin(), parameters.end(), [&](const auto& parameter) { return parameter.name == var.name; });
|
||||
if (it == parameters.end())
|
||||
throw AstError{ "Parameter not found in function" };
|
||||
|
||||
TypeMustMatch(it->type, var.type);
|
||||
}
|
||||
|
||||
void ShaderValidator::Visit(const ShaderNodes::UniformVariable& var)
|
||||
{
|
||||
for (std::size_t i = 0; i < m_shader.GetUniformCount(); ++i)
|
||||
{
|
||||
const auto& uniform = m_shader.GetUniform(i);
|
||||
if (uniform.name == var.name)
|
||||
{
|
||||
TypeMustMatch(uniform.type, var.type);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
throw AstError{ "Uniform not found" };
|
||||
}
|
||||
|
||||
bool ValidateShader(const ShaderAst& shader, std::string* error)
|
||||
{
|
||||
ShaderValidator validator(shader);
|
||||
|
|
|
|||
Loading…
Reference in New Issue