From 64efd81bf8bfccd6a03caa4815fddaacff630fda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Sun, 23 Jan 2022 19:58:04 +0100 Subject: [PATCH] Shader: Handle matrix cast properly --- include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 3 +- include/Nazara/Shader/ShaderBuilder.hpp | 2 + include/Nazara/Shader/ShaderBuilder.inl | 26 ++++ src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 139 ++++++++++++++++-- src/Nazara/Shader/ShaderLangParser.cpp | 14 ++ src/Nazara/Shader/SpirvWriter.cpp | 3 +- tests/Engine/Shader/Sanitizations.cpp | 129 ++++++++++++++++ 7 files changed, 305 insertions(+), 11 deletions(-) diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index 7274b399a..4d878668b 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -40,9 +40,10 @@ namespace Nz::ShaderAst std::unordered_set reservedIdentifiers; std::unordered_map optionValues; bool makeVariableNameUnique = false; - bool removeConstDeclaration = false; bool reduceLoopsToWhile = false; + bool removeConstDeclaration = false; bool removeCompoundAssignments = false; + bool removeMatrixCast = false; bool removeOptionDeclaration = false; bool removeScalarSwizzling = false; bool splitMultipleBranches = false; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index 84b121a95..5783e8b22 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -54,6 +54,7 @@ namespace Nz::ShaderBuilder struct Cast { + inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, ShaderAst::ExpressionPtr expression) const; inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, std::array expressions) const; inline std::unique_ptr operator()(ShaderAst::ExpressionType targetType, std::vector expressions) const; }; @@ -71,6 +72,7 @@ namespace Nz::ShaderBuilder struct Constant { inline std::unique_ptr operator()(ShaderAst::ConstantValue value) const; + template std::unique_ptr operator()(ShaderAst::ExpressionType type, T value) const; }; struct DeclareConst diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index 0bfc90187..83a2423ab 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include namespace Nz::ShaderBuilder @@ -110,6 +111,15 @@ namespace Nz::ShaderBuilder return callFunctionExpression; } + inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionType targetType, ShaderAst::ExpressionPtr expression) const + { + auto castNode = std::make_unique(); + castNode->targetType = std::move(targetType); + castNode->expressions[0] = std::move(expression); + + return castNode; + } + inline std::unique_ptr Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::array expressions) const { auto castNode = std::make_unique(); @@ -159,6 +169,22 @@ namespace Nz::ShaderBuilder return constantNode; } + template + std::unique_ptr Impl::Constant::operator()(ShaderAst::ExpressionType type, T value) const + { + assert(IsPrimitiveType(type)); + + switch (std::get(type)) + { + case ShaderAst::PrimitiveType::Boolean: return ShaderBuilder::Constant(value != T(0)); + case ShaderAst::PrimitiveType::Float32: return ShaderBuilder::Constant(SafeCast(value)); + case ShaderAst::PrimitiveType::Int32: return ShaderBuilder::Constant(SafeCast(value)); + case ShaderAst::PrimitiveType::UInt32: return ShaderBuilder::Constant(SafeCast(value)); + } + + throw std::runtime_error("unexpected primitive type"); + } + inline std::unique_ptr Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const { auto declareConstNode = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 5e1449f11..448075fb5 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -386,6 +387,88 @@ namespace Nz::ShaderAst auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); Validate(*clone); + if (m_context->options.removeMatrixCast && IsMatrixType(clone->targetType)) + { + const MatrixType& targetMatrixType = std::get(clone->targetType); + + const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front()); + bool isMatrixCast = IsMatrixType(frontExprType); + if (isMatrixCast && std::get(frontExprType) == targetMatrixType) + { + // Nothing to do + return std::move(clone->expressions.front()); + } + + auto variableDeclaration = ShaderBuilder::DeclareVariable("temp", clone->targetType); //< Validation will prevent name-clash if required + Validate(*variableDeclaration); + + std::size_t variableIndex = *variableDeclaration->varIndex; + + m_context->currentStatementList->emplace_back(std::move(variableDeclaration)); + + for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i) + { + // temp[i] + auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, clone->targetType), ShaderBuilder::Constant(UInt32(i))); + Validate(*columnExpr); + + // vector expression + ExpressionPtr vectorExpr; + std::size_t vectorComponentCount; + if (isMatrixCast) + { + // fromMatrix[i] + auto matrixColumnExpr = ShaderBuilder::AccessIndex(CloneExpression(clone->expressions.front()), ShaderBuilder::Constant(UInt32(i))); + Validate(*matrixColumnExpr); + + vectorExpr = std::move(matrixColumnExpr); + vectorComponentCount = std::get(frontExprType).rowCount; + } + else + { + // parameter #i + vectorExpr = std::move(clone->expressions[i]); + vectorComponentCount = std::get(GetExpressionType(*vectorExpr)).componentCount; + } + + // cast expression (turn fromMatrix[i] to vec3(fromMatrix[i])) + ExpressionPtr castExpr; + if (vectorComponentCount != targetMatrixType.rowCount) + { + CastExpressionPtr vecCast; + if (vectorComponentCount < targetMatrixType.rowCount) + { + std::array expressions; + expressions[0] = std::move(vectorExpr); + for (std::size_t j = 0; j < targetMatrixType.rowCount - vectorComponentCount; ++j) + expressions[j + 1] = ShaderBuilder::Constant(targetMatrixType.type, (i == j + vectorComponentCount) ? 1 : 0); //< set 1 to diagonal + + vecCast = ShaderBuilder::Cast(VectorType{ targetMatrixType.rowCount, targetMatrixType.type }, std::move(expressions)); + Validate(*vecCast); + + castExpr = std::move(vecCast); + } + else + { + std::array swizzleComponents; + std::iota(swizzleComponents.begin(), swizzleComponents.begin() + targetMatrixType.rowCount, 0); + + auto swizzleExpr = ShaderBuilder::Swizzle(std::move(vectorExpr), swizzleComponents, targetMatrixType.rowCount); + Validate(*swizzleExpr); + + castExpr = std::move(swizzleExpr); + } + } + else + castExpr = std::move(vectorExpr); + + // temp[i] = castExpr + m_context->currentStatementList->emplace_back(ShaderBuilder::ExpressionStatement(ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(castExpr)))); + } + + return ShaderBuilder::Variable(variableIndex, clone->targetType); + } + return clone; } @@ -653,7 +736,7 @@ namespace Nz::ShaderAst else if (IsSamplerType(extVar.type)) varType = extVar.type; else - throw AstError{ "External variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" }; + throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" }; std::size_t varIndex = RegisterVariable(extVar.name, std::move(varType)); if (!clone->varIndex) @@ -820,6 +903,18 @@ namespace Nz::ShaderAst declaredMembers.insert(member.name); member.type = ResolveType(member.type); + if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140) + { + if (IsPrimitiveType(member.type) && std::get(member.type) == PrimitiveType::Boolean) + throw AstError{ "boolean type is not allowed in std140 layout" }; + else if (IsStructType(member.type)) + { + std::size_t structIndex = std::get(member.type).structIndex; + const StructDescription* desc = m_context->structs[structIndex]; + if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue()) + throw AstError{ "inner struct layout mismatch" }; + } + } } clone->structIndex = RegisterStruct(clone->description.name, &clone->description); @@ -1695,18 +1790,44 @@ namespace Nz::ShaderAst void SanitizeVisitor::Validate(CastExpression& node) { - node.cachedExpressionType = node.targetType; node.targetType = ResolveType(node.targetType); + node.cachedExpressionType = node.targetType; - // Allow casting a matrix to itself (wtf?) - // FIXME: Make proper rules - if (IsMatrixType(node.targetType) && node.expressions.front()) + const auto& firstExprPtr = node.expressions.front(); + if (!firstExprPtr) + throw AstError{ "expected at least one expression" }; + + if (IsMatrixType(node.targetType)) { - const ExpressionType& exprType = GetExpressionType(*node.expressions.front()); - if (IsMatrixType(exprType) && !node.expressions[1]) + const MatrixType& targetMatrixType = std::get(node.targetType); + + const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr); + if (IsMatrixType(firstExprType)) { + if (node.expressions[1]) + throw AstError{ "too many expressions" }; + + // Matrix to matrix cast: always valid return; } + else + { + assert(targetMatrixType.columnCount <= 4); + for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i) + { + const auto& exprPtr = node.expressions[i]; + if (!exprPtr) + throw AstError{ "component count doesn't match required component count" }; + + const ExpressionType& exprType = GetExpressionType(*exprPtr); + if (!IsVectorType(exprType)) + throw AstError{ "expected vector type" }; + + const VectorType& vecType = std::get(exprType); + if (vecType.componentCount != targetMatrixType.rowCount) + throw AstError{ "vector component count must match target matrix row count" }; + } + } } auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t @@ -1936,9 +2057,9 @@ namespace Nz::ShaderAst if (node.componentCount > 4) throw AstError{ "cannot swizzle more than four elements" }; - for (UInt32 swizzleIndex : node.components) + for (std::size_t i = 0; i < node.componentCount; ++i) { - if (swizzleIndex >= componentCount) + if (node.components[i] >= componentCount) throw AstError{ "invalid swizzle" }; } diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 210d20c01..f08ad1ca4 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -235,6 +235,20 @@ namespace Nz::ShaderLang return matrixType; } + else if (identifier == "mat2") + { + Consume(); + + ShaderAst::MatrixType matrixType; + matrixType.columnCount = 2; + matrixType.rowCount = 2; + + Expect(Advance(), TokenType::LessThan); //< '<' + matrixType.type = ParsePrimitiveType(); + Expect(Advance(), TokenType::GreaterThan); //< '>' + + return matrixType; + } else if (identifier == "sampler2D") { Consume(); diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 13a314b1f..a4ecc52d5 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -491,7 +491,8 @@ namespace Nz options.optionValues = states.optionValues; options.reduceLoopsToWhile = true; options.removeCompoundAssignments = true; - options.removeOptionDeclaration = true; + options.removeMatrixCast = true; + options.removeOptionDeclaration = true; options.splitMultipleBranches = true; options.useIdentifierAccessesForStructs = false; diff --git a/tests/Engine/Shader/Sanitizations.cpp b/tests/Engine/Shader/Sanitizations.cpp index e28256adb..76025fee9 100644 --- a/tests/Engine/Shader/Sanitizations.cpp +++ b/tests/Engine/Shader/Sanitizations.cpp @@ -124,6 +124,135 @@ fn main() } } +)"); + + } + + WHEN("removing matrix casts") + { + std::string_view nzslSource = R"( +fn testMat2ToMat2(input: mat2) -> mat2 +{ + return mat2(input); +} + +fn testMat2ToMat3(input: mat2) -> mat3 +{ + return mat3(input); +} + +fn testMat2ToMat4(input: mat2) -> mat4 +{ + return mat4(input); +} + +fn testMat3ToMat2(input: mat3) -> mat2 +{ + return mat2(input); +} + +fn testMat3ToMat3(input: mat3) -> mat3 +{ + return mat3(input); +} + +fn testMat3ToMat4(input: mat3) -> mat4 +{ + return mat4(input); +} + +fn testMat4ToMat2(input: mat4) -> mat2 +{ + return mat2(input); +} + +fn testMat4ToMat3(input: mat4) -> mat3 +{ + return mat3(input); +} + +fn testMat4ToMat4(input: mat4) -> mat4 +{ + return mat4(input); +} +)"; + + Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource); + + Nz::ShaderAst::SanitizeVisitor::Options options; + options.removeMatrixCast = true; + + REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader, options)); + + ExpectNZSL(*shader, R"( +fn testMat2ToMat2(input: mat2) -> mat2 +{ + return input; +} + +fn testMat2ToMat3(input: mat2) -> mat3 +{ + let temp: mat3; + temp[0] = vec3(input[0], 0.000000); + temp[1] = vec3(input[1], 0.000000); + temp[2] = vec3(input[2], 1.000000); + return temp; +} + +fn testMat2ToMat4(input: mat2) -> mat4 +{ + let temp: mat4; + temp[0] = vec4(input[0], 0.000000, 0.000000); + temp[1] = vec4(input[1], 0.000000, 0.000000); + temp[2] = vec4(input[2], 1.000000, 0.000000); + temp[3] = vec4(input[3], 0.000000, 1.000000); + return temp; +} + +fn testMat3ToMat2(input: mat3) -> mat2 +{ + let temp: mat2; + temp[0] = input[0].xy; + temp[1] = input[1].xy; + return temp; +} + +fn testMat3ToMat3(input: mat3) -> mat3 +{ + return input; +} + +fn testMat3ToMat4(input: mat3) -> mat4 +{ + let temp: mat4; + temp[0] = vec4(input[0], 0.000000); + temp[1] = vec4(input[1], 0.000000); + temp[2] = vec4(input[2], 0.000000); + temp[3] = vec4(input[3], 1.000000); + return temp; +} + +fn testMat4ToMat2(input: mat4) -> mat2 +{ + let temp: mat2; + temp[0] = input[0].xy; + temp[1] = input[1].xy; + return temp; +} + +fn testMat4ToMat3(input: mat4) -> mat3 +{ + let temp: mat3; + temp[0] = input[0].xyz; + temp[1] = input[1].xyz; + temp[2] = input[2].xyz; + return temp; +} + +fn testMat4ToMat4(input: mat4) -> mat4 +{ + return input; +} )"); }