Shader: Fix validation and cast from u32 to i32

This commit is contained in:
Lynix 2022-04-02 21:20:01 +02:00
parent 5cd9f6fdcd
commit 83de0939bb
4 changed files with 115 additions and 13 deletions

View File

@ -60,6 +60,7 @@ NAZARA_SHADERLANG_COMPILER_ERROR(BinaryIncompatibleTypes, "incompatibles types (
NAZARA_SHADERLANG_COMPILER_ERROR(BinaryUnsupported, "{} type ({}) does not support this binary operation", std::string, std::string)
NAZARA_SHADERLANG_COMPILER_ERROR(BranchOutsideOfFunction, "non-const branching statements can only exist inside a function")
NAZARA_SHADERLANG_COMPILER_ERROR(CastComponentMismatch, "component count ({}) doesn't match required component count ({})", UInt32, UInt32)
NAZARA_SHADERLANG_COMPILER_ERROR(CastIncompatibleBaseTypes, "incompatibles base types (expected {}, got {})", std::string, std::string)
NAZARA_SHADERLANG_COMPILER_ERROR(CastIncompatibleTypes, "incompatibles types ({} and {})", std::string, std::string)
NAZARA_SHADERLANG_COMPILER_ERROR(CastMatrixExpectedVector, "expected vector type, got {}", std::string)
NAZARA_SHADERLANG_COMPILER_ERROR(CastMatrixVectorComponentMismatch, "vector component count ({}) doesn't match target matrix row count ({})", UInt32, UInt32)
@ -102,6 +103,7 @@ NAZARA_SHADERLANG_COMPILER_ERROR(IntrinsicExpectedParameterCount, "expected {} p
NAZARA_SHADERLANG_COMPILER_ERROR(IntrinsicExpectedType, "expected type {1} for parameter #{0}, got {2}", UInt32, std::string, std::string)
NAZARA_SHADERLANG_COMPILER_ERROR(IntrinsicUnexpectedBoolean, "boolean parameters are not allowed")
NAZARA_SHADERLANG_COMPILER_ERROR(IntrinsicUnmatchingParameterType, "all types must match")
NAZARA_SHADERLANG_COMPILER_ERROR(InvalidCast, "invalid cast to type {}", std::string)
NAZARA_SHADERLANG_COMPILER_ERROR(InvalidScalarSwizzle, "invalid swizzle for scalar")
NAZARA_SHADERLANG_COMPILER_ERROR(InvalidSwizzle, "invalid swizzle {}", std::string)
NAZARA_SHADERLANG_COMPILER_ERROR(MissingOptionValue, "option {} requires a value (no default value set)", std::string)

View File

@ -122,7 +122,7 @@ fn main(input: VertToFrag) -> FragOut
else
normal = normalize(input.normal);
for i in 0 -> lightData.lightCount
for i in u32(0) -> lightData.lightCount
{
let light = lightData.lights[i];

View File

@ -3159,6 +3159,13 @@ namespace Nz::ShaderAst
auto& firstExprPtr = MandatoryExpr(node.expressions.front(), node.sourceLocation);
std::size_t expressionCount = 0;
for (; expressionCount < node.expressions.size(); ++expressionCount)
{
if (!node.expressions[expressionCount])
break;
}
if (IsMatrixType(targetType))
{
const MatrixType& targetMatrixType = std::get<MatrixType>(targetType);
@ -3169,22 +3176,23 @@ namespace Nz::ShaderAst
if (IsMatrixType(ResolveAlias(*firstExprType)))
{
if (node.expressions[1])
throw ShaderLang::CompilerCastComponentMismatchError{ node.expressions[1]->sourceLocation, 2, 1 }; //< get real component count
if (expressionCount != 1)
throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast<UInt32>(expressionCount), 1 };
// Matrix to matrix cast: always valid
}
else
{
// Matrix builder (from vectors)
assert(targetMatrixType.columnCount <= 4);
UInt32 expressionCount = 0;
if (expressionCount != targetMatrixType.columnCount)
throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast<UInt32>(expressionCount), SafeCast<UInt32>(targetMatrixType.columnCount) };
for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i)
{
const auto& exprPtr = node.expressions[i];
if (!exprPtr)
throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, expressionCount, SafeCast<UInt32>(targetMatrixType.columnCount) };
expressionCount++;
assert(exprPtr);
const ExpressionType* exprType = GetExpressionType(*exprPtr);
if (!exprType)
@ -3197,11 +3205,90 @@ namespace Nz::ShaderAst
const VectorType& vecType = std::get<VectorType>(resolvedExprType);
if (vecType.componentCount != targetMatrixType.rowCount)
throw ShaderLang::CompilerCastMatrixVectorComponentMismatchError{ node.expressions[i]->sourceLocation, SafeCast<UInt32>(vecType.componentCount), SafeCast<UInt32>(targetMatrixType.rowCount) };
if (vecType.type != targetMatrixType.type)
throw ShaderLang::CompilerCastIncompatibleBaseTypesError{ node.expressions[i]->sourceLocation, ToString(targetMatrixType.type, node.sourceLocation), ToString(vecType.type, node.sourceLocation) };
}
}
}
else
else if (IsPrimitiveType(targetType))
{
// Cast between primitive types
if (expressionCount != 1)
throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast<UInt32>(expressionCount), 1 };
const ExpressionType* fromType = GetExpressionType(*node.expressions[0]);
if (!fromType)
return ValidationResult::Unresolved;
const ExpressionType& resolvedFromType = ResolveAlias(*fromType);
if (!IsPrimitiveType(resolvedFromType))
throw ShaderLang::CompilerCastIncompatibleTypesError{ node.expressions[0]->sourceLocation, ToString(targetType, node.sourceLocation), ToString(resolvedFromType, node.sourceLocation) };
PrimitiveType fromPrimitiveType = std::get<PrimitiveType>(resolvedFromType);
PrimitiveType targetPrimitiveType = std::get<PrimitiveType>(targetType);
bool areTypeCompatibles = [&]
{
switch (targetPrimitiveType)
{
case PrimitiveType::Boolean:
case PrimitiveType::String:
return false;
case PrimitiveType::Float32:
{
switch (fromPrimitiveType)
{
case PrimitiveType::Boolean:
case PrimitiveType::String:
return false;
case PrimitiveType::Float32:
case PrimitiveType::Int32:
case PrimitiveType::UInt32:
return true;
}
}
case PrimitiveType::Int32:
{
switch (fromPrimitiveType)
{
case PrimitiveType::Boolean:
case PrimitiveType::String:
return false;
case PrimitiveType::Float32:
case PrimitiveType::Int32:
return true;
}
}
case PrimitiveType::UInt32:
{
switch (fromPrimitiveType)
{
case PrimitiveType::Boolean:
case PrimitiveType::String:
return false;
case PrimitiveType::Float32:
case PrimitiveType::Int32:
case PrimitiveType::UInt32:
return true;
}
}
}
}();
if (!areTypeCompatibles)
throw ShaderLang::CompilerCastIncompatibleTypesError{ node.expressions[0]->sourceLocation, ToString(targetType, node.sourceLocation), ToString(resolvedFromType, node.sourceLocation) };
}
else if (IsVectorType(targetType))
{
PrimitiveType targetBaseType = std::get<VectorType>(targetType).type;
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
{
if (IsVectorType(exprType))
@ -3226,7 +3313,19 @@ namespace Nz::ShaderAst
return ValidationResult::Unresolved;
const ExpressionType& resolvedExprType = ResolveAlias(*exprType);
if (!IsPrimitiveType(resolvedExprType) && !IsVectorType(resolvedExprType))
if (IsPrimitiveType(resolvedExprType))
{
PrimitiveType primitiveType = std::get<PrimitiveType>(resolvedExprType);
if (primitiveType != targetBaseType)
throw ShaderLang::CompilerCastIncompatibleBaseTypesError{ exprPtr->sourceLocation, ToString(targetBaseType, node.sourceLocation), ToString(primitiveType, exprPtr->sourceLocation) };
}
else if (IsVectorType(resolvedExprType))
{
PrimitiveType primitiveType = std::get<VectorType>(resolvedExprType).type;
if (primitiveType != targetBaseType)
throw ShaderLang::CompilerCastIncompatibleBaseTypesError{ exprPtr->sourceLocation, ToString(targetBaseType, node.sourceLocation), ToString(primitiveType, exprPtr->sourceLocation) };
}
else
throw ShaderLang::CompilerCastIncompatibleTypesError{ exprPtr->sourceLocation, ToString(targetType, node.sourceLocation), ToString(resolvedExprType, exprPtr->sourceLocation) };
componentCount += GetComponentCount(resolvedExprType);
@ -3235,6 +3334,8 @@ namespace Nz::ShaderAst
if (componentCount != requiredComponents)
throw ShaderLang::CompilerCastComponentMismatchError{ node.sourceLocation, SafeCast<UInt32>(componentCount), SafeCast<UInt32>(requiredComponents) };
}
else
throw ShaderLang::CompilerInvalidCastError{ node.sourceLocation, ToString(targetType, node.sourceLocation) };
node.cachedExpressionType = targetType;
node.targetType = targetType;

View File

@ -509,8 +509,7 @@ namespace Nz
break; //< Already handled
case ShaderAst::PrimitiveType::UInt32:
castOp = SpirvOp::OpSConvert;
break;
throw std::runtime_error("unsupported cast from int32");
case ShaderAst::PrimitiveType::String:
throw std::runtime_error("unexpected string type");
@ -530,7 +529,7 @@ namespace Nz
break;
case ShaderAst::PrimitiveType::Int32:
castOp = SpirvOp::OpUConvert;
castOp = SpirvOp::OpBitcast;
break;
case ShaderAst::PrimitiveType::UInt32: