Shader: Add type functions

This commit is contained in:
Lynix
2020-09-22 21:50:07 +02:00
parent 9c2c62b063
commit 79c53061e0
8 changed files with 143 additions and 29 deletions

View File

@@ -273,7 +273,7 @@ namespace Nz
if (remainingMembers > 1)
{
assert(std::holds_alternative<std::string>(member.type));
assert(IsStructType(member.type));
AppendField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
}
}
@@ -363,7 +363,7 @@ namespace Nz
Visit(node.structExpr, true);
const ShaderExpressionType& exprType = node.structExpr->GetExpressionType();
assert(std::holds_alternative<std::string>(exprType));
assert(IsStructType(exprType));
AppendField(std::get<std::string>(exprType), node.memberIndices.data(), node.memberIndices.size());
}

View File

@@ -98,7 +98,7 @@ namespace Nz
if (remainingMembers > 1)
{
if (!std::holds_alternative<std::string>(member.type))
if (!IsStructType(member.type))
throw AstError{ "member type does not match node type" };
return CheckField(std::get<std::string>(member.type), memberIndex + 1, remainingMembers - 1);
@@ -110,7 +110,7 @@ namespace Nz
void ShaderAstValidator::Visit(ShaderNodes::AccessMember& node)
{
const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType();
if (!std::holds_alternative<std::string>(exprType))
if (!IsStructType(exprType))
throw AstError{ "expression is not a structure" };
const std::string& structName = std::get<std::string>(exprType);
@@ -138,11 +138,11 @@ namespace Nz
MandatoryNode(node.right);
const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(leftExprType))
if (!IsBasicType(leftExprType))
throw AstError{ "left expression type does not support binary operation" };
const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(rightExprType))
if (!IsBasicType(rightExprType))
throw AstError{ "right expression type does not support binary operation" };
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
@@ -229,7 +229,7 @@ namespace Nz
break;
const ShaderExpressionType& exprType = exprPtr->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(exprType))
if (!IsBasicType(exprType))
throw AstError{ "incompatible type" };
componentCount += node.GetComponentCount(std::get<ShaderNodes::BasicType>(exprType));
@@ -352,7 +352,7 @@ namespace Nz
throw AstError{ "Cannot swizzle more than four elements" };
const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType();
if (!std::holds_alternative<ShaderNodes::BasicType>(exprType))
if (!IsBasicType(exprType))
throw AstError{ "Cannot swizzle this type" };
switch (std::get<ShaderNodes::BasicType>(exprType))
@@ -379,7 +379,7 @@ namespace Nz
switch (var.entry)
{
case ShaderNodes::BuiltinEntry::VertexPosition:
if (!std::holds_alternative<ShaderNodes::BasicType>(var.type) ||
if (!IsBasicType(var.type) ||
std::get<ShaderNodes::BasicType>(var.type) != ShaderNodes::BasicType::Float4)
throw AstError{ "Builtin is not of the expected type" };

View File

@@ -100,10 +100,10 @@ namespace Nz::ShaderNodes
case BinaryType::Multiply:
{
const ShaderExpressionType& leftExprType = left->GetExpressionType();
assert(std::holds_alternative<BasicType>(leftExprType));
assert(IsBasicType(leftExprType));
const ShaderExpressionType& rightExprType = right->GetExpressionType();
assert(std::holds_alternative<BasicType>(rightExprType));
assert(IsBasicType(rightExprType));
switch (std::get<BasicType>(leftExprType))
{
@@ -212,7 +212,7 @@ namespace Nz::ShaderNodes
ShaderExpressionType SwizzleOp::GetExpressionType() const
{
const ShaderExpressionType& exprType = expression->GetExpressionType();
assert(std::holds_alternative<BasicType>(exprType));
assert(IsBasicType(exprType));
return static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + componentCount - 1);
}

View File

@@ -39,13 +39,13 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderNodes::BinaryOp& node)
{
ShaderExpressionType resultExprType = node.GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(resultExprType));
assert(IsBasicType(resultExprType));
const ShaderExpressionType& leftExprType = node.left->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(leftExprType));
assert(IsBasicType(leftExprType));
const ShaderExpressionType& rightExprType = node.right->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(rightExprType));
assert(IsBasicType(rightExprType));
ShaderNodes::BasicType resultType = std::get<ShaderNodes::BasicType>(resultExprType);
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
@@ -285,7 +285,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderNodes::Cast& node)
{
const ShaderExpressionType& targetExprType = node.exprType;
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
assert(IsBasicType(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);
@@ -351,7 +351,7 @@ namespace Nz
case ShaderNodes::IntrinsicType::DotProduct:
{
const ShaderExpressionType& vecExprType = node.parameters[0]->GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(vecExprType));
assert(IsBasicType(vecExprType));
ShaderNodes::BasicType vecType = std::get<ShaderNodes::BasicType>(vecExprType);
@@ -394,7 +394,7 @@ namespace Nz
void SpirvAstVisitor::Visit(ShaderNodes::SwizzleOp& node)
{
const ShaderExpressionType& targetExprType = node.GetExpressionType();
assert(std::holds_alternative<ShaderNodes::BasicType>(targetExprType));
assert(IsBasicType(targetExprType));
ShaderNodes::BasicType targetType = std::get<ShaderNodes::BasicType>(targetExprType);

View File

@@ -293,13 +293,13 @@ namespace Nz
else if constexpr (std::is_same_v<T, float>)
cache.Register({ Float{ 32 } });
else if constexpr (std::is_same_v<T, Int32>)
cache.Register({ Integer{ 32, 1 } });
cache.Register({ Integer{ 32, true } });
else if constexpr (std::is_same_v<T, Int64>)
cache.Register({ Integer{ 64, 1 } });
cache.Register({ Integer{ 64, true } });
else if constexpr (std::is_same_v<T, UInt32>)
cache.Register({ Integer{ 32, 0 } });
cache.Register({ Integer{ 32, false } });
else if constexpr (std::is_same_v<T, UInt64>)
cache.Register({ Integer{ 64, 0 } });
cache.Register({ Integer{ 64, false } });
else
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
@@ -593,7 +593,7 @@ namespace Nz
return Float{ 32 };
case ShaderNodes::BasicType::Int1:
return Integer{ 32, 1 };
return Integer{ 32, true };
case ShaderNodes::BasicType::Float2:
case ShaderNodes::BasicType::Float3:
@@ -615,7 +615,7 @@ namespace Nz
return Matrix{ BuildType(ShaderNodes::BasicType::Float4), 4u };
case ShaderNodes::BasicType::UInt1:
return Integer{ 32, 0 };
return Integer{ 32, false };
case ShaderNodes::BasicType::Void:
return Void{};
@@ -706,13 +706,13 @@ namespace Nz
else if constexpr (std::is_same_v<ValueType, float>)
typeId = GetId({ Float{ 32 } });
else if constexpr (std::is_same_v<ValueType, Int32>)
typeId = GetId({ Integer{ 32, 1 } });
typeId = GetId({ Integer{ 32, true } });
else if constexpr (std::is_same_v<ValueType, Int64>)
typeId = GetId({ Integer{ 64, 1 } });
typeId = GetId({ Integer{ 64, true } });
else if constexpr (std::is_same_v<ValueType, UInt32>)
typeId = GetId({ Integer{ 32, 0 } });
typeId = GetId({ Integer{ 32, false } });
else if constexpr (std::is_same_v<ValueType, UInt64>)
typeId = GetId({ Integer{ 64, 0 } });
typeId = GetId({ Integer{ 64, false } });
else
static_assert(AlwaysFalse<ValueType>::value, "non-exhaustive visitor");

View File

@@ -268,7 +268,7 @@ namespace Nz
}
const ShaderExpressionType& builtinExprType = builtin->type;
assert(std::holds_alternative<ShaderNodes::BasicType>(builtinExprType));
assert(IsBasicType(builtinExprType));
ShaderNodes::BasicType builtinType = std::get<ShaderNodes::BasicType>(builtinExprType);