diff --git a/include/Nazara/Shader/ShaderExpressionType.hpp b/include/Nazara/Shader/ShaderExpressionType.hpp index 69b53b06a..6d5385121 100644 --- a/include/Nazara/Shader/ShaderExpressionType.hpp +++ b/include/Nazara/Shader/ShaderExpressionType.hpp @@ -15,6 +15,13 @@ namespace Nz { using ShaderExpressionType = std::variant; + + inline bool IsBasicType(const ShaderExpressionType& type); + inline bool IsMatrixType(const ShaderExpressionType& type); + inline bool IsSamplerType(const ShaderExpressionType& type); + inline bool IsStructType(const ShaderExpressionType& type); } +#include + #endif // NAZARA_SHADER_EXPRESSIONTYPE_HPP diff --git a/include/Nazara/Shader/ShaderExpressionType.inl b/include/Nazara/Shader/ShaderExpressionType.inl new file mode 100644 index 000000000..15f602a5b --- /dev/null +++ b/include/Nazara/Shader/ShaderExpressionType.inl @@ -0,0 +1,107 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz +{ + inline bool IsBasicType(const ShaderExpressionType& type) + { + return std::visit([&](auto&& arg) + { + using T = std::decay_t; + if constexpr (std::is_same_v) + return true; + else if constexpr (std::is_same_v) + return false; + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + + }, type); + } + + inline bool IsMatrixType(const ShaderExpressionType& type) + { + using namespace ShaderNodes; + + if (!IsBasicType(type)) + return false; + + switch (std::get(type)) + { + case BasicType::Mat4x4: + return true; + + case BasicType::Boolean: + case BasicType::Float1: + case BasicType::Float2: + case BasicType::Float3: + case BasicType::Float4: + case BasicType::Int1: + case BasicType::Int2: + case BasicType::Int3: + case BasicType::Int4: + case BasicType::Sampler2D: + case BasicType::Void: + case BasicType::UInt1: + case BasicType::UInt2: + case BasicType::UInt3: + case BasicType::UInt4: + return false; + } + + return false; + } + + inline bool IsSamplerType(const ShaderExpressionType& type) + { + using namespace ShaderNodes; + + if (!IsBasicType(type)) + return false; + + switch (std::get(type)) + { + case BasicType::Sampler2D: + return true; + + case BasicType::Boolean: + case BasicType::Float1: + case BasicType::Float2: + case BasicType::Float3: + case BasicType::Float4: + case BasicType::Int1: + case BasicType::Int2: + case BasicType::Int3: + case BasicType::Int4: + case BasicType::Mat4x4: + case BasicType::Void: + case BasicType::UInt1: + case BasicType::UInt2: + case BasicType::UInt3: + case BasicType::UInt4: + return false; + } + + return false; + } + + inline bool IsStructType(const ShaderExpressionType& type) + { + return std::visit([&](auto&& arg) + { + using T = std::decay_t; + if constexpr (std::is_same_v) + return false; + else if constexpr (std::is_same_v) + return true; + else + static_assert(AlwaysFalse::value, "non-exhaustive visitor"); + + }, type); + } +} + +#include diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 7be512c28..4e22aeffc 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -273,7 +273,7 @@ namespace Nz if (remainingMembers > 1) { - assert(std::holds_alternative(member.type)); + assert(IsStructType(member.type)); AppendField(std::get(member.type), memberIndex + 1, remainingMembers - 1); } } @@ -363,7 +363,7 @@ namespace Nz Visit(node.structExpr, true); const ShaderExpressionType& exprType = node.structExpr->GetExpressionType(); - assert(std::holds_alternative(exprType)); + assert(IsStructType(exprType)); AppendField(std::get(exprType), node.memberIndices.data(), node.memberIndices.size()); } diff --git a/src/Nazara/Shader/ShaderAstValidator.cpp b/src/Nazara/Shader/ShaderAstValidator.cpp index b614ec4a2..95ae20df0 100644 --- a/src/Nazara/Shader/ShaderAstValidator.cpp +++ b/src/Nazara/Shader/ShaderAstValidator.cpp @@ -98,7 +98,7 @@ namespace Nz if (remainingMembers > 1) { - if (!std::holds_alternative(member.type)) + if (!IsStructType(member.type)) throw AstError{ "member type does not match node type" }; return CheckField(std::get(member.type), memberIndex + 1, remainingMembers - 1); @@ -110,7 +110,7 @@ namespace Nz void ShaderAstValidator::Visit(ShaderNodes::AccessMember& node) { const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType(); - if (!std::holds_alternative(exprType)) + if (!IsStructType(exprType)) throw AstError{ "expression is not a structure" }; const std::string& structName = std::get(exprType); @@ -138,11 +138,11 @@ namespace Nz MandatoryNode(node.right); const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType(); - if (!std::holds_alternative(leftExprType)) + if (!IsBasicType(leftExprType)) throw AstError{ "left expression type does not support binary operation" }; const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType(); - if (!std::holds_alternative(rightExprType)) + if (!IsBasicType(rightExprType)) throw AstError{ "right expression type does not support binary operation" }; ShaderNodes::BasicType leftType = std::get(leftExprType); @@ -229,7 +229,7 @@ namespace Nz break; const ShaderExpressionType& exprType = exprPtr->GetExpressionType(); - if (!std::holds_alternative(exprType)) + if (!IsBasicType(exprType)) throw AstError{ "incompatible type" }; componentCount += node.GetComponentCount(std::get(exprType)); @@ -352,7 +352,7 @@ namespace Nz throw AstError{ "Cannot swizzle more than four elements" }; const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType(); - if (!std::holds_alternative(exprType)) + if (!IsBasicType(exprType)) throw AstError{ "Cannot swizzle this type" }; switch (std::get(exprType)) @@ -379,7 +379,7 @@ namespace Nz switch (var.entry) { case ShaderNodes::BuiltinEntry::VertexPosition: - if (!std::holds_alternative(var.type) || + if (!IsBasicType(var.type) || std::get(var.type) != ShaderNodes::BasicType::Float4) throw AstError{ "Builtin is not of the expected type" }; diff --git a/src/Nazara/Shader/ShaderNodes.cpp b/src/Nazara/Shader/ShaderNodes.cpp index 15c7e40f0..bdc2897db 100644 --- a/src/Nazara/Shader/ShaderNodes.cpp +++ b/src/Nazara/Shader/ShaderNodes.cpp @@ -100,10 +100,10 @@ namespace Nz::ShaderNodes case BinaryType::Multiply: { const ShaderExpressionType& leftExprType = left->GetExpressionType(); - assert(std::holds_alternative(leftExprType)); + assert(IsBasicType(leftExprType)); const ShaderExpressionType& rightExprType = right->GetExpressionType(); - assert(std::holds_alternative(rightExprType)); + assert(IsBasicType(rightExprType)); switch (std::get(leftExprType)) { @@ -212,7 +212,7 @@ namespace Nz::ShaderNodes ShaderExpressionType SwizzleOp::GetExpressionType() const { const ShaderExpressionType& exprType = expression->GetExpressionType(); - assert(std::holds_alternative(exprType)); + assert(IsBasicType(exprType)); return static_cast(UnderlyingCast(GetComponentType(std::get(exprType))) + componentCount - 1); } diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index e066337a9..4c419f987 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -39,13 +39,13 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderNodes::BinaryOp& node) { ShaderExpressionType resultExprType = node.GetExpressionType(); - assert(std::holds_alternative(resultExprType)); + assert(IsBasicType(resultExprType)); const ShaderExpressionType& leftExprType = node.left->GetExpressionType(); - assert(std::holds_alternative(leftExprType)); + assert(IsBasicType(leftExprType)); const ShaderExpressionType& rightExprType = node.right->GetExpressionType(); - assert(std::holds_alternative(rightExprType)); + assert(IsBasicType(rightExprType)); ShaderNodes::BasicType resultType = std::get(resultExprType); ShaderNodes::BasicType leftType = std::get(leftExprType); @@ -285,7 +285,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderNodes::Cast& node) { const ShaderExpressionType& targetExprType = node.exprType; - assert(std::holds_alternative(targetExprType)); + assert(IsBasicType(targetExprType)); ShaderNodes::BasicType targetType = std::get(targetExprType); @@ -351,7 +351,7 @@ namespace Nz case ShaderNodes::IntrinsicType::DotProduct: { const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType(); - assert(std::holds_alternative(vecExprType)); + assert(IsBasicType(vecExprType)); ShaderNodes::BasicType vecType = std::get(vecExprType); @@ -394,7 +394,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderNodes::SwizzleOp& node) { const ShaderExpressionType& targetExprType = node.GetExpressionType(); - assert(std::holds_alternative(targetExprType)); + assert(IsBasicType(targetExprType)); ShaderNodes::BasicType targetType = std::get(targetExprType); diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index 74dd25ff2..b8489ce76 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -293,13 +293,13 @@ namespace Nz else if constexpr (std::is_same_v) cache.Register({ Float{ 32 } }); else if constexpr (std::is_same_v) - cache.Register({ Integer{ 32, 1 } }); + cache.Register({ Integer{ 32, true } }); else if constexpr (std::is_same_v) - cache.Register({ Integer{ 64, 1 } }); + cache.Register({ Integer{ 64, true } }); else if constexpr (std::is_same_v) - cache.Register({ Integer{ 32, 0 } }); + cache.Register({ Integer{ 32, false } }); else if constexpr (std::is_same_v) - cache.Register({ Integer{ 64, 0 } }); + cache.Register({ Integer{ 64, false } }); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); @@ -593,7 +593,7 @@ namespace Nz return Float{ 32 }; case ShaderNodes::BasicType::Int1: - return Integer{ 32, 1 }; + return Integer{ 32, true }; case ShaderNodes::BasicType::Float2: case ShaderNodes::BasicType::Float3: @@ -615,7 +615,7 @@ namespace Nz return Matrix{ BuildType(ShaderNodes::BasicType::Float4), 4u }; case ShaderNodes::BasicType::UInt1: - return Integer{ 32, 0 }; + return Integer{ 32, false }; case ShaderNodes::BasicType::Void: return Void{}; @@ -706,13 +706,13 @@ namespace Nz else if constexpr (std::is_same_v) typeId = GetId({ Float{ 32 } }); else if constexpr (std::is_same_v) - typeId = GetId({ Integer{ 32, 1 } }); + typeId = GetId({ Integer{ 32, true } }); else if constexpr (std::is_same_v) - typeId = GetId({ Integer{ 64, 1 } }); + typeId = GetId({ Integer{ 64, true } }); else if constexpr (std::is_same_v) - typeId = GetId({ Integer{ 32, 0 } }); + typeId = GetId({ Integer{ 32, false } }); else if constexpr (std::is_same_v) - typeId = GetId({ Integer{ 64, 0 } }); + typeId = GetId({ Integer{ 64, false } }); else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 81f9d1d7f..522dd699d 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -268,7 +268,7 @@ namespace Nz } const ShaderExpressionType& builtinExprType = builtin->type; - assert(std::holds_alternative(builtinExprType)); + assert(IsBasicType(builtinExprType)); ShaderNodes::BasicType builtinType = std::get(builtinExprType);