Shader: Handle matrix cast properly
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include <Nazara/Shader/Ast/AstOptimizer.hpp>
|
||||
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
|
||||
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
||||
#include <numeric>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
@@ -386,6 +387,88 @@ namespace Nz::ShaderAst
|
||||
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
|
||||
Validate(*clone);
|
||||
|
||||
if (m_context->options.removeMatrixCast && IsMatrixType(clone->targetType))
|
||||
{
|
||||
const MatrixType& targetMatrixType = std::get<MatrixType>(clone->targetType);
|
||||
|
||||
const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
|
||||
bool isMatrixCast = IsMatrixType(frontExprType);
|
||||
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
|
||||
{
|
||||
// Nothing to do
|
||||
return std::move(clone->expressions.front());
|
||||
}
|
||||
|
||||
auto variableDeclaration = ShaderBuilder::DeclareVariable("temp", clone->targetType); //< Validation will prevent name-clash if required
|
||||
Validate(*variableDeclaration);
|
||||
|
||||
std::size_t variableIndex = *variableDeclaration->varIndex;
|
||||
|
||||
m_context->currentStatementList->emplace_back(std::move(variableDeclaration));
|
||||
|
||||
for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i)
|
||||
{
|
||||
// temp[i]
|
||||
auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, clone->targetType), ShaderBuilder::Constant(UInt32(i)));
|
||||
Validate(*columnExpr);
|
||||
|
||||
// vector expression
|
||||
ExpressionPtr vectorExpr;
|
||||
std::size_t vectorComponentCount;
|
||||
if (isMatrixCast)
|
||||
{
|
||||
// fromMatrix[i]
|
||||
auto matrixColumnExpr = ShaderBuilder::AccessIndex(CloneExpression(clone->expressions.front()), ShaderBuilder::Constant(UInt32(i)));
|
||||
Validate(*matrixColumnExpr);
|
||||
|
||||
vectorExpr = std::move(matrixColumnExpr);
|
||||
vectorComponentCount = std::get<MatrixType>(frontExprType).rowCount;
|
||||
}
|
||||
else
|
||||
{
|
||||
// parameter #i
|
||||
vectorExpr = std::move(clone->expressions[i]);
|
||||
vectorComponentCount = std::get<VectorType>(GetExpressionType(*vectorExpr)).componentCount;
|
||||
}
|
||||
|
||||
// cast expression (turn fromMatrix[i] to vec3<f32>(fromMatrix[i]))
|
||||
ExpressionPtr castExpr;
|
||||
if (vectorComponentCount != targetMatrixType.rowCount)
|
||||
{
|
||||
CastExpressionPtr vecCast;
|
||||
if (vectorComponentCount < targetMatrixType.rowCount)
|
||||
{
|
||||
std::array<ExpressionPtr, 4> expressions;
|
||||
expressions[0] = std::move(vectorExpr);
|
||||
for (std::size_t j = 0; j < targetMatrixType.rowCount - vectorComponentCount; ++j)
|
||||
expressions[j + 1] = ShaderBuilder::Constant(targetMatrixType.type, (i == j + vectorComponentCount) ? 1 : 0); //< set 1 to diagonal
|
||||
|
||||
vecCast = ShaderBuilder::Cast(VectorType{ targetMatrixType.rowCount, targetMatrixType.type }, std::move(expressions));
|
||||
Validate(*vecCast);
|
||||
|
||||
castExpr = std::move(vecCast);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::array<UInt32, 4> swizzleComponents;
|
||||
std::iota(swizzleComponents.begin(), swizzleComponents.begin() + targetMatrixType.rowCount, 0);
|
||||
|
||||
auto swizzleExpr = ShaderBuilder::Swizzle(std::move(vectorExpr), swizzleComponents, targetMatrixType.rowCount);
|
||||
Validate(*swizzleExpr);
|
||||
|
||||
castExpr = std::move(swizzleExpr);
|
||||
}
|
||||
}
|
||||
else
|
||||
castExpr = std::move(vectorExpr);
|
||||
|
||||
// temp[i] = castExpr
|
||||
m_context->currentStatementList->emplace_back(ShaderBuilder::ExpressionStatement(ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(castExpr))));
|
||||
}
|
||||
|
||||
return ShaderBuilder::Variable(variableIndex, clone->targetType);
|
||||
}
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
@@ -653,7 +736,7 @@ namespace Nz::ShaderAst
|
||||
else if (IsSamplerType(extVar.type))
|
||||
varType = extVar.type;
|
||||
else
|
||||
throw AstError{ "External variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
|
||||
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
|
||||
|
||||
std::size_t varIndex = RegisterVariable(extVar.name, std::move(varType));
|
||||
if (!clone->varIndex)
|
||||
@@ -820,6 +903,18 @@ namespace Nz::ShaderAst
|
||||
declaredMembers.insert(member.name);
|
||||
|
||||
member.type = ResolveType(member.type);
|
||||
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
|
||||
{
|
||||
if (IsPrimitiveType(member.type) && std::get<PrimitiveType>(member.type) == PrimitiveType::Boolean)
|
||||
throw AstError{ "boolean type is not allowed in std140 layout" };
|
||||
else if (IsStructType(member.type))
|
||||
{
|
||||
std::size_t structIndex = std::get<StructType>(member.type).structIndex;
|
||||
const StructDescription* desc = m_context->structs[structIndex];
|
||||
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
|
||||
throw AstError{ "inner struct layout mismatch" };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clone->structIndex = RegisterStruct(clone->description.name, &clone->description);
|
||||
@@ -1695,18 +1790,44 @@ namespace Nz::ShaderAst
|
||||
|
||||
void SanitizeVisitor::Validate(CastExpression& node)
|
||||
{
|
||||
node.cachedExpressionType = node.targetType;
|
||||
node.targetType = ResolveType(node.targetType);
|
||||
node.cachedExpressionType = node.targetType;
|
||||
|
||||
// Allow casting a matrix to itself (wtf?)
|
||||
// FIXME: Make proper rules
|
||||
if (IsMatrixType(node.targetType) && node.expressions.front())
|
||||
const auto& firstExprPtr = node.expressions.front();
|
||||
if (!firstExprPtr)
|
||||
throw AstError{ "expected at least one expression" };
|
||||
|
||||
if (IsMatrixType(node.targetType))
|
||||
{
|
||||
const ExpressionType& exprType = GetExpressionType(*node.expressions.front());
|
||||
if (IsMatrixType(exprType) && !node.expressions[1])
|
||||
const MatrixType& targetMatrixType = std::get<MatrixType>(node.targetType);
|
||||
|
||||
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
|
||||
if (IsMatrixType(firstExprType))
|
||||
{
|
||||
if (node.expressions[1])
|
||||
throw AstError{ "too many expressions" };
|
||||
|
||||
// Matrix to matrix cast: always valid
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(targetMatrixType.columnCount <= 4);
|
||||
for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i)
|
||||
{
|
||||
const auto& exprPtr = node.expressions[i];
|
||||
if (!exprPtr)
|
||||
throw AstError{ "component count doesn't match required component count" };
|
||||
|
||||
const ExpressionType& exprType = GetExpressionType(*exprPtr);
|
||||
if (!IsVectorType(exprType))
|
||||
throw AstError{ "expected vector type" };
|
||||
|
||||
const VectorType& vecType = std::get<VectorType>(exprType);
|
||||
if (vecType.componentCount != targetMatrixType.rowCount)
|
||||
throw AstError{ "vector component count must match target matrix row count" };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
|
||||
@@ -1936,9 +2057,9 @@ namespace Nz::ShaderAst
|
||||
if (node.componentCount > 4)
|
||||
throw AstError{ "cannot swizzle more than four elements" };
|
||||
|
||||
for (UInt32 swizzleIndex : node.components)
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
{
|
||||
if (swizzleIndex >= componentCount)
|
||||
if (node.components[i] >= componentCount)
|
||||
throw AstError{ "invalid swizzle" };
|
||||
}
|
||||
|
||||
|
||||
@@ -235,6 +235,20 @@ namespace Nz::ShaderLang
|
||||
|
||||
return matrixType;
|
||||
}
|
||||
else if (identifier == "mat2")
|
||||
{
|
||||
Consume();
|
||||
|
||||
ShaderAst::MatrixType matrixType;
|
||||
matrixType.columnCount = 2;
|
||||
matrixType.rowCount = 2;
|
||||
|
||||
Expect(Advance(), TokenType::LessThan); //< '<'
|
||||
matrixType.type = ParsePrimitiveType();
|
||||
Expect(Advance(), TokenType::GreaterThan); //< '>'
|
||||
|
||||
return matrixType;
|
||||
}
|
||||
else if (identifier == "sampler2D")
|
||||
{
|
||||
Consume();
|
||||
|
||||
@@ -491,7 +491,8 @@ namespace Nz
|
||||
options.optionValues = states.optionValues;
|
||||
options.reduceLoopsToWhile = true;
|
||||
options.removeCompoundAssignments = true;
|
||||
options.removeOptionDeclaration = true;
|
||||
options.removeMatrixCast = true;
|
||||
options.removeOptionDeclaration = true;
|
||||
options.splitMultipleBranches = true;
|
||||
options.useIdentifierAccessesForStructs = false;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user