Shader: Handle matrix cast properly
This commit is contained in:
parent
249aebac05
commit
64efd81bf8
|
|
@ -40,9 +40,10 @@ namespace Nz::ShaderAst
|
|||
std::unordered_set<std::string> reservedIdentifiers;
|
||||
std::unordered_map<std::size_t, ConstantValue> optionValues;
|
||||
bool makeVariableNameUnique = false;
|
||||
bool removeConstDeclaration = false;
|
||||
bool reduceLoopsToWhile = false;
|
||||
bool removeConstDeclaration = false;
|
||||
bool removeCompoundAssignments = false;
|
||||
bool removeMatrixCast = false;
|
||||
bool removeOptionDeclaration = false;
|
||||
bool removeScalarSwizzling = false;
|
||||
bool splitMultipleBranches = false;
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ namespace Nz::ShaderBuilder
|
|||
|
||||
struct Cast
|
||||
{
|
||||
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, ShaderAst::ExpressionPtr expression) const;
|
||||
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, std::array<ShaderAst::ExpressionPtr, 4> expressions) const;
|
||||
inline std::unique_ptr<ShaderAst::CastExpression> operator()(ShaderAst::ExpressionType targetType, std::vector<ShaderAst::ExpressionPtr> expressions) const;
|
||||
};
|
||||
|
|
@ -71,6 +72,7 @@ namespace Nz::ShaderBuilder
|
|||
struct Constant
|
||||
{
|
||||
inline std::unique_ptr<ShaderAst::ConstantValueExpression> operator()(ShaderAst::ConstantValue value) const;
|
||||
template<typename T> std::unique_ptr<ShaderAst::ConstantValueExpression> operator()(ShaderAst::ExpressionType type, T value) const;
|
||||
};
|
||||
|
||||
struct DeclareConst
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderBuilder.hpp>
|
||||
#include <stdexcept>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz::ShaderBuilder
|
||||
|
|
@ -110,6 +111,15 @@ namespace Nz::ShaderBuilder
|
|||
return callFunctionExpression;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::CastExpression> Impl::Cast::operator()(ShaderAst::ExpressionType targetType, ShaderAst::ExpressionPtr expression) const
|
||||
{
|
||||
auto castNode = std::make_unique<ShaderAst::CastExpression>();
|
||||
castNode->targetType = std::move(targetType);
|
||||
castNode->expressions[0] = std::move(expression);
|
||||
|
||||
return castNode;
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::CastExpression> Impl::Cast::operator()(ShaderAst::ExpressionType targetType, std::array<ShaderAst::ExpressionPtr, 4> expressions) const
|
||||
{
|
||||
auto castNode = std::make_unique<ShaderAst::CastExpression>();
|
||||
|
|
@ -159,6 +169,22 @@ namespace Nz::ShaderBuilder
|
|||
return constantNode;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::unique_ptr<ShaderAst::ConstantValueExpression> Impl::Constant::operator()(ShaderAst::ExpressionType type, T value) const
|
||||
{
|
||||
assert(IsPrimitiveType(type));
|
||||
|
||||
switch (std::get<ShaderAst::PrimitiveType>(type))
|
||||
{
|
||||
case ShaderAst::PrimitiveType::Boolean: return ShaderBuilder::Constant(value != T(0));
|
||||
case ShaderAst::PrimitiveType::Float32: return ShaderBuilder::Constant(SafeCast<float>(value));
|
||||
case ShaderAst::PrimitiveType::Int32: return ShaderBuilder::Constant(SafeCast<Int32>(value));
|
||||
case ShaderAst::PrimitiveType::UInt32: return ShaderBuilder::Constant(SafeCast<UInt32>(value));
|
||||
}
|
||||
|
||||
throw std::runtime_error("unexpected primitive type");
|
||||
}
|
||||
|
||||
inline std::unique_ptr<ShaderAst::DeclareConstStatement> Impl::DeclareConst::operator()(std::string name, ShaderAst::ExpressionPtr initialValue) const
|
||||
{
|
||||
auto declareConstNode = std::make_unique<ShaderAst::DeclareConstStatement>();
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <Nazara/Shader/Ast/AstOptimizer.hpp>
|
||||
#include <Nazara/Shader/Ast/AstRecursiveVisitor.hpp>
|
||||
#include <Nazara/Shader/Ast/AstUtils.hpp>
|
||||
#include <numeric>
|
||||
#include <stdexcept>
|
||||
#include <unordered_set>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
|
@ -386,6 +387,88 @@ namespace Nz::ShaderAst
|
|||
auto clone = static_unique_pointer_cast<CastExpression>(AstCloner::Clone(node));
|
||||
Validate(*clone);
|
||||
|
||||
if (m_context->options.removeMatrixCast && IsMatrixType(clone->targetType))
|
||||
{
|
||||
const MatrixType& targetMatrixType = std::get<MatrixType>(clone->targetType);
|
||||
|
||||
const ShaderAst::ExpressionType& frontExprType = GetExpressionType(*clone->expressions.front());
|
||||
bool isMatrixCast = IsMatrixType(frontExprType);
|
||||
if (isMatrixCast && std::get<MatrixType>(frontExprType) == targetMatrixType)
|
||||
{
|
||||
// Nothing to do
|
||||
return std::move(clone->expressions.front());
|
||||
}
|
||||
|
||||
auto variableDeclaration = ShaderBuilder::DeclareVariable("temp", clone->targetType); //< Validation will prevent name-clash if required
|
||||
Validate(*variableDeclaration);
|
||||
|
||||
std::size_t variableIndex = *variableDeclaration->varIndex;
|
||||
|
||||
m_context->currentStatementList->emplace_back(std::move(variableDeclaration));
|
||||
|
||||
for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i)
|
||||
{
|
||||
// temp[i]
|
||||
auto columnExpr = ShaderBuilder::AccessIndex(ShaderBuilder::Variable(variableIndex, clone->targetType), ShaderBuilder::Constant(UInt32(i)));
|
||||
Validate(*columnExpr);
|
||||
|
||||
// vector expression
|
||||
ExpressionPtr vectorExpr;
|
||||
std::size_t vectorComponentCount;
|
||||
if (isMatrixCast)
|
||||
{
|
||||
// fromMatrix[i]
|
||||
auto matrixColumnExpr = ShaderBuilder::AccessIndex(CloneExpression(clone->expressions.front()), ShaderBuilder::Constant(UInt32(i)));
|
||||
Validate(*matrixColumnExpr);
|
||||
|
||||
vectorExpr = std::move(matrixColumnExpr);
|
||||
vectorComponentCount = std::get<MatrixType>(frontExprType).rowCount;
|
||||
}
|
||||
else
|
||||
{
|
||||
// parameter #i
|
||||
vectorExpr = std::move(clone->expressions[i]);
|
||||
vectorComponentCount = std::get<VectorType>(GetExpressionType(*vectorExpr)).componentCount;
|
||||
}
|
||||
|
||||
// cast expression (turn fromMatrix[i] to vec3<f32>(fromMatrix[i]))
|
||||
ExpressionPtr castExpr;
|
||||
if (vectorComponentCount != targetMatrixType.rowCount)
|
||||
{
|
||||
CastExpressionPtr vecCast;
|
||||
if (vectorComponentCount < targetMatrixType.rowCount)
|
||||
{
|
||||
std::array<ExpressionPtr, 4> expressions;
|
||||
expressions[0] = std::move(vectorExpr);
|
||||
for (std::size_t j = 0; j < targetMatrixType.rowCount - vectorComponentCount; ++j)
|
||||
expressions[j + 1] = ShaderBuilder::Constant(targetMatrixType.type, (i == j + vectorComponentCount) ? 1 : 0); //< set 1 to diagonal
|
||||
|
||||
vecCast = ShaderBuilder::Cast(VectorType{ targetMatrixType.rowCount, targetMatrixType.type }, std::move(expressions));
|
||||
Validate(*vecCast);
|
||||
|
||||
castExpr = std::move(vecCast);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::array<UInt32, 4> swizzleComponents;
|
||||
std::iota(swizzleComponents.begin(), swizzleComponents.begin() + targetMatrixType.rowCount, 0);
|
||||
|
||||
auto swizzleExpr = ShaderBuilder::Swizzle(std::move(vectorExpr), swizzleComponents, targetMatrixType.rowCount);
|
||||
Validate(*swizzleExpr);
|
||||
|
||||
castExpr = std::move(swizzleExpr);
|
||||
}
|
||||
}
|
||||
else
|
||||
castExpr = std::move(vectorExpr);
|
||||
|
||||
// temp[i] = castExpr
|
||||
m_context->currentStatementList->emplace_back(ShaderBuilder::ExpressionStatement(ShaderBuilder::Assign(AssignType::Simple, std::move(columnExpr), std::move(castExpr))));
|
||||
}
|
||||
|
||||
return ShaderBuilder::Variable(variableIndex, clone->targetType);
|
||||
}
|
||||
|
||||
return clone;
|
||||
}
|
||||
|
||||
|
|
@ -653,7 +736,7 @@ namespace Nz::ShaderAst
|
|||
else if (IsSamplerType(extVar.type))
|
||||
varType = extVar.type;
|
||||
else
|
||||
throw AstError{ "External variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
|
||||
throw AstError{ "external variable " + extVar.name + " is of wrong type: only uniform and sampler are allowed in external blocks" };
|
||||
|
||||
std::size_t varIndex = RegisterVariable(extVar.name, std::move(varType));
|
||||
if (!clone->varIndex)
|
||||
|
|
@ -820,6 +903,18 @@ namespace Nz::ShaderAst
|
|||
declaredMembers.insert(member.name);
|
||||
|
||||
member.type = ResolveType(member.type);
|
||||
if (clone->description.layout.HasValue() && clone->description.layout.GetResultingValue() == StructLayout::Std140)
|
||||
{
|
||||
if (IsPrimitiveType(member.type) && std::get<PrimitiveType>(member.type) == PrimitiveType::Boolean)
|
||||
throw AstError{ "boolean type is not allowed in std140 layout" };
|
||||
else if (IsStructType(member.type))
|
||||
{
|
||||
std::size_t structIndex = std::get<StructType>(member.type).structIndex;
|
||||
const StructDescription* desc = m_context->structs[structIndex];
|
||||
if (!desc->layout.HasValue() || desc->layout.GetResultingValue() != clone->description.layout.GetResultingValue())
|
||||
throw AstError{ "inner struct layout mismatch" };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
clone->structIndex = RegisterStruct(clone->description.name, &clone->description);
|
||||
|
|
@ -1695,18 +1790,44 @@ namespace Nz::ShaderAst
|
|||
|
||||
void SanitizeVisitor::Validate(CastExpression& node)
|
||||
{
|
||||
node.cachedExpressionType = node.targetType;
|
||||
node.targetType = ResolveType(node.targetType);
|
||||
node.cachedExpressionType = node.targetType;
|
||||
|
||||
// Allow casting a matrix to itself (wtf?)
|
||||
// FIXME: Make proper rules
|
||||
if (IsMatrixType(node.targetType) && node.expressions.front())
|
||||
const auto& firstExprPtr = node.expressions.front();
|
||||
if (!firstExprPtr)
|
||||
throw AstError{ "expected at least one expression" };
|
||||
|
||||
if (IsMatrixType(node.targetType))
|
||||
{
|
||||
const ExpressionType& exprType = GetExpressionType(*node.expressions.front());
|
||||
if (IsMatrixType(exprType) && !node.expressions[1])
|
||||
const MatrixType& targetMatrixType = std::get<MatrixType>(node.targetType);
|
||||
|
||||
const ExpressionType& firstExprType = GetExpressionType(*firstExprPtr);
|
||||
if (IsMatrixType(firstExprType))
|
||||
{
|
||||
if (node.expressions[1])
|
||||
throw AstError{ "too many expressions" };
|
||||
|
||||
// Matrix to matrix cast: always valid
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(targetMatrixType.columnCount <= 4);
|
||||
for (std::size_t i = 0; i < targetMatrixType.columnCount; ++i)
|
||||
{
|
||||
const auto& exprPtr = node.expressions[i];
|
||||
if (!exprPtr)
|
||||
throw AstError{ "component count doesn't match required component count" };
|
||||
|
||||
const ExpressionType& exprType = GetExpressionType(*exprPtr);
|
||||
if (!IsVectorType(exprType))
|
||||
throw AstError{ "expected vector type" };
|
||||
|
||||
const VectorType& vecType = std::get<VectorType>(exprType);
|
||||
if (vecType.componentCount != targetMatrixType.rowCount)
|
||||
throw AstError{ "vector component count must match target matrix row count" };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto GetComponentCount = [](const ExpressionType& exprType) -> std::size_t
|
||||
|
|
@ -1936,9 +2057,9 @@ namespace Nz::ShaderAst
|
|||
if (node.componentCount > 4)
|
||||
throw AstError{ "cannot swizzle more than four elements" };
|
||||
|
||||
for (UInt32 swizzleIndex : node.components)
|
||||
for (std::size_t i = 0; i < node.componentCount; ++i)
|
||||
{
|
||||
if (swizzleIndex >= componentCount)
|
||||
if (node.components[i] >= componentCount)
|
||||
throw AstError{ "invalid swizzle" };
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -235,6 +235,20 @@ namespace Nz::ShaderLang
|
|||
|
||||
return matrixType;
|
||||
}
|
||||
else if (identifier == "mat2")
|
||||
{
|
||||
Consume();
|
||||
|
||||
ShaderAst::MatrixType matrixType;
|
||||
matrixType.columnCount = 2;
|
||||
matrixType.rowCount = 2;
|
||||
|
||||
Expect(Advance(), TokenType::LessThan); //< '<'
|
||||
matrixType.type = ParsePrimitiveType();
|
||||
Expect(Advance(), TokenType::GreaterThan); //< '>'
|
||||
|
||||
return matrixType;
|
||||
}
|
||||
else if (identifier == "sampler2D")
|
||||
{
|
||||
Consume();
|
||||
|
|
|
|||
|
|
@ -491,7 +491,8 @@ namespace Nz
|
|||
options.optionValues = states.optionValues;
|
||||
options.reduceLoopsToWhile = true;
|
||||
options.removeCompoundAssignments = true;
|
||||
options.removeOptionDeclaration = true;
|
||||
options.removeMatrixCast = true;
|
||||
options.removeOptionDeclaration = true;
|
||||
options.splitMultipleBranches = true;
|
||||
options.useIdentifierAccessesForStructs = false;
|
||||
|
||||
|
|
|
|||
|
|
@ -124,6 +124,135 @@ fn main()
|
|||
}
|
||||
|
||||
}
|
||||
)");
|
||||
|
||||
}
|
||||
|
||||
WHEN("removing matrix casts")
|
||||
{
|
||||
std::string_view nzslSource = R"(
|
||||
fn testMat2ToMat2(input: mat2<f32>) -> mat2<f32>
|
||||
{
|
||||
return mat2<f32>(input);
|
||||
}
|
||||
|
||||
fn testMat2ToMat3(input: mat2<f32>) -> mat3<f32>
|
||||
{
|
||||
return mat3<f32>(input);
|
||||
}
|
||||
|
||||
fn testMat2ToMat4(input: mat2<f32>) -> mat4<f32>
|
||||
{
|
||||
return mat4<f32>(input);
|
||||
}
|
||||
|
||||
fn testMat3ToMat2(input: mat3<f32>) -> mat2<f32>
|
||||
{
|
||||
return mat2<f32>(input);
|
||||
}
|
||||
|
||||
fn testMat3ToMat3(input: mat3<f32>) -> mat3<f32>
|
||||
{
|
||||
return mat3<f32>(input);
|
||||
}
|
||||
|
||||
fn testMat3ToMat4(input: mat3<f32>) -> mat4<f32>
|
||||
{
|
||||
return mat4<f32>(input);
|
||||
}
|
||||
|
||||
fn testMat4ToMat2(input: mat4<f32>) -> mat2<f32>
|
||||
{
|
||||
return mat2<f32>(input);
|
||||
}
|
||||
|
||||
fn testMat4ToMat3(input: mat4<f32>) -> mat3<f32>
|
||||
{
|
||||
return mat3<f32>(input);
|
||||
}
|
||||
|
||||
fn testMat4ToMat4(input: mat4<f32>) -> mat4<f32>
|
||||
{
|
||||
return mat4<f32>(input);
|
||||
}
|
||||
)";
|
||||
|
||||
Nz::ShaderAst::StatementPtr shader = Nz::ShaderLang::Parse(nzslSource);
|
||||
|
||||
Nz::ShaderAst::SanitizeVisitor::Options options;
|
||||
options.removeMatrixCast = true;
|
||||
|
||||
REQUIRE_NOTHROW(shader = Nz::ShaderAst::Sanitize(*shader, options));
|
||||
|
||||
ExpectNZSL(*shader, R"(
|
||||
fn testMat2ToMat2(input: mat2<f32>) -> mat2<f32>
|
||||
{
|
||||
return input;
|
||||
}
|
||||
|
||||
fn testMat2ToMat3(input: mat2<f32>) -> mat3<f32>
|
||||
{
|
||||
let temp: mat3<f32>;
|
||||
temp[0] = vec3<f32>(input[0], 0.000000);
|
||||
temp[1] = vec3<f32>(input[1], 0.000000);
|
||||
temp[2] = vec3<f32>(input[2], 1.000000);
|
||||
return temp;
|
||||
}
|
||||
|
||||
fn testMat2ToMat4(input: mat2<f32>) -> mat4<f32>
|
||||
{
|
||||
let temp: mat4<f32>;
|
||||
temp[0] = vec4<f32>(input[0], 0.000000, 0.000000);
|
||||
temp[1] = vec4<f32>(input[1], 0.000000, 0.000000);
|
||||
temp[2] = vec4<f32>(input[2], 1.000000, 0.000000);
|
||||
temp[3] = vec4<f32>(input[3], 0.000000, 1.000000);
|
||||
return temp;
|
||||
}
|
||||
|
||||
fn testMat3ToMat2(input: mat3<f32>) -> mat2<f32>
|
||||
{
|
||||
let temp: mat2<f32>;
|
||||
temp[0] = input[0].xy;
|
||||
temp[1] = input[1].xy;
|
||||
return temp;
|
||||
}
|
||||
|
||||
fn testMat3ToMat3(input: mat3<f32>) -> mat3<f32>
|
||||
{
|
||||
return input;
|
||||
}
|
||||
|
||||
fn testMat3ToMat4(input: mat3<f32>) -> mat4<f32>
|
||||
{
|
||||
let temp: mat4<f32>;
|
||||
temp[0] = vec4<f32>(input[0], 0.000000);
|
||||
temp[1] = vec4<f32>(input[1], 0.000000);
|
||||
temp[2] = vec4<f32>(input[2], 0.000000);
|
||||
temp[3] = vec4<f32>(input[3], 1.000000);
|
||||
return temp;
|
||||
}
|
||||
|
||||
fn testMat4ToMat2(input: mat4<f32>) -> mat2<f32>
|
||||
{
|
||||
let temp: mat2<f32>;
|
||||
temp[0] = input[0].xy;
|
||||
temp[1] = input[1].xy;
|
||||
return temp;
|
||||
}
|
||||
|
||||
fn testMat4ToMat3(input: mat4<f32>) -> mat3<f32>
|
||||
{
|
||||
let temp: mat3<f32>;
|
||||
temp[0] = input[0].xyz;
|
||||
temp[1] = input[1].xyz;
|
||||
temp[2] = input[2].xyz;
|
||||
return temp;
|
||||
}
|
||||
|
||||
fn testMat4ToMat4(input: mat4<f32>) -> mat4<f32>
|
||||
{
|
||||
return input;
|
||||
}
|
||||
)");
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue