Shader: Add type functions

This commit is contained in:
Lynix 2020-09-22 21:50:07 +02:00
parent 9c2c62b063
commit 79c53061e0
8 changed files with 143 additions and 29 deletions

View File

@ -15,6 +15,13 @@
namespace Nz
{
using ShaderExpressionType = std::variant<ShaderNodes::BasicType, std::string>;
inline bool IsBasicType(const ShaderExpressionType& type);
inline bool IsMatrixType(const ShaderExpressionType& type);
inline bool IsSamplerType(const ShaderExpressionType& type);
inline bool IsStructType(const ShaderExpressionType& type);
}
#include <Nazara/Shader/ShaderExpressionType.inl>
#endif // NAZARA_SHADER_EXPRESSIONTYPE_HPP

View File

@ -0,0 +1,107 @@
// Copyright (C) 2020 Jérôme Leclercq
// This file is part of the "Nazara Engine - Shader generator"
// For conditions of distribution and use, see copyright notice in Config.hpp
#include <Nazara/Shader/ShaderExpressionType.hpp>
#include <Nazara/Shader/Debug.hpp>
namespace Nz
{
inline bool IsBasicType(const ShaderExpressionType& type)
{
return std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
return true;
else if constexpr (std::is_same_v<T, std::string>)
return false;
else
static_assert(AlwaysFalse<U>::value, "non-exhaustive visitor");
}, type);
}
inline bool IsMatrixType(const ShaderExpressionType& type)
{
using namespace ShaderNodes;
if (!IsBasicType(type))
return false;
switch (std::get<BasicType>(type))
{
case BasicType::Mat4x4:
return true;
case BasicType::Boolean:
case BasicType::Float1:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int1:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
case BasicType::Sampler2D:
case BasicType::Void:
case BasicType::UInt1:
case BasicType::UInt2:
case BasicType::UInt3:
case BasicType::UInt4:
return false;
}
return false;
}
inline bool IsSamplerType(const ShaderExpressionType& type)
{
using namespace ShaderNodes;
if (!IsBasicType(type))
return false;
switch (std::get<BasicType>(type))
{
case BasicType::Sampler2D:
return true;
case BasicType::Boolean:
case BasicType::Float1:
case BasicType::Float2:
case BasicType::Float3:
case BasicType::Float4:
case BasicType::Int1:
case BasicType::Int2:
case BasicType::Int3:
case BasicType::Int4:
case BasicType::Mat4x4:
case BasicType::Void:
case BasicType::UInt1:
case BasicType::UInt2:
case BasicType::UInt3:
case BasicType::UInt4:
return false;
}
return false;
}
inline bool IsStructType(const ShaderExpressionType& type)
{
return std::visit([&](auto&& arg)
{
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
return false;
else if constexpr (std::is_same_v<T, std::string>)
return true;
else
static_assert(AlwaysFalse<U>::value, "non-exhaustive visitor");
}, type);
}
}
#include <Nazara/Shader/DebugOff.hpp>

View File

@ -273,7 +273,7 @@ namespace Nz
if (remainingMembers > 1)
{
assert(std::holds_alternative<std::string>(member.type));
assert(IsStructType(member.type));
AppendField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
}
}
@ -363,7 +363,7 @@ namespace Nz
Visit(node.structExpr, true);
const ShaderExpressionType& exprType = node.structExpr->GetExpressionType();
assert(std::holds_alternative<std::string>(exprType));
assert(IsStructType(exprType));
AppendField(std::get<std::string>(exprType), node.memberIndices.data(), node.memberIndices.size());
}

View File

