diff --git a/include/Nazara/Renderer/ShaderValidator.hpp b/include/Nazara/Renderer/ShaderValidator.hpp new file mode 100644 index 000000000..50075aaf0 --- /dev/null +++ b/include/Nazara/Renderer/ShaderValidator.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Renderer module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERVALIDATOR_HPP +#define NAZARA_SHADERVALIDATOR_HPP + +#include +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + class NAZARA_RENDERER_API ShaderValidator : public ShaderVisitor + { + public: + ShaderValidator() = default; + ShaderValidator(const ShaderValidator&) = delete; + ShaderValidator(ShaderValidator&&) = delete; + ~ShaderValidator() = default; + + bool Validate(const StatementPtr& shader, std::string* error = nullptr); + + private: + const ExpressionPtr& MandatoryExpr(const ExpressionPtr& node); + const NodePtr& MandatoryNode(const NodePtr& node); + void TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right); + + using ShaderVisitor::Visit; + void Visit(const AssignOp& node) override; + void Visit(const BinaryFunc& node) override; + void Visit(const BinaryOp& node) override; + void Visit(const Branch& node) override; + void Visit(const BuiltinVariable& node) override; + void Visit(const Cast& node) override; + void Visit(const Constant& node) override; + void Visit(const DeclareVariable& node) override; + void Visit(const ExpressionStatement& node) override; + void Visit(const NamedVariable& node) override; + void Visit(const Sample2D& node) override; + void Visit(const StatementBlock& node) override; + void Visit(const SwizzleOp& node) override; + }; + + NAZARA_RENDERER_API bool Validate(const StatementPtr& shader, std::string* error = nullptr); +} + +#include + +#endif diff --git a/include/Nazara/Renderer/ShaderValidator.inl b/include/Nazara/Renderer/ShaderValidator.inl new file mode 100644 index 000000000..2fb43d576 --- /dev/null +++ b/include/Nazara/Renderer/ShaderValidator.inl @@ -0,0 +1,12 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Renderer module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ +} + +#include diff --git a/src/Nazara/Renderer/GlslWriter.cpp b/src/Nazara/Renderer/GlslWriter.cpp index e7b8f6155..b982dab90 100644 --- a/src/Nazara/Renderer/GlslWriter.cpp +++ b/src/Nazara/Renderer/GlslWriter.cpp @@ -4,6 +4,8 @@ #include #include +#include +#include #include namespace Nz @@ -17,6 +19,10 @@ namespace Nz String GlslWriter::Generate(const ShaderAst::StatementPtr& node) { + std::string error; + if (!ShaderAst::Validate(node, &error)) + throw std::runtime_error("Invalid shader AST: " + error); + State state; m_currentState = &state; CallOnExit onExit([this]() diff --git a/src/Nazara/Renderer/ShaderValidator.cpp b/src/Nazara/Renderer/ShaderValidator.cpp new file mode 100644 index 000000000..2a4d38f25 --- /dev/null +++ b/src/Nazara/Renderer/ShaderValidator.cpp @@ -0,0 +1,234 @@ +// Copyright (C) 2015 Jérôme Leclercq +// This file is part of the "Nazara Engine - Renderer module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ + struct AstError + { + std::string errMsg; + }; + + bool ShaderValidator::Validate(const StatementPtr& shader, std::string* error) + { + try + { + shader->Visit(*this); + return true; + } + catch (const AstError& e) + { + if (error) + *error = e.errMsg; + + return false; + } + } + + const ExpressionPtr& ShaderValidator::MandatoryExpr(const ExpressionPtr& node) + { + MandatoryNode(node); + + return node; + } + + const NodePtr& ShaderValidator::MandatoryNode(const NodePtr& node) + { + if (!node) + throw AstError{ "Invalid node" }; + + return node; + } + + void ShaderValidator::TypeMustMatch(const ExpressionPtr& left, const ExpressionPtr& right) + { + if (left->GetExpressionType() != right->GetExpressionType()) + throw AstError{ "Left expression type must match right expression type" }; + } + + void ShaderValidator::Visit(const AssignOp& node) + { + MandatoryNode(node.left); + MandatoryNode(node.right); + TypeMustMatch(node.left, node.right); + + Visit(node.left); + Visit(node.right); + } + + void ShaderValidator::Visit(const BinaryFunc& node) + { + MandatoryNode(node.left); + MandatoryNode(node.right); + TypeMustMatch(node.left, node.right); + + switch (node.intrinsic) + { + case BinaryIntrinsic::CrossProduct: + { + if (node.left->GetExpressionType() != ExpressionType::Float3) + throw AstError{ "CrossProduct only works with Float3 expressions" }; + } + + case BinaryIntrinsic::DotProduct: + break; + } + + Visit(node.left); + Visit(node.right); + } + + void ShaderValidator::Visit(const BinaryOp& node) + { + MandatoryNode(node.left); + MandatoryNode(node.right); + + ExpressionType leftType = node.left->GetExpressionType(); + ExpressionType rightType = node.right->GetExpressionType(); + + switch (node.op) + { + case BinaryType::Add: + case BinaryType::Equality: + case BinaryType::Substract: + TypeMustMatch(node.left, node.right); + break; + + case BinaryType::Multiply: + case BinaryType::Divide: + { + switch (leftType) + { + case ExpressionType::Float2: + case ExpressionType::Float3: + case ExpressionType::Float4: + { + if (leftType != rightType && rightType != ExpressionType::Float1) + throw AstError{ "Left expression type is not compatible with right expression type" }; + + break; + } + + case ExpressionType::Mat4x4: + { + switch (rightType) + { + case ExpressionType::Float1: + case ExpressionType::Float4: + case ExpressionType::Mat4x4: + break; + + default: + TypeMustMatch(node.left, node.right); + } + + break; + } + + default: + TypeMustMatch(node.left, node.right); + } + } + } + + Visit(node.left); + Visit(node.right); + } + + void ShaderValidator::Visit(const Branch& node) + { + for (const auto& condStatement : node.condStatements) + { + Visit(MandatoryNode(condStatement.condition)); + Visit(MandatoryNode(condStatement.statement)); + } + } + + void ShaderValidator::Visit(const BuiltinVariable& /*node*/) + { + } + + void ShaderValidator::Visit(const Cast& node) + { + unsigned int componentCount = 0; + unsigned int requiredComponents = node.GetComponentCount(node.exprType); + for (const auto& exprPtr : node.expressions) + { + if (!exprPtr) + break; + + componentCount += node.GetComponentCount(exprPtr->GetExpressionType()); + Visit(exprPtr); + } + + if (componentCount != requiredComponents) + throw AstError{ "Component count doesn't match required component count" }; + } + + void ShaderValidator::Visit(const Constant& /*node*/) + { + } + + void ShaderValidator::Visit(const DeclareVariable& node) + { + Visit(MandatoryNode(node.expression)); + } + + void ShaderValidator::Visit(const ExpressionStatement& node) + { + Visit(MandatoryNode(node.expression)); + } + + void ShaderValidator::Visit(const NamedVariable& node) + { + if (node.name.empty()) + throw AstError{ "Variable has empty name" }; + } + + void ShaderValidator::Visit(const Sample2D& node) + { + if (MandatoryExpr(node.sampler)->GetExpressionType() != ExpressionType::Sampler2D) + throw AstError{ "Sampler must be a Sampler2D" }; + + if (MandatoryExpr(node.coordinates)->GetExpressionType() != ExpressionType::Float2) + throw AstError{ "Coordinates must be a Float2" }; + + Visit(node.sampler); + Visit(node.coordinates); + } + + void ShaderValidator::Visit(const StatementBlock& node) + { + for (const auto& statement : node.statements) + Visit(MandatoryNode(statement)); + } + + void ShaderValidator::Visit(const SwizzleOp& node) + { + if (node.componentCount > 4) + throw AstError{ "Cannot swizzle more than four elements" }; + + switch (MandatoryExpr(node.expression)->GetExpressionType()) + { + case ExpressionType::Float1: + case ExpressionType::Float2: + case ExpressionType::Float3: + case ExpressionType::Float4: + break; + + default: + throw AstError{ "Cannot swizzle this type" }; + } + + Visit(node.expression); + } + + bool Validate(const StatementPtr& shader, std::string* error) + { + ShaderValidator validator; + return validator.Validate(shader, error); + } +}