Shader: Handle matrix cast properly

This commit is contained in:
Jérôme Leclercq 2022-01-23 19:58:04 +01:00
parent 249aebac05
commit 64efd81bf8
7 changed files with 305 additions and 11 deletions

View File

@ -40,9 +40,10 @@ namespace Nz::ShaderAst
std::unordered_set<std::string> reservedIdentifiers;
std::unordered_map<std::size_t, ConstantValue> optionValues;
bool makeVariableNameUnique = false;
bool removeConstDeclaration = false;
bool reduceLoopsToWhile = false;
bool removeConstDeclaration = false;
bool removeCompoundAssignments = false;
bool removeMatrixCast = false;
bool removeOptionDeclaration = false;
bool removeScalarSwizzling = false;
bool splitMultipleBranches = false;

View File

@ -54,6 +54,7 @@ namespace Nz::ShaderBuilder
struct Cast
{
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, ShaderAst::ExpressionPtr expression) const;
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, std::array<ShaderAst::ExpressionPtr, 4> expressions) const;
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, std::vector<ShaderAst::ExpressionPtr> expressions) const;
};
@ -71,6 +72,7 @@ namespace Nz::ShaderBuilder
struct Constant
{
inline std::unique_ptr<ShaderAst::ConstantValueExpression> operator()(ShaderAst::ConstantValue value) const;
template<typename T> std::unique_ptr<ShaderAst::ConstantValueExpression> operator()(ShaderAst::ExpressionType type, T value) const;
};
struct DeclareConst

View File