@ -98,7 +98,7 @@ namespace Nz
if (remainingMembers > 1)
{
if (!std::holds_alternative<std::string>(member.type))
if (!IsStructType(member.type))
throw AstError{ "member type does not match node type" };
return CheckField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
@ -110,7 +110,7 @@ namespace Nz
void ShaderAstValidator::Visit(ShaderNodes::AccessMember& node)
{
const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType();
if (!std::holds_alternative<std::string>(exprType))
if (!IsStructType(exprType))
throw AstError{ "expression is not a structure" };
const std::string& structName = std::get<std::string>(exprType);
@ -138,11 +138,11 @@ namespace Nz
MandatoryNode(node.right);
const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(leftExprType))
if (!IsBasicType(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(rightExprType))
if (!IsBasicType(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
@ -229,7 +229,7 @@ namespace Nz
break;
const ShaderExpressionType& exprType = exprPtr->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(exprType))
if (!IsBasicType(exprType))
throw AstError{ "incompatible type" };
componentCount += node.GetComponentCount(std::get<ShaderNodes::BasicType>(exprType));
@ -352,7 +352,7 @@ namespace Nz
throw AstError{ "Cannot swizzle more than four elements" };
const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(exprType))
if (!IsBasicType(exprType))
throw AstError{ "Cannot swizzle this type" };
switch (std::get<ShaderNodes::BasicType>(exprType))
@ -379,7 +379,7 @@ namespace Nz
switch (var.entry)
{
case ShaderNodes::BuiltinEntry::VertexPosition:
if (!std::holds_alternative<ShaderNodes::BasicType>(var.type) ||
if (!IsBasicType(var.type) ||
std::get<ShaderNodes::BasicType>(var.type) != ShaderNodes::BasicType::Float4)
throw AstError{ "Builtin is not of the expected type" };

View File

@ -100,10 +100,10 @@ namespace Nz::ShaderNodes
case BinaryType::Multiply:
{
const ShaderExpressionType& leftExprType = left->GetExpressionType();
assert(std::holds_alternative<BasicType>(leftExprType));
assert(IsBasicType(leftExprType));
const ShaderExpressionType& rightExprType = right->GetExpressionType();
assert(std::holds_alternative<BasicType>(rightExprType));
assert(IsBasicType(rightExprType));
switch (std::get<BasicType>(leftExprType))
{
@ -212,7 +212,7 @@ namespace Nz::ShaderNodes
ShaderExpressionType SwizzleOp::GetExpressionType() const
{
const ShaderExpressionType& exprType = expression->GetExpressionType();
assert(std::holds_alternative<BasicType>(exprType));
assert(IsBasicType(exprType));
return static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + componentCount - 1);
}

View File

@ -39,13 +39,13 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderNodes::BinaryOp& node)
{
ShaderExpressionType resultExprType = node.GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(resultExprType));
assert(IsBasicType(resultExprType));
const ShaderExpressionType& leftExprType = node.left->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(leftExprType));
assert(IsBasicType(leftExprType));
const ShaderExpressionType& rightExprType = node.right->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(rightExprType));
assert(IsBasicType(rightExprType));
ShaderNodes::BasicType resultType = std::get<ShaderNodes::BasicType>(resultExprType);
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
@ -285,7 +285,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderNodes::Cast& node)
{
const ShaderExpressionType& targetExprType = node.exprType;
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
assert(IsBasicType(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
@ -351,7 +351,7 @@ namespace Nz
case ShaderNodes::IntrinsicType::DotProduct:
{
const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(vecExprType));
assert(IsBasicType(vecExprType));
ShaderNodes::BasicType vecType = std::get<ShaderNodes::BasicType>(vecExprType);
@ -394,7 +394,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderNodes::SwizzleOp& node)
{
const ShaderExpressionType& targetExprType = node.GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
assert(IsBasicType(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);

View File

@ -293,13 +293,13 @@ namespace Nz
else if constexpr (std::is_same_v<T, float>)
cache.Register({ Float{ 32 } });
else if constexpr (std::is_same_v<T, Int32>)
cache.Register({ Integer{ 32, 1 } });
cache.Register({ Integer{ 32, true } });
else if constexpr (std::is_same_v<T, Int64>)
cache.Register({ Integer{ 64, 1 } });
cache.Register({ Integer{ 64, true } });
else if constexpr (std::is_same_v<T, UInt32>)
cache.Register({ Integer{ 32, 0 } });
cache.Register({ Integer{ 32, false } });
else if constexpr (std::is_same_v<T, UInt64>)
cache.Register({ Integer{ 64, 0 } });
cache.Register({ Integer{ 64, false } });
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
@ -593,7 +593,7 @@ namespace Nz
return Float{ 32 };
case ShaderNodes::BasicType::Int1:
return Integer{ 32, 1 };
return Integer{ 32, true };
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
@ -615,7 +615,7 @@ namespace Nz
return Matrix{ BuildType(ShaderNodes::BasicType::Float4), 4u };
case ShaderNodes::BasicType::UInt1:
return Integer{ 32, 0 };
return Integer{ 32, false };
case ShaderNodes::BasicType::Void:
return Void{};
@ -706,13 +706,13 @@ namespace Nz
else if constexpr (std::is_same_v<ValueType, float>)
typeId = GetId({ Float{ 32 } });
else if constexpr (std::is_same_v<ValueType, Int32>)
typeId = GetId({ Integer{ 32, 1 } });
typeId = GetId({ Integer{ 32, true } });
else if constexpr (std::is_same_v<ValueType, Int64>)
typeId = GetId({ Integer{ 64, 1 } });
typeId = GetId({ Integer{ 64, true } });
else if constexpr (std::is_same_v<ValueType, UInt32>)
typeId = GetId({ Integer{ 32, 0 } });
typeId = GetId({ Integer{ 32, false } });
else if constexpr (std::is_same_v<ValueType, UInt64>)
typeId = GetId({ Integer{ 64, 0 } });
typeId = GetId({ Integer{ 64, false } });
else
static_assert(AlwaysFalse<ValueType>::value, "non-exhaustive visitor");

View File

@ -268,7 +268,7 @@ namespace Nz
}
const ShaderExpressionType& builtinExprType = builtin->type;
assert(std::holds_alternative<ShaderNodes::BasicType>(builtinExprType));
assert(IsBasicType(builtinExprType));
ShaderNodes::BasicType builtinType = std::get<ShaderNodes::BasicType>(builtinExprType);