Renderer/ShaderNodes: Add support for accessing struct fields
This commit is contained in:
@@ -21,6 +21,8 @@ namespace Nz
|
||||
if (!ValidateShader(shader, &error))
|
||||
throw std::runtime_error("Invalid shader AST: " + error);
|
||||
|
||||
m_context.shader = &shader;
|
||||
|
||||
State state;
|
||||
m_currentState = &state;
|
||||
CallOnExit onExit([this]()
|
||||
@@ -294,7 +296,30 @@ namespace Nz
|
||||
AppendLine();
|
||||
AppendLine("}");
|
||||
}
|
||||
|
||||
|
||||
void GlslWriter::Visit(const ShaderNodes::AccessMember& node)
|
||||
{
|
||||
Append("(");
|
||||
Visit(node.structExpr);
|
||||
Append(")");
|
||||
|
||||
const ShaderExpressionType& exprType = node.structExpr->GetExpressionType();
|
||||
assert(std::holds_alternative<std::string>(exprType));
|
||||
|
||||
const std::string& structName = std::get<std::string>(exprType);
|
||||
|
||||
const auto& structs = m_context.shader->GetStructs();
|
||||
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; });
|
||||
assert(it != structs.end());
|
||||
|
||||
const ShaderAst::Struct& s = *it;
|
||||
assert(node.memberIndex < s.members.size());
|
||||
|
||||
const auto& member = s.members[node.memberIndex];
|
||||
Append(".");
|
||||
Append(member.name);
|
||||
}
|
||||
|
||||
void GlslWriter::Visit(const ShaderNodes::AssignOp& node)
|
||||
{
|
||||
Visit(node.left);
|
||||
@@ -374,9 +399,7 @@ namespace Nz
|
||||
Append(node.exprType);
|
||||
Append("(");
|
||||
|
||||
unsigned int i = 0;
|
||||
unsigned int requiredComponents = ShaderNodes::Node::GetComponentCount(node.exprType);
|
||||
while (requiredComponents > 0)
|
||||
for (std::size_t i = 0; node.expressions[i]; ++i)
|
||||
{
|
||||
if (i != 0)
|
||||
m_currentState->stream << ", ";
|
||||
@@ -385,7 +408,6 @@ namespace Nz
|
||||
NazaraAssert(exprPtr, "Invalid expression");
|
||||
|
||||
Visit(exprPtr);
|
||||
requiredComponents -= ShaderNodes::Node::GetComponentCount(exprPtr->GetExpressionType());
|
||||
}
|
||||
|
||||
Append(")");
|
||||
|
||||
@@ -47,7 +47,7 @@ namespace Nz::ShaderNodes
|
||||
return ExpressionCategory::LValue;
|
||||
}
|
||||
|
||||
BasicType Identifier::GetExpressionType() const
|
||||
ShaderExpressionType Identifier::GetExpressionType() const
|
||||
{
|
||||
assert(var);
|
||||
return var->type;
|
||||
@@ -58,8 +58,22 @@ namespace Nz::ShaderNodes
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
ExpressionCategory ShaderNodes::AccessMember::GetExpressionCategory() const
|
||||
{
|
||||
return ExpressionCategory::LValue;
|
||||
}
|
||||
|
||||
BasicType AssignOp::GetExpressionType() const
|
||||
ShaderExpressionType AccessMember::GetExpressionType() const
|
||||
{
|
||||
return exprType;
|
||||
}
|
||||
|
||||
void AccessMember::Visit(ShaderVisitor& visitor)
|
||||
{
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
ShaderExpressionType AssignOp::GetExpressionType() const
|
||||
{
|
||||
return left->GetExpressionType();
|
||||
}
|
||||
@@ -70,31 +84,39 @@ namespace Nz::ShaderNodes
|
||||
}
|
||||
|
||||
|
||||
BasicType BinaryOp::GetExpressionType() const
|
||||
ShaderExpressionType BinaryOp::GetExpressionType() const
|
||||
{
|
||||
ShaderNodes::BasicType exprType = ShaderNodes::BasicType::Void;
|
||||
std::optional<ShaderExpressionType> exprType;
|
||||
|
||||
switch (op)
|
||||
{
|
||||
case ShaderNodes::BinaryType::Add:
|
||||
case ShaderNodes::BinaryType::Substract:
|
||||
case BinaryType::Add:
|
||||
case BinaryType::Substract:
|
||||
exprType = left->GetExpressionType();
|
||||
break;
|
||||
|
||||
case ShaderNodes::BinaryType::Divide:
|
||||
case ShaderNodes::BinaryType::Multiply:
|
||||
//FIXME
|
||||
exprType = static_cast<BasicType>(std::max(UnderlyingCast(left->GetExpressionType()), UnderlyingCast(right->GetExpressionType())));
|
||||
break;
|
||||
case BinaryType::Divide:
|
||||
case BinaryType::Multiply:
|
||||
{
|
||||
const ShaderExpressionType& leftExprType = left->GetExpressionType();
|
||||
assert(std::holds_alternative<BasicType>(leftExprType));
|
||||
|
||||
case ShaderNodes::BinaryType::Equality:
|
||||
const ShaderExpressionType& rightExprType = right->GetExpressionType();
|
||||
assert(std::holds_alternative<BasicType>(rightExprType));
|
||||
|
||||
//FIXME
|
||||
exprType = static_cast<BasicType>(std::max(UnderlyingCast(std::get<BasicType>(leftExprType)), UnderlyingCast(std::get<BasicType>(rightExprType))));
|
||||
break;
|
||||
}
|
||||
|
||||
case BinaryType::Equality:
|
||||
exprType = BasicType::Boolean;
|
||||
break;
|
||||
}
|
||||
|
||||
NazaraAssert(exprType != ShaderNodes::BasicType::Void, "Unhandled builtin");
|
||||
NazaraAssert(exprType.has_value(), "Unhandled builtin");
|
||||
|
||||
return exprType;
|
||||
return *exprType;
|
||||
}
|
||||
|
||||
void BinaryOp::Visit(ShaderVisitor& visitor)
|
||||
@@ -109,7 +131,7 @@ namespace Nz::ShaderNodes
|
||||
}
|
||||
|
||||
|
||||
BasicType Constant::GetExpressionType() const
|
||||
ShaderExpressionType Constant::GetExpressionType() const
|
||||
{
|
||||
return exprType;
|
||||
}
|
||||
@@ -119,7 +141,7 @@ namespace Nz::ShaderNodes
|
||||
visitor.Visit(*this);
|
||||
}
|
||||
|
||||
BasicType Cast::GetExpressionType() const
|
||||
ShaderExpressionType Cast::GetExpressionType() const
|
||||
{
|
||||
return exprType;
|
||||
}
|
||||
@@ -135,9 +157,12 @@ namespace Nz::ShaderNodes
|
||||
return ExpressionCategory::LValue;
|
||||
}
|
||||
|
||||
BasicType SwizzleOp::GetExpressionType() const
|
||||
ShaderExpressionType SwizzleOp::GetExpressionType() const
|
||||
{
|
||||
return static_cast<BasicType>(UnderlyingCast(GetComponentType(expression->GetExpressionType())) + componentCount - 1);
|
||||
const ShaderExpressionType& exprType = expression->GetExpressionType();
|
||||
assert(std::holds_alternative<BasicType>(exprType));
|
||||
|
||||
return static_cast<BasicType>(UnderlyingCast(GetComponentType(std::get<BasicType>(exprType))) + componentCount - 1);
|
||||
}
|
||||
|
||||
void SwizzleOp::Visit(ShaderVisitor& visitor)
|
||||
@@ -146,7 +171,7 @@ namespace Nz::ShaderNodes
|
||||
}
|
||||
|
||||
|
||||
BasicType Sample2D::GetExpressionType() const
|
||||
ShaderExpressionType Sample2D::GetExpressionType() const
|
||||
{
|
||||
return BasicType::Float4;
|
||||
}
|
||||
@@ -157,7 +182,7 @@ namespace Nz::ShaderNodes
|
||||
}
|
||||
|
||||
|
||||
BasicType IntrinsicCall::GetExpressionType() const
|
||||
ShaderExpressionType IntrinsicCall::GetExpressionType() const
|
||||
{
|
||||
switch (intrinsic)
|
||||
{
|
||||
|
||||
@@ -22,6 +22,11 @@ namespace Nz
|
||||
{
|
||||
}
|
||||
|
||||
void Visit(const ShaderNodes::AccessMember& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
}
|
||||
|
||||
void Visit(const ShaderNodes::AssignOp& node) override
|
||||
{
|
||||
Serialize(node);
|
||||
@@ -125,6 +130,13 @@ namespace Nz
|
||||
};
|
||||
}
|
||||
|
||||
void ShaderSerializerBase::Serialize(ShaderNodes::AccessMember& node)
|
||||
{
|
||||
Value(node.memberIndex);
|
||||
Node(node.structExpr);
|
||||
Type(node.exprType);
|
||||
}
|
||||
|
||||
void ShaderSerializerBase::Serialize(ShaderNodes::AssignOp& node)
|
||||
{
|
||||
Enum(node.op);
|
||||
@@ -153,8 +165,8 @@ namespace Nz
|
||||
|
||||
void ShaderSerializerBase::Serialize(ShaderNodes::BuiltinVariable& node)
|
||||
{
|
||||
Enum(node.type);
|
||||
Enum(node.type);
|
||||
Enum(node.entry);
|
||||
Type(node.type);
|
||||
}
|
||||
|
||||
void ShaderSerializerBase::Serialize(ShaderNodes::Cast& node)
|
||||
@@ -219,7 +231,7 @@ namespace Nz
|
||||
void ShaderSerializerBase::Serialize(ShaderNodes::NamedVariable& node)
|
||||
{
|
||||
Value(node.name);
|
||||
Enum(node.type);
|
||||
Type(node.type);
|
||||
}
|
||||
|
||||
void ShaderSerializerBase::Serialize(ShaderNodes::Sample2D& node)
|
||||
@@ -348,6 +360,26 @@ namespace Nz
|
||||
}
|
||||
}
|
||||
|
||||
void ShaderSerializer::Type(ShaderExpressionType& type)
|
||||
{
|
||||
std::visit([&](auto&& arg)
|
||||
{
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<T, ShaderNodes::BasicType>)
|
||||
{
|
||||
m_stream << UInt8(0);
|
||||
m_stream << UInt32(arg);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T, std::string>)
|
||||
{
|
||||
m_stream << UInt8(1);
|
||||
m_stream << arg;
|
||||
}
|
||||
else
|
||||
static_assert(AlwaysFalse<T>::value, "non-exhaustive visitor");
|
||||
}, type);
|
||||
}
|
||||
|
||||
void ShaderSerializer::Node(const ShaderNodes::NodePtr& node)
|
||||
{
|
||||
Node(const_cast<ShaderNodes::NodePtr&>(node)); //< Yes const_cast is ugly but it won't be used for writing
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace Nz
|
||||
struct Local
|
||||
{
|
||||
std::string name;
|
||||
ShaderNodes::BasicType type;
|
||||
ShaderExpressionType type;
|
||||
};
|
||||
|
||||
const ShaderAst::Function* currentFunction;
|
||||
@@ -83,6 +83,28 @@ namespace Nz
|
||||
throw AstError{ "Left expression type must match right expression type" };
|
||||
}
|
||||
|
||||
void ShaderValidator::Visit(const ShaderNodes::AccessMember& node)
|
||||
{
|
||||
const ShaderExpressionType& exprType = MandatoryExpr(node.structExpr)->GetExpressionType();
|
||||
if (!std::holds_alternative<std::string>(exprType))
|
||||
throw AstError{ "expression is not a structure" };
|
||||
|
||||
const std::string& structName = std::get<std::string>(exprType);
|
||||
|
||||
const auto& structs = m_shader.GetStructs();
|
||||
auto it = std::find_if(structs.begin(), structs.end(), [&](const auto& s) { return s.name == structName; });
|
||||
if (it == structs.end())
|
||||
throw AstError{ "invalid structure" };
|
||||
|
||||
const ShaderAst::Struct& s = *it;
|
||||
if (node.memberIndex >= s.members.size())
|
||||
throw AstError{ "member index out of bounds" };
|
||||
|
||||
const auto& member = s.members[node.memberIndex];
|
||||
if (member.type != node.exprType)
|
||||
throw AstError{ "member type does not match node type" };
|
||||
}
|
||||
|
||||
void ShaderValidator::Visit(const ShaderNodes::AssignOp& node)
|
||||
{
|
||||
MandatoryNode(node.left);
|
||||
@@ -101,8 +123,16 @@ namespace Nz
|
||||
MandatoryNode(node.left);
|
||||
MandatoryNode(node.right);
|
||||
|
||||
ShaderNodes::BasicType leftType = node.left->GetExpressionType();
|
||||
ShaderNodes::BasicType rightType = node.right->GetExpressionType();
|
||||
const ShaderExpressionType& leftExprType = MandatoryExpr(node.left)->GetExpressionType();
|
||||
if (!std::holds_alternative<ShaderNodes::BasicType>(leftExprType))
|
||||
throw AstError{ "left expression type does not support binary operation" };
|
||||
|
||||
const ShaderExpressionType& rightExprType = MandatoryExpr(node.right)->GetExpressionType();
|
||||
if (!std::holds_alternative<ShaderNodes::BasicType>(rightExprType))
|
||||
throw AstError{ "right expression type does not support binary operation" };
|
||||
|
||||
ShaderNodes::BasicType leftType = std::get<ShaderNodes::BasicType>(leftExprType);
|
||||
ShaderNodes::BasicType rightType = std::get<ShaderNodes::BasicType>(rightExprType);
|
||||
|
||||
switch (node.op)
|
||||
{
|
||||
@@ -179,7 +209,11 @@ namespace Nz
|
||||
if (!exprPtr)
|
||||
break;
|
||||
|
||||
componentCount += node.GetComponentCount(exprPtr->GetExpressionType());
|
||||
const ShaderExpressionType& exprType = exprPtr->GetExpressionType();
|
||||
if (!std::holds_alternative<ShaderNodes::BasicType>(exprType))
|
||||
throw AstError{ "incompatible type" };
|
||||
|
||||
componentCount += node.GetComponentCount(std::get<ShaderNodes::BasicType>(exprType));
|
||||
Visit(exprPtr);
|
||||
}
|
||||
|
||||
@@ -318,7 +352,7 @@ namespace Nz
|
||||
for (auto& param : node.parameters)
|
||||
MandatoryNode(param);
|
||||
|
||||
ShaderNodes::BasicType type = node.parameters.front()->GetExpressionType();
|
||||
ShaderExpressionType type = node.parameters.front()->GetExpressionType();
|
||||
for (std::size_t i = 1; i < node.parameters.size(); ++i)
|
||||
{
|
||||
if (type != node.parameters[i]->GetExpressionType())
|
||||
@@ -333,7 +367,7 @@ namespace Nz
|
||||
{
|
||||
case ShaderNodes::IntrinsicType::CrossProduct:
|
||||
{
|
||||
if (node.parameters[0]->GetExpressionType() != ShaderNodes::BasicType::Float3)
|
||||
if (node.parameters[0]->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float3 })
|
||||
throw AstError{ "CrossProduct only works with Float3 expressions" };
|
||||
|
||||
break;
|
||||
@@ -349,10 +383,10 @@ namespace Nz
|
||||
|
||||
void ShaderValidator::Visit(const ShaderNodes::Sample2D& node)
|
||||
{
|
||||
if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderNodes::BasicType::Sampler2D)
|
||||
if (MandatoryExpr(node.sampler)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Sampler2D })
|
||||
throw AstError{ "Sampler must be a Sampler2D" };
|
||||
|
||||
if (MandatoryExpr(node.coordinates)->GetExpressionType() != ShaderNodes::BasicType::Float2)
|
||||
if (MandatoryExpr(node.coordinates)->GetExpressionType() != ShaderExpressionType{ ShaderNodes::BasicType::Float2 })
|
||||
throw AstError{ "Coordinates must be a Float2" };
|
||||
|
||||
Visit(node.sampler);
|
||||
@@ -378,7 +412,11 @@ namespace Nz
|
||||
if (node.componentCount > 4)
|
||||
throw AstError{ "Cannot swizzle more than four elements" };
|
||||
|
||||
switch (MandatoryExpr(node.expression)->GetExpressionType())
|
||||
const ShaderExpressionType& exprType = MandatoryExpr(node.expression)->GetExpressionType();
|
||||
if (!std::holds_alternative<ShaderNodes::BasicType>(exprType))
|
||||
throw AstError{ "Cannot swizzle this type" };
|
||||
|
||||
switch (std::get<ShaderNodes::BasicType>(exprType))
|
||||
{
|
||||
case ShaderNodes::BasicType::Float1:
|
||||
case ShaderNodes::BasicType::Float2:
|
||||
|
||||
Reference in New Issue
Block a user