@ -3,6 +3,7 @@
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderBuilder.hpp>
#include <stdexcept>
#include <Nazara/Shader/Debug.hpp>
namespace Nz::ShaderBuilder
@ -110,6 +111,15 @@ namespace Nz::ShaderBuilder
return callFunctionExpression;
}
inline std::unique_ptr<ShaderAst::CastExpression> Impl::Cast::operator()(ShaderAst::ExpressionType targetType, ShaderAst::ExpressionPtr expression) const
{
auto castNode = std::make_unique<ShaderAst::CastExpression>();
castNode->targetType = std::move(targetType);
castNode->expressions[0] = std::move(expression);
return castNode;
}
inline std::unique_ptr<ShaderAst::CastExpression> Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::array<ShaderAst::ExpressionPtr, 4> expressions) const
{
auto castNode = std::make_unique<ShaderAst::CastExpression>();
@ -159,6 +169,22 @@ namespace Nz::ShaderBuilder
return constantNode;
}
template<typename T>
std::unique_ptr<ShaderAst::ConstantValueExpression> Impl::Constant::operator()(ShaderAst::ExpressionType type, T value) const
{
assert(IsPrimitiveType(type));
switch (std::get<ShaderAst::PrimitiveType>(type))
{
case ShaderAst::PrimitiveType::Boolean: return ShaderBuilder::Constant(value != T(0));
case ShaderAst::PrimitiveType::Float32: return ShaderBuilder::Constant(SafeCast<float>(value));
case ShaderAst::PrimitiveType::Int32: return ShaderBuilder::Constant(SafeCast<Int32>(value));
case ShaderAst::PrimitiveType::UInt32: return ShaderBuilder::Constant(SafeCast<UInt32>(value));
}
throw std::runtime_error("unexpected primitive type");
}
inline std::unique_ptr<ShaderAst::DeclareConstStatement> Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const
{
auto declareConstNode = std::make_unique<ShaderAst::DeclareConstStatement>();

View File

@ -10,6 +10,7 @@
#include <Nazara/Shader/Ast/AstOptimizer.hpp>
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
#include <Nazara/Shader/Ast/AstUtils.hpp>
#include <numeric>
#include <stdexcept>
#include <unordered_set>
#include <Nazara/Shader/Debug.hpp>
@ -386,6 +387,88 @@ namespace Nz::ShaderAst
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
Validate(*clone);
if (m_context->options.removeMatrixCast && IsMatrixType(clone->targetType))
{
const MatrixType& targetMatrixType = std::get<MatrixType>(clone->targetType);
const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
bool isMatrixCast = IsMatrixType(frontExprType);
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
{
// Nothing to do
return std::move(clone->expressions.front());
}
auto variableDeclaration = ShaderBuilder::DeclareVariable("temp", clone->targetType); //< Validation will prevent name-clash if required
Validate(*variableDeclaration);
std::size_t variableIndex = *variableDeclaration->varIndex;
m_context->currentStatementList->emplace_back(std::move(variableDeclaration));
for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i)
{
// temp[i]
auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, clone->targetType), ShaderBuilder::Constant(UInt32(i)));
Validate(*columnExpr);
// vector expression
ExpressionPtr vectorExpr;
std::size_t vectorComponentCount;
if (isMatrixCast)
{
// fromMatrix[i]
auto matrixColumnExpr = ShaderBuilder::AccessIndex(CloneExpression(clone->expressions.front()), ShaderBuilder::Constant(UInt32(i)));
Validate(*matrixColumnExpr);
vectorExpr = std::move(matrixColumnExpr);
vectorComponentCount = std::get<MatrixType>(frontExprType).rowCount;
}
else
{
// parameter #i
vectorExpr = std::move(clone->expressions[i]);
vectorComponentCount = std::get<VectorType>(GetExpressionType(*vectorExpr)).componentCount;
}
// cast expression (turn fromMatrix[i] to vec3<f32>(fromMatrix[i]))
ExpressionPtr castExpr;
if (vectorComponentCount != targetMatrixType.rowCount)
{
CastExpressionPtr vecCast;
if (vectorComponentCount < targetMatrixType.rowCount)
{
std::array<ExpressionPtr, 4> expressions;
expressions[0] = std::move(vectorExpr);
for (std::size_t j = 0; j < targetMatrixType.rowCount - vectorComponentCount; ++j)
expressions[j + 1] = ShaderBuilder::Constant(targetMatrixType.type, (i == j + vectorComponentCount) ? 1 : 0); //< set 1 to diagonal
vecCast = ShaderBuilder::Cast(VectorType{ targetMatrixType.rowCount, targetMatrixType.type }, std::move(expressions));
Validate(*vecCast);
castExpr = std::move(vecCast);
}
else
{
std::array<UInt32, 4> swizzleComponents;
std::iota(swizzleComponents.begin(), swizzleComponents.begin() + targetMatrixType.rowCount, 0);
auto swizzleExpr = ShaderBuilder::Swizzle(std::move(vectorExpr), swizzleComponents, targetMatrixType.rowCount);
Validate(*swizzleExpr);
castExpr = std::move(swizzleExpr);
}
}
else
castExpr = std::move(vectorExpr);
// temp[i] = castExpr
m_context->currentStatementList->emplace_back(ShaderBuilder::ExpressionStatement(ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(castExpr))));
}
return ShaderBuilder::Variable(variableIndex, clone->targetType);
}
return clone;
}
@ -653,7 +736,7 @@ namespace Nz::ShaderAst
else if (IsSamplerType(extVar.type))
varType = extVar.type;
else
throw AstError{ "External variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
std::size_t varIndex = RegisterVariable(extVar.name, std::move(varType));
if (!clone->varIndex)
@ -820,6 +903,18 @@ namespace Nz::ShaderAst
declaredMembers.insert(member.name);
member.type = ResolveType(member.type);
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
{
if (IsPrimitiveType(member.type) && std::get<PrimitiveType>(member.type) == PrimitiveType::Boolean)
throw AstError{ "boolean type is not allowed in std140 layout" };
else if (IsStructType(member.type))
{
std::size_t structIndex = std::get<StructType>(member.type).structIndex;
const StructDescription* desc = m_context->structs[structIndex];
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
throw AstError{ "inner struct layout mismatch" };
}
}
}
clone->structIndex = RegisterStruct(clone->description.name, &clone->description);
@ -1695,18 +1790,44 @@ namespace Nz::ShaderAst
void SanitizeVisitor::Validate(CastExpression& node)
{
node.cachedExpressionType = node.targetType;
node.targetType = ResolveType(node.targetType);
node.cachedExpressionType = node.targetType;
// Allow casting a matrix to itself (wtf?)
// FIXME: Make proper rules
if (IsMatrixType(node.targetType) && node.expressions.front())
const auto& firstExprPtr = node.expressions.front();
if (!firstExprPtr)
throw AstError{ "expected at least one expression" };
if (IsMatrixType(node.targetType))
{
const ExpressionType& exprType = GetExpressionType(*node.expressions.front());
if (IsMatrixType(exprType) && !node.expressions[1])
const MatrixType& targetMatrixType = std::get<MatrixType>(node.targetType);
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
if (IsMatrixType(firstExprType))
{
if (node.expressions[1])
throw AstError{ "too many expressions" };
// Matrix to matrix cast: always valid
return;
}
else
{
assert(targetMatrixType.columnCount <= 4);
for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i)
{
const auto& exprPtr = node.expressions[i];
if (!exprPtr)
throw AstError{ "component count doesn't match required component count" };
const ExpressionType& exprType = GetExpressionType(*exprPtr);
if (!IsVectorType(exprType))
throw AstError{ "expected vector type" };
const VectorType& vecType = std::get<VectorType>(exprType);
if (vecType.componentCount != targetMatrixType.rowCount)
throw AstError{ "vector component count must match target matrix row count" };
}
}
}
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
@ -1936,9 +2057,9 @@ namespace Nz::ShaderAst
if (node.componentCount > 4)
throw AstError{ "cannot swizzle more than four elements" };
for (UInt32 swizzleIndex : node.components)
for (std::size_t i = 0; i < node.componentCount; ++i)
{
if (swizzleIndex >= componentCount)
if (node.components[i] >= componentCount)
throw AstError{ "invalid swizzle" };
}

