Shader: Add type functions
This commit is contained in:
parent
9c2c62b063
commit
79c53061e0
|
|
@ -15,6 +15,13 @@
|
|||
namespace Nz
|
||||
{
|
||||
using ShaderExpressionType = std::variant<ShaderNodes::BasicType, std::string>;
|
||||
|
||||
inline bool IsBasicType(const ShaderExpressionType& type);
|
||||
inline bool IsMatrixType(const ShaderExpressionType& type);
|
||||
inline bool IsSamplerType(const ShaderExpressionType& type);
|
||||
inline bool IsStructType(const ShaderExpressionType& type);
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/ShaderExpressionType.inl>
|
||||
|
||||
#endif // NAZARA_SHADER_EXPRESSIONTYPE_HPP
|
||||
|
|
|
|||
|
|
@ -0,0 +1,107 @@
|
|||
// Copyright (C) 2020 Jérôme Leclercq
|
||||
// This file is part of the "Nazara Engine - Shader generator"
|
||||
// For conditions of distribution and use, see copyright notice in Config.hpp
|
||||
|
||||
#include <Nazara/Shader/ShaderExpressionType.hpp>
|
||||
#include <Nazara/Shader/Debug.hpp>
|
||||
|
||||
namespace Nz
|
||||
{
|
||||
inline bool IsBasicType(const ShaderExpressionType& type)
|
||||
{
|
||||
return std::visit([&](auto&& arg)
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
|
||||
return true;
|
||||
else if constexpr (std::is_same_v<T, std::string>)
|
||||
return false;
|
||||
else
|
||||
static_assert(AlwaysFalse<U>::value, "non-exhaustive visitor");
|
||||
|
||||
}, type);
|
||||
}
|
||||
|
||||
inline bool IsMatrixType(const ShaderExpressionType& type)
|
||||
{
|
||||
using namespace ShaderNodes;
|
||||
|
||||
if (!IsBasicType(type))
|
||||
return false;
|
||||
|
||||
switch (std::get<BasicType>(type))
|
||||
{
|
||||
case BasicType::Mat4x4:
|
||||
return true;
|
||||
|
||||
case BasicType::Boolean:
|
||||
case BasicType::Float1:
|
||||
case BasicType::Float2:
|
||||
case BasicType::Float3:
|
||||
case BasicType::Float4:
|
||||
case BasicType::Int1:
|
||||
case BasicType::Int2:
|
||||
case BasicType::Int3:
|
||||
case BasicType::Int4:
|
||||
case BasicType::Sampler2D:
|
||||
case BasicType::Void:
|
||||
case BasicType::UInt1:
|
||||
case BasicType::UInt2:
|
||||
case BasicType::UInt3:
|
||||
case BasicType::UInt4:
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool IsSamplerType(const ShaderExpressionType& type)
|
||||
{
|
||||
using namespace ShaderNodes;
|
||||
|
||||
if (!IsBasicType(type))
|
||||
return false;
|
||||
|
||||
switch (std::get<BasicType>(type))
|
||||
{
|
||||
case BasicType::Sampler2D:
|
||||
return true;
|
||||
|
||||
case BasicType::Boolean:
|
||||
case BasicType::Float1:
|
||||
case BasicType::Float2:
|
||||
case BasicType::Float3:
|
||||
case BasicType::Float4:
|
||||
case BasicType::Int1:
|
||||
case BasicType::Int2:
|
||||
case BasicType::Int3:
|
||||
case BasicType::Int4:
|
||||
case BasicType::Mat4x4:
|
||||
case BasicType::Void:
|
||||
case BasicType::UInt1:
|
||||
case BasicType::UInt2:
|
||||
case BasicType::UInt3:
|
||||
case BasicType::UInt4:
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool IsStructType(const ShaderExpressionType& type)
|
||||
{
|
||||
return std::visit([&](auto&& arg)
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
|
||||
return false;
|
||||
else if constexpr (std::is_same_v<T, std::string>)
|
||||
return true;
|
||||
else
|
||||
static_assert(AlwaysFalse<U>::value, "non-exhaustive visitor");
|
||||
|
||||
}, type);
|
||||
}
|
||||
}
|
||||
|
||||
#include <Nazara/Shader/DebugOff.hpp>
|
||||
|
|
@ -273,7 +273,7 @@ namespace Nz
|
|||
|
||||
if (remainingMembers > 1)
|
||||
{
|
||||
assert(std::holds_alternative<std::string>(member.type));
|
||||
assert(IsStructType(member.type));
|
||||
AppendField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
|
||||
}
|
||||
}
|
||||
|
|
@ -363,7 +363,7 @@ namespace Nz
|
|||
Visit(node.structExpr, true);
|
||||
|
||||
const ShaderExpressionType& exprType = node.structExpr->GetExpressionType();
|
||||
assert(std::holds_alternative<std::string>(exprType));
|
||||
assert(IsStructType(exprType));
|
||||
|
||||
AppendField(std::get<std::string>(exprType), node.memberIndices.data(), node.memberIndices.size());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ namespace Nz
|
|||
|
||||
if (remainingMembers > 1)
|
||||
{
|
||||
if (!std::holds_alternative<std::string>(member.type))
|
||||
if (!IsStructType(member.type))
|
||||
throw AstError{ "member type does not match node type" };
|
||||
|
||||
return CheckField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
|
||||
|
|
@ -110,7 +110,7 @@ namespace Nz
|
|||
void ShaderAstValidator::Visit(ShaderNodes::AccessMember& node)
|
||||
{
|
||||
const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType();
|
||||
if (!std::holds_alternative<std::string>(exprType))
|
||||
if (!IsStructType(exprType))
|
||||
throw AstError{ "expression is not a structure" };
|
||||
|
||||
const std::string& structName = std::get<std::string>(exprType);
|
||||
|
|
@ -138,11 +138,11 @@ namespace Nz
|
|||
MandatoryNode(node.right);
|
||||
|
||||
const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType();
|
||||
if (!std::holds_alternative<ShaderNodes::BasicType>(leftExprType))
|
||||
if (!IsBasicType(leftExprType))
|
||||
throw AstError{ "left expression type does not support binary operation" };
|
||||
|
||||
const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType();
|
||||
if (!std::holds_alternative<ShaderNodes::BasicType>(rightExprType))
|
||||
if (!IsBasicType(rightExprType))
|
||||
throw AstError{ "right expression type does not support binary operation" };
|
||||
|
||||
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
|
||||
|
|
@ -229,7 +229,7 @@ namespace Nz
|
|||
break;
|
||||
|
||||
const ShaderExpressionType& exprType = exprPtr->GetExpressionType();
|
||||
if (!std::holds_alternative<ShaderNodes::BasicType>(exprType))
|
||||
if (!IsBasicType(exprType))
|
||||
throw AstError{ "incompatible type" };
|
||||
|
||||
componentCount += node.GetComponentCount(std::get<ShaderNodes::BasicType>(exprType));
|
||||
|
|
@ -352,7 +352,7 @@ namespace Nz
|
|||
throw AstError{ "Cannot swizzle more than four elements" };
|
||||
|
||||
const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType();
|
||||
if (!std::holds_alternative<ShaderNodes::BasicType>(exprType))
|
||||
if (!IsBasicType(exprType))
|
||||
throw AstError{ "Cannot swizzle this type" };
|
||||
|
||||
switch (std::get<ShaderNodes::BasicType>(exprType))
|
||||
|
|
@ -379,7 +379,7 @@ namespace Nz
|
|||
switch (var.entry)
|
||||
{
|
||||
case ShaderNodes::BuiltinEntry::VertexPosition:
|
||||
if (!std::holds_alternative<ShaderNodes::BasicType>(var.type) ||
|
||||
if (!IsBasicType(var.type) ||
|
||||
std::get<ShaderNodes::BasicType>(var.type) != ShaderNodes::BasicType::Float4)
|
||||
throw AstError{ "Builtin is not of the expected type" };
|
||||
|
||||
|
|
|
|||
|
|
@ -100,10 +100,10 @@ namespace Nz::ShaderNodes
|
|||
case BinaryType::Multiply:
|
||||
{
|
||||
const ShaderExpressionType& leftExprType = left->GetExpressionType();
|
||||
assert(std::holds_alternative<BasicType>(leftExprType));
|
||||
assert(IsBasicType(leftExprType));
|
||||
|
||||
const ShaderExpressionType& rightExprType = right->GetExpressionType();
|
||||
assert(std::holds_alternative<BasicType>(rightExprType));
|
||||
assert(IsBasicType(rightExprType));
|
||||
|
||||
switch (std::get<BasicType>(leftExprType))
|
||||
{
|
||||
|
|
@ -212,7 +212,7 @@ namespace Nz::ShaderNodes
|
|||
ShaderExpressionType SwizzleOp::GetExpressionType() const
|
||||
{
|
||||
const ShaderExpressionType& exprType = expression->GetExpressionType();
|
||||
assert(std::holds_alternative<BasicType>(exprType));
|
||||
assert(IsBasicType(exprType));
|
||||
|
||||
return static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + componentCount - 1);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -39,13 +39,13 @@ namespace Nz
|
|||
void SpirvAstVisitor::Visit(ShaderNodes::BinaryOp& node)
|
||||
{
|
||||
ShaderExpressionType resultExprType = node.GetExpressionType();
|
||||
assert(std::holds_alternative<ShaderNodes::BasicType>(resultExprType));
|
||||
assert(IsBasicType(resultExprType));
|
||||
|
||||
const ShaderExpressionType& leftExprType = node.left->GetExpressionType();
|
||||
assert(std::holds_alternative<ShaderNodes::BasicType>(leftExprType));
|
||||
assert(IsBasicType(leftExprType));
|
||||
|
||||
const ShaderExpressionType& rightExprType = node.right->GetExpressionType();
|
||||
assert(std::holds_alternative<ShaderNodes::BasicType>(rightExprType));
|
||||
assert(IsBasicType(rightExprType));
|
||||
|
||||
ShaderNodes::BasicType resultType = std::get<ShaderNodes::BasicType>(resultExprType);
|
||||
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
|
||||
|
|
@ -285,7 +285,7 @@ namespace Nz
|
|||
void SpirvAstVisitor::Visit(ShaderNodes::Cast& node)
|
||||
{
|
||||
const ShaderExpressionType& targetExprType = node.exprType;
|
||||
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
|
||||
assert(IsBasicType(targetExprType));
|
||||
|
||||
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
|
||||
|
||||
|
|
@ -351,7 +351,7 @@ namespace Nz
|
|||
case ShaderNodes::IntrinsicType::DotProduct:
|
||||
{
|
||||
const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType();
|
||||
assert(std::holds_alternative<ShaderNodes::BasicType>(vecExprType));
|
||||
assert(IsBasicType(vecExprType));
|
||||
|
||||
ShaderNodes::BasicType vecType = std::get<ShaderNodes::BasicType>(vecExprType);
|
||||
|
||||
|
|
@ -394,7 +394,7 @@ namespace Nz
|
|||
void SpirvAstVisitor::Visit(ShaderNodes::SwizzleOp& node)
|
||||
{
|
||||
const ShaderExpressionType& targetExprType = node.GetExpressionType();
|
||||
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
|
||||
assert(IsBasicType(targetExprType));
|
||||
|
||||
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
|
||||
|
||||
|
|
|
|||
|
|
@ -293,13 +293,13 @@ namespace Nz
|
|||
else if constexpr (std::is_same_v<T, float>)
|
||||
cache.Register({ Float{ 32 } });
|
||||
else if constexpr (std::is_same_v<T, Int32>)
|
||||
cache.Register({ Integer{ 32, 1 } });
|
||||
cache.Register({ Integer{ 32, true } });
|
||||
else if constexpr (std::is_same_v<T, Int64>)
|
||||
cache.Register({ Integer{ 64, 1 } });
|
||||
cache.Register({ Integer{ 64, true } });
|
||||
else if constexpr (std::is_same_v<T, UInt32>)
|
||||
cache.Register({ Integer{ 32, 0 } });
|
||||
cache.Register({ Integer{ 32, false } });
|
||||
else if constexpr (std::is_same_v<T, UInt64>)
|
||||
cache.Register({ Integer{ 64, 0 } });
|
||||
cache.Register({ Integer{ 64, false } });
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
|
||||
|
|
@ -593,7 +593,7 @@ namespace Nz
|
|||
return Float{ 32 };
|
||||
|
||||
case ShaderNodes::BasicType::Int1:
|
||||
return Integer{ 32, 1 };
|
||||
return Integer{ 32, true };
|
||||
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
case ShaderNodes::BasicType::Float3:
|
||||
|
|
@ -615,7 +615,7 @@ namespace Nz
|
|||
return Matrix{ BuildType(ShaderNodes::BasicType::Float4), 4u };
|
||||
|
||||
case ShaderNodes::BasicType::UInt1:
|
||||
return Integer{ 32, 0 };
|
||||
return Integer{ 32, false };
|
||||
|
||||
case ShaderNodes::BasicType::Void:
|
||||
return Void{};
|
||||
|
|
@ -706,13 +706,13 @@ namespace Nz
|
|||
else if constexpr (std::is_same_v<ValueType, float>)
|
||||
typeId = GetId({ Float{ 32 } });
|
||||
else if constexpr (std::is_same_v<ValueType, Int32>)
|
||||
typeId = GetId({ Integer{ 32, 1 } });
|
||||
typeId = GetId({ Integer{ 32, true } });
|
||||
else if constexpr (std::is_same_v<ValueType, Int64>)
|
||||
typeId = GetId({ Integer{ 64, 1 } });
|
||||
typeId = GetId({ Integer{ 64, true } });
|
||||
else if constexpr (std::is_same_v<ValueType, UInt32>)
|
||||
typeId = GetId({ Integer{ 32, 0 } });
|
||||
typeId = GetId({ Integer{ 32, false } });
|
||||
else if constexpr (std::is_same_v<ValueType, UInt64>)
|
||||
typeId = GetId({ Integer{ 64, 0 } });
|
||||
typeId = GetId({ Integer{ 64, false } });
|
||||
else
|
||||
static_assert(AlwaysFalse<ValueType>::value, "non-exhaustive visitor");
|
||||
|
||||
|
|
|
|||
|
|
@ -268,7 +268,7 @@ namespace Nz
|
|||
}
|
||||
|
||||
const ShaderExpressionType& builtinExprType = builtin->type;
|
||||
assert(std::holds_alternative<ShaderNodes::BasicType>(builtinExprType));
|
||||
assert(IsBasicType(builtinExprType));
|
||||
|
||||
ShaderNodes::BasicType builtinType = std::get<ShaderNodes::BasicType>(builtinExprType);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue