From 83de0939bbe88a932c5cd26da1eda18329ef4ccc Mon Sep 17 00:00:00 2001 From: Lynix Date: Sat, 2 Apr 2022 21:20:01 +0200 Subject: [PATCH] Shader: Fix validation and cast from u32 to i32 --- include/Nazara/Shader/ShaderLangErrorList.hpp | 2 + .../Resources/Shaders/PhongMaterial.nzsl | 2 +- src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 119 ++++++++++++++++-- src/Nazara/Shader/SpirvAstVisitor.cpp | 5 +- 4 files changed, 115 insertions(+), 13 deletions(-) diff --git a/include/Nazara/Shader/ShaderLangErrorList.hpp b/include/Nazara/Shader/ShaderLangErrorList.hpp index e19d8bec9..f69ab9279 100644 --- a/include/Nazara/Shader/ShaderLangErrorList.hpp +++ b/include/Nazara/Shader/ShaderLangErrorList.hpp @@ -60,6 +60,7 @@ NAZARA_SHADERLANG_COMPILER_ERROR(BinaryIncompatibleTypes, "incompatibles types ( NAZARA_SHADERLANG_COMPILER_ERROR(BinaryUnsupported, "{} type ({}) does not support this binary operation", std::string, std::string) NAZARA_SHADERLANG_COMPILER_ERROR(BranchOutsideOfFunction, "non-const branching statements can only exist inside a function") NAZARA_SHADERLANG_COMPILER_ERROR(CastComponentMismatch, "component count ({}) doesn't match required component count ({})", UInt32, UInt32) +NAZARA_SHADERLANG_COMPILER_ERROR(CastIncompatibleBaseTypes, "incompatibles base types (expected {}, got {})", std::string, std::string) NAZARA_SHADERLANG_COMPILER_ERROR(CastIncompatibleTypes, "incompatibles types ({} and {})", std::string, std::string) NAZARA_SHADERLANG_COMPILER_ERROR(CastMatrixExpectedVector, "expected vector type, got {}", std::string) NAZARA_SHADERLANG_COMPILER_ERROR(CastMatrixVectorComponentMismatch, "vector component count ({}) doesn't match target matrix row count ({})", UInt32, UInt32) @@ -102,6 +103,7 @@ NAZARA_SHADERLANG_COMPILER_ERROR(IntrinsicExpectedParameterCount, "expected {} p NAZARA_SHADERLANG_COMPILER_ERROR(IntrinsicExpectedType, "expected type {1} for parameter #{0}, got {2}", UInt32, std::string, std::string) NAZARA_SHADERLANG_COMPILER_ERROR(IntrinsicUnexpectedBoolean, "boolean parameters are not allowed") NAZARA_SHADERLANG_COMPILER_ERROR(IntrinsicUnmatchingParameterType, "all types must match") +NAZARA_SHADERLANG_COMPILER_ERROR(InvalidCast, "invalid cast to type {}", std::string) NAZARA_SHADERLANG_COMPILER_ERROR(InvalidScalarSwizzle, "invalid swizzle for scalar") NAZARA_SHADERLANG_COMPILER_ERROR(InvalidSwizzle, "invalid swizzle {}", std::string) NAZARA_SHADERLANG_COMPILER_ERROR(MissingOptionValue, "option {} requires a value (no default value set)", std::string) diff --git a/src/Nazara/Graphics/Resources/Shaders/PhongMaterial.nzsl b/src/Nazara/Graphics/Resources/Shaders/PhongMaterial.nzsl index 769621143..38fd0fde9 100644 --- a/src/Nazara/Graphics/Resources/Shaders/PhongMaterial.nzsl +++ b/src/Nazara/Graphics/Resources/Shaders/PhongMaterial.nzsl @@ -122,7 +122,7 @@ fn main(input: VertToFrag) -> FragOut else normal = normalize(input.normal); - for i in 0 -> lightData.lightCount + for i in u32(0) -> lightData.lightCount { let light = lightData.lights[i]; diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index ab7ab145f..aa3ff9ff4 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -3159,6 +3159,13 @@ namespace Nz::ShaderAst auto& firstExprPtr = MandatoryExpr(node.expressions.front(), node.sourceLocation); + std::size_t expressionCount = 0; + for (; expressionCount < node.expressions.size(); ++expressionCount) + { + if (!node.expressions[expressionCount]) + break; + } + if (IsMatrixType(targetType)) { const MatrixType& targetMatrixType = std::get(targetType); @@ -3169,22 +3176,23 @@ namespace Nz::ShaderAst if (IsMatrixType(ResolveAlias(*firstExprType))) { - if (node.expressions[1]) - throw ShaderLang::CompilerCastComponentMismatchError{ node.expressions[1]->sourceLocation, 2, 1 }; //< get real component count + if (expressionCount != 1) + throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast(expressionCount), 1 }; // Matrix to matrix cast: always valid } else { + // Matrix builder (from vectors) + assert(targetMatrixType.columnCount <= 4); - UInt32 expressionCount = 0; + if (expressionCount != targetMatrixType.columnCount) + throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast(expressionCount), SafeCast(targetMatrixType.columnCount) }; + for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i) { const auto& exprPtr = node.expressions[i]; - if (!exprPtr) - throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, expressionCount, SafeCast(targetMatrixType.columnCount) }; - - expressionCount++; + assert(exprPtr); const ExpressionType* exprType = GetExpressionType(*exprPtr); if (!exprType) @@ -3197,11 +3205,90 @@ namespace Nz::ShaderAst const VectorType& vecType = std::get(resolvedExprType); if (vecType.componentCount != targetMatrixType.rowCount) throw ShaderLang::CompilerCastMatrixVectorComponentMismatchError{ node.expressions[i]->sourceLocation, SafeCast(vecType.componentCount), SafeCast(targetMatrixType.rowCount) }; + + if (vecType.type != targetMatrixType.type) + throw ShaderLang::CompilerCastIncompatibleBaseTypesError{ node.expressions[i]->sourceLocation, ToString(targetMatrixType.type, node.sourceLocation), ToString(vecType.type, node.sourceLocation) }; } } } - else + else if (IsPrimitiveType(targetType)) { + // Cast between primitive types + if (expressionCount != 1) + throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast(expressionCount), 1 }; + + const ExpressionType* fromType = GetExpressionType(*node.expressions[0]); + if (!fromType) + return ValidationResult::Unresolved; + + const ExpressionType& resolvedFromType = ResolveAlias(*fromType); + if (!IsPrimitiveType(resolvedFromType)) + throw ShaderLang::CompilerCastIncompatibleTypesError{ node.expressions[0]->sourceLocation, ToString(targetType, node.sourceLocation), ToString(resolvedFromType, node.sourceLocation) }; + + PrimitiveType fromPrimitiveType = std::get(resolvedFromType); + PrimitiveType targetPrimitiveType = std::get(targetType); + + bool areTypeCompatibles = [&] + { + switch (targetPrimitiveType) + { + case PrimitiveType::Boolean: + case PrimitiveType::String: + return false; + + case PrimitiveType::Float32: + { + switch (fromPrimitiveType) + { + case PrimitiveType::Boolean: + case PrimitiveType::String: + return false; + + case PrimitiveType::Float32: + case PrimitiveType::Int32: + case PrimitiveType::UInt32: + return true; + } + } + + case PrimitiveType::Int32: + { + switch (fromPrimitiveType) + { + case PrimitiveType::Boolean: + case PrimitiveType::String: + return false; + + case PrimitiveType::Float32: + case PrimitiveType::Int32: + return true; + } + } + + case PrimitiveType::UInt32: + { + switch (fromPrimitiveType) + { + case PrimitiveType::Boolean: + case PrimitiveType::String: + return false; + + case PrimitiveType::Float32: + case PrimitiveType::Int32: + case PrimitiveType::UInt32: + return true; + } + } + } + }(); + + if (!areTypeCompatibles) + throw ShaderLang::CompilerCastIncompatibleTypesError{ node.expressions[0]->sourceLocation, ToString(targetType, node.sourceLocation), ToString(resolvedFromType, node.sourceLocation) }; + } + else if (IsVectorType(targetType)) + { + PrimitiveType targetBaseType = std::get(targetType).type; + auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t { if (IsVectorType(exprType)) @@ -3226,7 +3313,19 @@ namespace Nz::ShaderAst return ValidationResult::Unresolved; const ExpressionType& resolvedExprType = ResolveAlias(*exprType); - if (!IsPrimitiveType(resolvedExprType) && !IsVectorType(resolvedExprType)) + if (IsPrimitiveType(resolvedExprType)) + { + PrimitiveType primitiveType = std::get(resolvedExprType); + if (primitiveType != targetBaseType) + throw ShaderLang::CompilerCastIncompatibleBaseTypesError{ exprPtr->sourceLocation, ToString(targetBaseType, node.sourceLocation), ToString(primitiveType, exprPtr->sourceLocation) }; + } + else if (IsVectorType(resolvedExprType)) + { + PrimitiveType primitiveType = std::get(resolvedExprType).type; + if (primitiveType != targetBaseType) + throw ShaderLang::CompilerCastIncompatibleBaseTypesError{ exprPtr->sourceLocation, ToString(targetBaseType, node.sourceLocation), ToString(primitiveType, exprPtr->sourceLocation) }; + } + else throw ShaderLang::CompilerCastIncompatibleTypesError{ exprPtr->sourceLocation, ToString(targetType, node.sourceLocation), ToString(resolvedExprType, exprPtr->sourceLocation) }; componentCount += GetComponentCount(resolvedExprType); @@ -3235,6 +3334,8 @@ namespace Nz::ShaderAst if (componentCount != requiredComponents) throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast(componentCount), SafeCast(requiredComponents) }; } + else + throw ShaderLang::CompilerInvalidCastError{ node.sourceLocation, ToString(targetType, node.sourceLocation) }; node.cachedExpressionType = targetType; node.targetType = targetType; diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index d7a7c4d62..8a02b45d6 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -509,8 +509,7 @@ namespace Nz break; //< Already handled case ShaderAst::PrimitiveType::UInt32: - castOp = SpirvOp::OpSConvert; - break; + throw std::runtime_error("unsupported cast from int32"); case ShaderAst::PrimitiveType::String: throw std::runtime_error("unexpected string type"); @@ -530,7 +529,7 @@ namespace Nz break; case ShaderAst::PrimitiveType::Int32: - castOp = SpirvOp::OpUConvert; + castOp = SpirvOp::OpBitcast; break; case ShaderAst::PrimitiveType::UInt32: