From 72edff30c7f6594f0e7da6ae7e1395947206ead2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Wed, 7 Jul 2021 15:23:39 +0200 Subject: [PATCH] Shader: Add support for logical and/or --- include/Nazara/Shader/Ast/Enums.hpp | 22 ++-- include/Nazara/Shader/ShaderLangTokenList.hpp | 2 + src/Nazara/Shader/Ast/AstOptimizer.cpp | 53 ++++++++- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 23 ++++ src/Nazara/Shader/GlslWriter.cpp | 3 + src/Nazara/Shader/LangWriter.cpp | 23 ++-- src/Nazara/Shader/ShaderLangLexer.cpp | 28 +++++ src/Nazara/Shader/ShaderLangParser.cpp | 4 + src/Nazara/Shader/SpirvAstVisitor.cpp | 110 +++++++++--------- 9 files changed, 192 insertions(+), 76 deletions(-) diff --git a/include/Nazara/Shader/Ast/Enums.hpp b/include/Nazara/Shader/Ast/Enums.hpp index f6b30890b..f87e8fc19 100644 --- a/include/Nazara/Shader/Ast/Enums.hpp +++ b/include/Nazara/Shader/Ast/Enums.hpp @@ -34,16 +34,18 @@ namespace Nz enum class BinaryType { - Add, //< + - CompEq, //< == - CompGe, //< >= - CompGt, //< > - CompLe, //< <= - CompLt, //< < - CompNe, //< <= - Divide, //< / - Multiply, //< * - Subtract, //< - + Add, //< + + CompEq, //< == + CompGe, //< >= + CompGt, //< > + CompLe, //< <= + CompLt, //< < + CompNe, //< <= + Divide, //< / + Multiply, //< * + LogicalAnd, //< && + LogicalOr, //< || + Subtract, //< - }; enum class BuiltinEntry diff --git a/include/Nazara/Shader/ShaderLangTokenList.hpp b/include/Nazara/Shader/ShaderLangTokenList.hpp index 20b0f1d7d..1bb6a3898 100644 --- a/include/Nazara/Shader/ShaderLangTokenList.hpp +++ b/include/Nazara/Shader/ShaderLangTokenList.hpp @@ -36,6 +36,8 @@ NAZARA_SHADERLANG_TOKEN(If) NAZARA_SHADERLANG_TOKEN(LessThan) NAZARA_SHADERLANG_TOKEN(LessThanEqual) NAZARA_SHADERLANG_TOKEN(Let) +NAZARA_SHADERLANG_TOKEN(LogicalAnd) +NAZARA_SHADERLANG_TOKEN(LogicalOr) NAZARA_SHADERLANG_TOKEN(Multiply) NAZARA_SHADERLANG_TOKEN(Minus) NAZARA_SHADERLANG_TOKEN(NotEqual) diff --git a/src/Nazara/Shader/Ast/AstOptimizer.cpp b/src/Nazara/Shader/Ast/AstOptimizer.cpp index 102d18012..766090ba9 100644 --- a/src/Nazara/Shader/Ast/AstOptimizer.cpp +++ b/src/Nazara/Shader/Ast/AstOptimizer.cpp @@ -152,6 +152,44 @@ namespace Nz::ShaderAst using Op = BinaryCompNe; }; + // LogicalAnd + template + struct BinaryLogicalAndBase + { + std::unique_ptr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs && rhs); + } + }; + + template + struct BinaryLogicalAnd; + + template + struct BinaryConstantPropagation + { + using Op = BinaryLogicalAnd; + }; + + // LogicalOr + template + struct BinaryLogicalOrBase + { + std::unique_ptr operator()(const T1& lhs, const T2& rhs) + { + return ShaderBuilder::Constant(lhs || rhs); + } + }; + + template + struct BinaryLogicalOr; + + template + struct BinaryConstantPropagation + { + using Op = BinaryLogicalOr; + }; + // Addition template struct BinaryAdditionBase @@ -325,7 +363,6 @@ namespace Nz::ShaderAst EnableOptimisation(BinaryCompEq, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompEq, Vector4i32, Vector4i32); - EnableOptimisation(BinaryCompGe, bool, bool); EnableOptimisation(BinaryCompGe, double, double); EnableOptimisation(BinaryCompGe, float, float); EnableOptimisation(BinaryCompGe, Int32, Int32); @@ -336,7 +373,6 @@ namespace Nz::ShaderAst EnableOptimisation(BinaryCompGe, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompGe, Vector4i32, Vector4i32); - EnableOptimisation(BinaryCompGt, bool, bool); EnableOptimisation(BinaryCompGt, double, double); EnableOptimisation(BinaryCompGt, float, float); EnableOptimisation(BinaryCompGt, Int32, Int32); @@ -347,7 +383,6 @@ namespace Nz::ShaderAst EnableOptimisation(BinaryCompGt, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompGt, Vector4i32, Vector4i32); - EnableOptimisation(BinaryCompLe, bool, bool); EnableOptimisation(BinaryCompLe, double, double); EnableOptimisation(BinaryCompLe, float, float); EnableOptimisation(BinaryCompLe, Int32, Int32); @@ -358,7 +393,6 @@ namespace Nz::ShaderAst EnableOptimisation(BinaryCompLe, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompLe, Vector4i32, Vector4i32); - EnableOptimisation(BinaryCompLt, bool, bool); EnableOptimisation(BinaryCompLt, double, double); EnableOptimisation(BinaryCompLt, float, float); EnableOptimisation(BinaryCompLt, Int32, Int32); @@ -380,6 +414,9 @@ namespace Nz::ShaderAst EnableOptimisation(BinaryCompNe, Vector3i32, Vector3i32); EnableOptimisation(BinaryCompNe, Vector4i32, Vector4i32); + EnableOptimisation(BinaryLogicalAnd, bool, bool); + EnableOptimisation(BinaryLogicalOr, bool, bool); + EnableOptimisation(BinaryAddition, double, double); EnableOptimisation(BinaryAddition, float, float); EnableOptimisation(BinaryAddition, Int32, Int32); @@ -583,6 +620,14 @@ namespace Nz::ShaderAst case BinaryType::CompNe: optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); break; + + case BinaryType::LogicalAnd: + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; + + case BinaryType::LogicalOr: + optimized = PropagateBinaryConstant(std::move(lhsConstant), std::move(rhsConstant)); + break; } if (optimized) diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 88c684553..a48681575 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -314,7 +314,18 @@ namespace Nz::ShaderAst default: throw AstError{ "incompatible types" }; } + break; } + + case BinaryType::LogicalAnd: + case BinaryType::LogicalOr: + if (leftType != PrimitiveType::Boolean) + throw AstError{ "logical and/or are only supported on booleans" }; + + TypeMustMatch(clone->left, clone->right); + + clone->cachedExpressionType = PrimitiveType::Boolean; + break; } } else if (IsMatrixType(leftExprType)) @@ -363,7 +374,13 @@ namespace Nz::ShaderAst } else throw AstError{ "incompatible types" }; + + break; } + + case BinaryType::LogicalAnd: + case BinaryType::LogicalOr: + throw AstError{ "logical and/or are only supported on booleans" }; } } else if (IsVectorType(leftExprType)) @@ -402,7 +419,13 @@ namespace Nz::ShaderAst } else throw AstError{ "incompatible types" }; + + break; } + + case BinaryType::LogicalAnd: + case BinaryType::LogicalOr: + throw AstError{ "logical and/or are only supported on booleans" }; } } diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index e236d8017..b28eaec2d 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -763,6 +763,9 @@ namespace Nz case ShaderAst::BinaryType::CompLe: Append(" <= "); break; case ShaderAst::BinaryType::CompLt: Append(" < "); break; case ShaderAst::BinaryType::CompNe: Append(" != "); break; + + case ShaderAst::BinaryType::LogicalAnd: Append(" && "); break; + case ShaderAst::BinaryType::LogicalOr: Append(" || "); break; } Visit(node.right, true); diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index d71ea6d58..d0fb77839 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -584,17 +584,20 @@ namespace Nz switch (node.op) { - case ShaderAst::BinaryType::Add: Append(" + "); break; - case ShaderAst::BinaryType::Subtract: Append(" - "); break; - case ShaderAst::BinaryType::Multiply: Append(" * "); break; - case ShaderAst::BinaryType::Divide: Append(" / "); break; + case ShaderAst::BinaryType::Add: Append(" + "); break; + case ShaderAst::BinaryType::Subtract: Append(" - "); break; + case ShaderAst::BinaryType::Multiply: Append(" * "); break; + case ShaderAst::BinaryType::Divide: Append(" / "); break; - case ShaderAst::BinaryType::CompEq: Append(" == "); break; - case ShaderAst::BinaryType::CompGe: Append(" >= "); break; - case ShaderAst::BinaryType::CompGt: Append(" > "); break; - case ShaderAst::BinaryType::CompLe: Append(" <= "); break; - case ShaderAst::BinaryType::CompLt: Append(" < "); break; - case ShaderAst::BinaryType::CompNe: Append(" != "); break; + case ShaderAst::BinaryType::CompEq: Append(" == "); break; + case ShaderAst::BinaryType::CompGe: Append(" >= "); break; + case ShaderAst::BinaryType::CompGt: Append(" > "); break; + case ShaderAst::BinaryType::CompLe: Append(" <= "); break; + case ShaderAst::BinaryType::CompLt: Append(" < "); break; + case ShaderAst::BinaryType::CompNe: Append(" != "); break; + + case ShaderAst::BinaryType::LogicalAnd: Append(" && "); break; + case ShaderAst::BinaryType::LogicalOr: Append(" || "); break; } Visit(node.right, true); diff --git a/src/Nazara/Shader/ShaderLangLexer.cpp b/src/Nazara/Shader/ShaderLangLexer.cpp index 6213c95da..2aed4846e 100644 --- a/src/Nazara/Shader/ShaderLangLexer.cpp +++ b/src/Nazara/Shader/ShaderLangLexer.cpp @@ -246,6 +246,34 @@ namespace Nz::ShaderLang break; } + case '|': + { + char next = Peek(); + if (next == '|') + { + currentPos++; + tokenType = TokenType::LogicalOr; + } + else + throw UnrecognizedToken{}; //< TODO: Add BOR (a | b) + + break; + } + + case '&': + { + char next = Peek(); + if (next == '&') + { + currentPos++; + tokenType = TokenType::LogicalAnd; + } + else + throw UnrecognizedToken{}; //< TODO: Add BAND (a & b) + + break; + } + case '<': { char next = Peek(); diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index d6e52ae85..4a5cba701 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -917,6 +917,8 @@ namespace Nz::ShaderLang case TokenType::Equal: binaryType = ShaderAst::BinaryType::CompEq; break; case TokenType::LessThan: binaryType = ShaderAst::BinaryType::CompLt; break; case TokenType::LessThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break; + case TokenType::LogicalAnd: binaryType = ShaderAst::BinaryType::LogicalAnd; break; + case TokenType::LogicalOr: binaryType = ShaderAst::BinaryType::LogicalOr; break; case TokenType::GreaterThan: binaryType = ShaderAst::BinaryType::CompLt; break; case TokenType::GreaterThanEqual: binaryType = ShaderAst::BinaryType::CompLe; break; case TokenType::Minus: binaryType = ShaderAst::BinaryType::Subtract; break; @@ -1140,6 +1142,8 @@ namespace Nz::ShaderLang case TokenType::Equal: return 50; case TokenType::LessThan: return 40; case TokenType::LessThanEqual: return 40; + case TokenType::LogicalAnd: return 120; + case TokenType::LogicalOr: return 140; case TokenType::GreaterThan: return 40; case TokenType::GreaterThanEqual: return 40; case TokenType::Multiply: return 80; diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index a6d2e27d4..d3aea9f3c 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -139,6 +139,60 @@ namespace Nz break; } + case ShaderAst::BinaryType::Multiply: + { + switch (leftTypeBase) + { + case ShaderAst::PrimitiveType::Float32: + { + if (IsPrimitiveType(leftType)) + { + // Handle float * matrix|vector as matrix|vector * float + if (IsMatrixType(rightType)) + { + swapOperands = true; + return SpirvOp::OpMatrixTimesScalar; + } + else if (IsVectorType(rightType)) + { + swapOperands = true; + return SpirvOp::OpVectorTimesScalar; + } + } + else if (IsPrimitiveType(rightType)) + { + if (IsMatrixType(leftType)) + return SpirvOp::OpMatrixTimesScalar; + else if (IsVectorType(leftType)) + return SpirvOp::OpVectorTimesScalar; + } + else if (IsMatrixType(leftType)) + { + if (IsMatrixType(rightType)) + return SpirvOp::OpMatrixTimesMatrix; + else if (IsVectorType(rightType)) + return SpirvOp::OpMatrixTimesVector; + } + else if (IsMatrixType(rightType)) + { + assert(IsVectorType(leftType)); + return SpirvOp::OpVectorTimesMatrix; + } + + return SpirvOp::OpFMul; + } + + case ShaderAst::PrimitiveType::Int32: + case ShaderAst::PrimitiveType::UInt32: + return SpirvOp::OpIMul; + + default: + break; + } + + break; + } + case ShaderAst::BinaryType::CompEq: { switch (leftTypeBase) @@ -255,59 +309,11 @@ namespace Nz break; } - case ShaderAst::BinaryType::Multiply: - { - switch (leftTypeBase) - { - case ShaderAst::PrimitiveType::Float32: - { - if (IsPrimitiveType(leftType)) - { - // Handle float * matrix|vector as matrix|vector * float - if (IsMatrixType(rightType)) - { - swapOperands = true; - return SpirvOp::OpMatrixTimesScalar; - } - else if (IsVectorType(rightType)) - { - swapOperands = true; - return SpirvOp::OpVectorTimesScalar; - } - } - else if (IsPrimitiveType(rightType)) - { - if (IsMatrixType(leftType)) - return SpirvOp::OpMatrixTimesScalar; - else if (IsVectorType(leftType)) - return SpirvOp::OpVectorTimesScalar; - } - else if (IsMatrixType(leftType)) - { - if (IsMatrixType(rightType)) - return SpirvOp::OpMatrixTimesMatrix; - else if (IsVectorType(rightType)) - return SpirvOp::OpMatrixTimesVector; - } - else if (IsMatrixType(rightType)) - { - assert(IsVectorType(leftType)); - return SpirvOp::OpVectorTimesMatrix; - } + case ShaderAst::BinaryType::LogicalAnd: + return SpirvOp::OpLogicalAnd; - return SpirvOp::OpFMul; - } - - case ShaderAst::PrimitiveType::Int32: - case ShaderAst::PrimitiveType::UInt32: - return SpirvOp::OpIMul; - - default: - break; - } - - break; - } + case ShaderAst::BinaryType::LogicalOr: + return SpirvOp::OpLogicalOr; } assert(false);