From f9af35b48995de393b56d3fa6d1ffcdbea69f34d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Leclercq?= Date: Wed, 7 Jul 2021 11:41:58 +0200 Subject: [PATCH] Shader: Attribute can now have expressions as values and struct fields can be conditionally supported --- include/Nazara/Shader/Ast/AstCloner.hpp | 3 + include/Nazara/Shader/Ast/AstCloner.inl | 15 + include/Nazara/Shader/Ast/AstNodeList.hpp | 1 + include/Nazara/Shader/Ast/AstOptimizer.hpp | 21 +- include/Nazara/Shader/Ast/AstOptimizer.inl | 40 +- .../Nazara/Shader/Ast/AstRecursiveVisitor.hpp | 1 + include/Nazara/Shader/Ast/AstSerializer.hpp | 2 + include/Nazara/Shader/Ast/AstSerializer.inl | 67 +++ include/Nazara/Shader/Ast/AstUtils.hpp | 1 + include/Nazara/Shader/Ast/Attribute.hpp | 37 +- include/Nazara/Shader/Ast/Attribute.inl | 72 +++ include/Nazara/Shader/Ast/ConstantValue.hpp | 5 +- include/Nazara/Shader/Ast/Enums.hpp | 2 +- include/Nazara/Shader/Ast/ExpressionType.hpp | 7 +- include/Nazara/Shader/Ast/Nodes.hpp | 54 ++- include/Nazara/Shader/Ast/SanitizeVisitor.hpp | 20 +- include/Nazara/Shader/Ast/SanitizeVisitor.inl | 9 - include/Nazara/Shader/GlslWriter.hpp | 6 +- include/Nazara/Shader/LangWriter.hpp | 5 +- include/Nazara/Shader/ShaderBuilder.hpp | 4 +- include/Nazara/Shader/ShaderBuilder.inl | 12 +- include/Nazara/Shader/SpirvAstVisitor.hpp | 6 +- include/Nazara/Shader/SpirvAstVisitor.inl | 4 +- src/Nazara/Graphics/UberShader.cpp | 7 +- .../OpenGLRenderer/OpenGLShaderModule.cpp | 2 +- src/Nazara/Shader/Ast/AstCloner.cpp | 49 ++- src/Nazara/Shader/Ast/AstOptimizer.cpp | 55 ++- src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp | 6 + src/Nazara/Shader/Ast/AstSerializer.cpp | 30 +- src/Nazara/Shader/Ast/AstUtils.cpp | 5 + src/Nazara/Shader/Ast/SanitizeVisitor.cpp | 409 +++++++++++------- src/Nazara/Shader/GlslWriter.cpp | 95 ++-- src/Nazara/Shader/LangWriter.cpp | 187 ++++---- src/Nazara/Shader/ShaderLangParser.cpp | 237 ++++------ src/Nazara/Shader/SpirvAstVisitor.cpp | 16 +- src/Nazara/Shader/SpirvWriter.cpp | 53 ++- 36 files changed, 945 insertions(+), 600 deletions(-) create mode 100644 include/Nazara/Shader/Ast/Attribute.inl diff --git a/include/Nazara/Shader/Ast/AstCloner.hpp b/include/Nazara/Shader/Ast/AstCloner.hpp index 87f480391..a81acfdff 100644 --- a/include/Nazara/Shader/Ast/AstCloner.hpp +++ b/include/Nazara/Shader/Ast/AstCloner.hpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -30,6 +31,7 @@ namespace Nz::ShaderAst AstCloner& operator=(AstCloner&&) = delete; protected: + template AttributeValue CloneAttribute(const AttributeValue& attribute); inline ExpressionPtr CloneExpression(const ExpressionPtr& expr); inline StatementPtr CloneStatement(const StatementPtr& statement); @@ -44,6 +46,7 @@ namespace Nz::ShaderAst virtual ExpressionPtr Clone(CallMethodExpression& node); virtual ExpressionPtr Clone(CastExpression& node); virtual ExpressionPtr Clone(ConditionalExpression& node); + virtual ExpressionPtr Clone(ConstantIndexExpression& node); virtual ExpressionPtr Clone(ConstantExpression& node); virtual ExpressionPtr Clone(IdentifierExpression& node); virtual ExpressionPtr Clone(IntrinsicExpression& node); diff --git a/include/Nazara/Shader/Ast/AstCloner.inl b/include/Nazara/Shader/Ast/AstCloner.inl index 4e9375b76..ea509bd08 100644 --- a/include/Nazara/Shader/Ast/AstCloner.inl +++ b/include/Nazara/Shader/Ast/AstCloner.inl @@ -7,6 +7,21 @@ namespace Nz::ShaderAst { + template + AttributeValue AstCloner::CloneAttribute(const AttributeValue& attribute) + { + if (!attribute.HasValue()) + return {}; + + if (attribute.IsExpression()) + return CloneExpression(attribute.GetExpression()); + else + { + assert(attribute.IsResultingValue()); + return attribute.GetResultingValue(); + } + } + ExpressionPtr AstCloner::CloneExpression(const ExpressionPtr& expr) { if (!expr) diff --git a/include/Nazara/Shader/Ast/AstNodeList.hpp b/include/Nazara/Shader/Ast/AstNodeList.hpp index 37f694d5d..60978d8bd 100644 --- a/include/Nazara/Shader/Ast/AstNodeList.hpp +++ b/include/Nazara/Shader/Ast/AstNodeList.hpp @@ -35,6 +35,7 @@ NAZARA_SHADERAST_EXPRESSION(CallMethodExpression) NAZARA_SHADERAST_EXPRESSION(CastExpression) NAZARA_SHADERAST_EXPRESSION(ConditionalExpression) NAZARA_SHADERAST_EXPRESSION(ConstantExpression) +NAZARA_SHADERAST_EXPRESSION(ConstantIndexExpression) NAZARA_SHADERAST_EXPRESSION(IdentifierExpression) NAZARA_SHADERAST_EXPRESSION(IntrinsicExpression) NAZARA_SHADERAST_EXPRESSION(SelectOptionExpression) diff --git a/include/Nazara/Shader/Ast/AstOptimizer.hpp b/include/Nazara/Shader/Ast/AstOptimizer.hpp index afc63ec93..9b541e504 100644 --- a/include/Nazara/Shader/Ast/AstOptimizer.hpp +++ b/include/Nazara/Shader/Ast/AstOptimizer.hpp @@ -18,21 +18,32 @@ namespace Nz::ShaderAst class NAZARA_SHADER_API AstOptimizer : public AstCloner { public: + struct Options; + AstOptimizer() = default; AstOptimizer(const AstOptimizer&) = delete; AstOptimizer(AstOptimizer&&) = delete; ~AstOptimizer() = default; - StatementPtr Optimise(Statement& statement); - StatementPtr Optimise(Statement& statement, UInt64 enabledConditions); + inline ExpressionPtr Optimise(Expression& expression); + inline ExpressionPtr Optimise(Expression& expression, const Options& options); + inline StatementPtr Optimise(Statement& statement); + inline StatementPtr Optimise(Statement& statement, const Options& options); AstOptimizer& operator=(const AstOptimizer&) = delete; AstOptimizer& operator=(AstOptimizer&&) = delete; + struct Options + { + std::function constantQueryCallback; + std::optional enabledOptions = 0; + }; + protected: ExpressionPtr Clone(BinaryExpression& node) override; ExpressionPtr Clone(CastExpression& node) override; ExpressionPtr Clone(ConditionalExpression& node) override; + ExpressionPtr Clone(ConstantIndexExpression& node) override; ExpressionPtr Clone(UnaryExpression& node) override; StatementPtr Clone(BranchStatement& node) override; StatementPtr Clone(ConditionalStatement& node) override; @@ -45,11 +56,13 @@ namespace Nz::ShaderAst template ExpressionPtr PropagateVec4Cast(TargetType v1, TargetType v2, TargetType v3, TargetType v4); private: - std::optional m_enabledOptions; + Options m_options; }; + inline ExpressionPtr Optimize(Expression& expr); + inline ExpressionPtr Optimize(Expression& expr, const AstOptimizer::Options& options); inline StatementPtr Optimize(Statement& ast); - inline StatementPtr Optimize(Statement& ast, UInt64 enabledConditions); + inline StatementPtr Optimize(Statement& ast, const AstOptimizer::Options& options); } #include diff --git a/include/Nazara/Shader/Ast/AstOptimizer.inl b/include/Nazara/Shader/Ast/AstOptimizer.inl index 398ea643d..f501117b7 100644 --- a/include/Nazara/Shader/Ast/AstOptimizer.inl +++ b/include/Nazara/Shader/Ast/AstOptimizer.inl @@ -7,16 +7,52 @@ namespace Nz::ShaderAst { + inline ExpressionPtr AstOptimizer::Optimise(Expression& expression) + { + m_options = {}; + return CloneExpression(expression); + } + + inline ExpressionPtr AstOptimizer::Optimise(Expression& expression, const Options& options) + { + m_options = options; + return CloneExpression(expression); + } + + inline StatementPtr AstOptimizer::Optimise(Statement& statement) + { + m_options = {}; + return CloneStatement(statement); + } + + inline StatementPtr AstOptimizer::Optimise(Statement& statement, const Options& options) + { + m_options = options; + return CloneStatement(statement); + } + + inline ExpressionPtr Optimize(Expression& ast) + { + AstOptimizer optimize; + return optimize.Optimise(ast); + } + + inline ExpressionPtr Optimize(Expression& ast, const AstOptimizer::Options& options) + { + AstOptimizer optimize; + return optimize.Optimise(ast, options); + } + inline StatementPtr Optimize(Statement& ast) { AstOptimizer optimize; return optimize.Optimise(ast); } - inline StatementPtr Optimize(Statement& ast, UInt64 enabledConditions) + inline StatementPtr Optimize(Statement& ast, const AstOptimizer::Options& options) { AstOptimizer optimize; - return optimize.Optimise(ast, enabledConditions); + return optimize.Optimise(ast, options); } } diff --git a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp index 5b1428633..9410b4cb6 100644 --- a/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp +++ b/include/Nazara/Shader/Ast/AstRecursiveVisitor.hpp @@ -29,6 +29,7 @@ namespace Nz::ShaderAst void Visit(CastExpression& node) override; void Visit(ConditionalExpression& node) override; void Visit(ConstantExpression& node) override; + void Visit(ConstantIndexExpression& node) override; void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; void Visit(SelectOptionExpression& node) override; diff --git a/include/Nazara/Shader/Ast/AstSerializer.hpp b/include/Nazara/Shader/Ast/AstSerializer.hpp index 932f7991a..9756fcfca 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.hpp +++ b/include/Nazara/Shader/Ast/AstSerializer.hpp @@ -30,6 +30,7 @@ namespace Nz::ShaderAst void Serialize(CallFunctionExpression& node); void Serialize(CallMethodExpression& node); void Serialize(CastExpression& node); + void Serialize(ConstantIndexExpression& node); void Serialize(ConditionalExpression& node); void Serialize(ConstantExpression& node); void Serialize(IdentifierExpression& node); @@ -53,6 +54,7 @@ namespace Nz::ShaderAst void Serialize(ReturnStatement& node); protected: + template void Attribute(AttributeValue& attribute); template void Container(T& container); template void Enum(T& enumVal); template void OptEnum(std::optional& optVal); diff --git a/include/Nazara/Shader/Ast/AstSerializer.inl b/include/Nazara/Shader/Ast/AstSerializer.inl index 025e869b9..6bbb5832a 100644 --- a/include/Nazara/Shader/Ast/AstSerializer.inl +++ b/include/Nazara/Shader/Ast/AstSerializer.inl @@ -7,6 +7,73 @@ namespace Nz::ShaderAst { + template + void AstSerializerBase::Attribute(AttributeValue& attribute) + { + UInt32 valueType; + if (IsWriting()) + { + if (!attribute.HasValue()) + valueType = 0; + else if (attribute.IsExpression()) + valueType = 1; + else if (attribute.IsResultingValue()) + valueType = 2; + else + throw std::runtime_error("unexpected attribute"); + } + + Value(valueType); + + switch (valueType) + { + case 0: + if (!IsWriting()) + attribute = {}; + + break; + + case 1: + { + if (!IsWriting()) + { + ExpressionPtr expr; + Node(expr); + + attribute = std::move(expr); + } + else + Node(const_cast(attribute.GetExpression())); //< not used for writing + + break; + } + + case 2: + { + if (!IsWriting()) + { + T value; + if constexpr (std::is_enum_v) + Enum(value); + else + Value(value); + + attribute = std::move(value); + } + else + { + T& value = const_cast(attribute.GetResultingValue()); //< not used for writing + if constexpr (std::is_enum_v) + Enum(value); + else + Value(value); + } + + break; + } + } + } + template void AstSerializerBase::Container(T& container) { diff --git a/include/Nazara/Shader/Ast/AstUtils.hpp b/include/Nazara/Shader/Ast/AstUtils.hpp index e020c10ed..e32d9bd8d 100644 --- a/include/Nazara/Shader/Ast/AstUtils.hpp +++ b/include/Nazara/Shader/Ast/AstUtils.hpp @@ -40,6 +40,7 @@ namespace Nz::ShaderAst void Visit(CastExpression& node) override; void Visit(ConditionalExpression& node) override; void Visit(ConstantExpression& node) override; + void Visit(ConstantIndexExpression& node) override; void Visit(IdentifierExpression& node) override; void Visit(IntrinsicExpression& node) override; void Visit(SelectOptionExpression& node) override; diff --git a/include/Nazara/Shader/Ast/Attribute.hpp b/include/Nazara/Shader/Ast/Attribute.hpp index 5bd1cf589..bbe1db5d0 100644 --- a/include/Nazara/Shader/Ast/Attribute.hpp +++ b/include/Nazara/Shader/Ast/Attribute.hpp @@ -9,17 +9,52 @@ #include #include +#include +#include #include namespace Nz::ShaderAst { + struct Expression; + + using ExpressionPtr = std::unique_ptr; + + template + class AttributeValue + { + public: + AttributeValue() = default; + AttributeValue(T value); + AttributeValue(ExpressionPtr expr); + AttributeValue(const AttributeValue&) = default; + AttributeValue(AttributeValue&&) = default; + ~AttributeValue() = default; + + ExpressionPtr&& GetExpression() &&; + const ExpressionPtr& GetExpression() const &; + const T& GetResultingValue() const; + + bool IsExpression() const; + bool IsResultingValue() const; + + bool HasValue() const; + + AttributeValue& operator=(const AttributeValue&) = default; + AttributeValue& operator=(AttributeValue&&) = default; + + private: + std::variant m_value; + }; + struct Attribute { - using Param = std::variant; + using Param = std::optional; AttributeType type; Param args; }; } +#include + #endif diff --git a/include/Nazara/Shader/Ast/Attribute.inl b/include/Nazara/Shader/Ast/Attribute.inl new file mode 100644 index 000000000..b6662aede --- /dev/null +++ b/include/Nazara/Shader/Ast/Attribute.inl @@ -0,0 +1,72 @@ +// Copyright (C) 2020 Jérôme Leclercq +// This file is part of the "Nazara Engine - Shader generator" +// For conditions of distribution and use, see copyright notice in Config.hpp + +#include +#include +#include +#include + +namespace Nz::ShaderAst +{ + template + AttributeValue::AttributeValue(T value) : + m_value(std::move(value)) + { + } + + template + AttributeValue::AttributeValue(ExpressionPtr expr) + { + assert(expr); + m_value = std::move(expr); + } + + template + ExpressionPtr&& AttributeValue::GetExpression() && + { + if (!IsExpression()) + throw std::runtime_error("excepted expression"); + + return std::get(std::move(m_value)); + } + + template + const ExpressionPtr& AttributeValue::GetExpression() const & + { + if (!IsExpression()) + throw std::runtime_error("excepted expression"); + + assert(std::get(m_value)); + return std::get(m_value); + } + + template + const T& AttributeValue::GetResultingValue() const + { + if (!IsResultingValue()) + throw std::runtime_error("excepted resulting value"); + + return std::get(m_value); + } + + template + bool AttributeValue::IsExpression() const + { + return std::holds_alternative(m_value); + } + + template + bool AttributeValue::IsResultingValue() const + { + return std::holds_alternative(m_value); + } + + template + bool AttributeValue::HasValue() const + { + return !std::holds_alternative(m_value); + } +} + +#include diff --git a/include/Nazara/Shader/Ast/ConstantValue.hpp b/include/Nazara/Shader/Ast/ConstantValue.hpp index 3064603cf..1b04f98b7 100644 --- a/include/Nazara/Shader/Ast/ConstantValue.hpp +++ b/include/Nazara/Shader/Ast/ConstantValue.hpp @@ -8,6 +8,7 @@ #define NAZARA_SHADER_CONSTANTVALUE_HPP #include +#include #include #include #include @@ -17,7 +18,7 @@ namespace Nz::ShaderAst { - using ConstantValue = std::variant< + using ConstantTypes = TypeList< bool, float, Int32, @@ -30,6 +31,8 @@ namespace Nz::ShaderAst Vector4i32 >; + using ConstantValue = TypeListInstantiate; + NAZARA_SHADER_API ExpressionType GetExpressionType(const ConstantValue& constant); } diff --git a/include/Nazara/Shader/Ast/Enums.hpp b/include/Nazara/Shader/Ast/Enums.hpp index 027b6c500..f6b30890b 100644 --- a/include/Nazara/Shader/Ast/Enums.hpp +++ b/include/Nazara/Shader/Ast/Enums.hpp @@ -23,12 +23,12 @@ namespace Nz { Binding, //< Binding (external var only) - has argument index Builtin, //< Builtin (struct member only) - has argument type + Cond, //< Conditional compilation option - has argument expr DepthWrite, //< Depth write mode (function only) - has argument type EarlyFragmentTests, //< Entry point (function only) - has argument on/off Entry, //< Entry point (function only) - has argument type Layout, //< Struct layout (struct only) - has argument style Location, //< Location (struct member only) - has argument index - Option, //< Conditional compilation option - has argument expr Set, //< Binding set (external var only) - has argument index }; diff --git a/include/Nazara/Shader/Ast/ExpressionType.hpp b/include/Nazara/Shader/Ast/ExpressionType.hpp index 185dc12ab..8d68fd802 100644 --- a/include/Nazara/Shader/Ast/ExpressionType.hpp +++ b/include/Nazara/Shader/Ast/ExpressionType.hpp @@ -82,13 +82,14 @@ namespace Nz::ShaderAst { struct StructMember { - std::optional builtin; - std::optional locationIndex; + AttributeValue builtin; + AttributeValue cond; + AttributeValue locationIndex; std::string name; ExpressionType type; }; - std::optional layout; + AttributeValue layout; std::string name; std::vector members; }; diff --git a/include/Nazara/Shader/Ast/Nodes.hpp b/include/Nazara/Shader/Ast/Nodes.hpp index 1863518c5..c1b8d4de8 100644 --- a/include/Nazara/Shader/Ast/Nodes.hpp +++ b/include/Nazara/Shader/Ast/Nodes.hpp @@ -64,7 +64,7 @@ namespace Nz::ShaderAst std::optional cachedExpressionType; }; - struct NAZARA_SHADER_API AccessIdentifierExpression : public Expression + struct NAZARA_SHADER_API AccessIdentifierExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -73,7 +73,7 @@ namespace Nz::ShaderAst std::vector identifiers; }; - struct NAZARA_SHADER_API AccessIndexExpression : public Expression + struct NAZARA_SHADER_API AccessIndexExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -82,7 +82,7 @@ namespace Nz::ShaderAst std::vector indices; }; - struct NAZARA_SHADER_API AssignExpression : public Expression + struct NAZARA_SHADER_API AssignExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -92,7 +92,7 @@ namespace Nz::ShaderAst ExpressionPtr right; }; - struct NAZARA_SHADER_API BinaryExpression : public Expression + struct NAZARA_SHADER_API BinaryExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -102,7 +102,7 @@ namespace Nz::ShaderAst ExpressionPtr right; }; - struct NAZARA_SHADER_API CallFunctionExpression : public Expression + struct NAZARA_SHADER_API CallFunctionExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -111,7 +111,7 @@ namespace Nz::ShaderAst std::vector parameters; }; - struct NAZARA_SHADER_API CallMethodExpression : public Expression + struct NAZARA_SHADER_API CallMethodExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -121,7 +121,7 @@ namespace Nz::ShaderAst std::vector parameters; }; - struct NAZARA_SHADER_API CastExpression : public Expression + struct NAZARA_SHADER_API CastExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -130,17 +130,17 @@ namespace Nz::ShaderAst std::array expressions; }; - struct NAZARA_SHADER_API ConditionalExpression : public Expression + struct NAZARA_SHADER_API ConditionalExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; - std::size_t optionIndex; + ExpressionPtr condition; ExpressionPtr falsePath; ExpressionPtr truePath; }; - struct NAZARA_SHADER_API ConstantExpression : public Expression + struct NAZARA_SHADER_API ConstantExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -148,7 +148,15 @@ namespace Nz::ShaderAst ShaderAst::ConstantValue value; }; - struct NAZARA_SHADER_API IdentifierExpression : public Expression + struct NAZARA_SHADER_API ConstantIndexExpression : Expression + { + NodeType GetType() const override; + void Visit(AstExpressionVisitor& visitor) override; + + std::size_t constantId; + }; + + struct NAZARA_SHADER_API IdentifierExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -156,7 +164,7 @@ namespace Nz::ShaderAst std::string identifier; }; - struct NAZARA_SHADER_API IntrinsicExpression : public Expression + struct NAZARA_SHADER_API IntrinsicExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -165,7 +173,7 @@ namespace Nz::ShaderAst std::vector parameters; }; - struct NAZARA_SHADER_API SelectOptionExpression : public Expression + struct NAZARA_SHADER_API SelectOptionExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -175,7 +183,7 @@ namespace Nz::ShaderAst ExpressionPtr truePath; }; - struct NAZARA_SHADER_API SwizzleExpression : public Expression + struct NAZARA_SHADER_API SwizzleExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -193,7 +201,7 @@ namespace Nz::ShaderAst std::size_t variableId; }; - struct NAZARA_SHADER_API UnaryExpression : public Expression + struct NAZARA_SHADER_API UnaryExpression : Expression { NodeType GetType() const override; void Visit(AstExpressionVisitor& visitor) override; @@ -221,7 +229,7 @@ namespace Nz::ShaderAst Statement& operator=(Statement&&) noexcept = default; }; - struct NAZARA_SHADER_API BranchStatement : public Statement + struct NAZARA_SHADER_API BranchStatement : Statement { NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; @@ -241,7 +249,7 @@ namespace Nz::ShaderAst NodeType GetType() const override; void Visit(AstStatementVisitor& visitor) override; - std::size_t optionIndex; + ExpressionPtr condition; StatementPtr statement; }; @@ -252,12 +260,13 @@ namespace Nz::ShaderAst struct ExternalVar { - std::optional bindingIndex; - std::optional bindingSet; + AttributeValue bindingIndex; + AttributeValue bindingSet; std::string name; ExpressionType type; }; + AttributeValue bindingSet; std::optional varIndex; std::vector externalVars; }; @@ -273,12 +282,11 @@ namespace Nz::ShaderAst ExpressionType type; }; - std::optional depthWrite; - std::optional earlyFragmentTests; - std::optional entryStage; + AttributeValue depthWrite; + AttributeValue earlyFragmentTests; + AttributeValue entryStage; std::optional funcIndex; std::optional varIndex; - std::string optionName; std::string name; std::vector parameters; std::vector statements; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp index cf70e918b..1fbb305e0 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.hpp +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.hpp @@ -36,6 +36,7 @@ namespace Nz::ShaderAst struct Options { std::unordered_set reservedIdentifiers; + UInt64 enabledOptions = 0; bool makeVariableNameUnique = false; bool removeOptionDeclaration = true; }; @@ -71,7 +72,7 @@ namespace Nz::ShaderAst StatementPtr Clone(ExpressionStatement& node) override; StatementPtr Clone(MultiStatement& node) override; - inline const Identifier* FindIdentifier(const std::string_view& identifierName) const; + const Identifier* FindIdentifier(const std::string_view& identifierName) const; Expression& MandatoryExpr(ExpressionPtr& node); Statement& MandatoryStatement(StatementPtr& node); @@ -81,14 +82,17 @@ namespace Nz::ShaderAst void PushScope(); void PopScope(); + template const T& ComputeAttributeValue(AttributeValue& attribute); + ConstantValue ComputeConstantValue(Expression& expr); + std::size_t DeclareFunction(DeclareFunctionStatement& funcDecl); void PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen); + std::size_t RegisterConstant(std::string name, ConstantValue value); FunctionData& RegisterFunction(std::size_t functionIndex); std::size_t RegisterIntrinsic(std::string name, IntrinsicType type); - std::size_t RegisterOption(std::string name, ExpressionType type); - std::size_t RegisterStruct(std::string name, StructDescription description); + std::size_t RegisterStruct(std::string name, StructDescription* description); std::size_t RegisterVariable(std::string name, ExpressionType type); void ResolveFunctions(); @@ -118,9 +122,9 @@ namespace Nz::ShaderAst enum class Type { Alias, + Constant, Function, Intrinsic, - Option, Struct, Variable }; @@ -130,14 +134,6 @@ namespace Nz::ShaderAst Type type; }; - std::vector m_identifiersInScope; - std::vector m_functions; - std::vector m_intrinsics; - std::vector m_options; - std::vector m_structs; - std::vector m_variableTypes; - std::vector m_scopeSizes; - struct Context; Context* m_context; }; diff --git a/include/Nazara/Shader/Ast/SanitizeVisitor.inl b/include/Nazara/Shader/Ast/SanitizeVisitor.inl index c4def12fe..7c0e157ba 100644 --- a/include/Nazara/Shader/Ast/SanitizeVisitor.inl +++ b/include/Nazara/Shader/Ast/SanitizeVisitor.inl @@ -12,15 +12,6 @@ namespace Nz::ShaderAst return Sanitize(statement, {}, error); } - inline auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* - { - auto it = std::find_if(m_identifiersInScope.rbegin(), m_identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); - if (it == m_identifiersInScope.rend()) - return nullptr; - - return &*it; - } - inline StatementPtr Sanitize(Statement& ast, std::string* error) { SanitizeVisitor sanitizer; diff --git a/include/Nazara/Shader/GlslWriter.hpp b/include/Nazara/Shader/GlslWriter.hpp index 988489bac..5bb96e34c 100644 --- a/include/Nazara/Shader/GlslWriter.hpp +++ b/include/Nazara/Shader/GlslWriter.hpp @@ -46,7 +46,7 @@ namespace Nz }; static const char* GetFlipYUniformName(); - static ShaderAst::StatementPtr Sanitize(ShaderAst::Statement& ast, std::string* error = nullptr); + static ShaderAst::StatementPtr Sanitize(ShaderAst::Statement& ast, UInt64 enabledConditions, std::string* error = nullptr); private: void Append(const ShaderAst::ExpressionType& type); @@ -76,7 +76,7 @@ namespace Nz void HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node); void HandleInOut(); - void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc); + void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc); void RegisterVariable(std::size_t varIndex, std::string varName); void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); @@ -86,7 +86,6 @@ namespace Nz void Visit(ShaderAst::BinaryExpression& node) override; void Visit(ShaderAst::CallFunctionExpression& node) override; void Visit(ShaderAst::CastExpression& node) override; - void Visit(ShaderAst::ConditionalExpression& node) override; void Visit(ShaderAst::ConstantExpression& node) override; void Visit(ShaderAst::IntrinsicExpression& node) override; void Visit(ShaderAst::SwizzleExpression& node) override; @@ -94,7 +93,6 @@ namespace Nz void Visit(ShaderAst::UnaryExpression& node) override; void Visit(ShaderAst::BranchStatement& node) override; - void Visit(ShaderAst::ConditionalStatement& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override; void Visit(ShaderAst::DeclareOptionStatement& node) override; diff --git a/include/Nazara/Shader/LangWriter.hpp b/include/Nazara/Shader/LangWriter.hpp index 5982d3627..97e5d8b90 100644 --- a/include/Nazara/Shader/LangWriter.hpp +++ b/include/Nazara/Shader/LangWriter.hpp @@ -78,8 +78,8 @@ namespace Nz void EnterScope(); void LeaveScope(bool skipLine = true); - void RegisterOption(std::size_t optionIndex, std::string optionName); - void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc); + void RegisterConstant(std::size_t constantIndex, std::string constantName); + void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc); void RegisterVariable(std::size_t varIndex, std::string varName); void Visit(ShaderAst::ExpressionPtr& expr, bool encloseIfRequired = false); @@ -90,6 +90,7 @@ namespace Nz void Visit(ShaderAst::CastExpression& node) override; void Visit(ShaderAst::ConditionalExpression& node) override; void Visit(ShaderAst::ConstantExpression& node) override; + void Visit(ShaderAst::ConstantIndexExpression& node) override; void Visit(ShaderAst::IntrinsicExpression& node) override; void Visit(ShaderAst::SwizzleExpression& node) override; void Visit(ShaderAst::VariableExpression& node) override; diff --git a/include/Nazara/Shader/ShaderBuilder.hpp b/include/Nazara/Shader/ShaderBuilder.hpp index d4fd91862..945f3f725 100644 --- a/include/Nazara/Shader/ShaderBuilder.hpp +++ b/include/Nazara/Shader/ShaderBuilder.hpp @@ -57,12 +57,12 @@ namespace Nz::ShaderBuilder struct ConditionalExpression { - inline std::unique_ptr operator()(std::size_t optionIndex, ShaderAst::ExpressionPtr truePath, ShaderAst::ExpressionPtr falsePath) const; + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr condition, ShaderAst::ExpressionPtr truePath, ShaderAst::ExpressionPtr falsePath) const; }; struct ConditionalStatement { - inline std::unique_ptr operator()(std::size_t optionIndex, ShaderAst::StatementPtr statement) const; + inline std::unique_ptr operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr statement) const; }; struct Constant diff --git a/include/Nazara/Shader/ShaderBuilder.inl b/include/Nazara/Shader/ShaderBuilder.inl index 9b9ecc510..912edd345 100644 --- a/include/Nazara/Shader/ShaderBuilder.inl +++ b/include/Nazara/Shader/ShaderBuilder.inl @@ -109,20 +109,20 @@ namespace Nz::ShaderBuilder return castNode; } - inline std::unique_ptr Impl::ConditionalExpression::operator()(std::size_t optionIndex, ShaderAst::ExpressionPtr truePath, ShaderAst::ExpressionPtr falsePath) const + inline std::unique_ptr Impl::ConditionalExpression::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::ExpressionPtr truePath, ShaderAst::ExpressionPtr falsePath) const { auto condExprNode = std::make_unique(); - condExprNode->optionIndex = optionIndex; + condExprNode->condition = std::move(condition); condExprNode->falsePath = std::move(falsePath); condExprNode->truePath = std::move(truePath); return condExprNode; } - inline std::unique_ptr Impl::ConditionalStatement::operator()(std::size_t optionIndex, ShaderAst::StatementPtr statement) const + inline std::unique_ptr Impl::ConditionalStatement::operator()(ShaderAst::ExpressionPtr condition, ShaderAst::StatementPtr statement) const { auto condStatementNode = std::make_unique(); - condStatementNode->optionIndex = optionIndex; + condStatementNode->condition = std::move(condition); condStatementNode->statement = std::move(statement); return condStatementNode; @@ -159,7 +159,9 @@ namespace Nz::ShaderBuilder inline std::unique_ptr Impl::DeclareFunction::operator()(std::optional entryStage, std::string name, std::vector parameters, std::vector statements, ShaderAst::ExpressionType returnType) const { auto declareFunctionNode = std::make_unique(); - declareFunctionNode->entryStage = entryStage; + if (entryStage) + declareFunctionNode->entryStage = *entryStage; + declareFunctionNode->name = std::move(name); declareFunctionNode->parameters = std::move(parameters); declareFunctionNode->returnType = std::move(returnType); diff --git a/include/Nazara/Shader/SpirvAstVisitor.hpp b/include/Nazara/Shader/SpirvAstVisitor.hpp index d6301d5f3..bc17eed57 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.hpp +++ b/include/Nazara/Shader/SpirvAstVisitor.hpp @@ -47,8 +47,6 @@ namespace Nz void Visit(ShaderAst::BranchStatement& node) override; void Visit(ShaderAst::CallFunctionExpression& node) override; void Visit(ShaderAst::CastExpression& node) override; - void Visit(ShaderAst::ConditionalExpression& node) override; - void Visit(ShaderAst::ConditionalStatement& node) override; void Visit(ShaderAst::ConstantExpression& node) override; void Visit(ShaderAst::DeclareExternalStatement& node) override; void Visit(ShaderAst::DeclareFunctionStatement& node) override; @@ -142,7 +140,7 @@ namespace Nz UInt32 PopResultId(); inline void RegisterExternalVariable(std::size_t varIndex, const ShaderAst::ExpressionType& type); - inline void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription structDesc); + inline void RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* structDesc); inline void RegisterVariable(std::size_t varIndex, UInt32 typeId, UInt32 pointerId, SpirvStorageClass storageClass); std::size_t m_extVarIndex; @@ -150,7 +148,7 @@ namespace Nz std::size_t m_funcIndex; std::vector m_scopeSizes; std::vector& m_funcData; - std::vector m_structs; + std::vector m_structs; std::vector> m_variables; std::vector m_functionBlocks; std::vector m_resultIds; diff --git a/include/Nazara/Shader/SpirvAstVisitor.inl b/include/Nazara/Shader/SpirvAstVisitor.inl index f67692f6f..48ad986c7 100644 --- a/include/Nazara/Shader/SpirvAstVisitor.inl +++ b/include/Nazara/Shader/SpirvAstVisitor.inl @@ -25,12 +25,12 @@ namespace Nz RegisterVariable(varIndex, m_writer.GetTypeId(type), pointerId, storageClass); } - inline void SpirvAstVisitor::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription structDesc) + inline void SpirvAstVisitor::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* structDesc) { if (structIndex >= m_structs.size()) m_structs.resize(structIndex + 1); - m_structs[structIndex] = std::move(structDesc); + m_structs[structIndex] = structDesc; } inline void SpirvAstVisitor::RegisterVariable(std::size_t varIndex, UInt32 typeId, UInt32 pointerId, SpirvStorageClass storageClass) diff --git a/src/Nazara/Graphics/UberShader.cpp b/src/Nazara/Graphics/UberShader.cpp index 763c8ffc8..e7a3ef611 100644 --- a/src/Nazara/Graphics/UberShader.cpp +++ b/src/Nazara/Graphics/UberShader.cpp @@ -16,10 +16,8 @@ namespace Nz UberShader::UberShader(ShaderStageType shaderStage, const ShaderAst::StatementPtr& shaderAst) : m_shaderStage(shaderStage) { - ShaderAst::SanitizeVisitor::Options options; - options.removeOptionDeclaration = false; - - m_shaderAst = ShaderAst::Sanitize(*shaderAst, options); + //TODO: Try to partially sanitize shader? + m_shaderAst = ShaderAst::Clone(*shaderAst); std::size_t optionCount = 0; @@ -59,7 +57,6 @@ namespace Nz { ShaderWriter::States states; states.enabledOptions = combination; - states.sanitized = true; std::shared_ptr stage = Graphics::Instance()->GetRenderDevice()->InstantiateShaderModule(m_shaderStage, *m_shaderAst, std::move(states)); diff --git a/src/Nazara/OpenGLRenderer/OpenGLShaderModule.cpp b/src/Nazara/OpenGLRenderer/OpenGLShaderModule.cpp index 4d70be832..27aa72b3d 100644 --- a/src/Nazara/OpenGLRenderer/OpenGLShaderModule.cpp +++ b/src/Nazara/OpenGLRenderer/OpenGLShaderModule.cpp @@ -140,7 +140,7 @@ namespace Nz { m_states = states; m_states.sanitized = true; //< Shader is always sanitized (because of keywords) - std::shared_ptr sanitized = GlslWriter::Sanitize(shaderAst); + std::shared_ptr sanitized = GlslWriter::Sanitize(shaderAst, states.enabledOptions); for (std::size_t i = 0; i < ShaderStageTypeCount; ++i) { diff --git a/src/Nazara/Shader/Ast/AstCloner.cpp b/src/Nazara/Shader/Ast/AstCloner.cpp index c3780769c..f4a37efff 100644 --- a/src/Nazara/Shader/Ast/AstCloner.cpp +++ b/src/Nazara/Shader/Ast/AstCloner.cpp @@ -56,7 +56,7 @@ namespace Nz::ShaderAst StatementPtr AstCloner::Clone(ConditionalStatement& node) { auto clone = std::make_unique(); - clone->optionIndex = node.optionIndex; + clone->condition = CloneExpression(node.condition); clone->statement = CloneStatement(node.statement); return clone; @@ -65,21 +65,31 @@ namespace Nz::ShaderAst StatementPtr AstCloner::Clone(DeclareExternalStatement& node) { auto clone = std::make_unique(); - clone->externalVars = node.externalVars; clone->varIndex = node.varIndex; + clone->bindingSet = CloneAttribute(node.bindingSet); + + clone->externalVars.reserve(node.externalVars.size()); + for (const auto& var : node.externalVars) + { + auto& cloneVar = clone->externalVars.emplace_back(); + cloneVar.name = var.name; + cloneVar.type = var.type; + cloneVar.bindingIndex = CloneAttribute(var.bindingIndex); + cloneVar.bindingSet = CloneAttribute(var.bindingSet); + } + return clone; } StatementPtr AstCloner::Clone(DeclareFunctionStatement& node) { auto clone = std::make_unique(); - clone->depthWrite = node.depthWrite; - clone->earlyFragmentTests = node.earlyFragmentTests; - clone->entryStage = node.entryStage; + clone->depthWrite = CloneAttribute(node.depthWrite); + clone->earlyFragmentTests = CloneAttribute(node.earlyFragmentTests); + clone->entryStage = CloneAttribute(node.entryStage); clone->funcIndex = node.funcIndex; clone->name = node.name; - clone->optionName = node.optionName; clone->parameters = node.parameters; clone->returnType = node.returnType; clone->varIndex = node.varIndex; @@ -106,7 +116,20 @@ namespace Nz::ShaderAst { auto clone = std::make_unique(); clone->structIndex = node.structIndex; - clone->description = node.description; + + clone->description.layout = CloneAttribute(node.description.layout); + clone->description.name = node.description.name; + + clone->description.members.reserve(node.description.members.size()); + for (const auto& member : node.description.members) + { + auto& cloneMember = clone->description.members.emplace_back(); + cloneMember.name = member.name; + cloneMember.type = member.type; + cloneMember.builtin = CloneAttribute(member.builtin); + cloneMember.cond = CloneAttribute(member.cond); + cloneMember.locationIndex = CloneAttribute(member.locationIndex); + } return clone; } @@ -259,7 +282,7 @@ namespace Nz::ShaderAst ExpressionPtr AstCloner::Clone(ConditionalExpression& node) { auto clone = std::make_unique(); - clone->optionIndex = node.optionIndex; + clone->condition = CloneExpression(node.condition); clone->falsePath = CloneExpression(node.falsePath); clone->truePath = CloneExpression(node.truePath); @@ -268,6 +291,16 @@ namespace Nz::ShaderAst return clone; } + ExpressionPtr AstCloner::Clone(ConstantIndexExpression& node) + { + auto clone = std::make_unique(); + clone->constantId = node.constantId; + + clone->cachedExpressionType = node.cachedExpressionType; + + return clone; + } + ExpressionPtr AstCloner::Clone(ConstantExpression& node) { auto clone = std::make_unique(); diff --git a/src/Nazara/Shader/Ast/AstOptimizer.cpp b/src/Nazara/Shader/Ast/AstOptimizer.cpp index 4ac579cab..102d18012 100644 --- a/src/Nazara/Shader/Ast/AstOptimizer.cpp +++ b/src/Nazara/Shader/Ast/AstOptimizer.cpp @@ -531,19 +531,6 @@ namespace Nz::ShaderAst #undef EnableOptimisation } - StatementPtr AstOptimizer::Optimise(Statement& statement) - { - m_enabledOptions.reset(); - return CloneStatement(statement); - } - - StatementPtr AstOptimizer::Optimise(Statement& statement, UInt64 enabledConditions) - { - m_enabledOptions = enabledConditions; - - return CloneStatement(statement); - } - ExpressionPtr AstOptimizer::Clone(BinaryExpression& node) { auto lhs = CloneExpression(node.left); @@ -785,15 +772,36 @@ namespace Nz::ShaderAst ExpressionPtr AstOptimizer::Clone(ConditionalExpression& node) { - if (!m_enabledOptions) + if (!m_options.enabledOptions) return AstCloner::Clone(node); - if (TestBit(*m_enabledOptions, node.optionIndex)) + auto cond = CloneExpression(node.condition); + if (cond->GetType() != NodeType::ConstantExpression) + throw std::runtime_error("conditional expression condition must be a constant expression"); + + auto& constant = static_cast(*cond); + + assert(constant.cachedExpressionType); + const ExpressionType& constantType = constant.cachedExpressionType.value(); + + if (!IsPrimitiveType(constantType) || std::get(constantType) != PrimitiveType::Boolean) + throw std::runtime_error("conditional expression condition must resolve to a boolean"); + + bool cValue = std::get(constant.value); + if (cValue) return AstCloner::Clone(*node.truePath); else return AstCloner::Clone(*node.falsePath); } + ExpressionPtr AstOptimizer::Clone(ConstantIndexExpression& node) + { + if (!m_options.constantQueryCallback) + return AstCloner::Clone(node); + + return ShaderBuilder::Constant(m_options.constantQueryCallback(node.constantId)); + } + ExpressionPtr AstOptimizer::Clone(UnaryExpression& node) { auto expr = CloneExpression(node.expression); @@ -830,10 +838,23 @@ namespace Nz::ShaderAst StatementPtr AstOptimizer::Clone(ConditionalStatement& node) { - if (!m_enabledOptions) + if (!m_options.enabledOptions) return AstCloner::Clone(node); - if (TestBit(*m_enabledOptions, node.optionIndex)) + auto cond = CloneExpression(node.condition); + if (cond->GetType() != NodeType::ConstantExpression) + throw std::runtime_error("conditional expression condition must be a constant expression"); + + auto& constant = static_cast(*cond); + + assert(constant.cachedExpressionType); + const ExpressionType& constantType = constant.cachedExpressionType.value(); + + if (!IsPrimitiveType(constantType) || std::get(constantType) != PrimitiveType::Boolean) + throw std::runtime_error("conditional expression condition must resolve to a boolean"); + + bool cValue = std::get(constant.value); + if (cValue) return AstCloner::Clone(node); else return ShaderBuilder::NoOp(); diff --git a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp index 1268b8bf7..52d8fdc8d 100644 --- a/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp +++ b/src/Nazara/Shader/Ast/AstRecursiveVisitor.cpp @@ -4,6 +4,7 @@ #include #include +#include "..\..\..\..\include\Nazara\Shader\Ast\AstRecursiveVisitor.hpp" namespace Nz::ShaderAst { @@ -67,6 +68,11 @@ namespace Nz::ShaderAst /* Nothing to do */ } + void AstRecursiveVisitor::Visit(ConstantIndexExpression& /*node*/) + { + /* Nothing to do */ + } + void AstRecursiveVisitor::Visit(IdentifierExpression& /*node*/) { /* Nothing to do */ diff --git a/src/Nazara/Shader/Ast/AstSerializer.cpp b/src/Nazara/Shader/Ast/AstSerializer.cpp index 05578f94c..0b5024415 100644 --- a/src/Nazara/Shader/Ast/AstSerializer.cpp +++ b/src/Nazara/Shader/Ast/AstSerializer.cpp @@ -111,9 +111,14 @@ namespace Nz::ShaderAst Node(expr); } + void AstSerializerBase::Serialize(ConstantIndexExpression& node) + { + SizeT(node.constantId); + } + void AstSerializerBase::Serialize(ConditionalExpression& node) { - SizeT(node.optionIndex); + Node(node.condition); Node(node.truePath); Node(node.falsePath); } @@ -207,7 +212,7 @@ namespace Nz::ShaderAst void AstSerializerBase::Serialize(ConditionalStatement& node) { - SizeT(node.optionIndex); + Node(node.condition); Node(node.statement); } @@ -215,13 +220,15 @@ namespace Nz::ShaderAst { OptVal(node.varIndex); + Attribute(node.bindingSet); + Container(node.externalVars); for (auto& extVar : node.externalVars) { Value(extVar.name); Type(extVar.type); - OptVal(extVar.bindingIndex); - OptVal(extVar.bindingSet); + Attribute(extVar.bindingIndex); + Attribute(extVar.bindingSet); } } @@ -229,11 +236,10 @@ namespace Nz::ShaderAst { Value(node.name); Type(node.returnType); - OptEnum(node.depthWrite); - OptVal(node.earlyFragmentTests); - OptEnum(node.entryStage); + Attribute(node.depthWrite); + Attribute(node.earlyFragmentTests); + Attribute(node.entryStage); OptVal(node.funcIndex); - Value(node.optionName); OptVal(node.varIndex); Container(node.parameters); @@ -261,15 +267,16 @@ namespace Nz::ShaderAst OptVal(node.structIndex); Value(node.description.name); - OptEnum(node.description.layout); + Attribute(node.description.layout); Container(node.description.members); for (auto& member : node.description.members) { Value(member.name); Type(member.type); - OptEnum(member.builtin); - OptVal(member.locationIndex); + Attribute(member.builtin); + Attribute(member.cond); + Attribute(member.locationIndex); } } @@ -535,6 +542,7 @@ namespace Nz::ShaderAst #define NAZARA_SHADERAST_STATEMENT(Node) case NodeType:: Node : node = std::make_unique(); break; #include +#include "..\..\..\..\include\Nazara\Shader\Ast\AstSerializer.hpp" default: throw std::runtime_error("unexpected node type"); } diff --git a/src/Nazara/Shader/Ast/AstUtils.cpp b/src/Nazara/Shader/Ast/AstUtils.cpp index db9b50895..62e837b88 100644 --- a/src/Nazara/Shader/Ast/AstUtils.cpp +++ b/src/Nazara/Shader/Ast/AstUtils.cpp @@ -67,6 +67,11 @@ namespace Nz::ShaderAst m_expressionCategory = ExpressionCategory::RValue; } + void ShaderAstValueCategory::Visit(ConstantIndexExpression& /*node*/) + { + m_expressionCategory = ExpressionCategory::LValue; + } + void ShaderAstValueCategory::Visit(IdentifierExpression& /*node*/) { m_expressionCategory = ExpressionCategory::LValue; diff --git a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp index b257a6182..88c684553 100644 --- a/src/Nazara/Shader/Ast/SanitizeVisitor.cpp +++ b/src/Nazara/Shader/Ast/SanitizeVisitor.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -30,7 +31,7 @@ namespace Nz::ShaderAst struct SanitizeVisitor::Context { - struct FunctionData + struct CurrentFunctionData { std::optional stageType; Bitset<> calledFunctions; @@ -38,11 +39,19 @@ namespace Nz::ShaderAst FunctionFlags flags; }; + std::size_t nextOptionIndex = 0; Options options; std::array entryFunctions = {}; std::unordered_set declaredExternalVar; std::unordered_set usedBindingIndexes; - FunctionData* currentFunction = nullptr; + std::vector identifiersInScope; + std::vector constantValues; + std::vector functions; + std::vector intrinsics; + std::vector structs; + std::vector variableTypes; + std::vector scopeSizes; + CurrentFunctionData* currentFunction = nullptr; }; StatementPtr SanitizeVisitor::Sanitize(Statement& statement, const Options& options, std::string* error) @@ -123,14 +132,23 @@ namespace Nz::ShaderAst accessIndexPtr = static_cast(indexedExpr.get()); std::size_t structIndex = ResolveStruct(exprType); - assert(structIndex < m_structs.size()); - const StructDescription& s = m_structs[structIndex]; + assert(structIndex < m_context->structs.size()); + const StructDescription* s = m_context->structs[structIndex]; - auto it = std::find_if(s.members.begin(), s.members.end(), [&](const auto& field) { return field.name == identifier; }); - if (it == s.members.end()) + auto it = std::find_if(s->members.begin(), s->members.end(), [&](const auto& field) + { + if (field.name != identifier) + return false; + + if (field.cond.HasValue() && !field.cond.GetResultingValue()) + return false; + + return true; + }); + if (it == s->members.end()) throw AstError{ "unknown field " + identifier }; - accessIndexPtr->indices.push_back(ShaderBuilder::Constant(Int32(std::distance(s.members.begin(), it)))); + accessIndexPtr->indices.push_back(ShaderBuilder::Constant(Int32(std::distance(s->members.begin(), it)))); accessIndexPtr->cachedExpressionType = ResolveType(it->type); } else if (IsVectorType(exprType)) @@ -419,7 +437,7 @@ namespace Nz::ShaderAst for (const auto& param : node.parameters) parameters.push_back(CloneExpression(param)); - auto intrinsic = ShaderBuilder::Intrinsic(m_intrinsics[identifier->index], std::move(parameters)); + auto intrinsic = ShaderBuilder::Intrinsic(m_context->intrinsics[identifier->index], std::move(parameters)); Validate(*intrinsic); return intrinsic; @@ -437,11 +455,11 @@ namespace Nz::ShaderAst else { // Identifier not found, maybe the function is declared later - auto it = std::find_if(m_functions.begin(), m_functions.end(), [&](const auto& funcData) { return funcData.node->name == functionName; }); - if (it == m_functions.end()) + auto it = std::find_if(m_context->functions.begin(), m_context->functions.end(), [&](const auto& funcData) { return funcData.node->name == functionName; }); + if (it == m_context->functions.end()) throw AstError{ "function " + functionName + " does not exist" }; - targetFuncIndex = std::distance(m_functions.begin(), it); + targetFuncIndex = std::distance(m_context->functions.begin(), it); clone->targetFunction = targetFuncIndex; } @@ -451,7 +469,7 @@ namespace Nz::ShaderAst m_context->currentFunction->calledFunctions.UnboundedSet(targetFuncIndex); - Validate(*clone, m_functions[targetFuncIndex].node); + Validate(*clone, m_context->functions[targetFuncIndex].node); return clone; } @@ -507,18 +525,18 @@ namespace Nz::ShaderAst ExpressionPtr SanitizeVisitor::Clone(ConditionalExpression& node) { + MandatoryExpr(node.condition); MandatoryExpr(node.truePath); MandatoryExpr(node.falsePath); - auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*node.condition)); + if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean }) + throw AstError{ "expected a boolean value" }; - const ExpressionType& leftExprType = GetExpressionType(*clone->truePath); - if (leftExprType != GetExpressionType(*clone->falsePath)) - throw AstError{ "true path type must match false path type" }; - - clone->cachedExpressionType = leftExprType; - - return clone; + if (std::get(conditionValue)) + return AstCloner::Clone(*node.truePath); + else + return AstCloner::Clone(*node.falsePath); } ExpressionPtr SanitizeVisitor::Clone(ConstantExpression& node) @@ -537,15 +555,31 @@ namespace Nz::ShaderAst if (!identifier) throw AstError{ "unknown identifier " + node.identifier }; - if (identifier->type != Identifier::Type::Variable) - throw AstError{ "expected variable identifier" }; + switch (identifier->type) + { + case Identifier::Type::Constant: + { + // Replace IdentifierExpression by ConstantIndexExpression + auto constantExpr = std::make_unique(); + constantExpr->cachedExpressionType = GetExpressionType(m_context->constantValues[identifier->index]); + constantExpr->constantId = identifier->index; - // Replace IdentifierExpression by VariableExpression - auto varExpr = std::make_unique(); - varExpr->cachedExpressionType = m_variableTypes[identifier->index]; - varExpr->variableId = identifier->index; + return constantExpr; + } - return varExpr; + case Identifier::Type::Variable: + { + // Replace IdentifierExpression by VariableExpression + auto varExpr = std::make_unique(); + varExpr->cachedExpressionType = m_context->variableTypes[identifier->index]; + varExpr->variableId = identifier->index; + + return varExpr; + } + + default: + throw AstError{ "expected constant or variable identifier" }; + } } ExpressionPtr SanitizeVisitor::Clone(IntrinsicExpression& node) @@ -561,26 +595,20 @@ namespace Nz::ShaderAst MandatoryExpr(node.truePath); MandatoryExpr(node.falsePath); - auto condExpr = std::make_unique(); - condExpr->truePath = CloneExpression(node.truePath); - condExpr->falsePath = CloneExpression(node.falsePath); - const Identifier* identifier = FindIdentifier(node.optionName); if (!identifier) - throw AstError{ "unknown option " + node.optionName }; + throw AstError{ "unknown constant " + node.optionName }; - if (identifier->type != Identifier::Type::Option) - throw AstError{ "expected option identifier" }; + if (identifier->type != Identifier::Type::Constant) + throw AstError{ "expected constant identifier" }; - condExpr->optionIndex = identifier->index; + if (GetExpressionType(m_context->constantValues[identifier->index]) != ExpressionType{ PrimitiveType::Boolean }) + throw AstError{ "constant is not a boolean" }; - const ExpressionType& leftExprType = GetExpressionType(*condExpr->truePath); - if (leftExprType != GetExpressionType(*condExpr->falsePath)) - throw AstError{ "true path type must match false path type" }; - - condExpr->cachedExpressionType = leftExprType; - - return condExpr; + if (std::get(m_context->constantValues[identifier->index])) + return CloneExpression(node.truePath); + else + return CloneExpression(node.falsePath); } ExpressionPtr SanitizeVisitor::Clone(SwizzleExpression& node) @@ -687,43 +715,54 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(ConditionalStatement& node) { + MandatoryExpr(node.condition); MandatoryStatement(node.statement); - PushScope(); - - auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + ConstantValue conditionValue = ComputeConstantValue(*AstCloner::Clone(*node.condition)); + if (GetExpressionType(conditionValue) != ExpressionType{ PrimitiveType::Boolean }) + throw AstError{ "expected a boolean value" }; - PopScope(); - - return clone; + if (std::get(conditionValue)) + return AstCloner::Clone(*node.statement); + else + return ShaderBuilder::NoOp(); } StatementPtr SanitizeVisitor::Clone(DeclareExternalStatement& node) { assert(m_context); - for (const auto& extVar : node.externalVars) + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); + + UInt32 defaultBlockSet = 0; + if (clone->bindingSet.HasValue()) + defaultBlockSet = ComputeAttributeValue(clone->bindingSet); + + for (auto& extVar : clone->externalVars) { - if (!extVar.bindingIndex) + if (!extVar.bindingIndex.HasValue()) throw AstError{ "external variable " + extVar.name + " requires a binding index" }; - UInt64 bindingIndex = *extVar.bindingIndex; - UInt64 bindingSet = extVar.bindingSet.value_or(0); + if (extVar.bindingSet.HasValue()) + ComputeAttributeValue(extVar.bindingSet); + else + extVar.bindingSet = defaultBlockSet; + + UInt64 bindingSet = extVar.bindingSet.GetResultingValue(); + + UInt64 bindingIndex = ComputeAttributeValue(extVar.bindingIndex); + UInt64 bindingKey = bindingSet << 32 | bindingIndex; if (m_context->usedBindingIndexes.find(bindingKey) != m_context->usedBindingIndexes.end()) - throw AstError{ "Binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" }; + throw AstError{ "binding (set=" + std::to_string(bindingSet) + ", binding=" + std::to_string(bindingIndex) + ") is already in use" }; m_context->usedBindingIndexes.insert(bindingKey); if (m_context->declaredExternalVar.find(extVar.name) != m_context->declaredExternalVar.end()) - throw AstError{ "External variable " + extVar.name + " is already declared" }; + throw AstError{ "external variable " + extVar.name + " is already declared" }; m_context->declaredExternalVar.insert(extVar.name); - } - auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); - for (auto& extVar : clone->externalVars) - { extVar.type = ResolveType(extVar.type); ExpressionType varType; @@ -746,9 +785,23 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(DeclareFunctionStatement& node) { - if (node.entryStage) + auto clone = std::make_unique(); + clone->name = node.name; + clone->parameters = node.parameters; + clone->returnType = ResolveType(node.returnType); + + if (node.depthWrite.HasValue()) + clone->depthWrite = ComputeAttributeValue(node.depthWrite); + + if (node.earlyFragmentTests.HasValue()) + clone->earlyFragmentTests = ComputeAttributeValue(node.earlyFragmentTests); + + if (node.entryStage.HasValue()) + clone->entryStage = ComputeAttributeValue(node.entryStage); + + if (clone->entryStage.HasValue()) { - ShaderStageType stageType = *node.entryStage; + ShaderStageType stageType = clone->entryStage.GetResultingValue(); if (m_context->entryFunctions[UnderlyingCast(stageType)]) throw AstError{ "the same entry type has been defined multiple times" }; @@ -760,26 +813,17 @@ namespace Nz::ShaderAst if (stageType != ShaderStageType::Fragment) { - if (node.depthWrite.has_value()) + if (node.depthWrite.HasValue()) throw AstError{ "only fragment entry-points can have the depth_write attribute" }; - if (node.earlyFragmentTests.has_value()) + if (node.earlyFragmentTests.HasValue()) throw AstError{ "only functions with entry(frag) attribute can have the early_fragments_tests attribute" }; } } - auto clone = std::make_unique(); - clone->depthWrite = node.depthWrite; - clone->earlyFragmentTests = node.earlyFragmentTests; - clone->entryStage = node.entryStage; - clone->name = node.name; - clone->optionName = node.optionName; - clone->parameters = node.parameters; - clone->returnType = ResolveType(node.returnType); - - - Context::FunctionData tempFuncData; - tempFuncData.stageType = node.entryStage; + Context::CurrentFunctionData tempFuncData; + if (node.entryStage.HasValue()) + tempFuncData.stageType = node.entryStage.GetResultingValue(); m_context->currentFunction = &tempFuncData; @@ -803,31 +847,17 @@ namespace Nz::ShaderAst m_context->currentFunction = nullptr; - if (clone->earlyFragmentTests.has_value() && *clone->earlyFragmentTests) + if (clone->earlyFragmentTests.HasValue() && clone->earlyFragmentTests.GetResultingValue()) { //TODO: warning and disable early fragment tests throw AstError{ "discard is not compatible with early fragment tests" }; } - if (!clone->optionName.empty()) - { - const Identifier* identifier = FindIdentifier(node.optionName); - if (!identifier) - throw AstError{ "unknown option " + node.optionName }; - - if (identifier->type != Identifier::Type::Option) - throw AstError{ "expected option identifier" }; - - std::size_t optionIndex = identifier->index; - - return ShaderBuilder::ConditionalStatement(optionIndex, std::move(clone)); - } - - auto it = std::find_if(m_functions.begin(), m_functions.end(), [&](const auto& funcData) { return funcData.node == &node; }); - assert(it != m_functions.end()); + auto it = std::find_if(m_context->functions.begin(), m_context->functions.end(), [&](const auto& funcData) { return funcData.node == &node; }); + assert(it != m_context->functions.end()); assert(!it->defined); - std::size_t funcIndex = std::distance(m_functions.begin(), it); + std::size_t funcIndex = std::distance(m_context->functions.begin(), it); clone->funcIndex = funcIndex; @@ -836,8 +866,8 @@ namespace Nz::ShaderAst for (std::size_t i = tempFuncData.calledFunctions.FindFirst(); i != tempFuncData.calledFunctions.npos; i = tempFuncData.calledFunctions.FindNext(i)) { - assert(i < m_functions.size()); - auto& targetFunc = m_functions[i]; + assert(i < m_context->functions.size()); + auto& targetFunc = m_context->functions[i]; targetFunc.calledByFunctions.UnboundedSet(funcIndex); } @@ -854,7 +884,9 @@ namespace Nz::ShaderAst if (clone->initialValue && clone->optType != GetExpressionType(*clone->initialValue)) throw AstError{ "option " + clone->optName + " initial expression must be of the same type than the option" }; - clone->optIndex = RegisterOption(clone->optName, clone->optType); + std::size_t optionIndex = m_context->nextOptionIndex++; + + clone->optIndex = RegisterConstant(clone->optName, TestBit(m_context->options.enabledOptions, optionIndex)); if (m_context->options.removeOptionDeclaration) return ShaderBuilder::NoOp(); @@ -864,22 +896,33 @@ namespace Nz::ShaderAst StatementPtr SanitizeVisitor::Clone(DeclareStructStatement& node) { - std::unordered_set declaredMembers; + auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); - for (auto& member : node.description.members) + std::unordered_set declaredMembers; + for (auto& member : clone->description.members) { + if (member.cond.HasValue()) + { + member.cond = ComputeAttributeValue(member.cond); + if (!member.cond.GetResultingValue()) + continue; + } + + if (member.builtin.HasValue()) + member.builtin = ComputeAttributeValue(member.builtin); + + if (member.locationIndex.HasValue()) + member.locationIndex = ComputeAttributeValue(member.locationIndex); + if (declaredMembers.find(member.name) != declaredMembers.end()) throw AstError{ "struct member " + member.name + " found multiple time" }; declaredMembers.insert(member.name); + + member.type = ResolveType(member.type); } - auto clone = static_unique_pointer_cast(AstCloner::Clone(node)); - - for (auto& member : clone->description.members) - member.type = ResolveType(member.type); - - clone->structIndex = RegisterStruct(clone->description.name, clone->description); + clone->structIndex = RegisterStruct(clone->description.name, &clone->description); SanitizeIdentifier(clone->description.name); @@ -951,6 +994,15 @@ namespace Nz::ShaderAst return clone; } + auto SanitizeVisitor::FindIdentifier(const std::string_view& identifierName) const -> const Identifier* + { + auto it = std::find_if(m_context->identifiersInScope.rbegin(), m_context->identifiersInScope.rend(), [&](const Identifier& identifier) { return identifier.name == identifierName; }); + if (it == m_context->identifiersInScope.rend()) + return nullptr; + + return &*it; + } + Expression& SanitizeVisitor::MandatoryExpr(ExpressionPtr& node) { if (!node) @@ -969,20 +1021,69 @@ namespace Nz::ShaderAst void SanitizeVisitor::PushScope() { - m_scopeSizes.push_back(m_identifiersInScope.size()); + m_context->scopeSizes.push_back(m_context->identifiersInScope.size()); } void SanitizeVisitor::PopScope() { - assert(!m_scopeSizes.empty()); - m_identifiersInScope.resize(m_scopeSizes.back()); - m_scopeSizes.pop_back(); + assert(!m_context->scopeSizes.empty()); + m_context->identifiersInScope.resize(m_context->scopeSizes.back()); + m_context->scopeSizes.pop_back(); + } + + template + const T& SanitizeVisitor::ComputeAttributeValue(AttributeValue& attribute) + { + if (!attribute.HasValue()) + throw AstError{"attribute expected a value"}; + + if (attribute.IsExpression()) + { + ConstantValue value = ComputeConstantValue(*attribute.GetExpression()); + if constexpr (TypeListFind) + { + if (!std::holds_alternative(value)) + { + // HAAAAAX + if (std::holds_alternative(value) && std::is_same_v) + attribute = static_cast(std::get(value)); + else + throw AstError{ "unexpected attribute type" }; + } + else + attribute = std::get(value); + } + else + throw AstError{ "unexpected expression for this type" }; + } + + assert(attribute.IsResultingValue()); + return attribute.GetResultingValue(); + } + + ConstantValue SanitizeVisitor::ComputeConstantValue(Expression& expr) + { + AstOptimizer::Options optimizerOptions; + optimizerOptions.constantQueryCallback = [this](std::size_t constantId) + { + assert(constantId < m_context->constantValues.size()); + return m_context->constantValues[constantId]; + }; + + optimizerOptions.enabledOptions = m_context->options.enabledOptions; + + // Run optimizer on constant value to hopefully retrieve a single constant value + ExpressionPtr optimizedExpr = Optimize(expr, optimizerOptions); + if (optimizedExpr->GetType() != NodeType::ConstantExpression) + throw AstError{"expected a constant expression"}; + + return static_cast(*optimizedExpr).value; } std::size_t SanitizeVisitor::DeclareFunction(DeclareFunctionStatement& funcDecl) { - std::size_t functionIndex = m_functions.size(); - auto& funcData = m_functions.emplace_back(); + std::size_t functionIndex = m_context->functions.size(); + auto& funcData = m_context->functions.emplace_back(); funcData.node = &funcDecl; return functionIndex; @@ -990,8 +1091,8 @@ namespace Nz::ShaderAst void SanitizeVisitor::PropagateFunctionFlags(std::size_t funcIndex, FunctionFlags flags, Bitset<>& seen) { - assert(funcIndex < m_functions.size()); - auto& funcData = m_functions[funcIndex]; + assert(funcIndex < m_context->functions.size()); + auto& funcData = m_context->functions[funcIndex]; assert(funcData.defined); funcData.flags |= flags; @@ -999,11 +1100,28 @@ namespace Nz::ShaderAst for (std::size_t i = funcData.calledByFunctions.FindFirst(); i != funcData.calledByFunctions.npos; i = funcData.calledByFunctions.FindNext(i)) PropagateFunctionFlags(i, funcData.flags, seen); } + + std::size_t SanitizeVisitor::RegisterConstant(std::string name, ConstantValue value) + { + if (FindIdentifier(name)) + throw AstError{ name + " is already used" }; + + std::size_t constantIndex = m_context->constantValues.size(); + m_context->constantValues.emplace_back(std::move(value)); + + m_context->identifiersInScope.push_back({ + std::move(name), + constantIndex, + Identifier::Type::Constant + }); + + return constantIndex; + } auto SanitizeVisitor::RegisterFunction(std::size_t functionIndex) -> FunctionData& { - assert(m_functions.size() >= functionIndex); - auto& funcData = m_functions[functionIndex]; + assert(m_context->functions.size() >= functionIndex); + auto& funcData = m_context->functions[functionIndex]; assert(!funcData.defined); funcData.defined = true; @@ -1012,10 +1130,10 @@ namespace Nz::ShaderAst bool duplicate = true; // Functions cannot be declared twice, except for entry ones if their stages are different - if (funcData.node->entryStage && identifier->type == Identifier::Type::Function) + if (funcData.node->entryStage.HasValue() && identifier->type == Identifier::Type::Function) { - auto& otherFunction = m_functions[identifier->index]; - if (funcData.node->entryStage != otherFunction.node->entryStage) + auto& otherFunction = m_context->functions[identifier->index]; + if (funcData.node->entryStage.GetResultingValue() != otherFunction.node->entryStage.GetResultingValue()) duplicate = false; } @@ -1023,7 +1141,7 @@ namespace Nz::ShaderAst throw AstError{ funcData.node->name + " is already used" }; } - m_identifiersInScope.push_back({ + m_context->identifiersInScope.push_back({ funcData.node->name, functionIndex, Identifier::Type::Function @@ -1037,10 +1155,10 @@ namespace Nz::ShaderAst if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t intrinsicIndex = m_intrinsics.size(); - m_intrinsics.push_back(type); + std::size_t intrinsicIndex = m_context->intrinsics.size(); + m_context->intrinsics.push_back(type); - m_identifiersInScope.push_back({ + m_context->identifiersInScope.push_back({ std::move(name), intrinsicIndex, Identifier::Type::Intrinsic @@ -1049,32 +1167,15 @@ namespace Nz::ShaderAst return intrinsicIndex; } - std::size_t SanitizeVisitor::RegisterOption(std::string name, ExpressionType type) + std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription* description) { if (FindIdentifier(name)) throw AstError{ name + " is already used" }; - std::size_t optionIndex = m_options.size(); - m_options.emplace_back(std::move(type)); + std::size_t structIndex = m_context->structs.size(); + m_context->structs.emplace_back(description); - m_identifiersInScope.push_back({ - std::move(name), - optionIndex, - Identifier::Type::Option - }); - - return optionIndex; - } - - std::size_t SanitizeVisitor::RegisterStruct(std::string name, StructDescription description) - { - if (FindIdentifier(name)) - throw AstError{ name + " is already used" }; - - std::size_t structIndex = m_structs.size(); - m_structs.emplace_back(std::move(description)); - - m_identifiersInScope.push_back({ + m_context->identifiersInScope.push_back({ std::move(name), structIndex, Identifier::Type::Struct @@ -1089,10 +1190,10 @@ namespace Nz::ShaderAst if (auto* identifier = FindIdentifier(name); identifier && identifier->type != Identifier::Type::Variable) throw AstError{ name + " is already used" }; - std::size_t varIndex = m_variableTypes.size(); - m_variableTypes.emplace_back(std::move(type)); + std::size_t varIndex = m_context->variableTypes.size(); + m_context->variableTypes.emplace_back(std::move(type)); - m_identifiersInScope.push_back({ + m_context->identifiersInScope.push_back({ std::move(name), varIndex, Identifier::Type::Variable @@ -1106,17 +1207,17 @@ namespace Nz::ShaderAst // Once every function is known, we can propagate flags Bitset<> seen; - for (std::size_t funcIndex = 0; funcIndex < m_functions.size(); ++funcIndex) + for (std::size_t funcIndex = 0; funcIndex < m_context->functions.size(); ++funcIndex) { - auto& funcData = m_functions[funcIndex]; + auto& funcData = m_context->functions[funcIndex]; PropagateFunctionFlags(funcIndex, funcData.flags, seen); seen.Clear(); } - for (const FunctionData& funcData : m_functions) + for (const FunctionData& funcData : m_context->functions) { - if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage && *funcData.node->entryStage != ShaderStageType::Fragment) + if (funcData.flags.Test(ShaderAst::FunctionFlag::DoesDiscard) && funcData.node->entryStage.HasValue() && funcData.node->entryStage.GetResultingValue() != ShaderStageType::Fragment) throw AstError{ "discard can only be used in the fragment stage" }; } } @@ -1246,18 +1347,18 @@ namespace Nz::ShaderAst auto& indexExpr = node.indices[i]; const ShaderAst::ExpressionType& indexType = GetExpressionType(*indexExpr); - if (indexExpr->GetType() != NodeType::ConstantExpression) - throw AstError{ "struct can only be accessed with constant indices" }; + if (indexExpr->GetType() != NodeType::ConstantExpression || indexType != ExpressionType{ PrimitiveType::Int32 }) + throw AstError{ "struct can only be accessed with constant i32 indices" }; ConstantExpression& constantExpr = static_cast(*indexExpr); Int32 index = std::get(constantExpr.value); std::size_t structIndex = ResolveStruct(exprType); - assert(structIndex < m_structs.size()); - const StructDescription& s = m_structs[structIndex]; + assert(structIndex < m_context->structs.size()); + const StructDescription* s = m_context->structs[structIndex]; - exprType = ResolveType(s.members[index].type); + exprType = ResolveType(s->members[index].type); } else if (IsMatrixType(exprType)) { @@ -1283,7 +1384,7 @@ namespace Nz::ShaderAst void SanitizeVisitor::Validate(CallFunctionExpression& node, const DeclareFunctionStatement* referenceDeclaration) { - if (referenceDeclaration->entryStage) + if (referenceDeclaration->entryStage.HasValue()) throw AstError{ referenceDeclaration->name + " is an entry function which cannot be called by the program" }; for (std::size_t i = 0; i < node.parameters.size(); ++i) diff --git a/src/Nazara/Shader/GlslWriter.cpp b/src/Nazara/Shader/GlslWriter.cpp index 5f50ae3f7..e236d8017 100644 --- a/src/Nazara/Shader/GlslWriter.cpp +++ b/src/Nazara/Shader/GlslWriter.cpp @@ -46,20 +46,27 @@ namespace Nz currentFunction->calledFunctions.UnboundedSet(std::get(node.targetFunction)); } + void Visit(ShaderAst::ConditionalExpression& node) override + { + throw std::runtime_error("unexpected conditional expression, is shader sanitized?"); + } + void Visit(ShaderAst::ConditionalStatement& node) override { - if (TestBit(enabledOptions, node.optionIndex)) - node.statement->Visit(*this); + throw std::runtime_error("unexpected conditional statement, is shader sanitized?"); } void Visit(ShaderAst::DeclareFunctionStatement& node) override { // Dismiss function if it's an entry point of another type than the one selected - if (node.entryStage) + if (node.entryStage.HasValue()) { if (selectedStage) { - ShaderStageType stage = *node.entryStage; + if (!node.entryStage.IsResultingValue()) + throw std::runtime_error("unexpected unresolved value for entry attribute, is shader sanitized?"); + + ShaderStageType stage = node.entryStage.GetResultingValue(); if (stage != *selectedStage) return; @@ -132,7 +139,7 @@ namespace Nz std::optional stage; std::stringstream stream; - std::unordered_map structs; + std::unordered_map structs; std::unordered_map variableNames; std::vector inputFields; std::vector outputFields; @@ -161,7 +168,7 @@ namespace Nz ShaderAst::Statement* targetAst; if (!states.sanitized) { - sanitizedAst = Sanitize(shader); + sanitizedAst = Sanitize(shader, states.enabledOptions); targetAst = sanitizedAst.get(); } else @@ -196,10 +203,11 @@ namespace Nz return s_flipYUniformName; } - ShaderAst::StatementPtr GlslWriter::Sanitize(ShaderAst::Statement& ast, std::string* error) + ShaderAst::StatementPtr GlslWriter::Sanitize(ShaderAst::Statement& ast, UInt64 enabledConditions, std::string* error) { // Always sanitize for reserved identifiers ShaderAst::SanitizeVisitor::Options options; + options.enabledOptions = enabledConditions; options.makeVariableNameUnique = true; options.reservedIdentifiers = { // All reserved GLSL keywords as of GLSL ES 3.2 @@ -296,8 +304,8 @@ namespace Nz void GlslWriter::Append(const ShaderAst::StructType& structType) { - const auto& structDesc = Retrieve(m_currentState->structs, structType.structIndex); - Append(structDesc.name); + ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structType.structIndex); + Append(structDesc->name); } void GlslWriter::Append(const ShaderAst::UniformType& /*uniformType*/) @@ -377,13 +385,13 @@ namespace Nz void GlslWriter::AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers) { - const auto& structDesc = Retrieve(m_currentState->structs, structIndex); + ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantExpression); auto& constantValue = static_cast(**memberIndices); Int32 index = std::get(constantValue.value); - const auto& member = structDesc.members[index]; + const auto& member = structDesc->members[index]; Append("."); Append(member.name); @@ -529,7 +537,7 @@ namespace Nz void GlslWriter::HandleEntryPoint(ShaderAst::DeclareFunctionStatement& node) { - if (node.entryStage == ShaderStageType::Fragment && node.earlyFragmentTests && *node.earlyFragmentTests) + if (node.entryStage.GetResultingValue() == ShaderStageType::Fragment && node.earlyFragmentTests.HasValue() && node.earlyFragmentTests.GetResultingValue()) { if ((m_environment.glES && m_environment.glMajorVersion >= 3 && m_environment.glMinorVersion >= 1) || (!m_environment.glES && m_environment.glMajorVersion >= 4 && m_environment.glMinorVersion >= 2) || (m_environment.extCallback && m_environment.extCallback("GL_ARB_shader_image_load_store"))) { @@ -553,9 +561,9 @@ namespace Nz assert(IsStructType(parameter.type)); std::size_t structIndex = std::get(parameter.type).structIndex; - const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex); + const ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); - AppendLine(structDesc.name, " ", varName, ";"); + AppendLine(structDesc->name, " ", varName, ";"); for (const auto& [memberName, targetName] : m_currentState->inputFields) AppendLine(varName, ".", memberName, " = ", targetName, ";"); @@ -578,9 +586,9 @@ namespace Nz { for (const auto& member : structDesc.members) { - if (member.builtin) + if (member.builtin.HasValue()) { - auto it = s_builtinMapping.find(member.builtin.value()); + auto it = s_builtinMapping.find(member.builtin.GetResultingValue()); assert(it != s_builtinMapping.end()); const Builtin& builtin = it->second; @@ -592,10 +600,10 @@ namespace Nz builtin.identifier }); } - else if (member.locationIndex) + else if (member.locationIndex.HasValue()) { Append("layout(location = "); - Append(*member.locationIndex); + Append(member.locationIndex.GetResultingValue()); Append(") "); Append(keyword); Append(" "); @@ -625,7 +633,7 @@ namespace Nz assert(std::holds_alternative(parameter.type)); std::size_t inputStructIndex = std::get(parameter.type).structIndex; - inputStruct = &Retrieve(m_currentState->structs, inputStructIndex); + inputStruct = Retrieve(m_currentState->structs, inputStructIndex); AppendCommentSection("Inputs"); AppendInOut(*inputStruct, m_currentState->inputFields, "in", s_inputPrefix); @@ -642,17 +650,17 @@ namespace Nz assert(std::holds_alternative(node.returnType)); std::size_t outputStructIndex = std::get(node.returnType).structIndex; - const ShaderAst::StructDescription& outputStruct = Retrieve(m_currentState->structs, outputStructIndex); + const ShaderAst::StructDescription* outputStruct = Retrieve(m_currentState->structs, outputStructIndex); AppendCommentSection("Outputs"); - AppendInOut(outputStruct, m_currentState->outputFields, "out", s_outputPrefix); + AppendInOut(*outputStruct, m_currentState->outputFields, "out", s_outputPrefix); } } - void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc) + void GlslWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc) { assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end()); - m_currentState->structs.emplace(structIndex, std::move(desc)); + m_currentState->structs.emplace(structIndex, desc); } void GlslWriter::RegisterVariable(std::size_t varIndex, std::string varName) @@ -797,20 +805,6 @@ namespace Nz Append(")"); } - void GlslWriter::Visit(ShaderAst::ConditionalExpression& node) - { - if (TestBit(m_currentState->enabledOptions, node.optionIndex)) - Visit(node.truePath); - else - Visit(node.falsePath); - } - - void GlslWriter::Visit(ShaderAst::ConditionalStatement& node) - { - if (TestBit(m_currentState->enabledOptions, node.optionIndex)) - node.statement->Visit(*this); - } - void GlslWriter::Visit(ShaderAst::ConstantExpression& node) { std::visit([&](auto&& arg) @@ -849,8 +843,9 @@ namespace Nz assert(std::holds_alternative(uniform.containedType)); std::size_t structIndex = std::get(uniform.containedType).structIndex; - auto& structInfo = Retrieve(m_currentState->structs, structIndex); - isStd140 = structInfo.layout == StructLayout::Std140; + ShaderAst::StructDescription* structInfo = Retrieve(m_currentState->structs, structIndex); + if (structInfo->layout.HasValue()) + isStd140 = structInfo->layout.GetResultingValue() == StructLayout::Std140; } if (!m_currentState->bindingMapping.empty() || isStd140) @@ -858,10 +853,14 @@ namespace Nz if (!m_currentState->bindingMapping.empty()) { - assert(externalVar.bindingIndex); + assert(externalVar.bindingIndex.HasValue()); - UInt64 bindingIndex = *externalVar.bindingIndex; - UInt64 bindingSet = externalVar.bindingSet.value_or(0); + UInt64 bindingIndex = externalVar.bindingIndex.GetResultingValue(); + UInt64 bindingSet; + if (externalVar.bindingSet.HasValue()) + bindingSet = externalVar.bindingSet.GetResultingValue(); + else + bindingSet = 0; auto bindingIt = m_currentState->bindingMapping.find(bindingSet << 32 | bindingIndex); if (bindingIt == m_currentState->bindingMapping.end()) @@ -894,7 +893,7 @@ namespace Nz auto& structDesc = Retrieve(m_currentState->structs, structIndex); bool first = true; - for (const auto& member : structDesc.members) + for (const auto& member : structDesc->members) { if (!first) AppendLine(); @@ -927,7 +926,7 @@ namespace Nz { NazaraAssert(m_currentState, "This function should only be called while processing an AST"); - if (node.entryStage && m_currentState->previsitor.entryPoint != &node) + if (node.entryStage.HasValue() && m_currentState->previsitor.entryPoint != &node) return; //< Ignore other entry points assert(node.funcIndex); @@ -951,7 +950,7 @@ namespace Nz if (hasPredeclaration) AppendLine(); - if (node.entryStage) + if (node.entryStage.HasValue()) return HandleEntryPoint(node); std::optional varIndexOpt = node.varIndex; @@ -981,7 +980,7 @@ namespace Nz void GlslWriter::Visit(ShaderAst::DeclareStructStatement& node) { assert(node.structIndex); - RegisterStruct(*node.structIndex, node.description); + RegisterStruct(*node.structIndex, &node.description); Append("struct "); AppendLine(node.description.name); @@ -1096,7 +1095,7 @@ namespace Nz const ShaderAst::ExpressionType& returnType = GetExpressionType(*node.returnExpr); assert(IsStructType(returnType)); std::size_t structIndex = std::get(returnType).structIndex; - const ShaderAst::StructDescription& structDesc = Retrieve(m_currentState->structs, structIndex); + const ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); std::string outputStructVarName; if (node.returnExpr->GetType() == ShaderAst::NodeType::VariableExpression) @@ -1104,7 +1103,7 @@ namespace Nz else { AppendLine(); - Append(structDesc.name, " ", s_outputVarName, " = "); + Append(structDesc->name, " ", s_outputVarName, " = "); node.returnExpr->Visit(*this); AppendLine(";"); diff --git a/src/Nazara/Shader/LangWriter.cpp b/src/Nazara/Shader/LangWriter.cpp index f6429f242..d71ea6d58 100644 --- a/src/Nazara/Shader/LangWriter.cpp +++ b/src/Nazara/Shader/LangWriter.cpp @@ -29,66 +29,66 @@ namespace Nz struct LangWriter::BindingAttribute { - std::optional bindingIndex; + const ShaderAst::AttributeValue& bindingIndex; - inline bool HasValue() const { return bindingIndex.has_value(); } + inline bool HasValue() const { return bindingIndex.HasValue(); } }; struct LangWriter::BuiltinAttribute { - std::optional builtin; + const ShaderAst::AttributeValue& builtin; - inline bool HasValue() const { return builtin.has_value(); } + inline bool HasValue() const { return builtin.HasValue(); } }; struct LangWriter::DepthWriteAttribute { - std::optional writeMode; + const ShaderAst::AttributeValue& writeMode; - inline bool HasValue() const { return writeMode.has_value(); } + inline bool HasValue() const { return writeMode.HasValue(); } }; struct LangWriter::EarlyFragmentTestsAttribute { - std::optional earlyFragmentTests; + const ShaderAst::AttributeValue& earlyFragmentTests; - inline bool HasValue() const { return earlyFragmentTests.has_value(); } + inline bool HasValue() const { return earlyFragmentTests.HasValue(); } }; struct LangWriter::EntryAttribute { - std::optional stageType; + const ShaderAst::AttributeValue& stageType; - inline bool HasValue() const { return stageType.has_value(); } + inline bool HasValue() const { return stageType.HasValue(); } }; struct LangWriter::LayoutAttribute { - std::optional layout; + const ShaderAst::AttributeValue& layout; - inline bool HasValue() const { return layout.has_value(); } + inline bool HasValue() const { return layout.HasValue(); } }; struct LangWriter::LocationAttribute { - std::optional locationIndex; + const ShaderAst::AttributeValue& locationIndex; - inline bool HasValue() const { return locationIndex.has_value(); } + inline bool HasValue() const { return locationIndex.HasValue(); } }; struct LangWriter::SetAttribute { - std::optional setIndex; + const ShaderAst::AttributeValue& setIndex; - inline bool HasValue() const { return setIndex.has_value(); } + inline bool HasValue() const { return setIndex.HasValue(); } }; struct LangWriter::State { const States* states = nullptr; std::stringstream stream; - std::unordered_map optionNames; - std::unordered_map structs; + std::unordered_map constantNames; + std::unordered_map structs; std::unordered_map variableNames; bool isInEntryPoint = false; unsigned int indentLevel = 0; @@ -181,8 +181,8 @@ namespace Nz void LangWriter::Append(const ShaderAst::StructType& structType) { - const auto& structDesc = Retrieve(m_currentState->structs, structType.structIndex); - Append(structDesc.name); + ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structType.structIndex); + Append(structDesc->name); } void LangWriter::Append(const ShaderAst::UniformType& uniformType) @@ -265,7 +265,10 @@ namespace Nz if (!binding.HasValue()) return; - Append("binding(", *binding.bindingIndex, ")"); + if (binding.bindingIndex.IsResultingValue()) + Append("binding(", binding.bindingIndex.GetResultingValue(), ")"); + else + binding.bindingIndex.GetExpression()->Visit(*this); } void LangWriter::AppendAttribute(BuiltinAttribute builtin) @@ -273,20 +276,25 @@ namespace Nz if (!builtin.HasValue()) return; - switch (*builtin.builtin) + if (builtin.builtin.IsResultingValue()) { - case ShaderAst::BuiltinEntry::FragCoord: - Append("builtin(fragcoord)"); - break; + switch (builtin.builtin.GetResultingValue()) + { + case ShaderAst::BuiltinEntry::FragCoord: + Append("builtin(fragcoord)"); + break; - case ShaderAst::BuiltinEntry::FragDepth: - Append("builtin(fragdepth)"); - break; + case ShaderAst::BuiltinEntry::FragDepth: + Append("builtin(fragdepth)"); + break; - case ShaderAst::BuiltinEntry::VertexPosition: - Append("builtin(position)"); - break; + case ShaderAst::BuiltinEntry::VertexPosition: + Append("builtin(position)"); + break; + } } + else + builtin.builtin.GetExpression()->Visit(*this); } void LangWriter::AppendAttribute(DepthWriteAttribute depthWrite) @@ -294,24 +302,29 @@ namespace Nz if (!depthWrite.HasValue()) return; - switch (*depthWrite.writeMode) + if (depthWrite.writeMode.IsResultingValue()) { - case ShaderAst::DepthWriteMode::Greater: - Append("depth_write(greater)"); - break; + switch (depthWrite.writeMode.GetResultingValue()) + { + case ShaderAst::DepthWriteMode::Greater: + Append("depth_write(greater)"); + break; - case ShaderAst::DepthWriteMode::Less: - Append("depth_write(less)"); - break; + case ShaderAst::DepthWriteMode::Less: + Append("depth_write(less)"); + break; - case ShaderAst::DepthWriteMode::Replace: - Append("depth_write(replace)"); - break; + case ShaderAst::DepthWriteMode::Replace: + Append("depth_write(replace)"); + break; - case ShaderAst::DepthWriteMode::Unchanged: - Append("depth_write(unchanged)"); - break; + case ShaderAst::DepthWriteMode::Unchanged: + Append("depth_write(unchanged)"); + break; + } } + else + depthWrite.writeMode.GetExpression()->Visit(*this); } void LangWriter::AppendAttribute(EarlyFragmentTestsAttribute earlyFragmentTests) @@ -319,10 +332,15 @@ namespace Nz if (!earlyFragmentTests.HasValue()) return; - if (*earlyFragmentTests.earlyFragmentTests) - Append("early_fragment_tests(on)"); + if (earlyFragmentTests.earlyFragmentTests.IsResultingValue()) + { + if (earlyFragmentTests.earlyFragmentTests.GetResultingValue()) + Append("early_fragment_tests(true)"); + else + Append("early_fragment_tests(false)"); + } else - Append("early_fragment_tests(off)"); + earlyFragmentTests.earlyFragmentTests.GetExpression()->Visit(*this); } void LangWriter::AppendAttribute(EntryAttribute entry) @@ -330,16 +348,21 @@ namespace Nz if (!entry.HasValue()) return; - switch (*entry.stageType) + if (entry.stageType.IsResultingValue()) { - case ShaderStageType::Fragment: - Append("entry(frag)"); - break; + switch (entry.stageType.GetResultingValue()) + { + case ShaderStageType::Fragment: + Append("entry(frag)"); + break; - case ShaderStageType::Vertex: - Append("entry(vert)"); - break; + case ShaderStageType::Vertex: + Append("entry(vert)"); + break; + } } + else + entry.stageType.GetExpression()->Visit(*this); } void LangWriter::AppendAttribute(LayoutAttribute entry) @@ -347,12 +370,17 @@ namespace Nz if (!entry.HasValue()) return; - switch (*entry.layout) + if (entry.layout.IsResultingValue()) { - case StructLayout::Std140: - Append("layout(std140)"); - break; + switch (entry.layout.GetResultingValue()) + { + case StructLayout::Std140: + Append("layout(std140)"); + break; + } } + else + entry.layout.GetExpression()->Visit(*this); } void LangWriter::AppendAttribute(LocationAttribute location) @@ -360,7 +388,10 @@ namespace Nz if (!location.HasValue()) return; - Append("location(", *location.locationIndex, ")"); + if (location.locationIndex.IsResultingValue()) + Append("location(", location.locationIndex.GetResultingValue(), ")"); + else + location.locationIndex.GetExpression()->Visit(*this); } void LangWriter::AppendAttribute(SetAttribute set) @@ -368,7 +399,10 @@ namespace Nz if (!set.HasValue()) return; - Append("set(", *set.setIndex, ")"); + if (set.setIndex.IsResultingValue()) + Append("set(", set.setIndex.GetResultingValue(), ")"); + else + set.setIndex.GetExpression()->Visit(*this); } void LangWriter::AppendCommentSection(const std::string& section) @@ -382,13 +416,13 @@ namespace Nz void LangWriter::AppendField(std::size_t structIndex, const ShaderAst::ExpressionPtr* memberIndices, std::size_t remainingMembers) { - const auto& structDesc = Retrieve(m_currentState->structs, structIndex); + ShaderAst::StructDescription* structDesc = Retrieve(m_currentState->structs, structIndex); assert((*memberIndices)->GetType() == ShaderAst::NodeType::ConstantExpression); auto& constantValue = static_cast(**memberIndices); Int32 index = std::get(constantValue.value); - const auto& member = structDesc.members[index]; + const auto& member = structDesc->members[index]; Append("."); Append(member.name); @@ -449,16 +483,16 @@ namespace Nz Append("}"); } - void LangWriter::RegisterOption(std::size_t optionIndex, std::string optionName) + void LangWriter::RegisterConstant(std::size_t constantIndex, std::string constantName) { - assert(m_currentState->optionNames.find(optionIndex) == m_currentState->optionNames.end()); - m_currentState->optionNames.emplace(optionIndex, std::move(optionName)); + assert(m_currentState->constantNames.find(constantIndex) == m_currentState->constantNames.end()); + m_currentState->constantNames.emplace(constantIndex, std::move(constantName)); } - void LangWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription desc) + void LangWriter::RegisterStruct(std::size_t structIndex, ShaderAst::StructDescription* desc) { assert(m_currentState->structs.find(structIndex) == m_currentState->structs.end()); - m_currentState->structs.emplace(structIndex, std::move(desc)); + m_currentState->structs.emplace(structIndex, desc); } void LangWriter::RegisterVariable(std::size_t varIndex, std::string varName) @@ -589,16 +623,14 @@ namespace Nz void LangWriter::Visit(ShaderAst::ConditionalExpression& node) { - Append("select_opt(", Retrieve(m_currentState->optionNames, node.optionIndex), ", "); - node.truePath->Visit(*this); - Append(", "); - node.falsePath->Visit(*this); - Append(")"); + throw std::runtime_error("fixme"); } void LangWriter::Visit(ShaderAst::ConditionalStatement& node) { - Append("[opt(", Retrieve(m_currentState->optionNames, node.optionIndex), ")]"); + Append("[cond("); + node.condition->Visit(*this); + AppendLine(")]"); node.statement->Visit(*this); } @@ -629,6 +661,11 @@ namespace Nz }, node.value); } + void LangWriter::Visit(ShaderAst::ConstantIndexExpression& node) + { + Append(Retrieve(m_currentState->constantNames, node.constantId)); + } + void LangWriter::Visit(ShaderAst::DeclareExternalStatement& node) { assert(node.varIndex); @@ -690,7 +727,7 @@ namespace Nz void LangWriter::Visit(ShaderAst::DeclareOptionStatement& node) { assert(node.optIndex); - RegisterOption(*node.optIndex, node.optName); + RegisterConstant(*node.optIndex, node.optName); Append("option ", node.optName, ": ", node.optType); if (node.initialValue) @@ -705,7 +742,7 @@ namespace Nz void LangWriter::Visit(ShaderAst::DeclareStructStatement& node) { assert(node.structIndex); - RegisterStruct(*node.structIndex, node.description); + RegisterStruct(*node.structIndex, &node.description); AppendAttributes(true, LayoutAttribute{ node.description.layout }); Append("struct "); diff --git a/src/Nazara/Shader/ShaderLangParser.cpp b/src/Nazara/Shader/ShaderLangParser.cpp index 738787f9c..d6e52ae85 100644 --- a/src/Nazara/Shader/ShaderLangParser.cpp +++ b/src/Nazara/Shader/ShaderLangParser.cpp @@ -29,12 +29,12 @@ namespace Nz::ShaderLang std::unordered_map s_identifierToAttributeType = { { "binding", ShaderAst::AttributeType::Binding }, { "builtin", ShaderAst::AttributeType::Builtin }, + { "cond", ShaderAst::AttributeType::Cond }, { "depth_write", ShaderAst::AttributeType::DepthWrite }, { "early_fragment_tests", ShaderAst::AttributeType::EarlyFragmentTests }, { "entry", ShaderAst::AttributeType::Entry }, { "layout", ShaderAst::AttributeType::Layout }, { "location", ShaderAst::AttributeType::Location }, - { "opt", ShaderAst::AttributeType::Option }, { "set", ShaderAst::AttributeType::Set }, }; @@ -61,6 +61,41 @@ namespace Nz::ShaderLang return static_cast(val); } + + template + void HandleUniqueAttribute(const std::string_view& attributeName, ShaderAst::AttributeValue& targetAttribute, ShaderAst::Attribute::Param&& param, bool requireValue = true) + { + if (targetAttribute.HasValue()) + throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" }; + + if (!param && requireValue) + throw AttributeError{ "attribute " + std::string(attributeName) + " requires a parameter" }; + + targetAttribute = std::move(*param); + } + + template + void HandleUniqueStringAttribute(const std::string_view& attributeName, const std::unordered_map& map, ShaderAst::AttributeValue& targetAttribute, ShaderAst::Attribute::Param&& param) + { + if (targetAttribute.HasValue()) + throw AttributeError{ "attribute " + std::string(attributeName) + " must be present once" }; + + //FIXME: This should be handled with global values at sanitization stage + if (!param) + throw AttributeError{ "attribute " + std::string(attributeName) + " requires a value" }; + + const ShaderAst::ExpressionPtr& expr = *param; + if (expr->GetType() != ShaderAst::NodeType::IdentifierExpression) + throw AttributeError{ "attribute " + std::string(attributeName) + " can only be an identifier for now" }; + + const std::string& exprStr = static_cast(*expr).identifier; + + auto it = map.find(exprStr); + if (it == map.end()) + throw AttributeError{ ("invalid parameter " + exprStr + " for " + std::string(attributeName) + " attribute").c_str() }; + + targetAttribute = it->second; + } } ShaderAst::StatementPtr Parser::Parse(const std::vector& tokens) @@ -347,17 +382,7 @@ namespace Nz::ShaderLang { Consume(); - const Token& n = Peek(); - if (n.type == TokenType::Identifier) - { - arg = std::get(n.data); - Consume(); - } - else if (n.type == TokenType::IntegerValue) - { - arg = std::get(n.data); - Consume(); - } + arg = ParseExpression(); Expect(Advance(), TokenType::ClosingParenthesis); } @@ -418,37 +443,24 @@ namespace Nz::ShaderLang ShaderAst::StatementPtr Parser::ParseExternalBlock(std::vector attributes) { - std::optional blockSetIndex; - for (const auto& [attributeType, arg] : attributes) + Expect(Advance(), TokenType::External); + Expect(Advance(), TokenType::OpenCurlyBracket); + + std::unique_ptr externalStatement = std::make_unique(); + + for (auto&& [attributeType, arg] : attributes) { switch (attributeType) { case ShaderAst::AttributeType::Set: - { - if (blockSetIndex) - throw AttributeError{ "attribute set must be present once" }; - - if (!std::holds_alternative(arg)) - throw AttributeError{ "attribute set requires a string parameter" }; - - std::optional bindingIndex = BoundCast(std::get(arg)); - if (!bindingIndex) - throw AttributeError{ "invalid set index" }; - - blockSetIndex = bindingIndex.value(); + HandleUniqueAttribute("set", externalStatement->bindingSet, std::move(arg)); break; - } default: throw AttributeError{ "unhandled attribute for external block" }; } } - Expect(Advance(), TokenType::External); - Expect(Advance(), TokenType::OpenCurlyBracket); - - std::unique_ptr externalStatement = std::make_unique(); - bool first = true; for (;;) { @@ -474,41 +486,17 @@ namespace Nz::ShaderLang if (token.type == TokenType::OpenSquareBracket) { - for (const auto& [attributeType, arg] : ParseAttributes()) + for (auto&& [attributeType, arg] : ParseAttributes()) { switch (attributeType) { case ShaderAst::AttributeType::Binding: - { - if (extVar.bindingIndex) - throw AttributeError{ "attribute binding must be present once" }; - - if (!std::holds_alternative(arg)) - throw AttributeError{ "attribute binding requires a string parameter" }; - - std::optional bindingIndex = BoundCast(std::get(arg)); - if (!bindingIndex) - throw AttributeError{ "invalid binding index" }; - - extVar.bindingIndex = bindingIndex.value(); + HandleUniqueAttribute("binding", extVar.bindingIndex, std::move(arg)); break; - } case ShaderAst::AttributeType::Set: - { - if (extVar.bindingSet) - throw AttributeError{ "attribute set must be present once" }; - - if (!std::holds_alternative(arg)) - throw AttributeError{ "attribute set requires a string parameter" }; - - std::optional bindingIndex = BoundCast(std::get(arg)); - if (!bindingIndex) - throw AttributeError{ "invalid set index" }; - - extVar.bindingSet = bindingIndex.value(); + HandleUniqueAttribute("set", extVar.bindingSet, std::move(arg)); break; - } default: throw AttributeError{ "unhandled attribute for external variable" }; @@ -520,9 +508,6 @@ namespace Nz::ShaderLang Expect(Advance(), TokenType::Colon); extVar.type = ParseType(); - if (!extVar.bindingSet && blockSetIndex) - extVar.bindingSet = *blockSetIndex; - RegisterVariable(extVar.name); } @@ -583,90 +568,37 @@ namespace Nz::ShaderLang auto func = ShaderBuilder::DeclareFunction(std::move(functionName), std::move(parameters), std::move(functionBody), std::move(returnType)); - for (const auto& [attributeType, arg] : attributes) + ShaderAst::AttributeValue condition; + + for (auto&& [attributeType, arg] : attributes) { switch (attributeType) { - case ShaderAst::AttributeType::DepthWrite: - { - if (func->depthWrite) - throw AttributeError{ "attribute depth_write can only be present once" }; - - if (!std::holds_alternative(arg)) - throw AttributeError{ "attribute entry requires a string parameter" }; - - const std::string& argStr = std::get(arg); - - auto it = s_depthWriteModes.find(argStr); - if (it == s_depthWriteModes.end()) - throw AttributeError{ ("invalid parameter " + argStr + " for depth_write attribute").c_str() }; - - func->depthWrite = it->second; + case ShaderAst::AttributeType::Cond: + HandleUniqueAttribute("cond", condition, std::move(arg)); break; - } - - case ShaderAst::AttributeType::EarlyFragmentTests: - { - if (func->earlyFragmentTests) - throw AttributeError{ "attribute early_fragment_tests can only be present once" }; - - if (std::holds_alternative(arg)) - { - const std::string& argStr = std::get(arg); - if (argStr == "true" || argStr == "on") - func->earlyFragmentTests = true; - else if (argStr == "false" || argStr == "off") - func->earlyFragmentTests = false; - else - throw AttributeError{ "expected boolean value (got " + argStr + ")" }; - } - else if (std::holds_alternative(arg)) - { - // No parameter, default to true - func->earlyFragmentTests = true; - } - else - throw AttributeError{ "unexpected value for early_fragment_tests" }; - - break; - } case ShaderAst::AttributeType::Entry: - { - if (func->entryStage) - throw AttributeError{ "attribute entry can only be present once" }; - - if (!std::holds_alternative(arg)) - throw AttributeError{ "attribute entry requires a string parameter" }; - - const std::string& argStr = std::get(arg); - - auto it = s_entryPoints.find(argStr); - if (it == s_entryPoints.end()) - throw AttributeError{ ("invalid parameter " + argStr + " for entry attribute").c_str() }; - - func->entryStage = it->second; + HandleUniqueStringAttribute("entry", s_entryPoints, func->entryStage, std::move(arg)); break; - } - case ShaderAst::AttributeType::Option: - { - if (!func->optionName.empty()) - throw AttributeError{ "attribute opt must be present once" }; - - if (!std::holds_alternative(arg)) - throw AttributeError{ "attribute opt requires a string parameter" }; - - func->optionName = std::get(arg); + case ShaderAst::AttributeType::DepthWrite: + HandleUniqueStringAttribute("depth_write", s_depthWriteModes, func->depthWrite, std::move(arg)); + break; + + case ShaderAst::AttributeType::EarlyFragmentTests: + HandleUniqueAttribute("early_fragment_tests", func->earlyFragmentTests, std::move(arg), false); break; - } default: throw AttributeError{ "unhandled attribute for function" }; } } - return func; + if (condition.HasValue()) + return ShaderBuilder::ConditionalStatement(std::move(condition).GetExpression(), std::move(func)); + else + return func; } ShaderAst::DeclareFunctionStatement::Parameter Parser::ParseFunctionParameter() @@ -710,22 +642,13 @@ namespace Nz::ShaderLang ShaderAst::StructDescription description; description.name = ParseIdentifierAsName(); - for (const auto& [attributeType, attributeParam] : attributes) + for (auto&& [attributeType, attributeParam] : attributes) { switch (attributeType) { case ShaderAst::AttributeType::Layout: - { - if (description.layout) - throw AttributeError{ "attribute layout must be present once" }; - - auto it = s_layoutMapping.find(std::get(attributeParam)); - if (it == s_layoutMapping.end()) - throw AttributeError{ "unknown layout" }; - - description.layout = it->second; + HandleUniqueStringAttribute("layout", s_layoutMapping, description.layout, std::move(attributeParam)); break; - } default: throw AttributeError{ "unexpected attribute" }; @@ -760,42 +683,28 @@ namespace Nz::ShaderLang if (token.type == TokenType::OpenSquareBracket) { - for (const auto& [attributeType, attributeParam] : ParseAttributes()) + for (auto&& [attributeType, arg] : ParseAttributes()) { switch (attributeType) { case ShaderAst::AttributeType::Builtin: - { - if (structField.builtin) - throw AttributeError{ "attribute builtin must be present once" }; - - auto it = s_builtinMapping.find(std::get(attributeParam)); - - if (it == s_builtinMapping.end()) - throw AttributeError{ "unknown builtin" }; - - structField.builtin = it->second; + HandleUniqueStringAttribute("builtin", s_builtinMapping, structField.builtin, std::move(arg)); + break; + + case ShaderAst::AttributeType::Cond: + HandleUniqueAttribute("cond", structField.cond, std::move(arg)); break; - } case ShaderAst::AttributeType::Location: - { - if (structField.locationIndex) - throw AttributeError{ "attribute location must be present once" }; - - structField.locationIndex = BoundCast(std::get(attributeParam)); - if (!structField.locationIndex) - throw AttributeError{ "invalid location index" }; - + HandleUniqueAttribute("location", structField.locationIndex, std::move(arg)); break; - } default: throw AttributeError{ "unexpected attribute" }; } } - if (structField.builtin && structField.locationIndex) + if (structField.builtin.HasValue() && structField.locationIndex.HasValue()) throw AttributeError{ "A struct field cannot have both builtin and location attributes" }; } diff --git a/src/Nazara/Shader/SpirvAstVisitor.cpp b/src/Nazara/Shader/SpirvAstVisitor.cpp index b1980efa4..a6d2e27d4 100644 --- a/src/Nazara/Shader/SpirvAstVisitor.cpp +++ b/src/Nazara/Shader/SpirvAstVisitor.cpp @@ -570,20 +570,6 @@ namespace Nz } } - void SpirvAstVisitor::Visit(ShaderAst::ConditionalExpression& node) - { - if (m_writer.IsOptionEnabled(node.optionIndex)) - node.truePath->Visit(*this); - else - node.falsePath->Visit(*this); - } - - void SpirvAstVisitor::Visit(ShaderAst::ConditionalStatement& node) - { - if (m_writer.IsOptionEnabled(node.optionIndex)) - node.statement->Visit(*this); - } - void SpirvAstVisitor::Visit(ShaderAst::ConstantExpression& node) { std::visit([&] (const auto& value) @@ -678,7 +664,7 @@ namespace Nz void SpirvAstVisitor::Visit(ShaderAst::DeclareStructStatement& node) { assert(node.structIndex); - RegisterStruct(*node.structIndex, node.description); + RegisterStruct(*node.structIndex, &node.description); } void SpirvAstVisitor::Visit(ShaderAst::DeclareVariableStatement& node) diff --git a/src/Nazara/Shader/SpirvWriter.cpp b/src/Nazara/Shader/SpirvWriter.cpp index 0dbb15dfc..9640c80c4 100644 --- a/src/Nazara/Shader/SpirvWriter.cpp +++ b/src/Nazara/Shader/SpirvWriter.cpp @@ -57,7 +57,7 @@ namespace Nz using ExtVarContainer = std::unordered_map; using LocalContainer = std::unordered_set; using FunctionContainer = std::vector>; - using StructContainer = std::vector; + using StructContainer = std::vector; PreVisitor(const SpirvWriter::States& conditions, SpirvConstantCache& constantCache, std::vector& funcs) : m_states(conditions), @@ -68,7 +68,7 @@ namespace Nz m_constantCache.SetStructCallback([this](std::size_t structIndex) -> const ShaderAst::StructDescription& { assert(structIndex < declaredStructs.size()); - return declaredStructs[structIndex]; + return *declaredStructs[structIndex]; }); } @@ -88,18 +88,12 @@ namespace Nz void Visit(ShaderAst::ConditionalExpression& node) override { - if (TestBit(m_states.enabledOptions, node.optionIndex)) - node.truePath->Visit(*this); - else - node.falsePath->Visit(*this); - - m_constantCache.Register(*m_constantCache.BuildType(node.cachedExpressionType.value())); + throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?"); } void Visit(ShaderAst::ConditionalStatement& node) override { - if (TestBit(m_states.enabledOptions, node.optionIndex)) - node.statement->Visit(*this); + throw std::runtime_error("unexpected conditional expression, did you forget to sanitize the shader?"); } void Visit(ShaderAst::ConstantExpression& node) override @@ -123,12 +117,12 @@ namespace Nz variable.storageClass = (ShaderAst::IsSamplerType(extVar.type)) ? SpirvStorageClass::UniformConstant : SpirvStorageClass::Uniform; variable.type = m_constantCache.BuildPointerType(extVar.type, variable.storageClass); - assert(extVar.bindingIndex); + assert(extVar.bindingIndex.IsResultingValue()); UniformVar& uniformVar = extVars[varIndex++]; uniformVar.pointerId = m_constantCache.Register(variable); - uniformVar.bindingIndex = *extVar.bindingIndex; - uniformVar.descriptorSet = extVar.bindingSet.value_or(0); + uniformVar.bindingIndex = extVar.bindingIndex.GetResultingValue(); + uniformVar.descriptorSet = (extVar.bindingSet.HasValue()) ? extVar.bindingSet.GetResultingValue() : 0; } } @@ -151,7 +145,9 @@ namespace Nz void Visit(ShaderAst::DeclareFunctionStatement& node) override { - std::optional entryPointType = node.entryStage; + std::optional entryPointType; + if (node.entryStage.HasValue()) + entryPointType = node.entryStage.GetResultingValue(); assert(node.funcIndex); std::size_t funcIndex = *node.funcIndex; @@ -188,14 +184,14 @@ namespace Nz if (*entryPointType == ShaderStageType::Fragment) { executionModes.push_back(SpirvExecutionMode::OriginUpperLeft); - if (node.earlyFragmentTests && *node.earlyFragmentTests) + if (node.earlyFragmentTests.HasValue() && node.earlyFragmentTests.GetResultingValue()) executionModes.push_back(SpirvExecutionMode::EarlyFragmentTests); - if (node.depthWrite) + if (node.depthWrite.HasValue()) { executionModes.push_back(SpirvExecutionMode::DepthReplacing); - switch (*node.depthWrite) + switch (node.depthWrite.GetResultingValue()) { case ShaderAst::DepthWriteMode::Replace: break; case ShaderAst::DepthWriteMode::Greater: executionModes.push_back(SpirvExecutionMode::DepthGreater); break; @@ -217,10 +213,10 @@ namespace Nz assert(std::holds_alternative(parameter.type)); std::size_t structIndex = std::get(parameter.type).structIndex; - const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex]; + const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex]; std::size_t memberIndex = 0; - for (const auto& member : structDesc.members) + for (const auto& member : structDesc->members) { if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Input); varId != 0) { @@ -247,10 +243,10 @@ namespace Nz assert(std::holds_alternative(node.returnType)); std::size_t structIndex = std::get(node.returnType).structIndex; - const ShaderAst::StructDescription& structDesc = declaredStructs[structIndex]; + const ShaderAst::StructDescription* structDesc = declaredStructs[structIndex]; std::size_t memberIndex = 0; - for (const auto& member : structDesc.members) + for (const auto& member : structDesc->members) { if (UInt32 varId = HandleEntryInOutType(*entryPointType, funcIndex, member, SpirvStorageClass::Output); varId != 0) { @@ -291,7 +287,7 @@ namespace Nz if (structIndex >= declaredStructs.size()) declaredStructs.resize(structIndex + 1); - declaredStructs[structIndex] = node.description; + declaredStructs[structIndex] = &node.description; m_constantCache.Register(*m_constantCache.BuildType(node.description)); } @@ -357,9 +353,9 @@ namespace Nz UInt32 HandleEntryInOutType(ShaderStageType entryPointType, std::size_t funcIndex, const ShaderAst::StructDescription::StructMember& member, SpirvStorageClass storageClass) { - if (member.builtin) + if (member.builtin.HasValue()) { - auto it = s_builtinMapping.find(*member.builtin); + auto it = s_builtinMapping.find(member.builtin.GetResultingValue()); assert(it != s_builtinMapping.end()); Builtin& builtin = it->second; @@ -379,7 +375,7 @@ namespace Nz return varId; } - else if (member.locationIndex) + else if (member.locationIndex.HasValue()) { SpirvConstantCache::Variable variable; variable.debugName = member.name; @@ -388,7 +384,7 @@ namespace Nz variable.type = m_constantCache.BuildPointerType(member.type, storageClass); UInt32 varId = m_constantCache.Register(variable); - locationDecorations[varId] = *member.locationIndex; + locationDecorations[varId] = member.locationIndex.GetResultingValue(); return varId; } @@ -453,7 +449,10 @@ namespace Nz ShaderAst::StatementPtr sanitizedAst; if (!states.sanitized) { - sanitizedAst = ShaderAst::Sanitize(shader); + ShaderAst::SanitizeVisitor::Options options; + options.enabledOptions = states.enabledOptions; + + sanitizedAst = ShaderAst::Sanitize(shader, options); targetAst = sanitizedAst.get(); }