// Copyright (C) 2020 Jérôme Leclercq // This file is part of the "Nazara Engine - Shader generator" // For conditions of distribution and use, see copyright notice in Config.hpp #include #include #include #include #include #include namespace Nz { namespace { template struct is_complete_helper { template static auto test(U*)->std::integral_constant; static auto test(...) -> std::false_type; using type = decltype(test((T*)0)); }; template struct is_complete : is_complete_helper::type {}; template inline constexpr bool is_complete_v = is_complete::value; template struct PropagateConstantType; // CompEq template struct CompEqBase { ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs == rhs); } }; template struct CompEq; template struct PropagateConstantType { using Op = typename CompEq; }; // CompGe template struct CompGeBase { ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs >= rhs); } }; template struct CompGe; template struct PropagateConstantType { using Op = typename CompGe; }; // CompGt template struct CompGtBase { ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs > rhs); } }; template struct CompGt; template struct PropagateConstantType { using Op = typename CompGt; }; // CompLe template struct CompLeBase { ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs <= rhs); } }; template struct CompLe; template struct PropagateConstantType { using Op = typename CompLe; }; // CompLt template struct CompLtBase { ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs < rhs); } }; template struct CompLt; template struct PropagateConstantType { using Op = typename CompLe; }; // CompNe template struct CompNeBase { ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs != rhs); } }; template struct CompNe; template struct PropagateConstantType { using Op = typename CompNe; }; // Addition template struct AdditionBase { ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs + rhs); } }; template struct Addition; template struct PropagateConstantType { using Op = typename Addition; }; // Division template struct DivisionBase { ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs / rhs); } }; template struct Division; template struct PropagateConstantType { using Op = typename Division; }; // Multiplication template struct MultiplicationBase { ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs * rhs); } }; template struct Multiplication; template struct PropagateConstantType { using Op = typename Multiplication; }; // Subtraction template struct SubtractionBase { ShaderNodes::ExpressionPtr operator()(const T1& lhs, const T2& rhs) { return ShaderBuilder::Constant(lhs - rhs); } }; template struct Subtraction; template struct PropagateConstantType { using Op = typename Subtraction; }; #define EnableOptimisation(Op, T1, T2) template<> struct Op : Op##Base {} EnableOptimisation(CompEq, bool, bool); EnableOptimisation(CompEq, double, double); EnableOptimisation(CompEq, float, float); EnableOptimisation(CompEq, Nz::Int32, Nz::Int32); EnableOptimisation(CompEq, Nz::Vector2f, Nz::Vector2f); EnableOptimisation(CompEq, Nz::Vector3f, Nz::Vector3f); EnableOptimisation(CompEq, Nz::Vector4f, Nz::Vector4f); EnableOptimisation(CompEq, Nz::Vector2i32, Nz::Vector2i32); EnableOptimisation(CompEq, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(CompEq, Nz::Vector4i32, Nz::Vector4i32); EnableOptimisation(CompGe, bool, bool); EnableOptimisation(CompGe, double, double); EnableOptimisation(CompGe, float, float); EnableOptimisation(CompGe, Nz::Int32, Nz::Int32); EnableOptimisation(CompGe, Nz::Vector2f, Nz::Vector2f); EnableOptimisation(CompGe, Nz::Vector3f, Nz::Vector3f); EnableOptimisation(CompGe, Nz::Vector4f, Nz::Vector4f); EnableOptimisation(CompGe, Nz::Vector2i32, Nz::Vector2i32); EnableOptimisation(CompGe, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(CompGe, Nz::Vector4i32, Nz::Vector4i32); EnableOptimisation(CompGt, bool, bool); EnableOptimisation(CompGt, double, double); EnableOptimisation(CompGt, float, float); EnableOptimisation(CompGt, Nz::Int32, Nz::Int32); EnableOptimisation(CompGt, Nz::Vector2f, Nz::Vector2f); EnableOptimisation(CompGt, Nz::Vector3f, Nz::Vector3f); EnableOptimisation(CompGt, Nz::Vector4f, Nz::Vector4f); EnableOptimisation(CompGt, Nz::Vector2i32, Nz::Vector2i32); EnableOptimisation(CompGt, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(CompGt, Nz::Vector4i32, Nz::Vector4i32); EnableOptimisation(CompLe, bool, bool); EnableOptimisation(CompLe, double, double); EnableOptimisation(CompLe, float, float); EnableOptimisation(CompLe, Nz::Int32, Nz::Int32); EnableOptimisation(CompLe, Nz::Vector2f, Nz::Vector2f); EnableOptimisation(CompLe, Nz::Vector3f, Nz::Vector3f); EnableOptimisation(CompLe, Nz::Vector4f, Nz::Vector4f); EnableOptimisation(CompLe, Nz::Vector2i32, Nz::Vector2i32); EnableOptimisation(CompLe, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(CompLe, Nz::Vector4i32, Nz::Vector4i32); EnableOptimisation(CompLt, bool, bool); EnableOptimisation(CompLt, double, double); EnableOptimisation(CompLt, float, float); EnableOptimisation(CompLt, Nz::Int32, Nz::Int32); EnableOptimisation(CompLt, Nz::Vector2f, Nz::Vector2f); EnableOptimisation(CompLt, Nz::Vector3f, Nz::Vector3f); EnableOptimisation(CompLt, Nz::Vector4f, Nz::Vector4f); EnableOptimisation(CompLt, Nz::Vector2i32, Nz::Vector2i32); EnableOptimisation(CompLt, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(CompLt, Nz::Vector4i32, Nz::Vector4i32); EnableOptimisation(CompNe, bool, bool); EnableOptimisation(CompNe, double, double); EnableOptimisation(CompNe, float, float); EnableOptimisation(CompNe, Nz::Int32, Nz::Int32); EnableOptimisation(CompNe, Nz::Vector2f, Nz::Vector2f); EnableOptimisation(CompNe, Nz::Vector3f, Nz::Vector3f); EnableOptimisation(CompNe, Nz::Vector4f, Nz::Vector4f); EnableOptimisation(CompNe, Nz::Vector2i32, Nz::Vector2i32); EnableOptimisation(CompNe, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(CompNe, Nz::Vector4i32, Nz::Vector4i32); EnableOptimisation(Addition, double, double); EnableOptimisation(Addition, float, float); EnableOptimisation(Addition, Nz::Int32, Nz::Int32); EnableOptimisation(Addition, Nz::Vector2f, Nz::Vector2f); EnableOptimisation(Addition, Nz::Vector3f, Nz::Vector3f); EnableOptimisation(Addition, Nz::Vector4f, Nz::Vector4f); EnableOptimisation(Addition, Nz::Vector2i32, Nz::Vector2i32); EnableOptimisation(Addition, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(Addition, Nz::Vector4i32, Nz::Vector4i32); EnableOptimisation(Division, double, double); EnableOptimisation(Division, double, Nz::Vector2d); EnableOptimisation(Division, double, Nz::Vector3d); EnableOptimisation(Division, double, Nz::Vector4d); EnableOptimisation(Division, float, float); EnableOptimisation(Division, float, Nz::Vector2f); EnableOptimisation(Division, float, Nz::Vector3f); EnableOptimisation(Division, float, Nz::Vector4f); EnableOptimisation(Division, Nz::Int32, Nz::Int32); EnableOptimisation(Division, Nz::Int32, Nz::Vector2i32); EnableOptimisation(Division, Nz::Int32, Nz::Vector3i32); EnableOptimisation(Division, Nz::Int32, Nz::Vector4i32); EnableOptimisation(Division, Nz::Vector2f, float); EnableOptimisation(Division, Nz::Vector2f, Nz::Vector2f); EnableOptimisation(Division, Nz::Vector3f, float); EnableOptimisation(Division, Nz::Vector3f, Nz::Vector3f); EnableOptimisation(Division, Nz::Vector4f, float); EnableOptimisation(Division, Nz::Vector4f, Nz::Vector4f); EnableOptimisation(Division, Nz::Vector2d, double); EnableOptimisation(Division, Nz::Vector2d, Nz::Vector2d); EnableOptimisation(Division, Nz::Vector3d, double); EnableOptimisation(Division, Nz::Vector3d, Nz::Vector3d); EnableOptimisation(Division, Nz::Vector4d, double); EnableOptimisation(Division, Nz::Vector4d, Nz::Vector4d); EnableOptimisation(Division, Nz::Vector2i32, Nz::Int32); EnableOptimisation(Division, Nz::Vector2i32, Nz::Vector2i32); EnableOptimisation(Division, Nz::Vector3i32, Nz::Int32); EnableOptimisation(Division, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(Division, Nz::Vector4i32, Nz::Int32); EnableOptimisation(Division, Nz::Vector4i32, Nz::Vector4i32); EnableOptimisation(Multiplication, double, double); EnableOptimisation(Multiplication, double, Nz::Vector2d); EnableOptimisation(Multiplication, double, Nz::Vector3d); EnableOptimisation(Multiplication, double, Nz::Vector4d); EnableOptimisation(Multiplication, float, float); EnableOptimisation(Multiplication, float, Nz::Vector2f); EnableOptimisation(Multiplication, float, Nz::Vector3f); EnableOptimisation(Multiplication, float, Nz::Vector4f); EnableOptimisation(Multiplication, Nz::Int32, Nz::Int32); EnableOptimisation(Multiplication, Nz::Int32, Nz::Vector2i32); EnableOptimisation(Multiplication, Nz::Int32, Nz::Vector3i32); EnableOptimisation(Multiplication, Nz::Int32, Nz::Vector4i32); EnableOptimisation(Multiplication, Nz::Vector2f, float); EnableOptimisation(Multiplication, Nz::Vector2f, Nz::Vector2f); EnableOptimisation(Multiplication, Nz::Vector3f, float); EnableOptimisation(Multiplication, Nz::Vector3f, Nz::Vector3f); EnableOptimisation(Multiplication, Nz::Vector4f, float); EnableOptimisation(Multiplication, Nz::Vector4f, Nz::Vector4f); EnableOptimisation(Multiplication, Nz::Vector2d, double); EnableOptimisation(Multiplication, Nz::Vector2d, Nz::Vector2d); EnableOptimisation(Multiplication, Nz::Vector3d, double); EnableOptimisation(Multiplication, Nz::Vector3d, Nz::Vector3d); EnableOptimisation(Multiplication, Nz::Vector4d, double); EnableOptimisation(Multiplication, Nz::Vector4d, Nz::Vector4d); EnableOptimisation(Multiplication, Nz::Vector2i32, Nz::Int32); EnableOptimisation(Multiplication, Nz::Vector2i32, Nz::Vector2i32); EnableOptimisation(Multiplication, Nz::Vector3i32, Nz::Int32); EnableOptimisation(Multiplication, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(Multiplication, Nz::Vector4i32, Nz::Int32); EnableOptimisation(Multiplication, Nz::Vector4i32, Nz::Vector4i32); EnableOptimisation(Subtraction, double, double); EnableOptimisation(Subtraction, float, float); EnableOptimisation(Subtraction, Nz::Int32, Nz::Int32); EnableOptimisation(Subtraction, Nz::Vector2f, Nz::Vector2f); EnableOptimisation(Subtraction, Nz::Vector3f, Nz::Vector3f); EnableOptimisation(Subtraction, Nz::Vector4f, Nz::Vector4f); EnableOptimisation(Subtraction, Nz::Vector2i32, Nz::Vector2i32); EnableOptimisation(Subtraction, Nz::Vector3i32, Nz::Vector3i32); EnableOptimisation(Subtraction, Nz::Vector4i32, Nz::Vector4i32); #undef EnableOptimisation } ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement) { m_shaderAst = nullptr; return CloneStatement(statement); } ShaderNodes::StatementPtr ShaderAstOptimizer::Optimise(const ShaderNodes::StatementPtr& statement, const ShaderAst& shader, UInt64 enabledConditions) { m_shaderAst = &shader; m_enabledConditions = enabledConditions; return CloneStatement(statement); } void ShaderAstOptimizer::Visit(ShaderNodes::BinaryOp& node) { auto lhs = CloneExpression(node.left); auto rhs = CloneExpression(node.right); if (lhs->GetType() == ShaderNodes::NodeType::Constant && rhs->GetType() == ShaderNodes::NodeType::Constant) { auto lhsConstant = std::static_pointer_cast(lhs); auto rhsConstant = std::static_pointer_cast(rhs); switch (node.op) { case ShaderNodes::BinaryType::Add: return PropagateConstant(lhsConstant, rhsConstant); case ShaderNodes::BinaryType::Subtract: return PropagateConstant(lhsConstant, rhsConstant); case ShaderNodes::BinaryType::Multiply: return PropagateConstant(lhsConstant, rhsConstant); case ShaderNodes::BinaryType::Divide: return PropagateConstant(lhsConstant, rhsConstant); case ShaderNodes::BinaryType::CompEq: return PropagateConstant(lhsConstant, rhsConstant); case ShaderNodes::BinaryType::CompGe: return PropagateConstant(lhsConstant, rhsConstant); case ShaderNodes::BinaryType::CompGt: return PropagateConstant(lhsConstant, rhsConstant); case ShaderNodes::BinaryType::CompLe: return PropagateConstant(lhsConstant, rhsConstant); case ShaderNodes::BinaryType::CompLt: return PropagateConstant(lhsConstant, rhsConstant); case ShaderNodes::BinaryType::CompNe: return PropagateConstant(lhsConstant, rhsConstant); } } ShaderAstCloner::Visit(node); } void ShaderAstOptimizer::Visit(ShaderNodes::Branch& node) { std::vector statements; ShaderNodes::StatementPtr elseStatement; for (auto& condStatement : node.condStatements) { auto cond = CloneExpression(condStatement.condition); if (cond->GetType() == ShaderNodes::NodeType::Constant) { auto constant = std::static_pointer_cast(cond); assert(IsBasicType(cond->GetExpressionType())); assert(std::get(cond->GetExpressionType()) == ShaderNodes::BasicType::Boolean); bool cValue = std::get(constant->value); if (!cValue) continue; if (statements.empty()) { // First condition is true, dismiss the branch Visit(condStatement.statement); return; } else { // Some condition after the first one is true, make it the else statement and stop there elseStatement = CloneStatement(condStatement.statement); break; } } else { auto& c = statements.emplace_back(); c.condition = std::move(cond); c.statement = CloneStatement(condStatement.statement); } } if (statements.empty()) { // All conditions have been removed, replace by else statement or no-op if (node.elseStatement) return Visit(node.elseStatement); else return PushStatement(ShaderNodes::NoOp::Build()); } if (!elseStatement) elseStatement = CloneStatement(node.elseStatement); PushStatement(ShaderNodes::Branch::Build(std::move(statements), std::move(elseStatement))); } void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalExpression& node) { if (!m_shaderAst) return ShaderAstCloner::Visit(node); std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName); assert(conditionIndex != ShaderAst::InvalidCondition); if (TestBit(m_enabledConditions, conditionIndex)) Visit(node.truePath); else Visit(node.falsePath); } void ShaderAstOptimizer::Visit(ShaderNodes::ConditionalStatement& node) { if (!m_shaderAst) return ShaderAstCloner::Visit(node); std::size_t conditionIndex = m_shaderAst->FindConditionByName(node.conditionName); assert(conditionIndex != ShaderAst::InvalidCondition); if (TestBit(m_enabledConditions, conditionIndex)) Visit(node.statement); } template void ShaderAstOptimizer::PropagateConstant(const std::shared_ptr& lhs, const std::shared_ptr& rhs) { ShaderNodes::ExpressionPtr optimized; std::visit([&](auto&& arg1) { using T1 = std::decay_t; std::visit([&](auto&& arg2) { using T2 = std::decay_t; using PCType = PropagateConstantType; if constexpr (is_complete_v) { using Op = typename PCType::Op; if constexpr (is_complete_v) optimized = Op{}(arg1, arg2); } }, rhs->value); }, lhs->value); if (optimized) PushExpression(optimized); else PushExpression(ShaderNodes::BinaryOp::Build(Type, lhs, rhs)); } }