From a037eef4c9d2dad31daf35e9000d0f9df7b7c36d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Thu, 14 Jan 2021 22:01:53 +0100 Subject: [PATCH] Shader: Remove ShaderNode::GetExpressionType (replaced by visitor) and minor stuff --- include/Nazara/Shader/ShaderAstUtils.hpp | 54 +++++++++++++++++ include/Nazara/Shader/ShaderAstUtils.inl | 17 ++++++ include/Nazara/Shader/ShaderNodes.hpp | 9 +-- src/Nazara/Shader/GlslWriter.cpp | 4 +- src/Nazara/Shader/ShaderAstUtils.cpp | 73 +++++++++++++++++++++++ src/Nazara/Shader/ShaderAstValidator.cpp | 8 ++- src/Nazara/Shader/ShaderNodes.cpp | 20 ------- src/Nazara/Shader/SpirvExpressionLoad.cpp | 2 +- src/ShaderNode/ShaderGraph.cpp | 2 +- 9 files changed, 157 insertions(+), 32 deletions(-) create mode 100644 include/Nazara/Shader/ShaderAstUtils.hpp create mode 100644 include/Nazara/Shader/ShaderAstUtils.inl create mode 100644 src/Nazara/Shader/ShaderAstUtils.cpp diff --git a/include/Nazara/Shader/ShaderAstUtils.hpp b/include/Nazara/Shader/ShaderAstUtils.hpp new file mode 100644 index 000000000..e78ee657f --- /dev/null +++ b/include/Nazara/Shader/ShaderAstUtils.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERASTUTILS_HPP +#define NAZARA_SHADERASTUTILS_HPP + +#include +#include +#include +#include +#include + +namespace Nz +{ + class ShaderAst; + + class NAZARA_SHADER_API ShaderAstValueCategory final : public ShaderAstVisitorExcept + { + public: + ShaderAstValueCategory() = default; + ShaderAstValueCategory(const ShaderAstValueCategory&) = delete; + ShaderAstValueCategory(ShaderAstValueCategory&&) = delete; + ~ShaderAstValueCategory() = default; + + ShaderNodes::ExpressionCategory GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression); + + ShaderAstValueCategory& operator=(const ShaderAstValueCategory&) = delete; + ShaderAstValueCategory& operator=(ShaderAstValueCategory&&) = delete; + + private: + using ShaderAstVisitorExcept::Visit; + void Visit(ShaderNodes::AccessMember& node) override; + void Visit(ShaderNodes::AssignOp& node) override; + void Visit(ShaderNodes::BinaryOp& node) override; + void Visit(ShaderNodes::Cast& node) override; + void Visit(ShaderNodes::ConditionalExpression& node) override; + void Visit(ShaderNodes::Constant& node) override; + void Visit(ShaderNodes::Identifier& node) override; + void Visit(ShaderNodes::IntrinsicCall& node) override; + void Visit(ShaderNodes::Sample2D& node) override; + void Visit(ShaderNodes::SwizzleOp& node) override; + + ShaderNodes::ExpressionCategory m_expressionCategory; + }; + + inline ShaderNodes::ExpressionCategory GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression); +} + +#include + +#endif diff --git a/include/Nazara/Shader/ShaderAstUtils.inl b/include/Nazara/Shader/ShaderAstUtils.inl new file mode 100644 index 000000000..852b2e685 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstUtils.inl @@ -0,0 +1,17 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + ShaderNodes::ExpressionCategory GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression) + { + ShaderAstValueCategory visitor; + return visitor.GetExpressionCategory(expression); + } +} + +#include diff --git a/include/Nazara/Shader/ShaderNodes.hpp b/include/Nazara/Shader/ShaderNodes.hpp index 5a8f28ce7..e2e082b97 100644 --- a/include/Nazara/Shader/ShaderNodes.hpp +++ b/include/Nazara/Shader/ShaderNodes.hpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -55,12 +56,11 @@ namespace Nz using ExpressionPtr = std::shared_ptr; - class NAZARA_SHADER_API Expression : public Node + class NAZARA_SHADER_API Expression : public Node, public std::enable_shared_from_this { public: inline Expression(NodeType type); - virtual ExpressionCategory GetExpressionCategory() const; virtual ShaderExpressionType GetExpressionType() const = 0; }; @@ -68,7 +68,7 @@ namespace Nz using StatementPtr = std::shared_ptr; - class NAZARA_SHADER_API Statement : public Node + class NAZARA_SHADER_API Statement : public Node, public std::enable_shared_from_this { public: inline Statement(NodeType type); @@ -136,7 +136,6 @@ namespace Nz { inline Identifier(); - ExpressionCategory GetExpressionCategory() const override; ShaderExpressionType GetExpressionType() const override; void Visit(ShaderAstVisitor& visitor) override; @@ -149,7 +148,6 @@ namespace Nz { inline AccessMember(); - ExpressionCategory GetExpressionCategory() const override; ShaderExpressionType GetExpressionType() const override; void Visit(ShaderAstVisitor& visitor) override; @@ -265,7 +263,6 @@ namespace Nz { inline SwizzleOp(); - ExpressionCategory GetExpressionCategory() const override; ShaderExpressionType GetExpressionType() const override; void Visit(ShaderAstVisitor& visitor) override; diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 6179f094b..e1f32d414 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -349,7 +350,7 @@ namespace Nz void GlslWriter::Visit(ShaderNodes::ExpressionPtr& expr, bool encloseIfRequired) { - bool enclose = encloseIfRequired && (expr->GetExpressionCategory() != ShaderNodes::ExpressionCategory::LValue); + bool enclose = encloseIfRequired && (GetExpressionCategory(expr) != ShaderNodes::ExpressionCategory::LValue); if (enclose) Append("("); @@ -461,7 +462,6 @@ namespace Nz Append(")"); } - void GlslWriter::Visit(ShaderNodes::ConditionalExpression& node) { std::size_t conditionIndex = m_context.shader->FindConditionByName(node.conditionName); diff --git a/src/Nazara/Shader/ShaderAstUtils.cpp b/src/Nazara/Shader/ShaderAstUtils.cpp new file mode 100644 index 000000000..d13571ded --- /dev/null +++ b/src/Nazara/Shader/ShaderAstUtils.cpp @@ -0,0 +1,73 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + ShaderNodes::ExpressionCategory ShaderAstValueCategory::GetExpressionCategory(const ShaderNodes::ExpressionPtr& expression) + { + Visit(expression); + return m_expressionCategory; + } + + void ShaderAstValueCategory::Visit(ShaderNodes::AccessMember& node) + { + Visit(node.structExpr); + } + + void ShaderAstValueCategory::Visit(ShaderNodes::AssignOp& node) + { + m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + } + + void ShaderAstValueCategory::Visit(ShaderNodes::BinaryOp& node) + { + m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + } + + void ShaderAstValueCategory::Visit(ShaderNodes::Cast& node) + { + m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + } + + void ShaderAstValueCategory::Visit(ShaderNodes::ConditionalExpression& node) + { + Visit(node.truePath); + ShaderNodes::ExpressionCategory trueExprCategory = m_expressionCategory; + Visit(node.falsePath); + ShaderNodes::ExpressionCategory falseExprCategory = m_expressionCategory; + + if (trueExprCategory == ShaderNodes::ExpressionCategory::RValue || falseExprCategory == ShaderNodes::ExpressionCategory::RValue) + m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + else + m_expressionCategory = ShaderNodes::ExpressionCategory::LValue; + } + + void ShaderAstValueCategory::Visit(ShaderNodes::Constant& node) + { + m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + } + + void ShaderAstValueCategory::Visit(ShaderNodes::Identifier& node) + { + m_expressionCategory = ShaderNodes::ExpressionCategory::LValue; + } + + void ShaderAstValueCategory::Visit(ShaderNodes::IntrinsicCall& node) + { + m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + } + + void ShaderAstValueCategory::Visit(ShaderNodes::Sample2D& node) + { + m_expressionCategory = ShaderNodes::ExpressionCategory::RValue; + } + + void ShaderAstValueCategory::Visit(ShaderNodes::SwizzleOp& node) + { + Visit(node.expression); + } +} diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index 90f8aeef2..ec2d9a7d7 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -126,7 +127,7 @@ namespace Nz MandatoryNode(node.right); TypeMustMatch(node.left, node.right); - if (node.left->GetExpressionCategory() != ShaderNodes::ExpressionCategory::LValue) + if (GetExpressionCategory(node.left) != ShaderNodes::ExpressionCategory::LValue) throw AstError { "Assignation is only possible with a l-value" }; ShaderAstRecursiveVisitor::Visit(node); @@ -221,7 +222,10 @@ namespace Nz { for (const auto& condStatement : node.condStatements) { - MandatoryNode(condStatement.condition); + const ShaderExpressionType& condType = MandatoryExpr(condStatement.condition)->GetExpressionType(); + if (!IsBasicType(condType) || std::get(condType) != ShaderNodes::BasicType::Boolean) + throw AstError{ "if expression must resolve to boolean type" }; + MandatoryNode(condStatement.statement); } diff --git a/src/Nazara/Shader/ShaderNodes.cpp b/src/Nazara/Shader/ShaderNodes.cpp index 8d96c40ea..0e617297a 100644 --- a/src/Nazara/Shader/ShaderNodes.cpp +++ b/src/Nazara/Shader/ShaderNodes.cpp @@ -13,11 +13,6 @@ namespace Nz::ShaderNodes { Node::~Node() = default; - ExpressionCategory Expression::GetExpressionCategory() const - { - return ExpressionCategory::RValue; - } - void ExpressionStatement::Visit(ShaderAstVisitor& visitor) { visitor.Visit(*this); @@ -48,11 +43,6 @@ namespace Nz::ShaderNodes } - ExpressionCategory Identifier::GetExpressionCategory() const - { - return ExpressionCategory::LValue; - } - ShaderExpressionType Identifier::GetExpressionType() const { assert(var); @@ -64,11 +54,6 @@ namespace Nz::ShaderNodes visitor.Visit(*this); } - ExpressionCategory AccessMember::GetExpressionCategory() const - { - return structExpr->GetExpressionCategory(); - } - ShaderExpressionType AccessMember::GetExpressionType() const { return exprType; @@ -231,11 +216,6 @@ namespace Nz::ShaderNodes } - ExpressionCategory SwizzleOp::GetExpressionCategory() const - { - return expression->GetExpressionCategory(); - } - ShaderExpressionType SwizzleOp::GetExpressionType() const { const ShaderExpressionType& exprType = expression->GetExpressionType(); diff --git a/src/Nazara/Shader/SpirvExpressionLoad.cpp b/src/Nazara/Shader/SpirvExpressionLoad.cpp index b82d9b046..93af139e0 100644 --- a/src/Nazara/Shader/SpirvExpressionLoad.cpp +++ b/src/Nazara/Shader/SpirvExpressionLoad.cpp @@ -13,7 +13,7 @@ namespace Nz namespace { template struct overloaded : Ts... { using Ts::operator()...; }; - template overloaded(Ts...)->overloaded; + template overloaded(Ts...) -> overloaded; } UInt32 SpirvExpressionLoad::Evaluate(ShaderNodes::Expression& node) diff --git a/src/ShaderNode/ShaderGraph.cpp b/src/ShaderNode/ShaderGraph.cpp index d385b098d..21be5ec9a 100644 --- a/src/ShaderNode/ShaderGraph.cpp +++ b/src/ShaderNode/ShaderGraph.cpp @@ -531,7 +531,7 @@ Nz::ShaderNodes::StatementPtr ShaderGraph::ToAst() auto expression = std::static_pointer_cast(astNode); Nz::ShaderNodes::ExpressionPtr varExpression; - if (expression->GetExpressionCategory() == Nz::ShaderNodes::ExpressionCategory::RValue) + if (Nz::GetExpressionCategory(expression) == Nz::ShaderNodes::ExpressionCategory::RValue) { std::string name; if (variableName.empty())