From 1f15328fdd91677451de6ce6d2e9de3708cc6715 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Sat, 1 Jan 2022 23:01:31 +0100 Subject: [PATCH] Shader: Add initial support for arrays --- include/Nazara/Shader/Ast/AstCloner.hpp | 3 +- include/Nazara/Shader/Ast/AstCloner.inl | 10 +- include/Nazara/Shader/Ast/AstCompare.hpp | 65 +++ include/Nazara/Shader/Ast/AstCompare.inl | 494 +++++++++++++++++++ include/Nazara/Shader/Ast/ExpressionType.hpp | 26 +- include/Nazara/Shader/Ast/ExpressionType.inl | 11 + include/Nazara/Shader/GlslWriter.hpp | 1 + include/Nazara/Shader/LangWriter.hpp | 1 + include/Nazara/Shader/ShaderLangParser.hpp | 1 + include/Nazara/Shader/SpirvConstantCache.hpp | 8 +- src/Nazara/Graphics/BasicMaterial.cpp | 2 - src/Nazara/Shader/Ast/AstCloner.cpp | 20 +- src/Nazara/Shader/Ast/AstCompare.cpp | 10 + src/Nazara/Shader/Ast/AstSerializer.cpp | 56 +-- src/Nazara/Shader/Ast/ConstantValue.cpp | 1 + src/Nazara/Shader/Ast/ExpressionType.cpp | 42 ++ src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 10 +- src/Nazara/Shader/GlslWriter.cpp | 14 +- src/Nazara/Shader/LangWriter.cpp | 12 + src/Nazara/Shader/ShaderLangParser.cpp | 27 +- src/Nazara/Shader/ShaderWriter.cpp | 1 + src/Nazara/Shader/SpirvConstantCache.cpp | 23 +- 22 files changed, 781 insertions(+), 57 deletions(-) create mode 100644 include/Nazara/Shader/Ast/AstCompare.hpp create mode 100644 include/Nazara/Shader/Ast/AstCompare.inl create mode 100644 src/Nazara/Shader/Ast/AstCompare.cpp create mode 100644 src/Nazara/Shader/Ast/ExpressionType.cpp diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index 40d6357be..0a14dff0a 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -24,6 +24,7 @@ namespace Nz::ShaderAst AstCloner(AstCloner&&) = delete; ~AstCloner() = default; + template AttributeValue Clone(const AttributeValue& attribute); ExpressionPtr Clone(Expression& statement); StatementPtr Clone(Statement& statement); @@ -31,7 +32,6 @@ namespace Nz::ShaderAst AstCloner& operator=(AstCloner&&) = delete; protected: - template AttributeValue CloneAttribute(const AttributeValue& attribute); inline ExpressionPtr CloneExpression(const ExpressionPtr& expr); inline StatementPtr CloneStatement(const StatementPtr& statement); @@ -83,6 +83,7 @@ namespace Nz::ShaderAst std::vector m_statementStack; }; + template AttributeValue Clone(const AttributeValue& attribute); inline ExpressionPtr Clone(Expression& node); inline StatementPtr Clone(Statement& node); } diff --git a/include/Nazara/Shader/Ast/AstCloner.inl b/include/Nazara/Shader/Ast/AstCloner.inl index ed705fc2d..dc5158d15 100644 --- a/include/Nazara/Shader/Ast/AstCloner.inl +++ b/include/Nazara/Shader/Ast/AstCloner.inl @@ -8,7 +8,7 @@ namespace Nz::ShaderAst { template - AttributeValue AstCloner::CloneAttribute(const AttributeValue& attribute) + AttributeValue AstCloner::Clone(const AttributeValue& attribute) { if (!attribute.HasValue()) return {}; @@ -38,6 +38,14 @@ namespace Nz::ShaderAst return CloneStatement(*statement); } + + template + AttributeValue Clone(const AttributeValue& attribute) + { + AstCloner cloner; + return cloner.Clone(attribute); + } + inline ExpressionPtr Clone(Expression& node) { AstCloner cloner; diff --git a/include/Nazara/Shader/Ast/AstCompare.hpp b/include/Nazara/Shader/Ast/AstCompare.hpp new file mode 100644 index 000000000..4a338abc6 --- /dev/null +++ b/include/Nazara/Shader/Ast/AstCompare.hpp @@ -0,0 +1,65 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#pragma once + +#ifndef NAZARA_SHADER_AST_ASTCOMPARE_HPP +#define NAZARA_SHADER_AST_ASTCOMPARE_HPP + +#include +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + inline bool Compare(const Expression& lhs, const Expression& rhs); + inline bool Compare(const Statement& lhs, const Statement& rhs); + + template bool Compare(const T& lhs, const T& rhs); + template bool Compare(const std::array& lhs, const std::array& rhs); + template bool Compare(const std::vector& lhs, const std::vector& rhs); + template bool Compare(const AttributeValue& lhs, const AttributeValue& rhs); + inline bool Compare(const BranchStatement::ConditionalStatement& lhs, const BranchStatement::ConditionalStatement& rhs); + inline bool Compare(const DeclareExternalStatement::ExternalVar& lhs, const DeclareExternalStatement::ExternalVar& rhs); + inline bool Compare(const DeclareFunctionStatement::Parameter& lhs, const DeclareFunctionStatement::Parameter& rhs); + inline bool Compare(const StructDescription& lhs, const StructDescription& rhs); + inline bool Compare(const StructDescription::StructMember& lhs, const StructDescription::StructMember& rhs); + + inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs); + inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs); + inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs); + inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs); + inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs); + inline bool Compare(const CallMethodExpression& lhs, const CallMethodExpression& rhs); + inline bool Compare(const CastExpression& lhs, const CastExpression& rhs); + inline bool Compare(const ConditionalExpression& lhs, const ConditionalExpression& rhs); + inline bool Compare(const ConstantExpression& lhs, const ConstantExpression& rhs); + inline bool Compare(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs); + inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs); + inline bool Compare(const IntrinsicExpression& lhs, const IntrinsicExpression& rhs); + inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs); + inline bool Compare(const VariableExpression& lhs, const VariableExpression& rhs); + inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs); + + inline bool Compare(const BranchStatement& lhs, const BranchStatement& rhs); + inline bool Compare(const ConditionalStatement& lhs, const ConditionalStatement& rhs); + inline bool Compare(const DeclareConstStatement& lhs, const DeclareConstStatement& rhs); + inline bool Compare(const DeclareExternalStatement& lhs, const DeclareExternalStatement& rhs); + inline bool Compare(const DeclareFunctionStatement& lhs, const DeclareFunctionStatement& rhs); + inline bool Compare(const DeclareOptionStatement& lhs, const DeclareOptionStatement& rhs); + inline bool Compare(const DeclareStructStatement& lhs, const DeclareStructStatement& rhs); + inline bool Compare(const DeclareVariableStatement& lhs, const DeclareVariableStatement& rhs); + inline bool Compare(const DiscardStatement& lhs, const DiscardStatement& rhs); + inline bool Compare(const ExpressionStatement& lhs, const ExpressionStatement& rhs); + inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs); + inline bool Compare(const NoOpStatement& lhs, const NoOpStatement& rhs); + inline bool Compare(const ReturnStatement& lhs, const ReturnStatement& rhs); + inline bool Compare(const WhileStatement& lhs, const WhileStatement& rhs); +} + +#include + +#endif // NAZARA_SHADER_AST_ASTCOMPARE_HPP diff --git a/include/Nazara/Shader/Ast/AstCompare.inl b/include/Nazara/Shader/Ast/AstCompare.inl new file mode 100644 index 000000000..1ee389c18 --- /dev/null +++ b/include/Nazara/Shader/Ast/AstCompare.inl @@ -0,0 +1,494 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include + +namespace Nz::ShaderAst +{ + inline bool Compare(const Expression& lhs, const Expression& rhs) + { + if (lhs.GetType() != rhs.GetType()) + return false; + + switch (lhs.GetType()) + { + case NodeType::None: break; + +#define NAZARA_SHADERAST_EXPRESSION(Node) case NodeType::Node: return Compare(static_cast(lhs), static_cast(lhs)); +#include + + default: throw std::runtime_error("unexpected node type"); + } + + return true; + } + + inline bool Compare(const Statement& lhs, const Statement& rhs) + { + if (lhs.GetType() != rhs.GetType()) + return false; + + switch (lhs.GetType()) + { + case NodeType::None: break; + +#define NAZARA_SHADERAST_STATEMENT(Node) case NodeType::Node: return Compare(static_cast(lhs), static_cast(lhs)); +#include + + default: throw std::runtime_error("unexpected node type"); + } + + return false; + } + + template + bool Compare(const T& lhs, const T& rhs) + { + return lhs == rhs; + } + + template + bool Compare(const std::array& lhs, const std::array& rhs) + { + for (std::size_t i = 0; i < S; ++i) + { + if (!Compare(lhs[i], rhs[i])) + return false; + } + + return true; + } + + template + bool Compare(const std::vector& lhs, const std::vector& rhs) + { + if (lhs.size() != rhs.size()) + return false; + + for (std::size_t i = 0; i < lhs.size(); ++i) + { + if (!Compare(lhs[i], rhs[i])) + return false; + } + + return true; + } + + template + bool Compare(const AttributeValue& lhs, const AttributeValue& rhs) + { + if (!Compare(lhs.HasValue(), rhs.HasValue())) + return false; + + if (!Compare(lhs.IsResultingValue(), rhs.IsResultingValue())) + return false; + + if (!Compare(lhs.IsExpression(), rhs.IsExpression())) + return false; + + if (lhs.IsExpression()) + { + if (!Compare(lhs.GetExpression(), rhs.GetExpression())) + return false; + } + else if (lhs.IsResultingValue()) + { + if (!Compare(lhs.GetResultingValue(), rhs.GetResultingValue())) + return false; + } + + return true; + } + + inline bool Compare(const BranchStatement::ConditionalStatement& lhs, const BranchStatement::ConditionalStatement& rhs) + { + if (!Compare(lhs.condition, rhs.condition)) + return false; + + if (!Compare(lhs.statement, rhs.statement)) + return false; + + return true; + } + + inline bool Compare(const DeclareExternalStatement::ExternalVar& lhs, const DeclareExternalStatement::ExternalVar& rhs) + { + if (!Compare(lhs.bindingIndex, rhs.bindingIndex)) + return false; + + if (!Compare(lhs.bindingSet, rhs.bindingSet)) + return false; + + if (!Compare(lhs.name, rhs.name)) + return false; + + if (!Compare(lhs.type, rhs.type)) + return false; + + return true; + } + + inline bool Compare(const DeclareFunctionStatement::Parameter& lhs, const DeclareFunctionStatement::Parameter& rhs) + { + if (!Compare(lhs.name, rhs.name)) + return false; + + if (!Compare(lhs.type, rhs.type)) + return false; + + return true; + } + + inline bool Compare(const StructDescription& lhs, const StructDescription& rhs) + { + if (!Compare(lhs.layout, rhs.layout)) + return false; + + if (!Compare(lhs.name, rhs.name)) + return false; + + if (!Compare(lhs.members, rhs.members)) + return false; + + return true; + } + + inline bool Compare(const StructDescription::StructMember& lhs, const StructDescription::StructMember& rhs) + { + if (!Compare(lhs.builtin, rhs.builtin)) + return false; + + if (!Compare(lhs.cond, rhs.cond)) + return false; + + if (!Compare(lhs.locationIndex, rhs.locationIndex)) + return false; + + if (!Compare(lhs.name, rhs.name)) + return false; + + if (!Compare(lhs.type, rhs.type)) + return false; + + return true; + } + + inline bool Compare(const AccessIdentifierExpression& lhs, const AccessIdentifierExpression& rhs) + { + if (!Compare(*lhs.expr, *rhs.expr)) + return false; + + if (!Compare(lhs.identifiers, rhs.identifiers)) + return false; + + return true; + } + + inline bool Compare(const AccessIndexExpression& lhs, const AccessIndexExpression& rhs) + { + if (!Compare(*lhs.expr, *rhs.expr)) + return false; + + if (!Compare(lhs.indices, rhs.indices)) + return false; + + return true; + } + + inline bool Compare(const AssignExpression& lhs, const AssignExpression& rhs) + { + if (!Compare(lhs.op, rhs.op)) + return false; + + if (!Compare(lhs.left, rhs.left)) + return false; + + if (!Compare(lhs.right, rhs.right)) + return false; + + return true; + } + + inline bool Compare(const BinaryExpression& lhs, const BinaryExpression& rhs) + { + if (!Compare(lhs.op, rhs.op)) + return false; + + if (!Compare(lhs.left, rhs.left)) + return false; + + if (!Compare(lhs.right, rhs.right)) + return false; + + return true; + } + + inline bool Compare(const CallFunctionExpression& lhs, const CallFunctionExpression& rhs) + { + if (!Compare(lhs.targetFunction, rhs.targetFunction)) + return false; + + if (!Compare(lhs.parameters, rhs.parameters)) + return false; + + return true; + } + + inline bool Compare(const CallMethodExpression& lhs, const CallMethodExpression& rhs) + { + if (!Compare(lhs.methodName, rhs.methodName)) + return false; + + if (!Compare(lhs.object, rhs.object)) + return false; + + if (!Compare(lhs.parameters, rhs.parameters)) + return false; + + return true; + } + + inline bool Compare(const CastExpression& lhs, const CastExpression& rhs) + { + if (!Compare(lhs.targetType, rhs.targetType)) + return false; + + if (!Compare(lhs.expressions, rhs.expressions)) + return false; + + return true; + } + + inline bool Compare(const ConditionalExpression& lhs, const ConditionalExpression& rhs) + { + if (!Compare(lhs.condition, rhs.condition)) + return false; + + if (!Compare(lhs.truePath, rhs.truePath)) + return false; + + if (!Compare(lhs.falsePath, rhs.falsePath)) + return false; + + return true; + } + + inline bool Compare(const ConstantExpression& lhs, const ConstantExpression& rhs) + { + if (!Compare(lhs.constantId, rhs.constantId)) + return false; + + return true; + } + + inline bool Compare(const ConstantValueExpression& lhs, const ConstantValueExpression& rhs) + { + if (!Compare(lhs.value, rhs.value)) + return false; + + return true; + } + + inline bool Compare(const IdentifierExpression& lhs, const IdentifierExpression& rhs) + { + if (!Compare(lhs.identifier, rhs.identifier)) + return false; + + return true; + } + + inline bool Compare(const IntrinsicExpression& lhs, const IntrinsicExpression& rhs) + { + if (!Compare(lhs.intrinsic, rhs.intrinsic)) + return false; + + if (!Compare(lhs.parameters, rhs.parameters)) + return false; + + return true; + } + + inline bool Compare(const SwizzleExpression& lhs, const SwizzleExpression& rhs) + { + if (!Compare(lhs.componentCount, rhs.componentCount)) + return false; + + if (!Compare(lhs.expression, rhs.expression)) + return false; + + if (!Compare(lhs.components, rhs.components)) + return false; + + return true; + } + + inline bool Compare(const VariableExpression& lhs, const VariableExpression& rhs) + { + if (!Compare(lhs.variableId, rhs.variableId)) + return false; + + return true; + } + + inline bool Compare(const UnaryExpression& lhs, const UnaryExpression& rhs) + { + if (!Compare(lhs.op, rhs.op)) + return false; + + if (!Compare(lhs.expression, rhs.expression)) + return false; + + return true; + } + + inline bool Compare(const BranchStatement& lhs, const BranchStatement& rhs) + { + if (!Compare(lhs.isConst, rhs.isConst)) + return false; + + if (!Compare(lhs.elseStatement, rhs.elseStatement)) + return false; + + if (!Compare(lhs.condStatements, rhs.condStatements)) + return false; + + return true; + } + + inline bool Compare(const DeclareConstStatement& lhs, const DeclareConstStatement& rhs) + { + if (!Compare(lhs.name, rhs.name)) + return false; + + if (!Compare(lhs.type, rhs.type)) + return false; + + if (!Compare(lhs.expression, rhs.expression)) + return false; + + return true; + } + + inline bool Compare(const DeclareExternalStatement& lhs, const DeclareExternalStatement& rhs) + { + if (!Compare(lhs.bindingSet, rhs.bindingSet)) + return false; + + if (!Compare(lhs.externalVars, rhs.externalVars)) + return false; + + return true; + } + + inline bool Compare(const DeclareFunctionStatement& lhs, const DeclareFunctionStatement& rhs) + { + if (!Compare(lhs.depthWrite, rhs.depthWrite)) + return false; + + if (!Compare(lhs.earlyFragmentTests, rhs.earlyFragmentTests)) + return false; + + if (!Compare(lhs.entryStage, rhs.entryStage)) + return false; + + if (!Compare(lhs.name, rhs.name)) + return false; + + if (!Compare(lhs.parameters, rhs.parameters)) + return false; + + if (!Compare(lhs.returnType, rhs.returnType)) + return false; + + if (!Compare(lhs.statements, rhs.statements)) + return false; + + return true; + } + + inline bool Compare(const DeclareOptionStatement& lhs, const DeclareOptionStatement& rhs) + { + if (!Compare(lhs.optName, rhs.optName)) + return false; + + if (!Compare(lhs.optType, rhs.optType)) + return false; + + if (!Compare(lhs.defaultValue, rhs.defaultValue)) + return false; + + return true; + } + + inline bool Compare(const DeclareStructStatement& lhs, const DeclareStructStatement& rhs) + { + if (!Compare(lhs.description, rhs.description)) + return false; + + return true; + } + + inline bool Compare(const DeclareVariableStatement& lhs, const DeclareVariableStatement& rhs) + { + if (!Compare(lhs.varName, rhs.varName)) + return false; + + if (!Compare(lhs.varType, rhs.varType)) + return false; + + if (!Compare(lhs.initialExpression, rhs.initialExpression)) + return false; + + return true; + } + + inline bool Compare(const DiscardStatement& /*lhs*/, const DiscardStatement& /*rhs*/) + { + return true; + } + + inline bool Compare(const ExpressionStatement& lhs, const ExpressionStatement& rhs) + { + if (!Compare(lhs.expression, rhs.expression)) + return false; + + return true; + } + + inline bool Compare(const MultiStatement& lhs, const MultiStatement& rhs) + { + if (!Compare(lhs.statements, rhs.statements)) + return false; + + return true; + } + + inline bool Compare(const NoOpStatement& /*lhs*/, const NoOpStatement& /*rhs*/) + { + return true; + } + + inline bool Compare(const ReturnStatement& lhs, const ReturnStatement& rhs) + { + if (!Compare(lhs.returnExpr, rhs.returnExpr)) + return false; + + return true; + } + + inline bool Compare(const WhileStatement& lhs, const WhileStatement& rhs) + { + if (!Compare(lhs.condition, rhs.condition)) + return false; + + if (!Compare(lhs.body, rhs.body)) + return false; + + return true; + } +} + +#include diff --git a/include/Nazara/Shader/Ast/ExpressionType.hpp b/include/Nazara/Shader/Ast/ExpressionType.hpp index 75c21aa3f..2a5d44824 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.hpp +++ b/include/Nazara/Shader/Ast/ExpressionType.hpp @@ -18,6 +18,24 @@ namespace Nz::ShaderAst { + struct ContainedType; + + struct NAZARA_SHADER_API ArrayType + { + ArrayType() = default; + ArrayType(const ArrayType& array); + ArrayType(ArrayType&&) noexcept = default; + + ArrayType& operator=(const ArrayType& array); + ArrayType& operator=(ArrayType&&) noexcept = default; + + AttributeValue length; + std::unique_ptr containedType; + + bool operator==(const ArrayType& rhs) const; + inline bool operator!=(const ArrayType& rhs) const; + }; + struct IdentifierType //< Alias or struct { std::string name; @@ -76,7 +94,12 @@ namespace Nz::ShaderAst inline bool operator!=(const VectorType& rhs) const; }; - using ExpressionType = std::variant; + using ExpressionType = std::variant; + + struct ContainedType + { + ExpressionType type; + }; struct StructDescription { @@ -94,6 +117,7 @@ namespace Nz::ShaderAst std::vector members; }; + inline bool IsArrayType(const ExpressionType& type); inline bool IsIdentifierType(const ExpressionType& type); inline bool IsMatrixType(const ExpressionType& type); inline bool IsNoType(const ExpressionType& type); diff --git a/include/Nazara/Shader/Ast/ExpressionType.inl b/include/Nazara/Shader/Ast/ExpressionType.inl index 6a5e413e7..0fe8d0163 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.inl +++ b/include/Nazara/Shader/Ast/ExpressionType.inl @@ -8,6 +8,12 @@ namespace Nz::ShaderAst { + inline bool ArrayType::operator!=(const ArrayType& rhs) const + { + return !operator==(rhs); + } + + inline bool IdentifierType::operator==(const IdentifierType& rhs) const { return name == rhs.name; @@ -84,6 +90,11 @@ namespace Nz::ShaderAst } + bool IsArrayType(const ExpressionType& type) + { + return std::holds_alternative(type); + } + inline bool IsIdentifierType(const ExpressionType& type) { return std::holds_alternative(type); diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 2a4ea8ce2..deb2ae505 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -50,6 +50,7 @@ namespace Nz static ShaderAst::StatementPtr Sanitize(ShaderAst::Statement& ast, std::unordered_map optionValues, std::string* error = nullptr); private: + void Append(const ShaderAst::ArrayType& type); void Append(const ShaderAst::ExpressionType& type); void Append(ShaderAst::BuiltinEntry builtin); void Append(const ShaderAst::IdentifierType& identifierType); diff --git a/include/Nazara/Shader/LangWriter.hpp b/include/Nazara/Shader/LangWriter.hpp index d62be156d..8264048f2 100644 --- a/include/Nazara/Shader/LangWriter.hpp +++ b/include/Nazara/Shader/LangWriter.hpp @@ -46,6 +46,7 @@ namespace Nz struct LocationAttribute; struct SetAttribute; + void Append(const ShaderAst::ArrayType& type); void Append(const ShaderAst::ExpressionType& type); void Append(const ShaderAst::IdentifierType& identifierType); void Append(const ShaderAst::MatrixType& matrixType); diff --git a/include/Nazara/Shader/ShaderLangParser.hpp b/include/Nazara/Shader/ShaderLangParser.hpp index 28e39cb5a..2bd89245e 100644 --- a/include/Nazara/Shader/ShaderLangParser.hpp +++ b/include/Nazara/Shader/ShaderLangParser.hpp @@ -112,6 +112,7 @@ namespace Nz::ShaderLang ShaderAst::ExpressionPtr ParsePrimaryExpression(); ShaderAst::ExpressionPtr ParseVariableAssignation(); + ShaderAst::ExpressionType ParseArrayType(); ShaderAst::AttributeType ParseIdentifierAsAttributeType(); const std::string& ParseIdentifierAsName(); ShaderAst::PrimitiveType ParsePrimitiveType(); diff --git a/include/Nazara/Shader/SpirvConstantCache.hpp b/include/Nazara/Shader/SpirvConstantCache.hpp index 88d14f9d9..eee6add0c 100644 --- a/include/Nazara/Shader/SpirvConstantCache.hpp +++ b/include/Nazara/Shader/SpirvConstantCache.hpp @@ -39,6 +39,12 @@ namespace Nz using ConstantPtr = std::shared_ptr; using TypePtr = std::shared_ptr; + struct Array + { + TypePtr elementType; + UInt32 length; + }; + struct Bool {}; struct Float @@ -108,7 +114,7 @@ namespace Nz std::vector decorations; }; - using AnyType = std::variant; + using AnyType = std::variant; struct ConstantBool { diff --git a/src/Nazara/Graphics/BasicMaterial.cpp b/src/Nazara/Graphics/BasicMaterial.cpp index 4de873eb2..0d62ff96c 100644 --- a/src/Nazara/Graphics/BasicMaterial.cpp +++ b/src/Nazara/Graphics/BasicMaterial.cpp @@ -8,9 +8,7 @@ #include #include #include -#include #include -#include #include #include #include diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index 10ff59610..2c793c006 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -79,7 +79,7 @@ namespace Nz::ShaderAst auto clone = std::make_unique(); clone->varIndex = node.varIndex; - clone->bindingSet = CloneAttribute(node.bindingSet); + clone->bindingSet = Clone(node.bindingSet); clone->externalVars.reserve(node.externalVars.size()); for (const auto& var : node.externalVars) @@ -87,8 +87,8 @@ namespace Nz::ShaderAst auto& cloneVar = clone->externalVars.emplace_back(); cloneVar.name = var.name; cloneVar.type = var.type; - cloneVar.bindingIndex = CloneAttribute(var.bindingIndex); - cloneVar.bindingSet = CloneAttribute(var.bindingSet); + cloneVar.bindingIndex = Clone(var.bindingIndex); + cloneVar.bindingSet = Clone(var.bindingSet); } return clone; @@ -97,9 +97,9 @@ namespace Nz::ShaderAst StatementPtr AstCloner::Clone(DeclareFunctionStatement& node) { auto clone = std::make_unique(); - clone->depthWrite = CloneAttribute(node.depthWrite); - clone->earlyFragmentTests = CloneAttribute(node.earlyFragmentTests); - clone->entryStage = CloneAttribute(node.entryStage); + clone->depthWrite = Clone(node.depthWrite); + clone->earlyFragmentTests = Clone(node.earlyFragmentTests); + clone->entryStage = Clone(node.entryStage); clone->funcIndex = node.funcIndex; clone->name = node.name; clone->parameters = node.parameters; @@ -129,7 +129,7 @@ namespace Nz::ShaderAst auto clone = std::make_unique(); clone->structIndex = node.structIndex; - clone->description.layout = CloneAttribute(node.description.layout); + clone->description.layout = Clone(node.description.layout); clone->description.name = node.description.name; clone->description.members.reserve(node.description.members.size()); @@ -138,9 +138,9 @@ namespace Nz::ShaderAst auto& cloneMember = clone->description.members.emplace_back(); cloneMember.name = member.name; cloneMember.type = member.type; - cloneMember.builtin = CloneAttribute(member.builtin); - cloneMember.cond = CloneAttribute(member.cond); - cloneMember.locationIndex = CloneAttribute(member.locationIndex); + cloneMember.builtin = Clone(member.builtin); + cloneMember.cond = Clone(member.cond); + cloneMember.locationIndex = Clone(member.locationIndex); } return clone; diff --git a/src/Nazara/Shader/Ast/AstCompare.cpp b/src/Nazara/Shader/Ast/AstCompare.cpp new file mode 100644 index 000000000..03c0b8046 --- /dev/null +++ b/src/Nazara/Shader/Ast/AstCompare.cpp @@ -0,0 +1,10 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include + +namespace Nz::ShaderAst +{ +} diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 661ec1861..df15d3de2 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -409,6 +409,12 @@ namespace Nz::ShaderAst m_stream << UInt32(arg.componentCount); m_stream << UInt32(arg.type); } + else if constexpr (std::is_same_v) + { + m_stream << UInt8(8); + Attribute(arg.length); + Type(arg.containedType->type); + } else static_assert(AlwaysFalse::value, "non-exhaustive visitor"); }, type); @@ -569,40 +575,6 @@ namespace Nz::ShaderAst switch (typeIndex) { - /* - if constexpr (std::is_same_v) - m_stream << UInt8(0); - else if constexpr (std::is_same_v) - { - m_stream << UInt8(1); - m_stream << UInt32(arg); - } - else if constexpr (std::is_same_v) - { - m_stream << UInt8(2); - m_stream << arg.name; - } - else if constexpr (std::is_same_v) - { - m_stream << UInt8(3); - m_stream << UInt32(arg.columnCount); - m_stream << UInt32(arg.rowCount); - m_stream << UInt32(arg.type); - } - else if constexpr (std::is_same_v) - { - m_stream << UInt8(4); - m_stream << UInt32(arg.dim); - m_stream << UInt32(arg.sampledType); - } - else if constexpr (std::is_same_v) - { - m_stream << UInt8(5); - m_stream << UInt32(arg.componentCount); - m_stream << UInt32(arg.type); - } - */ - case 0: //< NoType type = NoType{}; break; @@ -693,6 +665,22 @@ namespace Nz::ShaderAst break; } + case 8: //< ArrayType + { + AttributeValue length; + ExpressionType containedType; + Attribute(length); + Type(containedType); + + ArrayType arrayType; + arrayType.length = std::move(length); + arrayType.containedType = std::make_unique(); + arrayType.containedType->type = std::move(containedType); + + type = std::move(arrayType); + break; + } + default: break; } diff --git a/src/Nazara/Shader/Ast/ConstantValue.cpp b/src/Nazara/Shader/Ast/ConstantValue.cpp index 06f0eb2ab..2de3b2e83 100644 --- a/src/Nazara/Shader/Ast/ConstantValue.cpp +++ b/src/Nazara/Shader/Ast/ConstantValue.cpp @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include namespace Nz::ShaderAst diff --git a/src/Nazara/Shader/Ast/ExpressionType.cpp b/src/Nazara/Shader/Ast/ExpressionType.cpp new file mode 100644 index 000000000..5d507fb4e --- /dev/null +++ b/src/Nazara/Shader/Ast/ExpressionType.cpp @@ -0,0 +1,42 @@ +// Copyright (C) 2022 Jérôme "Lynix" Leclercq (lynix680@gmail.com) +// This file is part of the "Nazara Engine - Shader module" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + ArrayType::ArrayType(const ArrayType& array) + { + assert(array.containedType); + containedType = std::make_unique(*array.containedType); + length = Clone(length); + } + + ArrayType& ArrayType::operator=(const ArrayType& array) + { + assert(array.containedType); + + containedType = std::make_unique(*array.containedType); + length = Clone(length); + + return *this; + } + + bool ArrayType::operator==(const ArrayType& rhs) const + { + assert(containedType); + assert(rhs.containedType); + + if (containedType->type != rhs.containedType->type) + return false; + + if (!Compare(length, rhs.length)) + return false; + + return true; + } +} diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index 768584f1c..ce20348a5 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -1156,6 +1156,7 @@ namespace Nz::ShaderAst if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) return ResolveStruct(arg); else if constexpr (std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || @@ -1205,6 +1206,7 @@ namespace Nz::ShaderAst using T = std::decay_t; if constexpr (std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || @@ -1267,7 +1269,13 @@ namespace Nz::ShaderAst ExpressionType exprType = GetExpressionType(*node.expr); for (const auto& indexExpr : node.indices) { - if (IsStructType(exprType)) + if (IsArrayType(exprType)) + { + const ArrayType& arrayType = std::get(exprType); + + exprType = arrayType.containedType->type; + } + else if (IsStructType(exprType)) { const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr); if (indexExpr->GetType() != NodeType::ConstantValueExpression || indexType != ExpressionType{ PrimitiveType::Int32 }) diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 8c65272a8..86bad6d63 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -219,6 +219,18 @@ namespace Nz return ShaderAst::Sanitize(ast, options, error); } + void GlslWriter::Append(const ShaderAst::ArrayType& type) + { + Append(type.containedType->type, "["); + + if (type.length.IsResultingValue()) + Append(type.length.GetResultingValue()); + else + type.length.GetExpression()->Visit(*this); + + Append("]"); + } + void GlslWriter::Append(const ShaderAst::ExpressionType& type) { std::visit([&](auto&& arg) @@ -963,7 +975,7 @@ namespace Nz std::size_t structIndex = std::get(uniform.containedType).structIndex; ShaderAst::StructDescription* structInfo = Retrieve(m_currentState->structs, structIndex); if (structInfo->layout.HasValue()) - isStd140 = structInfo->layout.GetResultingValue() == StructLayout::Std140; + isStd140 = structInfo->layout.GetResultingValue() == StructLayout::Std140; } if (!m_currentState->bindingMapping.empty() || isStd140) diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index 4c2e7ae17..f9cd93029 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -120,6 +120,18 @@ namespace Nz m_environment = std::move(environment); } + void LangWriter::Append(const ShaderAst::ArrayType& type) + { + Append("[", type.containedType->type, "; "); + + if (type.length.IsResultingValue()) + Append(type.length.GetResultingValue()); + else + type.length.GetExpression()->Visit(*this); + + Append("]"); + } + void LangWriter::Append(const ShaderAst::ExpressionType& type) { std::visit([&](auto&& arg) diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 2c9844aab..8c800843d 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -680,7 +680,7 @@ namespace Nz::ShaderLang ShaderAst::ExpressionType parameterType = ParseType(); - return { parameterName, parameterType }; + return { parameterName, std::move(parameterType) }; } ShaderAst::StatementPtr Parser::ParseOptionDeclaration() @@ -1088,7 +1088,7 @@ namespace Nz::ShaderLang ShaderAst::ExpressionPtr Parser::ParseIntegerExpression() { const Token& integerToken = Expect(Advance(), TokenType::IntegerValue); - return ShaderBuilder::Constant(static_cast(std::get(integerToken.data))); //< FIXME + return ShaderBuilder::Constant(SafeCast(std::get(integerToken.data))); //< FIXME } std::vector Parser::ParseParameters() @@ -1193,6 +1193,24 @@ namespace Nz::ShaderLang } } + ShaderAst::ExpressionType Parser::ParseArrayType() + { + ShaderAst::ArrayType arrayType; + + Expect(Advance(), TokenType::OpenSquareBracket); + + arrayType.containedType = std::make_unique(); + arrayType.containedType->type = ParseType(); + + Expect(Advance(), TokenType::Semicolon); + + arrayType.length = ParseExpression(); + + Expect(Advance(), TokenType::ClosingSquareBracket); + + return arrayType; + } + ShaderAst::AttributeType Parser::ParseIdentifierAsAttributeType() { const Token& identifierToken = Expect(Advance(), TokenType::Identifier); @@ -1240,6 +1258,9 @@ namespace Nz::ShaderLang return ShaderAst::NoType{}; } + if (Peek().type == TokenType::OpenSquareBracket) + return ParseArrayType(); + const Token& identifierToken = Expect(Peek(), TokenType::Identifier); const std::string& identifier = std::get(identifierToken.data); @@ -1250,7 +1271,7 @@ namespace Nz::ShaderLang return ShaderAst::IdentifierType{ identifier }; } - return *type; + return *std::move(type); } int Parser::GetTokenPrecedence(TokenType token) diff --git a/src/Nazara/Shader/ShaderWriter.cpp b/src/Nazara/Shader/ShaderWriter.cpp index 02b3512ad..520a51d2f 100644 --- a/src/Nazara/Shader/ShaderWriter.cpp +++ b/src/Nazara/Shader/ShaderWriter.cpp @@ -3,6 +3,7 @@ // For conditions of distribution and use, see copyright notice in Config.hpp #include +#include #include namespace Nz diff --git a/src/Nazara/Shader/SpirvConstantCache.cpp b/src/Nazara/Shader/SpirvConstantCache.cpp index 547c8f898..625cf6fa8 100644 --- a/src/Nazara/Shader/SpirvConstantCache.cpp +++ b/src/Nazara/Shader/SpirvConstantCache.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -35,6 +36,11 @@ namespace Nz return lhs.value == rhs.value; } + bool Compare(const Array& lhs, const Array& rhs) const + { + return lhs.length == rhs.length && Compare(lhs.elementType, rhs.elementType); + } + bool Compare(const Bool& /*lhs*/, const Bool& /*rhs*/) const { return true; @@ -227,6 +233,12 @@ namespace Nz { } + void Register(const Array& array) + { + assert(array.elementType); + cache.Register(*array.elementType); + } + void Register(const Bool&) {} void Register(const Float&) {} void Register(const Integer&) {} @@ -801,7 +813,9 @@ namespace Nz { using T = std::decay_t; - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) + constants.Append(SpirvOp::OpTypeArray, resultId, GetId(*arg.elementType), arg.length); + else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeBool, resultId); else if constexpr (std::is_same_v) constants.Append(SpirvOp::OpTypeFloat, resultId, arg.width); @@ -892,7 +906,12 @@ namespace Nz { using T = std::decay_t; - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) + { + // TODO + throw std::runtime_error("todo"); + } + else if constexpr (std::is_same_v) return structOffsets.AddField(StructFieldType::Bool1); else if constexpr (std::is_same_v) {