View File

@ -235,6 +235,20 @@ namespace Nz::ShaderLang
return matrixType;
}
else if (identifier == "mat2")
{
Consume();
ShaderAst::MatrixType matrixType;
matrixType.columnCount = 2;
matrixType.rowCount = 2;
Expect(Advance(), TokenType::LessThan); //< '<'
matrixType.type = ParsePrimitiveType();
Expect(Advance(), TokenType::GreaterThan); //< '>'
return matrixType;
}
else if (identifier == "sampler2D")
{
Consume();

View File

@ -491,7 +491,8 @@ namespace Nz
options.optionValues = states.optionValues;
options.reduceLoopsToWhile = true;
options.removeCompoundAssignments = true;
options.removeOptionDeclaration = true;
options.removeMatrixCast = true;
options.removeOptionDeclaration = true;
options.splitMultipleBranches = true;
options.useIdentifierAccessesForStructs = false;

View File

@ -124,6 +124,135 @@ fn main()
}
}
)");
}
WHEN("removing matrix casts")
{
std::string_view nzslSource = R"(
fn testMat2ToMat2(input: mat2<f32>) -> mat2<f32>
{
return mat2<f32>(input);
}
fn testMat2ToMat3(input: mat2<f32>) -> mat3<f32>
{
return mat3<f32>(input);
}
fn testMat2ToMat4(input: mat2<f32>) -> mat4<f32>
{
return mat4<f32>(input);
}
fn testMat3ToMat2(input: mat3<f32>) -> mat2<f32>
{
return mat2<f32>(input);
}
fn testMat3ToMat3(input: mat3<f32>) -> mat3<f32>
{
return mat3<f32>(input);
}
fn testMat3ToMat4(input: mat3<f32>) -> mat4<f32>
{
return mat4<f32>(input);
}
fn testMat4ToMat2(input: mat4<f32>) -> mat2<f32>
{
return mat2<f32>(input);
}
fn testMat4ToMat3(input: mat4<f32>) -> mat3<f32>
{
return mat3<f32>(input);
}
fn testMat4ToMat4(input: mat4<f32>) -> mat4<f32>
{
return mat4<f32>(input);
}
)";
Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource);
Nz::ShaderAst::SanitizeVisitor::Options options;
options.removeMatrixCast = true;
REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader, options));
ExpectNZSL(*shader, R"(
fn testMat2ToMat2(input: mat2<f32>) -> mat2<f32>
{
return input;
}
fn testMat2ToMat3(input: mat2<f32>) -> mat3<f32>
{
let temp: mat3<f32>;
temp[0] = vec3<f32>(input[0], 0.000000);
temp[1] = vec3<f32>(input[1], 0.000000);
temp[2] = vec3<f32>(input[2], 1.000000);
return temp;
}
fn testMat2ToMat4(input: mat2<f32>) -> mat4<f32>
{
let temp: mat4<f32>;
temp[0] = vec4<f32>(input[0], 0.000000, 0.000000);
temp[1] = vec4<f32>(input[1], 0.000000, 0.000000);
temp[2] = vec4<f32>(input[2], 1.000000, 0.000000);
temp[3] = vec4<f32>(input[3], 0.000000, 1.000000);
return temp;
}
fn testMat3ToMat2(input: mat3<f32>) -> mat2<f32>
{
let temp: mat2<f32>;
temp[0] = input[0].xy;
temp[1] = input[1].xy;
return temp;
}
fn testMat3ToMat3(input: mat3<f32>) -> mat3<f32>
{
return input;
}
fn testMat3ToMat4(input: mat3<f32>) -> mat4<f32>
{
let temp: mat4<f32>;
temp[0] = vec4<f32>(input[0], 0.000000);
temp[1] = vec4<f32>(input[1], 0.000000);
temp[2] = vec4<f32>(input[2], 0.000000);
temp[3] = vec4<f32>(input[3], 1.000000);
return temp;
}
fn testMat4ToMat2(input: mat4<f32>) -> mat2<f32>
{
let temp: mat2<f32>;
temp[0] = input[0].xy;
temp[1] = input[1].xy;
return temp;
}
fn testMat4ToMat3(input: mat4<f32>) -> mat3<f32>
{
let temp: mat3<f32>;
temp[0] = input[0].xyz;
temp[1] = input[1].xyz;
temp[2] = input[2].xyz;
return temp;
}
fn testMat4ToMat4(input: mat4<f32>) -> mat4<f32>
{
return input;
}
)");
}