From 00ac6e8a0a886e16b1161c1c7ae7623c03a779dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Thu, 14 Jan 2021 22:01:05 +0100 Subject: [PATCH] Shader: Add ShaderAstOptimizer --- include/Nazara/Shader/ShaderAstCloner.hpp | 2 + include/Nazara/Shader/ShaderAstOptimizer.hpp | 50 ++ include/Nazara/Shader/ShaderAstOptimizer.inl | 12 + src/Nazara/Shader/ShaderAstOptimizer.cpp | 544 +++++++++++++++++++ 4 files changed, 608 insertions(+) create mode 100644 include/Nazara/Shader/ShaderAstOptimizer.hpp create mode 100644 include/Nazara/Shader/ShaderAstOptimizer.inl create mode 100644 src/Nazara/Shader/ShaderAstOptimizer.cpp diff --git a/include/Nazara/Shader/ShaderAstCloner.hpp b/include/Nazara/Shader/ShaderAstCloner.hpp index ddd55e4ba..94b68cb0a 100644 --- a/include/Nazara/Shader/ShaderAstCloner.hpp +++ b/include/Nazara/Shader/ShaderAstCloner.hpp @@ -33,6 +33,7 @@ namespace Nz ShaderNodes::StatementPtr CloneStatement(const ShaderNodes::StatementPtr& statement); ShaderNodes::VariablePtr CloneVariable(const ShaderNodes::VariablePtr& statement); + using ShaderAstVisitor::Visit; void Visit(ShaderNodes::AccessMember& node) override; void Visit(ShaderNodes::AssignOp& node) override; void Visit(ShaderNodes::BinaryOp& node) override; @@ -51,6 +52,7 @@ namespace Nz void Visit(ShaderNodes::StatementBlock& node) override; void Visit(ShaderNodes::SwizzleOp& node) override; + using ShaderVarVisitor::Visit; void Visit(ShaderNodes::BuiltinVariable& var) override; void Visit(ShaderNodes::InputVariable& var) override; void Visit(ShaderNodes::LocalVariable& var) override; diff --git a/include/Nazara/Shader/ShaderAstOptimizer.hpp b/include/Nazara/Shader/ShaderAstOptimizer.hpp new file mode 100644 index 000000000..d7d56b7cf --- /dev/null +++ b/include/Nazara/Shader/ShaderAstOptimizer.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADERASTOPTIMISER_HPP +#define NAZARA_SHADERASTOPTIMISER_HPP + +#include +#include +#include +#include + +namespace Nz +{ + class ShaderAst; + + class NAZARA_SHADER_API ShaderAstOptimizer : public ShaderAstCloner + { + public: + ShaderAstOptimizer() = default; + ShaderAstOptimizer(const ShaderAstOptimizer&) = delete; + ShaderAstOptimizer(ShaderAstOptimizer&&) = delete; + ~ShaderAstOptimizer() = default; + + ShaderNodes::StatementPtr Optimise(const ShaderNodes::StatementPtr& statement); + ShaderNodes::StatementPtr Optimise(const ShaderNodes::StatementPtr& statement, const ShaderAst& shader, UInt64 enabledConditions); + + ShaderAstOptimizer& operator=(const ShaderAstOptimizer&) = delete; + ShaderAstOptimizer& operator=(ShaderAstOptimizer&&) = delete; + + protected: + using ShaderAstCloner::Visit; + void Visit(ShaderNodes::BinaryOp& node) override; + void Visit(ShaderNodes::Branch& node) override; + void Visit(ShaderNodes::ConditionalExpression& node) override; + void Visit(ShaderNodes::ConditionalStatement& node) override; + + template void PropagateConstant(const std::shared_ptr& lhs, const std::shared_ptr& rhs); + + private: + const ShaderAst* m_shaderAst; + UInt64 m_enabledConditions; + }; +} + +#include + +#endif diff --git a/include/Nazara/Shader/ShaderAstOptimizer.inl b/include/Nazara/Shader/ShaderAstOptimizer.inl new file mode 100644 index 000000000..17cd14c12 --- /dev/null +++ b/include/Nazara/Shader/ShaderAstOptimizer.inl @@ -0,0 +1,12 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ +} + +#include diff --git a/src/Nazara/Shader/ShaderAstOptimizer.cpp b/src/Nazara/Shader/ShaderAstOptimizer.cpp new file mode 100644 index 000000000..6bc61c1e4 --- /dev/null +++ b/src/Nazara/Shader/ShaderAstOptimizer.cpp @@ -0,0 +1,544 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include +#include +#include + +namespace Nz +{ + namespace + { + template + struct is_complete_helper + { + template static auto test(U*)->std::integral_constant; + static auto test(...) -> std::false_type; + + using type = decltype(test((T*)0)); + }; + + template + struct is_complete : is_complete_helper::type {}; + + template + inline constexpr bool is_complete_v = is_complete::value; + + + template + struct PropagateConstantType; + + // CompEq + template + struct CompEqBase + { + ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs == rhs); + } + }; + + template + struct CompEq; + + template + struct PropagateConstantType + { + using Op = typename CompEq; + }; + + // CompGe + template + struct CompGeBase + { + ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs >= rhs); + } + }; + + template + struct CompGe; + + template + struct PropagateConstantType + { + using Op = typename CompGe; + }; + + // CompGt + template + struct CompGtBase + { + ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs > rhs); + } + }; + + template + struct CompGt; + + template + struct PropagateConstantType + { + using Op = typename CompGt; + }; + + // CompLe + template + struct CompLeBase + { + ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs <= rhs); + } + }; + + template + struct CompLe; + + template + struct PropagateConstantType + { + using Op = typename CompLe; + }; + + // CompLt + template + struct CompLtBase + { + ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs < rhs); + } + }; + + template + struct CompLt; + + template + struct PropagateConstantType + { + using Op = typename CompLe; + }; + + // CompNe + template + struct CompNeBase + { + ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs != rhs); + } + }; + + template + struct CompNe; + + template + struct PropagateConstantType + { + using Op = typename CompNe; + }; + + // Addition + template + struct AdditionBase + { + ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs + rhs); + } + }; + + template + struct Addition; + + template + struct PropagateConstantType + { + using Op = typename Addition; + }; + + // Division + template + struct DivisionBase + { + ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs / rhs); + } + }; + + template + struct Division; + + template + struct PropagateConstantType + { + using Op = typename Division; + }; + + // Multiplication + template + struct MultiplicationBase + { + ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs * rhs); + } + }; + + template + struct Multiplication; + + template + struct PropagateConstantType + { + using Op = typename Multiplication; + }; + + // Subtraction + template + struct SubtractionBase + { + ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs - rhs); + } + }; + + template + struct Subtraction; + + template + struct PropagateConstantType + { + using Op = typename Subtraction; + }; + +#define EnableOptimisation(Op, T1, T2) template<> struct Op : Op##Base {} + + EnableOptimisation(CompEq, bool, bool); + EnableOptimisation(CompEq, double, double); + EnableOptimisation(CompEq, float, float); + EnableOptimisation(CompEq, Nz::Int32, Nz::Int32); + EnableOptimisation(CompEq, Nz::Vector2f, Nz::Vector2f); + EnableOptimisation(CompEq, Nz::Vector3f, Nz::Vector3f); + EnableOptimisation(CompEq, Nz::Vector4f, Nz::Vector4f); + EnableOptimisation(CompEq, Nz::Vector2i32, Nz::Vector2i32); + EnableOptimisation(CompEq, Nz::Vector3i32, Nz::Vector3i32); + EnableOptimisation(CompEq, Nz::Vector4i32, Nz::Vector4i32); + + EnableOptimisation(CompGe, bool, bool); + EnableOptimisation(CompGe, double, double); + EnableOptimisation(CompGe, float, float); + EnableOptimisation(CompGe, Nz::Int32, Nz::Int32); + EnableOptimisation(CompGe, Nz::Vector2f, Nz::Vector2f); + EnableOptimisation(CompGe, Nz::Vector3f, Nz::Vector3f); + EnableOptimisation(CompGe, Nz::Vector4f, Nz::Vector4f); + EnableOptimisation(CompGe, Nz::Vector2i32, Nz::Vector2i32); + EnableOptimisation(CompGe, Nz::Vector3i32, Nz::Vector3i32); + EnableOptimisation(CompGe, Nz::Vector4i32, Nz::Vector4i32); + + EnableOptimisation(CompGt, bool, bool); + EnableOptimisation(CompGt, double, double); + EnableOptimisation(CompGt, float, float); + EnableOptimisation(CompGt, Nz::Int32, Nz::Int32); + EnableOptimisation(CompGt, Nz::Vector2f, Nz::Vector2f); + EnableOptimisation(CompGt, Nz::Vector3f, Nz::Vector3f); + EnableOptimisation(CompGt, Nz::Vector4f, Nz::Vector4f); + EnableOptimisation(CompGt, Nz::Vector2i32, Nz::Vector2i32); + EnableOptimisation(CompGt, Nz::Vector3i32, Nz::Vector3i32); + EnableOptimisation(CompGt, Nz::Vector4i32, Nz::Vector4i32); + + EnableOptimisation(CompLe, bool, bool); + EnableOptimisation(CompLe, double, double); + EnableOptimisation(CompLe, float, float); + EnableOptimisation(CompLe, Nz::Int32, Nz::Int32); + EnableOptimisation(CompLe, Nz::Vector2f, Nz::Vector2f); + EnableOptimisation(CompLe, Nz::Vector3f, Nz::Vector3f); + EnableOptimisation(CompLe, Nz::Vector4f, Nz::Vector4f); + EnableOptimisation(CompLe, Nz::Vector2i32, Nz::Vector2i32); + EnableOptimisation(CompLe, Nz::Vector3i32, Nz::Vector3i32); + EnableOptimisation(CompLe, Nz::Vector4i32, Nz::Vector4i32); + + EnableOptimisation(CompLt, bool, bool); + EnableOptimisation(CompLt, double, double); + EnableOptimisation(CompLt, float, float); + EnableOptimisation(CompLt, Nz::Int32, Nz::Int32); + EnableOptimisation(CompLt, Nz::Vector2f, Nz::Vector2f); + EnableOptimisation(CompLt, Nz::Vector3f, Nz::Vector3f); + EnableOptimisation(CompLt, Nz::Vector4f, Nz::Vector4f); + EnableOptimisation(CompLt, Nz::Vector2i32, Nz::Vector2i32); + EnableOptimisation(CompLt, Nz::Vector3i32, Nz::Vector3i32); + EnableOptimisation(CompLt, Nz::Vector4i32, Nz::Vector4i32); + + EnableOptimisation(CompNe, bool, bool); + EnableOptimisation(CompNe, double, double); + EnableOptimisation(CompNe, float, float); + EnableOptimisation(CompNe, Nz::Int32, Nz::Int32); + EnableOptimisation(CompNe, Nz::Vector2f, Nz::Vector2f); + EnableOptimisation(CompNe, Nz::Vector3f, Nz::Vector3f); + EnableOptimisation(CompNe, Nz::Vector4f, Nz::Vector4f); + EnableOptimisation(CompNe, Nz::Vector2i32, Nz::Vector2i32); + EnableOptimisation(CompNe, Nz::Vector3i32, Nz::Vector3i32); + EnableOptimisation(CompNe, Nz::Vector4i32, Nz::Vector4i32); + + EnableOptimisation(Addition, double, double); + EnableOptimisation(Addition, float, float); + EnableOptimisation(Addition, Nz::Int32, Nz::Int32); + EnableOptimisation(Addition, Nz::Vector2f, Nz::Vector2f); + EnableOptimisation(Addition, Nz::Vector3f, Nz::Vector3f); + EnableOptimisation(Addition, Nz::Vector4f, Nz::Vector4f); + EnableOptimisation(Addition, Nz::Vector2i32, Nz::Vector2i32); + EnableOptimisation(Addition, Nz::Vector3i32, Nz::Vector3i32); + EnableOptimisation(Addition, Nz::Vector4i32, Nz::Vector4i32); + + EnableOptimisation(Division, double, double); + EnableOptimisation(Division, double, Nz::Vector2d); + EnableOptimisation(Division, double, Nz::Vector3d); + EnableOptimisation(Division, double, Nz::Vector4d); + EnableOptimisation(Division, float, float); + EnableOptimisation(Division, float, Nz::Vector2f); + EnableOptimisation(Division, float, Nz::Vector3f); + EnableOptimisation(Division, float, Nz::Vector4f); + EnableOptimisation(Division, Nz::Int32, Nz::Int32); + EnableOptimisation(Division, Nz::Int32, Nz::Vector2i32); + EnableOptimisation(Division, Nz::Int32, Nz::Vector3i32); + EnableOptimisation(Division, Nz::Int32, Nz::Vector4i32); + EnableOptimisation(Division, Nz::Vector2f, float); + EnableOptimisation(Division, Nz::Vector2f, Nz::Vector2f); + EnableOptimisation(Division, Nz::Vector3f, float); + EnableOptimisation(Division, Nz::Vector3f, Nz::Vector3f); + EnableOptimisation(Division, Nz::Vector4f, float); + EnableOptimisation(Division, Nz::Vector4f, Nz::Vector4f); + EnableOptimisation(Division, Nz::Vector2d, double); + EnableOptimisation(Division, Nz::Vector2d, Nz::Vector2d); + EnableOptimisation(Division, Nz::Vector3d, double); + EnableOptimisation(Division, Nz::Vector3d, Nz::Vector3d); + EnableOptimisation(Division, Nz::Vector4d, double); + EnableOptimisation(Division, Nz::Vector4d, Nz::Vector4d); + EnableOptimisation(Division, Nz::Vector2i32, Nz::Int32); + EnableOptimisation(Division, Nz::Vector2i32, Nz::Vector2i32); + EnableOptimisation(Division, Nz::Vector3i32, Nz::Int32); + EnableOptimisation(Division, Nz::Vector3i32, Nz::Vector3i32); + EnableOptimisation(Division, Nz::Vector4i32, Nz::Int32); + EnableOptimisation(Division, Nz::Vector4i32, Nz::Vector4i32); + + EnableOptimisation(Multiplication, double, double); + EnableOptimisation(Multiplication, double, Nz::Vector2d); + EnableOptimisation(Multiplication, double, Nz::Vector3d); + EnableOptimisation(Multiplication, double, Nz::Vector4d); + EnableOptimisation(Multiplication, float, float); + EnableOptimisation(Multiplication, float, Nz::Vector2f); + EnableOptimisation(Multiplication, float, Nz::Vector3f); + EnableOptimisation(Multiplication, float, Nz::Vector4f); + EnableOptimisation(Multiplication, Nz::Int32, Nz::Int32); + EnableOptimisation(Multiplication, Nz::Int32, Nz::Vector2i32); + EnableOptimisation(Multiplication, Nz::Int32, Nz::Vector3i32); + EnableOptimisation(Multiplication, Nz::Int32, Nz::Vector4i32); + EnableOptimisation(Multiplication, Nz::Vector2f, float); + EnableOptimisation(Multiplication, Nz::Vector2f, Nz::Vector2f); + EnableOptimisation(Multiplication, Nz::Vector3f, float); + EnableOptimisation(Multiplication, Nz::Vector3f, Nz::Vector3f); + EnableOptimisation(Multiplication, Nz::Vector4f, float); + EnableOptimisation(Multiplication, Nz::Vector4f, Nz::Vector4f); + EnableOptimisation(Multiplication, Nz::Vector2d, double); + EnableOptimisation(Multiplication, Nz::Vector2d, Nz::Vector2d); + EnableOptimisation(Multiplication, Nz::Vector3d, double); + EnableOptimisation(Multiplication, Nz::Vector3d, Nz::Vector3d); + EnableOptimisation(Multiplication, Nz::Vector4d, double); + EnableOptimisation(Multiplication, Nz::Vector4d, Nz::Vector4d); + EnableOptimisation(Multiplication, Nz::Vector2i32, Nz::Int32); + EnableOptimisation(Multiplication, Nz::Vector2i32, Nz::Vector2i32); + EnableOptimisation(Multiplication, Nz::Vector3i32, Nz::Int32); + EnableOptimisation(Multiplication, Nz::Vector3i32, Nz::Vector3i32); + EnableOptimisation(Multiplication, Nz::Vector4i32, Nz::Int32); + EnableOptimisation(Multiplication, Nz::Vector4i32, Nz::Vector4i32); + + EnableOptimisation(Subtraction, double, double); + EnableOptimisation(Subtraction, float, float); + EnableOptimisation(Subtraction, Nz::Int32, Nz::Int32); + EnableOptimisation(Subtraction, Nz::Vector2f, Nz::Vector2f); + EnableOptimisation(Subtraction, Nz::Vector3f, Nz::Vector3f); + EnableOptimisation(Subtraction, Nz::Vector4f, Nz::Vector4f); + EnableOptimisation(Subtraction, Nz::Vector2i32, Nz::Vector2i32); + EnableOptimisation(Subtraction, Nz::Vector3i32, Nz::Vector3i32); + EnableOptimisation(Subtraction, Nz::Vector4i32, Nz::Vector4i32); + +#undef EnableOptimisation + } + + ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement) + { + m_shaderAst = nullptr; + + return CloneStatement(statement); + } + + ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement, const ShaderAst& shader, UInt64 enabledConditions) + { + m_shaderAst = &shader; + m_enabledConditions = enabledConditions; + + return CloneStatement(statement); + } + + void ShaderAstOptimizer::Visit(ShaderNodes::BinaryOp& node) + { + auto lhs = CloneExpression(node.left); + auto rhs = CloneExpression(node.right); + + if (lhs->GetType() == ShaderNodes::NodeType::Constant && rhs->GetType() == ShaderNodes::NodeType::Constant) + { + auto lhsConstant = std::static_pointer_cast(lhs); + auto rhsConstant = std::static_pointer_cast(rhs); + + switch (node.op) + { + case ShaderNodes::BinaryType::Add: + return PropagateConstant(lhsConstant, rhsConstant); + + case ShaderNodes::BinaryType::Subtract: + return PropagateConstant(lhsConstant, rhsConstant); + + case ShaderNodes::BinaryType::Multiply: + return PropagateConstant(lhsConstant, rhsConstant); + + case ShaderNodes::BinaryType::Divide: + return PropagateConstant(lhsConstant, rhsConstant); + + case ShaderNodes::BinaryType::CompEq: + return PropagateConstant(lhsConstant, rhsConstant); + + case ShaderNodes::BinaryType::CompGe: + return PropagateConstant(lhsConstant, rhsConstant); + + case ShaderNodes::BinaryType::CompGt: + return PropagateConstant(lhsConstant, rhsConstant); + + case ShaderNodes::BinaryType::CompLe: + return PropagateConstant(lhsConstant, rhsConstant); + + case ShaderNodes::BinaryType::CompLt: + return PropagateConstant(lhsConstant, rhsConstant); + + case ShaderNodes::BinaryType::CompNe: + return PropagateConstant(lhsConstant, rhsConstant); + } + } + + ShaderAstCloner::Visit(node); + } + + void ShaderAstOptimizer::Visit(ShaderNodes::Branch& node) + { + std::vector statements; + ShaderNodes::StatementPtr elseStatement; + + for (auto& condStatement : node.condStatements) + { + auto cond = CloneExpression(condStatement.condition); + + if (cond->GetType() == ShaderNodes::NodeType::Constant) + { + auto constant = std::static_pointer_cast(cond); + + assert(IsBasicType(cond->GetExpressionType())); + assert(std::get(cond->GetExpressionType()) == ShaderNodes::BasicType::Boolean); + + bool cValue = std::get(constant->value); + if (!cValue) + continue; + + if (statements.empty()) + { + // First condition is true, dismiss the branch + Visit(condStatement.statement); + return; + } + else + { + // Some condition after the first one is true, make it the else statement and stop there + elseStatement = CloneStatement(condStatement.statement); + break; + } + } + } + + if (statements.empty()) + { + // All conditions have been removed, replace by else statement or no-op + if (node.elseStatement) + return Visit(node.elseStatement); + else + return PushStatement(ShaderNodes::NoOp::Build()); + } + + if (!elseStatement) + elseStatement = CloneStatement(node.elseStatement); + + PushStatement(ShaderNodes::Branch::Build(std::move(statements), std::move(elseStatement))); + } + + void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalExpression& node) + { + if (!m_shaderAst) + return ShaderAstCloner::Visit(node); + + std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName); + assert(conditionIndex != ShaderAst::InvalidCondition); + + if (TestBit(m_enabledConditions, conditionIndex)) + Visit(node.truePath); + else + Visit(node.falsePath); + } + + void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalStatement& node) + { + if (!m_shaderAst) + return ShaderAstCloner::Visit(node); + + std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName); + assert(conditionIndex != ShaderAst::InvalidCondition); + + if (TestBit(m_enabledConditions, conditionIndex)) + Visit(node.statement); + } + + template + void ShaderAstOptimizer::PropagateConstant(const std::shared_ptr& lhs, const std::shared_ptr& rhs) + { + ShaderNodes::ExpressionPtr optimized; + std::visit([&](auto&& arg1) + { + using T1 = std::decay_t; + + std::visit([&](auto&& arg2) + { + using T2 = std::decay_t; + using PCType = PropagateConstantType; + + if constexpr (is_complete_v) + { + using Op = typename PCType::Op; + if constexpr (is_complete_v) + optimized = Op{}(arg1, arg2); + } + + }, rhs->value); + }, lhs->value); + + if (optimized) + PushExpression(optimized); + else + PushExpression(ShaderNodes::BinaryOp::Build(Type, lhs, rhs)); + } +}