Shader: Handle matrix cast properly

This commit is contained in:
Jérôme Leclercq
2022-01-23 19:58:04 +01:00
parent 249aebac05
commit 64efd81bf8
7 changed files with 305 additions and 11 deletions

View File

@@ -10,6 +10,7 @@
#include <Nazara/Shader/Ast/AstOptimizer.hpp>
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
#include <Nazara/Shader/Ast/AstUtils.hpp>
#include <numeric>
#include <stdexcept>
#include <unordered_set>
#include <Nazara/Shader/Debug.hpp>
@@ -386,6 +387,88 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
Validate(*clone);
if (m_context->options.removeMatrixCast && IsMatrixType(clone->targetType))
{
const MatrixType& targetMatrixType = std::get<MatrixType>(clone->targetType);
const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
bool isMatrixCast = IsMatrixType(frontExprType);
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
{
// Nothing to do
return std::move(clone->expressions.front());
}
auto variableDeclaration = ShaderBuilder::DeclareVariable("temp", clone->targetType); //< Validation will prevent name-clash if required
Validate(*variableDeclaration);
std::size_t variableIndex = *variableDeclaration->varIndex;
m_context->currentStatementList->emplace_back(std::move(variableDeclaration));
for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i)
{
// temp[i]
auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, clone->targetType), ShaderBuilder::Constant(UInt32(i)));
Validate(*columnExpr);
// vector expression
ExpressionPtr vectorExpr;
std::size_t vectorComponentCount;
if (isMatrixCast)
{
// fromMatrix[i]
auto matrixColumnExpr = ShaderBuilder::AccessIndex(CloneExpression(clone->expressions.front()), ShaderBuilder::Constant(UInt32(i)));
Validate(*matrixColumnExpr);
vectorExpr = std::move(matrixColumnExpr);
vectorComponentCount = std::get<MatrixType>(frontExprType).rowCount;
}
else
{
// parameter #i
vectorExpr = std::move(clone->expressions[i]);
vectorComponentCount = std::get<VectorType>(GetExpressionType(*vectorExpr)).componentCount;
}
// cast expression (turn fromMatrix[i] to vec3<f32>(fromMatrix[i]))
ExpressionPtr castExpr;
if (vectorComponentCount != targetMatrixType.rowCount)
{
CastExpressionPtr vecCast;
if (vectorComponentCount < targetMatrixType.rowCount)
{
std::array<ExpressionPtr, 4> expressions;
expressions[0] = std::move(vectorExpr);
for (std::size_t j = 0; j < targetMatrixType.rowCount - vectorComponentCount; ++j)
expressions[j + 1] = ShaderBuilder::Constant(targetMatrixType.type, (i == j + vectorComponentCount) ? 1 : 0); //< set 1 to diagonal
vecCast = ShaderBuilder::Cast(VectorType{ targetMatrixType.rowCount, targetMatrixType.type }, std::move(expressions));
Validate(*vecCast);
castExpr = std::move(vecCast);
}
else
{
std::array<UInt32, 4> swizzleComponents;
std::iota(swizzleComponents.begin(), swizzleComponents.begin() + targetMatrixType.rowCount, 0);
auto swizzleExpr = ShaderBuilder::Swizzle(std::move(vectorExpr), swizzleComponents, targetMatrixType.rowCount);
Validate(*swizzleExpr);
castExpr = std::move(swizzleExpr);
}
}
else
castExpr = std::move(vectorExpr);
// temp[i] = castExpr
m_context->currentStatementList->emplace_back(ShaderBuilder::ExpressionStatement(ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(castExpr))));
}
return ShaderBuilder::Variable(variableIndex, clone->targetType);
}
return clone;
}
@@ -653,7 +736,7 @@ namespace Nz::ShaderAst
else if (IsSamplerType(extVar.type))
varType = extVar.type;
else
throw AstError{ "External variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
std::size_t varIndex = RegisterVariable(extVar.name, std::move(varType));
if (!clone->varIndex)
@@ -820,6 +903,18 @@ namespace Nz::ShaderAst
declaredMembers.insert(member.name);
member.type = ResolveType(member.type);
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
{
if (IsPrimitiveType(member.type) && std::get<PrimitiveType>(member.type) == PrimitiveType::Boolean)
throw AstError{ "boolean type is not allowed in std140 layout" };
else if (IsStructType(member.type))
{
std::size_t structIndex = std::get<StructType>(member.type).structIndex;
const StructDescription* desc = m_context->structs[structIndex];
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
throw AstError{ "inner struct layout mismatch" };
}
}
}
clone->structIndex = RegisterStruct(clone->description.name, &clone->description);
@@ -1695,18 +1790,44 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(CastExpression& node)
{
node.cachedExpressionType = node.targetType;
node.targetType = ResolveType(node.targetType);
node.cachedExpressionType = node.targetType;
// Allow casting a matrix to itself (wtf?)
// FIXME: Make proper rules
if (IsMatrixType(node.targetType) && node.expressions.front())
const auto& firstExprPtr = node.expressions.front();
if (!firstExprPtr)
throw AstError{ "expected at least one expression" };
if (IsMatrixType(node.targetType))
{
const ExpressionType& exprType = GetExpressionType(*node.expressions.front());
if (IsMatrixType(exprType) && !node.expressions[1])
const MatrixType& targetMatrixType = std::get<MatrixType>(node.targetType);
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
if (IsMatrixType(firstExprType))
{
if (node.expressions[1])
throw AstError{ "too many expressions" };
// Matrix to matrix cast: always valid
return;
}
else
{
assert(targetMatrixType.columnCount <= 4);
for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i)
{
const auto& exprPtr = node.expressions[i];
if (!exprPtr)
throw AstError{ "component count doesn't match required component count" };
const ExpressionType& exprType = GetExpressionType(*exprPtr);
if (!IsVectorType(exprType))
throw AstError{ "expected vector type" };
const VectorType& vecType = std::get<VectorType>(exprType);
if (vecType.componentCount != targetMatrixType.rowCount)
throw AstError{ "vector component count must match target matrix row count" };
}
}
}
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
@@ -1936,9 +2057,9 @@ namespace Nz::ShaderAst
if (node.componentCount > 4)
throw AstError{ "cannot swizzle more than four elements" };
for (UInt32 swizzleIndex : node.components)
for (std::size_t i = 0; i < node.componentCount; ++i)
{
if (swizzleIndex >= componentCount)
if (node.components[i] >= componentCount)
throw AstError{ "invalid swizzle" };
}

View File

@@ -235,6 +235,20 @@ namespace Nz::ShaderLang
return matrixType;
}
else if (identifier == "mat2")
{
Consume();
ShaderAst::MatrixType matrixType;
matrixType.columnCount = 2;
matrixType.rowCount = 2;
Expect(Advance(), TokenType::LessThan); //< '<'
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreaterThan); //< '>'
return matrixType;
}
else if (identifier == "sampler2D")
{
Consume();

View File

@@ -491,7 +491,8 @@ namespace Nz
options.optionValues = states.optionValues;
options.reduceLoopsToWhile = true;
options.removeCompoundAssignments = true;
options.removeOptionDeclaration = true;
options.removeMatrixCast = true;
options.removeOptionDeclaration = true;
options.splitMultipleBranches = true;
options.useIdentifierAccessesForStructs = false;